Skip to content

Commit

Permalink
Linting and typing fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Moloq committed Oct 4, 2024
1 parent dbcd2e9 commit 4fc190b
Showing 1 changed file with 24 additions and 20 deletions.
44 changes: 24 additions & 20 deletions src/rra_climate_health/training/run_training.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import itertools
from pathlib import Path
from typing import Any

import click
import itertools
import numpy as np
import pandas as pd
from pymer4.models.Lmer import Lmer
from sklearn.model_selection import GroupShuffleSplit,StratifiedGroupKFold
from sklearn.calibration import calibration_curve
from sklearn.metrics import mean_absolute_error, brier_score_loss, log_loss, mean_squared_error, precision_recall_curve, auc
from rra_tools import jobmon

from rra_climate_health import cli_options as clio
Expand Down Expand Up @@ -62,7 +58,7 @@ def model_training_main(
output_root: Path,
measure: str,
model_version: str,
submodel: list[tuple[str, str]] = [],
submodel: list[tuple[str, str]] | None = None,
) -> None:
cm_data = ClimateMalnutritionData(output_root / measure)
model_spec = cm_data.load_model_specification(model_version)
Expand All @@ -73,9 +69,12 @@ def model_training_main(
full_training_data = full_training_data.reset_index(drop=True)
full_training_data["intercept"] = 1.0

subset_mask = (full_training_data.sex_id == sex_id) & (
full_training_data.age_group_id == age_group_id
)
subset_mask = pd.Series(True, index=full_training_data.index) # noqa: FBT003
if submodel:
for var, value in submodel:
# Convert value to the type of the column in the training data and build subset mask
retyped_value = full_training_data[var].dtype.type(value)
subset_mask = (full_training_data[var] == retyped_value) & subset_mask

raw_df = full_training_data.loc[:, model_spec.raw_variables]
null_mask = raw_df.isna().any(axis=1)
Expand Down Expand Up @@ -116,7 +115,7 @@ def model_training_main(
"--submodel",
"-s",
multiple=True,
type = (str, str),
type=(str, str),
help="Submodel specification.",
)
def model_training_task(
Expand Down Expand Up @@ -158,22 +157,27 @@ def model_training(
training_data = cm_data.load_training_data(model_spec.version.training_data)

# Deal with submodels
submodel_vars = model_spec.submodel_vars or []
submodel_vars = [var.name for var in submodel_vars]
print('Training submodels by:', ", ".join([var for var in submodel_vars]))
submodel_vars = [var.name for var in (model_spec.submodel_vars or [])]
print("Training submodels by:", ", ".join(submodel_vars))
submodel_var_values = [training_data[var].unique() for var in submodel_vars]
#cross product of all submodel_var_values lists
submodel_specs = [list(zip(submodel_vars, values)) for values in itertools.product(*submodel_var_values)]
submodel_specs = [' --submodel '.join([f'{var} {val}' for var, val in spec]) for spec in submodel_specs]
node_args = {"submodel": submodel_specs} if submodel_vars else dict()
node_args['measure'] = [measure]
# cross product of all submodel_var_values lists
submodel_specs = [
list(zip(submodel_vars, values, strict=False))
for values in itertools.product(*submodel_var_values)
]
submodel_specs_strs = [
" --submodel ".join([f"{var} {val}" for var, val in spec])
for spec in submodel_specs
]
node_args = {"submodel": submodel_specs_strs} if submodel_vars else {}
node_args["measure"] = [measure]

print("Running model training for model version", model_version)

jobmon.run_parallel(
runner="sttask",
task_name="training",
node_args=node_args,
node_args=node_args, # type: ignore[arg-type]
task_args={
"output-root": output_root,
"model-version": model_version,
Expand All @@ -189,4 +193,4 @@ def model_training(
log_root=str(version_root),
)

print('Model training complete. Results can be found at', version_root)
print("Model training complete. Results can be found at", version_root)

0 comments on commit 4fc190b

Please sign in to comment.