Skip to content

Commit

Permalink
feat: DictConfig class for unstructured configs
Browse files Browse the repository at this point in the history
refactor: change Config.get_config behavior to return reference instead of copy
  • Loading branch information
Jeremy Silver committed Apr 22, 2024
1 parent 0c424f6 commit cba63ee
Show file tree
Hide file tree
Showing 7 changed files with 124 additions and 73 deletions.
6 changes: 3 additions & 3 deletions TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,15 @@

## v0.4.0

- `DictConfig` subclass of `Config` (unstructured configs loaded from JSON/TOML)
- `DCMixin` classmethod to coerce this class's settings or field settings to another type?
- `DataclassMixin` classmethod to coerce this class's settings or field settings to another type?
- E.g. to adapt new settings to a parent class (`CLIAdapterDataclass` example)

## v0.4.1

- `mkdocs` output in `package_data`?
- `_docs` subdirectory?
- Pre-commit hook to run `mkdocs build`
- Takse only a second, but could use file hashes to prevent redundant build
- Takes only a second, but could use file hashes to prevent redundant build, e.g. `sha1sum docs/*.md | sha1sum | head -c 40`
- Need some hook (post-tag?) to require the docs be up-to-date
- documentation
- Dataclass mixins/settings
Expand Down Expand Up @@ -42,6 +41,7 @@
- NOTE: the parsed values themselves have a `_trivia` attribute storing various formatting info
- Use field metadata (`help`?) as comment prior to the field
- For `None`, serialize as commented field?
- `JSON5Dataclass`?
- `ArgparseDataclass`
- Support subparsers
- Test subparsers, groups, mutually exclusive groups
Expand Down
20 changes: 13 additions & 7 deletions docs/config.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,22 +88,28 @@ def print_current_username():
>>> print_current_username()
admin

# update only the local copy
>>> cfg.database.username = 'test'
# update the config by mutating a local reference
>>> cfg.database.username = 'test1'
>>> print_current_username()
admin
test1

# update the config with another object
>>> from copy import deepcopy
>>> cfg2 = deepcopy(cfg)
>>> cfg2.database.username = 'test2'
>>> cfg2.update_config()

# update the global config
>>> cfg.update_config()
>>> cfg2.update_config()
>>> print_current_username()
test
test2
```

Sometimes it is useful to modify the configs temporarily:

```python
>>> print_current_username()
test
test2
>>> cfg.database.username = 'temporary'

# temporarily update global config with the local version
Expand All @@ -113,7 +119,7 @@ temporary

