Skip to content

Commit

Permalink
Add 'load_config' function, which wraps 'from_yaml_file' and checks t…
Browse files Browse the repository at this point in the history
…hat either 'variables' or 'derived_variables' are included and that if both are included, they don't contain the same variable names
  • Loading branch information
ealerskans committed Dec 20, 2024
1 parent f61a3b6 commit 554f869
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 3 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ The package can also be used as a python module to create datasets directly, for
import mllam_data_prep as mdp

config_path = "example.danra.yaml"
config = mdp.Config.from_yaml_file(config_path)
config = mdp.Config.load_config(config_path)
ds = mdp.create_dataset(config=config)
```

Expand Down
50 changes: 49 additions & 1 deletion mllam_data_prep/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,54 @@ class Config(dataclass_wizard.JSONWizard, dataclass_wizard.YAMLWizard):
class _(JSONWizard.Meta):
raise_on_unknown_json_key = True

@staticmethod
def load_config(*args, **kwargs):
"""
Wrapper function for `from_yaml_file` to load config file and validate that:
- either `variables` or `derived_variables` are present in the config
- if both `variables` and `derived_variables` are present, that they don't
add the same variables to the dataset
Parameters
----------
*args: Positional arguments for `from_yaml_file`
**kwargs: Keyword arguments for `from_yaml_file`
Returns
-------
config: Config
"""

# Load the config
config = Config.from_yaml_file(*args, **kwargs)

for input_dataset in config.inputs.values():
if not input_dataset.variables and not input_dataset.derived_variables:
raise InvalidConfigException(
"At least one of the keys `variables` and `derived_variables` must be included"
" in the input dataset."
)
elif input_dataset.variables and input_dataset.derived_variables:
# Check so that there are no overlapping variables
if isinstance(input_dataset.variables, list):
variable_vars = input_dataset.variables
elif isinstance(input_dataset.variables, dict):
variable_vars = input_dataset.variables.keys()
else:
raise TypeError(
f"Expected an instance of list or dict, but got {type(input_dataset.variables)}."
)
derived_variable_vars = input_dataset.derived_variables.keys()
common_vars = list(set(variable_vars) & set(derived_variable_vars))
if len(common_vars) > 0:
raise InvalidConfigException(
"Both `variables` and `derived_variables` include the following variables name(s):"
f" '{', '.join(common_vars)}'. This is not allowed. Make sure that there"
" are no overlapping variable names between `variables` and `derived_variables`,"
f" either by renaming or removing '{', '.join(common_vars)}' from one of them."
)
return config


if __name__ == "__main__":
import argparse
Expand All @@ -338,7 +386,7 @@ class _(JSONWizard.Meta):
)
args = argparser.parse_args()

config = Config.from_yaml_file(args.f)
config = Config.load_config(args.f)
import rich

rich.print(config)
2 changes: 1 addition & 1 deletion mllam_data_prep/create_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def create_dataset_zarr(fp_config, fp_zarr: str = None):
The path to the zarr file to write the dataset to. If not provided, the zarr file will be written
to the same directory as the config file with the extension changed to '.zarr'.
"""
config = Config.from_yaml_file(file=fp_config)
config = Config.load_config(file=fp_config)

ds = create_dataset(config=config)

Expand Down

0 comments on commit 554f869

Please sign in to comment.