Skip to content

Commit

Permalink
update kernels in manual impls
Browse files Browse the repository at this point in the history
  • Loading branch information
heheda12345 committed Apr 20, 2023
1 parent e568ef6 commit a6514e7
Show file tree
Hide file tree
Showing 6 changed files with 138 additions and 598 deletions.

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions artifacts/models/manual_attention/bs1/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def forward(self, x, k, v):
SIZE_PER_HEAD = 64
n_warmup = 100
n_run = 100
torch.random.manual_seed(0)

if __name__ == "__main__":
batch_size = 1
Expand All @@ -41,6 +42,7 @@ def forward(self, x, k, v):
for i in range(n_warmup):
_ = model.forward(x, k, v)
torch.cuda.synchronize()
print(_)
# run
timer = Timer("ms")
enable_profile('V100')
Expand Down
2 changes: 1 addition & 1 deletion artifacts/models/manual_seq2seq/bs1/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def read_bin(s, dtype=np.float32):
mask = masks[:, i:i+batch_size].contiguous()
torch.cuda.synchronize()
output_all, h = model(encoder_output, mask, h, c)
if i == 0: print(output_all)
if i == 0: print(output_all, h)
torch.cuda.synchronize()
# run
timer = Timer("ms")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -654,36 +654,30 @@ extern void FusedKernel_float_float_float_float_float_float_float_float_float_fl
// - name: Dot_77_0 type: float shape: Shape{1, 3797}
extern "C" __global__ void Dot_float_float_float_cuda_Dot_77(float* __restrict__ input0, float* __restrict__ input1, float* __restrict__ output0)
{
const int num_tasks = 3797;
int block_start = 16 * 256 * blockIdx.x;
int warp_id = threadIdx.x >> 5;
int lane_id = threadIdx.x & 31;
int task_id_in_block = threadIdx.x >> 4;
int in_task_id = threadIdx.x & 15;
__shared__ float share_a[256];
__shared__ float share_b[2304];
{
{
const int num_tasks = 3797;
int block_start = 64 * 256 * blockIdx.x;
int warp_id = threadIdx.x >> 5;
int lane_id = threadIdx.x & 31;
int task_id_in_block = threadIdx.x >> 2;
int in_task_id = threadIdx.x & 3;

// 36 * 64
share_a[threadIdx.x] = input0[threadIdx.x];
float s = 0.0;
for (int k = 0; k < 256; k += 32) {
#pragma unroll
for (int i = warp_id; i < 64; i += 8) if (blockIdx.x * 64 + i < num_tasks) share_b[i * 36 + lane_id] = input1[block_start + i * 256 + k + lane_id];
__syncthreads();
// if (threadIdx.x == 0) { printf("shareb k=%d:", k); for (int i = 0; i < 32; i++) printf("%f ", share_b[i]); printf("\n");}
#pragma unroll
for (int j = in_task_id; j < 32; j += 4) s += share_a[k + j] * share_b[task_id_in_block * 36 + j];
__syncthreads();
}
s += __shfl_xor_sync(0xffffffff, s, 2);
s += __shfl_xor_sync(0xffffffff, s, 1);
if (in_task_id == 0 && blockIdx.x * 64 + task_id_in_block < num_tasks) output0[blockIdx.x * 64 + task_id_in_block] = s;
}

__shared__ float share_b[768]; // 36 * 64
share_a[threadIdx.x] = input0[threadIdx.x];
float s = 0.0;
for (int k = 0; k < 256; k += 32) {
#pragma unroll
for (int i = warp_id; i < 16; i += 8) if (blockIdx.x * 16 + i < num_tasks) share_b[i * 48 + lane_id] = input1[block_start + i * 256 + k + lane_id];
__syncthreads();
// if (threadIdx.x == 0) { printf("shareb k=%d:", k); for (int i = 0; i < 32; i++) printf("%f ", share_b[i]); printf("\n");}
#pragma unroll
for (int j = in_task_id; j < 32; j += 16) s += share_a[k + j] * share_b[task_id_in_block * 48 + j];
__syncthreads();
}

s += __shfl_xor_sync(0xffffffff, s, 8);
s += __shfl_xor_sync(0xffffffff, s, 4);
s += __shfl_xor_sync(0xffffffff, s, 2);
s += __shfl_xor_sync(0xffffffff, s, 1);
if (in_task_id == 0 && blockIdx.x * 16 + task_id_in_block < num_tasks) output0[blockIdx.x * 16 + task_id_in_block] = s;
}
extern void Dot_float_float_float_cuda_Dot_77_Call(const dim3 &grids, const dim3 &blocks, unsigned mem, cudaStream_t stream, float* __restrict__ input0, float* __restrict__ input1, float* __restrict__ output0) {
Dot_float_float_float_cuda_Dot_77<<<grids, blocks, mem, stream>>>(input0, input1, output0);
Expand Down Expand Up @@ -1546,7 +1540,7 @@ FusedKernel_float_float_float_float_float_float_float_float_float_float_float_fl
// name=@tmp_33
Reshape_float_float_cuda_Reshape_21_Call(dim3(1, 1, 1), dim3(256, 1, 1), 0, 0, Result_96_0, Reshape_76_0);
// name=Dot_77
Dot_float_float_float_cuda_Dot_77_Call(dim3(64, 1, 1), dim3(256, 1, 1), 0, 0, Reshape_76_0, Constant_11_0, Dot_77_0);
Dot_float_float_float_cuda_Dot_77_Call(dim3(238, 1, 1), dim3(256, 1, 1), 0, 0, Reshape_76_0, Constant_11_0, Dot_77_0);
// name=ElementWiseFused_103
FusedKernel_float_float_float_float_cuda_Add_Reshape_Add_1_Call(dim3(15, 1, 1), dim3(256, 1, 1), 0, 0, Dot_77_0, Reshape_78_0, GatherV2_82_0, Add_83_0);
// name=output_2
Expand Down
2 changes: 1 addition & 1 deletion artifacts/models/manual_seq2seq/bs64/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def read_bin(s, dtype=np.float32):
mask = masks[:, i:i+batch_size].contiguous()
torch.cuda.synchronize()
output_all, h = model(encoder_output, mask, h, c)
if i == 0: print(output_all)
if i == 0: print(output_all, h)
torch.cuda.synchronize()
# run
timer = Timer("ms")
Expand Down
Loading

0 comments on commit a6514e7

Please sign in to comment.