From 735ccbdb8fed0be009c42c563021208e4c6cabb3 Mon Sep 17 00:00:00 2001 From: Hananel Hazan Date: Fri, 14 Aug 2020 11:49:53 -0400 Subject: [PATCH] Move reward modulation to GPU --- bindsnet/learning/learning.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/bindsnet/learning/learning.py b/bindsnet/learning/learning.py index 9c582f78..63115165 100644 --- a/bindsnet/learning/learning.py +++ b/bindsnet/learning/learning.py @@ -13,6 +13,7 @@ ) from ..utils import im2col_indices + class LearningRule(ABC): # language=rst """ @@ -547,11 +548,17 @@ def _connection_update(self, **kwargs) -> None: # Initialize eligibility, P^+, and P^-. if not hasattr(self, "p_plus"): - self.p_plus = torch.zeros(batch_size, *self.source.shape) + self.p_plus = torch.zeros( + batch_size, *self.source.shape, device=self.source.s.device + ) if not hasattr(self, "p_minus"): - self.p_minus = torch.zeros(batch_size, *self.target.shape) + self.p_minus = torch.zeros( + batch_size, *self.target.shape, device=self.target.s.device + ) if not hasattr(self, "eligibility"): - self.eligibility = torch.zeros(batch_size, *self.connection.w.shape) + self.eligibility = torch.zeros( + batch_size, *self.connection.w.shape, device=self.connection.w.device + ) # Reshape pre- and post-synaptic spikes. source_s = self.source.s.view(batch_size, -1).float()