Basic example

Basic example#

This notebook executes the basic examples in the agjax project readme.

import functools
import autograd.numpy as npa
import jax
import jax.numpy as jnp

import agjax

Define a basic function using autograd, wrap it, and use jax to compute the gradient.

@agjax.wrap_for_jax
def fn(x, y):
  return x * npa.cos(y)

grad = jax.grad(fn, argnums=(0, 1))(1.0, 0.0)
print(f"grad = {grad}")
grad = (Array(1., dtype=float32), Array(0., dtype=float32))
wrapped_fn = agjax.experimental.wrap_for_jax(
  lambda x, y: x * npa.cos(y),
  result_shape_dtypes=jnp.ones((5,)),
)

jac = jax.jacrev(wrapped_fn, argnums=0)(jnp.arange(5, dtype=float), jnp.arange(5, 10, dtype=float))
print(f"jac = \n{jac}")
jac = 
[[ 0.28366217  0.          0.          0.          0.        ]
 [ 0.          0.96017027  0.          0.          0.        ]
 [ 0.          0.          0.75390226  0.          0.        ]
 [ 0.          0.          0.         -0.14550003  0.        ]
 [ 0.          0.          0.          0.         -0.91113025]]

Define a function that has a nondifferentiable argument and output (both with string type), and compute the value and gradient.

@functools.partial(
  agjax.wrap_for_jax, nondiff_argnums=(2,), nondiff_outputnums=(1,)
)
def fn(x, y, string_arg):
  return x * npa.cos(y), string_arg * 2

(value, aux), grad = jax.value_and_grad(
  fn,
  argnums=(0, 1),
  has_aux=True,
)(1.0, 0.0, "test")

print(f"value = {value}")
print(f"  aux = {aux}")
print(f" grad = {grad}")
value = 1.0
  aux = testtest
 grad = (Array(1., dtype=float32), Array(0., dtype=float32))