From 1a25a5d5f0317ac2a530d388e697a234b3a2e1dd Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Tue, 25 Jul 2023 14:52:01 +0100 Subject: [PATCH 01/10] start tidying indexing for replace_* fml functions --- gusto/fml/replacement.py | 147 ++++++++++++++----- unit-tests/fml_tests/test_replace_subject.py | 6 +- 2 files changed, 112 insertions(+), 41 deletions(-) diff --git a/gusto/fml/replacement.py b/gusto/fml/replacement.py index 0280abc31..0e2029bee 100644 --- a/gusto/fml/replacement.py +++ b/gusto/fml/replacement.py @@ -13,7 +13,7 @@ # ---------------------------------------------------------------------------- # # A general routine for building the replacement dictionary # ---------------------------------------------------------------------------- # -def _replace_dict(old, new, idx, replace_type): +def _replace_dict(old, new, old_idx, new_idx, replace_type): """ Build a dictionary to pass to the ufl.replace routine The dictionary matches variables in the old term with those in the new @@ -21,51 +21,116 @@ def _replace_dict(old, new, idx, replace_type): Does not check types unless indexing is required (leave type-checking to ufl.replace) """ + mixed_old = type(old.ufl_element()) is MixedElement + mixed_new = hasattr(new, "ufl_element") and type(new.ufl_element()) is MixedElement + + indexable_old = mixed_old + indexable_new = mixed_new or type(new) is tuple + + # check indices arguments are valid + if not indexable_old and old_idx is not None: + raise ValueError(f"old_idx should not be specified to replace_{replace_type}" + + f" when replaced {replace_type} of type {old} is not mixed.") + + if not indexable_new and new_idx is not None: + raise ValueError(f"new_idx should not be specified to replace_{replace_type} when" + + f" new {replace_type} of type {new} is not mixed or indexable.") + + if indexable_old and not indexable_new: + if old_idx is None: + raise ValueError(f"old_idx must be specified to replace_{replace_type} when replaced" + + f" {replace_type} of type {old} is mixed and new {replace_type}" + + f" of type {new} is not mixed or indexable.") + + if indexable_new and not indexable_old: + if new_idx is None: + raise ValueError(f"new_idx must be specified to replace_{replace_type} when new" + + f" {replace_type} of type {new} is mixed or indexable and" + + f" old {replace_type} of type {old} is not mixed.") + + if indexable_old and indexable_new: + # must be both True or both False + if old_idx ^ new_idx: + raise ValueError(f"both or neither old_idx and new_idx must be specified to" + + f" replace_{replace_type} when old {replace_type} of type" + + f" {old} is mixed and new {replace_type} of type {new} is" + + f" mixed or indexable.") + if old_idx is None: # both indexes are none + if len(old) != len(new): + raise ValueError(f"if neither index is specified to replace_{replace_type}" + + f" and both old {replace_type} of type {old} and new" + + f" {replace_type} of type {new} are mixed or indexable" + + f" then old and new must be the same length.") + + # make the replace_dict + + if mixed_old: + split_old = split(old) + if indexable_new: + split_new = new if type(new) is tuple else split(new) + replace_dict = {} - if type(old.ufl_element()) is MixedElement: + # flat + if not indexable_old and not indexable_new: + replace_dict[old] = new + + elif not indexable_old and indexable_new: + replace_dict[old] = split_new[new_idx] + + elif indexable_old and not indexable_new: + replace_dict[split_old[old_idx]] = new + + elif indexable_old and indexable_new: + if old_idx is not None: + replace_dict[split_old[old_idx]] = split_new[new_idx] + else: # idxs are none + for k, v in zip(split_old, split_new): + replace_dict[k] = v + + # if type(old.ufl_element()) is MixedElement: - mixed_new = hasattr(new, "ufl_element") and type(new.ufl_element()) is MixedElement - indexable_new = type(new) is tuple or mixed_new + # mixed_new = hasattr(new, "ufl_element") and type(new.ufl_element()) is MixedElement + # indexable_new = type(new) is tuple or mixed_new - if indexable_new: - split_new = new if type(new) is tuple else split(new) + # if indexable_new: + # split_new = new if type(new) is tuple else split(new) - if len(split_new) != len(old.function_space()): - raise ValueError(f"new {replace_type} of type {new} must be same length" - + f"as replaced mixed {replace_type} of type {old}") + # if len(split_new) != len(old.function_space()): + # raise ValueError(f"new {replace_type} of type {new} must be same length" + # + f"as replaced mixed {replace_type} of type {old}") - if idx is None: - for k, v in zip(split(old), split_new): - replace_dict[k] = v - else: - replace_dict[split(old)[idx]] = split_new[idx] + # if idx is None: + # for k, v in zip(split(old), split_new): + # replace_dict[k] = v + # else: + # replace_dict[split(old)[idx]] = split_new[idx] - else: # new is not indexable - if idx is None: - raise ValueError(f"idx must be specified to replace_{replace_type} when" - + f" replaced {replace_type} of type {old} is mixed and" - + f" new {replace_type} of type {new} is a single component") + # else: # new is not indexable + # if idx is None: + # raise ValueError(f"idx must be specified to replace_{replace_type} when" + # + f" replaced {replace_type} of type {old} is mixed and" + # + f" new {replace_type} of type {new} is a single component") - replace_dict[split(old)[idx]] = new + # replace_dict[split(old)[idx]] = new - else: # old is not mixed + # else: # old is not mixed - mixed_new = hasattr(new, "ufl_element") and type(new.ufl_element()) is MixedElement - indexable_new = type(new) is tuple or mixed_new + # mixed_new = hasattr(new, "ufl_element") and type(new.ufl_element()) is MixedElement + # indexable_new = type(new) is tuple or mixed_new - if indexable_new: - split_new = new if type(new) is tuple else split(new) + # if indexable_new: + # split_new = new if type(new) is tuple else split(new) - if idx is None: - raise ValueError(f"idx must be specified to replace_{replace_type} when" - + f" replaced {replace_type} of type {old} is not mixed" - + f" and new {replace_type} of type {new} is indexable") + # if idx is None: + # raise ValueError(f"idx must be specified to replace_{replace_type} when" + # + f" replaced {replace_type} of type {old} is not mixed" + # + f" and new {replace_type} of type {new} is indexable") - replace_dict[old] = split_new[idx] + # replace_dict[old] = split_new[idx] - else: - replace_dict[old] = new + # else: + # replace_dict[old] = new return replace_dict @@ -73,7 +138,7 @@ def _replace_dict(old, new, idx, replace_type): # ---------------------------------------------------------------------------- # # Replacement routines # ---------------------------------------------------------------------------- # -def replace_test_function(new_test, idx=None): +def replace_test_function(new_test, new_idx=None): """ A routine to replace the test function in a term with a new test function. @@ -97,7 +162,9 @@ def repl(t): :class:`Term`: the new term. """ old_test = t.form.arguments()[0] - replace_dict = _replace_dict(old_test, new_test, idx, 'test') + replace_dict = _replace_dict(old_test, new_test, + old_idx=None, new_idx=new_idx, + replace_type='test') try: new_form = ufl.replace(t.form, replace_dict) @@ -111,7 +178,7 @@ def repl(t): return repl -def replace_trial_function(new_trial, idx=None): +def replace_trial_function(new_trial, new_idx=None): """ A routine to replace the trial function in a term with a new expression. @@ -140,7 +207,9 @@ def repl(t): if len(t.form.arguments()) != 2: raise TypeError('Trying to replace trial function of a form that is not linear') old_trial = t.form.arguments()[1] - replace_dict = _replace_dict(old_trial, new_trial, idx, 'trial') + replace_dict = _replace_dict(old_trial, new_trial, + old_idx=None, new_idx=new_idx, + replace_type='trial') try: new_form = ufl.replace(t.form, replace_dict) @@ -170,7 +239,7 @@ def repl(t): return repl -def replace_subject(new_subj, idx=None): +def replace_subject(new_subj, old_idx=None, new_idx=None): """ A routine to replace the subject in a term with a new variable. @@ -196,7 +265,9 @@ def repl(t): """ old_subj = t.get(subject) - replace_dict = _replace_dict(old_subj, new_subj, idx, 'subject') + replace_dict = _replace_dict(old_subj, new_subj, + old_idx=old_idx, new_idx=new_idx, + replace_type='subject') try: new_form = ufl.replace(t.form, replace_dict) diff --git a/unit-tests/fml_tests/test_replace_subject.py b/unit-tests/fml_tests/test_replace_subject.py index 4b1ade0f1..baaa8f1f3 100644 --- a/unit-tests/fml_tests/test_replace_subject.py +++ b/unit-tests/fml_tests/test_replace_subject.py @@ -5,7 +5,7 @@ from firedrake import (UnitSquareMesh, FunctionSpace, Function, TestFunction, VectorFunctionSpace, MixedFunctionSpace, dx, inner, TrialFunctions, TrialFunction, split) -from gusto.fml import (Label subject, replace_subject, replace_test_function, +from gusto.fml import (Label, subject, replace_subject, replace_test_function, replace_trial_function) import pytest @@ -28,7 +28,7 @@ def test_replace_subject(subject_type, replacement_type, function_or_indexed, re # only makes sense to replace a vector with a vector if (subject_type == 'vector') ^ (replacement_type == 'vector'): - pytest.skip("invalid option combination") + pytest.skip("Invalid vector option combination") # ------------------------------------------------------------------------ # # Set up @@ -39,7 +39,7 @@ def test_replace_subject(subject_type, replacement_type, function_or_indexed, re bar_label = Label("bar") # Create mesh, function space and forms - n = 3 + n = 2 mesh = UnitSquareMesh(n, n) V0 = FunctionSpace(mesh, "DG", 0) V1 = FunctionSpace(mesh, "CG", 1) From 2d2fe98a2224187821a5a9d6569ee3aed180649a Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Tue, 25 Jul 2023 15:30:34 +0100 Subject: [PATCH 02/10] correct replace if statements --- gusto/fml/replacement.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/gusto/fml/replacement.py b/gusto/fml/replacement.py index 0e2029bee..24cb2d99d 100644 --- a/gusto/fml/replacement.py +++ b/gusto/fml/replacement.py @@ -50,7 +50,7 @@ def _replace_dict(old, new, old_idx, new_idx, replace_type): if indexable_old and indexable_new: # must be both True or both False - if old_idx ^ new_idx: + if (old_idx is None) ^ (new_idx is None): raise ValueError(f"both or neither old_idx and new_idx must be specified to" + f" replace_{replace_type} when old {replace_type} of type" + f" {old} is mixed and new {replace_type} of type {new} is" @@ -71,7 +71,6 @@ def _replace_dict(old, new, old_idx, new_idx, replace_type): replace_dict = {} - # flat if not indexable_old and not indexable_new: replace_dict[old] = new @@ -82,11 +81,11 @@ def _replace_dict(old, new, old_idx, new_idx, replace_type): replace_dict[split_old[old_idx]] = new elif indexable_old and indexable_new: - if old_idx is not None: - replace_dict[split_old[old_idx]] = split_new[new_idx] - else: # idxs are none + if old_idx is None: # replace everything for k, v in zip(split_old, split_new): replace_dict[k] = v + else: # idxs are given + replace_dict[split_old[old_idx]] = split_new[new_idx] # if type(old.ufl_element()) is MixedElement: From d1de427a0375ca9c664f629e0b28842a18116e84 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Tue, 25 Jul 2023 17:09:25 +0100 Subject: [PATCH 03/10] parameterized replace test --- unit-tests/fml_tests/test_replace_subject.py | 192 ++++++++++++++++++- 1 file changed, 188 insertions(+), 4 deletions(-) diff --git a/unit-tests/fml_tests/test_replace_subject.py b/unit-tests/fml_tests/test_replace_subject.py index baaa8f1f3..a2865429a 100644 --- a/unit-tests/fml_tests/test_replace_subject.py +++ b/unit-tests/fml_tests/test_replace_subject.py @@ -2,13 +2,197 @@ Tests the replace_subject routine from labels.py """ -from firedrake import (UnitSquareMesh, FunctionSpace, Function, TestFunction, +from firedrake import (UnitSquareMesh, FunctionSpace, Function, TestFunction, TestFunctions, VectorFunctionSpace, MixedFunctionSpace, dx, inner, - TrialFunctions, TrialFunction, split) + TrialFunctions, TrialFunction, split, grad) from gusto.fml import (Label, subject, replace_subject, replace_test_function, - replace_trial_function) + replace_trial_function, drop) import pytest +from collections import namedtuple + +ReplaceArgs = namedtuple("ReplaceArgs", "subject idxs error") + +# some dummy labels +foo_label = Label("foo") +bar_label = Label("bar") + +nx = 2 +mesh = UnitSquareMesh(nx, nx) + +V0 = FunctionSpace(mesh, 'CG', 1) +V1 = FunctionSpace(mesh, 'DG', 1) + +W = V0*V1 + +subj = Function(V0) +v = TestFunction(V0) + +term1 = foo_label(subject(subj*v*dx, subj)) +term2 = bar_label(inner(grad(subj), grad(v))*dx) + +labelled_form = term1 + term2 + +argsets = [ + ReplaceArgs(Function(V0), {}, None), + ReplaceArgs(Function(V0), {'new_idx': 0}, ValueError), + ReplaceArgs(Function(V0), {'old_idx': 0}, ValueError), + ReplaceArgs(Function(W), {'new_idx': 0}, None), + ReplaceArgs(Function(W), {'new_idx': 1}, None), + ReplaceArgs(Function(W), {'old_idx': 0}, ValueError), + ReplaceArgs(Function(W), {'new_idx': 7}, IndexError), +] + + +@pytest.mark.parametrize('argset', argsets) +def test_replace_subject_params(argset): + arg = argset.subject + idxs = argset.idxs + error = argset.error + + if error is None: + new_form = labelled_form.label_map( + lambda t: t.has_label(foo_label), + map_if_true=replace_subject(arg, **idxs), + map_if_false=drop) + assert arg == new_form.form.coefficients()[0] + assert subj not in new_form.form.coefficients() + + else: + with pytest.raises(error): + new_form = labelled_form.label_map( + lambda t: t.has_label(foo_label), + map_if_true=replace_subject(arg, **idxs)) + + +def test_replace_subject_primal(): + # setup some basic labels + foo_label = Label("foo") + bar_label = Label("bar") + + # setup the mesh and function space + n = 2 + mesh = UnitSquareMesh(n, n) + V = FunctionSpace(mesh, "CG", 1) + + # set up the form + u = Function(V) + v = TestFunction(V) + + form1 = inner(u, v)*dx + form2 = inner(grad(u), grad(v))*dx + + term1 = foo_label(subject(form1, u)) + term2 = bar_label(form2) + + labelled_form = term1 + term2 + + # replace with another function + w = Function(V) + + # this should work + new_form = labelled_form.label_map( + lambda t: t.has_label(foo_label), + map_if_true=replace_subject(w)) + + # these should fail if given an index + with pytest.raises(ValueError): + new_form = labelled_form.label_map( + lambda t: t.has_label(foo_label), + map_if_true=replace_subject(w, new_idx=0)) + + with pytest.raises(ValueError): + new_form = labelled_form.label_map( + lambda t: t.has_label(foo_label), + map_if_true=replace_subject(w, old_idx=0)) + + with pytest.raises(ValueError): + new_form = labelled_form.label_map( + lambda t: t.has_label(foo_label), + map_if_true=replace_subject(w, old_idx=0, new_idx=0)) + + # replace with mixed component + wm = Function(V*V) + wms = split(wm) + wm0, wm1 = wms + + new_form = labelled_form.label_map( + lambda t: t.has_label(foo_label), + map_if_true=replace_subject(wm0)) + + new_form = labelled_form.label_map( + lambda t: t.has_label(foo_label), + map_if_true=replace_subject(wms, new_idx=0)) + + +def test_replace_subject_mixed(): + # setup some basic labels + foo_label = Label("foo") + bar_label = Label("bar") + + # setup the mesh and function space + n = 2 + mesh = UnitSquareMesh(n, n) + V0 = FunctionSpace(mesh, "CG", 1) + V1 = FunctionSpace(mesh, "DG", 1) + W = V0*V1 + + # set up the form + u = Function(W) + u0, u1 = split(u) + v0, v1 = TestFunctions(W) + + form1 = inner(u0, v0)*dx + form2 = inner(grad(u1), grad(v1))*dx + + term1 = foo_label(subject(form1, u)) + term2 = bar_label(form2) + + labelled_form = term1 + term2 + + # replace with another function + w = Function(W) + + # replace all parts of the subject + new_form = labelled_form.label_map( + lambda t: t.has_label(foo_label), + map_if_true=replace_subject(w)) + + # replace either part of the subject + new_form = labelled_form.label_map( + lambda t: t.has_label(foo_label), + map_if_true=replace_subject(w, old_idx=0, new_idx=0)) + + new_form = labelled_form.label_map( + lambda t: t.has_label(foo_label), + map_if_true=replace_subject(w, old_idx=1, new_idx=1)) + + # these should fail if given only one index + with pytest.raises(ValueError): + new_form = labelled_form.label_map( + lambda t: t.has_label(foo_label), + map_if_true=replace_subject(w, old_idx=1)) + + with pytest.raises(ValueError): + new_form = labelled_form.label_map( + lambda t: t.has_label(foo_label), + map_if_true=replace_subject(w, new_idx=1)) + + # try indexing only one + w0, w1 = split(w) + + # replace a specific part of the subject + new_form = labelled_form.label_map( + lambda t: t.has_label(foo_label), + map_if_true=replace_subject(w1, old_idx=0)) + + # replace with something from a primal space + wp = Function(V0) + new_form = labelled_form.label_map( + lambda t: t.has_label(foo_label), + map_if_true=replace_subject(wp, old_idx=1)) + + replace_funcs = [ pytest.param((Function, replace_subject), id="replace_subj"), pytest.param((TestFunction, replace_test_function), id="replace_test"), @@ -20,7 +204,7 @@ @pytest.mark.parametrize('replacement_type', ['normal', 'mixed', 'vector', 'tuple']) @pytest.mark.parametrize('function_or_indexed', ['function', 'indexed']) @pytest.mark.parametrize('replace_func', replace_funcs) -def test_replace_subject(subject_type, replacement_type, function_or_indexed, replace_func): +def old_test_replace_subject(subject_type, replacement_type, function_or_indexed, replace_func): # ------------------------------------------------------------------------ # # Only certain combinations of options are valid From 9ac02e27b3b2b2255c2e81d640de7afd66eecfb3 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Wed, 26 Jul 2023 09:58:26 +0100 Subject: [PATCH 04/10] remove old _replace_dict method and flake8 --- gusto/fml/replacement.py | 56 +++++----------------------------------- 1 file changed, 6 insertions(+), 50 deletions(-) diff --git a/gusto/fml/replacement.py b/gusto/fml/replacement.py index 24cb2d99d..bb12f8d3e 100644 --- a/gusto/fml/replacement.py +++ b/gusto/fml/replacement.py @@ -51,16 +51,16 @@ def _replace_dict(old, new, old_idx, new_idx, replace_type): if indexable_old and indexable_new: # must be both True or both False if (old_idx is None) ^ (new_idx is None): - raise ValueError(f"both or neither old_idx and new_idx must be specified to" + raise ValueError("both or neither old_idx and new_idx must be specified to" + f" replace_{replace_type} when old {replace_type} of type" + f" {old} is mixed and new {replace_type} of type {new} is" - + f" mixed or indexable.") - if old_idx is None: # both indexes are none + + " mixed or indexable.") + if old_idx is None: # both indexes are none if len(old) != len(new): raise ValueError(f"if neither index is specified to replace_{replace_type}" + f" and both old {replace_type} of type {old} and new" + f" {replace_type} of type {new} are mixed or indexable" - + f" then old and new must be the same length.") + + " then old and new must be the same length.") # make the replace_dict @@ -81,56 +81,12 @@ def _replace_dict(old, new, old_idx, new_idx, replace_type): replace_dict[split_old[old_idx]] = new elif indexable_old and indexable_new: - if old_idx is None: # replace everything + if old_idx is None: # replace everything for k, v in zip(split_old, split_new): replace_dict[k] = v - else: # idxs are given + else: # idxs are given replace_dict[split_old[old_idx]] = split_new[new_idx] - # if type(old.ufl_element()) is MixedElement: - - # mixed_new = hasattr(new, "ufl_element") and type(new.ufl_element()) is MixedElement - # indexable_new = type(new) is tuple or mixed_new - - # if indexable_new: - # split_new = new if type(new) is tuple else split(new) - - # if len(split_new) != len(old.function_space()): - # raise ValueError(f"new {replace_type} of type {new} must be same length" - # + f"as replaced mixed {replace_type} of type {old}") - - # if idx is None: - # for k, v in zip(split(old), split_new): - # replace_dict[k] = v - # else: - # replace_dict[split(old)[idx]] = split_new[idx] - - # else: # new is not indexable - # if idx is None: - # raise ValueError(f"idx must be specified to replace_{replace_type} when" - # + f" replaced {replace_type} of type {old} is mixed and" - # + f" new {replace_type} of type {new} is a single component") - - # replace_dict[split(old)[idx]] = new - - # else: # old is not mixed - - # mixed_new = hasattr(new, "ufl_element") and type(new.ufl_element()) is MixedElement - # indexable_new = type(new) is tuple or mixed_new - - # if indexable_new: - # split_new = new if type(new) is tuple else split(new) - - # if idx is None: - # raise ValueError(f"idx must be specified to replace_{replace_type} when" - # + f" replaced {replace_type} of type {old} is not mixed" - # + f" and new {replace_type} of type {new} is indexable") - - # replace_dict[old] = split_new[idx] - - # else: - # replace_dict[old] = new - return replace_dict From b3b0f4accf570643dc2d05efd33c820788d8b73e Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Wed, 26 Jul 2023 09:58:41 +0100 Subject: [PATCH 05/10] new replace_subject tests --- unit-tests/fml_tests/test_replace_subject.py | 385 ++++++------------- 1 file changed, 128 insertions(+), 257 deletions(-) diff --git a/unit-tests/fml_tests/test_replace_subject.py b/unit-tests/fml_tests/test_replace_subject.py index a2865429a..57d4feaac 100644 --- a/unit-tests/fml_tests/test_replace_subject.py +++ b/unit-tests/fml_tests/test_replace_subject.py @@ -2,16 +2,14 @@ Tests the replace_subject routine from labels.py """ -from firedrake import (UnitSquareMesh, FunctionSpace, Function, TestFunction, TestFunctions, - VectorFunctionSpace, MixedFunctionSpace, dx, inner, - TrialFunctions, TrialFunction, split, grad) -from gusto.fml import (Label, subject, replace_subject, replace_test_function, - replace_trial_function, drop) +from firedrake import (UnitSquareMesh, FunctionSpace, Function, TestFunction, + VectorFunctionSpace, dx, inner, split, grad) +from gusto.fml import (Label, subject, replace_subject, drop) import pytest from collections import namedtuple -ReplaceArgs = namedtuple("ReplaceArgs", "subject idxs error") +ReplaceArgs = namedtuple("ReplaceArgs", "new_subj idxs error") # some dummy labels foo_label = Label("foo") @@ -19,299 +17,172 @@ nx = 2 mesh = UnitSquareMesh(nx, nx) - V0 = FunctionSpace(mesh, 'CG', 1) V1 = FunctionSpace(mesh, 'DG', 1) - W = V0*V1 +Vv = VectorFunctionSpace(mesh, 'CG', 1) +Wv = Vv*V1 -subj = Function(V0) -v = TestFunction(V0) -term1 = foo_label(subject(subj*v*dx, subj)) -term2 = bar_label(inner(grad(subj), grad(v))*dx) +@pytest.fixture +def primal_form(): + primal_subj = Function(V0) + primal_test = TestFunction(V0) -labelled_form = term1 + term2 + primal_term1 = foo_label(subject(primal_subj*primal_test*dx, primal_subj)) + primal_term2 = bar_label(inner(grad(primal_subj), grad(primal_test))*dx) -argsets = [ - ReplaceArgs(Function(V0), {}, None), - ReplaceArgs(Function(V0), {'new_idx': 0}, ValueError), - ReplaceArgs(Function(V0), {'old_idx': 0}, ValueError), - ReplaceArgs(Function(W), {'new_idx': 0}, None), - ReplaceArgs(Function(W), {'new_idx': 1}, None), - ReplaceArgs(Function(W), {'old_idx': 0}, ValueError), - ReplaceArgs(Function(W), {'new_idx': 7}, IndexError), -] + return primal_term1 + primal_term2 -@pytest.mark.parametrize('argset', argsets) -def test_replace_subject_params(argset): - arg = argset.subject - idxs = argset.idxs - error = argset.error +def primal_argsets(): + argsets = [ + ReplaceArgs(Function(V0), {}, None), + ReplaceArgs(Function(V0), {'new_idx': 0}, ValueError), + ReplaceArgs(Function(V0), {'old_idx': 0}, ValueError), + ReplaceArgs(Function(W), {'new_idx': 0}, None), + ReplaceArgs(Function(W), {'new_idx': 1}, None), + ReplaceArgs(split(Function(W)), {'new_idx': 1}, None), + ReplaceArgs(Function(W), {'old_idx': 0}, ValueError), + ReplaceArgs(Function(W), {'new_idx': 7}, IndexError), + ] + return argsets - if error is None: - new_form = labelled_form.label_map( - lambda t: t.has_label(foo_label), - map_if_true=replace_subject(arg, **idxs), - map_if_false=drop) - assert arg == new_form.form.coefficients()[0] - assert subj not in new_form.form.coefficients() - else: - with pytest.raises(error): - new_form = labelled_form.label_map( - lambda t: t.has_label(foo_label), - map_if_true=replace_subject(arg, **idxs)) +@pytest.fixture +def mixed_form(): + mixed_subj = Function(W) + mixed_test = TestFunction(W) + mixed_subj0, mixed_subj1 = split(mixed_subj) + mixed_test0, mixed_test1 = split(mixed_test) -def test_replace_subject_primal(): - # setup some basic labels - foo_label = Label("foo") - bar_label = Label("bar") + mixed_term1 = foo_label(subject(mixed_subj0*mixed_test0*dx, mixed_subj)) + mixed_term2 = bar_label(inner(grad(mixed_subj1), grad(mixed_test1))*dx) - # setup the mesh and function space - n = 2 - mesh = UnitSquareMesh(n, n) - V = FunctionSpace(mesh, "CG", 1) + return mixed_term1 + mixed_term2 - # set up the form - u = Function(V) - v = TestFunction(V) - form1 = inner(u, v)*dx - form2 = inner(grad(u), grad(v))*dx +def mixed_argsets(): + argsets = [ + ReplaceArgs(Function(W), {}, None), + ReplaceArgs(Function(W), {'new_idx': 0, 'old_idx': 0}, None), + ReplaceArgs(Function(W), {'old_idx': 0}, ValueError), + ReplaceArgs(Function(W), {'new_idx': 0}, ValueError), + ReplaceArgs(Function(V0), {'old_idx': 0}, None), + ReplaceArgs(Function(V0), {'new_idx': 0}, ValueError), + ReplaceArgs(split(Function(W)), {'new_idx': 0, 'old_idx': 0}, None), + ] + return argsets - term1 = foo_label(subject(form1, u)) - term2 = bar_label(form2) - labelled_form = term1 + term2 +@pytest.fixture +def vector_form(): + vector_subj = Function(Vv) + vector_test = TestFunction(Vv) - # replace with another function - w = Function(V) + vector_term1 = foo_label(subject(inner(vector_subj, vector_test)*dx, vector_subj)) + vector_term2 = bar_label(inner(grad(vector_subj), grad(vector_test))*dx) - # this should work - new_form = labelled_form.label_map( - lambda t: t.has_label(foo_label), - map_if_true=replace_subject(w)) + return vector_term1 + vector_term2 - # these should fail if given an index - with pytest.raises(ValueError): - new_form = labelled_form.label_map( - lambda t: t.has_label(foo_label), - map_if_true=replace_subject(w, new_idx=0)) - with pytest.raises(ValueError): - new_form = labelled_form.label_map( - lambda t: t.has_label(foo_label), - map_if_true=replace_subject(w, old_idx=0)) - - with pytest.raises(ValueError): - new_form = labelled_form.label_map( - lambda t: t.has_label(foo_label), - map_if_true=replace_subject(w, old_idx=0, new_idx=0)) - - # replace with mixed component - wm = Function(V*V) - wms = split(wm) - wm0, wm1 = wms - - new_form = labelled_form.label_map( - lambda t: t.has_label(foo_label), - map_if_true=replace_subject(wm0)) +def vector_argsets(): + argsets = [ + ReplaceArgs(Function(Vv), {}, None), + ReplaceArgs(Function(V0), {'new_idx': 0}, ValueError), + ReplaceArgs(Function(V0), {'old_idx': 0}, ValueError), + ReplaceArgs(Function(Wv), {'new_idx': 0}, None), + ReplaceArgs(Function(Wv), {'new_idx': 1}, ValueError), + ReplaceArgs(split(Function(Wv)), {'new_idx': 0}, None), + ReplaceArgs(Function(W), {'old_idx': 0}, ValueError), + ReplaceArgs(Function(W), {'new_idx': 7}, IndexError), + ] + return argsets - new_form = labelled_form.label_map( - lambda t: t.has_label(foo_label), - map_if_true=replace_subject(wms, new_idx=0)) +@pytest.mark.parametrize('argset', primal_argsets()) +def test_replace_subject_primal(primal_form, argset): + new_subj = argset.new_subj + idxs = argset.idxs + error = argset.error -def test_replace_subject_mixed(): - # setup some basic labels - foo_label = Label("foo") - bar_label = Label("bar") + if error is None: + old_subj = primal_form.form.coefficients()[0] - # setup the mesh and function space - n = 2 - mesh = UnitSquareMesh(n, n) - V0 = FunctionSpace(mesh, "CG", 1) - V1 = FunctionSpace(mesh, "DG", 1) - W = V0*V1 + new_form = primal_form.label_map( + lambda t: t.has_label(foo_label), + map_if_true=replace_subject(new_subj, **idxs), + map_if_false=drop) - # set up the form - u = Function(W) - u0, u1 = split(u) - v0, v1 = TestFunctions(W) + # what if we only replace part of the subject? + if 'new_idx' in idxs: + split_new = new_subj if type(new_subj) is tuple else split(new_subj) + new_subj = split_new[idxs['new_idx']].ufl_operands[0] - form1 = inner(u0, v0)*dx - form2 = inner(grad(u1), grad(v1))*dx + assert new_subj in new_form.form.coefficients() + assert old_subj not in new_form.form.coefficients() - term1 = foo_label(subject(form1, u)) - term2 = bar_label(form2) + else: + with pytest.raises(error): + new_form = primal_form.label_map( + lambda t: t.has_label(foo_label), + map_if_true=replace_subject(new_subj, **idxs)) - labelled_form = term1 + term2 - # replace with another function - w = Function(W) +@pytest.mark.parametrize('argset', mixed_argsets()) +def test_replace_subject_mixed(mixed_form, argset): + new_subj = argset.new_subj + idxs = argset.idxs + error = argset.error - # replace all parts of the subject - new_form = labelled_form.label_map( - lambda t: t.has_label(foo_label), - map_if_true=replace_subject(w)) + if error is None: + old_subj = mixed_form.form.coefficients()[0] - # replace either part of the subject - new_form = labelled_form.label_map( - lambda t: t.has_label(foo_label), - map_if_true=replace_subject(w, old_idx=0, new_idx=0)) + new_form = mixed_form.label_map( + lambda t: t.has_label(foo_label), + map_if_true=replace_subject(new_subj, **idxs), + map_if_false=drop) - new_form = labelled_form.label_map( - lambda t: t.has_label(foo_label), - map_if_true=replace_subject(w, old_idx=1, new_idx=1)) + # what if we only replace part of the subject? + if 'new_idx' in idxs: + split_new = new_subj if type(new_subj) is tuple else split(new_subj) + new_subj = split_new[idxs['new_idx']].ufl_operands[0] - # these should fail if given only one index - with pytest.raises(ValueError): - new_form = labelled_form.label_map( - lambda t: t.has_label(foo_label), - map_if_true=replace_subject(w, old_idx=1)) + assert new_subj in new_form.form.coefficients() + assert old_subj not in new_form.form.coefficients() - with pytest.raises(ValueError): - new_form = labelled_form.label_map( - lambda t: t.has_label(foo_label), - map_if_true=replace_subject(w, new_idx=1)) - - # try indexing only one - w0, w1 = split(w) - - # replace a specific part of the subject - new_form = labelled_form.label_map( - lambda t: t.has_label(foo_label), - map_if_true=replace_subject(w1, old_idx=0)) - - # replace with something from a primal space - wp = Function(V0) - new_form = labelled_form.label_map( - lambda t: t.has_label(foo_label), - map_if_true=replace_subject(wp, old_idx=1)) - - -replace_funcs = [ - pytest.param((Function, replace_subject), id="replace_subj"), - pytest.param((TestFunction, replace_test_function), id="replace_test"), - pytest.param((TrialFunction, replace_trial_function), id="replace_trial") -] - - -@pytest.mark.parametrize('subject_type', ['normal', 'mixed', 'vector']) -@pytest.mark.parametrize('replacement_type', ['normal', 'mixed', 'vector', 'tuple']) -@pytest.mark.parametrize('function_or_indexed', ['function', 'indexed']) -@pytest.mark.parametrize('replace_func', replace_funcs) -def old_test_replace_subject(subject_type, replacement_type, function_or_indexed, replace_func): - - # ------------------------------------------------------------------------ # - # Only certain combinations of options are valid - # ------------------------------------------------------------------------ # - - # only makes sense to replace a vector with a vector - if (subject_type == 'vector') ^ (replacement_type == 'vector'): - pytest.skip("Invalid vector option combination") - - # ------------------------------------------------------------------------ # - # Set up - # ------------------------------------------------------------------------ # - - # Some basic labels - foo_label = Label("foo") - bar_label = Label("bar") - - # Create mesh, function space and forms - n = 2 - mesh = UnitSquareMesh(n, n) - V0 = FunctionSpace(mesh, "DG", 0) - V1 = FunctionSpace(mesh, "CG", 1) - V2 = VectorFunctionSpace(mesh, "DG", 0) - Vmixed = MixedFunctionSpace((V0, V1)) - - idx = None - - # ------------------------------------------------------------------------ # - # Choose subject - # ------------------------------------------------------------------------ # - - if subject_type == 'normal': - V = V0 - elif subject_type == 'mixed': - V = Vmixed - if replacement_type == 'normal': - idx = 0 - elif subject_type == 'vector': - V = V2 - else: - raise ValueError - - the_subject = Function(V) - not_subject = TrialFunction(V) - test = TestFunction(V) - - form_1 = inner(the_subject, test)*dx - form_2 = inner(not_subject, test)*dx - - term_1 = foo_label(subject(form_1, the_subject)) - term_2 = bar_label(form_2) - labelled_form = term_1 + term_2 - - # ------------------------------------------------------------------------ # - # Choose replacement - # ------------------------------------------------------------------------ # - - if replacement_type == 'normal': - V = V1 - elif replacement_type == 'mixed': - V = Vmixed - if subject_type != 'mixed': - idx = 0 - elif replacement_type == 'vector': - V = V2 - elif replacement_type == 'tuple': - V = Vmixed else: - raise ValueError - - FunctionType = replace_func[0] + with pytest.raises(error): + new_form = mixed_form.label_map( + lambda t: t.has_label(foo_label), + map_if_true=replace_subject(new_subj, **idxs)) - the_replacement = FunctionType(V) - if function_or_indexed == 'indexed' and replacement_type != 'vector': - the_replacement = split(the_replacement) +@pytest.mark.parametrize('argset', vector_argsets()) +def test_replace_subject_vector(vector_form, argset): + new_subj = argset.new_subj + idxs = argset.idxs + error = argset.error - if len(the_replacement) == 1: - the_replacement = the_replacement[0] + if error is None: + old_subj = vector_form.form.coefficients()[0] - if replacement_type == 'tuple': - the_replacement = TrialFunctions(Vmixed) - if subject_type == 'normal': - idx = 0 + new_form = vector_form.label_map( + lambda t: t.has_label(foo_label), + map_if_true=replace_subject(new_subj, **idxs), + map_if_false=drop) - # ------------------------------------------------------------------------ # - # Test replace_subject - # ------------------------------------------------------------------------ # + # what if we only replace part of the subject? + if 'new_idx' in idxs: + split_new = new_subj if type(new_subj) is tuple else split(new_subj) + new_subj = split_new[idxs['new_idx']].ufl_operands[0].ufl_operands[0] - replace_map = replace_func[1] + assert new_subj in new_form.form.coefficients() + assert old_subj not in new_form.form.coefficients() - if replace_map is replace_trial_function: - match_label = bar_label else: - match_label = subject - - labelled_form = labelled_form.label_map( - lambda t: t.has_label(match_label), - map_if_true=replace_map(the_replacement, idx=idx) - ) - - # also test indexed - if subject_type == 'mixed' and function_or_indexed == 'indexed': - idx = 0 - the_replacement = split(FunctionType(Vmixed))[idx] - - labelled_form = labelled_form.label_map( - lambda t: t.has_label(match_label), - map_if_true=replace_map(the_replacement, idx=idx) - ) + with pytest.raises(error): + new_form = vector_form.label_map( + lambda t: t.has_label(foo_label), + map_if_true=replace_subject(new_subj, **idxs)) From 79f96505c07a37db142bc55a52ed83771ad1dfd9 Mon Sep 17 00:00:00 2001 From: Josh Hope-Collins Date: Wed, 26 Jul 2023 11:43:41 +0100 Subject: [PATCH 06/10] correct replace_subject calls in tests and gusto --- gusto/fml/replacement.py | 15 +++--- gusto/time_discretisation.py | 58 +++++++++++------------ unit-tests/fml_tests/test_replace_perp.py | 2 +- 3 files changed, 38 insertions(+), 37 deletions(-) diff --git a/gusto/fml/replacement.py b/gusto/fml/replacement.py index bb12f8d3e..09d12d6c5 100644 --- a/gusto/fml/replacement.py +++ b/gusto/fml/replacement.py @@ -27,6 +27,11 @@ def _replace_dict(old, new, old_idx, new_idx, replace_type): indexable_old = mixed_old indexable_new = mixed_new or type(new) is tuple + if mixed_old: + split_old = split(old) + if indexable_new: + split_new = new if type(new) is tuple else split(new) + # check indices arguments are valid if not indexable_old and old_idx is not None: raise ValueError(f"old_idx should not be specified to replace_{replace_type}" @@ -56,19 +61,15 @@ def _replace_dict(old, new, old_idx, new_idx, replace_type): + f" {old} is mixed and new {replace_type} of type {new} is" + " mixed or indexable.") if old_idx is None: # both indexes are none - if len(old) != len(new): + if len(split_old) != len(split_new): raise ValueError(f"if neither index is specified to replace_{replace_type}" + f" and both old {replace_type} of type {old} and new" + f" {replace_type} of type {new} are mixed or indexable" - + " then old and new must be the same length.") + + f" then old of length {len(split_old)} and new of length {len(split_new)}" + + " must be the same length.") # make the replace_dict - if mixed_old: - split_old = split(old) - if indexable_new: - split_new = new if type(new) is tuple else split(new) - replace_dict = {} if not indexable_old and not indexable_new: diff --git a/gusto/time_discretisation.py b/gusto/time_discretisation.py index cbe6cad4b..c76ca41b3 100644 --- a/gusto/time_discretisation.py +++ b/gusto/time_discretisation.py @@ -185,7 +185,7 @@ def lhs(self): """Set up the discretisation's left hand side (the time derivative).""" l = self.residual.label_map( lambda t: t.has_label(time_derivative), - map_if_true=replace_subject(self.x_out, self.idx), + map_if_true=replace_subject(self.x_out, old_idx=self.idx), map_if_false=drop) return l.form @@ -195,7 +195,7 @@ def rhs(self): """Set up the time discretisation's right hand side.""" r = self.residual.label_map( all_terms, - map_if_true=replace_subject(self.x1, self.idx)) + map_if_true=replace_subject(self.x1, old_idx=self.idx)) r = r.label_map( lambda t: t.has_label(time_derivative), @@ -437,7 +437,7 @@ def lhs(self): """Set up the discretisation's left hand side (the time derivative).""" l = self.residual.label_map( lambda t: t.has_label(time_derivative), - map_if_true=replace_subject(self.x_out, self.idx), + map_if_true=replace_subject(self.x_out, old_idx=self.idx), map_if_false=drop) return l.form @@ -447,7 +447,7 @@ def rhs(self): """Set up the time discretisation's right hand side.""" r = self.residual.label_map( all_terms, - map_if_true=replace_subject(self.x1, self.idx)) + map_if_true=replace_subject(self.x1, old_idx=self.idx)) r = r.label_map( lambda t: t.has_label(time_derivative), @@ -603,7 +603,7 @@ def lhs(self): """Set up the discretisation's left hand side (the time derivative).""" l = self.residual.label_map( all_terms, - map_if_true=replace_subject(self.x_out, self.idx)) + map_if_true=replace_subject(self.x_out, old_idx=self.idx)) l = l.label_map(lambda t: t.has_label(time_derivative), map_if_false=lambda t: self.dt*t) @@ -614,7 +614,7 @@ def rhs(self): """Set up the time discretisation's right hand side.""" r = self.residual.label_map( lambda t: t.has_label(time_derivative), - map_if_true=replace_subject(self.x1, self.idx), + map_if_true=replace_subject(self.x1, old_idx=self.idx), map_if_false=drop) return r.form @@ -685,7 +685,7 @@ def lhs(self): """Set up the discretisation's left hand side (the time derivative).""" l = self.residual.label_map( all_terms, - map_if_true=replace_subject(self.x_out, self.idx)) + map_if_true=replace_subject(self.x_out, old_idx=self.idx)) l = l.label_map(lambda t: t.has_label(time_derivative), map_if_false=lambda t: self.theta*self.dt*t) @@ -696,7 +696,7 @@ def rhs(self): """Set up the time discretisation's right hand side.""" r = self.residual.label_map( all_terms, - map_if_true=replace_subject(self.x1, self.idx)) + map_if_true=replace_subject(self.x1, old_idx=self.idx)) r = r.label_map(lambda t: t.has_label(time_derivative), map_if_false=lambda t: -(1-self.theta)*self.dt*t) @@ -798,7 +798,7 @@ def lhs0(self): """Set up the discretisation's left hand side (the time derivative).""" l = self.residual.label_map( all_terms, - map_if_true=replace_subject(self.x_out, self.idx)) + map_if_true=replace_subject(self.x_out, old_idx=self.idx)) l = l.label_map(lambda t: t.has_label(time_derivative), map_if_false=lambda t: self.dt*t) @@ -809,7 +809,7 @@ def rhs0(self): """Set up the time discretisation's right hand side for inital BDF step.""" r = self.residual.label_map( lambda t: t.has_label(time_derivative), - map_if_true=replace_subject(self.x1, self.idx), + map_if_true=replace_subject(self.x1, old_idx=self.idx), map_if_false=drop) return r.form @@ -819,7 +819,7 @@ def lhs(self): """Set up the discretisation's left hand side (the time derivative).""" l = self.residual.label_map( all_terms, - map_if_true=replace_subject(self.x_out, self.idx)) + map_if_true=replace_subject(self.x_out, old_idx=self.idx)) l = l.label_map(lambda t: t.has_label(time_derivative), map_if_false=lambda t: (2/3)*self.dt*t) @@ -830,11 +830,11 @@ def rhs(self): """Set up the time discretisation's right hand side for BDF2 steps.""" xn = self.residual.label_map( lambda t: t.has_label(time_derivative), - map_if_true=replace_subject(self.x1, self.idx), + map_if_true=replace_subject(self.x1, old_idx=self.idx), map_if_false=drop) xnm1 = self.residual.label_map( lambda t: t.has_label(time_derivative), - map_if_true=replace_subject(self.xnm1, self.idx), + map_if_true=replace_subject(self.xnm1, old_idx=self.idx), map_if_false=drop) r = (4/3.) * xn - (1/3.) * xnm1 @@ -930,7 +930,7 @@ def lhs(self): """Set up the discretisation's left hand side (the time derivative) for the TR stage.""" l = self.residual.label_map( all_terms, - map_if_true=replace_subject(self.xnpg, self.idx)) + map_if_true=replace_subject(self.xnpg, old_idx=self.idx)) l = l.label_map(lambda t: t.has_label(time_derivative), map_if_false=lambda t: 0.5*self.gamma*self.dt*t) @@ -941,7 +941,7 @@ def rhs(self): """Set up the time discretisation's right hand side for the TR stage.""" r = self.residual.label_map( all_terms, - map_if_true=replace_subject(self.xn, self.idx)) + map_if_true=replace_subject(self.xn, old_idx=self.idx)) r = r.label_map(lambda t: t.has_label(time_derivative), map_if_false=lambda t: -0.5*self.gamma*self.dt*t) @@ -952,7 +952,7 @@ def lhs_bdf2(self): """Set up the discretisation's left hand side (the time derivative) for the BDF2 stage.""" l = self.residual.label_map( all_terms, - map_if_true=replace_subject(self.x_out, self.idx)) + map_if_true=replace_subject(self.x_out, old_idx=self.idx)) l = l.label_map(lambda t: t.has_label(time_derivative), map_if_false=lambda t: ((1.0-self.gamma)/(2.0-self.gamma))*self.dt*t) @@ -963,11 +963,11 @@ def rhs_bdf2(self): """Set up the time discretisation's right hand side for the BDF2 stage.""" xn = self.residual.label_map( lambda t: t.has_label(time_derivative), - map_if_true=replace_subject(self.xn, self.idx), + map_if_true=replace_subject(self.xn, old_idx=self.idx), map_if_false=drop) xnpg = self.residual.label_map( lambda t: t.has_label(time_derivative), - map_if_true=replace_subject(self.xnpg, self.idx), + map_if_true=replace_subject(self.xnpg, old_idx=self.idx), map_if_false=drop) r = (1.0/(self.gamma*(2.0-self.gamma)))*xnpg - ((1.0-self.gamma)**2/(self.gamma*(2.0-self.gamma))) * xn @@ -1020,7 +1020,7 @@ def rhs0(self): """Set up the discretisation's right hand side for initial forward euler step.""" r = self.residual.label_map( all_terms, - map_if_true=replace_subject(self.x1, self.idx)) + map_if_true=replace_subject(self.x1, old_idx=self.idx)) r = r.label_map(lambda t: t.has_label(time_derivative), map_if_false=lambda t: -self.dt*t) @@ -1036,9 +1036,9 @@ def rhs(self): """Set up the discretisation's right hand side for leapfrog steps.""" r = self.residual.label_map( lambda t: t.has_label(time_derivative), - map_if_false=replace_subject(self.x1, self.idx)) + map_if_false=replace_subject(self.x1, old_idx=self.idx)) r = r.label_map(lambda t: t.has_label(time_derivative), - map_if_true=replace_subject(self.xnm1, self.idx), + map_if_true=replace_subject(self.xnm1, old_idx=self.idx), map_if_false=lambda t: -2.0*self.dt*t) return r.form @@ -1143,7 +1143,7 @@ def rhs0(self): """Set up the discretisation's right hand side for initial forward euler step.""" r = self.residual.label_map( all_terms, - map_if_true=replace_subject(self.x[-1], self.idx)) + map_if_true=replace_subject(self.x[-1], old_idx=self.idx)) r = r.label_map(lambda t: t.has_label(time_derivative), map_if_false=lambda t: -self.dt*t) @@ -1158,13 +1158,13 @@ def lhs(self): def rhs(self): """Set up the discretisation's right hand side for Adams Bashforth steps.""" r = self.residual.label_map(all_terms, - map_if_true=replace_subject(self.x[-1], self.idx)) + map_if_true=replace_subject(self.x[-1], old_idx=self.idx)) r = r.label_map(lambda t: t.has_label(time_derivative), map_if_false=lambda t: -self.b[-1]*self.dt*t) for n in range(self.nlevels-1): rtemp = self.residual.label_map(lambda t: t.has_label(time_derivative), map_if_true=drop, - map_if_false=replace_subject(self.x[n], self.idx)) + map_if_false=replace_subject(self.x[n], old_idx=self.idx)) rtemp = rtemp.label_map(lambda t: t.has_label(time_derivative), map_if_false=lambda t: -self.dt*self.b[n]*t) r += rtemp @@ -1274,7 +1274,7 @@ def rhs0(self): """Set up the discretisation's right hand side for initial trapezoidal step.""" r = self.residual.label_map( all_terms, - map_if_true=replace_subject(self.x[-1], self.idx)) + map_if_true=replace_subject(self.x[-1], old_idx=self.idx)) r = r.label_map(lambda t: t.has_label(time_derivative), map_if_false=lambda t: -0.5*self.dt*t) @@ -1285,7 +1285,7 @@ def lhs0(self): """Set up the time discretisation's right hand side for initial trapezoidal step.""" l = self.residual.label_map( all_terms, - map_if_true=replace_subject(self.x_out, self.idx)) + map_if_true=replace_subject(self.x_out, old_idx=self.idx)) l = l.label_map(lambda t: t.has_label(time_derivative), map_if_false=lambda t: 0.5*self.dt*t) return l.form @@ -1295,7 +1295,7 @@ def lhs(self): """Set up the time discretisation's right hand side for Adams Moulton steps.""" l = self.residual.label_map( all_terms, - map_if_true=replace_subject(self.x_out, self.idx)) + map_if_true=replace_subject(self.x_out, old_idx=self.idx)) l = l.label_map(lambda t: t.has_label(time_derivative), map_if_false=lambda t: self.bl*self.dt*t) return l.form @@ -1304,13 +1304,13 @@ def lhs(self): def rhs(self): """Set up the discretisation's right hand side for Adams Moulton steps.""" r = self.residual.label_map(all_terms, - map_if_true=replace_subject(self.x[-1], self.idx)) + map_if_true=replace_subject(self.x[-1], old_idx=self.idx)) r = r.label_map(lambda t: t.has_label(time_derivative), map_if_false=lambda t: -self.br[-1]*self.dt*t) for n in range(self.nlevels-1): rtemp = self.residual.label_map(lambda t: t.has_label(time_derivative), map_if_true=drop, - map_if_false=replace_subject(self.x[n], self.idx)) + map_if_false=replace_subject(self.x[n], old_idx=self.idx)) rtemp = rtemp.label_map(lambda t: t.has_label(time_derivative), map_if_false=lambda t: -self.dt*self.br[n]*t) r += rtemp diff --git a/unit-tests/fml_tests/test_replace_perp.py b/unit-tests/fml_tests/test_replace_perp.py index 967780696..e429fda81 100644 --- a/unit-tests/fml_tests/test_replace_perp.py +++ b/unit-tests/fml_tests/test_replace_perp.py @@ -36,7 +36,7 @@ def test_replace_perp(): u, D = TrialFunctions(W) a = inner(u, w)*dx + D*p*dx - L = form.label_map(all_terms, replace_subject(U1, 0)) + L = form.label_map(all_terms, replace_subject(U1, old_idx=0, new_idx=0)) U2 = Function(W) solve(a == L.form, U2) From 27f2f3c2012c2c1383e9783ab581c3862257408a Mon Sep 17 00:00:00 2001 From: jshipton Date: Wed, 26 Jul 2023 15:27:30 +0100 Subject: [PATCH 07/10] we do need the old index for test and trial functions --- gusto/fml/replacement.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/gusto/fml/replacement.py b/gusto/fml/replacement.py index bb12f8d3e..a76e6ad03 100644 --- a/gusto/fml/replacement.py +++ b/gusto/fml/replacement.py @@ -93,7 +93,7 @@ def _replace_dict(old, new, old_idx, new_idx, replace_type): # ---------------------------------------------------------------------------- # # Replacement routines # ---------------------------------------------------------------------------- # -def replace_test_function(new_test, new_idx=None): +def replace_test_function(new_test, old_idx=None, new_idx=None): """ A routine to replace the test function in a term with a new test function. @@ -118,7 +118,7 @@ def repl(t): """ old_test = t.form.arguments()[0] replace_dict = _replace_dict(old_test, new_test, - old_idx=None, new_idx=new_idx, + old_idx=old_idx, new_idx=new_idx, replace_type='test') try: @@ -133,7 +133,7 @@ def repl(t): return repl -def replace_trial_function(new_trial, new_idx=None): +def replace_trial_function(new_trial, old_idx=None, new_idx=None): """ A routine to replace the trial function in a term with a new expression. @@ -163,7 +163,7 @@ def repl(t): raise TypeError('Trying to replace trial function of a form that is not linear') old_trial = t.form.arguments()[1] replace_dict = _replace_dict(old_trial, new_trial, - old_idx=None, new_idx=new_idx, + old_idx=old_idx, new_idx=new_idx, replace_type='trial') try: From ca840e35da0bda20fc261465ec40e3295790fef7 Mon Sep 17 00:00:00 2001 From: jshipton Date: Wed, 26 Jul 2023 15:27:58 +0100 Subject: [PATCH 08/10] add tests for replacing test and trial functions --- unit-tests/fml_tests/test_replace_subject.py | 251 ++++++++++++++++--- 1 file changed, 219 insertions(+), 32 deletions(-) diff --git a/unit-tests/fml_tests/test_replace_subject.py b/unit-tests/fml_tests/test_replace_subject.py index 57d4feaac..2f4ae17dc 100644 --- a/unit-tests/fml_tests/test_replace_subject.py +++ b/unit-tests/fml_tests/test_replace_subject.py @@ -3,13 +3,24 @@ """ from firedrake import (UnitSquareMesh, FunctionSpace, Function, TestFunction, + TestFunctions, TrialFunction, TrialFunctions, + Argument, VectorFunctionSpace, dx, inner, split, grad) -from gusto.fml import (Label, subject, replace_subject, drop) +from gusto.fml import (Label, subject, replace_subject, + replace_test_function, replace_trial_function, + drop, all_terms) import pytest from collections import namedtuple -ReplaceArgs = namedtuple("ReplaceArgs", "new_subj idxs error") +ReplaceSubjArgs = namedtuple("ReplaceSubjArgs", "new_subj idxs error") +ReplaceArgsArgs = namedtuple("ReplaceArgsArgs", "new_arg idxs error replace_function arg_idx") + +def ReplaceTestArgs(*args): + return ReplaceArgsArgs(*args, replace_test_function, 0) + +def ReplaceTrialArgs(*args): + return ReplaceArgsArgs(*args, replace_trial_function, 1) # some dummy labels foo_label = Label("foo") @@ -24,7 +35,7 @@ Wv = Vv*V1 -@pytest.fixture +@pytest.fixture() def primal_form(): primal_subj = Function(V0) primal_test = TestFunction(V0) @@ -35,16 +46,46 @@ def primal_form(): return primal_term1 + primal_term2 -def primal_argsets(): +def primal_subj_argsets(): + argsets = [ + ReplaceSubjArgs(Function(V0), {}, None), + ReplaceSubjArgs(Function(V0), {'new_idx': 0}, ValueError), + ReplaceSubjArgs(Function(V0), {'old_idx': 0}, ValueError), + ReplaceSubjArgs(Function(W), {'new_idx': 0}, None), + ReplaceSubjArgs(Function(W), {'new_idx': 1}, None), + ReplaceSubjArgs(split(Function(W)), {'new_idx': 1}, None), + ReplaceSubjArgs(Function(W), {'old_idx': 0}, ValueError), + ReplaceSubjArgs(Function(W), {'new_idx': 7}, IndexError) + ] + return argsets + + +def primal_test_argsets(): + argsets = [ + ReplaceTestArgs(TestFunction(V0), {}, None), + ReplaceTestArgs(TestFunction(V0), {'new_idx': 0}, ValueError), + ReplaceTestArgs(TestFunction(W), {'new_idx': 0}, None), + ReplaceTestArgs(TestFunction(W), {'new_idx': 1}, None), + ReplaceTestArgs(TestFunctions(W), {'new_idx': 1}, None), + ReplaceTestArgs(TestFunction(W), {'new_idx': 7}, IndexError) + ] + return argsets + + +def primal_trial_argsets(): argsets = [ - ReplaceArgs(Function(V0), {}, None), - ReplaceArgs(Function(V0), {'new_idx': 0}, ValueError), - ReplaceArgs(Function(V0), {'old_idx': 0}, ValueError), - ReplaceArgs(Function(W), {'new_idx': 0}, None), - ReplaceArgs(Function(W), {'new_idx': 1}, None), - ReplaceArgs(split(Function(W)), {'new_idx': 1}, None), - ReplaceArgs(Function(W), {'old_idx': 0}, ValueError), - ReplaceArgs(Function(W), {'new_idx': 7}, IndexError), + ReplaceTrialArgs(TrialFunction(V0), {}, None), + ReplaceTrialArgs(TrialFunction(V0), {'new_idx': 0}, ValueError), + ReplaceTrialArgs(TrialFunction(W), {'new_idx': 0}, None), + ReplaceTrialArgs(TrialFunction(W), {'new_idx': 1}, None), + ReplaceTrialArgs(TrialFunctions(W), {'new_idx': 1}, None), + ReplaceTrialArgs(TrialFunction(W), {'new_idx': 7}, IndexError), + ReplaceTrialArgs(Function(V0), {}, None), + ReplaceTrialArgs(Function(V0), {'new_idx': 0}, ValueError), + ReplaceTrialArgs(Function(W), {'new_idx': 0}, None), + ReplaceTrialArgs(Function(W), {'new_idx': 1}, None), + ReplaceTrialArgs(split(Function(W)), {'new_idx': 1}, None), + ReplaceTrialArgs(Function(W), {'new_idx': 7}, IndexError), ] return argsets @@ -63,15 +104,50 @@ def mixed_form(): return mixed_term1 + mixed_term2 -def mixed_argsets(): +def mixed_subj_argsets(): argsets = [ - ReplaceArgs(Function(W), {}, None), - ReplaceArgs(Function(W), {'new_idx': 0, 'old_idx': 0}, None), - ReplaceArgs(Function(W), {'old_idx': 0}, ValueError), - ReplaceArgs(Function(W), {'new_idx': 0}, ValueError), - ReplaceArgs(Function(V0), {'old_idx': 0}, None), - ReplaceArgs(Function(V0), {'new_idx': 0}, ValueError), - ReplaceArgs(split(Function(W)), {'new_idx': 0, 'old_idx': 0}, None), + ReplaceSubjArgs(Function(W), {}, None), + ReplaceSubjArgs(Function(W), {'new_idx': 0, 'old_idx': 0}, None), + ReplaceSubjArgs(Function(W), {'old_idx': 0}, ValueError), + ReplaceSubjArgs(Function(W), {'new_idx': 0}, ValueError), + ReplaceSubjArgs(Function(V0), {'old_idx': 0}, None), + ReplaceSubjArgs(Function(V0), {'new_idx': 0}, ValueError), + ReplaceSubjArgs(split(Function(W)), {'new_idx': 0, 'old_idx': 0}, None), + ] + return argsets + + +def mixed_test_argsets(): + argsets = [ + ReplaceTestArgs(TestFunction(W), {}, None), + ReplaceTestArgs(TestFunctions(W), {}, None), + ReplaceTestArgs(TestFunction(W), {'old_idx': 0, 'new_idx': 0}, None), + ReplaceTestArgs(TestFunctions(W), {'old_idx': 0}, ValueError), + ReplaceTestArgs(TestFunction(W), {'new_idx': 0}, ValueError), + #ReplaceTestArgs(TestFunctions(W), {'old_idx': 1, 'new_idx': 1}, None), + ReplaceTestArgs(TestFunction(V0), {'old_idx': 0}, None), + ReplaceTestArgs(TestFunctions(V0), {'new_idx': 1}, ValueError), + ReplaceTestArgs(TestFunction(W), {'old_idx': 7, 'new_idx': 7}, IndexError) + ] + return argsets + + +def mixed_trial_argsets(): + argsets = [ + ReplaceTrialArgs(TrialFunction(W), {}, None), + ReplaceTrialArgs(TrialFunctions(W), {}, None), + ReplaceTrialArgs(TrialFunction(W), {'old_idx': 0, 'new_idx': 0}, None), + #ReplaceTrialArgs(TrialFunctions(W), {'old_idx': 1, 'new_idx': 1}, None), + ReplaceTrialArgs(TrialFunction(V0), {'old_idx': 0}, None), + ReplaceTrialArgs(TrialFunctions(V0), {'new_idx': 1}, ValueError), + ReplaceTrialArgs(TrialFunction(W), {'old_idx': 7, 'new_idx': 7}, IndexError), + ReplaceTrialArgs(Function(W), {}, None), + ReplaceTrialArgs(split(Function(W)), {}, None), + ReplaceTrialArgs(Function(W), {'old_idx': 0, 'new_idx': 0}, None), + #ReplaceTrialArgs(Function(W), {'old_idx': 1, 'new_idx': 1}, None), + ReplaceTrialArgs(Function(V0), {'old_idx': 0}, None), + ReplaceTrialArgs(Function(V0), {'new_idx': 0}, ValueError), + ReplaceTrialArgs(Function(W), {'old_idx': 7, 'new_idx': 7}, IndexError), ] return argsets @@ -87,21 +163,36 @@ def vector_form(): return vector_term1 + vector_term2 -def vector_argsets(): +def vector_subj_argsets(): + argsets = [ + ReplaceSubjArgs(Function(Vv), {}, None), + ReplaceSubjArgs(Function(V0), {}, ValueError), + ReplaceSubjArgs(Function(Vv), {'new_idx': 0}, ValueError), + ReplaceSubjArgs(Function(Vv), {'old_idx': 0}, ValueError), + ReplaceSubjArgs(Function(Wv), {'new_idx': 0}, None), + ReplaceSubjArgs(Function(Wv), {'new_idx': 1}, ValueError), + ReplaceSubjArgs(split(Function(Wv)), {'new_idx': 0}, None), + ReplaceSubjArgs(Function(W), {'old_idx': 0}, ValueError), + ReplaceSubjArgs(Function(W), {'new_idx': 7}, IndexError), + ] + return argsets + + +def vector_test_argsets(): argsets = [ - ReplaceArgs(Function(Vv), {}, None), - ReplaceArgs(Function(V0), {'new_idx': 0}, ValueError), - ReplaceArgs(Function(V0), {'old_idx': 0}, ValueError), - ReplaceArgs(Function(Wv), {'new_idx': 0}, None), - ReplaceArgs(Function(Wv), {'new_idx': 1}, ValueError), - ReplaceArgs(split(Function(Wv)), {'new_idx': 0}, None), - ReplaceArgs(Function(W), {'old_idx': 0}, ValueError), - ReplaceArgs(Function(W), {'new_idx': 7}, IndexError), + ReplaceTestArgs(TestFunction(Vv), {}, None), + ReplaceTestArgs(TestFunction(V0), {}, ValueError), + ReplaceTestArgs(TestFunction(Vv), {'new_idx': 0}, ValueError), + ReplaceTestArgs(TestFunction(Wv), {'new_idx': 0}, None), + ReplaceTestArgs(TestFunction(Wv), {'new_idx': 1}, ValueError), + ReplaceTestArgs(TestFunctions(Wv), {'new_idx': 0}, None), + #ReplaceTestArgs(TestFunction(W), {'new_idx': 0}, None), + #ReplaceTestArgs(TestFunction(W), {'new_idx': 7}, IndexError), ] return argsets -@pytest.mark.parametrize('argset', primal_argsets()) +@pytest.mark.parametrize('argset', primal_subj_argsets()) def test_replace_subject_primal(primal_form, argset): new_subj = argset.new_subj idxs = argset.idxs @@ -130,7 +221,7 @@ def test_replace_subject_primal(primal_form, argset): map_if_true=replace_subject(new_subj, **idxs)) -@pytest.mark.parametrize('argset', mixed_argsets()) +@pytest.mark.parametrize('argset', mixed_subj_argsets()) def test_replace_subject_mixed(mixed_form, argset): new_subj = argset.new_subj idxs = argset.idxs @@ -159,7 +250,7 @@ def test_replace_subject_mixed(mixed_form, argset): map_if_true=replace_subject(new_subj, **idxs)) -@pytest.mark.parametrize('argset', vector_argsets()) +@pytest.mark.parametrize('argset', vector_subj_argsets()) def test_replace_subject_vector(vector_form, argset): new_subj = argset.new_subj idxs = argset.idxs @@ -186,3 +277,99 @@ def test_replace_subject_vector(vector_form, argset): new_form = vector_form.label_map( lambda t: t.has_label(foo_label), map_if_true=replace_subject(new_subj, **idxs)) + + +@pytest.mark.parametrize('argset', primal_test_argsets() + primal_trial_argsets()) +def test_replace_arg_primal(primal_form, argset): + new_arg = argset.new_arg + idxs = argset.idxs + error = argset.error + replace_function = argset.replace_function + arg_idx = argset.arg_idx + primal_form = primal_form.label_map(lambda t: t.has_label(subject), + replace_subject(TrialFunction(V0)), + drop) + + if error is None: + new_form = primal_form.label_map( + all_terms, + map_if_true=replace_function(new_arg, **idxs)) + + if 'new_idx' in idxs: + split_arg = new_arg if type(new_arg) is tuple else split(new_arg) + new_arg = split_arg[idxs['new_idx']].ufl_operands[0] + + if isinstance(new_arg, Argument): + assert new_form.form.arguments()[arg_idx] is new_arg + elif type(new_arg) is Function: + assert new_form.form.coefficients()[0] is new_arg + + else: + with pytest.raises(error): + new_form = primal_form.label_map( + all_terms, + map_if_true=replace_function(new_arg, **idxs)) + + +@pytest.mark.parametrize('argset', mixed_test_argsets() + mixed_trial_argsets()) +def test_replace_arg_mixed(mixed_form, argset): + new_arg = argset.new_arg + idxs = argset.idxs + error = argset.error + replace_function = argset.replace_function + arg_idx = argset.arg_idx + mixed_form = mixed_form.label_map(lambda t: t.has_label(subject), + replace_subject(TrialFunction(W)), + drop) + + if error is None: + new_form = mixed_form.label_map( + all_terms, + map_if_true=replace_function(new_arg, **idxs)) + + if 'new_idx' in idxs: + split_arg = new_arg if type(new_arg) is tuple else split(new_arg) + new_arg = split_arg[idxs['new_idx']].ufl_operands[0] + + if isinstance(new_arg, Argument): + assert new_form.form.arguments()[arg_idx] is new_arg + elif type(new_arg) is Function: + assert new_form.form.coefficients()[0] is new_arg + + else: + with pytest.raises(error): + new_form = mixed_form.label_map( + all_terms, + map_if_true=replace_function(new_arg, **idxs)) + + +@pytest.mark.parametrize('argset', vector_test_argsets()) +def test_replace_arg_vector(vector_form, argset): + new_arg = argset.new_arg + idxs = argset.idxs + error = argset.error + replace_function = argset.replace_function + arg_idx = argset.arg_idx + vector_form = vector_form.label_map(lambda t: t.has_label(subject), + replace_subject(TrialFunction(Vv)), + drop) + + if error is None: + new_form = vector_form.label_map( + all_terms, + map_if_true=replace_function(new_arg, **idxs)) + + if 'new_idx' in idxs: + split_arg = new_arg if type(new_arg) is tuple else split(new_arg) + new_arg = split_arg[idxs['new_idx']].ufl_operands[0] + + if isinstance(new_arg, Argument): + assert new_form.form.arguments()[arg_idx] is new_arg + elif type(new_arg) is Function: + assert new_form.form.coefficients()[0] is new_arg + + else: + with pytest.raises(error): + new_form = vector_form.label_map( + all_terms, + map_if_true=replace_function(new_arg, **idxs)) From 4c8592d94d80b1cbc85d2d5112b0d249f8467233 Mon Sep 17 00:00:00 2001 From: jshipton Date: Wed, 26 Jul 2023 15:35:19 +0100 Subject: [PATCH 09/10] fix lint --- unit-tests/fml_tests/test_replace_subject.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/unit-tests/fml_tests/test_replace_subject.py b/unit-tests/fml_tests/test_replace_subject.py index 2f4ae17dc..e3a93cf7f 100644 --- a/unit-tests/fml_tests/test_replace_subject.py +++ b/unit-tests/fml_tests/test_replace_subject.py @@ -16,12 +16,15 @@ ReplaceSubjArgs = namedtuple("ReplaceSubjArgs", "new_subj idxs error") ReplaceArgsArgs = namedtuple("ReplaceArgsArgs", "new_arg idxs error replace_function arg_idx") + def ReplaceTestArgs(*args): return ReplaceArgsArgs(*args, replace_test_function, 0) + def ReplaceTrialArgs(*args): return ReplaceArgsArgs(*args, replace_trial_function, 1) + # some dummy labels foo_label = Label("foo") bar_label = Label("bar") @@ -71,7 +74,7 @@ def primal_test_argsets(): ] return argsets - + def primal_trial_argsets(): argsets = [ ReplaceTrialArgs(TrialFunction(V0), {}, None), @@ -124,7 +127,7 @@ def mixed_test_argsets(): ReplaceTestArgs(TestFunction(W), {'old_idx': 0, 'new_idx': 0}, None), ReplaceTestArgs(TestFunctions(W), {'old_idx': 0}, ValueError), ReplaceTestArgs(TestFunction(W), {'new_idx': 0}, ValueError), - #ReplaceTestArgs(TestFunctions(W), {'old_idx': 1, 'new_idx': 1}, None), + # ReplaceTestArgs(TestFunctions(W), {'old_idx': 1, 'new_idx': 1}, None), ReplaceTestArgs(TestFunction(V0), {'old_idx': 0}, None), ReplaceTestArgs(TestFunctions(V0), {'new_idx': 1}, ValueError), ReplaceTestArgs(TestFunction(W), {'old_idx': 7, 'new_idx': 7}, IndexError) @@ -137,14 +140,14 @@ def mixed_trial_argsets(): ReplaceTrialArgs(TrialFunction(W), {}, None), ReplaceTrialArgs(TrialFunctions(W), {}, None), ReplaceTrialArgs(TrialFunction(W), {'old_idx': 0, 'new_idx': 0}, None), - #ReplaceTrialArgs(TrialFunctions(W), {'old_idx': 1, 'new_idx': 1}, None), + # ReplaceTrialArgs(TrialFunctions(W), {'old_idx': 1, 'new_idx': 1}, None), ReplaceTrialArgs(TrialFunction(V0), {'old_idx': 0}, None), ReplaceTrialArgs(TrialFunctions(V0), {'new_idx': 1}, ValueError), ReplaceTrialArgs(TrialFunction(W), {'old_idx': 7, 'new_idx': 7}, IndexError), ReplaceTrialArgs(Function(W), {}, None), ReplaceTrialArgs(split(Function(W)), {}, None), ReplaceTrialArgs(Function(W), {'old_idx': 0, 'new_idx': 0}, None), - #ReplaceTrialArgs(Function(W), {'old_idx': 1, 'new_idx': 1}, None), + # ReplaceTrialArgs(Function(W), {'old_idx': 1, 'new_idx': 1}, None), ReplaceTrialArgs(Function(V0), {'old_idx': 0}, None), ReplaceTrialArgs(Function(V0), {'new_idx': 0}, ValueError), ReplaceTrialArgs(Function(W), {'old_idx': 7, 'new_idx': 7}, IndexError), @@ -186,8 +189,8 @@ def vector_test_argsets(): ReplaceTestArgs(TestFunction(Wv), {'new_idx': 0}, None), ReplaceTestArgs(TestFunction(Wv), {'new_idx': 1}, ValueError), ReplaceTestArgs(TestFunctions(Wv), {'new_idx': 0}, None), - #ReplaceTestArgs(TestFunction(W), {'new_idx': 0}, None), - #ReplaceTestArgs(TestFunction(W), {'new_idx': 7}, IndexError), + # ReplaceTestArgs(TestFunction(W), {'new_idx': 0}, None), + # ReplaceTestArgs(TestFunction(W), {'new_idx': 7}, IndexError), ] return argsets From af4f6c80c1db5c1ee8360f541d5818b5035f2b15 Mon Sep 17 00:00:00 2001 From: jshipton Date: Wed, 26 Jul 2023 15:36:06 +0100 Subject: [PATCH 10/10] rename file --- .../fml_tests/{test_replace_subject.py => test_replacement.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename unit-tests/fml_tests/{test_replace_subject.py => test_replacement.py} (100%) diff --git a/unit-tests/fml_tests/test_replace_subject.py b/unit-tests/fml_tests/test_replacement.py similarity index 100% rename from unit-tests/fml_tests/test_replace_subject.py rename to unit-tests/fml_tests/test_replacement.py