Skip to content
This repository has been archived by the owner on Sep 13, 2023. It is now read-only.

Commit

Permalink
Fix declaring classes with abc fields with defaults (#625)
Browse files Browse the repository at this point in the history
  • Loading branch information
mike0sv authored Mar 1, 2023
1 parent daf022e commit 0939e5d
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 18 deletions.
1 change: 1 addition & 0 deletions mlem/cli/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def _add_examples(
allow_none=False,
default=None,
root_cls=root_cls,
force_not_set=False,
),
root_cls=root_cls,
parent_help=f"Element of {field.path}",
Expand Down
57 changes: 39 additions & 18 deletions mlem/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,26 +198,34 @@ def _get_type_name_alias(type_):
return type_.__name__ if type_ is not None else "any"


def anything(type_):
def anything(type_, allow_none: bool):
"""Creates special type that is named as original type or collection type
It returns original object on creation and is needed for nice typename in cli option help
"""
return type(
_get_type_name_alias(type_), (), {"__new__": lambda cls, value: value}
)

def new(cls, value): # pylint: disable=unused-argument
"""Just return the value"""
if allow_none and value == "None":
return None
return value

return type(_get_type_name_alias(type_), (), {"__new__": new})


def optional(type_):
"""Creates special type that is named as original type or collection type
It allows use string `None` to indicate None value"""

def new(cls, value): # pylint: disable=unused-argument
"""Check if value is string None"""
if value == "None":
return None
return type_(value)

return type(
_get_type_name_alias(type_),
(),
{
"__new__": lambda cls, value: None
if value == "None"
else type_(value)
},
{"__new__": new},
)


Expand All @@ -231,6 +239,7 @@ def parse_type_field(
allow_none: bool,
default: Any,
root_cls: Type[BaseModel],
force_not_set: bool,
) -> Iterator[CliTypeField]:
"""Recursively creates CliTypeFields from field description"""
if is_list or is_mapping:
Expand Down Expand Up @@ -278,7 +287,7 @@ def parse_type_field(
allow_none=allow_none,
path=path,
type_=type_,
default=default,
default=default if not force_not_set else NOT_SET,
help=help_,
is_list=is_list,
is_mapping=is_mapping,
Expand Down Expand Up @@ -333,22 +342,24 @@ def iterate_type_fields(
display_as_type(field_type), __root__=(field_type, ...)
)
if field_type is Any:
field_type = anything(field_type)
field_type = anything(field_type, field.allow_none)

if not isinstance(field_type, type):
# skip too complicated stuff
continue

required = not force_not_req and bool(field.required)
yield from parse_type_field(
path=fullname,
type_=field_type,
help_=get_field_help(cls, name),
is_list=field.shape in LIST_LIKE_SHAPES,
is_mapping=field.shape in MAPPING_LIKE_SHAPES,
required=not force_not_req and bool(field.required),
required=required,
allow_none=field.allow_none,
default=field.default,
default=field.default if required else NOT_SET,
root_cls=root_cls,
force_not_set=force_not_req,
)


Expand Down Expand Up @@ -381,11 +392,18 @@ def _options_from_model(
continue
if issubclass(field.type_, MlemABC) and field.type_.__is_root__:
yield from _options_from_mlem_abc(
ctx, field, path, force_not_set=force_not_set
ctx,
field,
path,
force_not_set=force_not_set or field.default == NOT_SET,
)
continue

yield _option_from_field(field, path, force_not_set=force_not_set)
yield _option_from_field(
field,
path,
force_not_set=force_not_set or field.default == NOT_SET,
)


def _options_from_mlem_abc(
Expand Down Expand Up @@ -506,12 +524,12 @@ def _option_from_field(
"""Create cli option from field descriptor"""
type_ = override_type or field.type_
if force_not_set:
type_ = anything(type_)
type_ = anything(type_, field.allow_none)
elif field.allow_none:
type_ = optional(type_)
option = SetViaFileTyperOption(
param_decls=[f"--{path}", path.replace(".", "_")],
type=type_ if not force_not_set else anything(type_),
type=type_ if not force_not_set else anything(type_, field.allow_none),
required=field.required and not force_not_set,
default=field.default
if not field.is_list and not field.is_mapping and not force_not_set
Expand All @@ -531,7 +549,10 @@ def generator(ctx: CallContext):
cls = load_impl_ext(mlem_abc.abs_name, type_name=type_name)
except ImportError:
return
yield from _options_from_model(cls, ctx)
yield from _options_from_model(
cls,
ctx,
)

return generator

Expand Down
65 changes: 65 additions & 0 deletions tests/cli/test_declare.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from mlem.cli.declare import create_declare_mlem_object_subcommand, declare
from mlem.contrib.docker import DockerDirBuilder
from mlem.contrib.docker.base import DockerRegistry, RemoteRegistry
from mlem.contrib.docker.context import DockerBuildArgs
from mlem.contrib.fastapi import FastAPIServer
from mlem.contrib.heroku.meta import HerokuEnv
Expand Down Expand Up @@ -404,6 +405,70 @@ class RootListNested(_MockBuilder):
)


class MockOptionalFieldWithNonOptionalSubfield(_MockBuilder):
"""mock"""

f: Optional[SimpleValue] = None


all_test_params.append(
pytest.param(
MockOptionalFieldWithNonOptionalSubfield(),
"",
id="non_optional_subfield_empty",
)
)
all_test_params.append(
pytest.param(
MockOptionalFieldWithNonOptionalSubfield(f=SimpleValue(value="a")),
"--f.value a",
id="non_optional_subfield_full",
)
)


class ThreeValues(BaseModel):
value: str
with_def: str = "value"
opt: Optional[str] = None
with_def_model: SimpleValue = SimpleValue(value="value")
with_def_abc: DockerRegistry = DockerRegistry()


class MockOptionalFieldWithOptionalAndNonOptionalSubfield(_MockBuilder):
"""mock"""

f: Optional[ThreeValues] = None


all_test_params.append(
pytest.param(
MockOptionalFieldWithOptionalAndNonOptionalSubfield(),
"",
id="optional_and_non_optional_subfield_empty",
)
)
all_test_params.append(
pytest.param(
MockOptionalFieldWithOptionalAndNonOptionalSubfield(
f=ThreeValues(value="a")
),
"--f.value a",
id="optional_and_non_optional_subfield_full",
)
)

all_test_params.append(
pytest.param(
MockOptionalFieldWithOptionalAndNonOptionalSubfield(
f=ThreeValues(value="a", with_def_abc=RemoteRegistry(host="aaa"))
),
"--f.value a --f.with_def_abc remote --f.with_def_abc.host aaa",
id="optional_and_non_optional_subfield_full_abc",
)
)


@lru_cache()
def _declare_builder_command(type_: str):
create_declare_mlem_object_subcommand(
Expand Down

0 comments on commit 0939e5d

Please sign in to comment.