Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
simran-arora committed Jan 4, 2025
1 parent 88e329f commit b108138
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 47 deletions.
42 changes: 35 additions & 7 deletions kernels/torch_scaled/gentests.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
import torch
import torch.nn.functional as F
import sys
from tqdm import tqdm


size=(16, 16)
size=(4096, 4096)

TESTNAME = sys.argv[1]

if TESTNAME == 'ones':
x = torch.ones(size, dtype=torch.float16, device='cuda')
w = torch.ones(size, dtype=torch.float16, device='cuda').t() # Note: cuBLASLt float8 matmul requires column major for the second argument
x = 10 * torch.ones(size, dtype=torch.float32, device='cuda')
w = 10 * torch.ones(size, dtype=torch.float32, device='cuda').t() # Note: cuBLASLt float8 matmul requires column major for the second argument
elif TESTNAME == 'randn':
torch.random.manual_seed(42)
x = torch.randn(size, dtype=torch.float16, device='cuda')
w = torch.randn(size, dtype=torch.float16, device='cuda').t() # Note: cuBLASLt float8 matmul requires column major for the second argument
x = torch.randn(size, dtype=torch.float32, device='cuda') * 0.2
w = torch.randn(size, dtype=torch.float32, device='cuda').t() * 0.2
elif TESTNAME == 'custom':
x = torch.arange(size[0] * size[1], dtype=torch.float32, device='cuda') / 100000.0
w = x.clone() # Since you want the same values
x = x.reshape(size).contiguous()
w = w.reshape(size).contiguous().t()
print(f"x.shape: {x.shape}, w.shape: {w.shape}")
else:
print('Invalid test name')
sys.exit(0)
Expand All @@ -34,7 +41,9 @@ def compare_f8_mm(dtype=torch.float8_e4m3fn) -> None:
x_f8, x_inv_s = to_float8_e4m3fn(x)
w_f8, w_inv_s = to_float8_e4m3fn(w)

breakpoint()
print(f'x_inv_s: {x_inv_s[:10]}')
print(f'w_inv_s: {w_inv_s[:10]}')

