Skip to content

Commit 8abac4f

Browse files
author
Alex Lapin
committed
Add SamplesBatch class
1 parent 9bfa1fd commit 8abac4f

File tree

6 files changed

+154
-0
lines changed

6 files changed

+154
-0
lines changed

selene_sdk/predict/tests/__init__.py

Whitespace-only changes.

selene_sdk/samplers/samples_batch.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import numpy as np
2+
import torch
3+
4+
5+
class SamplesBatch:
6+
"""
7+
This class represents NN inputs and targets. Values are stored as numpy.ndarrays
8+
and there is a method to convert them to torch.Tensors.
9+
10+
Inputs are stored in a dict, which can be used if you are providing more than just a
11+
`sequence_batch` to the NN.
12+
13+
NOTE: If you store just a sequence as an input to the model, then `inputs()` and
14+
`torch_inputs_and_targets()` will return only the batch of sequences rather than
15+
a dict.
16+
17+
"""
18+
19+
_SEQUENCE_LABEL = "sequence_batch"
20+
21+
def __init__(
22+
self,
23+
sequence_batch: np.ndarray,
24+
other_input_batches=dict(),
25+
target_batch: np.ndarray = None,
26+
) -> None:
27+
self._input_batches = other_input_batches.copy()
28+
self._input_batches[self._SEQUENCE_LABEL] = sequence_batch
29+
self._target_batch = target_batch
30+
31+
def sequence_batch(self) -> torch.Tensor:
32+
"""Returns the sequence batch with a shape of
33+
[batch_size, sequence_length, alphabet_size].
34+
"""
35+
return self._input_batches[self._SEQUENCE_LABEL]
36+
37+
def inputs(self):
38+
"""Based on the size of inputs dictionary, returns either just the
39+
sequence or the whole dictionary.
40+
41+
Returns
42+
-------
43+
numpy.ndarray or dict[str, numpy.ndarray]
44+
numpy.ndarray is returned when inputs contain just the sequence batch.
45+
dict[str, numpy.ndarray] is returned when there are multiple inputs.
46+
47+
NOTE: Sequence batch has a shape of
48+
[batch_size, sequence_length, alphabet_size].
49+
"""
50+
if len(self._input_batches) == 1:
51+
return self.sequence_batch()
52+
53+
return self._input_batches
54+
55+
def targets(self):
56+
"""Returns target batch if it is present.
57+
58+
Returns
59+
-------
60+
numpy.ndarray
61+
62+
"""
63+
return self._target_batch
64+
65+
def torch_inputs_and_targets(self, use_cuda: bool):
66+
"""
67+
Returns inputs and targets in torch.Tensor format.
68+
69+
Based on the size of inputs dictionary, returns either just the
70+
sequence or the whole dictionary.
71+
72+
Returns
73+
-------
74+
inputs, targets :\
75+
tuple(numpy.ndarray or dict[str, numpy.ndarray], numpy.ndarray)
76+
For `inputs`:
77+
numpy.ndarray is returned when inputs contain just the sequence batch.
78+
dict[str, numpy.ndarray] is returned when there are multiple inputs.
79+
80+
NOTE: Returned sequence batch has a shape of
81+
[batch_size, alphabet_size, sequence_length].
82+
83+
"""
84+
all_inputs = dict()
85+
for key, value in self._input_batches.items():
86+
all_inputs[key] = torch.Tensor(value)
87+
88+
if use_cuda:
89+
all_inputs[key] = all_inputs[key].cuda()
90+
91+
# Transpose the sequences to satisfy NN convolution input format (which is
92+
# [batch_size, channels_size, sequence_length]).
93+
all_inputs[self._SEQUENCE_LABEL] = all_inputs[self._SEQUENCE_LABEL].transpose(
94+
1, 2
95+
)
96+
97+
inputs = all_inputs if len(all_inputs) > 1 else all_inputs[self._SEQUENCE_LABEL]
98+
99+
targets = None
100+
if self._target_batch is not None:
101+
targets = torch.Tensor(self._target_batch)
102+
103+
if use_cuda:
104+
targets = targets.cuda()
105+
106+
return inputs, targets

selene_sdk/samplers/tests/__init__.py

Whitespace-only changes.
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import unittest
2+
3+
import numpy as np
4+
import torch
5+
from selene_sdk.samplers.samples_batch import SamplesBatch
6+
7+
8+
class TestSamplesBatch(unittest.TestCase):
9+
def test_single_input(self):
10+
samples_batch = SamplesBatch(np.ones((6, 200, 4)), target_batch=np.ones(20))
11+
12+
inputs = samples_batch.inputs()
13+
self.assertIsInstance(inputs, np.ndarray)
14+
self.assertSequenceEqual(inputs.shape, (6, 200, 4))
15+
16+
torch_inputs, _ = samples_batch.torch_inputs_and_targets(use_cuda=False)
17+
self.assertIsInstance(torch_inputs, torch.Tensor)
18+
self.assertSequenceEqual(torch_inputs.shape, (6, 4, 200))
19+
20+
def test_multiple_inputs(self):
21+
samples_batch = SamplesBatch(
22+
np.ones((6, 200, 4)),
23+
other_input_batches={"something": np.ones(10)},
24+
target_batch=np.ones(20),
25+
)
26+
27+
inputs = samples_batch.inputs()
28+
self.assertIsInstance(inputs, dict)
29+
self.assertEqual(len(inputs), 2)
30+
self.assertSequenceEqual(inputs["sequence_batch"].shape, (6, 200, 4))
31+
32+
torch_inputs, _ = samples_batch.torch_inputs_and_targets(use_cuda=False)
33+
self.assertIsInstance(torch_inputs, dict)
34+
self.assertEqual(len(torch_inputs), 2)
35+
self.assertSequenceEqual(torch_inputs["sequence_batch"].shape, (6, 4, 200))
36+
37+
def test_has_target(self):
38+
samples_batch = SamplesBatch(np.ones((6, 200, 4)), target_batch=np.ones(20))
39+
targets = samples_batch.targets()
40+
self.assertIsInstance(targets, np.ndarray)
41+
_, torch_targets = samples_batch.torch_inputs_and_targets(use_cuda=False)
42+
self.assertIsInstance(torch_targets, torch.Tensor)
43+
44+
def test_no_target(self):
45+
samples_batch = SamplesBatch(np.ones((6, 200, 4)))
46+
self.assertIsNone(samples_batch.targets())
47+
_, torch_targets = samples_batch.torch_inputs_and_targets(use_cuda=False)
48+
self.assertIsNone(torch_targets)

selene_sdk/sequences/tests/__init__.py

Whitespace-only changes.

selene_sdk/targets/tests/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)