diff --git a/clients/client-python/gravitino/api/expressions/partitions/partitions.py b/clients/client-python/gravitino/api/expressions/partitions/partitions.py new file mode 100644 index 00000000000..051c578ff8c --- /dev/null +++ b/clients/client-python/gravitino/api/expressions/partitions/partitions.py @@ -0,0 +1,228 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import List, Dict, Any, Optional +from abc import ABC, abstractmethod + +from gravitino.api.expressions.literals.literal import Literal +from gravitino.api.expressions.partitions.identity_partition import IdentityPartition +from gravitino.api.expressions.partitions.list_partition import ListPartition +from gravitino.api.expressions.partitions.partition import Partition +from gravitino.api.expressions.partitions.range_partition import RangePartition + + +class Partitions: + """The helper class for partition expressions.""" + + EMPTY_PARTITIONS: List[Partition] = [] + """ + An empty array of partitions + """ + + @staticmethod + def range( + name: str, + upper: Literal[Any], + lower: Literal[Any], + properties: Optional[Dict[str, str]], + ) -> RangePartition: + """ + Creates a range partition. + + Args: + name: The name of the partition. + upper: The upper bound of the partition. + lower: The lower bound of the partition. + properties: The properties of the partition. + + Returns: + The created partition. + """ + return RangePartitionImpl(name, upper, lower, properties) + + @staticmethod + def list( + name: str, + lists: List[List[Literal]], + properties: Optional[Dict[str, str]] = None, + ) -> ListPartition: + """ + Creates a list partition. + + Args: + name: The name of the partition. + lists: The values of the list partition. + properties: The properties of the partition. + + Returns: + The created partition. + """ + return ListPartitionImpl(name, lists, properties or {}) + + @staticmethod + def identity( + name: Optional[str], + field_names: List[List[str]], + values: List[Literal], + properties: Optional[Dict[str, str]] = None, + ) -> IdentityPartition: + """ + Creates an identity partition. + + The `values` must correspond to the `field_names`. + + Args: + name: The name of the partition. + field_names: The field names of the identity partition. + values: The value of the identity partition. + properties: The properties of the partition. + + Returns: + The created partition. + """ + return IdentityPartitionImpl(name, field_names, values, properties or {}) + + +class RangePartitionImpl(RangePartition): + def __init__( + self, + name: str, + upper: Literal, + lower: Literal, + properties: Optional[Dict[str, str]], + ): + self._name = name + self._upper = upper + self._lower = lower + self._properties = properties + + def upper(self) -> Literal: + """Returns the upper bound of the partition.""" + return self._upper + + def lower(self) -> Literal: + """Returns the lower bound of the partition.""" + return self._lower + + def name(self) -> str: + return self._name + + def properties(self) -> Optional[Dict[str, str]]: + return self._properties + + def __eq__(self, other: Any) -> bool: + if isinstance(other, RangePartitionImpl): + return ( + self._name == other._name + and self._upper == other._upper + and self._lower == other._lower + and self._properties == other._properties + ) + return False + + def __hash__(self) -> int: + return hash( + (self._name, self._upper, self._lower, frozenset(self._properties.items())) + ) + + +class ListPartitionImpl(ListPartition): + def __init__( + self, + name: str, + lists: List[List[Literal]], + properties: Optional[Dict[str, str]], + ): + self._name = name + self._lists = lists + self._properties = properties + + def lists(self) -> List[List[Literal]]: + """Returns the values of the list partition.""" + return self._lists + + def name(self) -> str: + return self._name + + def properties(self) -> Dict[str, str]: + return self._properties + + def __eq__(self, other: Any) -> bool: + if isinstance(other, ListPartitionImpl): + return ( + self._name == other._name + and self._lists == other._lists + and self._properties == other._properties + ) + return False + + def __hash__(self) -> int: + return hash( + ( + self._name, + tuple(tuple(l) for l in self._lists), + frozenset(self._properties.items()), + ) + ) + + +class IdentityPartitionImpl(IdentityPartition): + def __init__( + self, + name: Optional[str], + field_names: List[List[str]], + values: List[Literal], + properties: Dict[str, str], + ): + self._name = name + self._field_names = field_names + self._values = values + self._properties = properties + + def field_names(self) -> List[List[str]]: + """Returns the field names of the identity partition.""" + return self._field_names + + def values(self) -> List[Literal]: + """Returns the values of the identity partition.""" + return self._values + + def name(self) -> str: + return self._name + + def properties(self) -> Dict[str, str]: + return self._properties + + def __eq__(self, other: Any) -> bool: + if isinstance(other, IdentityPartitionImpl): + return ( + self._name == other._name + and self._field_names == other._field_names + and self._values == other._values + and self._properties == other._properties + ) + return False + + def __hash__(self) -> int: + return hash( + ( + self._name, + tuple(tuple(fn) for fn in self._field_names), + tuple(self._values), + frozenset(self._properties.items()), + ) + ) diff --git a/clients/client-python/tests/unittests/rel/test_partitions.py b/clients/client-python/tests/unittests/rel/test_partitions.py new file mode 100644 index 00000000000..ad586410c98 --- /dev/null +++ b/clients/client-python/tests/unittests/rel/test_partitions.py @@ -0,0 +1,82 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import unittest +from datetime import date, time, datetime +from decimal import Decimal + +from gravitino.api.expressions.literals.literals import Literals +from gravitino.api.expressions.partitions.partitions import Partitions +from gravitino.api.types.types import Types + + +class TestPartitions(unittest.TestCase): + def test_partitions(self): + # Test RangePartition + partition = Partitions.range( + "p0", Literals.NULL, Literals.integer_literal(6), {} + ) + self.assertEqual("p0", partition.name()) + self.assertEqual({}, partition.properties()) + self.assertEqual(Literals.NULL, partition.upper()) + self.assertEqual(Literals.integer_literal(6), partition.lower()) + + # Test ListPartition + partition = Partitions.list( + "p202204_California", + [ + [ + Literals.date_literal(date(2022, 4, 1)), + Literals.string_literal("Los Angeles"), + ], + [ + Literals.date_literal(date(2022, 4, 1)), + Literals.string_literal("San Francisco"), + ], + ], + {}, + ) + self.assertEqual("p202204_California", partition.name()) + self.assertEqual({}, partition.properties()) + self.assertEqual( + Literals.date_literal(date(2022, 4, 1)), partition.lists()[0][0] + ) + self.assertEqual( + Literals.string_literal("Los Angeles"), partition.lists()[0][1] + ) + self.assertEqual( + Literals.date_literal(date(2022, 4, 1)), partition.lists()[1][0] + ) + self.assertEqual( + Literals.string_literal("San Francisco"), partition.lists()[1][1] + ) + + # Test IdentityPartition + partition = Partitions.identity( + "dt=2008-08-08/country=us", + [["dt"], ["country"]], + [Literals.date_literal(date(2008, 8, 8)), Literals.string_literal("us")], + {"location": "/user/hive/warehouse/tpch_flat_orc_2.db/orders"}, + ) + self.assertEqual("dt=2008-08-08/country=us", partition.name()) + self.assertEqual( + {"location": "/user/hive/warehouse/tpch_flat_orc_2.db/orders"}, + partition.properties(), + ) + self.assertEqual(["dt"], partition.field_names()[0]) + self.assertEqual(["country"], partition.field_names()[1]) + self.assertEqual(Literals.date_literal(date(2008, 8, 8)), partition.values()[0]) + self.assertEqual(Literals.string_literal("us"), partition.values()[1])