Skip to content

Commit

Permalink
Fix shape computation
Browse files Browse the repository at this point in the history
  • Loading branch information
pfultz2 committed Jan 10, 2025
1 parent d24f634 commit 925baf1
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/targets/gpu/jit/reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,8 +425,9 @@ struct fused_reduce_compiler : compiler<fused_reduce_compiler>
tuning_config tc;
auto shapes = to_shapes(ins->inputs());
tc.problem = to_value(shapes);
auto axes = op.to_value().at("axes").to_vector<std::size_t>();
auto input_shape = get_input_shape(shapes);
auto reduce_shape = get_reduce_shape(shapes);
auto reduce_shape = get_reduced_shape(input_shape, axes);
auto relements = reduce_shape.elements();
for(auto block_size:{64, 128, 256, 512, 1024})
{
Expand Down

0 comments on commit 925baf1

Please sign in to comment.