Basic example#
This notebook executes the basic examples in the agjax project readme.
import functools
import autograd.numpy as npa
import jax
import jax.numpy as jnp
import agjax
Define a basic function using autograd, wrap it, and use jax to compute the gradient.
@agjax.wrap_for_jax
def fn(x, y):
return x * npa.cos(y)
grad = jax.grad(fn, argnums=(0, 1))(1.0, 0.0)
print(f"grad = {grad}")
grad = (Array(1., dtype=float32), Array(0., dtype=float32))
wrapped_fn = agjax.experimental.wrap_for_jax(
lambda x, y: x * npa.cos(y),
result_shape_dtypes=jnp.ones((5,)),
)
jac = jax.jacrev(wrapped_fn, argnums=0)(jnp.arange(5, dtype=float), jnp.arange(5, 10, dtype=float))
print(f"jac = \n{jac}")
jac =
[[ 0.28366217 0. 0. 0. 0. ]
[ 0. 0.96017027 0. 0. 0. ]
[ 0. 0. 0.75390226 0. 0. ]
[ 0. 0. 0. -0.14550003 0. ]
[ 0. 0. 0. 0. -0.91113025]]
Define a function that has a nondifferentiable argument and output (both with string type), and compute the value and gradient.
@functools.partial(
agjax.wrap_for_jax, nondiff_argnums=(2,), nondiff_outputnums=(1,)
)
def fn(x, y, string_arg):
return x * npa.cos(y), string_arg * 2
(value, aux), grad = jax.value_and_grad(
fn,
argnums=(0, 1),
has_aux=True,
)(1.0, 0.0, "test")
print(f"value = {value}")
print(f" aux = {aux}")
print(f" grad = {grad}")
value = 1.0
aux = testtest
grad = (Array(1., dtype=float32), Array(0., dtype=float32))