Skip to content

Commit

Permalink
Merge pull request #82 from t-bz/gpu_support
Browse files Browse the repository at this point in the history
Add GPU support
  • Loading branch information
roussel-ryan authored Jun 6, 2023
2 parents 7b4d47e + 8bbeeb9 commit f756eb8
Showing 1 changed file with 33 additions and 9 deletions.
42 changes: 33 additions & 9 deletions lume_model/torch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
output_format: Optional[Dict[str, str]] = {"type": "tensor"},
feature_order: Optional[list] = None,
output_order: Optional[list] = None,
device: Optional[Union[torch.device, str]] = "cpu",
) -> None:
"""Initializes the model, stores inputs/outputs and determines the format
in which the model results will be output.
Expand All @@ -49,33 +50,43 @@ def __init__(
order in which they are passed to the model
output_order: List[str]: list containing the names of outputs in the
order the model produces them
device (Optional[Union[torch.device, str]]): Device on which the
model will be evaluated. Defaults to "cpu".
"""
super(BaseModel, self).__init__()

# Save init
self.device = device
self.input_variables = input_variables
self.default_values = torch.tensor(
[var.default for var in input_variables.values()],
dtype=torch.double,
requires_grad=True,
requires_grad=True
)
self.output_variables = output_variables
self._model_file = model_file
self._output_format = output_format

# make sure all of the transformers are in eval mode
# make sure transformers are passed as lists
if not isinstance(input_transformers, list) or not isinstance(
output_transformers, list):
raise TypeError(
"In- and output transformers have to be passed as lists.")
self._input_transformers = input_transformers
for transformer in self._input_transformers:
transformer.eval()
self._output_transformers = output_transformers
for transformer in self._output_transformers:

# put all transformers in eval mode
for transformer in self._input_transformers + self._output_transformers:
transformer.eval()

self._model = torch.load(model_file).double()
self._model.eval()
self._model.requires_grad = False

# move model, transformers and default values to device
self.to(self.device)

self._feature_order = feature_order
self._output_order = output_order

Expand Down Expand Up @@ -174,17 +185,18 @@ def _prepare_inputs(
for var_name, var in input_variables.items():
if isinstance(var, InputVariable):
model_vals[var_name] = torch.tensor(
var.value, dtype=torch.double, requires_grad=True
var.value, dtype=torch.double, requires_grad=True,
device=self.device
)
self.input_variables[var_name].value = var.value
elif isinstance(var, float):
model_vals[var_name] = torch.tensor(
var, dtype=torch.double, requires_grad=True
var, dtype=torch.double, requires_grad=True,
device=self.device
)
self.input_variables[var_name].value = var
elif isinstance(var, torch.Tensor):
var = var.double()
var = var.squeeze()
var = var.double().squeeze().to(self.device)
if not var.requires_grad:
var.requires_grad = True
model_vals[var_name] = var
Expand Down Expand Up @@ -339,3 +351,15 @@ def _update_image_limits(
self.output_variables[variable.name].y_max = predicted_output[
self.output_variables[variable.name].y_max_variable
].item()

def to(self, device: Union[torch.device, str]):
"""Updates the device for the model, transformers and default values.
Args:
device: Device on which the model will be evaluated.
"""
self._model.to(device)
for transformer in self._input_transformers + self._output_transformers:
transformer.to(device)
self.default_values = self.default_values.to(device)
self.device = device

0 comments on commit f756eb8

Please sign in to comment.