diff --git a/python/lsst/ap/association/__init__.py b/python/lsst/ap/association/__init__.py
index a999c1d4..f91a0347 100644
--- a/python/lsst/ap/association/__init__.py
+++ b/python/lsst/ap/association/__init__.py
@@ -20,6 +20,7 @@
# along with this program. If not, see .
from .version import *
+from .trailedSourceFilter import *
from .association import *
from .diaForcedSource import *
from .loadDiaCatalogs import *
diff --git a/python/lsst/ap/association/association.py b/python/lsst/ap/association/association.py
index 8df3daaa..a9c87002 100644
--- a/python/lsst/ap/association/association.py
+++ b/python/lsst/ap/association/association.py
@@ -32,6 +32,7 @@
import lsst.pex.config as pexConfig
import lsst.pipe.base as pipeBase
from lsst.utils.timer import timeMethod
+from .trailedSourceFilter import TrailedSourceFilterTask
# Enforce an error for unsafe column/array value setting in pandas.
pd.options.mode.chained_assignment = 'raise'
@@ -40,13 +41,27 @@
class AssociationConfig(pexConfig.Config):
"""Config class for AssociationTask.
"""
+
maxDistArcSeconds = pexConfig.Field(
dtype=float,
- doc='Maximum distance in arcseconds to test for a DIASource to be a '
- 'match to a DIAObject.',
+ doc="Maximum distance in arcseconds to test for a DIASource to be a "
+ "match to a DIAObject.",
default=1.0,
)
+ trailedSourceFilter = pexConfig.ConfigurableField(
+ target=TrailedSourceFilterTask,
+ doc="Subtask to remove long trailed sources based on catalog source "
+ "morphological measurements.",
+ )
+
+ doTrailedSourceFilter = pexConfig.Field(
+ doc="Run traildeSourceFilter to remove long trailed sources from "
+ "output catalog.",
+ dtype=bool,
+ default=True,
+ )
+
class AssociationTask(pipeBase.Task):
"""Associate DIAOSources into existing DIAObjects.
@@ -60,10 +75,16 @@ class AssociationTask(pipeBase.Task):
ConfigClass = AssociationConfig
_DefaultName = "association"
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ if self.config.doTrailedSourceFilter:
+ self.makeSubtask("trailedSourceFilter")
+
@timeMethod
def run(self,
diaSources,
- diaObjects):
+ diaObjects,
+ exposure_time=None):
"""Associate the new DiaSources with existing DiaObjects.
Parameters
@@ -72,22 +93,24 @@ def run(self,
New DIASources to be associated with existing DIAObjects.
diaObjects : `pandas.DataFrame`
Existing diaObjects from the Apdb.
+ exposure_time : `float`, optional
+ Exposure time from difference image.
Returns
-------
result : `lsst.pipe.base.Struct`
Results struct with components.
- - ``"matchedDiaSources"`` : DiaSources that were matched. Matched
+ - ``matchedDiaSources`` : DiaSources that were matched. Matched
Sources have their diaObjectId updated and set to the id of the
diaObject they were matched to. (`pandas.DataFrame`)
- - ``"unAssocDiaSources"`` : DiaSources that were not matched.
+ - ``unAssocDiaSources`` : DiaSources that were not matched.
Unassociated sources have their diaObject set to 0 as they
were not associated with any existing DiaObjects.
(`pandas.DataFrame`)
- - ``"nUpdatedDiaObjects"`` : Number of DiaObjects that were
+ - ``nUpdatedDiaObjects`` : Number of DiaObjects that were
matched to new DiaSources. (`int`)
- - ``"nUnassociatedDiaObjects"`` : Number of DiaObjects that were
+ - ``nUnassociatedDiaObjects`` : Number of DiaObjects that were
not matched a new DiaSource. (`int`)
"""
diaSources = self.check_dia_source_radec(diaSources)
@@ -98,7 +121,15 @@ def run(self,
nUpdatedDiaObjects=0,
nUnassociatedDiaObjects=0)
- matchResult = self.associate_sources(diaObjects, diaSources)
+ if self.config.doTrailedSourceFilter:
+ diaTrailedResult = self.trailedSourceFilter.run(diaSources, exposure_time)
+ matchResult = self.associate_sources(diaObjects, diaTrailedResult.diaSources)
+
+ self.log.info("%i DIASources exceed max_trail_length, dropping "
+ "from source catalog." % len(diaTrailedResult.trailedDiaSources))
+
+ else:
+ matchResult = self.associate_sources(diaObjects, diaSources)
mask = matchResult.diaSources["diaObjectId"] != 0
@@ -157,11 +188,11 @@ def associate_sources(self, dia_objects, dia_sources):
result : `lsst.pipe.base.Struct`
Results struct with components.
- - ``"diaSources"`` : Full set of diaSources both matched and not.
+ - ``diaSources`` : Full set of diaSources both matched and not.
(`pandas.DataFrame`)
- - ``"nUpdatedDiaObjects"`` : Number of DiaObjects that were
+ - ``nUpdatedDiaObjects`` : Number of DiaObjects that were
associated. (`int`)
- - ``"nUnassociatedDiaObjects"`` : Number of DiaObjects that were
+ - ``nUnassociatedDiaObjects`` : Number of DiaObjects that were
not matched a new DiaSource. (`int`)
"""
scores = self.score(
@@ -196,11 +227,11 @@ def score(self, dia_objects, dia_sources, max_dist):
result : `lsst.pipe.base.Struct`
Results struct with components:
- - ``"scores"``: array of floats of match quality updated DIAObjects
+ - ``scores``: array of floats of match quality updated DIAObjects
(array-like of `float`).
- - ``"obj_idxs"``: indexes of the matched DIAObjects in the catalog.
+ - ``obj_idxs``: indexes of the matched DIAObjects in the catalog.
(array-like of `int`)
- - ``"obj_ids"``: array of floats of match quality updated DIAObjects
+ - ``obj_ids``: array of floats of match quality updated DIAObjects
(array-like of `int`).
Default values for these arrays are
diff --git a/python/lsst/ap/association/diaPipe.py b/python/lsst/ap/association/diaPipe.py
index 1d3eaa47..30409eee 100644
--- a/python/lsst/ap/association/diaPipe.py
+++ b/python/lsst/ap/association/diaPipe.py
@@ -28,6 +28,10 @@
Currently loads directly from the Apdb rather than pre-loading.
"""
+__all__ = ("DiaPipelineConfig",
+ "DiaPipelineTask",
+ "DiaPipelineConnections")
+
import pandas as pd
import lsst.dax.apdb as daxApdb
@@ -44,10 +48,6 @@
PackageAlertsTask)
from lsst.ap.association.ssoAssociation import SolarSystemAssociationTask
-__all__ = ("DiaPipelineConfig",
- "DiaPipelineTask",
- "DiaPipelineConnections")
-
class DiaPipelineConnections(
pipeBase.PipelineTaskConnections,
@@ -367,8 +367,8 @@ def run(self,
loaderResult = self.diaCatalogLoader.run(diffIm, self.apdb)
# Associate new DiaSources with existing DiaObjects.
- assocResults = self.associator.run(diaSourceTable,
- loaderResult.diaObjects)
+ assocResults = self.associator.run(diaSourceTable, loaderResult.diaObjects,
+ exposure_time=diffIm.visitInfo.exposureTime)
if self.config.doSolarSystemAssociation:
ssoAssocResult = self.solarSystemAssociator.run(
assocResults.unAssocDiaSources,
diff --git a/python/lsst/ap/association/metrics.py b/python/lsst/ap/association/metrics.py
index 765d5d38..316c2349 100644
--- a/python/lsst/ap/association/metrics.py
+++ b/python/lsst/ap/association/metrics.py
@@ -156,11 +156,11 @@ def makeMeasurement(self, values):
A `dict` representation of the metadata. Each `dict` has the
following keys:
- ``"updatedObjects"``
+ ``updatedObjects``
The number of DIAObjects updated for this image (`int` or
`None`). May be `None` if the image was not
successfully associated.
- ``"unassociatedObjects"``
+ ``unassociatedObjects``
The number of DIAObjects not associated with a DiaSource in
this image (`int` or `None`). May be `None` if the image was
not successfully associated.
@@ -216,7 +216,7 @@ def makeMeasurement(self, values):
A `dict` representation of the metadata. Each `dict` has the
following key:
- ``"numTotalSolarSystemObjects"``
+ ``numTotalSolarSystemObjects``
The number of SolarSystemObjects within the observable detector
area (`int` or `None`). May be `None` if solar system
association was not attempted or the image was not
@@ -264,7 +264,7 @@ def makeMeasurement(self, values):
A `dict` representation of the metadata. Each `dict` has the
following key:
- ``"numAssociatedSsObjects"``
+ ``numAssociatedSsObjects``
The number of successfully associated SolarSystem Objects
(`int` or `None`). May be `None` if solar system association
was not attempted or the image was not successfully associated.
diff --git a/python/lsst/ap/association/skyBotEphemerisQuery.py b/python/lsst/ap/association/skyBotEphemerisQuery.py
index 3dde9d91..2cbb91c7 100644
--- a/python/lsst/ap/association/skyBotEphemerisQuery.py
+++ b/python/lsst/ap/association/skyBotEphemerisQuery.py
@@ -126,43 +126,43 @@ def run(self, visitInfos, visit):
details see
https://ssp.imcce.fr/webservices/skybot/api/conesearch/#output-results
- ``"Num"``
+ ``Num``
object number (`int`, optional)
- ``"Name"``
+ ``Name``
object name (`str`)
- ``"RA(h)"``
+ ``RA(h)``
RA in HMS (`str`)
- ``"DE(deg)"``
+ ``DE(deg)``
DEC in DMS (`str`)
- ``"Class"``
+ ``Class``
Minor planet classification (`str`)
- ``"Mv"``
+ ``Mv``
visual magnitude (`float`)
- ``"Err(arcsec)"``
+ ``Err(arcsec)``
position error (`float`)
- ``"d(arcsec)"``
+ ``d(arcsec)``
distance from exposure boresight (`float`)?
- ``"dRA(arcsec/h)"``
+ ``dRA(arcsec/h)``
proper motion in RA (`float`)
- ``"dDEC(arcsec/h)"``
+ ``dDEC(arcsec/h)``
proper motion in DEC (`float`)
- ``"Dg(ua)"``
+ ``Dg(ua)``
geocentric distance (`float`)
- ``"Dh(ua)"``
+ ``Dh(ua)``
heliocentric distance (`float`)
- ``"Phase(deg)"``
+ ``Phase(deg)``
phase angle (`float`)
- ``"SunElong(deg)"``
+ ``SunElong(deg)``
solar elongation (`float`)
- ``"ra"``
+ ``ra``
RA in decimal degrees (`float`)
- ``"dec"``
+ ``dec``
DEC in decimal degrees (`float`)
- ``"ssObjectId"``
+ ``ssObjectId``
unique minor planet ID for internal use (`int`). Shared
across catalogs; the pair ``(ssObjectId, visitId)`` is
globally unique.
- ``"visitId"``
+ ``visitId``
a copy of ``visit`` (`int`)
"""
# Grab the visitInfo from the raw to get the information needed on the
diff --git a/python/lsst/ap/association/trailedSourceFilter.py b/python/lsst/ap/association/trailedSourceFilter.py
new file mode 100644
index 00000000..6eb20920
--- /dev/null
+++ b/python/lsst/ap/association/trailedSourceFilter.py
@@ -0,0 +1,110 @@
+# This file is part of ap_association.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+__all__ = ("TrailedSourceFilterTask", "TrailedSourceFilterConfig")
+
+import lsst.pex.config as pexConfig
+import lsst.pipe.base as pipeBase
+from lsst.utils.timer import timeMethod
+
+
+class TrailedSourceFilterConfig(pexConfig.Config):
+ """Config class for TrailedSourceFilterTask.
+ """
+
+ max_trail_length = pexConfig.Field(
+ dtype=float,
+ doc="Length of long trailed sources to remove from the input catalog, "
+ "in arcseconds per second. Default comes from DMTN-199, which "
+ "requires removal of sources with trails longer than 10 "
+ "degrees/day, which is 36000/3600/24 arcsec/second, or roughly"
+ "0.416 arcseconds per second.",
+ default=36000/3600.0/24.0,
+ )
+
+
+class TrailedSourceFilterTask(pipeBase.Task):
+ """Find trailed sources in DIASources and filter them as per DMTN-199
+ guidelines.
+
+ This task checks the length of trailLength in the DIASource catalog using
+ a given arcsecond/second rate from max_trail_length and the exposure time.
+ The two values are used to calculate the maximum allowed trail length and
+ filters out any trail longer than the maximum. The max_trail_length is
+ outlined in DMTN-199 and determines the default value.
+ """
+
+ ConfigClass = TrailedSourceFilterConfig
+ _DefaultName = "trailedSourceFilter"
+
+ @timeMethod
+ def run(self, dia_sources, exposure_time):
+ """Remove trailed sources longer than ``config.max_trail_length`` from
+ the input catalog.
+
+ Parameters
+ ----------
+ dia_sources : `pandas.DataFrame`
+ New DIASources to be checked for trailed sources.
+ exposure_time : `float`
+ Exposure time from difference image.
+
+ Returns
+ -------
+ result : `lsst.pipe.base.Struct`
+ Results struct with components.
+
+ - ``dia_sources`` : DIASource table that is free from unwanted
+ trailed sources. (`pandas.DataFrame`)
+
+ - ``trailed_dia_sources`` : DIASources that have trails which
+ exceed max_trail_length/second*exposure_time.
+ (`pandas.DataFrame`)
+ """
+ trail_mask = self._check_dia_source_trail(dia_sources, exposure_time)
+
+ return pipeBase.Struct(
+ diaSources=dia_sources[~trail_mask].reset_index(drop=True),
+ trailedDiaSources=dia_sources[trail_mask].reset_index(drop=True))
+
+ def _check_dia_source_trail(self, dia_sources, exposure_time):
+ """Find DiaSources that have long trails.
+
+ Return a mask of sources with lengths greater than
+ ``config.max_trail_length`` multiplied by the exposure time.
+
+ Parameters
+ ----------
+ dia_sources : `pandas.DataFrame`
+ Input DIASources to check for trail lengths.
+ exposure_time : `float`
+ Exposure time from difference image.
+
+ Returns
+ -------
+ trail_mask : `pandas.DataFrame`
+ Boolean mask for DIASources which are greater than the
+ cutoff length.
+ """
+ trail_mask = (dia_sources.loc[:, "trailLength"].values[:]
+ >= (self.config.max_trail_length*exposure_time))
+
+ return trail_mask
diff --git a/tests/test_association_task.py b/tests/test_association_task.py
index 0a10af81..71ce7bbf 100644
--- a/tests/test_association_task.py
+++ b/tests/test_association_task.py
@@ -22,7 +22,6 @@
import numpy as np
import pandas as pd
import unittest
-
import lsst.geom as geom
import lsst.utils.tests
@@ -46,34 +45,48 @@ def setUp(self):
self.diaSources = pd.DataFrame(data=[
{"ra": 0.04*idx + scatter*rng.uniform(-1, 1),
"dec": 0.04*idx + scatter*rng.uniform(-1, 1),
- "diaSourceId": idx + 1 + self.nObjects, "diaObjectId": 0}
+ "diaSourceId": idx + 1 + self.nObjects, "diaObjectId": 0, "trailLength": 5.5*idx}
for idx in range(self.nSources)])
self.diaSourceZeroScatter = pd.DataFrame(data=[
{"ra": 0.04*idx,
"dec": 0.04*idx,
- "diaSourceId": idx + 1 + self.nObjects, "diaObjectId": 0}
+ "diaSourceId": idx + 1 + self.nObjects, "diaObjectId": 0, "trailLength": 5.5*idx}
for idx in range(self.nSources)])
+ self.exposure_time = 30.0
def test_run(self):
"""Test the full task by associating a set of diaSources to
existing diaObjects.
"""
- assocTask = AssociationTask()
- results = assocTask.run(self.diaSources, self.diaObjects)
+ config = AssociationTask.ConfigClass()
+ config.doTrailedSourceFilter = False
+ assocTask = AssociationTask(config=config)
+ results = assocTask.run(self.diaSources, self.diaObjects, exposure_time=self.exposure_time)
self.assertEqual(results.nUpdatedDiaObjects, len(self.diaObjects) - 1)
self.assertEqual(results.nUnassociatedDiaObjects, 1)
self.assertEqual(len(results.matchedDiaSources),
len(self.diaObjects) - 1)
self.assertEqual(len(results.unAssocDiaSources), 1)
- for test_obj_id, expected_obj_id in zip(
- results.matchedDiaSources["diaObjectId"].to_numpy(),
- [1, 2, 3, 4]):
- self.assertEqual(test_obj_id, expected_obj_id)
- for test_obj_id, expected_obj_id in zip(
- results.unAssocDiaSources["diaObjectId"].to_numpy(),
- [0]):
- self.assertEqual(test_obj_id, expected_obj_id)
+ np.testing.assert_array_equal(results.matchedDiaSources["diaObjectId"].values, [1, 2, 3, 4])
+ np.testing.assert_array_equal(results.unAssocDiaSources["diaObjectId"].values, [0])
+
+ def test_run_trailed_sources(self):
+ """Test the full task by associating a set of diaSources to
+ existing diaObjects when trailed sources are filtered.
+
+ This should filter out two of the five sources based on trail length,
+ leaving one unassociated diaSource and two associated diaSources.
+ """
+ assocTask = AssociationTask()
+ results = assocTask.run(self.diaSources, self.diaObjects, exposure_time=self.exposure_time)
+
+ self.assertEqual(results.nUpdatedDiaObjects, len(self.diaObjects) - 3)
+ self.assertEqual(results.nUnassociatedDiaObjects, 3)
+ self.assertEqual(len(results.matchedDiaSources), len(self.diaObjects) - 3)
+ self.assertEqual(len(results.unAssocDiaSources), 1)
+ np.testing.assert_array_equal(results.matchedDiaSources["diaObjectId"].values, [1, 2])
+ np.testing.assert_array_equal(results.unAssocDiaSources["diaObjectId"].values, [0])
def test_run_no_existing_objects(self):
"""Test the run method with a completely empty database.
@@ -81,7 +94,8 @@ def test_run_no_existing_objects(self):
assocTask = AssociationTask()
results = assocTask.run(
self.diaSources,
- pd.DataFrame(columns=["ra", "dec", "diaObjectId"]))
+ pd.DataFrame(columns=["ra", "dec", "diaObjectId", "trailLength"]),
+ exposure_time=self.exposure_time)
self.assertEqual(results.nUpdatedDiaObjects, 0)
self.assertEqual(results.nUnassociatedDiaObjects, 0)
self.assertEqual(len(results.matchedDiaSources), 0)
@@ -99,6 +113,7 @@ def test_associate_sources(self):
assoc_result.diaSources["diaObjectId"].to_numpy(),
[0, 1, 2, 3, 4]):
self.assertEqual(test_obj_id, expected_obj_id)
+ np.testing.assert_array_equal(assoc_result.diaSources["diaObjectId"].values, [0, 1, 2, 3, 4])
def test_score_and_match(self):
"""Test association between a set of sources and an existing
diff --git a/tests/test_diaPipe.py b/tests/test_diaPipe.py
index 5236e42c..766236c2 100644
--- a/tests/test_diaPipe.py
+++ b/tests/test_diaPipe.py
@@ -121,7 +121,7 @@ def solarSystemAssociator_run(self, unAssocDiaSources, solarSystemObjectTable, d
unAssocDiaSources=MagicMock(spec=pd.DataFrame()))
@lsst.utils.timer.timeMethod
- def associator_run(self, table, diaObjects):
+ def associator_run(self, table, diaObjects, exposure_time=None):
return lsst.pipe.base.Struct(nUpdatedDiaObjects=2, nUnassociatedDiaObjects=3,
matchedDiaSources=MagicMock(spec=pd.DataFrame()),
unAssocDiaSources=MagicMock(spec=pd.DataFrame()))
diff --git a/tests/test_trailedSourceFilter.py b/tests/test_trailedSourceFilter.py
new file mode 100644
index 00000000..62b6e7e0
--- /dev/null
+++ b/tests/test_trailedSourceFilter.py
@@ -0,0 +1,117 @@
+# This file is part of ap_association.
+#
+# Developed for the LSST Data Management System.
+# This product includes software developed by the LSST Project
+# (https://www.lsst.org).
+# See the COPYRIGHT file at the top-level directory of this distribution
+# for details of code ownership.
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program. If not, see .
+
+import unittest
+from lsst.ap.association import TrailedSourceFilterTask
+import numpy as np
+import pandas as pd
+import lsst.utils.tests
+
+
+class TestTrailedSourceFilterTask(unittest.TestCase):
+
+ def setUp(self):
+ """Create sets of diaSources.
+
+ The trail lengths of the dia sources are 0, 5.5, 11, 16.5, 21.5
+ arcseconds.
+ """
+ rng = np.random.default_rng(1234)
+ scatter = 0.1 / 3600
+ self.nSources = 5
+ self.diaSources = pd.DataFrame(data=[
+ {"ra": 0.04*idx + scatter*rng.uniform(-1, 1),
+ "dec": 0.04*idx + scatter*rng.uniform(-1, 1),
+ "diaSourceId": idx, "diaObjectId": 0, "trailLength": 5.5*idx}
+ for idx in range(self.nSources)])
+ self.exposure_time = 30.0
+
+ def test_run(self):
+ """Run trailedSourceFilterTask with the default max distance.
+
+ With the default settings and an exposure of 30 seconds, the max trail
+ length is 12.5 arcseconds. Two out of five of the diaSources will be
+ filtered out of the final results and put into results.trailedSources.
+ """
+ trailedSourceFilterTask = TrailedSourceFilterTask()
+ results = trailedSourceFilterTask.run(self.diaSources, self.exposure_time)
+
+ self.assertEqual(len(results.diaSources), 3)
+ np.testing.assert_array_equal(results.diaSources['diaSourceId'].values, [0, 1, 2])
+ np.testing.assert_array_equal(results.trailedDiaSources['diaSourceId'].values, [3, 4])
+
+ def test_run_short_max_trail(self):
+ """Run trailedSourceFilterTask with aggressive trail length cutoff
+
+ With a max_trail_length config of 0.01 arcseconds/second and an
+ exposure of 30 seconds,the max trail length is 0.3 arcseconds. Only the
+ source with a trail of 0 stays in the catalog and the rest are filtered
+ out and put into results.trailedSources.
+ """
+ config = TrailedSourceFilterTask.ConfigClass()
+ config.max_trail_length = 0.01
+ trailedSourceFilterTask = TrailedSourceFilterTask(config=config)
+ results = trailedSourceFilterTask.run(self.diaSources, self.exposure_time)
+
+ self.assertEqual(len(results.diaSources), 1)
+ np.testing.assert_array_equal(results.diaSources['diaSourceId'].values, [0])
+ np.testing.assert_array_equal(results.trailedDiaSources['diaSourceId'].values, [1, 2, 3, 4])
+
+ def test_run_no_trails(self):
+ """Run trailedSourceFilterTask with a long trail length so that
+ every source in the catalog is in the final diaSource catalog.
+
+ With a max_trail_length config of 10 arcseconds/second and an
+ exposure of 30 seconds,the max trail length is 300 arcseconds. All
+ sources in the initial catalog should be in the final diaSource
+ catalog.
+ """
+ config = TrailedSourceFilterTask.ConfigClass()
+ config.max_trail_length = 10.00
+ trailedSourceFilterTask = TrailedSourceFilterTask(config=config)
+ results = trailedSourceFilterTask.run(self.diaSources, self.exposure_time)
+
+ self.assertEqual(len(results.diaSources), 5)
+ self.assertEqual(len(results.trailedDiaSources), 0)
+ np.testing.assert_array_equal(results.diaSources["diaSourceId"].values, [0, 1, 2, 3, 4])
+ np.testing.assert_array_equal(results.trailedDiaSources["diaSourceId"].values, [])
+
+ def test_check_dia_source_trail(self):
+ """Test the source trail mask filter.
+
+ Test that the mask filter returns the expected mask array.
+ """
+ trailedSourceFilterTask = TrailedSourceFilterTask()
+ mask = trailedSourceFilterTask._check_dia_source_trail(self.diaSources, self.exposure_time)
+ np.testing.assert_array_equal(mask, [False, False, False, True, True])
+
+
+class MemoryTester(lsst.utils.tests.MemoryTestCase):
+ pass
+
+
+def setup_module(module):
+ lsst.utils.tests.init()
+
+
+if __name__ == "__main__":
+ lsst.utils.tests.init()
+ unittest.main()