y = torch._scaled_mm(
x_f8, w_f8,
out_dtype=torch.bfloat16,
Expand All @@ -47,7 +56,26 @@ def compare_f8_mm(dtype=torch.float8_e4m3fn) -> None:
cos_sim = F.cosine_similarity(torch.mm(x, w).reshape(-1), y.reshape(-1), dim=0)
print(f'cos_sim {cos_sim.item():.4f}')

o = compare_f8_mm()
# average of x
avg_x = x.abs().mean()
print(f'avg_x {avg_x.item():.4f}')

# average of w
avg_w = w.abs().mean()
print(f'avg_w {avg_w.item():.4f}')

# average of y
avg_y = y.abs().mean()
print(f'avg_y {avg_y.item():.4f}')

# max diff
max_diff = (torch.mm(x, w) - y).abs().max()
print(f'max_diff {max_diff.item():.4f}')

# average diff
avg_diff = (torch.mm(x, w) - y).abs().mean()
print(f'avg_diff {avg_diff.item():.4f}')


if __name__ == "__main__":
compare_f8_mm()
Expand Down
Binary file added kernels/torch_scaled/scaled_matmul
Binary file not shown.
70 changes: 30 additions & 40 deletions kernels/torch_scaled/scaled_matmul.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "kittens.cuh"
#include "prototype.cuh"
#include <iomanip>

using namespace kittens;
using namespace kittens::prototype;
Expand Down Expand Up @@ -198,8 +199,10 @@ int run_benchmark(size_t M, size_t N, size_t K) {
std::normal_distribution dis(0.0f, 1.0f);

// Initialize matrices with random values
for (int i = 0; i < M * K; ++i) h_A[i] = dis(gen) * 10.0f;
for (int i = 0; i < K * N; ++i) h_B[i] = dis(gen) * 10.0f;
// for (int i = 0; i < M * K; ++i) h_A[i] = i / 100000.0f; // dis(gen) * 0.2f;
// for (int i = 0; i < K * N; ++i) h_B[i] = i / 100000.0f; // dis(gen) * 0.2f;
for (int i = 0; i < M * K; ++i) h_A[i] = dis(gen) * 0.2f;
for (int i = 0; i < K * N; ++i) h_B[i] = dis(gen) * 0.2f;

std::cout << "Initialized matrices" << std::endl;

Expand All @@ -224,14 +227,6 @@ int run_benchmark(size_t M, size_t N, size_t K) {

std::cout << "Allocated device memory" << std::endl;

// Convert to __nv_fp8_e4m3 and copy to device
__nv_fp8_e4m3 *h_A_fp8 = new __nv_fp8_e4m3[M * K];
__nv_fp8_e4m3 *h_B_fp8 = new __nv_fp8_e4m3[K * N];
for (int i = 0; i < M * K; ++i) h_A_fp8[i] = __nv_fp8_e4m3(h_A[i]);
for (int i = 0; i < K * N; ++i) h_B_fp8[i] = __nv_fp8_e4m3(h_B[i]);
for (int i = 0; i < M * K; ++i) h_A[i] = float(h_A_fp8[i]);
for (int i = 0; i < K * N; ++i) h_B[i] = float(h_B_fp8[i]);

// Perform CPU matrix multiplication for reference
if(true) cpu_gemm(h_A, h_B, h_C_ref, M, N, K);
std::cout << "Performed CPU matrix multiplication" << std::endl;
Expand All @@ -244,53 +239,44 @@ int run_benchmark(size_t M, size_t N, size_t K) {
__nv_fp8_e4m3 *h_A_fp8_scaled = new __nv_fp8_e4m3[M * K];
__nv_fp8_e4m3 *h_B_fp8_scaled = new __nv_fp8_e4m3[K * N];

// fill h_scale_a by following to_float8_e4m3fn
for(int i = 0; i < M; i++) {
// row-wise scaling
for(int row = 0; row < M; row++) {
float max_val = 0.0f;
for(int j = 0; j < K; j++) {
float abs_val = std::abs(h_A[i * K + j]);
for(int col = 0; col < K; col++) {
float abs_val = std::abs(h_A[row * K + col]);
max_val = std::max(max_val, abs_val);
}
h_scale_a[i] = 1.0f; //c_dtype(max_val / FP8_E4M3_MAX);

if ( i == 0 ) {
std::cout << "h_scale_a[" << i << "] = " << float(h_scale_a[i]) << ", max_val: " << max_val << std::endl;
h_scale_a[row] = c_dtype(max_val / FP8_E4M3_MAX);
if ( row < 10 ) {
std::cout << "h_scale_a[" << row << "] = " << float(h_scale_a[row]) << ", max_val: " << max_val << std::endl;
}
}

// fill h_A_fp8_scaled by following to_float8_e4m3fn.
for(int i = 0; i < M; i++) {
for(int j = 0; j < K; j++) {
h_A_fp8_scaled[i * K + j] = __nv_fp8_e4m3(h_A[i * K + j] / float(h_scale_a[i]));

if ( i == 0 && j == 0 ) {
std::cout << "h_A_fp8_scaled[" << i << "] = " << float(h_A_fp8_scaled[i * K + j]) << std::endl;
}
}
}

// fill h_scale_b by following to_float8_e4m3fn
for(int i = 0; i < N; i++) {
// column-wise scaling
for(int col = 0; col < N; col++) {
float max_val = 0.0f;
for(int j = 0; j < K; j++) {
float abs_val = std::abs(h_B[j * N + i]);
for(int row = 0; row < K; row++) {
float abs_val = std::abs(h_B[row * N + col]);
max_val = std::max(max_val, abs_val);
}
h_scale_b[i] = 1.0f; //c_dtype(max_val / FP8_E4M3_MAX);
if ( i == 0 ) {
std::cout << "h_scale_b[" << i << "] = " << float(h_scale_b[i]) << ", max_val: " << max_val << std::endl;
h_scale_b[col] = c_dtype(max_val / FP8_E4M3_MAX);

if ( col < 10 ) {
std::cout << "h_scale_b[" << col << "] = " << float(h_scale_b[col]) << ", max_val: " << max_val << std::endl;
}
}

// fill h_B_fp8_scaled by following to_float8_e4m3fn
for(int i = 0; i < N; i++) {
for(int j = 0; j < K; j++) {
h_B_fp8_scaled[j * N + i] = __nv_fp8_e4m3(h_B[j * N + i] / float(h_scale_b[i]));

if ( i == 0 && j == 0 ) {
std::cout << "h_B_fp8_scaled[" << i << "] = " << float(h_B_fp8_scaled[i * N + j]) << std::endl;
}
}
}

Expand Down Expand Up @@ -362,6 +348,7 @@ int run_benchmark(size_t M, size_t N, size_t K) {

// Check result
float max_error = 0.0f, total_error = 0.0f, total_ref = 0.0f, total_ours=0.0f;
float input_a = 0.0f, input_b = 0.0f;
int error_count = 0;
printf("Num rows: %d, Num cols: %d\n", M, N);
for (int i = 0; i < M * N; ++i) {
Expand All @@ -372,24 +359,27 @@ int run_benchmark(size_t M, size_t N, size_t K) {
error_count++;
}
max_error = std::max(max_error, error);
total_ref += h_C_ref[i]*h_C_ref[i];
total_error += error*error;
total_ours += h_C[i]*h_C[i];
total_ref += std::abs(h_C_ref[i]);
total_error += error;
total_ours += std::abs(h_C[i]);
input_a += std::abs(h_A[i]);
input_b += std::abs(h_B[i]);
}

std::cout << std::fixed << std::setprecision(6);
std::cout << "Max error: " << max_error << std::endl;
std::cout << "Average error: " << total_error / M / N << std::endl;
std::cout << "Average ref: " << total_ref / M / N << std::endl;
std::cout << "Average ref: " << total_ref / (M * N) << std::endl;
std::cout << "Average ours: " << total_ours / M / N << std::endl;
std::cout << "Average input_a: " << input_a / M / N << std::endl;
std::cout << "Average input_b: " << input_b / M / N << std::endl;
std::cout << "Error count: " << error_count << std::endl;

// Clean up
delete[] h_A;
delete[] h_B;
delete[] h_C;
delete[] h_C_ref;
delete[] h_A_fp8;
delete[] h_B_fp8;
delete[] h_C_out;
cudaFree(d_A);
cudaFree(d_B);
Expand Down

0 comments on commit b108138

Please sign in to comment.