Skip to content

Commit

Permalink
add support for multi-returning-value functions in transform (gwastro…
Browse files Browse the repository at this point in the history
…#4301)

* add support for multi-value functions in transform

* fix cc issue

* Update transforms.py

* fix cc issue

* fix

* Update transforms.py

* Update transforms.py

* fix cc issue

* Update transforms.py

* add 7 classes

* fix cc issues

* move LISA stuff to another PR

* Update transforms.py
  • Loading branch information
WuShichao authored and lpathak97 committed Mar 13, 2024
1 parent 80dd6e5 commit 8326527
Showing 1 changed file with 93 additions and 0 deletions.
93 changes: 93 additions & 0 deletions pycbc/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,98 @@ def from_config(cls, cp, section, outputs):
return cls(inputs, outputs, transform_functions, jacobian=jacobian)


class CustomTransformMultiOutputs(CustomTransform):
"""Allows for any transform to be defined. Based on CustomTransform,
but also supports multi-returning value functions.
Parameters
----------
input_args : (list of) str
The names of the input parameters.
output_args : (list of) str
The names of the output parameters.
transform_functions : dict
Dictionary mapping input args to a string giving a function call;
e.g., ``{'q': 'q_from_mass1_mass2(mass1, mass2)'}``.
jacobian : str, optional
String giving a jacobian function. The function must be in terms of
the input arguments.
"""

name = "custom_multi"

def __init__(self, input_args, output_args, transform_functions,
jacobian=None):
super(CustomTransformMultiOutputs, self).__init__(
input_args, output_args, transform_functions, jacobian)

def transform(self, maps):
"""Applies the transform functions to the given maps object.
Parameters
----------
maps : dict, or FieldArray
Returns
-------
dict or FieldArray
A map object containing the transformed variables, along with the
original variables. The type of the output will be the same as the
input.
"""
if self.transform_functions is None:
raise NotImplementedError("no transform function(s) provided")
# copy values to scratch
self._copytoscratch(maps)
# ensure that we return the same data type in each dict
getslice = self._getslice(maps)
# evaluate the functions
# func[0] is the function itself, func[1] is the index,
# this supports multiple returning values function
out = {
p: self._scratch[func[0]][func[1]][getslice] if
len(self._scratch[func[0]]) > 1 else
self._scratch[func[0]][getslice]
for p, func in self.transform_functions.items()
}
return self.format_output(maps, out)

@classmethod
def from_config(cls, cp, section, outputs):
"""Loads a CustomTransformMultiOutputs from the given config file.
Example section:
.. code-block:: ini
[{section}-outvar1+outvar2]
name = custom_multi
inputs = inputvar1, inputvar2
outvar1, outvar2 = func1(inputs)
jacobian = func2(inputs)
"""
tag = outputs
outputs = list(outputs.split(VARARGS_DELIM))
all_vars = ", ".join(outputs)
inputs = map(str.strip,
cp.get_opt_tag(section, "inputs", tag).split(","))
# get the functions for each output
transform_functions = {}
output_index = slice(None, None, None)
for var in outputs:
# check if option can be cast as a float
try:
func = cp.get_opt_tag(section, var, tag)
except Exception:
func = cp.get_opt_tag(section, all_vars, tag)
output_index = slice(outputs.index(var), outputs.index(var)+1)
transform_functions[var] = [func, output_index]
s = "-".join([section, tag])
if cp.has_option(s, "jacobian"):
jacobian = cp.get_opt_tag(section, "jacobian", tag)
else:
jacobian = None
return cls(inputs, outputs, transform_functions, jacobian=jacobian)


#
# =============================================================================
#
Expand Down Expand Up @@ -2725,6 +2817,7 @@ def from_config(cls, cp, section, outputs,
# dictionary of all transforms
transforms = {
CustomTransform.name: CustomTransform,
CustomTransformMultiOutputs.name: CustomTransformMultiOutputs,
MchirpQToMass1Mass2.name: MchirpQToMass1Mass2,
Mass1Mass2ToMchirpQ.name: Mass1Mass2ToMchirpQ,
MchirpEtaToMass1Mass2.name: MchirpEtaToMass1Mass2,
Expand Down

0 comments on commit 8326527

Please sign in to comment.