Skip to content

Commit

Permalink
Merge pull request #83 from t-bz/fix_PyTorchModel_gradients
Browse files Browse the repository at this point in the history
Fix handling of gradients and train/eval mode
  • Loading branch information
roussel-ryan committed Jun 30, 2023
2 parents f756eb8 + dd20dc5 commit e43850f
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 121 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,9 @@ venv.bak/
.spyderproject
.spyproject

# PyCharm project settings
.idea

# Rope project settings
.ropeproject

Expand Down
173 changes: 76 additions & 97 deletions lume_model/torch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ class PyTorchModel(BaseModel):
It is designed to implement the general behaviors expected for models used with
the pytorch lume-model tool kit.
By default we assume that these models are 'frozen' so we set requires_grad as
false and use the model in .eval() mode.
By default, we assume that these models are fixed, so we deactivate all gradients
and use the model in evaluation mode.
"""

def __init__(
Expand All @@ -27,42 +27,40 @@ def __init__(
input_transformers: Optional[List[ReversibleInputTransform]] = [],
output_transformers: Optional[List[ReversibleInputTransform]] = [],
output_format: Optional[Dict[str, str]] = {"type": "tensor"},
feature_order: Optional[list] = None,
output_order: Optional[list] = None,
feature_order: Optional[List[str]] = None,
output_order: Optional[List[str]] = None,
device: Optional[Union[torch.device, str]] = "cpu",
) -> None:
"""Initializes the model, stores inputs/outputs and determines the format
in which the model results will be output.
fixed_model: bool = True
):
"""Initializes the model.
Args:
model_file (str): Path to model file generated with torch.save()
input_variables (Dict[str, InputVariable]): list of model input variables
output_variables (Dict[str, OutputVariable]): list of model output variables
input_transformers: (List[ReversibleInputTransform]): list of transformer
objects to apply to input before passing to model
output_transformers: (List[ReversibleInputTransform]): list of transformer
objects to apply to output of model
output_format (Optional[dict]): Wrapper for interpreting outputs. This now handles
raw or softmax values, but should be expanded to accomodate misc
functions. Now, dictionary should look like:
{"type": Literal["raw", "string", "tensor", "variable"]}
feature_order: List[str]: list containing the names of features in the
order in which they are passed to the model
output_order: List[str]: list containing the names of outputs in the
order the model produces them
device (Optional[Union[torch.device, str]]): Device on which the
model will be evaluated. Defaults to "cpu".
Stores inputs/outputs and determines the format in which the model results will be output.
Args:
model_file: Path to model file generated with torch.save().
input_variables: List of model input variables.
output_variables: list of model output variables.
input_transformers: List of transformer objects to apply to input before passing
to model.
output_transformers: List of transformer objects to apply to output of model.
output_format: Wrapper for interpreting outputs. This now handles raw or softmax values,
but should be expanded to accommodate miscellaneous functions. Now, dictionary
should look like: {"type": Literal["raw", "string", "tensor", "variable"]}.
feature_order: List containing the names of features in the order in which they are
passed to the model.
output_order: List containing the names of outputs in the order the model
produces them.
fixed_model: If true, the model is put in evaluation mode and gradient computation
is deactivated.
device: Device on which the model will be evaluated. Defaults to "cpu".
"""
super(BaseModel, self).__init__()

