Skip to content

Commit d464fd7

Browse files
author
Alex Lapin
committed
Utilize SamplesBatch class
1 parent 8abac4f commit d464fd7

File tree

10 files changed

+133
-132
lines changed

10 files changed

+133
-132
lines changed

selene_sdk/evaluate_model.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -209,24 +209,17 @@ def evaluate(self):
209209
"""
210210
batch_losses = []
211211
all_predictions = []
212-
for (inputs, targets) in self._test_data:
213-
inputs = torch.Tensor(inputs)
214-
targets = torch.Tensor(targets[:, self._use_ixs])
212+
for samples_batch in self._test_data:
213+
inputs, targets = samples_batch.torch_inputs_and_targets(self.use_cuda)
214+
targets = targets[:, self._use_ixs]
215215

216-
if self.use_cuda:
217-
inputs = inputs.cuda()
218-
targets = targets.cuda()
219216
with torch.no_grad():
220-
inputs = Variable(inputs)
221-
targets = Variable(targets)
222-
223217
predictions = None
224218
if _is_lua_trained_model(self.model):
225219
predictions = self.model.forward(
226-
inputs.transpose(1, 2).contiguous().unsqueeze_(2))
220+
inputs.contiguous().unsqueeze_(2))
227221
else:
228-
predictions = self.model.forward(
229-
inputs.transpose(1, 2))
222+
predictions = self.model.forward(inputs)
230223
predictions = predictions[:, self._use_ixs]
231224
loss = self.criterion(predictions, targets)
232225

selene_sdk/samplers/file_samplers/bed_file_sampler.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""
22
This module provides the BedFileSampler class.
33
"""
4+
from selene_sdk.samplers.samples_batch import SamplesBatch
45
import numpy as np
56

67
from .file_sampler import FileSampler
@@ -96,8 +97,8 @@ def sample(self, batch_size=1):
9697
9798
Returns
9899
-------
99-
sequences, targets : tuple(numpy.ndarray, numpy.ndarray)
100-
A tuple containing the numeric representation of the
100+
SamplesBatch
101+
A batch containing the numeric representation of the
101102
sequence examples and their corresponding labels. The
102103
shape of `sequences` will be
103104
:math:`B \\times L \\times N`, where :math:`B` is
@@ -163,8 +164,8 @@ def sample(self, batch_size=1):
163164
sequences = np.array(sequences)
164165
if self.targets_avail:
165166
targets = np.array(targets)
166-
return (sequences, targets)
167-
return sequences,
167+
return SamplesBatch(sequences, target_batch=targets)
168+
return SamplesBatch(sequences)
168169

169170
def get_data(self, batch_size, n_samples=None):
170171
"""
@@ -188,18 +189,21 @@ def get_data(self, batch_size, n_samples=None):
188189
and :math:`N` is the size of the sequence type's alphabet.
189190
190191
"""
192+
# TODO: Should this method return a collection of samples_batch.inputs()?
193+
191194
if not n_samples:
192195
n_samples = self.n_samples
193196
sequences = []
194197

195198
count = batch_size
196199
while count < n_samples:
197-
seqs, = self.sample(batch_size=batch_size)
198-
sequences.append(seqs)
200+
samples_batch = self.sample(batch_size=batch_size)
201+
sequences.append(samples_batch.sequence_batch())
199202
count += batch_size
200203
remainder = batch_size - (count - n_samples)
201-
seqs, = self.sample(batch_size=remainder)
202-
sequences.append(seqs)
204+
samples_batch = self.sample(batch_size=remainder)
205+
sequences.append(samples_batch.sequence_batch())
206+
203207
return sequences
204208

205209
def get_data_and_targets(self, batch_size, n_samples=None):
@@ -216,11 +220,11 @@ def get_data_and_targets(self, batch_size, n_samples=None):
216220
217221
Returns
218222
-------
219-
sequences_and_targets, targets_matrix : \
220-
tuple(list(tuple(numpy.ndarray, numpy.ndarray)), numpy.ndarray)
221-
Tuple containing the list of sequence-target pairs, as well
223+
batches, targets_matrix : \
224+
tuple(list(SamplesBatch), numpy.ndarray)
225+
Tuple containing the list of batches, as well
222226
as a single matrix with all targets in the same order.
223-
Note that `sequences_and_targets`'s sequence elements are of
227+
Note that `batches`'s sequence elements are of
224228
the shape :math:`B \\times L \\times N` and its target
225229
elements are of the shape :math:`B \\times F`, where
226230
:math:`B` is `batch_size`, :math:`L` is the sequence length,
@@ -236,18 +240,18 @@ def get_data_and_targets(self, batch_size, n_samples=None):
236240
"Please use `get_data` instead.")
237241
if not n_samples:
238242
n_samples = self.n_samples
239-
sequences_and_targets = []
243+
batches = []
240244
targets_mat = []
241245

