Skip to content

Commit

Permalink
Add clra_gemv one_thread_per_row_impl
Browse files Browse the repository at this point in the history
Signed-off-by: Shaun Ren <[email protected]>
  • Loading branch information
shaunren committed Jul 11, 2016
1 parent 7a1b166 commit e996eb2
Showing 1 changed file with 164 additions and 0 deletions.
164 changes: 164 additions & 0 deletions nengo_ocl/clra_gemv.py
Original file line number Diff line number Diff line change
Expand Up @@ -1126,6 +1126,166 @@ def block_impl(p, items):
return [plan, plan_reduce]


def one_thread_per_row_impl(p, num_items):
"""Each thread computes the dot product over one row.
The rows are grouped consecutively when y_len is small.
"""
if p.clra_alpha is not None:
raise NotImplementedError()
if p.clra_gamma is not None:
raise NotImplementedError()
if p.clra_beta is not None:
raise NotImplementedError()
if p.cl_alpha is not None:
raise NotImplementedError()
if p.cl_gamma is not None:
raise NotImplementedError()
if not all(s == 1 for s in p.A.stride0s):
raise NotImplementedError()

assert num_items > 0
items = range(num_items)

assert p.float_alpha is not None
assert p.float_gamma is not None

if p.A_js is None:
# -- easy probably, but not done
raise NotImplementedError()

cl_gstructure, textconf = p.cl_geometry_and_textconf(items,
stride=1)

max_y_len = max(p.geometry[ii].y_len for ii in items)
max_n_dots = max(len(p.geometry[ii].dots) for ii in items)
all_same_y_len = len(set(p.geometry[ii].y_len for ii in items)) == 1


# Compute in consecutive threads if max_y_len is small,
# and all y_len are the same.
# TODO autotune
if all_same_y_len and max_y_len < 64:
consecutive = True
gsize = (max_y_len * num_items,)
else:
consecutive = False
gsize = (max_y_len, num_items)

textconf.update({
'max_n_dots': max_n_dots,
'max_y_len': max_y_len,
'consecutive': consecutive,
'all_same_y_len': all_same_y_len,
})
textconf.update(p.__dict__)

"""
int i, item_i;
% for ix, sprev, s, item_i in zip(range(len(items)), num_rows, num_rows[1:], items):
% if ix == 0:
if (global_i < ${s})
% else:
else if (global_i < ${s})
% endif
{
item_i = ${item_i};
i = global_i - ${sprev};
}
% endfor
"""

text = """
__kernel void gemv(
const __global int* gstructure,
const __global ${A.cl_buf.ctype}* A_data,
const __global ${X.cl_buf.ctype}* X_data,
% if cl_beta is not None:
const __global ${cl_beta.ctype}* betas,
% endif
const __global ${Y_in.cl_buf.ctype}* Y_in_data,
__global ${Y.cl_buf.ctype}* Y_data)
{
% if consecutive:
const int global_i = get_global_id(0);
const int item_i = global_i / ${max_y_len}; // Item index
const int i = global_i - item_i * ${max_y_len}; // Row index within item
% else:
const int item_i = get_global_id(1); // Item index
const int i = get_global_id(0); // Row index within item
% endif
const __global int* lstructure =
gstructure + item_i * ${structure_vars_stride};
% if not all_same_y_len:
if (i >= ${y_len}) return;
% endif
${Y.cl_buf.ctype} sum = 0;
% if max_n_dots > 1:
for (int ii = 0; ii < ${n_dot_products}; ii++)
{
% else:
const int ii = 0;
% endif
const int a_s1 = ${a_s1};
const int n_i = ${N_i};
const __global ${A.cl_buf.ctype}* a = A_data + ${a_starts} + i;
const __global ${X.cl_buf.ctype}* x = X_data + ${x_starts};
for (int j=0;j<n_i;j++)
sum += a[a_s1*j] * x[j];
% if max_n_dots > 1:
}
% endif
% if float_gamma is not None:
const ${Y.cl_buf.ctype} gamma = ${float_gamma};
% else:
const ${Y.cl_buf.ctype} gamma = 0;
% endif
% if float_beta is not None and float_beta != 0 :
Y_data[${y_offset} + i] = ${float_alpha} * sum + ${float_beta} * Y_in_data[${y_in_starts} + i] + gamma;
% elif cl_beta is not None:
Y_data[${y_offset} + i] = ${float_alpha} * sum + betas[${bb}] * Y_in_data[${y_in_starts} + i] + gamma;
% else:
Y_data[${y_offset} + i] = ${float_alpha} * sum + gamma;
% endif
}
"""

text = as_ascii(Template(text, output_encoding='ascii').render(**textconf))

fn = cl.Program(p.queue.context, text).build().gemv

full_args = [
cl_gstructure,
p.A.cl_buf,
p.X.cl_buf,
] + ([p.cl_beta] if p.cl_beta is not None else []) + [
p.Y_in.cl_buf,
p.Y.cl_buf,
]

fn.set_args(*[arr.data for arr in full_args])
rval = Plan(p.queue, fn, gsize, None,
name='clra_gemv.one_thread_per_row_impl',
tag=p.tag,
bw_per_call=bw_from_geometry(p.geometry, items),
flops_per_call=flops_from_geometry(p.geometry, items),
)
rval.full_args = full_args # prevent GC the args
rval.description = p.geometry.summary(items)
return rval



class plan_ref_gemv(gemv_prog):
def choose_plans(self):
return [ref_impl(self, range(len(self.Y)))]
Expand Down Expand Up @@ -1191,3 +1351,7 @@ def choose_plans(self):
plans.append(remaining_plan)

return plans

class plan_one_thread_per_row(gemv_prog):
def choose_plans(self):
return [one_thread_per_row_impl(self, len(self.Y))]

0 comments on commit e996eb2

Please sign in to comment.