Skip to content

Commit

Permalink
Merge pull request #72 from opencompl/sasha/conv-2d
Browse files Browse the repository at this point in the history
add conv2d linalg
  • Loading branch information
superlopuh authored Nov 6, 2023
2 parents 3277f9a + 390ab3b commit 26da1aa
Show file tree
Hide file tree
Showing 9 changed files with 358 additions and 1 deletion.
8 changes: 8 additions & 0 deletions kernels/conv2d_d1_s1_3x3/1x8x8x1xf64/Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
.DEFAULT_GOAL := all

include ../../../snitch/Makefile.rules

TESTS =
TESTS += linalg.x

include ../../Makefile.kernels
128 changes: 128 additions & 0 deletions kernels/conv2d_d1_s1_3x3/1x8x8x1xf64/data.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
#define N 1
#define C 1
#define H 8
#define W 8
#define F 1
#define NEW_H 6
#define NEW_W 6

const double X[N * C * H * W] = {
97.62700785,
430.37873274,
205.52675214,
89.76636599,
-152.69040132,
291.78822613,
-124.82557747,
783.54600156,
927.325521 ,
-233.11696235,
583.45007617,
57.78983951,
136.08912219,
851.19327659,
-857.9278836 ,
-825.7414006 ,
-959.56320512,
665.2396911 ,
556.3135019 ,
740.02429649,
957.23668447,
598.31712843,
-77.04127549,
561.05835257,
-763.45114826,
279.84204266,
-713.29342518,
889.3378341 ,
43.6966435 ,
-170.67612002,
-470.88877579,
548.46737887,
-87.69933557,
136.86789774,
-962.42039913,
235.27099415,
224.19144544,
233.86799375,
887.49615703,
363.64059821,
-280.98419885,
-125.9360924 ,
395.26239185,
-879.54905674,
333.53343089,
341.27573924,
-579.23487785,
-742.14740469,
-369.14329815,
-272.57845811,
140.39354084,
-122.79697308,
976.74767612,
-795.9103785 ,
-582.24648781,
-677.38096423,
306.21665093,
-493.41679492,
-67.37845429,
-511.148816 ,
-682.06083271,
-779.24971767,
312.65917893,
-723.6340973
};


const double Y[F * C * 3 * 3] = {
-606.83527664,
-262.54965868,
641.9864597 ,
-805.79744841,
675.889815 ,
-807.80318421,
952.91893003,
-62.6975967 ,
953.52217638
};


const double Z[N * F * NEW_H * NEW_W] = {
-1842042.26980262,
1582678.56948294,
609139.9014004 ,
746432.58436598,
1895797.51365202,
869973.49330121,
-778705.08964161,
426935.38403847,
-1697027.80827536,
724993.56385848,
-1558213.05499557,
-1418959.88780549,
1135473.31976214,
-1085577.38607162,
505127.29635872,
-432366.60882043,
487604.73402908,
-92500.41787501,
989545.90142042,
-1345889.25625288,
1730669.67927757,
-1421333.93465863,
-1279396.19699815,
350832.69483176,
-979378.49236044,
1014165.85956842,
556612.46644182,
-330178.92361772,
1227601.54362234,
-1576298.24125339,
715628.92071038,
-1131636.22002201,
-1462474.30698746,
879977.24758628,
-1821494.69964086,
-1188636.18765696
};

13 changes: 13 additions & 0 deletions kernels/conv2d_d1_s1_3x3/1x8x8x1xf64/data.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#pragma once

#define N 1
#define C 1
#define H 8
#define W 8
#define F 1
#define NEW_H 6
#define NEW_W 6

extern const double X[N * C * H * W];
extern const double Y[F * C * 3 * 3];
extern const double Z[N * F * NEW_H * NEW_W];
1 change: 1 addition & 0 deletions kernels/conv2d_d1_s1_3x3/1x8x8x1xf64/linalg.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3932
12 changes: 12 additions & 0 deletions kernels/conv2d_d1_s1_3x3/1x8x8x1xf64/linalg.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@

func.func public @conv_2d_nchw_fchw_d1_s1_3x3(
%X: memref<1x1x8x8xf64>,
%Y: memref<1x1x3x3xf64>,
%Z: memref<1x1x6x6xf64>) -> () {
linalg.conv_2d_nchw_fchw {
dilations = dense<1> : vector<2xi64>,
strides = dense<1> : vector<2xi64>
} ins(%X, %Y : memref<1x1x8x8xf64>, memref<1x1x3x3xf64>)
outs(%Z : memref<1x1x6x6xf64>) -> ()
return
}
43 changes: 43 additions & 0 deletions kernels/conv2d_d1_s1_3x3/1x8x8x1xf64/main.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#include "data.h"

#include <snrt.h>

#include <math.h>

// Kernel provided via external definition
void conv_2d_nchw_fchw_d1_s1_3x3(double *x, double *y, double *z);

