Skip to content

Commit

Permalink
Use plan_one_thread_per_row_gemv in Simulator
Browse files Browse the repository at this point in the history
Signed-off-by: Shaun Ren <[email protected]>
  • Loading branch information
shaunren committed Aug 1, 2016
1 parent 98e2fa8 commit 41e9e9d
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion nengo_ocl/clra_gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -1355,7 +1355,7 @@ def choose_plans(self):

class plan_one_thread_per_row_gemv(gemv_prog):
def choose_plans(self):
return [one_thread_per_row_impl(self, len(self.Y))]
return [one_thread_per_row_impl(self, range(len(self.Y)))]

class plan_pretuned_gemv(gemv_prog):
PLANS = (one_thread_per_row_impl, reduce_impl, many_dots_impl, block_impl)
Expand Down
4 changes: 2 additions & 2 deletions nengo_ocl/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from nengo_ocl.raggedarray import RaggedArray
from nengo_ocl.clraggedarray import CLRaggedArray, to_device
from nengo_ocl.clra_gemv import plan_block_gemv
from nengo_ocl.clra_gemv import plan_pretuned_gemv, plan_one_thread_per_row_gemv
from nengo_ocl.clra_nonlinearities import (
plan_timeupdate, plan_reset, plan_copy, plan_slicedcopy,
plan_direct, plan_lif, plan_lif_rate,
Expand Down Expand Up @@ -471,7 +471,7 @@ def _sig_gemv(self, ops, A_js_fn, X_js_fn, Y_fn, Y_in_fn=None,
if callable(beta):
beta = RaggedArray([sidx[beta(o)] for o in ops], dtype=np.float32)

rval = plan_block_gemv(
rval = plan_one_thread_per_row_gemv(
self.queue, alpha, all_data, A_js, all_data, X_js, beta, Y,
Y_in=Y_in, gamma=gamma, tag=tag)
return rval.plans
Expand Down

0 comments on commit 41e9e9d

Please sign in to comment.