diff --git a/tests/transform_tests.py b/tests/transform_tests.py index 60c33e4f..d7ede562 100644 --- a/tests/transform_tests.py +++ b/tests/transform_tests.py @@ -5,12 +5,15 @@ import numpy as np import warnings -from flowkit import Sample, transforms +from flowkit import Sample, transforms, generate_transforms data1_fcs_path = 'data/gate_ref/data1.fcs' data1_sample = Sample(data1_fcs_path) data1_raw_events = data1_sample.get_events(source='raw') +# Sample with null channel +null_chan_sample = Sample(data1_fcs_path, null_channel_list=['FL1-H']) + test_data_range1 = np.linspace(0.0, 10.0, 101) @@ -190,3 +193,136 @@ def test_inverse_wsp_biex_transform(): np.testing.assert_array_almost_equal(test_data_range1, x, decimal=10) + def test_generate_transforms_defaults(self): + xform_lut = generate_transforms(data1_sample) + + self.assertEqual(len(xform_lut), len(data1_sample.pnn_labels)) + + # pick a fluoro channel label and check Transform type + fluoro_label = data1_sample.pnn_labels[data1_sample.fluoro_indices[0]] + self.assertIsInstance(xform_lut[fluoro_label], transforms.LogicleTransform) + + # verify time use max time + time_max = data1_sample.get_channel_events(data1_sample.time_index, source='raw').max() + self.assertEqual(xform_lut['Time'].param_t, time_max) + + def test_generate_transforms_default_asinh(self): + xform_lut = generate_transforms( + data1_sample, + fluoro_xform_class=transforms.AsinhTransform + ) + + self.assertEqual(len(xform_lut), len(data1_sample.pnn_labels)) + + # pick a fluoro channel label and check Transform type + fluoro_label = data1_sample.pnn_labels[data1_sample.fluoro_indices[0]] + self.assertIsInstance(xform_lut[fluoro_label], transforms.AsinhTransform) + + def test_generate_transforms_default_hyperlog(self): + xform_lut = generate_transforms( + data1_sample, + fluoro_xform_class=transforms.HyperlogTransform + ) + + self.assertEqual(len(xform_lut), len(data1_sample.pnn_labels)) + + # pick a fluoro channel label and check Transform type + fluoro_label = data1_sample.pnn_labels[data1_sample.fluoro_indices[0]] + self.assertIsInstance(xform_lut[fluoro_label], transforms.HyperlogTransform) + + def test_generate_transforms_default_log(self): + xform_lut = generate_transforms( + data1_sample, + fluoro_xform_class=transforms.LogTransform + ) + + self.assertEqual(len(xform_lut), len(data1_sample.pnn_labels)) + + # pick a fluoro channel label and check Transform type + fluoro_label = data1_sample.pnn_labels[data1_sample.fluoro_indices[0]] + self.assertIsInstance(xform_lut[fluoro_label], transforms.LogTransform) + + def test_generate_transforms_default_wsp_biex(self): + xform_lut = generate_transforms( + data1_sample, + fluoro_xform_class=transforms.WSPBiexTransform + ) + + self.assertEqual(len(xform_lut), len(data1_sample.pnn_labels)) + + # pick a fluoro channel label and check Transform type + fluoro_label = data1_sample.pnn_labels[data1_sample.fluoro_indices[0]] + self.assertIsInstance(xform_lut[fluoro_label], transforms.WSPBiexTransform) + + def test_generate_transforms_default_null_channel(self): + # null channel sample has 'FL1-H' nullified + xform_lut = generate_transforms(null_chan_sample) + + self.assertEqual(len(xform_lut), len(data1_sample.pnn_labels) - 1) + + # ensure null channel label is missing from xform LUT + null_label = 'FL1-H' + self.assertNotIn(null_label, xform_lut) + + def test_generate_transforms_transform_not_supported(self): + self.assertRaises( + NotImplementedError, + generate_transforms, + data1_sample, + fluoro_xform_class=transforms.RatioTransform + ) + + def test_generate_transforms_transform_instance_scatter(self): + # specify a Transform instance for scatter channels + scatter_xform = transforms.LogTransform(param_t=262144, param_m=4.1) + + xform_lut = generate_transforms( + data1_sample, + scatter_xform_class=scatter_xform + ) + + self.assertEqual(len(xform_lut), len(data1_sample.pnn_labels)) + + # pick a scatter channel label and check Transform type + scatter_label = data1_sample.pnn_labels[data1_sample.scatter_indices[0]] + scatter_xform = xform_lut[scatter_label] + self.assertIsInstance(scatter_xform, transforms.LogTransform) + + self.assertEqual(scatter_xform.param_m, 4.1) + + def test_generate_transforms_transform_instance_fluoro(self): + # specify a Transform instance for fluoro channels + fluoro_xform = transforms.AsinhTransform(param_t=262144, param_m=4.1, param_a=0.0) + + xform_lut = generate_transforms( + data1_sample, + fluoro_xform_class=fluoro_xform + ) + + self.assertEqual(len(xform_lut), len(data1_sample.pnn_labels)) + + # pick a fluoro channel label and check Transform type + fluoro_label = data1_sample.pnn_labels[data1_sample.fluoro_indices[0]] + fluoro_xform = xform_lut[fluoro_label] + self.assertIsInstance(fluoro_xform, transforms.AsinhTransform) + + self.assertEqual(fluoro_xform.param_m, 4.1) + + def test_generate_transforms_transform_instance_time(self): + # specify a Transform instance for time channel + # make up a max time to check the custom instance is returned + time_max = 123 + time_xform = transforms.LinearTransform(param_t=time_max, param_a=0.0) + + xform_lut = generate_transforms( + data1_sample, + time_xform_class=time_xform + ) + + self.assertEqual(len(xform_lut), len(data1_sample.pnn_labels)) + + # check Transform type & params + time_xform = xform_lut['Time'] + self.assertIsInstance(time_xform, transforms.LinearTransform) + + self.assertEqual(time_xform.param_t, time_max)