# Save init
self.device = device
self.input_variables = input_variables
self.default_values = torch.tensor(
[var.default for var in input_variables.values()],
dtype=torch.double,
requires_grad=True
[var.default for var in input_variables.values()], dtype=torch.double
)
self.output_variables = output_variables
self._model_file = model_file
Expand All @@ -71,8 +69,7 @@ def __init__(
# make sure transformers are passed as lists
if not isinstance(input_transformers, list) or not isinstance(
output_transformers, list):
raise TypeError(
"In- and output transformers have to be passed as lists.")
raise TypeError("In- and output transformers have to be passed as lists.")
self._input_transformers = input_transformers
self._output_transformers = output_transformers

Expand All @@ -81,8 +78,9 @@ def __init__(
transformer.eval()

self._model = torch.load(model_file).double()
self._model.eval()
self._model.requires_grad = False
if fixed_model:
self._model.eval()
self._model.requires_grad_(False)

# move model, transformers and default values to device
self.to(self.device)
Expand All @@ -95,18 +93,17 @@ def features(self):
if self._feature_order is not None:
return self._feature_order
else:
# if there's no specified order, we make the assumption
# that the variables were passed in the desired order
# in the configuration file
# if there's no specified order, we make the assumption that the variables were passed
# in the desired order in the configuration file
return list(self.input_variables.keys())

@property
def outputs(self):
if self._output_order is not None:
return self._output_order
else:
# if there's no order specified, we assume it's the same as the
# order passed in the variables.yml file
# if there's no order specified, we assume it's the same as the order passed in the
# variables.yml file
return list(self.output_variables.keys())

@property
Expand Down Expand Up @@ -135,23 +132,20 @@ def output_transformers(

def evaluate(
self,
input_variables: Dict[str, Union[InputVariable, float, torch.Tensor]],
input_variables: Dict[str, Union[InputVariable, float, torch.Tensor]]
) -> Dict[str, Union[torch.Tensor, OutputVariable, float]]:
"""Evaluate model using new input variables.
"""Evaluates model using new input variables.
Args:
input_variables (Dict[str, InputVariable]): List of updated input
variables
input_variables: List of updated input variables.
Returns:
Dict[str, torch.Tensor]: Dictionary mapping var names to outputs
Dictionary mapping variable names to outputs.
"""
# all PyTorch models will follow the same process, the inputs
# are formatted, then converted to model features. Then they
# are passed through the model, and transformed again on the
# other side. The final dictionary is then converted into a
# useful form
# all PyTorch models will follow the same process, the inputs are formatted,
# then converted to model features. Then they are passed through the model,
# and transformed again on the other side. The final dictionary is then converted
# into a useful form
input_vals = self._prepare_inputs(input_variables)
input_vals = self._arrange_inputs(input_vals)
features = self._transform_inputs(input_vals)
Expand All @@ -165,63 +159,56 @@ def evaluate(
def _prepare_inputs(
self, input_variables: Dict[str, Union[InputVariable, float, torch.Tensor]]
) -> Dict[str, torch.Tensor]:
"""
Prepares the input variables dictionary as a format appropriate
to be passed to the transformers and updates the stored InputVariables
with new values
"""Prepares inputs to pass them to the transformers.
Prepares the input variables dictionary as a format appropriate to be passed to the
transformers and updates the stored InputVariables with new values.
Args:
input_variables (dict): Dictionary of input variable names to
variables in any format (InputVariable or raw values)
input_variables: Dictionary of input variable names to variables in any format
(InputVariable or raw values).
Returns:
dict (Dict[str, torch.Tensor]): dictionary of input variable
values to be passed to the transformers
Dictionary of input variable values to be passed to the transformers.
"""
# NOTE we only update the input variable if we receive a singular
# value, otherwise we don't know which value to assign so we just
# leave it
# NOTE we only update the input variable if we receive a singular value, otherwise we
# don't know which value to assign so we just leave it
model_vals = {}
for var_name, var in input_variables.items():
if isinstance(var, InputVariable):
model_vals[var_name] = torch.tensor(
var.value, dtype=torch.double, requires_grad=True,
device=self.device
var.value, dtype=torch.double, device=self.device
)
self.input_variables[var_name].value = var.value
elif isinstance(var, float):
model_vals[var_name] = torch.tensor(
var, dtype=torch.double, requires_grad=True,
device=self.device
var, dtype=torch.double, device=self.device
)
self.input_variables[var_name].value = var
elif isinstance(var, torch.Tensor):
var = var.double().squeeze().to(self.device)
if not var.requires_grad:
var.requires_grad = True
model_vals[var_name] = var
if var.dim() == 0:
self.input_variables[var_name].value = var.item()
else:
TypeError(
f"Unknown type {type(var)} passed to evaluate. Should be one of InputVariable, float or torch.Tensor"
f"Unknown type {type(var)} passed to evaluate."
f"Should be one of InputVariable, float or torch.Tensor."
)
return model_vals

def _arrange_inputs(self, input_variables: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
"""Enforces order of input variables.
Enforces the order of the input variables to be passed to the transformers
and models and updates the model with default values for any features that
are missing, maintaining the shape of the incoming features.
Args:
input_variables (dict): Dictionary of input variable names to raw
values of inputs
input_variables: Dictionary of input variable names to raw values of inputs.
Returns:
torch.Tensor: ordered tensor of input variables to be passed to the
transformers
Ordered tensor of input variables to be passed to the transformers.
"""
incoming_shape = list(input_variables.items())[0][1].unsqueeze(-1).shape
default_tensor = torch.tile(self.default_values, incoming_shape)
Expand All @@ -241,47 +228,40 @@ def _arrange_inputs(self, input_variables: Dict[str, torch.Tensor]) -> torch.Ten
return default_tensor

def _transform_inputs(self, input_values: torch.Tensor) -> torch.Tensor:
"""
Applies transformations to the inputs
"""Applies transformations to the inputs.
Args:
input_values (torch.Tensor): tensor of input variables to be passed
to the transformers
input_values: Tensor of input variables to be passed to the transformers.
Returns:
torch.Tensor: tensor of transformed input variables to be passed
to the model
Tensor of transformed input variables to be passed to the model.
"""
for transformer in self._input_transformers:
input_values = transformer(input_values)
return input_values

def _transform_outputs(self, model_output: torch.Tensor) -> torch.Tensor:
"""
Untransforms the model outputs to real units
"""Untransforms the model outputs to real units.
Args:
model_output (torch.Tensor): tensor of outputs from the model
model_output: Tensor of outputs from the model.
Returns:
Dict[str, torch.Tensor]: dictionary of variable name to tensor
of untransformed output variables
Dictionary of variable name to tensor of untransformed output variables.
"""
# NOTE do we need to sort these to reverse them?
for transformer in self._output_transformers:
model_output = transformer.untransform(model_output)
return model_output

def _parse_outputs(self, model_output: torch.Tensor) -> Dict[str, torch.Tensor]:
"""
Constructs dictionary from model outputs
"""Constructs dictionary from model outputs.
Args:
model_output (torch.Tensor): transformed output from NN model
model_output: Transformed output from NN model.
Returns:
Dict[str, torch.Tensor]: dictionary of output variable name to output
value
Dictionary of output variable name to output value.
"""
# NOTE if we have shape [50,3,1] coming out of the model, our output
# dictionary should have shape [50,3]
Expand All @@ -295,24 +275,23 @@ def _parse_outputs(self, model_output: torch.Tensor) -> Dict[str, torch.Tensor]:
def _prepare_outputs(
self, predicted_output: Dict[str, torch.Tensor]
) -> Dict[str, Union[OutputVariable, torch.Tensor]]:
"""
"""Updates and returns outputs according to _output_format.
Updates the output variables within the model to reflect the new values
if we only have a singular data point.
Args:
predicted_output (Dict[str, torch.Tensor]): Dictionary of output
variable name to value
predicted_output: Dictionary of output variable name to value.
Returns:
Dict[str, Union[OutputVariable,torch.Tensor]]: Dictionary of output
variable name to output tensor or OutputVariable depending
on model's _ouptut_format
Dictionary of output variable name to output tensor or OutputVariable depending
on model's _output_format.
"""
for variable in self.output_variables.values():
if predicted_output[variable.name].dim() == 0:
if variable.variable_type == "scalar":
self.output_variables[variable.name].value = predicted_output[
variable.name
].item()
self.output_variables[variable.name].value = \
predicted_output[variable.name].item()
elif variable.variable_type == "image":
# OutputVariables should be numpy arrays so we need to convert
# the tensor to a numpy array
Expand Down
Loading

0 comments on commit e43850f

Please sign in to comment.