Skip to content

Commit

Permalink
fix pooling gendata
Browse files Browse the repository at this point in the history
  • Loading branch information
superlopuh committed Oct 3, 2024
1 parent 92d26d3 commit a1d9448
Show file tree
Hide file tree
Showing 10 changed files with 34 additions and 28 deletions.
15 changes: 9 additions & 6 deletions kernels/pooling_nchw_max_d1_s2_3x3/gendata.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,15 @@ def sum_pool_data(
y_in = np.random.uniform(rmin, rmax, (n, c, new_h, new_w))
y_out = y_in.copy()

for row in range(0, H - pool_size[0] + 1, stride):
for col in range(0, W - pool_size[1] + 1, stride):
pooling_region = x[:, :, row : row + pool_size[0], col : col + pool_size[1]]
y_out[:, :, row // stride, col // stride] = np.max(
pooling_region, axis=(2, 3)
)
for row in range(new_h):
for col in range(new_w):
pooling_region = x[
:,
:,
row * stride : row * stride + pool_size[0],
col * stride : col * stride + pool_size[1],
]
y_out[:, :, row, col] = np.max(pooling_region, axis=(2, 3))

yield Define("N", n)
yield Define("C", c)
Expand Down
15 changes: 9 additions & 6 deletions kernels/pooling_nchw_sum_d1_s2_3x3/gendata.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,15 @@ def sum_pool_data(
y_in = np.random.uniform(rmin, rmax, (n, c, new_h, new_w))
y_out = y_in.copy()

for row in range(0, H - pool_size[0] + 1, stride):
for col in range(0, W - pool_size[1] + 1, stride):
pooling_region = x[:, :, row : row + pool_size[0], col : col + pool_size[1]]
y_out[:, :, row // stride, col // stride] = np.sum(
pooling_region, axis=(2, 3)
)
for row in range(new_h):
for col in range(new_w):
pooling_region = x[
:,
:,
row * stride : row * stride + pool_size[0],
col * stride : col * stride + pool_size[1],
]
y_out[:, :, row, col] = np.sum(pooling_region, axis=(2, 3))

yield Define("N", n)
yield Define("C", c)
Expand Down
6 changes: 3 additions & 3 deletions results/kernels.csv
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ matmul,4x16x8xf64,linalg_xdsl,708,1493,1490,2.811418685121107,0.0,512,578,1625,0
matmul_transb,4x16x16xf32,baseline,3386,4184,4181,2.539660056657224,1.4921875,0,706,1793,0.20850561134081513,0.3935340022296544,1794,1528,1024,0.5298287064382753,0,64,1.0,1.0,1,0.0,1794,0.5561066336019839,1432,0,0,0.42291789722386297,0,799,0.0,0.9527466036621383,0.0
matmul_transb,4x16x16xf32,snitch_stream,871,1660,1657,2.648367952522255,0.0,0,674,1785,0.7738231917336394,0.9519774011299436,708,0,0,0.8128587830080367,0,32,2.1325301204819276,2.1325301204819276,1,0.0,332,0.7793427230046949,94,0,0,0.1079219288174512,0,790,0.0,0.9207807118254879,0.0
matmul_transb,4x16x16xf32,snrt,849,1612,1609,2.648367952522255,0.0,0,674,1785,0.7938751472320377,0.9519774011299436,708,0,0,0.833922261484099,0,32,2.1325301204819276,2.1325301204819276,1,0.0,332,0.8924731182795699,40,0,0,0.04711425206124853,0,764,0.0,0.8810365135453475,0.0
pooling_nchw_max_d1_s2_3x3,4x4xf64,baseline,442,1213,1210,0.993103448275862,1.1192660550458715,0,145,144,0.32805429864253394,0.5370370370370371,270,122,109,0.6108597285067874,0,16,1.0,1.0,1,0.0,270,0.903010033444816,29,0,0,0.06561085972850679,0,772,0.0,0.6764705882352942,0.0
pooling_nchw_max_d1_s2_3x3,4x4xf64,linalg_xdsl,275,1023,1020,0.9943820224719101,0.0,0,178,177,0.6472727272727272,0.9888888888888889,180,0,0,0.6545454545454545,0,0,3.214285714285714,3.2142857142857144,1,0.0,56,0.5283018867924528,50,0,0,0.18181818181818182,0,749,0.0,0.8363636363636364,0.0
pooling_nchw_sum_d1_s2_3x3,4x4xf64,baseline,582,1341,1338,2.9767441860465116,1.1018518518518519,0,129,384,0.22164948453608246,0.5098814229249012,253,119,108,0.43470790378006874,0,16,1.0,1.0,1,0.0,253,0.9730769230769231,7,0,0,0.012027491408934709,0,760,0.0,0.44673539518900346,0.0
pooling_nchw_max_d1_s2_3x3,4x4xf64,baseline,584,1328,1325,0.995575221238938,1.1226415094339623,0,226,225,0.386986301369863,0.6330532212885154,357,119,106,0.6113013698630136,0,25,1.0,1.0,1,0.0,357,0.9153846153846154,33,0,0,0.05650684931506849,0,745,0.0,0.6678082191780821,0.0
pooling_nchw_max_d1_s2_3x3,4x4xf64,linalg_xdsl,275,1018,1015,0.9943820224719101,0.0,0,178,177,0.6472727272727272,0.9888888888888889,180,0,0,0.6545454545454545,0,0,3.214285714285714,3.2142857142857144,1,0.0,56,0.5283018867924528,50,0,0,0.18181818181818182,0,744,0.0,0.8363636363636364,0.0
pooling_nchw_sum_d1_s2_3x3,4x4xf64,baseline,902,1647,1644,2.985074626865672,1.1904761904761905,0,201,600,0.22283813747228381,0.6072507552870091,331,125,105,0.3669623059866962,0,25,1.0,1.0,1,0.0,331,0.914364640883978,31,0,0,0.03436807095343681,0,746,0.0,0.401330376940133,0.0
pooling_nchw_sum_d1_s2_3x3,4x4xf64,linalg_xdsl,271,1046,1043,2.6797752808988764,0.0,0,178,477,0.6568265682656826,0.9888888888888889,180,0,0,0.6642066420664207,0,0,3.214285714285714,3.2142857142857144,1,0.0,56,0.5384615384615384,48,0,0,0.17712177121771217,0,776,0.0,0.8413284132841328,0.0
relu,4x4xf64,baseline,142,892,889,0.9444444444444444,1.0,0,18,17,0.1267605633802817,0.36,50,16,16,0.352112676056338,0,16,1.0,1.0,1,0.0,50,0.8771929824561403,7,0,0,0.04929577464788732,0,751,0.0,0.4014084507042253,0.0
relu,4x4xf64,linalg_xdsl,72,817,814,0.9444444444444444,0.0,0,18,17,0.25,0.9,20,0,0,0.2777777777777778,0,0,3.333333333333333,3.3333333333333335,1,0.0,6,0.25,18,0,0,0.25,0,746,0.0,0.5277777777777778,0.0
Expand Down
6 changes: 3 additions & 3 deletions results/kernels.fast.csv
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@ matmul,4x16x8xf64,linalg_xdsl,708,1493,1490,2.811418685121107,0.0,512,578,1625,0
matmul_transb,4x16x16xf32,baseline,3386,4184,4181,2.539660056657224,1.4921875,0,706,1793,0.20850561134081513,0.3935340022296544,1794,1528,1024,0.5298287064382753,0,64,1.0,1.0,1,0.0,1794,0.5561066336019839,1432,0,0,0.42291789722386297,0,799,0.0,0.9527466036621383,0.0
matmul_transb,4x16x16xf32,snitch_stream,871,1660,1657,2.648367952522255,0.0,0,674,1785,0.7738231917336394,0.9519774011299436,708,0,0,0.8128587830080367,0,32,2.1325301204819276,2.1325301204819276,1,0.0,332,0.7793427230046949,94,0,0,0.1079219288174512,0,790,0.0,0.9207807118254879,0.0
matmul_transb,4x16x16xf32,snrt,849,1612,1609,2.648367952522255,0.0,0,674,1785,0.7938751472320377,0.9519774011299436,708,0,0,0.833922261484099,0,32,2.1325301204819276,2.1325301204819276,1,0.0,332,0.8924731182795699,40,0,0,0.04711425206124853,0,764,0.0,0.8810365135453475,0.0
pooling_nchw_max_d1_s2_3x3,4x4xf64,baseline,442,1213,1210,0.993103448275862,1.1192660550458715,0,145,144,0.32805429864253394,0.5370370370370371,270,122,109,0.6108597285067874,0,16,1.0,1.0,1,0.0,270,0.903010033444816,29,0,0,0.06561085972850679,0,772,0.0,0.6764705882352942,0.0
pooling_nchw_max_d1_s2_3x3,4x4xf64,linalg_xdsl,275,1023,1020,0.9943820224719101,0.0,0,178,177,0.6472727272727272,0.9888888888888889,180,0,0,0.6545454545454545,0,0,3.214285714285714,3.2142857142857144,1,0.0,56,0.5283018867924528,50,0,0,0.18181818181818182,0,749,0.0,0.8363636363636364,0.0
pooling_nchw_sum_d1_s2_3x3,4x4xf64,baseline,582,1341,1338,2.9767441860465116,1.1018518518518519,0,129,384,0.22164948453608246,0.5098814229249012,253,119,108,0.43470790378006874,0,16,1.0,1.0,1,0.0,253,0.9730769230769231,7,0,0,0.012027491408934709,0,760,0.0,0.44673539518900346,0.0
pooling_nchw_max_d1_s2_3x3,4x4xf64,baseline,584,1328,1325,0.995575221238938,1.1226415094339623,0,226,225,0.386986301369863,0.6330532212885154,357,119,106,0.6113013698630136,0,25,1.0,1.0,1,0.0,357,0.9153846153846154,33,0,0,0.05650684931506849,0,745,0.0,0.6678082191780821,0.0
pooling_nchw_max_d1_s2_3x3,4x4xf64,linalg_xdsl,275,1018,1015,0.9943820224719101,0.0,0,178,177,0.6472727272727272,0.9888888888888889,180,0,0,0.6545454545454545,0,0,3.214285714285714,3.2142857142857144,1,0.0,56,0.5283018867924528,50,0,0,0.18181818181818182,0,744,0.0,0.8363636363636364,0.0
pooling_nchw_sum_d1_s2_3x3,4x4xf64,baseline,902,1647,1644,2.985074626865672,1.1904761904761905,0,201,600,0.22283813747228381,0.6072507552870091,331,125,105,0.3669623059866962,0,25,1.0,1.0,1,0.0,331,0.914364640883978,31,0,0,0.03436807095343681,0,746,0.0,0.401330376940133,0.0
pooling_nchw_sum_d1_s2_3x3,4x4xf64,linalg_xdsl,271,1046,1043,2.6797752808988764,0.0,0,178,477,0.6568265682656826,0.9888888888888889,180,0,0,0.6642066420664207,0,0,3.214285714285714,3.2142857142857144,1,0.0,56,0.5384615384615384,48,0,0,0.17712177121771217,0,776,0.0,0.8413284132841328,0.0
relu,4x4xf64,baseline,142,892,889,0.9444444444444444,1.0,0,18,17,0.1267605633802817,0.36,50,16,16,0.352112676056338,0,16,1.0,1.0,1,0.0,50,0.8771929824561403,7,0,0,0.04929577464788732,0,751,0.0,0.4014084507042253,0.0
relu,4x4xf64,linalg_xdsl,72,817,814,0.9444444444444444,0.0,0,18,17,0.25,0.9,20,0,0,0.2777777777777778,0,0,3.333333333333333,3.3333333333333335,1,0.0,6,0.25,18,0,0,0.25,0,746,0.0,0.5277777777777778,0.0
Expand Down
4 changes: 2 additions & 2 deletions results/pivoted.csv
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ dense 8x8xf64,3206,3530,,2741,2723
fill 4x4xf64,50,,63,,
matmul 4x16x8xf64,2495,,708,,
matmul_transb 4x16x16xf32,3386,,,871,849
pooling_nchw_max_d1_s2_3x3 4x4xf64,442,,275,,
pooling_nchw_sum_d1_s2_3x3 4x4xf64,582,,271,,
pooling_nchw_max_d1_s2_3x3 4x4xf64,584,,275,,
pooling_nchw_sum_d1_s2_3x3 4x4xf64,902,,271,,
relu 4x4xf64,142,,72,,
relu 4x8xf32,297,210,,67,85
saxpy 64xf32,634,634,,,140
Expand Down
4 changes: 2 additions & 2 deletions results/pivoted.fast.csv
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ dense 8x8xf64,3206,3530,,2741,2723
fill 4x4xf64,50,,63,,
matmul 4x16x8xf64,2495,,708,,
matmul_transb 4x16x16xf32,3386,,,871,849
pooling_nchw_max_d1_s2_3x3 4x4xf64,442,,275,,
pooling_nchw_sum_d1_s2_3x3 4x4xf64,582,,271,,
pooling_nchw_max_d1_s2_3x3 4x4xf64,584,,275,,
pooling_nchw_sum_d1_s2_3x3 4x4xf64,902,,271,,
relu 4x4xf64,142,,72,,
relu 4x8xf32,297,210,,67,85
saxpy 64xf32,634,634,,,140
Expand Down
2 changes: 1 addition & 1 deletion results/pivoted_fpu.csv
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ dense 8x8xf64,0.20,0.18,,0.26,0.26
fill 4x4xf64,0.02,,0.29,,
matmul 4x16x8xf64,0.21,,0.82,,
matmul_transb 4x16x16xf32,0.21,,,0.77,0.79
pooling_nchw_max_d1_s2_3x3 4x4xf64,0.33,,0.65,,
pooling_nchw_max_d1_s2_3x3 4x4xf64,0.39,,0.65,,
pooling_nchw_sum_d1_s2_3x3 4x4xf64,0.22,,0.66,,
relu 4x4xf64,0.13,,0.25,,
relu 4x8xf32,0.33,0.16,,0.28,0.22
Expand Down
2 changes: 1 addition & 1 deletion results/pivoted_fpu.fast.csv
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ dense 8x8xf64,0.20,0.18,,0.26,0.26
fill 4x4xf64,0.02,,0.29,,
matmul 4x16x8xf64,0.21,,0.82,,
matmul_transb 4x16x16xf32,0.21,,,0.77,0.79
pooling_nchw_max_d1_s2_3x3 4x4xf64,0.33,,0.65,,
pooling_nchw_max_d1_s2_3x3 4x4xf64,0.39,,0.65,,
pooling_nchw_sum_d1_s2_3x3 4x4xf64,0.22,,0.66,,
relu 4x4xf64,0.13,,0.25,,
relu 4x8xf32,0.33,0.16,,0.28,0.22
Expand Down
4 changes: 2 additions & 2 deletions results/pivoted_ipc.csv
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ dense 8x8xf64,0.51,0.55,,0.39,0.33
fill 4x4xf64,0.46,,0.56,,
matmul 4x16x8xf64,0.56,,0.93,,
matmul_transb 4x16x16xf32,0.95,,,0.92,0.88
pooling_nchw_max_d1_s2_3x3 4x4xf64,0.68,,0.84,,
pooling_nchw_sum_d1_s2_3x3 4x4xf64,0.45,,0.84,,
pooling_nchw_max_d1_s2_3x3 4x4xf64,0.67,,0.84,,
pooling_nchw_sum_d1_s2_3x3 4x4xf64,0.40,,0.84,,
relu 4x4xf64,0.40,,0.53,,
relu 4x8xf32,0.57,0.51,,0.57,0.40
saxpy 64xf32,0.93,0.93,,,0.65
Expand Down
4 changes: 2 additions & 2 deletions results/pivoted_ipc.fast.csv
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ dense 8x8xf64,0.51,0.55,,0.39,0.33
fill 4x4xf64,0.46,,0.56,,
matmul 4x16x8xf64,0.56,,0.93,,
matmul_transb 4x16x16xf32,0.95,,,0.92,0.88
pooling_nchw_max_d1_s2_3x3 4x4xf64,0.68,,0.84,,
pooling_nchw_sum_d1_s2_3x3 4x4xf64,0.45,,0.84,,
pooling_nchw_max_d1_s2_3x3 4x4xf64,0.67,,0.84,,
pooling_nchw_sum_d1_s2_3x3 4x4xf64,0.40,,0.84,,
relu 4x4xf64,0.40,,0.53,,
relu 4x8xf32,0.57,0.51,,0.57,0.40
saxpy 64xf32,0.93,0.93,,,0.65
Expand Down

0 comments on commit a1d9448

Please sign in to comment.