-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #72 from opencompl/sasha/conv-2d
add conv2d linalg
- Loading branch information
Showing
9 changed files
with
358 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
.DEFAULT_GOAL := all | ||
|
||
include ../../../snitch/Makefile.rules | ||
|
||
TESTS = | ||
TESTS += linalg.x | ||
|
||
include ../../Makefile.kernels |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
#define N 1 | ||
#define C 1 | ||
#define H 8 | ||
#define W 8 | ||
#define F 1 | ||
#define NEW_H 6 | ||
#define NEW_W 6 | ||
|
||
const double X[N * C * H * W] = { | ||
97.62700785, | ||
430.37873274, | ||
205.52675214, | ||
89.76636599, | ||
-152.69040132, | ||
291.78822613, | ||
-124.82557747, | ||
783.54600156, | ||
927.325521 , | ||
-233.11696235, | ||
583.45007617, | ||
57.78983951, | ||
136.08912219, | ||
851.19327659, | ||
-857.9278836 , | ||
-825.7414006 , | ||
-959.56320512, | ||
665.2396911 , | ||
556.3135019 , | ||
740.02429649, | ||
957.23668447, | ||
598.31712843, | ||
-77.04127549, | ||
561.05835257, | ||
-763.45114826, | ||
279.84204266, | ||
-713.29342518, | ||
889.3378341 , | ||
43.6966435 , | ||
-170.67612002, | ||
-470.88877579, | ||
548.46737887, | ||
-87.69933557, | ||
136.86789774, | ||
-962.42039913, | ||
235.27099415, | ||
224.19144544, | ||
233.86799375, | ||
887.49615703, | ||
363.64059821, | ||
-280.98419885, | ||
-125.9360924 , | ||
395.26239185, | ||
-879.54905674, | ||
333.53343089, | ||
341.27573924, | ||
-579.23487785, | ||
-742.14740469, | ||
-369.14329815, | ||
-272.57845811, | ||
140.39354084, | ||
-122.79697308, | ||
976.74767612, | ||
-795.9103785 , | ||
-582.24648781, | ||
-677.38096423, | ||
306.21665093, | ||
-493.41679492, | ||
-67.37845429, | ||
-511.148816 , | ||
-682.06083271, | ||
-779.24971767, | ||
312.65917893, | ||
-723.6340973 | ||
}; | ||
|
||
|
||
const double Y[F * C * 3 * 3] = { | ||
-606.83527664, | ||
-262.54965868, | ||
641.9864597 , | ||
-805.79744841, | ||
675.889815 , | ||
-807.80318421, | ||
952.91893003, | ||
-62.6975967 , | ||
953.52217638 | ||
}; | ||
|
||
|
||
const double Z[N * F * NEW_H * NEW_W] = { | ||
-1842042.26980262, | ||
1582678.56948294, | ||
609139.9014004 , | ||
746432.58436598, | ||
1895797.51365202, | ||
869973.49330121, | ||
-778705.08964161, | ||
426935.38403847, | ||
-1697027.80827536, | ||
724993.56385848, | ||
-1558213.05499557, | ||
-1418959.88780549, | ||
1135473.31976214, | ||
-1085577.38607162, | ||
505127.29635872, | ||
-432366.60882043, | ||
487604.73402908, | ||
-92500.41787501, | ||
989545.90142042, | ||
-1345889.25625288, | ||
1730669.67927757, | ||
-1421333.93465863, | ||
-1279396.19699815, | ||
350832.69483176, | ||
-979378.49236044, | ||
1014165.85956842, | ||
556612.46644182, | ||
-330178.92361772, | ||
1227601.54362234, | ||
-1576298.24125339, | ||
715628.92071038, | ||
-1131636.22002201, | ||
-1462474.30698746, | ||
879977.24758628, | ||
-1821494.69964086, | ||
-1188636.18765696 | ||
}; | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
#pragma once | ||
|
||
#define N 1 | ||
#define C 1 | ||
#define H 8 | ||
#define W 8 | ||
#define F 1 | ||
#define NEW_H 6 | ||
#define NEW_W 6 | ||
|
||
extern const double X[N * C * H * W]; | ||
extern const double Y[F * C * 3 * 3]; | ||
extern const double Z[N * F * NEW_H * NEW_W]; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
3932 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
|
||
func.func public @conv_2d_nchw_fchw_d1_s1_3x3( | ||
%X: memref<1x1x8x8xf64>, | ||
%Y: memref<1x1x3x3xf64>, | ||
%Z: memref<1x1x6x6xf64>) -> () { | ||
linalg.conv_2d_nchw_fchw { | ||
dilations = dense<1> : vector<2xi64>, | ||
strides = dense<1> : vector<2xi64> | ||
} ins(%X, %Y : memref<1x1x8x8xf64>, memref<1x1x3x3xf64>) | ||
outs(%Z : memref<1x1x6x6xf64>) -> () | ||
return | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
#include "data.h" | ||
|
||
#include <snrt.h> | ||
|
||
#include <math.h> | ||
|
||
// Kernel provided via external definition | ||
void conv_2d_nchw_fchw_d1_s1_3x3(double *x, double *y, double *z); | ||
|
||
int main() { | ||
// Allocate shared local memory | ||
// By avoiding allocators and bumping by a known offset a base pointer | ||
// (snrt_l1_next()) that is the same for all the cores in the cluster, we are | ||
// essentially providing the same memory regions to all the cores in this cluster. | ||
double *local_x = (double *)snrt_l1_next(); | ||
double *local_y = local_x + N * C * H * W; | ||
double *local_z = local_y + F * C * 3 * 3; | ||
|
||
// Copy data in shared local memory | ||
if (snrt_is_dm_core()) { | ||
snrt_dma_start_1d(local_x, X, N * C * H * W * sizeof(double)); | ||
snrt_dma_start_1d(local_y, Y, F * C * 3 * 3 * sizeof(double)); | ||
} | ||
|
||
snrt_cluster_hw_barrier(); | ||
|
||
// Launch kernel: from this point on only core 0 is required to be alive. | ||
int thiscore = snrt_cluster_core_idx(); | ||
if (thiscore != 0) return 0; | ||
|
||
(void)snrt_mcycle(); | ||
conv_2d_nchw_fchw_d1_s1_3x3(local_x, local_y, local_z); | ||
(void)snrt_mcycle(); | ||
|
||
// Correctness check | ||
int nerr = 0; | ||
for (int i = 0; i < N * F * NEW_H * NEW_W; i++) { | ||
double d = fabs(local_z[i] - Z[i]); | ||
nerr += !(d <= 1E-2f); // Make sure to take into account NaNs (e.g.: happy path | ||
// on the taken branch) | ||
} | ||
return nerr; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
#!/usr/bin/env python3 | ||
|
||
import numpy as np | ||
import argparse | ||
import sys | ||
|
||
|
||
C_TYPES = { | ||
"32": "float", | ||
"64": "double", | ||
} | ||
|
||
NUMPY_TYPES = { | ||
"32": np.single, | ||
"64": np.double, | ||
} | ||
|
||
MLIR_TYPES = { | ||
"32": "f32", | ||
"64": "f64", | ||
} | ||
|
||
MEMREF_GLOBAL = """ | ||
memref.global constant @{symbol} : memref<{shape}x{type}> = dense<[ | ||
{initializer} | ||
]> | ||
""" | ||
|
||
|
||
ARRAY_GLOBAL = """ | ||
const {type} {symbol}[{shape}] = {{ | ||
{initializer} | ||
}}; | ||
""" | ||
|
||
|
||
def array_to_memref_initializer(array: np.array): | ||
return ",\n".join(f" {np.array2string(row, separator=', ')}" for row in array) | ||
|
||
|
||
def array_to_memref(array: np.array, precision: int, shape=None, symbol=None): | ||
return MEMREF_GLOBAL.format( | ||
symbol=symbol or "array", | ||
type=MLIR_TYPES[str(precision)], | ||
shape=shape or "x".join(str(dim) for dim in array.shape), | ||
initializer=array_to_memref_initializer(array), | ||
) | ||
|
||
|
||
def array_to_c_initializer(array: np.array): | ||
return np.array2string(array.flatten(), separator=",\n").strip(" []") | ||
|
||
|
||
def array_to_c(array: np.array, *, precision: int, shape=None, symbol=None): | ||
return ARRAY_GLOBAL.format( | ||
symbol=symbol or "array", | ||
type=C_TYPES[str(precision)], | ||
shape=shape or "*".join(str(dim) for dim in array.shape), | ||
initializer=array_to_c_initializer(array), | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser( | ||
prog="gendata.py", | ||
description="Generate literal initializers for a fictional BLAS matmul " | ||
"(matrix-matrix single precision multiplication) on 2d memrefs", | ||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||
) | ||
parser.add_argument( | ||
"-r", | ||
"--range", | ||
type=float, | ||
nargs=2, | ||
default=(-1000.0, 1000.0), | ||
help="uniform distribution range", | ||
) | ||
parser.add_argument("-m", "--rows", type=int, default=8, help="number of rows") | ||
parser.add_argument( | ||
"-n", "--columns", type=int, default=8, help="number of columns" | ||
) | ||
parser.add_argument( | ||
"--format", default="c", choices=["mlir", "c"], help="output format" | ||
) | ||
parser.add_argument( | ||
"--precision", | ||
type=int, | ||
default=64, | ||
choices=[32, 64], | ||
help="floating-point precision to use", | ||
) | ||
args = parser.parse_args() | ||
|
||
rmin, rmax = args.range | ||
n = 1 # n for number of elements in a batch | ||
c = 1 # c for channels | ||
h = args.rows # h for height | ||
w = args.columns # w for width | ||
f = 1 # number of features | ||
|
||
kernel_size = (3, 3) | ||
stride = 1 | ||
|
||
np.random.seed(0) | ||
x = ( | ||
np.random.uniform(rmin, rmax, n * c * h * w) | ||
.astype(np.float64) | ||
.reshape((n, c, h, w)) | ||
) | ||
y = ( | ||
np.random.uniform(rmin, rmax, f * c * kernel_size[0] * kernel_size[1]) | ||
.astype(np.float64) | ||
.reshape((f, c, kernel_size[0], kernel_size[1])) | ||
) | ||
|
||
new_h = (h - kernel_size[0]) // stride + 1 | ||
new_w = (w - kernel_size[1]) // stride + 1 | ||
|
||
# Perform the max pooling operation | ||
z = np.zeros((n, f, new_h, new_w)) | ||
|
||
for i in range(f): | ||
for row in range(0, h - kernel_size[0] + 1, stride): | ||
for col in range(0, w - kernel_size[1] + 1, stride): | ||
receptive_field = x[ | ||
:, :, row : row + kernel_size[0], col : col + kernel_size[1] | ||
] | ||
z[:, i, row // stride, col // stride] = np.sum( | ||
receptive_field * y[i, :, :, :] | ||
) | ||
|
||
printopts = {"linewidth": None, "threshold": sys.maxsize} | ||
if args.format == "c": | ||
fmt = array_to_c | ||
print(f"#define N {n}") | ||
print(f"#define C {c}") | ||
print(f"#define H {h}") | ||
print(f"#define W {w}") | ||
print(f"#define F {f}") | ||
print(f"#define NEW_H {new_h}") | ||
print(f"#define NEW_W {new_w}") | ||
printopts["formatter"] = {"double ": lambda x: f"{x:+}f"} | ||
else: | ||
assert args.format == "mlir" | ||
fmt = array_to_memref | ||
printopts["sign"] = "+" | ||
np.set_printoptions(**printopts) | ||
print(fmt(x, shape="N * C * H * W", precision=args.precision, symbol="X")) | ||
print(fmt(y, shape="F * C * 3 * 3", precision=args.precision, symbol="Y")) | ||
print(fmt(z, shape="N * F * NEW_H * NEW_W", precision=args.precision, symbol="Z")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Submodule xdsl
updated
8 files