Skip to content

Commit

Permalink
Merge pull request BindsNET#406 from BindsNET/hananel
Browse files Browse the repository at this point in the history
Move reward modulation inputs to GPU
  • Loading branch information
Hananel-Hazan authored Aug 16, 2020
2 parents 50aa51d + 735ccbd commit ffe74bb
Showing 1 changed file with 10 additions and 3 deletions.
13 changes: 10 additions & 3 deletions bindsnet/learning/learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
)
from ..utils import im2col_indices


class LearningRule(ABC):
# language=rst
"""
Expand Down Expand Up @@ -547,11 +548,17 @@ def _connection_update(self, **kwargs) -> None:

# Initialize eligibility, P^+, and P^-.
if not hasattr(self, "p_plus"):
self.p_plus = torch.zeros(batch_size, *self.source.shape)
self.p_plus = torch.zeros(
batch_size, *self.source.shape, device=self.source.s.device
)
if not hasattr(self, "p_minus"):
self.p_minus = torch.zeros(batch_size, *self.target.shape)
self.p_minus = torch.zeros(
batch_size, *self.target.shape, device=self.target.s.device
)
if not hasattr(self, "eligibility"):
self.eligibility = torch.zeros(batch_size, *self.connection.w.shape)
self.eligibility = torch.zeros(
batch_size, *self.connection.w.shape, device=self.connection.w.device
)

# Reshape pre- and post-synaptic spikes.
source_s = self.source.s.view(batch_size, -1).float()
Expand Down

0 comments on commit ffe74bb

Please sign in to comment.