Skip to content

Commit

Permalink
Reschedule the matrix multiply performance app
Browse files Browse the repository at this point in the history
  • Loading branch information
abadams committed Sep 13, 2024
1 parent a65221b commit fa5f2a0
Showing 1 changed file with 54 additions and 20 deletions.
74 changes: 54 additions & 20 deletions test/performance/matrix_multiplication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,44 +30,78 @@ int main(int argc, char **argv) {
ImageParam A(type_of<float>(), 2);
ImageParam B(type_of<float>(), 2);

Var x("x"), xi("xi"), xo("xo"), y("y"), yo("yo"), yi("yi"), yii("yii"), xii("xii");
Func matrix_mul("matrix_mul");

Var x("x"), y("y");
RDom k(0, matrix_size);
RVar ki;

Func matrix_mul("matrix_mul");

matrix_mul(x, y) += A(k, y) * B(x, k);

Func out;
out(x, y) = matrix_mul(x, y);

Var xy;
// Now the schedule. Single-threaded, it hits 155 GFlops on Skylake-X
// i9-9960x with AVX-512 (80% of peak)), and 87 GFlops with AVX2 (90% of
// peak).
//
// Using 16 threads (and no hyperthreading), hits 2080 GFlops (67% of peak)
// and 1310 GFLops (85% of peak) respectively.

out.tile(x, y, xi, yi, 24, 32)
.fuse(x, y, xy)
.parallel(xy)
.split(yi, yi, yii, 4)
.vectorize(xi, 8)
const int vec = target.natural_vector_size<float>();

// Size the inner loop tiles to fit into the number of registers available
// on the target, using either 12 accumulator registers or 24.
const int inner_tile_x = 3 * vec;
const int inner_tile_y = (target.has_feature(Target::AVX512) || target.arch != Target::X86) ? 8 : 4;

// The shape of the outer tiling
const int tile_y = matrix_size / 4;
const int tile_k = matrix_size / 16;

Var xy("xy"), xi("xi"), yi("yi"), yii("yii");

out.tile(x, y, xi, yi, inner_tile_x, tile_y)
.split(yi, yi, yii, inner_tile_y)
.vectorize(xi, vec)
.unroll(xi)
.unroll(yii);
.unroll(yii)
.fuse(x, y, xy)
.parallel(xy);

RVar ko("ko"), ki("ki");
Var z("z");
matrix_mul.update().split(k, ko, ki, tile_k);

// Factor the reduction so that we can do outer blocking over the reduction
// dimension.
Func intm = matrix_mul.update().rfactor(ko, z);

matrix_mul.compute_at(out, yi)
.vectorize(x, 8)
intm.compute_at(matrix_mul, y)
.vectorize(x, vec)
.unroll(x)
.unroll(y);

matrix_mul.update(0)
.reorder(x, y, k)
.vectorize(x, 8)
intm.update(0)
.reorder(x, y, ki)
.vectorize(x, vec)
.unroll(x)
.unroll(y)
.unroll(k, 2);
.unroll(y);

matrix_mul.compute_at(out, xy)
.vectorize(x, vec)
.unroll(x);

matrix_mul.update()
.split(y, y, yi, inner_tile_y)
.reorder(x, yi, y, ko)
.vectorize(x, vec)
.unroll(x)
.unroll(yi);

out
.bound(x, 0, matrix_size)
.bound(y, 0, matrix_size);

out.compile_jit();

Buffer<float> mat_A(matrix_size, matrix_size);
Buffer<float> mat_B(matrix_size, matrix_size);
Buffer<float> output(matrix_size, matrix_size);
Expand Down

0 comments on commit fa5f2a0

Please sign in to comment.