Skip to content

Commit

Permalink
Merge pull request #257 from rabernat/switch-index-to-frozenset
Browse files Browse the repository at this point in the history
Switch Index to be a frozenset
  • Loading branch information
rabernat authored Jan 12, 2022
2 parents 4f8c052 + ff7cab3 commit 6194c72
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 31 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ repos:
- id: seed-isort-config

- repo: https://github.com/pre-commit/mirrors-mypy
rev: 'v0.910'
rev: 'v0.931'
hooks:
- id: mypy
exclude: tests
Expand Down
28 changes: 4 additions & 24 deletions pangeo_forge_recipes/patterns.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""
Filename / URL patterns.
"""

import inspect
from dataclasses import dataclass, field, replace
from enum import Enum
Expand All @@ -11,7 +10,7 @@
Callable,
ClassVar,
Dict,
Iterable,
FrozenSet,
Iterator,
List,
Optional,
Expand Down Expand Up @@ -68,7 +67,7 @@ class MergeDim:
operation: ClassVar[CombineOp] = CombineOp.MERGE


@dataclass(frozen=True)
@dataclass(frozen=True, order=True)
class DimIndex:
"""Object used to index a single dimension of a FilePattern or Recipe Chunks.
Expand All @@ -92,27 +91,8 @@ def __post_init__(self):
assert self.index < self.sequence_len


class Index(tuple):
"""A tuple of ``DimIndex`` objects.
The order of the indexes doesn't matter for comparision."""

def __new__(self, args: Iterable[DimIndex]):
# This validation really slows things down because we call Index a lot!
# if not all((isinstance(a, DimIndex) for a in args)):
# raise ValueError("All arguments must be DimIndex.")
# args_set = set(args)
# if len(set(args_set)) < len(tuple(args)):
# raise ValueError("Duplicate argument detected.")
return tuple.__new__(Index, args)

def __str__(self):
return ",".join(str(dim) for dim in self)

def __eq__(self, other):
return (set(self) == set(other)) and (len(self) == len(other))

def __hash__(self):
return hash(frozenset(self))
class Index(FrozenSet[DimIndex]):
pass


CombineDim = Union[MergeDim, ConcatDim]
Expand Down
2 changes: 2 additions & 0 deletions pangeo_forge_recipes/recipes/reference_hdf_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def scan_file(chunk_key: ChunkKey, config: HDFReferenceRecipe):
ref_fname = os.path.basename(fname + ".json")
with file_opener(fname, **config.netcdf_storage_options) as fp:
protocol = getattr(getattr(fp, "fs", None), "protocol", None) # make mypy happy
if protocol is None:
raise ValueError("Couldn't determine protocol")
target_url = unstrip_protocol(fname, protocol)
config.metadata_cache[ref_fname] = create_hdf5_reference(fp, target_url, fname)

Expand Down
4 changes: 2 additions & 2 deletions pangeo_forge_recipes/recipes/xarray_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,12 @@


def _input_metadata_fname(input_key: InputKey) -> str:
key_str = "-".join([f"{k.name}_{k.index}" for k in input_key])
key_str = "-".join([f"{k.name}_{k.index}" for k in sorted(input_key)])
return "input-meta-" + key_str + ".json"


def _input_reference_fname(input_key: InputKey) -> str:
key_str = "-".join([f"{k.name}_{k.index}" for k in input_key])
key_str = "-".join([f"{k.name}_{k.index}" for k in sorted(input_key)])
return "input-reference-" + key_str + ".json"


Expand Down
8 changes: 4 additions & 4 deletions tests/test_patterns.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_pattern_from_file_sequence():
assert fp.nitems_per_input == {"time": None}
assert fp.concat_sequence_lens == {"time": None}
for key in fp:
assert fp[key] == file_sequence[key[0].index]
assert fp[key] == file_sequence[sorted(key)[0].index]


@pytest.mark.parametrize("pickle", [False, True])
Expand Down Expand Up @@ -115,17 +115,17 @@ def test_file_pattern_concat_merge(runtime_secrets, pickle, concat_merge_pattern
assert fp.concat_sequence_lens == {"time": None}
assert len(list(fp)) == 6
for key in fp:
expected_fname = format_function(time=times[key[1].index], variable=varnames[key[0].index])
for k in key:
if k.name == "time":
assert k.operation == CombineOp.CONCAT
assert k.sequence_len == 3
time_val = times[k.index]
if k.name == "variable":
assert k.operation == CombineOp.MERGE
assert k.sequence_len == 2
variable_val = varnames[k.index]
expected_fname = format_function(time=time_val, variable=variable_val)
assert fp[key] == expected_fname
# make sure key order doesn't matter
assert fp[key[::-1]] == expected_fname

if "fsspec_open_kwargs" in kwargs.keys():
assert fp.is_opendap is False
Expand Down

0 comments on commit 6194c72

Please sign in to comment.