Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/dev' into dev
Browse files Browse the repository at this point in the history
# Conflicts:
#	tests/external/test_external_models.py
  • Loading branch information
knutdrand committed Aug 30, 2024
2 parents e8bc805 + 2080612 commit b42cdc8
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 9 deletions.
1 change: 1 addition & 0 deletions .github/workflows/build_sphinx_website.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ jobs:
- name: install dependencies
run: |
python -m pip install --upgrade pip
pip install -e .
pip install -r requirements_dev.txt
- name: make html & commit the changes
Expand Down
4 changes: 4 additions & 0 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,7 @@ meteostat
pytest-mock
furo
myst-parser
earthengine-api
python-dotenv
myst_parser
furo
11 changes: 9 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,20 @@
'pooch',
'python-dateutil',
'meteostat',
'cyclopts', 'requests', 'pydantic', 'pyyaml',
'cyclopts', 'requests',
'pydantic>=2.0',
'pyyaml',
'geopandas', 'libpysal', 'docker',
'jax', 'jaxlib', 'blackjax', 'dynamax', 'flax', 'optax',
'scipy',
'fastapi',
'gitpython', 'earthengine-api', 'python-dotenv', 'rq', "python-multipart", "uvicorn",
'pydantic-geojson', 'annotated_types'
'pydantic-geojson', 'annotated_types',
'pycountry',
'unidecode',
'httpx',
'earthengine-api'

]

test_requirements = ['pytest>=3', "hypothesis"]
Expand Down
15 changes: 8 additions & 7 deletions tests/external/test_external_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@

from climate_health.api import get_model_from_directory_or_github_url
from climate_health.spatio_temporal_data.temporal_dataclass import DataSet
from climate_health.datatypes import ClimateHealthTimeSeries
from climate_health.datatypes import ClimateHealthTimeSeries,FullData

logging.basicConfig(level=logging.INFO)
from climate_health.external.external_model import get_model_from_yaml_file, run_command
from ..data_fixtures import full_data, train_data, future_climate_data
from ..data_fixtures import full_data, train_data, train_data_pop, future_climate_data
from climate_health.util import conda_available


Expand All @@ -35,7 +35,7 @@ def test_python_model_from_folder(models_path, train_data, future_climate_data):
assert results is not None


def get_dataset_from_yaml(yaml_path: Path):
def get_dataset_from_yaml(yaml_path: Path, datatype=ClimateHealthTimeSeries):
specs = yaml.load(yaml_path.read_text(), Loader=yaml.FullLoader)
if 'demo_data' in specs:
path = yaml_path.parent / specs['demo_data']
Expand All @@ -49,19 +49,20 @@ def get_dataset_from_yaml(yaml_path: Path):
df[to_name] = df[from_name]
#df['disease_cases'] = np.arange(len(df))

return DataSet.from_pandas(df, ClimateHealthTimeSeries)
return DataSet.from_pandas(df, datatype)



#@pytest.mark.skipif(not conda_available(), reason='requires conda')
@pytest.mark.parametrize('model_directory', ['ewars_Plus'])
#@pytest.mark.parametrize('model_directory', ['naive_python_model'])
def test_all_external_models_acceptance(model_directory, models_path, train_data, future_climate_data):
def test_all_external_models_acceptance(model_directory, models_path, train_data_pop, future_climate_data):
"""Only tests that the model can be initiated and that train and predict
can be called without anything failing"""
print("Running")
yaml_path = models_path / model_directory / 'config.yml'
model = get_model_from_yaml_file(yaml_path, working_dir=models_path / model_directory)
train_data = get_dataset_from_yaml(yaml_path)
train_data = get_dataset_from_yaml(yaml_path, FullData)
model.setup()
model.train(train_data)
#results = model.predict(future_climate_data)
Expand All @@ -72,7 +73,7 @@ def test_all_external_models_acceptance(model_directory, models_path, train_data
@pytest.mark.parametrize('model_directory', ['ewars_Plus'])
def test_external_model_predict(model_directory, models_path):
yaml_path = models_path / model_directory / 'config.yml'
train_data = get_dataset_from_yaml(yaml_path)
train_data = get_dataset_from_yaml(yaml_path, FullData)
model = get_model_from_yaml_file(yaml_path, working_dir=models_path / model_directory)
model.setup()
#model.setup()
Expand Down

0 comments on commit b42cdc8

Please sign in to comment.