Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/dev' into dev
Browse files Browse the repository at this point in the history
  • Loading branch information
knutdrand committed Sep 5, 2024
2 parents a380d04 + 0414122 commit b7184f2
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 2 deletions.
5 changes: 4 additions & 1 deletion climate_health/external/external_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,9 @@ def predict(self, future_data: IsSpatioTemporalDataSet[FeatureType]) -> IsSpatio
df = future_data.to_pandas()
df['disease_cases'] = np.nan

# todo: instead of using saved state for historic data, get histori data in as argument to predict
# send historic data and future data as two seperate data sets to model

new_pd = self._adapt_data(df)
if self.is_lagged:
new_pd = pd.concat([self._saved_state, new_pd]).sort_values(['location', 'time_period'])
Expand Down Expand Up @@ -413,7 +416,7 @@ def get_model_from_mlproject_file(mlproject_file):
adapters = config.get('adapters', None)
allowed_data_types = {'HealthData': HealthData}
data_type = allowed_data_types.get(config.get('data_type', None), None)
return ExternalMLflowModel(mlproject_file, name=name, adapters=adapters, data_type=data_type,
return ExternalMLflowModel(mlproject_file.parent, name=name, adapters=adapters, data_type=data_type,
working_dir=Path(mlproject_file).parent)


Expand Down
6 changes: 5 additions & 1 deletion climate_health/external/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,15 @@ def __init__(self, model_path: str, name: str=None, adapters=None, working_dir="
self._model_file_name = Path(model_path).name + ".model"
self.is_lagged = True
self._data_type = data_type
self._name = name

@property
def name(self):
return self._name

def __call__(self):
return self

def train(self, train_data: DataSet, extra_args=None):

if extra_args is None:
Expand All @@ -51,7 +55,7 @@ def train(self, train_data: DataSet, extra_args=None):
response = mlflow.projects.run(str(self.model_path), entry_point="train",
parameters={
"train_data": str(train_file_name),
"model_output_file": str(self._model_file_name)
"model": str(self._model_file_name)
},
build_image=True)
self._saved_state = new_pd
Expand Down
22 changes: 22 additions & 0 deletions external_models/deepar/MLproject
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
name: My Project

# python_env: python_env.yaml
# or
# conda_env: my_env.yaml
# or
# docker_env:
# image: mlflow-docker-example

entry_points:
train:
parameters:
train_data: path
model: path
command: "ch_modelling train {train_data} {model}"
predict:
parameters:
model: path
historic_data: path
future_data: path
out_file: path
command: "ch_modelling predict {model} {historic_data} {future_data} {out_file}"
1 change: 1 addition & 0 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ earthengine-api
python-dotenv
myst_parser
furo
virtualenv

0 comments on commit b7184f2

Please sign in to comment.