diff --git a/utype/__init__.py b/utype/__init__.py index ba8e7be..2ec7bad 100644 --- a/utype/__init__.py +++ b/utype/__init__.py @@ -12,7 +12,7 @@ register_transformer = TypeTransformer.registry.register -VERSION = (0, 4, 1, None) +VERSION = (0, 5, 0, None) def _get_version(): diff --git a/utype/parser/base.py b/utype/parser/base.py index d0c45b7..1cc4ae2 100644 --- a/utype/parser/base.py +++ b/utype/parser/base.py @@ -1,5 +1,6 @@ import inspect import sys +import warnings from typing import Callable, Dict, List, Optional, Set, Tuple, Type, Union from ..utils import exceptions as exc diff --git a/utype/parser/cls.py b/utype/parser/cls.py index b7d59cc..6910591 100644 --- a/utype/parser/cls.py +++ b/utype/parser/cls.py @@ -2,7 +2,6 @@ import warnings from collections.abc import Mapping from functools import partial -from types import FunctionType from typing import Callable, Dict, Type, TypeVar from ..utils import exceptions as exc @@ -525,6 +524,17 @@ def __init__(_obj_self, _d: dict = None, **kwargs): return __init__ + @property + def schema_annotations(self): + # this is meant to be extended and override + # if the result is not None, it will become the x-annotation of the JSON schema output + data = dict() + if self.options.mode: + data.update(mode=self.options.mode) + if self.options.case_insensitive: + data.update(case_insensitive=self.options.case_insensitive) + return data + def init_dataclass( cls: Type[T], data, options: Options = None, context: RuntimeContext = None diff --git a/utype/parser/field.py b/utype/parser/field.py index e268209..55e7dab 100644 --- a/utype/parser/field.py +++ b/utype/parser/field.py @@ -319,6 +319,14 @@ def __call__(self, fn_or_cls, *args, **kwargs): setattr(fn_or_cls, "__field__", self) return fn_or_cls + @property + def schema_annotations(self): + return {} + + @property + def default_type(self): + return None + class Param(Field): def __init__( @@ -729,17 +737,18 @@ def is_case_insensitive(self, options: Options) -> bool: # return value() # return copy_value(value) - def get_default(self, options: Options, defer: bool = False): + def get_default(self, options: Options, defer: Optional[bool] = False): # options = options or self.options if options.no_default: return unprovided - if not defer: - if self.defer_default or options.defer_default: - return unprovided - else: - if not self.defer_default and not options.defer_default: - return unprovided + if isinstance(defer, bool): + if not defer: + if self.defer_default or options.defer_default: + return unprovided + else: + if not self.defer_default and not options.defer_default: + return unprovided if not unprovided(options.force_default): default = options.force_default @@ -763,10 +772,6 @@ def get_on_error(self, options: Options): return self.on_error return options.invalid_values - def get_example(self): - if not unprovided(self.field.example): - return self.field.example - def is_required(self, options: Options): if options.ignore_required or not self.required: return False @@ -1068,6 +1073,12 @@ def get_field(cls, annotation: Any, default, **kwargs): else: return default + @property + def schema_annotations(self): + # this is meant to be extended and override + # if the result is not None, it will become the x-annotation of the JSON schema output + return self.field.schema_annotations + @classmethod def generate( cls, @@ -1235,6 +1246,10 @@ def generate( if not dependencies and field.dependencies: dependencies = field.dependencies + if annotation is None: + # a place to inject + annotation = field.default_type + input_type = _cls.rule_cls.parse_annotation( annotation=annotation, constraints=field.constraints, diff --git a/utype/specs/json_schema.py b/utype/specs/json_schema.py index 1a599a6..92fb0bf 100644 --- a/utype/specs/json_schema.py +++ b/utype/specs/json_schema.py @@ -12,6 +12,8 @@ from ipaddress import IPv4Address, IPv6Address from typing import Optional, Type, Union, Dict from ..utils.datastructures import unprovided +from ..utils.compat import JSON_TYPES +from enum import EnumMeta class JsonSchemaGenerator: @@ -118,6 +120,30 @@ def generate_for_type(self, t: type): return self.generate_for_dataclass(t) elif isinstance(t, LogicalType) and t.combinator: return self.generate_for_logical(t) + elif isinstance(t, EnumMeta): + base = t.__base__ + enum_type = None + enum_values = [] + enum_map = {} + for key, val in t.__members__.items(): + enum_values.append(val.value) + enum_map[key] = val.value + enum_type = type(val.value) + if not isinstance(base, EnumMeta): + enum_type = base + prim = self._get_primitive(enum_type) + fmt = self._get_format(enum_type) + data = { + "type": prim, + "enum": enum_values, + "x-annotation": { + "enums": enum_map + } + } + if fmt: + data.update(format=fmt) + return data + # default common type prim = self._get_primitive(t) fmt = self._get_format(t) @@ -138,6 +164,9 @@ def generate_for_logical(self, t: LogicalType): def _get_format(self, origin: type) -> Optional[str]: if not origin: return None + format = getattr(origin, 'format', None) + if format and isinstance(format, str): + return format for types, f in self.FORMAT_MAP.items(): if issubclass(origin, types): return f @@ -289,9 +318,21 @@ def generate_for_field(self, f: ParserField, options: Options = None) -> Optiona elif f.field.mode == 'w': data.update(writeOnly=True) if not unprovided(f.field.example) and f.field.example is not None: - data.update(examples=[f.field.example]) + example = f.field.example + if type(f.field.example) not in JSON_TYPES: + example = str(f.field.example) + data.update(examples=[example]) if f.aliases: - data.update(aliases=list(f.aliases)) + aliases = list(f.aliases) + if aliases: + # sort to stay identical + aliases.sort() + data.update(aliases=aliases) + annotations = f.schema_annotations + if annotations: + data.update({ + 'x-annotation': annotations + }) return data # todo: de-duplicate generated schema class like UserSchema['a'] @@ -337,6 +378,12 @@ def generate_for_dataclass(self, t): else: data.update(additionalProperties=addition) + annotations = parser.schema_annotations + if annotations: + data.update({ + 'x-annotation': annotations + }) + if isinstance(self.defs, dict): return {"$ref": f"{self.ref_prefix}{self.set_def(cls_name, t, data)}"} return data @@ -372,3 +419,11 @@ def generate_for_function(self, f): else: data.update(additionalParameters=addition) return data + +# REVERSE ACTION OF GENERATE: +# --- GENERATE Schema and types based on Json schema + + +class JsonSchemaParser: + def __init__(self, json_schema: dict): + pass diff --git a/utype/utils/compat.py b/utype/utils/compat.py index 1c8d76e..3527dc0 100644 --- a/utype/utils/compat.py +++ b/utype/utils/compat.py @@ -34,6 +34,7 @@ "is_classvar", "is_annotated", "evaluate_forward_ref", + 'JSON_TYPES' ] if sys.version_info < (3, 8): diff --git a/utype/utils/encode.py b/utype/utils/encode.py index cbfc6eb..e6c21ad 100644 --- a/utype/utils/encode.py +++ b/utype/utils/encode.py @@ -7,6 +7,7 @@ from .base import TypeRegistry import json from .datastructures import unprovided +from ipaddress import IPv4Address, IPv6Address, IPv4Network, IPv6Network encoder_registry = TypeRegistry('encoder', cache=True, shortcut='__encoder__') @@ -98,6 +99,16 @@ def from_datetime(data: Union[datetime, date]): return data.isoformat() +@register_encoder(IPv4Network, IPv4Address, IPv6Network, IPv6Address) +def from_ip(data): + return str(data) + + +@register_encoder(IPv4Network) +def from_datetime(data): + return str(data) + + @register_encoder(timedelta) def from_duration(data: timedelta): return duration_iso_string(data) diff --git a/utype/utils/example.py b/utype/utils/example.py new file mode 100644 index 0000000..5c2f9cd --- /dev/null +++ b/utype/utils/example.py @@ -0,0 +1,231 @@ +import inspect +import math +import warnings +import random +from decimal import Decimal +from enum import Enum, EnumMeta +from ..utils.datastructures import unprovided +from ..parser.field import ParserField +from ..parser.base import BaseParser +from ..parser.rule import LogicalType, Rule, SEQ_TYPES, MAP_TYPES +from uuid import UUID +from datetime import date, datetime, timedelta, time +from typing import Type + +VALUE_TYPES = (str, int, float, Decimal, datetime, date, time, timedelta) + + +def get_example_from_json_schema(schema): + pass + + +def get_example_from(t: type): + if t == type(None): + return None + + if t == bool: + return random.choice([True, False]) + + if t == UUID: + import uuid + return uuid.uuid4() + + if isinstance(t, EnumMeta): + val = random.choice(t.__members__.values()) # noqa + return t(val.value) + + parser = getattr(t, '__parser__', None) + if isinstance(parser, BaseParser): + return t(**get_example_from_parser(parser)) + + if inspect.isclass(t) and issubclass(t, Rule): + return get_example_from_rule(t) + + if isinstance(t, LogicalType): + return t.get_example() + + return t() + + +def get_example_from_field(field: ParserField): + if not unprovided(field.field.example): + return field.field.example + return get_example_from(field.type) + + +def get_example_from_parser(self): + data = {} + for name, field in self.fields.items(): + try: + example = get_example_from_field(field) + except Exception as e: + warnings.warn(f'{self.obj}: generate example for field: [{repr(name)}] failed with error: {e}') + continue + data[name] = example + return data + + +def get_example_from_rule(cls: Type[Rule]): + """ + If example is forced and there is unsolvable rules (validator / converter) and no example provided + will prompt error to ask provide example + """ + if hasattr(cls, 'const'): + return cls.const + + if hasattr(cls, 'enum'): + return random.choice(cls.enum) + + if hasattr(cls, 'regex'): + try: + import exrex # noqa + return exrex.getone(cls.regex) + except (ModuleNotFoundError, AttributeError): + pass + + t = cls.__origin__ + + if t in SEQ_TYPES: + if cls.__args__: + values = [] + if cls.__ellipsis_args__: + # tuple + for arg in cls.__args__: + values.append( + get_example_from(arg) + ) + else: + values.append( + get_example_from(cls.__args__[0]) + ) + return t(values) + + if t in MAP_TYPES: + if cls.__args__: + values = {} + key_type = cls.__args__[0] + val_type = None + if len(cls.__args__) > 1: + val_type = cls.__args__[1] + key = get_example_from(key_type) + val = get_example_from(val_type) if val_type else None + values[key] = val + return t(values) + + if t not in VALUE_TYPES: + return get_example_from(t) + + multi_of = getattr(cls, 'multiple_of', None) + + length = getattr(cls, 'length', None) + min_length = getattr(cls, 'min_length', None) + max_length = getattr(cls, 'max_length', None) + + round_v = getattr(cls, 'decimal_places', None) + + ge = getattr(cls, 'ge', None) + gt = getattr(cls, 'gt', None) + le = getattr(cls, 'le', None) + lt = getattr(cls, 'lt', None) + min_value = getattr(cls, 'ge', getattr(cls, 'gt', None)) + max_value = getattr(cls, 'le', getattr(cls, 'lt', None)) + + if min_value is None: + if max_value is None: + if t == datetime: + return datetime.now() + elif t == date: + return datetime.now().date() + elif t == time: + return datetime.now().time() + elif t == timedelta: + v = datetime.now().time() + return timedelta(hours=v.hour, minutes=v.minute, seconds=v.second, microseconds=v.microsecond) + else: + if isinstance(multi_of, (int, float)): + return multi_of * random.randint(1, 10) + + val_length = length + if val_length is None: + min_len = min_length or 0 + max_len = max_length or min_len + 10 + val_length = int(min_len + (max_len - min_len) * random.random()) + + if not val_length: + return t() + + if t == str: + import string + return ''.join(random.sample(string.digits + string.ascii_letters, val_length)) + + elif t in (float, Decimal): + if round_v and round_v > 0: + val_length = max(val_length - round_v - 1, 1) + val = random.randint(10 ** (val_length - 1), 10 ** val_length - 1) + random.random() + if round_v is not None: + val = round(val, round_v) + while len(str(val)) < val_length: + val = float(str(val) + '1') + return t(val) + + elif t == int: + return random.randint(10 ** (val_length - 1), 10 ** val_length - 1) + else: + if t in (int, float, Decimal): + if isinstance(multi_of, (int, float)) and multi_of: + times = int(max_value / multi_of) + return multi_of * random.randint(max(1, times - 3), times) + + min_value = (max_value / 2) if max_value > 0 else max_value * 2 + elif t in (datetime, date): + min_value = max_value - timedelta(days=1) # noqa + elif t == timedelta: + min_value = min(timedelta(), max_value - timedelta(days=1)) # noqa + elif t == time: + min_value = time() + + elif max_value is None: + if t in (int, float, Decimal): + if isinstance(multi_of, (int, float)) and multi_of: + times = math.ceil(min_value / multi_of) + return multi_of * random.randint(times, times + 5) + + max_value = (min_value * 2) if min_value > 0 else min_value / 2 + elif t in (datetime, date, timedelta): + max_value = min_value + timedelta(days=1) # noqa + elif t == time: + max_value = time(23, 59, 59) + + if max_value is not None and min_value is not None: + try: + delta = max_value - min_value + if isinstance(delta, int): + if isinstance(multi_of, (int, float)) and multi_of: + max_times = int(max_value / multi_of) + min_times = math.ceil(min_value / multi_of) + return multi_of * random.randint(min_times, max_times) + + v = min_value + random.randint(0, delta) + if v == gt: + v = v + 1 + elif v == lt: + v = v - 1 + else: + v = min_value + delta * random.random() + + v = t(v) + if round_v is not None: + v = round(v, round_v) + return v + except TypeError: + # like str + seq = [] + if le: + seq.append(le) + if ge: + seq.append(ge) + if seq: + return random.choice(seq) + if isinstance(min_value, str): + return min_value + '_' + str(random.randint(0, 9)) + return t()