Skip to content

Commit 2304f7c

Browse files
authored
fixes issue #80: filter_spatial() not correctly returning region information (#93)
* fixes issue #80: filter_spatial() not correctly returning region information - added tests for filtering using simple grid - added test to ensure region midpoints, idx_map, bbox_mask, and origins are the same - made catalog spatio_temporal_counts() check more strict
1 parent 2b421aa commit 2304f7c

File tree

2 files changed

+32
-2
lines changed

2 files changed

+32
-2
lines changed

csep/core/catalogs.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,7 @@ def filter_spatial(self, region=None, update_stats=False, in_place=True):
574574
else:
575575
cls = self.__class__
576576
inst = cls(data=filtered, catalog_id=self.catalog_id, format=self.format, name=self.name,
577-
region=region, compute_stats=update_stats)
577+
region=self.region, compute_stats=update_stats)
578578
return inst
579579

580580
def apply_mct(self, m_main, event_epoch, mc=2.5):
@@ -855,6 +855,7 @@ def plot(self, ax=None, show=False, extent=None, set_global=False, plot_args=Non
855855
set_global=set_global, plot_args=plot_args)
856856
return ax
857857

858+
858859
class CSEPCatalog(AbstractBaseCatalog):
859860
"""
860861
Standard catalog class for PyCSEP catalog operations.

tests/test_catalog.py

+30-1
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
import copy
22
import unittest
33
import os
4+
import itertools
45

56
import numpy
67

78
import csep
89
from csep.core import regions, forecasts
910
from csep.utils.time_utils import strptime_to_utc_epoch, strptime_to_utc_datetime
1011
from csep.core.catalogs import CSEPCatalog
12+
from csep.core.regions import CartesianGrid2D, Polygon, compute_vertices
1113

1214
def comcat_path():
1315
root_dir = os.path.dirname(os.path.abspath(__file__))
@@ -18,6 +20,18 @@ def comcat_path():
1820
class CatalogFiltering(unittest.TestCase):
1921
def setUp(self):
2022

23+
# create some arbitrary grid
24+
self.nx = 8
25+
self.ny = 10
26+
self.dh = 1
27+
x_points = numpy.arange(1.5, self.nx) * self.dh
28+
y_points = numpy.arange(1.5, self.ny) * self.dh
29+
30+
# spatial grid starts at (1.5, 1.5); so the event at (1, 1) should be removed.
31+
self.origins = list(itertools.product(x_points, y_points))
32+
self.num_nodes = len(self.origins)
33+
self.cart_grid = CartesianGrid2D([Polygon(bbox) for bbox in compute_vertices(self.origins, self.dh)], self.dh)
34+
2135
# define dummy cat
2236
date1 = strptime_to_utc_epoch('2009-01-01 00:00:00.0000')
2337
date2 = strptime_to_utc_epoch('2010-01-01 00:00:00.0000')
@@ -72,6 +86,21 @@ def test_filter_with_datetime(self):
7286
filtered_test_cat = test_cat.filter(filters, in_place=False)
7387
numpy.testing.assert_equal(numpy.array([b'1', b'2'], dtype='S256').T, filtered_test_cat.get_event_ids())
7488

89+
def test_filter_spatial(self):
90+
91+
test_cat = copy.deepcopy(self.test_cat1)
92+
filtered_test_cat = test_cat.filter_spatial(region=self.cart_grid)
93+
numpy.testing.assert_equal(numpy.array([b'2', b'3'], dtype='S256').T, filtered_test_cat.get_event_ids())
94+
95+
96+
def test_filter_spatial_in_place_return(self):
97+
test_cat = copy.deepcopy(self.test_cat1)
98+
filtered_test_cat = test_cat.filter_spatial(region=self.cart_grid, in_place=False)
99+
numpy.testing.assert_array_equal(filtered_test_cat.region.midpoints(), test_cat.region.midpoints())
100+
numpy.testing.assert_array_equal(filtered_test_cat.region.origins(), test_cat.region.origins())
101+
numpy.testing.assert_array_equal(filtered_test_cat.region.bbox_mask, test_cat.region.bbox_mask)
102+
numpy.testing.assert_array_equal(filtered_test_cat.region.idx_map, test_cat.region.idx_map)
103+
75104
def test_catalog_binning_and_filtering_acceptance(self):
76105
# create space-magnitude region
77106
region = regions.create_space_magnitude_region(
@@ -93,7 +122,7 @@ def test_catalog_binning_and_filtering_acceptance(self):
93122
# catalog filtered cumulative
94123
c = comcat.filter([f'magnitude >= {m_min}'], in_place=False)
95124
# catalog filtered incrementally
96-
c_int = comcat.filter([f'magnitude >= {m_min}', f'magnitude < {m_min + 0.09999999}'], in_place=False)
125+
c_int = comcat.filter([f'magnitude >= {m_min}', f'magnitude < {m_min + 0.1}'], in_place=False)
97126
# sum from overall data set
98127
gs = d.data[:, idm:].sum()
99128
# incremental counts

0 commit comments

Comments
 (0)