From a10d27e8ecfe534edab4cd7b7ff1c3935265feae Mon Sep 17 00:00:00 2001 From: Ken Kroenlein Date: Wed, 13 Mar 2024 12:07:14 -0600 Subject: [PATCH 1/2] Implement GemdQuery obejcts for table building --- src/citrine/__version__.py | 2 +- src/citrine/gemd_queries/__init__.py | 0 src/citrine/gemd_queries/criteria.py | 162 ++++++++++++++++++ src/citrine/gemd_queries/filter.py | 223 +++++++++++++++++++++++++ src/citrine/gemd_queries/gemd_query.py | 64 +++++++ src/citrine/resources/table_config.py | 4 + tests/gemd_query/__init__.py | 0 tests/gemd_query/test_gemd_query.py | 45 +++++ tests/utils/factories.py | 55 ++++++ 9 files changed, 554 insertions(+), 1 deletion(-) create mode 100644 src/citrine/gemd_queries/__init__.py create mode 100644 src/citrine/gemd_queries/criteria.py create mode 100644 src/citrine/gemd_queries/filter.py create mode 100644 src/citrine/gemd_queries/gemd_query.py create mode 100644 tests/gemd_query/__init__.py create mode 100644 tests/gemd_query/test_gemd_query.py diff --git a/src/citrine/__version__.py b/src/citrine/__version__.py index 131942e76..f5f41e567 100644 --- a/src/citrine/__version__.py +++ b/src/citrine/__version__.py @@ -1 +1 @@ -__version__ = "3.0.2" +__version__ = "3.1.0" diff --git a/src/citrine/gemd_queries/__init__.py b/src/citrine/gemd_queries/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/citrine/gemd_queries/criteria.py b/src/citrine/gemd_queries/criteria.py new file mode 100644 index 000000000..d4f0c3847 --- /dev/null +++ b/src/citrine/gemd_queries/criteria.py @@ -0,0 +1,162 @@ +"""Definitions for GemdQuery objects, and their sub-objects.""" +from typing import List, Type + +from gemd.enumeration.base_enumeration import BaseEnumeration + +from citrine._serialization.serializable import Serializable +from citrine._serialization.polymorphic_serializable import PolymorphicSerializable +from citrine._serialization import properties +from citrine.gemd_queries.filter import PropertyFilterType + +__all__ = ['MaterialClassification', 'TextSearchType', + 'Criteria', + 'AndOperator', 'OrOperator', + 'PropertiesCriteria', 'NameCriteria', 'MaterialRunClassificationCriteria', + 'MaterialTemplatesCriteria' + ] + + +class MaterialClassification(BaseEnumeration): + """A classification based on where in a Material History you find a Material.""" + + ATOMIC_INGREDIENT = "atomic_ingredient" + INTERMEDIATE_INGREDIENT = "intermediate_ingredient" + TERMINAL_MATERIAL = "terminal_material" + + +class TextSearchType(BaseEnumeration): + """The style of text search to run.""" + + EXACT = "exact" + PREFIX = "prefix" + SUFFIX = "suffix" + SUBSTRING = "substring" + + +class Criteria(PolymorphicSerializable): + """Abstract concept of a criteria to apply when searching for materials.""" + + @classmethod + def get_type(cls, data) -> Type[Serializable]: + """Return the subtype.""" + classes: List[Type[Criteria]] = [ + AndOperator, OrOperator, + PropertiesCriteria, NameCriteria, MaterialRunClassificationCriteria, + MaterialTemplatesCriteria + ] + return {klass.typ: klass for klass in classes}[data['type']] + + +class AndOperator(Serializable['AndOperator'], Criteria): + """ + Combine multiple criteria, requiring EACH to be true for a match. + + Parameters + ---------- + criteria: Criteria + List of conditions all responses must satisfy (i.e., joined with an AND). + + """ + + criteria = properties.List(properties.Object(Criteria), "criteria") + typ = properties.String('type', default="and_operator", deserializable=False) + + +class OrOperator(Serializable['OrOperator'], Criteria): + """ + Combine multiple criteria, requiring ANY to be true for a match. + + Parameters + ---------- + criteria: Criteria + List of conditions, at least one of which must match (i.e., joined with an OR). + + """ + + criteria = properties.List(properties.Object(Criteria), "criteria") + typ = properties.String('type', default="or_operator", deserializable=False) + + +class PropertiesCriteria(Serializable['PropertiesCriteria'], Criteria): + """ + Look for materials with a particular Property and optionally Value types & ranges. + + Parameters + ---------- + property_templates_filter: Set[UUID] + The citrine IDs of the property templates matches must reference. + value_type_filter: Optional[PropertyFilterType] + The value range matches must conform to. + + """ + + property_templates_filter = properties.Set(properties.UUID, "property_templates_filter") + value_type_filter = properties.Optional( + properties.Object(PropertyFilterType), "value_type_filter" + ) + typ = properties.String('type', default="properties_criteria", deserializable=False) + + +class NameCriteria(Serializable['NameCriteria'], Criteria): + """ + Look for materials with particular names. + + Parameters + ---------- + name: str + The name the returned objects must have. + search_type: TextSearchType + What kind of string match to use (exact, substring, ...). + + """ + + name = properties.String('name') + search_type = properties.Enumeration(TextSearchType, 'search_type') + typ = properties.String('type', default="name_criteria", deserializable=False) + + +class MaterialRunClassificationCriteria( + Serializable['MaterialRunClassificationCriteria'], + Criteria +): + """ + Look for materials with particular classification, defined by MaterialClassification. + + Parameters + ---------- + classifications: Set[MaterialClassification] + The classification, based on where in a material history an object appears. + + """ + + classifications = properties.Set( + properties.Enumeration(MaterialClassification), 'classifications' + ) + typ = properties.String( + 'type', + default="material_run_classification_criteria", + deserializable=False + ) + + +class MaterialTemplatesCriteria(Serializable['MaterialTemplatesCriteria'], Criteria): + """ + Look for materials with particular Material Templates and tags. + + This has a similar behavior to the old [[MaterialRunByTemplate]] Row definition + + Parameters + ---------- + material_templates_identifiers: Set[UUID] + Which material templates to filter by. + tag_filters: Set[str] + Which tags to filter by. + + """ + + material_templates_identifiers = properties.Set( + properties.UUID, + "material_templates_identifiers" + ) + tag_filters = properties.Set(properties.String, 'tag_filters') + typ = properties.String('type', default="material_template_criteria", deserializable=False) diff --git a/src/citrine/gemd_queries/filter.py b/src/citrine/gemd_queries/filter.py new file mode 100644 index 000000000..ebc739ae9 --- /dev/null +++ b/src/citrine/gemd_queries/filter.py @@ -0,0 +1,223 @@ +"""Definitions for GemdQuery objects, and their sub-objects.""" +from typing import List, Type + +from gemd.enumeration.base_enumeration import BaseEnumeration + +from citrine._serialization.serializable import Serializable +from citrine._serialization.polymorphic_serializable import PolymorphicSerializable +from citrine._serialization import properties + +__all__ = ['PropertyFilterType', + 'RealFilter', 'IntegerFilter', + 'MaterialClassification', 'TextSearchType', + ] + + +class RealFilter(Serializable['RealFilter']): + """ + A general filter for Real/Continuous Values. + + Parameters + ---------- + unit: str + The units associated with the floating point values for this filter. + lower_filter: str + The lower bound on this filter range. + upper_filter: str + The upper bound on this filter range. + lower_is_inclusive: bool + Whether the lower bound value included in the valid range. + upper_is_inclusive: bool + Whether the upper bound value included in the valid range. + + """ + + unit = properties.String('unit') + lower_filter = properties.Optional(properties.Float, 'lower_filter') + upper_filter = properties.Optional(properties.Float, 'upper_filter') + lower_is_inclusive = properties.Boolean('lower_is_inclusive') + upper_is_inclusive = properties.Boolean('upper_is_inclusive') + + +class IntegerFilter(Serializable['IntegerFilter']): + """ + A general filter for Integer/Discrete Values. + + Parameters + ---------- + lower_filter: str + The lower bound on this filter range. + upper_filter: str + The upper bound on this filter range. + lower_is_inclusive: bool + Whether the lower bound value included in the valid range. + upper_is_inclusive: bool + Whether the upper bound value included in the valid range. + + """ + + lower_filter = properties.Optional(properties.Float, 'lower_filter') + upper_filter = properties.Optional(properties.Float, 'upper_filter') + lower_is_inclusive = properties.Boolean('lower_is_inclusive') + upper_is_inclusive = properties.Boolean('upper_is_inclusive') + + +class MaterialClassification(BaseEnumeration): + """A classification based on where in a Material History you find a Material.""" + + ATOMIC_INGREDIENT = "atomic_ingredient" + INTERMEDIATE_INGREDIENT = "intermediate_ingredient" + TERMINAL_MATERIAL = "terminal_material" + + +class TextSearchType(BaseEnumeration): + """The style of text search to run.""" + + EXACT = "exact" + PREFIX = "prefix" + SUFFIX = "suffix" + SUBSTRING = "substring" + + +class PropertyFilterType(PolymorphicSerializable): + """Abstract concept of a criteria to apply when searching for materials.""" + + @classmethod + def get_type(cls, data) -> Type[Serializable]: + """Return the subtype.""" + classes: List[Type[PropertyFilterType]] = [ + NominalCategoricalFilter, + NominalRealFilter, NormalRealFilter, UniformRealFilter, + NominalIntegerFilter, UniformIntegerFilter, + AllRealFilter, AllIntegerFilter + ] + return {klass.typ: klass for klass in classes}[data['type']] + + +class NominalCategoricalFilter(Serializable['NominalCategoricalFilter'], PropertyFilterType): + """ + Filter based upon a fixed list of Categorical Values. + + Parameters + ---------- + categories: Set[str] + Which categorical values match. + + """ + + categories = properties.Set(properties.String, 'categories') + typ = properties.String('type', default="nominal_categorical_filter", deserializable=False) + + +class NominalRealFilter(Serializable['NominalRealFilter'], PropertyFilterType): + """ + Filter for Nominal Reals that fit certain constraints. + + Parameters + ---------- + values: Set[RealFilter] + What value filter to use. + + """ + + values = properties.Object(RealFilter, 'values') + typ = properties.String('type', default="nominal_real_filter", deserializable=False) + + +class NormalRealFilter(Serializable['NormalRealFilter'], PropertyFilterType): + """ + Filter for Normal Reals that fit certain constraints. + + Parameters + ---------- + values: Set[RealFilter] + What value filter to use. + + """ + + values = properties.Object(RealFilter, 'values') + typ = properties.String('type', default="normal_real_filter", deserializable=False) + + +class UniformRealFilter(Serializable['UniformRealFilter'], PropertyFilterType): + """ + Filter for Uniform Reals that fit certain constraints. + + Parameters + ---------- + values: Set[RealFilter] + What value filter to use. + + """ + + values = properties.Object(RealFilter, 'values') + typ = properties.String('type', default="uniform_real_filter", deserializable=False) + + +class NominalIntegerFilter(Serializable['NominalIntegerFilter'], PropertyFilterType): + """ + Filter for Nominal Integers that fit certain constraints. + + Parameters + ---------- + values: Set[IntegerFilter] + What value filter to use. + + """ + + values = properties.Object(IntegerFilter, 'values') + typ = properties.String('type', default="nominal_integer_filter", deserializable=False) + + +class UniformIntegerFilter(Serializable['UniformIntegerFilter'], PropertyFilterType): + """ + Filter for Uniform Integers that fit certain constraints. + + Parameters + ---------- + values: Set[IntegerFilter] + What value filter to use. + + """ + + values = properties.Object(IntegerFilter, 'values') + typ = properties.String('type', default="uniform_integer_filter", deserializable=False) + + +class AllRealFilter(Serializable['AllRealFilter'], PropertyFilterType): + """ + Filter for any real value that fits certain constraints. + + Parameters + ---------- + lower: str + The lower bound on this filter range. + upper: str + The upper bound on this filter range. + unit: str + The units associated with the floating point values for this filter. + + """ + + lower = properties.Float('lower') + upper = properties.Float('upper') + unit = properties.String('units') + typ = properties.String('type', default="all_real_filter", deserializable=False) + + +class AllIntegerFilter(Serializable['AllIntegerFilter'], PropertyFilterType): + """ + Filter for any integer value that fits certain constraints. + + Parameters + ---------- + lower: str + The lower bound on this filter range. + upper: str + The upper bound on this filter range. + + """ + + lower = properties.Float('lower') + upper = properties.Float('upper') + typ = properties.String('type', default="all_integer_filter", deserializable=False) diff --git a/src/citrine/gemd_queries/gemd_query.py b/src/citrine/gemd_queries/gemd_query.py new file mode 100644 index 000000000..28b3deaff --- /dev/null +++ b/src/citrine/gemd_queries/gemd_query.py @@ -0,0 +1,64 @@ +"""Definitions for GemdQuery objects, and their sub-objects.""" +from gemd.enumeration.base_enumeration import BaseEnumeration + +from citrine._serialization.serializable import Serializable +from citrine._serialization import properties +from citrine.gemd_queries.criteria import Criteria + + +class GemdObjectType(BaseEnumeration): + """The style of text search to run.""" + + # An old defect has some old GemdQuery values stored with invalid enums + # The synonyms will allow invalid old values to be read, but not emitted + MEASUREMENT_TEMPLATE_TYPE = "measurement_template", "MEASUREMENT_TEMPLATE_TYPE" + MATERIAL_TEMPLATE_TYPE = "material_template", "MATERIAL_TEMPLATE_TYPE" + PROCESS_TEMPLATE_TYPE = "process_template", "PROCESS_TEMPLATE_TYPE" + PROPERTY_TEMPLATE_TYPE = "property_template", "PROPERTY_TEMPLATE_TYPE" + CONDITION_TEMPLATE_TYPE = "condition_template", "CONDITION_TEMPLATE_TYPE" + PARAMETER_TEMPLATE_TYPE = "parameter_template", "PARAMETER_TEMPLATE_TYPE" + PROCESS_RUN_TYPE = "process_run", "PROCESS_RUN_TYPE" + PROCESS_SPEC_TYPE = "process_spec", "PROCESS_SPEC_TYPE" + MATERIAL_RUN_TYPE = "material_run", "MATERIAL_RUN_TYPE" + MATERIAL_SPEC_TYPE = "material_spec", "MATERIAL_SPEC_TYPE" + INGREDIENT_RUN_TYPE = "ingredient_run", "INGREDIENT_RUN_TYPE" + INGREDIENT_SPEC_TYPE = "ingredient_spec", "INGREDIENT_SPEC_TYPE" + MEASUREMENT_RUN_TYPE = "measurement_run", "MEASUREMENT_RUN_TYPE" + MEASUREMENT_SPEC_TYPE = "measurement_spec", "MEASUREMENT_SPEC_TYPE" + + +class GemdQuery(Serializable['GemdQuery']): + """ + This describes what data objects to fetch (or graph of data objects). + + Parameters + ---------- + criteria: Criteria + List of conditions all responses must satisfy (i.e., joined with an AND). + datasets: UUID + Set of datasets to look in for matching objects. + object_types: GemdObjectType + Classes of objects to consider when searching. + schema_version: Int + What version of the query schema this package represents. + + """ + + criteria = properties.List(properties.Object(Criteria), "criteria", default=[]) + datasets = properties.Set(properties.UUID, "datasets", default=set()) + object_types = properties.Set( + properties.Enumeration(GemdObjectType), + 'object_types', + default={x for x in GemdObjectType} + ) + schema_version = properties.Integer('schema_version', default=1) + + @classmethod + def _pre_build(cls, data: dict) -> dict: + """Run data modification before building.""" + version = data.get('schema_version') + if data.get('schema_version') != 1: + raise ValueError( + f"This version of the library only supports schema_version 1, not '{version}'" + ) + return data diff --git a/src/citrine/resources/table_config.py b/src/citrine/resources/table_config.py index 4099d461a..5a6357c05 100644 --- a/src/citrine/resources/table_config.py +++ b/src/citrine/resources/table_config.py @@ -14,6 +14,7 @@ from citrine._utils.functions import format_escaped_url from citrine.resources.data_concepts import CITRINE_SCOPE, _make_link_by_uid from citrine.resources.process_template import ProcessTemplate +from citrine.gemd_queries.gemd_query import GemdQuery from citrine.gemtables.columns import Column, MeanColumn, IdentityColumn, OriginalUnitsColumn, \ ConcatColumn from citrine.gemtables.rows import Row @@ -53,6 +54,8 @@ class TableConfig(Resource["TableConfig"]): List of row definitions that define the rows of the table columns: list[Column] Column definitions, which describe how the variables are shaped into the table + gemd_query: Optional[GemdQuery] + The query used to define the materials underpinning this table """ @@ -79,6 +82,7 @@ def _get_dups(lst: List) -> List: variables = properties.List(properties.Object(Variable), "variables") rows = properties.List(properties.Object(Row), "rows") columns = properties.List(properties.Object(Column), "columns") + gemd_query = properties.Optional(properties.Object(GemdQuery), "gemd_query") def __init__(self, name: str, *, description: str, datasets: List[UUID], variables: List[Variable], rows: List[Row], columns: List[Column]): diff --git a/tests/gemd_query/__init__.py b/tests/gemd_query/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/gemd_query/test_gemd_query.py b/tests/gemd_query/test_gemd_query.py new file mode 100644 index 000000000..b302da482 --- /dev/null +++ b/tests/gemd_query/test_gemd_query.py @@ -0,0 +1,45 @@ +from uuid import uuid4 +import pytest + +from citrine.gemd_queries.criteria import PropertiesCriteria +from citrine.gemd_queries.filter import AllRealFilter +from citrine.gemd_queries.gemd_query import GemdQuery + +from tests.utils.factories import GemdQueryDataFactory + + +def test_gemd_query_version(): + valid = GemdQueryDataFactory() + assert GemdQuery.build(valid) is not None + + invalid = GemdQueryDataFactory() + invalid['schema_version'] = 2 + with pytest.raises(ValueError): + GemdQuery.build(invalid) + + +def test_criteria_rebuild(): + value_filter = AllRealFilter() + value_filter.unit = 'm' + value_filter.lower = 0 + value_filter.upper = 1 + + crit = PropertiesCriteria() + crit.property_templates_filter = {uuid4()} + crit.value_type_filter = value_filter + + query = GemdQuery() + query.criteria.append(crit) + query.datasets.add(uuid4()) + query.object_types = {'material_run'} + + query_copy = GemdQuery.build(query.dump()) + + assert len(query.criteria) == len(query_copy.criteria) + assert query.criteria[0].property_templates_filter == query_copy.criteria[0].property_templates_filter + assert query.criteria[0].value_type_filter.unit == query_copy.criteria[0].value_type_filter.unit + assert query.criteria[0].value_type_filter.lower == query_copy.criteria[0].value_type_filter.lower + assert query.criteria[0].value_type_filter.upper == query_copy.criteria[0].value_type_filter.upper + assert query.datasets == query_copy.datasets + assert query.object_types == query_copy.object_types + assert query.schema_version == query_copy.schema_version diff --git a/tests/utils/factories.py b/tests/utils/factories.py index 487da4276..56cb858dd 100644 --- a/tests/utils/factories.py +++ b/tests/utils/factories.py @@ -7,6 +7,9 @@ from random import randrange, random import factory +from citrine.gemd_queries.gemd_query import * +from citrine.gemd_queries.criteria import * +from citrine.gemd_queries.filter import * from citrine.informatics.scores import LIScore from citrine.resources.dataset import Dataset from citrine.resources.file_link import _Uploader @@ -114,6 +117,57 @@ class ListGemTableVersionsDataFactory(factory.DictFactory): tables[2]["version"] = 2 +class PropertiesCriteriaDataFactory(factory.DictFactory): + type = PropertiesCriteria.typ + property_templates_filter = factory.List([factory.Faker('uuid4')]) + classifications = factory.Faker('random_element', elements=[str(x) for x in MaterialClassification]) + + +class NameCriteriaDataFactory(factory.DictFactory): + type = NameCriteria.typ + name = factory.Faker('word') + search_type = factory.Faker('random_element', elements=[str(x) for x in TextSearchType]) + + +class MaterialRunClassificationCriteriaDataFactory(factory.DictFactory): + type = MaterialRunClassificationCriteria.typ + classifications = factory.Faker( + 'random_elements', + elements=[str(x) for x in MaterialClassification], + unique=True + ) + + +class MaterialTemplatesCriteriaDataFactory(factory.DictFactory): + type = MaterialTemplatesCriteria.typ + material_templates_identifiers = factory.List([factory.Faker('uuid4')]) + tag_filters = factory.Faker('words', unique=True) + + +class AndOperatorCriteriaDataFactory(factory.DictFactory): + type = AndOperator.typ + criteria = factory.List([ + factory.SubFactory(NameCriteriaDataFactory), + factory.SubFactory(MaterialRunClassificationCriteriaDataFactory), + factory.SubFactory(MaterialTemplatesCriteriaDataFactory) + ]) + + +class OrOperatorCriteriaDataFactory(factory.DictFactory): + type = OrOperator.typ + criteria = factory.List([ + factory.SubFactory(PropertiesCriteriaDataFactory), + factory.SubFactory(AndOperatorCriteriaDataFactory) + ]) + + +class GemdQueryDataFactory(factory.DictFactory): + criteria = factory.List([factory.SubFactory(OrOperatorCriteriaDataFactory)]) + datasets = factory.List([factory.Faker('uuid4')]) + object_types = factory.List([str(x) for x in GemdObjectType]) + schema_version = 1 + + class TableConfigJSONDataFactory(factory.DictFactory): """ This is simply the JSON Blob stored in an Table Config Version""" name = factory.Faker("company") @@ -122,6 +176,7 @@ class TableConfigJSONDataFactory(factory.DictFactory): columns = [] variables = [] datasets = [] + gemd_query = factory.SubFactory(GemdQueryDataFactory) class TableConfigVersionJSONDataFactory(factory.DictFactory): From ef55b55f3410bdbe2430bf1833525a3253cfdd16 Mon Sep 17 00:00:00 2001 From: Ken Kroenlein Date: Wed, 13 Mar 2024 18:09:50 -0600 Subject: [PATCH 2/2] PR feedback --- src/citrine/gemd_queries/criteria.py | 5 +- src/citrine/gemd_queries/filter.py | 182 +++------------------------ tests/utils/factories.py | 41 +++++- 3 files changed, 58 insertions(+), 170 deletions(-) diff --git a/src/citrine/gemd_queries/criteria.py b/src/citrine/gemd_queries/criteria.py index d4f0c3847..a020f53de 100644 --- a/src/citrine/gemd_queries/criteria.py +++ b/src/citrine/gemd_queries/criteria.py @@ -9,10 +9,9 @@ from citrine.gemd_queries.filter import PropertyFilterType __all__ = ['MaterialClassification', 'TextSearchType', - 'Criteria', 'AndOperator', 'OrOperator', - 'PropertiesCriteria', 'NameCriteria', 'MaterialRunClassificationCriteria', - 'MaterialTemplatesCriteria' + 'PropertiesCriteria', 'NameCriteria', + 'MaterialRunClassificationCriteria', 'MaterialTemplatesCriteria' ] diff --git a/src/citrine/gemd_queries/filter.py b/src/citrine/gemd_queries/filter.py index ebc739ae9..4cd7b018c 100644 --- a/src/citrine/gemd_queries/filter.py +++ b/src/citrine/gemd_queries/filter.py @@ -1,82 +1,11 @@ """Definitions for GemdQuery objects, and their sub-objects.""" from typing import List, Type -from gemd.enumeration.base_enumeration import BaseEnumeration - from citrine._serialization.serializable import Serializable from citrine._serialization.polymorphic_serializable import PolymorphicSerializable from citrine._serialization import properties -__all__ = ['PropertyFilterType', - 'RealFilter', 'IntegerFilter', - 'MaterialClassification', 'TextSearchType', - ] - - -class RealFilter(Serializable['RealFilter']): - """ - A general filter for Real/Continuous Values. - - Parameters - ---------- - unit: str - The units associated with the floating point values for this filter. - lower_filter: str - The lower bound on this filter range. - upper_filter: str - The upper bound on this filter range. - lower_is_inclusive: bool - Whether the lower bound value included in the valid range. - upper_is_inclusive: bool - Whether the upper bound value included in the valid range. - - """ - - unit = properties.String('unit') - lower_filter = properties.Optional(properties.Float, 'lower_filter') - upper_filter = properties.Optional(properties.Float, 'upper_filter') - lower_is_inclusive = properties.Boolean('lower_is_inclusive') - upper_is_inclusive = properties.Boolean('upper_is_inclusive') - - -class IntegerFilter(Serializable['IntegerFilter']): - """ - A general filter for Integer/Discrete Values. - - Parameters - ---------- - lower_filter: str - The lower bound on this filter range. - upper_filter: str - The upper bound on this filter range. - lower_is_inclusive: bool - Whether the lower bound value included in the valid range. - upper_is_inclusive: bool - Whether the upper bound value included in the valid range. - - """ - - lower_filter = properties.Optional(properties.Float, 'lower_filter') - upper_filter = properties.Optional(properties.Float, 'upper_filter') - lower_is_inclusive = properties.Boolean('lower_is_inclusive') - upper_is_inclusive = properties.Boolean('upper_is_inclusive') - - -class MaterialClassification(BaseEnumeration): - """A classification based on where in a Material History you find a Material.""" - - ATOMIC_INGREDIENT = "atomic_ingredient" - INTERMEDIATE_INGREDIENT = "intermediate_ingredient" - TERMINAL_MATERIAL = "terminal_material" - - -class TextSearchType(BaseEnumeration): - """The style of text search to run.""" - - EXACT = "exact" - PREFIX = "prefix" - SUFFIX = "suffix" - SUBSTRING = "substring" +__all__ = ['AllRealFilter', 'AllIntegerFilter', 'NominalCategoricalFilter'] class PropertyFilterType(PolymorphicSerializable): @@ -87,103 +16,11 @@ def get_type(cls, data) -> Type[Serializable]: """Return the subtype.""" classes: List[Type[PropertyFilterType]] = [ NominalCategoricalFilter, - NominalRealFilter, NormalRealFilter, UniformRealFilter, - NominalIntegerFilter, UniformIntegerFilter, AllRealFilter, AllIntegerFilter ] return {klass.typ: klass for klass in classes}[data['type']] -class NominalCategoricalFilter(Serializable['NominalCategoricalFilter'], PropertyFilterType): - """ - Filter based upon a fixed list of Categorical Values. - - Parameters - ---------- - categories: Set[str] - Which categorical values match. - - """ - - categories = properties.Set(properties.String, 'categories') - typ = properties.String('type', default="nominal_categorical_filter", deserializable=False) - - -class NominalRealFilter(Serializable['NominalRealFilter'], PropertyFilterType): - """ - Filter for Nominal Reals that fit certain constraints. - - Parameters - ---------- - values: Set[RealFilter] - What value filter to use. - - """ - - values = properties.Object(RealFilter, 'values') - typ = properties.String('type', default="nominal_real_filter", deserializable=False) - - -class NormalRealFilter(Serializable['NormalRealFilter'], PropertyFilterType): - """ - Filter for Normal Reals that fit certain constraints. - - Parameters - ---------- - values: Set[RealFilter] - What value filter to use. - - """ - - values = properties.Object(RealFilter, 'values') - typ = properties.String('type', default="normal_real_filter", deserializable=False) - - -class UniformRealFilter(Serializable['UniformRealFilter'], PropertyFilterType): - """ - Filter for Uniform Reals that fit certain constraints. - - Parameters - ---------- - values: Set[RealFilter] - What value filter to use. - - """ - - values = properties.Object(RealFilter, 'values') - typ = properties.String('type', default="uniform_real_filter", deserializable=False) - - -class NominalIntegerFilter(Serializable['NominalIntegerFilter'], PropertyFilterType): - """ - Filter for Nominal Integers that fit certain constraints. - - Parameters - ---------- - values: Set[IntegerFilter] - What value filter to use. - - """ - - values = properties.Object(IntegerFilter, 'values') - typ = properties.String('type', default="nominal_integer_filter", deserializable=False) - - -class UniformIntegerFilter(Serializable['UniformIntegerFilter'], PropertyFilterType): - """ - Filter for Uniform Integers that fit certain constraints. - - Parameters - ---------- - values: Set[IntegerFilter] - What value filter to use. - - """ - - values = properties.Object(IntegerFilter, 'values') - typ = properties.String('type', default="uniform_integer_filter", deserializable=False) - - class AllRealFilter(Serializable['AllRealFilter'], PropertyFilterType): """ Filter for any real value that fits certain constraints. @@ -201,7 +38,7 @@ class AllRealFilter(Serializable['AllRealFilter'], PropertyFilterType): lower = properties.Float('lower') upper = properties.Float('upper') - unit = properties.String('units') + unit = properties.String('unit') typ = properties.String('type', default="all_real_filter", deserializable=False) @@ -221,3 +58,18 @@ class AllIntegerFilter(Serializable['AllIntegerFilter'], PropertyFilterType): lower = properties.Float('lower') upper = properties.Float('upper') typ = properties.String('type', default="all_integer_filter", deserializable=False) + + +class NominalCategoricalFilter(Serializable['NominalCategoricalFilter'], PropertyFilterType): + """ + Filter based upon a fixed list of Categorical Values. + + Parameters + ---------- + categories: Set[str] + Which categorical values match. + + """ + + categories = properties.Set(properties.String, 'categories') + typ = properties.String('type', default="nominal_categorical_filter", deserializable=False) diff --git a/tests/utils/factories.py b/tests/utils/factories.py index 56cb858dd..bc1eba530 100644 --- a/tests/utils/factories.py +++ b/tests/utils/factories.py @@ -117,16 +117,51 @@ class ListGemTableVersionsDataFactory(factory.DictFactory): tables[2]["version"] = 2 +class RealFilterDataFactory(factory.DictFactory): + type = AllRealFilter.typ + unit = 'dimensionless' + + class Params: + midpoint = factory.Faker("pyfloat") + + lower = factory.LazyAttribute(lambda o: min(0., 2. * o.midpoint) + random() * o.midpoint) + upper = factory.LazyAttribute(lambda o: max(0., 2. * o.midpoint) - random() * o.midpoint) + + +class IntegerFilterDataFactory(factory.DictFactory): + type = AllIntegerFilter.typ + + class Params: + midpoint = factory.Faker("pyint") + + lower = factory.LazyAttribute(lambda o: randrange(min(0, 2 * o.midpoint), o.midpoint + 1)) + upper = factory.LazyAttribute(lambda o: randrange(o.midpoint, max(0, 2 * o.midpoint) + 1)) + + +class CategoryFilterDataFactory(factory.DictFactory): + type = NominalCategoricalFilter.typ + categories = factory.Faker('words', unique=True) + + class PropertiesCriteriaDataFactory(factory.DictFactory): type = PropertiesCriteria.typ property_templates_filter = factory.List([factory.Faker('uuid4')]) - classifications = factory.Faker('random_element', elements=[str(x) for x in MaterialClassification]) + value_type_filter = factory.SubFactory(RealFilterDataFactory) + classifications = factory.Faker('enum', enum_cls=MaterialClassification) + + class Params: + integer = factory.Trait( + value_type_filter=factory.SubFactory(IntegerFilterDataFactory) + ) + category = factory.Trait( + value_type_filter=factory.SubFactory(CategoryFilterDataFactory) + ) class NameCriteriaDataFactory(factory.DictFactory): type = NameCriteria.typ name = factory.Faker('word') - search_type = factory.Faker('random_element', elements=[str(x) for x in TextSearchType]) + search_type = factory.Faker('enum', enum_cls=TextSearchType) class MaterialRunClassificationCriteriaDataFactory(factory.DictFactory): @@ -157,6 +192,8 @@ class OrOperatorCriteriaDataFactory(factory.DictFactory): type = OrOperator.typ criteria = factory.List([ factory.SubFactory(PropertiesCriteriaDataFactory), + factory.SubFactory(PropertiesCriteriaDataFactory, integer=True), + factory.SubFactory(PropertiesCriteriaDataFactory, category=True), factory.SubFactory(AndOperatorCriteriaDataFactory) ])