Skip to content

Commit

Permalink
Merge pull request #395 from firedrakeproject/JHopeCollins/indexed_re…
Browse files Browse the repository at this point in the history
…place

Split apart the index argument to the `replace_subject,trial,test` functions into `idx_new` and `idx_out`
  • Loading branch information
tommbendall authored Jul 26, 2023
2 parents caa0f40 + 172ee74 commit 2b311c3
Show file tree
Hide file tree
Showing 5 changed files with 482 additions and 210 deletions.
121 changes: 74 additions & 47 deletions gusto/fml/replacement.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,67 +13,88 @@
# ---------------------------------------------------------------------------- #
# A general routine for building the replacement dictionary
# ---------------------------------------------------------------------------- #
def _replace_dict(old, new, idx, replace_type):
def _replace_dict(old, new, old_idx, new_idx, replace_type):
"""
Build a dictionary to pass to the ufl.replace routine
The dictionary matches variables in the old term with those in the new
Does not check types unless indexing is required (leave type-checking to ufl.replace)
"""

replace_dict = {}

if type(old.ufl_element()) is MixedElement:

mixed_new = hasattr(new, "ufl_element") and type(new.ufl_element()) is MixedElement
indexable_new = type(new) is tuple or mixed_new

if indexable_new:
split_new = new if type(new) is tuple else split(new)

if len(split_new) != len(old.function_space()):
raise ValueError(f"new {replace_type} of type {new} must be same length"
+ f"as replaced mixed {replace_type} of type {old}")
mixed_old = type(old.ufl_element()) is MixedElement
mixed_new = hasattr(new, "ufl_element") and type(new.ufl_element()) is MixedElement

indexable_old = mixed_old
indexable_new = mixed_new or type(new) is tuple

if mixed_old:
split_old = split(old)
if indexable_new:
split_new = new if type(new) is tuple else split(new)

# check indices arguments are valid
if not indexable_old and old_idx is not None:
raise ValueError(f"old_idx should not be specified to replace_{replace_type}"
+ f" when replaced {replace_type} of type {old} is not mixed.")

if not indexable_new and new_idx is not None:
raise ValueError(f"new_idx should not be specified to replace_{replace_type} when"
+ f" new {replace_type} of type {new} is not mixed or indexable.")

if indexable_old and not indexable_new:
if old_idx is None:
raise ValueError(f"old_idx must be specified to replace_{replace_type} when replaced"
+ f" {replace_type} of type {old} is mixed and new {replace_type}"
+ f" of type {new} is not mixed or indexable.")

if indexable_new and not indexable_old:
if new_idx is None:
raise ValueError(f"new_idx must be specified to replace_{replace_type} when new"
+ f" {replace_type} of type {new} is mixed or indexable and"
+ f" old {replace_type} of type {old} is not mixed.")

if indexable_old and indexable_new:
# must be both True or both False
if (old_idx is None) ^ (new_idx is None):
raise ValueError("both or neither old_idx and new_idx must be specified to"
+ f" replace_{replace_type} when old {replace_type} of type"
+ f" {old} is mixed and new {replace_type} of type {new} is"
+ " mixed or indexable.")
if old_idx is None: # both indexes are none
if len(split_old) != len(split_new):
raise ValueError(f"if neither index is specified to replace_{replace_type}"
+ f" and both old {replace_type} of type {old} and new"
+ f" {replace_type} of type {new} are mixed or indexable"
+ f" then old of length {len(split_old)} and new of length {len(split_new)}"
+ " must be the same length.")

# make the replace_dict

if idx is None:
for k, v in zip(split(old), split_new):
replace_dict[k] = v
else:
replace_dict[split(old)[idx]] = split_new[idx]

else: # new is not indexable
if idx is None:
raise ValueError(f"idx must be specified to replace_{replace_type} when"
+ f" replaced {replace_type} of type {old} is mixed and"
+ f" new {replace_type} of type {new} is a single component")

replace_dict[split(old)[idx]] = new

else: # old is not mixed

mixed_new = hasattr(new, "ufl_element") and type(new.ufl_element()) is MixedElement
indexable_new = type(new) is tuple or mixed_new
replace_dict = {}

if indexable_new:
split_new = new if type(new) is tuple else split(new)
if not indexable_old and not indexable_new:
replace_dict[old] = new

if idx is None:
raise ValueError(f"idx must be specified to replace_{replace_type} when"
+ f" replaced {replace_type} of type {old} is not mixed"
+ f" and new {replace_type} of type {new} is indexable")
elif not indexable_old and indexable_new:
replace_dict[old] = split_new[new_idx]

replace_dict[old] = split_new[idx]
elif indexable_old and not indexable_new:
replace_dict[split_old[old_idx]] = new

else:
replace_dict[old] = new
elif indexable_old and indexable_new:
if old_idx is None: # replace everything
for k, v in zip(split_old, split_new):
replace_dict[k] = v
else: # idxs are given
replace_dict[split_old[old_idx]] = split_new[new_idx]

return replace_dict


# ---------------------------------------------------------------------------- #
# Replacement routines
# ---------------------------------------------------------------------------- #
def replace_test_function(new_test, idx=None):
def replace_test_function(new_test, old_idx=None, new_idx=None):
"""
A routine to replace the test function in a term with a new test function.
Expand All @@ -97,7 +118,9 @@ def repl(t):
:class:`Term`: the new term.
"""
old_test = t.form.arguments()[0]
replace_dict = _replace_dict(old_test, new_test, idx, 'test')
replace_dict = _replace_dict(old_test, new_test,
old_idx=old_idx, new_idx=new_idx,
replace_type='test')

try:
new_form = ufl.replace(t.form, replace_dict)
Expand All @@ -111,7 +134,7 @@ def repl(t):
return repl


def replace_trial_function(new_trial, idx=None):
def replace_trial_function(new_trial, old_idx=None, new_idx=None):
"""
A routine to replace the trial function in a term with a new expression.
Expand Down Expand Up @@ -140,7 +163,9 @@ def repl(t):
if len(t.form.arguments()) != 2:
raise TypeError('Trying to replace trial function of a form that is not linear')
old_trial = t.form.arguments()[1]
replace_dict = _replace_dict(old_trial, new_trial, idx, 'trial')
replace_dict = _replace_dict(old_trial, new_trial,
old_idx=old_idx, new_idx=new_idx,
replace_type='trial')

try:
new_form = ufl.replace(t.form, replace_dict)
Expand All @@ -154,7 +179,7 @@ def repl(t):
return repl


def replace_subject(new_subj, idx=None):
def replace_subject(new_subj, old_idx=None, new_idx=None):
"""
A routine to replace the subject in a term with a new variable.
Expand All @@ -180,7 +205,9 @@ def repl(t):
"""

old_subj = t.get(subject)
replace_dict = _replace_dict(old_subj, new_subj, idx, 'subject')
replace_dict = _replace_dict(old_subj, new_subj,
old_idx=old_idx, new_idx=new_idx,
replace_type='subject')

try:
new_form = ufl.replace(t.form, replace_dict)
Expand Down
Loading

0 comments on commit 2b311c3

Please sign in to comment.