Photonic inverse design example

Photonic inverse design example#

In this example, we use the agjax wrapper with the ceviche-challenges photonic inverse design suite to carry out inverse design using jax. You must pip install ceviche_challenges to enable this notebook to run.

import ceviche_challenges
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as onp

import agjax

This example will use the waveguide bend challenge problem; the aim of the challenge is to design a structure which that redirects light from a horizontal waveguide into a vertical waveguide. Begin by constructing the model for the waveguide bend; this gives us an autograd-differentiable model.simulate method, which we wrap so that it can be differentiated by jax.

spec = ceviche_challenges.waveguide_bend.prefabs.waveguide_bend_2umx2um_spec()
params = ceviche_challenges.waveguide_bend.prefabs.waveguide_bend_sim_params()
waveguide_bend_model = ceviche_challenges.waveguide_bend.model.WaveguideBendModel(params, spec)

# The simulate method has signature `fn(design) -> (s_params, fields)`. To use a
# jit-compatible wrapper, we must specify the shapes and dtypes of outputs.
s_params_shape = (
    len(waveguide_bend_model.output_wavelengths), 1, len(waveguide_bend_model.ports)
)
fields_shape = s_params_shape[:2] + waveguide_bend_model.shape
result_shape_dtypes = (
    jnp.ones(s_params_shape, dtype=complex),
    jnp.ones(fields_shape, dtype=complex),
)

# Wrap this function, marking the fields as a non-differentiable output.
jax_simulate_fn = agjax.experimental.wrap_for_jax(
    fn=waveguide_bend_model.simulate,
    result_shape_dtypes=result_shape_dtypes,
    nondiff_outputnums=1,
)

The simulate function has signature fn(design) -> (s_params, fields), where the design is a 2D array giving the density (with values between 0 and 1) at a grid of locations in the design region. Densities of 0 and 1 correspond to the cladding and core materials, respectively, and intermediate values correspond to a blend of the two. The s_params and have shape (num_wavelengths, num_excitation_ports, num_output_ports), while the fields have shape (num_wavelengths, num_excitation_ports, xnum, ynum). For the waveguide bend, the default simulation involves excitation with the fundamental waveguide mode from the horizontal input waveguide (port 1) at a single wavelength (1550 nm).

# Perform an example simulation, where the design is `1` everywhere.
design = jnp.ones(waveguide_bend_model.design_variable_shape)
s_params, fields = jax_simulate_fn(design)
assert s_params.shape == (1, 1, 2)
assert fields.shape[:-2] == (1, 1)

fig, ax = plt.subplots(1, 2, figsize=(6, 3))
ax[0].imshow(jnp.rot90(waveguide_bend_model.density(design)), cmap="gray")
ax[0].set_title("Density")
ax[0].axis("off")
ax[1].imshow(jnp.rot90(jnp.abs(fields[0, 0, :, :])), cmap="magma")
ax[1].set_title("Fields")
ax[1].axis("off")
(np.float64(-0.5), np.float64(102.5), np.float64(102.5), np.float64(-0.5))
_images/acebfda8ee659919c7c63cbe69ea7bbcda23f6c0f768a117b0ae305c1168a2ce.png

Define a jax loss function which rewards proper wavguide bend behavior. An ideal design has high transmission into the vertical waveguide and low back-reflection into the horizontal waveguide.

def loss_fn(density):
  # A simple loss function that rewards high transmission and low reflection.
  s_params, fields = jax_simulate_fn(density)
  # Transmission is given by `s21`, reflection is given by `s11`.
  s11 = jnp.abs(s_params[:, 0, 0])
  s21 = jnp.abs(s_params[:, 0, 1])
  return jnp.mean(s11) - jnp.mean(s21), (s_params, fields)

Optimize a design using basic gradient descent, starting with an initial design that is 0.5-valued everywhere, i.e. with intermediate composition.

design = jnp.full(waveguide_bend_model.design_variable_shape, 0.5)
learning_rate = 0.1

@jax.jit
def step_fn(design):
  (loss, (s_params, fields)), grad = jax.value_and_grad(loss_fn, has_aux=True)(design)
  design = design - learning_rate * grad
  # Clip the design so that the permittivity everywhere in the design
  # region remains between that of cladding and core materials.
  design = jnp.clip(design, 0, 1)
  return design, (loss, s_params, fields)

loss_values = []
for _ in range(100):
  design, (loss, s_params, fields) = step_fn(design)
  loss_values.append(loss)

Plot the results of the optimization.

fig, ax = plt.subplots(1, 3, figsize=(10, 3))
ax[0].plot(loss_values)
ax[0].set_xlabel("step")
ax[0].set_ylabel("loss")
ax[1].imshow(jnp.rot90(waveguide_bend_model.density(design)), cmap="gray")
ax[1].set_title("Density")
ax[1].axis("off")
ax[2].imshow(jnp.rot90(jnp.abs(fields[0, 0, :, :])), cmap="magma")
ax[2].set_title("Fields")
ax[2].axis("off")
(np.float64(-0.5), np.float64(102.5), np.float64(102.5), np.float64(-0.5))
_images/0d37c5ec8253d58104386837a724002acd17d934354a1b80faebec250e6a1f8f.png