Skip to content

Commit

Permalink
Merge pull request #72 from jacquelinegarrahan/fix-keras-return
Browse files Browse the repository at this point in the history
Fix dict formatting
  • Loading branch information
jacquelinegarrahan authored Oct 12, 2022
2 parents 3b5855b + 80fd82c commit ac235a5
Showing 1 changed file with 14 additions and 18 deletions.
32 changes: 14 additions & 18 deletions lume_model/keras/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,7 @@ def evaluate(self, input_variables: Dict[str, InputVariable]) -> Dict[str, Outpu
Dict[str, OutputVariable]: List of output variables
"""
self.input_variables = {var.name: var for var in input_variables}

# convert list of input variables to dictionary
input_dictionary = {
input_variable.name: input_variable.value
for input_variable in input_variables
}
self.input_variables = input_variables

# converts from input_dict -> formatted input
formatted_input = self.format_input(input_dictionary)
Expand All @@ -97,37 +91,39 @@ def evaluate(self, input_variables: Dict[str, InputVariable]) -> Dict[str, Outpu
# prepare outputs will format return variables (dict-> variables)
return self._prepare_outputs(output)

def random_evaluate(self) -> List[OutputVariable]:
def random_evaluate(self) -> Dict[str, OutputVariable]:
"""Return a random evaluation of the model.
Returns:
List[OutputVariable]: List of outputs associated with random input
Dict[str, OutputVariable]: Outputs of random evaluation
"""
random_input = copy.deepcopy(self.input_variables)
for variable in self.input_variables:
if self.input_variables[variable].variable_type == "scalar":
for variable in self.input_variables.values():
if variable.variable_type == "scalar":
random_input[variable].value = np.random.uniform(
self.input_variables[variable].value_range[0],
self.input_variables[variable].value_range[1],
variable.value_range[0],
variable.value_range[1],
)

else:
random_input[variable].value = self.input_variables[variable].default
variable.value = variable.default

return self.evaluate(list(random_input.values()))
return self.evaluate(random_input)

def _prepare_outputs(
self, predicted_output: Dict[str, Any]
) -> List[OutputVariable]:
) -> Dict[str, OutputVariable]:
"""Prepares the model outputs to be served so that no additional manipulation
occurs in the BaseModel class.
Args:
model_outputs (dict): Dictionary of output variables to np.ndarrays of outputs
model_outputs (dict): Dictionary of output variables to np.ndarrays of
outputs
Returns:
List[OutputVariable]: List of output variables.
Dict[str, OutputVariable]: Dictionary of output variable name to output
variable.
"""
for variable in self.output_variables.values():
if variable.variable_type == "scalar":
Expand Down

0 comments on commit ac235a5

Please sign in to comment.