diff --git a/nengo_ocl/clra_gemv.py b/nengo_ocl/clra_gemv.py index 369b7bf..8fd770b 100644 --- a/nengo_ocl/clra_gemv.py +++ b/nengo_ocl/clra_gemv.py @@ -1126,6 +1126,166 @@ def block_impl(p, items): return [plan, plan_reduce] +def one_thread_per_row_impl(p, num_items): + """Each thread computes the dot product over one row. + The rows are grouped consecutively when y_len is small. + """ + if p.clra_alpha is not None: + raise NotImplementedError() + if p.clra_gamma is not None: + raise NotImplementedError() + if p.clra_beta is not None: + raise NotImplementedError() + if p.cl_alpha is not None: + raise NotImplementedError() + if p.cl_gamma is not None: + raise NotImplementedError() + if not all(s == 1 for s in p.A.stride0s): + raise NotImplementedError() + + assert num_items > 0 + items = range(num_items) + + assert p.float_alpha is not None + assert p.float_gamma is not None + + if p.A_js is None: + # -- easy probably, but not done + raise NotImplementedError() + + cl_gstructure, textconf = p.cl_geometry_and_textconf(items, + stride=1) + + max_y_len = max(p.geometry[ii].y_len for ii in items) + max_n_dots = max(len(p.geometry[ii].dots) for ii in items) + all_same_y_len = len(set(p.geometry[ii].y_len for ii in items)) == 1 + + + # Compute in consecutive threads if max_y_len is small, + # and all y_len are the same. + # TODO autotune + if all_same_y_len and max_y_len < 64: + consecutive = True + gsize = (max_y_len * num_items,) + else: + consecutive = False + gsize = (max_y_len, num_items) + + textconf.update({ + 'max_n_dots': max_n_dots, + 'max_y_len': max_y_len, + 'consecutive': consecutive, + 'all_same_y_len': all_same_y_len, + }) + textconf.update(p.__dict__) + + """ + int i, item_i; +% for ix, sprev, s, item_i in zip(range(len(items)), num_rows, num_rows[1:], items): + % if ix == 0: + if (global_i < ${s}) + % else: + else if (global_i < ${s}) + % endif + { + item_i = ${item_i}; + i = global_i - ${sprev}; + } + % endfor + """ + + text = """ + __kernel void gemv( + const __global int* gstructure, + const __global ${A.cl_buf.ctype}* A_data, + const __global ${X.cl_buf.ctype}* X_data, + % if cl_beta is not None: + const __global ${cl_beta.ctype}* betas, + % endif + const __global ${Y_in.cl_buf.ctype}* Y_in_data, + __global ${Y.cl_buf.ctype}* Y_data) + { + % if consecutive: + const int global_i = get_global_id(0); + + const int item_i = global_i / ${max_y_len}; // Item index + const int i = global_i - item_i * ${max_y_len}; // Row index within item + % else: + const int item_i = get_global_id(1); // Item index + const int i = get_global_id(0); // Row index within item + % endif + + const __global int* lstructure = + gstructure + item_i * ${structure_vars_stride}; + + % if not all_same_y_len: + if (i >= ${y_len}) return; + % endif + + ${Y.cl_buf.ctype} sum = 0; + + % if max_n_dots > 1: + for (int ii = 0; ii < ${n_dot_products}; ii++) + { + % else: + const int ii = 0; + % endif + + const int a_s1 = ${a_s1}; + const int n_i = ${N_i}; + + const __global ${A.cl_buf.ctype}* a = A_data + ${a_starts} + i; + const __global ${X.cl_buf.ctype}* x = X_data + ${x_starts}; + + for (int j=0;j 1: + } + % endif + + % if float_gamma is not None: + const ${Y.cl_buf.ctype} gamma = ${float_gamma}; + % else: + const ${Y.cl_buf.ctype} gamma = 0; + % endif + + % if float_beta is not None and float_beta != 0 : + Y_data[${y_offset} + i] = ${float_alpha} * sum + ${float_beta} * Y_in_data[${y_in_starts} + i] + gamma; + % elif cl_beta is not None: + Y_data[${y_offset} + i] = ${float_alpha} * sum + betas[${bb}] * Y_in_data[${y_in_starts} + i] + gamma; + % else: + Y_data[${y_offset} + i] = ${float_alpha} * sum + gamma; + % endif + } + """ + + text = as_ascii(Template(text, output_encoding='ascii').render(**textconf)) + + fn = cl.Program(p.queue.context, text).build().gemv + + full_args = [ + cl_gstructure, + p.A.cl_buf, + p.X.cl_buf, + ] + ([p.cl_beta] if p.cl_beta is not None else []) + [ + p.Y_in.cl_buf, + p.Y.cl_buf, + ] + + fn.set_args(*[arr.data for arr in full_args]) + rval = Plan(p.queue, fn, gsize, None, + name='clra_gemv.one_thread_per_row_impl', + tag=p.tag, + bw_per_call=bw_from_geometry(p.geometry, items), + flops_per_call=flops_from_geometry(p.geometry, items), + ) + rval.full_args = full_args # prevent GC the args + rval.description = p.geometry.summary(items) + return rval + + + class plan_ref_gemv(gemv_prog): def choose_plans(self): return [ref_impl(self, range(len(self.Y)))] @@ -1191,3 +1351,7 @@ def choose_plans(self): plans.append(remaining_plan) return plans + +class plan_one_thread_per_row(gemv_prog): + def choose_plans(self): + return [one_thread_per_row_impl(self, len(self.Y))]