Skip to content

Commit

Permalink
Fix MixedCheckpointSchedule (#180)
Browse files Browse the repository at this point in the history
* Test if checkpoints are cleaned correctly for multistep
  • Loading branch information
Ig-dolci authored Nov 27, 2024
1 parent 5f46e16 commit 16e6434
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 38 deletions.
90 changes: 60 additions & 30 deletions pyadjoint/checkpointing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from enum import Enum
import sys
from functools import singledispatchmethod
from checkpoint_schedules import Copy, Move, EndForward, EndReverse, Forward, Reverse, StorageType
from checkpoint_schedules import Copy, Move, EndForward, EndReverse, \
Forward, Reverse, StorageType, SingleMemoryStorageSchedule
# A callback interface allowing the user to provide a
# custom error message when disk checkpointing is not configured.
disk_checkpointing_callback = {}
Expand Down Expand Up @@ -78,6 +79,8 @@ def __init__(self, schedule, tape):
# Tell the tape to only checkpoint input data until told otherwise.
self.tape.latest_checkpoint = 0
self.end_timestep(-1)
self._keep_init_state_in_work = False
self._adj_deps_cleaned = False

def end_timestep(self, timestep):
"""Mark the end of one timestep when taping the forward model.
Expand Down Expand Up @@ -299,25 +302,50 @@ def _(self, cp_action, progress_bar, functional=None, **kwargs):
current_step.checkpoint(
_store_checkpointable_state, _store_adj_dependencies)

if (
(cp_action.write_adj_deps and cp_action.storage != StorageType.WORK)
or not cp_action.write_adj_deps
):
to_keep = set()
if step < (self.total_timesteps - 1):
next_step = self.tape.timesteps[step + 1]
# The checkpointable state set of the current step.
to_keep = next_step.checkpointable_state
if functional:
to_keep = to_keep.union([functional.block_variable])
for block in current_step:
# Remove unnecessary variables from previous steps.
for bv in block.get_outputs():
to_keep = set()
if step < (self.total_timesteps - 1):
next_step = self.tape.timesteps[step + 1]
# The checkpointable state set of the current step.
to_keep = next_step.checkpointable_state
if functional:
to_keep = to_keep.union([functional.block_variable])

for var in current_step.checkpointable_state - to_keep:
# Handle the case where step is 0
if step == 0 and var not in current_step._checkpoint:
# Ensure initialisation state is kept.
self._keep_init_state_in_work = True
break

# Handle the case for SingleMemoryStorageSchedule
if isinstance(self._schedule, SingleMemoryStorageSchedule):
if step > 1 and var not in self.tape.timesteps[step - 1].adjoint_dependencies:
var._checkpoint = None
continue

# Handle variables in the initial timestep
if (
var in self.tape.timesteps[0].checkpointable_state
and self._keep_init_state_in_work
):
continue

# Clear the checkpoint for other cases
var._checkpoint = None

for block in current_step:
# Remove unnecessary variables from previous steps.
for bv in block.get_outputs():
if (
(cp_action.write_adj_deps and cp_action.storage != StorageType.WORK)
or not cp_action.write_adj_deps
):
if bv not in to_keep:
bv._checkpoint = None
# Remove unnecessary variables from previous steps.
for var in (current_step.checkpointable_state - to_keep):
var._checkpoint = None
else:
if bv not in current_step.adjoint_dependencies.union(to_keep):
bv._checkpoint = None

step += 1
if cp_action.storage == StorageType.DISK:
# Activate disk checkpointing only in the checkpointing process.
Expand All @@ -333,22 +361,24 @@ def _(self, cp_action, progress_bar, markings, functional=None, **kwargs):
current_step = self.tape.timesteps[step]
for block in reversed(current_step):
block.evaluate_adj(markings=markings)
if not self._adj_deps_cleaned:
for out in block._outputs:
if not out.marked_in_path:
current_step.adjoint_dependencies.discard(out)
self._adj_deps_cleaned = True
# Output variables are used for the last time when running
# backwards.
to_keep = current_step.checkpointable_state
if functional:
to_keep = to_keep.union([functional.block_variable])
for block in current_step:
block.reset_adjoint_state()
for var in block.get_outputs():
var.checkpoint = None
var.reset_variables(("tlm",))
if not var.is_control:
var.reset_variables(("adjoint", "hessian"))
if cp_action.clear_adj_deps:
to_keep = current_step.checkpointable_state
if functional:
to_keep = to_keep.union([functional.block_variable])
for output in block.get_outputs():
if output not in to_keep:
output._checkpoint = None
for out in block.get_outputs():
out.reset_variables(("tlm",))
if not out.is_control:
out.reset_variables(("adjoint", "hessian"))
if cp_action.clear_adj_deps and out not in to_keep:
out._checkpoint = None

@process_operation.register(Copy)
def _(self, cp_action, progress_bar, **kwargs):
Expand Down
74 changes: 70 additions & 4 deletions tests/firedrake_adjoint/test_burgers_newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
set_log_level(CRITICAL)
continue_annotation()


def basics():
n = 30
mesh = UnitIntervalMesh(n)
Expand All @@ -21,13 +22,67 @@ def basics():
steps = int(end/float(timestep)) + 1
return mesh, timestep, steps


def Dt(u, u_, timestep):
return (u - u_)/timestep


def _check_forward(tape):
for current_step in tape.timesteps[1:-1]:
for block in current_step:
for deps in block.get_dependencies():
if (
deps not in tape.timesteps[0].checkpointable_state
and deps not in tape.timesteps[-1].checkpointable_state
):
assert deps._checkpoint is None
for out in block.get_outputs():
if out not in tape.timesteps[-1].checkpointable_state:
assert out._checkpoint is None


def _check_recompute(tape):
for current_step in tape.timesteps[1:-1]:
for block in current_step:
for deps in block.get_dependencies():
if deps not in tape.timesteps[0].checkpointable_state:
assert deps._checkpoint is None
for out in block.get_outputs():
assert out._checkpoint is None

for block in tape.timesteps[0]:
for out in block.get_outputs():
assert out._checkpoint is None
for block in tape.timesteps[len(tape.timesteps)-1]:
for deps in block.get_dependencies():
if (
deps not in tape.timesteps[0].checkpointable_state
and deps not in tape.timesteps[len(tape.timesteps)-1].adjoint_dependencies
):
assert deps._checkpoint is None


def _check_reverse(tape):
for step, current_step in enumerate(tape.timesteps):
if step > 0:
for block in current_step:
for deps in block.get_dependencies():
if deps not in tape.timesteps[0].checkpointable_state:
assert deps._checkpoint is None

for out in block.get_outputs():
assert out._checkpoint is None
assert out.adj_value is None

for block in current_step:
for out in block.get_outputs():
assert out._checkpoint is None


def J(ic, solve_type, timestep, steps, V):
u_ = Function(V)
u = Function(V)

u_ = Function(V, name="u_")
u = Function(V, name="u")
v = TestFunction(V)
u_.assign(ic)
nu = Constant(0.0001)
Expand Down Expand Up @@ -84,17 +139,28 @@ def test_burgers_newton(solve_type, checkpointing):
mesh = checkpointable_mesh(mesh)
x, = SpatialCoordinate(mesh)
V = FunctionSpace(mesh, "CG", 2)
ic = project(sin(2. * pi * x), V)
ic = project(sin(2. * pi * x), V, name="ic")
val = J(ic, solve_type, timestep, steps, V)
if checkpointing:
assert len(tape.timesteps) == steps
if checkpointing == "Revolve" or checkpointing == "Mixed":
_check_forward(tape)

Jhat = ReducedFunctional(val, Control(ic))
if checkpointing != "NoneAdjoint":
dJ = Jhat.derivative()
if checkpointing is not None:
# Check if the reverse checkpointing is working correctly.
if checkpointing == "Revolve" or checkpointing == "Mixed":
_check_reverse(tape)

# Recomputing the functional with a modified control variable
# before the recompute test.
Jhat(project(sin(pi*x), V))
if checkpointing:
# Check is the checkpointing is working correctly.
if checkpointing == "Revolve" or checkpointing == "Mixed":
_check_recompute(tape)

# Recompute test
assert(np.allclose(Jhat(ic), val))
Expand Down Expand Up @@ -143,4 +209,4 @@ def test_checkpointing_validity(solve_type, checkpointing):
Jhat = ReducedFunctional(val1, Control(ic))
assert len(tape.timesteps) == steps
assert np.allclose(val0, val1)
assert np.allclose(dJ0.dat.data_ro[:], Jhat.derivative().dat.data_ro[:])
assert np.allclose(dJ0.dat.data_ro[:], Jhat.derivative().dat.data_ro[:])
11 changes: 8 additions & 3 deletions tests/firedrake_adjoint/test_checkpointing_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

from firedrake import *
from firedrake.adjoint import *
from checkpoint_schedules import Revolve
from tests.firedrake_adjoint.test_burgers_newton import _check_forward, \
_check_recompute, _check_reverse
from checkpoint_schedules import MixedCheckpointSchedule, StorageType
import numpy as np
from collections import deque
continue_annotation()
Expand Down Expand Up @@ -42,15 +44,18 @@ def J(displacement_0):
def test_multisteps():
tape = get_working_tape()
tape.progress_bar = ProgressBar
tape.enable_checkpointing(Revolve(total_steps, 2))
tape.enable_checkpointing(MixedCheckpointSchedule(total_steps, 2, storage=StorageType.RAM))
displacement_0 = Function(V).assign(1.0)
val = J(displacement_0)
_check_forward(tape)
c = Control(displacement_0)
J_hat = ReducedFunctional(val, c)
dJ = J_hat.derivative()
_check_reverse(tape)
# Recomputing the functional with a modified control variable
# before the recompute test.
J_hat(Function(V).assign(0.5))
_check_recompute(tape)
# Recompute test
assert(np.allclose(J_hat(displacement_0), val))
# Test recompute adjoint-based gradient
Expand All @@ -70,7 +75,7 @@ def test_validity():
tape.clear_tape()

# With checkpointing.
tape.enable_checkpointing(Revolve(total_steps, 2))
tape.enable_checkpointing(MixedCheckpointSchedule(total_steps, 2, storage=StorageType.RAM))
val = J(displacement_0)
J_hat = ReducedFunctional(val, Control(displacement_0))
dJ = J_hat.derivative()
Expand Down
2 changes: 1 addition & 1 deletion tests/firedrake_adjoint/test_disk_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,5 +87,5 @@ def test_disk_checkpointing_error():
# check the raise of the exception
with pytest.raises(RuntimeError):
tape.enable_checkpointing(SingleDiskStorageSchedule())
assert disk_checkpointing_callback["firedrake"] == "Please call enable_disk_checkpointing() "\
assert disk_checkpointing_callback["firedrake"] == "Please call enable_disk_checkpointing() "\
"before checkpointing on the disk."

0 comments on commit 16e6434

Please sign in to comment.