242246
count = batch_size
243247
while count < n_samples:
244-
seqs, tgts = self.sample(batch_size=batch_size)
245-
sequences_and_targets.append((seqs, tgts))
246-
targets_mat.append(tgts)
248+
samples_batch = self.sample(batch_size=batch_size)
249+
batches.append(samples_batch)
250+
targets_mat.append(samples_batch.targets())
247251
count += batch_size
248252
remainder = batch_size - (count - n_samples)
249-
seqs, tgts = self.sample(batch_size=remainder)
250-
sequences_and_targets.append((seqs, tgts))
251-
targets_mat.append(tgts)
253+
samples_batch = self.sample(batch_size=remainder)
254+
batches.append(samples_batch)
255+
targets_mat.append(samples_batch.targets())
252256
targets_mat = np.vstack(targets_mat).astype(int)
253-
return sequences_and_targets, targets_mat
257+
return batches, targets_mat

selene_sdk/samplers/file_samplers/file_sampler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from abc import ABCMeta
99
from abc import abstractmethod
1010

11+
from selene_sdk.samplers.samples_batch import SamplesBatch
12+
1113

1214
class FileSampler(metaclass=ABCMeta):
1315
"""
@@ -26,7 +28,7 @@ def __init__(self):
2628
"""
2729

2830
@abstractmethod
29-
def sample(self, batch_size=1):
31+
def sample(self, batch_size=1) -> SamplesBatch:
3032
"""
3133
Fetches a mini-batch of the data from the sampler.
3234

selene_sdk/samplers/file_samplers/mat_file_sampler.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import h5py
66
import numpy as np
77
import scipy.io
8+
from selene_sdk.samplers.samples_batch import SamplesBatch
89

910
from .file_sampler import FileSampler
1011

@@ -126,8 +127,8 @@ def sample(self, batch_size=1):
126127
127128
Returns
128129
-------
129-
sequences, targets : tuple(numpy.ndarray, numpy.ndarray)
130-
A tuple containing the numeric representation of the
130+
SamplesBatch
131+
A batch containing the numeric representation of the
131132
sequence examples and their corresponding labels. The
132133
shape of `sequences` will be
133134
:math:`B \\times L \\times N`, where :math:`B` is
@@ -166,8 +167,8 @@ def sample(self, batch_size=1):
166167
targets = self._sample_tgts[:, use_indices].astype(float)
167168
targets = np.transpose(
168169
targets, (1, 0))
169-
return (sequences, targets)
170-
return sequences,
170+
return SamplesBatch(sequences, target_batch=targets)
171+
return SamplesBatch(sequences)
171172

172173
def get_data(self, batch_size, n_samples=None):
173174
"""
@@ -190,18 +191,20 @@ def get_data(self, batch_size, n_samples=None):
190191
is `batch_size`, :math:`L` is the sequence length,
191192
and :math:`N` is the size of the sequence type's alphabet.
192193
"""
194+
# TODO: Should this method return a collection of samples_batch.inputs()?
195+
193196
if not n_samples:
194197
n_samples = self.n_samples
195198
sequences = []
196199

197200
count = batch_size
198201
while count < n_samples:
199-
seqs, = self.sample(batch_size=batch_size)
200-
sequences.append(seqs)
202+
samples_batch = self.sample(batch_size=batch_size)
203+
sequences.append(samples_batch.sequence_batch())
201204
count += batch_size
202205
remainder = batch_size - (count - n_samples)
203-
seqs, = self.sample(batch_size=remainder)
204-
sequences.append(seqs)
206+
samples_batch = self.sample(batch_size=remainder)
207+
sequences.append(samples_batch.sequence_batch())
205208
return sequences
206209

207210
def get_data_and_targets(self, batch_size, n_samples=None):
@@ -218,11 +221,11 @@ def get_data_and_targets(self, batch_size, n_samples=None):
218221
219222
Returns
220223
-------
221-
sequences_and_targets, targets_matrix : \
222-
tuple(list(tuple(numpy.ndarray, numpy.ndarray)), numpy.ndarray)
223-
Tuple containing the list of sequence-target pairs, as well
224+
batches, targets_matrix : \
225+
tuple(list(SamplesBatch), numpy.ndarray)
226+
Tuple containing the list of batches, as well
224227
as a single matrix with all targets in the same order.
225-
Note that `sequences_and_targets`'s sequence elements are of
228+
Note that `batches`'s sequence elements are of
226229
the shape :math:`B \\times L \\times N` and its target
227230
elements are of the shape :math:`B \\times F`, where
228231
:math:`B` is `batch_size`, :math:`L` is the sequence length,
@@ -237,19 +240,19 @@ def get_data_and_targets(self, batch_size, n_samples=None):
237240
"initialization. Please use `get_data` instead.")
238241
if not n_samples:
239242
n_samples = self.n_samples
240-
sequences_and_targets = []
243+
batches = []
241244
targets_mat = []
242245

243246
count = batch_size
244247
while count < n_samples:
245-
seqs, tgts = self.sample(batch_size=batch_size)
246-
sequences_and_targets.append((seqs, tgts))
247-
targets_mat.append(tgts)
248+
samples_batch = self.sample(batch_size=batch_size)
249+
batches.append(samples_batch)
250+
targets_mat.append(samples_batch.targets())
248251
count += batch_size
249252
remainder = batch_size - (count - n_samples)
250-
seqs, tgts = self.sample(batch_size=remainder)
251-
sequences_and_targets.append((seqs, tgts))
252-
targets_mat.append(tgts)
253+
samples_batch = self.sample(batch_size=remainder)
254+
batches.append(samples_batch)
255+
targets_mat.append(samples_batch.targets())
253256
# TODO: should not assume targets are always integers
254257
targets_mat = np.vstack(targets_mat).astype(float)
255-
return sequences_and_targets, targets_mat
258+
return batches, targets_mat

selene_sdk/samplers/intervals_sampler.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
This module provides the `IntervalsSampler` class and supporting
33
methods.
44
"""
5-
from collections import namedtuple
65
import logging
76
import random
7+
from collections import namedtuple
88

99
import numpy as np
1010

11-
from .online_sampler import OnlineSampler
11+
from selene_sdk.samplers.samples_batch import SamplesBatch
1212
from ..utils import get_indices_and_probabilities
13+
from .online_sampler import OnlineSampler
1314

1415
logger = logging.getLogger(__name__)
1516

@@ -388,8 +389,8 @@ def sample(self, batch_size=1):
388389
389390
Returns
390391
-------
391-
sequences, targets : tuple(numpy.ndarray, numpy.ndarray)
392-
A tuple containing the numeric representation of the
392+
SamplesBatch
393+
A batch containing the numeric representation of the
393394
sequence examples and their corresponding labels. The
394395
shape of `sequences` will be
395396
:math:`B \\times L \\times N`, where :math:`B` is
@@ -426,4 +427,4 @@ def sample(self, batch_size=1):
426427
sequences[n_samples_drawn, :, :] = seq
427428
targets[n_samples_drawn, :] = seq_targets
428429
n_samples_drawn += 1
429-
return (sequences, targets)
430+
return SamplesBatch(sequences, target_batch=targets)

selene_sdk/samplers/multi_file_sampler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,11 +186,11 @@ def get_test_set(self, batch_size, n_samples=None):
186186
187187
Returns
188188
-------
189-
sequences_and_targets, targets_matrix : \
190-
tuple(list(tuple(numpy.ndarray, numpy.ndarray)), numpy.ndarray)
191-
Tuple containing the list of sequence-target pairs, as well
189+
batches, targets_matrix : \
190+
tuple(list(SamplesBatch), numpy.ndarray)
191+
Tuple containing the list of batches, as well
192192
as a single matrix with all targets in the same order.
193-
Note that `sequences_and_targets`'s sequence elements are of
193+
Note that `batches`'s sequence elements are of
194194
the shape :math:`B \\times L \\times N` and its target
195195
elements are of the shape :math:`B \\times F`, where
196196
:math:`B` is `batch_size`, :math:`L` is the sequence length,

0 commit comments

Comments
 (0)