int main() {
// Allocate shared local memory
// By avoiding allocators and bumping by a known offset a base pointer
// (snrt_l1_next()) that is the same for all the cores in the cluster, we are
// essentially providing the same memory regions to all the cores in this cluster.
double *local_x = (double *)snrt_l1_next();
double *local_y = local_x + N * C * H * W;
double *local_z = local_y + F * C * 3 * 3;

// Copy data in shared local memory
if (snrt_is_dm_core()) {
snrt_dma_start_1d(local_x, X, N * C * H * W * sizeof(double));
snrt_dma_start_1d(local_y, Y, F * C * 3 * 3 * sizeof(double));
}

snrt_cluster_hw_barrier();

// Launch kernel: from this point on only core 0 is required to be alive.
int thiscore = snrt_cluster_core_idx();
if (thiscore != 0) return 0;

(void)snrt_mcycle();
conv_2d_nchw_fchw_d1_s1_3x3(local_x, local_y, local_z);
(void)snrt_mcycle();

// Correctness check
int nerr = 0;
for (int i = 0; i < N * F * NEW_H * NEW_W; i++) {
double d = fabs(local_z[i] - Z[i]);
nerr += !(d <= 1E-2f); // Make sure to take into account NaNs (e.g.: happy path
// on the taken branch)
}
return nerr;
}
150 changes: 150 additions & 0 deletions kernels/conv2d_d1_s1_3x3/gendata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
#!/usr/bin/env python3

import numpy as np
import argparse
import sys


C_TYPES = {
"32": "float",
"64": "double",
}

NUMPY_TYPES = {
"32": np.single,
"64": np.double,
}

MLIR_TYPES = {
"32": "f32",
"64": "f64",
}

MEMREF_GLOBAL = """
memref.global constant @{symbol} : memref<{shape}x{type}> = dense<[
{initializer}
]>
"""


ARRAY_GLOBAL = """
const {type} {symbol}[{shape}] = {{
{initializer}
}};
"""


def array_to_memref_initializer(array: np.array):
return ",\n".join(f" {np.array2string(row, separator=', ')}" for row in array)


def array_to_memref(array: np.array, precision: int, shape=None, symbol=None):
return MEMREF_GLOBAL.format(
symbol=symbol or "array",
type=MLIR_TYPES[str(precision)],
shape=shape or "x".join(str(dim) for dim in array.shape),
initializer=array_to_memref_initializer(array),
)


def array_to_c_initializer(array: np.array):
return np.array2string(array.flatten(), separator=",\n").strip(" []")


def array_to_c(array: np.array, *, precision: int, shape=None, symbol=None):
return ARRAY_GLOBAL.format(
symbol=symbol or "array",
type=C_TYPES[str(precision)],
shape=shape or "*".join(str(dim) for dim in array.shape),
initializer=array_to_c_initializer(array),
)


if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="gendata.py",
description="Generate literal initializers for a fictional BLAS matmul "
"(matrix-matrix single precision multiplication) on 2d memrefs",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"-r",
"--range",
type=float,
nargs=2,
default=(-1000.0, 1000.0),
help="uniform distribution range",
)
parser.add_argument("-m", "--rows", type=int, default=8, help="number of rows")
parser.add_argument(
"-n", "--columns", type=int, default=8, help="number of columns"
)
parser.add_argument(
"--format", default="c", choices=["mlir", "c"], help="output format"
)
parser.add_argument(
"--precision",
type=int,
default=64,
choices=[32, 64],
help="floating-point precision to use",
)
args = parser.parse_args()

rmin, rmax = args.range
n = 1 # n for number of elements in a batch
c = 1 # c for channels
h = args.rows # h for height
w = args.columns # w for width
f = 1 # number of features

kernel_size = (3, 3)
stride = 1

np.random.seed(0)
x = (
np.random.uniform(rmin, rmax, n * c * h * w)
.astype(np.float64)
.reshape((n, c, h, w))
)
y = (
np.random.uniform(rmin, rmax, f * c * kernel_size[0] * kernel_size[1])
.astype(np.float64)
.reshape((f, c, kernel_size[0], kernel_size[1]))
)

new_h = (h - kernel_size[0]) // stride + 1
new_w = (w - kernel_size[1]) // stride + 1

# Perform the max pooling operation
z = np.zeros((n, f, new_h, new_w))

for i in range(f):
for row in range(0, h - kernel_size[0] + 1, stride):
for col in range(0, w - kernel_size[1] + 1, stride):
receptive_field = x[
:, :, row : row + kernel_size[0], col : col + kernel_size[1]
]
z[:, i, row // stride, col // stride] = np.sum(
receptive_field * y[i, :, :, :]
)

printopts = {"linewidth": None, "threshold": sys.maxsize}
if args.format == "c":
fmt = array_to_c
print(f"#define N {n}")
print(f"#define C {c}")
print(f"#define H {h}")
print(f"#define W {w}")
print(f"#define F {f}")
print(f"#define NEW_H {new_h}")
print(f"#define NEW_W {new_w}")
printopts["formatter"] = {"double ": lambda x: f"{x:+}f"}
else:
assert args.format == "mlir"
fmt = array_to_memref
printopts["sign"] = "+"
np.set_printoptions(**printopts)
print(fmt(x, shape="N * C * H * W", precision=args.precision, symbol="X"))
print(fmt(y, shape="F * C * 3 * 3", precision=args.precision, symbol="Y"))
print(fmt(z, shape="N * F * NEW_H * NEW_W", precision=args.precision, symbol="Z"))
2 changes: 2 additions & 0 deletions snitch/Makefile.rules
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@ XDSLOPTFLAGS += -t riscv-asm
MLIROPTFLAGS =
MLIROPTFLAGS += -opaque-pointers=0
MLIROPTFLAGS += --convert-linalg-to-loops
MLIROPTFLAGS += --lower-affine
MLIROPTFLAGS += --canonicalize
MLIROPTFLAGS += --convert-scf-to-cf
MLIROPTFLAGS += --canonicalize
MLIROPTFLAGS += --cse
Expand Down

0 comments on commit 26da1aa

Please sign in to comment.