From 0939e5d4ad6b49db96075c84cd1f1d6da21e5be4 Mon Sep 17 00:00:00 2001 From: Mikhail Sveshnikov Date: Wed, 1 Mar 2023 14:18:27 +0300 Subject: [PATCH] Fix declaring classes with abc fields with defaults (#625) --- mlem/cli/types.py | 1 + mlem/cli/utils.py | 57 +++++++++++++++++++++++----------- tests/cli/test_declare.py | 65 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 105 insertions(+), 18 deletions(-) diff --git a/mlem/cli/types.py b/mlem/cli/types.py index cba6b318..cec60c3e 100644 --- a/mlem/cli/types.py +++ b/mlem/cli/types.py @@ -33,6 +33,7 @@ def _add_examples( allow_none=False, default=None, root_cls=root_cls, + force_not_set=False, ), root_cls=root_cls, parent_help=f"Element of {field.path}", diff --git a/mlem/cli/utils.py b/mlem/cli/utils.py index 81355bec..b33c1d59 100644 --- a/mlem/cli/utils.py +++ b/mlem/cli/utils.py @@ -198,26 +198,34 @@ def _get_type_name_alias(type_): return type_.__name__ if type_ is not None else "any" -def anything(type_): +def anything(type_, allow_none: bool): """Creates special type that is named as original type or collection type It returns original object on creation and is needed for nice typename in cli option help """ - return type( - _get_type_name_alias(type_), (), {"__new__": lambda cls, value: value} - ) + + def new(cls, value): # pylint: disable=unused-argument + """Just return the value""" + if allow_none and value == "None": + return None + return value + + return type(_get_type_name_alias(type_), (), {"__new__": new}) def optional(type_): """Creates special type that is named as original type or collection type It allows use string `None` to indicate None value""" + + def new(cls, value): # pylint: disable=unused-argument + """Check if value is string None""" + if value == "None": + return None + return type_(value) + return type( _get_type_name_alias(type_), (), - { - "__new__": lambda cls, value: None - if value == "None" - else type_(value) - }, + {"__new__": new}, ) @@ -231,6 +239,7 @@ def parse_type_field( allow_none: bool, default: Any, root_cls: Type[BaseModel], + force_not_set: bool, ) -> Iterator[CliTypeField]: """Recursively creates CliTypeFields from field description""" if is_list or is_mapping: @@ -278,7 +287,7 @@ def parse_type_field( allow_none=allow_none, path=path, type_=type_, - default=default, + default=default if not force_not_set else NOT_SET, help=help_, is_list=is_list, is_mapping=is_mapping, @@ -333,22 +342,24 @@ def iterate_type_fields( display_as_type(field_type), __root__=(field_type, ...) ) if field_type is Any: - field_type = anything(field_type) + field_type = anything(field_type, field.allow_none) if not isinstance(field_type, type): # skip too complicated stuff continue + required = not force_not_req and bool(field.required) yield from parse_type_field( path=fullname, type_=field_type, help_=get_field_help(cls, name), is_list=field.shape in LIST_LIKE_SHAPES, is_mapping=field.shape in MAPPING_LIKE_SHAPES, - required=not force_not_req and bool(field.required), + required=required, allow_none=field.allow_none, - default=field.default, + default=field.default if required else NOT_SET, root_cls=root_cls, + force_not_set=force_not_req, ) @@ -381,11 +392,18 @@ def _options_from_model( continue if issubclass(field.type_, MlemABC) and field.type_.__is_root__: yield from _options_from_mlem_abc( - ctx, field, path, force_not_set=force_not_set + ctx, + field, + path, + force_not_set=force_not_set or field.default == NOT_SET, ) continue - yield _option_from_field(field, path, force_not_set=force_not_set) + yield _option_from_field( + field, + path, + force_not_set=force_not_set or field.default == NOT_SET, + ) def _options_from_mlem_abc( @@ -506,12 +524,12 @@ def _option_from_field( """Create cli option from field descriptor""" type_ = override_type or field.type_ if force_not_set: - type_ = anything(type_) + type_ = anything(type_, field.allow_none) elif field.allow_none: type_ = optional(type_) option = SetViaFileTyperOption( param_decls=[f"--{path}", path.replace(".", "_")], - type=type_ if not force_not_set else anything(type_), + type=type_ if not force_not_set else anything(type_, field.allow_none), required=field.required and not force_not_set, default=field.default if not field.is_list and not field.is_mapping and not force_not_set @@ -531,7 +549,10 @@ def generator(ctx: CallContext): cls = load_impl_ext(mlem_abc.abs_name, type_name=type_name) except ImportError: return - yield from _options_from_model(cls, ctx) + yield from _options_from_model( + cls, + ctx, + ) return generator diff --git a/tests/cli/test_declare.py b/tests/cli/test_declare.py index 9aedfd13..fd5d67b0 100644 --- a/tests/cli/test_declare.py +++ b/tests/cli/test_declare.py @@ -6,6 +6,7 @@ from mlem.cli.declare import create_declare_mlem_object_subcommand, declare from mlem.contrib.docker import DockerDirBuilder +from mlem.contrib.docker.base import DockerRegistry, RemoteRegistry from mlem.contrib.docker.context import DockerBuildArgs from mlem.contrib.fastapi import FastAPIServer from mlem.contrib.heroku.meta import HerokuEnv @@ -404,6 +405,70 @@ class RootListNested(_MockBuilder): ) +class MockOptionalFieldWithNonOptionalSubfield(_MockBuilder): + """mock""" + + f: Optional[SimpleValue] = None + + +all_test_params.append( + pytest.param( + MockOptionalFieldWithNonOptionalSubfield(), + "", + id="non_optional_subfield_empty", + ) +) +all_test_params.append( + pytest.param( + MockOptionalFieldWithNonOptionalSubfield(f=SimpleValue(value="a")), + "--f.value a", + id="non_optional_subfield_full", + ) +) + + +class ThreeValues(BaseModel): + value: str + with_def: str = "value" + opt: Optional[str] = None + with_def_model: SimpleValue = SimpleValue(value="value") + with_def_abc: DockerRegistry = DockerRegistry() + + +class MockOptionalFieldWithOptionalAndNonOptionalSubfield(_MockBuilder): + """mock""" + + f: Optional[ThreeValues] = None + + +all_test_params.append( + pytest.param( + MockOptionalFieldWithOptionalAndNonOptionalSubfield(), + "", + id="optional_and_non_optional_subfield_empty", + ) +) +all_test_params.append( + pytest.param( + MockOptionalFieldWithOptionalAndNonOptionalSubfield( + f=ThreeValues(value="a") + ), + "--f.value a", + id="optional_and_non_optional_subfield_full", + ) +) + +all_test_params.append( + pytest.param( + MockOptionalFieldWithOptionalAndNonOptionalSubfield( + f=ThreeValues(value="a", with_def_abc=RemoteRegistry(host="aaa")) + ), + "--f.value a --f.with_def_abc remote --f.with_def_abc.host aaa", + id="optional_and_non_optional_subfield_full_abc", + ) +) + + @lru_cache() def _declare_builder_command(type_: str): create_declare_mlem_object_subcommand(