Skip to content

Commit

Permalink
tests for generate_transforms (closes #162)
Browse files Browse the repository at this point in the history
  • Loading branch information
whitews committed May 26, 2024
1 parent 54130c8 commit b424801
Showing 1 changed file with 137 additions and 1 deletion.
138 changes: 137 additions & 1 deletion tests/transform_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@
import numpy as np
import warnings

from flowkit import Sample, transforms
from flowkit import Sample, transforms, generate_transforms

data1_fcs_path = 'data/gate_ref/data1.fcs'
data1_sample = Sample(data1_fcs_path)
data1_raw_events = data1_sample.get_events(source='raw')

# Sample with null channel
null_chan_sample = Sample(data1_fcs_path, null_channel_list=['FL1-H'])

test_data_range1 = np.linspace(0.0, 10.0, 101)


Expand Down Expand Up @@ -190,3 +193,136 @@ def test_inverse_wsp_biex_transform():

np.testing.assert_array_almost_equal(test_data_range1, x, decimal=10)

def test_generate_transforms_defaults(self):
xform_lut = generate_transforms(data1_sample)

self.assertEqual(len(xform_lut), len(data1_sample.pnn_labels))

# pick a fluoro channel label and check Transform type
fluoro_label = data1_sample.pnn_labels[data1_sample.fluoro_indices[0]]
self.assertIsInstance(xform_lut[fluoro_label], transforms.LogicleTransform)

# verify time use max time
time_max = data1_sample.get_channel_events(data1_sample.time_index, source='raw').max()
self.assertEqual(xform_lut['Time'].param_t, time_max)

def test_generate_transforms_default_asinh(self):
xform_lut = generate_transforms(
data1_sample,
fluoro_xform_class=transforms.AsinhTransform
)

self.assertEqual(len(xform_lut), len(data1_sample.pnn_labels))

# pick a fluoro channel label and check Transform type
fluoro_label = data1_sample.pnn_labels[data1_sample.fluoro_indices[0]]
self.assertIsInstance(xform_lut[fluoro_label], transforms.AsinhTransform)

def test_generate_transforms_default_hyperlog(self):
xform_lut = generate_transforms(
data1_sample,
fluoro_xform_class=transforms.HyperlogTransform
)

self.assertEqual(len(xform_lut), len(data1_sample.pnn_labels))

# pick a fluoro channel label and check Transform type
fluoro_label = data1_sample.pnn_labels[data1_sample.fluoro_indices[0]]
self.assertIsInstance(xform_lut[fluoro_label], transforms.HyperlogTransform)

def test_generate_transforms_default_log(self):
xform_lut = generate_transforms(
data1_sample,
fluoro_xform_class=transforms.LogTransform
)

self.assertEqual(len(xform_lut), len(data1_sample.pnn_labels))

# pick a fluoro channel label and check Transform type
fluoro_label = data1_sample.pnn_labels[data1_sample.fluoro_indices[0]]
self.assertIsInstance(xform_lut[fluoro_label], transforms.LogTransform)

def test_generate_transforms_default_wsp_biex(self):
xform_lut = generate_transforms(
data1_sample,
fluoro_xform_class=transforms.WSPBiexTransform
)

self.assertEqual(len(xform_lut), len(data1_sample.pnn_labels))

# pick a fluoro channel label and check Transform type
fluoro_label = data1_sample.pnn_labels[data1_sample.fluoro_indices[0]]
self.assertIsInstance(xform_lut[fluoro_label], transforms.WSPBiexTransform)

def test_generate_transforms_default_null_channel(self):
# null channel sample has 'FL1-H' nullified
xform_lut = generate_transforms(null_chan_sample)

self.assertEqual(len(xform_lut), len(data1_sample.pnn_labels) - 1)

# ensure null channel label is missing from xform LUT
null_label = 'FL1-H'
self.assertNotIn(null_label, xform_lut)

def test_generate_transforms_transform_not_supported(self):
self.assertRaises(
NotImplementedError,
generate_transforms,
data1_sample,
fluoro_xform_class=transforms.RatioTransform
)

def test_generate_transforms_transform_instance_scatter(self):
# specify a Transform instance for scatter channels
scatter_xform = transforms.LogTransform(param_t=262144, param_m=4.1)

xform_lut = generate_transforms(
data1_sample,
scatter_xform_class=scatter_xform
)

self.assertEqual(len(xform_lut), len(data1_sample.pnn_labels))

# pick a scatter channel label and check Transform type
scatter_label = data1_sample.pnn_labels[data1_sample.scatter_indices[0]]
scatter_xform = xform_lut[scatter_label]
self.assertIsInstance(scatter_xform, transforms.LogTransform)

self.assertEqual(scatter_xform.param_m, 4.1)

def test_generate_transforms_transform_instance_fluoro(self):
# specify a Transform instance for fluoro channels
fluoro_xform = transforms.AsinhTransform(param_t=262144, param_m=4.1, param_a=0.0)

xform_lut = generate_transforms(
data1_sample,
fluoro_xform_class=fluoro_xform
)

self.assertEqual(len(xform_lut), len(data1_sample.pnn_labels))

# pick a fluoro channel label and check Transform type
fluoro_label = data1_sample.pnn_labels[data1_sample.fluoro_indices[0]]
fluoro_xform = xform_lut[fluoro_label]
self.assertIsInstance(fluoro_xform, transforms.AsinhTransform)

self.assertEqual(fluoro_xform.param_m, 4.1)

def test_generate_transforms_transform_instance_time(self):
# specify a Transform instance for time channel
# make up a max time to check the custom instance is returned
time_max = 123
time_xform = transforms.LinearTransform(param_t=time_max, param_a=0.0)

xform_lut = generate_transforms(
data1_sample,
time_xform_class=time_xform
)

self.assertEqual(len(xform_lut), len(data1_sample.pnn_labels))

# check Transform type & params
time_xform = xform_lut['Time']
self.assertIsInstance(time_xform, transforms.LinearTransform)

self.assertEqual(time_xform.param_t, time_max)

0 comments on commit b424801

Please sign in to comment.