API#

agjax - a jax wrapper for autograd-differentiable functions.

agjax.wrap_for_jax(fn: Callable[[Any], Any], nondiff_argnums: int | Tuple[int, ...] = (), nondiff_outputnums: int | Tuple[int, ...] = ()) Callable[[Any], Any][source]#

Wraps fn so that it can be differentiated by jax.

Arguments should be jax types, and are converted to numpy arrays prior to calling the underlying autograd-differentiable fn. Optionally, nondifferentiable arguments (i.e. those which cannot be differentiated with respect to) may be specified; these are passed to fn unchanged.

Similarly, differentiable outputs are converted to jax types; some outputs can be identified as non-differentiable, which are returned unchanged.

Parameters:
  • fn – The autograd-differentiable function.

  • nondiff_argnums – The arguments that cannot be differentiated with respect to. These are passed to fn unchanged.

  • nondiff_outputnums – The outputs that cannot be differentiated. These are returned exactly as returned by fn.

Returns:

The wrapped function.

Defines a jax wrapper for autograd-differentiable functions.

agjax.experimental.wrapper.wrap_for_jax(fn: Callable[[Any], Any], result_shape_dtypes: Any, nondiff_argnums: int | Tuple[int, ...] = (), nondiff_outputnums: int | Tuple[int, ...] = ()) Callable[[Any], Any][source]#

Wraps fn so that it can be differentiated by jax.

The wrapped function is suitable for jax transformations such as grad, jit, vmap, and jacrev, which is achieved using jax.pure_callback.

Arguments to fn must be convertible to jax types, as must all outputs. The arguments to the wrapped function should be jax types, and the outputs will be jax types.

Arguments which need not be differentiated with respect to may be specified in nondiff_argnums, while outputs that need not be differentiated may be specified in nondiff_outputnums.

Parameters:
  • fn – The autograd-differentiable function.

  • result_shape_dtypes – A pytree matching the jax-converted output of fn. Specifically, the pytree structure, leaf shapes, and datatypes must match.

  • nondiff_argnums – The arguments that cannot be differentiated with respect to.

  • nondiff_outputnums – The outputs that cannot be differentiated.

Returns:

The wrapped function.