-
Notifications
You must be signed in to change notification settings - Fork 58
doc[autograd]: improve autograd plugin readme #2565
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
1 file reviewed, 2 comments
Edit PR Review Bot Settings | Greptile
**Example: A Simple Optimization** | ||
```python | ||
import autograd | ||
autograd.grad(f) | ||
``` | ||
|
||
There is also a `numpy` wrapper that can be similarly imported from `autograd.numpy` | ||
|
||
```py | ||
import jax.numpy as jnp | ||
jnp.sum(...) | ||
``` | ||
|
||
becomes | ||
|
||
```py | ||
import autograd.numpy as anp | ||
anp.sum(...) | ||
``` | ||
|
||
`Autograd` supports fewer features than `jax`. | ||
For example, the `has_aux` option is not supported in the default `autograd.grad()` function, but one can write their own utilities to implement these features, as we show in the notebook examples. | ||
We also have a `value_and_grad` function in `tidy3d.plugins.autograd.differential_operators` that is similar to `jax.value_and_grad` and supports `has_aux`. | ||
Additionally, `autograd` has a `grad_with_aux` function that can be used to compute gradients while returning auxiliary values, similar to `jax.grad` with `has_aux`. | ||
|
||
Otherwise, `jax` and `autograd` are very similar to each other in practice. | ||
|
||
### Migrating from `adjoint` plugin | ||
|
||
Converting code from the `adjoint` plugin to the native autograd support is straightforward. | ||
|
||
Instead of importing classes from the `tda` namespace, with names like `tda.Jax_`, we can just use regular `td.` classes. | ||
|
||
```py | ||
import tidy3d.plugins.adjoint as tda | ||
tda.JaxStructure(...) | ||
``` | ||
|
||
becomes | ||
|
||
```py | ||
import tidy3d as td | ||
td.Structure(...) | ||
``` | ||
import optax | ||
|
||
These `td.` classes can be used directly in the differentiable objective functions. | ||
Like before, only some fields are traceable for differentiation, and we outline the full list of supported fields in the feature roadmap below. | ||
|
||
Furthermore, there is no need for separated fields in the `JaxSimulation`, so one can eliminate `output_monitors` and `input_structures` and put everything in `monitors` and `structures`, respectively. | ||
`tidy3d` will automatically determine which structure and monitor is traced for differentiation. | ||
# 1. Function to create the simulation from parameters | ||
def make_simulation(width): | ||
# ... (define sources, monitors, etc.) | ||
geometry = td.Box(size=(width, 0.5, 0.22)) | ||
structure = td.Structure(geometry=geometry, medium=td.Medium(permittivity=12.0)) | ||
sim = td.Simulation( | ||
# ... (simulation parameters) | ||
structures=[structure], | ||
# ... | ||
) | ||
return sim | ||
|
||
# 2. Objective function returning a scalar | ||
def objective_fn(width): | ||
sim = make_simulation(width) | ||
sim_data = td.web.run(sim, task_name="optimization_step") | ||
# Objective: maximize power in the fundamental mode | ||
mode_amps = sim_data["monitor_name"].amps.sel(direction="+", mode_index=0) | ||
return anp.sum(anp.abs(mode_amps.values)**2) | ||
|
||
# 3. Get the value and gradient function | ||
value_and_grad_fn = autograd.value_and_grad(objective_fn) | ||
|
||
# 4. Optimization loop | ||
params = anp.array([2.0]) # Initial width | ||
optimizer = optax.adam(learning_rate=0.01) | ||
opt_state = optimizer.init(params) | ||
|
||
for i in range(20): | ||
value, gradient = value_and_grad_fn(params) | ||
updates, opt_state = optimizer.update(-gradient, opt_state, params) # Use -gradient to maximize | ||
params = optax.apply_updates(params, updates) | ||
print(f"Step {i+1}: Value = {value:.4f}, Width = {params[0]:.3f}") | ||
``` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: the optimization example uses optax which is not listed as a dependency. Either add note about optax being optional or show alternative using plain numpy optimization
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good bot. we should probably mention optax somewhere as our "preferred" optimizer. I ran into this with the metalens notebook because I realized that if we introduce optax halfway through these notebooks, it could be annoying to install and come back to it.
| **Geometry** | | | | ||
| `Box` | `.center`, `.size` | Shape Optimization | | ||
| `Cylinder` | `.center`, `.radius`, `.length` | Shape Optimization | | ||
| `PolySlab` | `.vertices`, `.slab_bounds`, `dilation` | Shape Optimization | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
syntax: typo in PolySlab attributes table row - dilation
missing period prefix
| `PolySlab` | `.vertices`, `.slab_bounds`, `dilation` | Shape Optimization | | |
| `PolySlab` | `.vertices`, `.slab_bounds`, `.dilation` | Shape Optimization | |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
good bot
Diff CoverageDiff: origin/develop...HEAD, staged and unstaged changesNo lines with coverage information in this diff. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
some comments from the human reviewer (me):
-
the intro makes a lot of mentions of the previous jax plugin. since we're removing it for 2.9, I suggest we just frame this document as if someone was starting from scratch, what is adjoint, autodiff, and very basic how does it work with python functions. We can perhaps remove this and add to the "migration guide" at the end the end of the document (where we also link to the deprecated adjoint plugin docs?) Otherwise it might be confusing to mix up with the main point of the document.
-
I actually think "basic workflow" should come before "how it works". it will be helpful to get right to the minimal code example I think. otherwise the text is a bit confusing to follow and people might give up.
-
Where do I find out more about what's available in
plugins.autograd
? would be good to link? -
Perhaps some links to the specific notebook(s) that demonstrate(s) some of the specific features listed could be useful if we want to include it. Or at very least, a link to the inverse design examples page.
-
A few sharp bits I feel are worth mentioning.
a. what is.values
,.data
and when do I need to use them?
b..item()
is useful (previously I was doingnp.sum(x.values)
. We should mention it?
c. Can I autograd through Job, Batch? what about things like web.upload()? -
As far as adjoint plugin deprecation is concerned, maybe in a separate PR, we should probably unlink the deprecated adjoint docs from the main docs TOC. We can keep them around and link to them here but no need for most users to be able to navigate there from the main Docs?
sim_data = td.web.run(sim, task_name="optimization_step") | ||
# Objective: maximize power in the fundamental mode | ||
mode_amps = sim_data["monitor_name"].amps.sel(direction="+", mode_index=0) | ||
return anp.sum(anp.abs(mode_amps.values)**2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be .data instead of .values?
return anp.sum(anp.abs(mode_amps.values)**2) | ||
|
||
# 3. Get the value and gradient function | ||
value_and_grad_fn = autograd.value_and_grad(objective_fn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we show the example with the internal tidy3d value_and_grad instead of autograd's version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looks good, left a few small comments/questions
Updating the autograd plugin readme in light of adjoint plugin deprecation (moved the transition guide to the bottom now since adjoint plugin should not be the first thing people see here anymore). Also did a bunch of other formatting and general style/content updates to hopefully make it more digestible and comprehensive at the same time.
Greptile Summary
Comprehensive update to
tidy3d/plugins/autograd/README.md
to establish autograd as the primary automatic differentiation solution, replacing the deprecated adjoint plugin.