Skip to content

Commit

Permalink
feat(partitions): add partitions and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
unknowntpo committed Dec 31, 2024
1 parent 090e874 commit f08ba54
Show file tree
Hide file tree
Showing 2 changed files with 310 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from typing import List, Dict, Any, Optional
from abc import ABC, abstractmethod

from gravitino.api.expressions.literals.literal import Literal
from gravitino.api.expressions.partitions.identity_partition import IdentityPartition
from gravitino.api.expressions.partitions.list_partition import ListPartition
from gravitino.api.expressions.partitions.partition import Partition
from gravitino.api.expressions.partitions.range_partition import RangePartition


class Partitions:
"""The helper class for partition expressions."""

EMPTY_PARTITIONS: List[Partition] = []
"""
An empty array of partitions
"""

@staticmethod
def range(
name: str,
upper: Literal[Any],
lower: Literal[Any],
properties: Optional[Dict[str, str]],
) -> RangePartition:
"""
Creates a range partition.
Args:
name: The name of the partition.
upper: The upper bound of the partition.
lower: The lower bound of the partition.
properties: The properties of the partition.
Returns:
The created partition.
"""
return RangePartitionImpl(name, upper, lower, properties)

@staticmethod
def list(
name: str,
lists: List[List[Literal]],
properties: Optional[Dict[str, str]] = None,
) -> ListPartition:
"""
Creates a list partition.
Args:
name: The name of the partition.
lists: The values of the list partition.
properties: The properties of the partition.
Returns:
The created partition.
"""
return ListPartitionImpl(name, lists, properties or {})

@staticmethod
def identity(
name: Optional[str],
field_names: List[List[str]],
values: List[Literal],
properties: Optional[Dict[str, str]] = None,
) -> IdentityPartition:
"""
Creates an identity partition.
The `values` must correspond to the `field_names`.
Args:
name: The name of the partition.
field_names: The field names of the identity partition.
values: The value of the identity partition.
properties: The properties of the partition.
Returns:
The created partition.
"""
return IdentityPartitionImpl(name, field_names, values, properties or {})


class RangePartitionImpl(RangePartition):
def __init__(
self,
name: str,
upper: Literal,
lower: Literal,
properties: Optional[Dict[str, str]],
):
self._name = name
self._upper = upper
self._lower = lower
self._properties = properties

def upper(self) -> Literal:
"""Returns the upper bound of the partition."""
return self._upper

def lower(self) -> Literal:
"""Returns the lower bound of the partition."""
return self._lower

def name(self) -> str:
return self._name

def properties(self) -> Optional[Dict[str, str]]:
return self._properties

def __eq__(self, other: Any) -> bool:
if isinstance(other, RangePartitionImpl):
return (
self._name == other._name
and self._upper == other._upper
and self._lower == other._lower
and self._properties == other._properties
)
return False

def __hash__(self) -> int:
return hash(
(self._name, self._upper, self._lower, frozenset(self._properties.items()))
)


class ListPartitionImpl(ListPartition):
def __init__(
self,
name: str,
lists: List[List[Literal]],
properties: Optional[Dict[str, str]],
):
self._name = name
self._lists = lists
self._properties = properties

def lists(self) -> List[List[Literal]]:
"""Returns the values of the list partition."""
return self._lists

def name(self) -> str:
return self._name

def properties(self) -> Dict[str, str]:
return self._properties

def __eq__(self, other: Any) -> bool:
if isinstance(other, ListPartitionImpl):
return (
self._name == other._name
and self._lists == other._lists
and self._properties == other._properties
)
return False

def __hash__(self) -> int:
return hash(
(
self._name,
tuple(tuple(l) for l in self._lists),
frozenset(self._properties.items()),
)
)


class IdentityPartitionImpl(IdentityPartition):
def __init__(
self,
name: Optional[str],
field_names: List[List[str]],
values: List[Literal],
properties: Dict[str, str],
):
self._name = name
self._field_names = field_names
self._values = values
self._properties = properties

def field_names(self) -> List[List[str]]:
"""Returns the field names of the identity partition."""
return self._field_names

def values(self) -> List[Literal]:
"""Returns the values of the identity partition."""
return self._values

def name(self) -> str:
return self._name

def properties(self) -> Dict[str, str]:
return self._properties

def __eq__(self, other: Any) -> bool:
if isinstance(other, IdentityPartitionImpl):
return (
self._name == other._name
and self._field_names == other._field_names
and self._values == other._values
and self._properties == other._properties
)
return False

def __hash__(self) -> int:
return hash(
(
self._name,
tuple(tuple(fn) for fn in self._field_names),
tuple(self._values),
frozenset(self._properties.items()),
)
)
82 changes: 82 additions & 0 deletions clients/client-python/tests/unittests/rel/test_partitions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import unittest
from datetime import date, time, datetime
from decimal import Decimal

from gravitino.api.expressions.literals.literals import Literals
from gravitino.api.expressions.partitions.partitions import Partitions
from gravitino.api.types.types import Types


class TestPartitions(unittest.TestCase):
def test_partitions(self):
# Test RangePartition
partition = Partitions.range(
"p0", Literals.NULL, Literals.integer_literal(6), {}
)
self.assertEqual("p0", partition.name())
self.assertEqual({}, partition.properties())
self.assertEqual(Literals.NULL, partition.upper())
self.assertEqual(Literals.integer_literal(6), partition.lower())

# Test ListPartition
partition = Partitions.list(
"p202204_California",
[
[
Literals.date_literal(date(2022, 4, 1)),
Literals.string_literal("Los Angeles"),
],
[
Literals.date_literal(date(2022, 4, 1)),
Literals.string_literal("San Francisco"),
],
],
{},
)
self.assertEqual("p202204_California", partition.name())
self.assertEqual({}, partition.properties())
self.assertEqual(
Literals.date_literal(date(2022, 4, 1)), partition.lists()[0][0]
)
self.assertEqual(
Literals.string_literal("Los Angeles"), partition.lists()[0][1]
)
self.assertEqual(
Literals.date_literal(date(2022, 4, 1)), partition.lists()[1][0]
)
self.assertEqual(
Literals.string_literal("San Francisco"), partition.lists()[1][1]
)

# Test IdentityPartition
partition = Partitions.identity(
"dt=2008-08-08/country=us",
[["dt"], ["country"]],
[Literals.date_literal(date(2008, 8, 8)), Literals.string_literal("us")],
{"location": "/user/hive/warehouse/tpch_flat_orc_2.db/orders"},
)
self.assertEqual("dt=2008-08-08/country=us", partition.name())
self.assertEqual(
{"location": "/user/hive/warehouse/tpch_flat_orc_2.db/orders"},
partition.properties(),
)
self.assertEqual(["dt"], partition.field_names()[0])
self.assertEqual(["country"], partition.field_names()[1])
self.assertEqual(Literals.date_literal(date(2008, 8, 8)), partition.values()[0])
self.assertEqual(Literals.string_literal("us"), partition.values()[1])

0 comments on commit f08ba54

Please sign in to comment.