You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
from firedrake import *
from firedrake.adjoint import *
from firedrake.adjoint_utils.blocks.solving import SolveVarFormBlock
from checkpoint_schedules import MultistageCheckpointSchedule
import itertools
N = 100
mesh = UnitIntervalMesh(1)
space = FunctionSpace(mesh, "Lagrange", 1)
test = TestFunction(space)
trial = TrialFunction(space)
tape = get_working_tape()
tape.enable_checkpointing(MultistageCheckpointSchedule(N, 3, 0))
u = Function(space, name="u").interpolate(Constant(2.0))
continue_annotation()
for _ in tape.timestepper(iter(range(N))):
u_ = Function(space)
solve(inner(trial, test) * dx == inner(test, u + u) * dx, u_)
u = u_
del u_
pause_annotation()
del u
deps = set()
for block in tape._blocks:
if isinstance(block, SolveVarFormBlock):
for dep in itertools.chain(ufl.algorithms.extract_coefficients(block.lhs),
ufl.algorithms.extract_coefficients(block.rhs)):
deps.add(dep.count())
print(f"{len(deps)=}")
leads to output
len(deps)=100
The text was updated successfully, but these errors were encountered:
Describe the bug
Firedrake
Block
subclasses reference variables via UFL expressions. This can prevent memory usage being reduced by checkpointing.Firedrake level version of dolfin-adjoint/pyadjoint#169.
Steps to Reproduce
leads to output
The text was updated successfully, but these errors were encountered: