From 0a0f548ab21045d3303564473eba39975972f539 Mon Sep 17 00:00:00 2001 From: Xuechen Li <12689993+lxuechen@users.noreply.github.com> Date: Wed, 8 Jul 2020 15:04:05 -0400 Subject: [PATCH] Fix to device in brownian utils. --- tests/test_brownian_tree.py | 38 +++++++++++++++++++++--------- torchsde/brownian/brownian_path.py | 11 +++++---- torchsde/brownian/brownian_tree.py | 20 ++++++++++------ torchsde/brownian/utils.py | 4 ++++ 4 files changed, 51 insertions(+), 22 deletions(-) diff --git a/tests/test_brownian_tree.py b/tests/test_brownian_tree.py index d941a55..7c5c239 100644 --- a/tests/test_brownian_tree.py +++ b/tests/test_brownian_tree.py @@ -30,37 +30,38 @@ torch.set_default_dtype(torch.float64) D = 3 -BATCH_SIZE = 16384 +SMALL_BATCH_SIZE = 16 +LARGE_BATCH_SIZE = 16384 REPS = 3 ALPHA = 0.001 class TestBrownianTree(TorchTestCase): - def _setUp(self, device=None): + def _setUp(self, batch_size, device=None): t0, t1 = torch.tensor([0., 1.]).to(device) - w0 = torch.zeros(BATCH_SIZE, D).to(device=device) - w1 = torch.randn(BATCH_SIZE, D).to(device=device) + w0 = torch.zeros(batch_size, D).to(device=device) + w1 = torch.randn(batch_size, D).to(device=device) t = torch.rand([]).to(device) self.t = t self.bm = BrownianTree(t0=t0, t1=t1, w0=w0, w1=w1, entropy=0) def test_basic_cpu(self): - self._setUp(device=torch.device('cpu')) + self._setUp(batch_size=SMALL_BATCH_SIZE, device=torch.device('cpu')) sample = self.bm(self.t) - self.assertEqual(sample.size(), (BATCH_SIZE, D)) + self.assertEqual(sample.size(), (SMALL_BATCH_SIZE, D)) def test_basic_gpu(self): if not torch.cuda.is_available(): self.skipTest(reason='CUDA not available.') - self._setUp(device=torch.device('cuda')) + self._setUp(batch_size=SMALL_BATCH_SIZE, device=torch.device('cuda')) sample = self.bm(self.t) - self.assertEqual(sample.size(), (BATCH_SIZE, D)) + self.assertEqual(sample.size(), (SMALL_BATCH_SIZE, D)) def test_determinism(self): - self._setUp() + self._setUp(batch_size=SMALL_BATCH_SIZE) vals = [self.bm(self.t) for _ in range(REPS)] for val in vals[1:]: self.tensorAssertAllClose(val, vals[0]) @@ -73,8 +74,8 @@ def test_normality(self): for _ in range(REPS): w0_, w1_ = 0.0, npr.randn() # Use the same endpoint for the batch, so samples from same dist. - w0 = torch.tensor(w0_).repeat(BATCH_SIZE) - w1 = torch.tensor(w1_).repeat(BATCH_SIZE) + w0 = torch.tensor(w0_).repeat(LARGE_BATCH_SIZE) + w1 = torch.tensor(w1_).repeat(LARGE_BATCH_SIZE) bm = BrownianTree(t0=t0, t1=t1, w0=w0, w1=w1, pool_size=100, tol=1e-14) for _ in range(REPS): @@ -89,6 +90,21 @@ def test_normality(self): _, pval = kstest(samples_, ref_dist.cdf) self.assertGreaterEqual(pval, ALPHA) + def test_to(self): + if not torch.cuda.is_available(): + self.skipTest(reason='CUDA not available.') + + self._setUp(batch_size=SMALL_BATCH_SIZE) + cache = self.bm.get_cache() + old = torch.cat(list(cache['ws_prev']) + list(cache['ws']) + list(cache['ws_post']), dim=0) + + gpu = torch.device('cuda') + self.bm.to(gpu) + cache = self.bm.get_cache() + new = torch.cat(list(cache['ws_prev']) + list(cache['ws']) + list(cache['ws_post']), dim=0) + self.assertTrue(str(new.device).startswith('cuda')) + self.tensorAssertAllClose(old, new.cpu()) + if __name__ == '__main__': unittest.main() diff --git a/torchsde/brownian/brownian_path.py b/torchsde/brownian/brownian_path.py index 2a49158..2df64cd 100644 --- a/torchsde/brownian/brownian_path.py +++ b/torchsde/brownian/brownian_path.py @@ -134,10 +134,7 @@ def __repr__(self): ) def to(self, *args, **kwargs): - ws_new = blist.blist() - for w in self._ws: - ws_new.append(w.to(*args, **kwargs)) - self._ws = ws_new + self._ws = utils.blist_to(self._ws, *args, **kwargs) @property def dtype(self): @@ -153,3 +150,9 @@ def size(self): def __len__(self): return len(self._ts) + + def get_cache(self): + return { + 'ts': self._ts, + 'ws': self._ws, + } diff --git a/torchsde/brownian/brownian_tree.py b/torchsde/brownian/brownian_tree.py index 7efabfa..387e21a 100644 --- a/torchsde/brownian/brownian_tree.py +++ b/torchsde/brownian/brownian_tree.py @@ -138,9 +138,9 @@ def __repr__(self): ) def to(self, *args, **kwargs): - self._ws_prev = _list_to(self._ws_prev, *args, **kwargs) - self._ws_post = _list_to(self._ws_post, *args, **kwargs) - self._ws = _list_to(self._ws, *args, **kwargs) + self._ws_prev = utils.blist_to(self._ws_prev, *args, **kwargs) + self._ws_post = utils.blist_to(self._ws_post, *args, **kwargs) + self._ws = utils.blist_to(self._ws, *args, **kwargs) @property def dtype(self): @@ -157,6 +157,16 @@ def size(self): def __len__(self): return len(self._ts) + len(self._ts_prev) + len(self._ts_post) + def get_cache(self): + return { + 'ts_prev': self._ts_prev, + 'ts': self._ts, + 'ts_post': self._ts_post, + 'ws_prev': self._ws_prev, + 'ws': self._ws, + 'ws_post': self._ws_post + } + def _binary_search(t0, t1, w0, w1, t, parent, tol): seedv, seedl, seedr = parent.spawn(3) @@ -211,7 +221,3 @@ def _create_cache(t0, t1, w0, w1, entropy, pool_size, k): seeds = new_seeds return ts, ws, seeds - - -def _list_to(l, *args, **kwargs): - return [li.to(*args, **kwargs) for li in l] diff --git a/torchsde/brownian/utils.py b/torchsde/brownian/utils.py index 778717a..a84ea9f 100644 --- a/torchsde/brownian/utils.py +++ b/torchsde/brownian/utils.py @@ -94,3 +94,7 @@ def brownian_bridge(t0: float, t1: float, w0, w1, t: float, seed=None): def is_scalar(x): return isinstance(x, int) or isinstance(x, float) or (isinstance(x, torch.Tensor) and x.numel() == 1) + + +def blist_to(l, *args, **kwargs): + return blist.blist([li.to(*args, **kwargs) for li in l])