# global config reverts back to its value before 'as_config' was called
>>> print_current_username()
test
test2
```

## Details
Expand Down
51 changes: 26 additions & 25 deletions fancy_dataclass/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,30 +114,30 @@ def configure_argument(cls, parser: ArgumentParser, name: str) -> None:
parser: parser object to update with a new argument
name: Name of the argument to configure"""
kwargs: Dict[str, Any] = {}
field = cls.__dataclass_fields__[name] # type: ignore[attr-defined]
if field.metadata.get('parse_exclude', False): # exclude the argument from the parser
fld = cls.__dataclass_fields__[name] # type: ignore[attr-defined]
settings = ArgparseDataclassFieldSettings.coerce(cls._field_settings(fld))
if settings.parse_exclude: # exclude the argument from the parser
return
group_name = field.metadata.get('group')
if group_name is not None: # add argument to a group instead of the main parser
if (group_name := settings.group) is not None: # add argument to a group instead of the main parser
for group in getattr(parser, '_action_groups', []): # get argument group with the given name
if getattr(group, 'title', None) == group_name:
break
else: # group not found, so create it
group_kwargs = {}
if issubclass_safe(field.type, ArgparseDataclass): # get kwargs from nested ArgparseDataclass
group_kwargs = field.type.parser_kwargs()
if issubclass_safe(fld.type, ArgparseDataclass): # get kwargs from nested ArgparseDataclass
group_kwargs = fld.type.parser_kwargs()
group = parser.add_argument_group(group_name, **group_kwargs)
parser = group
if issubclass_safe(field.type, ArgparseDataclass):
if issubclass_safe(fld.type, ArgparseDataclass):
# recursively configure a nested ArgparseDataclass field
field.type.configure_parser(parser)
fld.type.configure_parser(parser)
return
# determine the type of the parser argument for the field
tp = field.metadata.get('type', field.type)
action = field.metadata.get('action', 'store')
tp = settings.type or fld.type
action = settings.action or 'store'
origin_type = get_origin(tp)
if origin_type is not None: # compound type
if type_is_optional(tp):
if type_is_optional(tp): # type: ignore[arg-type]
kwargs['default'] = None
if origin_type == ClassVar: # by default, exclude ClassVars from the parser
return
Expand All @@ -151,18 +151,19 @@ def configure_argument(cls, parser: ArgumentParser, name: str) -> None:
raise ValueError(f'cannot infer type of items in field {name!r}')
if issubclass_safe(origin_type, list) and (action == 'store'):
kwargs['nargs'] = '*' # allow multiple arguments by default
if issubclass_safe(tp, IntEnum): # use a bare int type
if issubclass_safe(tp, IntEnum): # type: ignore[arg-type]
# use a bare int type
tp = int
kwargs['type'] = tp
# determine the default value
if field.default == MISSING:
if field.default_factory != MISSING:
kwargs['default'] = field.default_factory()
if fld.default == MISSING:
if fld.default_factory != MISSING:
kwargs['default'] = fld.default_factory()
else:
kwargs['default'] = field.default
kwargs['default'] = fld.default
# get the names of the arguments associated with the field
if 'args' in field.metadata:
args = field.metadata['args']
args = settings.args
if args is not None:
if isinstance(args, str):
args = [args]
# argument is positional if it is explicitly given without a leading dash
Expand All @@ -171,26 +172,26 @@ def configure_argument(cls, parser: ArgumentParser, name: str) -> None:
# no default available, so make the field a required option
kwargs['required'] = True
else:
argname = field.name.replace('_', '-')
argname = fld.name.replace('_', '-')
positional = (tp is not bool) and ('default' not in kwargs)
if positional:
args = [argname]
else:
# use a single dash for 1-letter names
prefix = '-' if (len(field.name) == 1) else '--'
prefix = '-' if (len(fld.name) == 1) else '--'
args = [prefix + argname]
if field.metadata.get('args') and (not positional):
if args and (not positional):
# store the argument based on the name of the field, and not whatever flag name was provided
kwargs['dest'] = field.name
if field.type is bool: # use boolean flag instead of an argument
kwargs['dest'] = fld.name
if fld.type is bool: # use boolean flag instead of an argument
kwargs['action'] = 'store_true'
for key in ('type', 'required'):
with suppress(KeyError):
kwargs.pop(key)
# extract additional items from metadata
for key in cls.parser_argument_kwarg_names():
if key in field.metadata:
kwargs[key] = field.metadata[key]
if key in fld.metadata:
kwargs[key] = fld.metadata[key]
if (kwargs.get('action') == 'store_const'):
del kwargs['type']
parser.add_argument(*args, **kwargs)
Expand Down
75 changes: 52 additions & 23 deletions fancy_dataclass/config.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from abc import ABC, abstractmethod
from contextlib import contextmanager
from copy import deepcopy
from dataclasses import is_dataclass, make_dataclass
from pathlib import Path
from typing import ClassVar, Iterator, Optional, Type
from typing import Any, ClassVar, Dict, Iterator, Optional, Type

from typing_extensions import Self

Expand All @@ -13,9 +13,9 @@


class Config:
"""Base class for a collection of configurations.
"""Base class for storing a collection of configurations.
This uses a quasi-Singleton pattern by storing a class attribute with the current global configurations, which can be retrieved or updated by the user."""
Subclasses may store a class attribute, `_config`, with the current global configurations, which can be retrieved or updated by the user."""

_config: ClassVar[Optional[Self]] = None

Expand All @@ -25,7 +25,8 @@ def get_config(cls) -> Optional[Self]:
Returns:
Global configuration object (`None` if not set)"""
return deepcopy(cls._config) # type: ignore[return-value]
# return deepcopy(cls._config) # type: ignore[return-value]
return cls._config # type: ignore[return-value]

@classmethod
def _set_config(cls, config: Optional[Self]) -> None:
Expand Down Expand Up @@ -53,7 +54,22 @@ def as_config(self) -> Iterator[None]:
type(self)._set_config(orig_config)


class ConfigDataclass(Config, DictDataclass, suppress_defaults=False):
class FileConfig(Config, ABC):
"""A collection of configurations that can be loaded from a file."""

@classmethod
@abstractmethod
def load_config(cls, path: AnyPath) -> Self:
"""Loads configurations from a file and sets them to be the global configurations for this class.
Args:
path: File from which to load configurations
Returns:
The newly loaded global configurations"""


class ConfigDataclass(DictDataclass, FileConfig, suppress_defaults=False):
"""A dataclass representing a collection of configurations.
The configurations can be loaded from a file, the type of which will be inferred from its extension.
Expand All @@ -75,33 +91,46 @@ def _wrap(tp: type) -> type:
return _wrap(dataclass_type_map(cls, _wrap)) # type: ignore[arg-type]

@classmethod
def _get_dataclass_type_for_extension(cls, ext: str) -> Type[FileSerializable]:
ext_lower = ext.lower()
def _get_dataclass_type_for_path(cls, path: AnyPath) -> Type[FileSerializable]:
p = Path(path)
if not p.suffix:
raise ValueError(f'filename {p} has no extension')
ext_lower = p.suffix.lower()
if ext_lower == '.json':
from fancy_dataclass.json import JSONDataclass
return JSONDataclass
elif ext_lower == '.toml':
from fancy_dataclass.toml import TOMLDataclass
return TOMLDataclass
else:
raise ValueError(f'unknown config file extension {ext!r}')
raise ValueError(f'unknown config file extension {p.suffix!r}')

@classmethod
def load_config(cls, path: AnyPath) -> Self:
"""Loads configurations from a file and sets them to be the global configurations for this class.
def load_config(cls, path: AnyPath) -> Self: # noqa: D102
tp = cls._get_dataclass_type_for_path(path)
new_cls: Type[FileSerializable] = ConfigDataclass._wrap_config_dataclass(tp, cls) # type: ignore
with open(path) as fp:
cfg: Self = coerce_to_dataclass(cls, new_cls._from_file(fp))
cfg.update_config()
return cfg

Args:
path: File from which to load configurations

Returns:
The newly loaded global configurations"""
p = Path(path)
ext = p.suffix
if not ext:
raise ValueError(f'filename {p} has no extension')
tp = cls._get_dataclass_type_for_extension(ext)
new_cls: Type[FileSerializable] = ConfigDataclass._wrap_config_dataclass(tp, cls) # type: ignore
with open(path) as f:
cfg: Self = coerce_to_dataclass(cls, new_cls._from_file(f))
class DictConfig(Config, Dict[Any, Any]):
"""A collection of configurations, stored as a Python dict.
To impose a type schema on the configurations, use [`ConfigDataclass`][fancy_dataclass.config.ConfigDataclass] instead.
The configurations can be loaded from a file, the type of which will be inferred from its extension.
Supported file types are:
- JSON
- TOML
"""

@classmethod
def load_config(cls, path: AnyPath) -> Self: # noqa: D102
tp = ConfigDataclass._get_dataclass_type_for_path(path)
with open(path) as fp:
cfg = cls(tp._text_file_to_dict(fp)) # type: ignore[attr-defined]
cfg.update_config()
return cfg
2 changes: 1 addition & 1 deletion fancy_dataclass/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _configure_mixin_settings(cls: Type['DataclassMixin'], **kwargs: Any) -> Non
cls.__settings__ = stype(**d)

def _configure_field_settings_type(cls: Type['DataclassMixin']) -> None:
"""Sets up the __field_settings_type__ attribute on a `DataclassMixin` subclass at definition type.
"""Sets up the __field_settings_type__ attribute on a `DataclassMixin` subclass at definition time.
This reconciles any such attributes inherited from multiple parent classes."""
stype = cls.__dict__.get('__field_settings_type__')
if stype is None:
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ disable_error_code = ["assignment"]

[[tool.mypy.overrides]]
module = "tests.test_config"
disable_error_code = ["misc", "union-attr"]
disable_error_code = ["index", "misc", "union-attr"]

[[tool.mypy.overrides]]
module = "tests.test_inheritance"
Expand Down
Loading

0 comments on commit cba63ee

Please sign in to comment.