Skip to content

Commit

Permalink
Change 8bit optimizer blocksize 2048->256; additional bf16 support (#…
Browse files Browse the repository at this point in the history
…1365)

* Change 8bit optimizer blocksize 2048->256; additional bf16 support
* Update tolerances for 8bit optimizer tests
  • Loading branch information
matthewdouglas authored Sep 20, 2024
1 parent d964546 commit aa57bd8
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 52 deletions.
6 changes: 5 additions & 1 deletion bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def prod(iterable):
"lamb": (
lib.cadam32bit_grad_fp32,
lib.cadam32bit_grad_fp16,
lib.cadam32bit_grad_bf16,
),
"ademamix": (
lib.cademamix32bit_grad_fp32,
Expand Down Expand Up @@ -96,10 +97,12 @@ def prod(iterable):
"momentum": (
lib.cmomentum_8bit_blockwise_grad_fp32,
lib.cmomentum_8bit_blockwise_grad_fp16,
lib.cmomentum_8bit_blockwise_grad_bf16,
),
"rmsprop": (
lib.crmsprop_8bit_blockwise_grad_fp32,
lib.crmsprop_8bit_blockwise_grad_fp16,
lib.crmsprop_8bit_blockwise_grad_bf16,
),
"lion": (
lib.clion_8bit_blockwise_grad_fp32,
Expand All @@ -109,6 +112,7 @@ def prod(iterable):
"adagrad": (
lib.cadagrad_8bit_blockwise_grad_fp32,
lib.cadagrad_8bit_blockwise_grad_fp16,
lib.cadagrad_8bit_blockwise_grad_bf16,
),
"ademamix": (
lib.cademamix_8bit_blockwise_grad_fp32,
Expand Down Expand Up @@ -398,7 +402,7 @@ def create_dynamic_map(signed=True, max_exponent_bits=7, total_bits=8):
data.append(0)

data.sort()
return Tensor(data)
return torch.tensor(data)


def create_quantile_map(A, total_bits=8):
Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/optim/ademamix.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def init_state(self, group, p, gindex, pindex):
self.name2qmap["udynamic"] = state["qmap2"] = self.name2qmap["udynamic"].to(p.device)

n = p.numel()
blocks = (n // 2048) + bool(n % 2048)
blocks = (n // 256) + bool(n % 256)

state["absmax1"] = torch.zeros((2, blocks), dtype=torch.float32, device=p.device)
state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
Expand Down
8 changes: 4 additions & 4 deletions bitsandbytes/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,8 +477,8 @@ def init_state(self, group, p, gindex, pindex):

if config["block_wise"]:
n = p.numel()
blocks = n // 2048
blocks += 1 if n % 2048 > 0 else 0
blocks = n // 256
blocks += 1 if n % 256 > 0 else 0

state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
state["absmax2"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
Expand Down Expand Up @@ -699,8 +699,8 @@ def init_state(self, group, p, gindex, pindex):

if config["block_wise"]:
n = p.numel()
blocks = n // 2048
blocks += 1 if n % 2048 > 0 else 0
blocks = n // 256
blocks += 1 if n % 256 > 0 else 0

state["absmax1"] = torch.zeros((blocks,), dtype=torch.float32, device=p.device)
else:
Expand Down
42 changes: 26 additions & 16 deletions csrc/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3829,27 +3829,33 @@ template __global__ void kPreconditionOptimizer32bit1State<gtype, oname, 4096, 8

MAKE_PreconditionOptimizer32bit1State(MOMENTUM, half)
MAKE_PreconditionOptimizer32bit1State(MOMENTUM, float)
MAKE_PreconditionOptimizer32bit1State(MOMENTUM, __nv_bfloat16)
MAKE_PreconditionOptimizer32bit1State(RMSPROP, half)
MAKE_PreconditionOptimizer32bit1State(RMSPROP, float)
MAKE_PreconditionOptimizer32bit1State(RMSPROP, __nv_bfloat16)
MAKE_PreconditionOptimizer32bit1State(LION, half)
MAKE_PreconditionOptimizer32bit1State(LION, float)
MAKE_PreconditionOptimizer32bit1State(LION, __nv_bfloat16)
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, half)
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float)
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, __nv_bfloat16)

#define MAKE_Optimizer32bit1State(oname, gtype) \
template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \

MAKE_Optimizer32bit1State(MOMENTUM, half)
MAKE_Optimizer32bit1State(MOMENTUM, float)
MAKE_Optimizer32bit1State(MOMENTUM, __nv_bfloat16)
MAKE_Optimizer32bit1State(RMSPROP, half)
MAKE_Optimizer32bit1State(RMSPROP, float)
MAKE_Optimizer32bit1State(RMSPROP, __nv_bfloat16)
MAKE_Optimizer32bit1State(LION, half)
MAKE_Optimizer32bit1State(LION, float)
MAKE_Optimizer32bit1State(LION, __nv_bfloat16)
MAKE_Optimizer32bit1State(ADAGRAD, half)
MAKE_Optimizer32bit1State(ADAGRAD, float)
MAKE_Optimizer32bit1State(ADAGRAD, __nv_bfloat16)

#define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \
template __global__ void kPreconditionOptimizer32bit2State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
Expand Down Expand Up @@ -3950,6 +3956,8 @@ MAKE_optimizerStatic8bit2State(ADAM, float)

template __global__ void kPercentileClipping<float, 2048, 4>(float * __restrict__ g, float *gnorm_vec, int step, const int n);
template __global__ void kPercentileClipping<half, 2048, 4>(half * __restrict__ g, float *gnorm_vec, int step, const int n);
// template __global__ void kPercentileClipping<float, 128, 4>(float * __restrict__ g, float *gnorm_vec, int step, const int n);
// template __global__ void kPercentileClipping<half, 128, 4>(half * __restrict__ g, float *gnorm_vec, int step, const int n);

#define MAKE_kQuantizeBlockwise(dtype, blocksize, num_per_thread, stochastic, data_type_name) \
template __global__ void kQuantizeBlockwise<dtype, blocksize, num_per_thread, stochastic, data_type_name>(float * code, dtype * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n); \
Expand Down Expand Up @@ -4041,13 +4049,12 @@ template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block
float weight_decay, \
const float gnorm_scale, const bool skip_zeros, const int n); \

MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 2048, 8)
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 2048, 8)
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, __nv_bfloat16, 2048, 8)
MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, float, 2048, 8)
MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, half, 2048, 8)
MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, __nv_bfloat16, 2048, 8)

MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, float, 256, 1)
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, half, 256, 1)
MAKE_OptimizerStatic8bit2StateBlockwise(ADAM, __nv_bfloat16, 256, 1)
MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, float, 256, 1)
MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, half, 256, 1)
MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, __nv_bfloat16, 256, 1)

#define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \
template __global__ void kOptimizerStatic8bit1StateBlockwise<gtype, oname, block_size, num_per_thread>( \
Expand All @@ -4059,15 +4066,18 @@ template __global__ void kOptimizerStatic8bit1StateBlockwise<gtype, oname, block
float weight_decay, \
const float gnorm_scale, const bool skip_zeros, const int n); \

MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, __nv_bfloat16, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 2048, 8)
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, float, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, half, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(MOMENTUM, __nv_bfloat16, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, float, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, half, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(RMSPROP, __nv_bfloat16, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, float, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, half, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(LION, __nv_bfloat16, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, float, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, half, 256, 1)
MAKE_OptimizerStatic8bit1StateBlockwise(ADAGRAD, __nv_bfloat16, 256, 1)

template __device__ void printnonzero<float>(float *A, int num_values, const char*strval);
template __device__ void printnonzero<half>(half *A, int num_values, const char*strval);
14 changes: 10 additions & 4 deletions csrc/ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,10 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g,
}
}

#define BLOCKSIZE_2STATE 2048
#define NUM_2STATE 8
#define BLOCKSIZE_1STATE 2048
#define NUM_1STATE 8
#define BLOCKSIZE_2STATE 256
#define NUM_2STATE 1
#define BLOCKSIZE_1STATE 256
#define NUM_1STATE 1

template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(
T* p,
Expand Down Expand Up @@ -818,13 +818,16 @@ MAKE_optimizer32bit(ADAM, float)
MAKE_optimizer32bit(ADAM, __nv_bfloat16)
MAKE_optimizer32bit(MOMENTUM, half)
MAKE_optimizer32bit(MOMENTUM, float)
MAKE_optimizer32bit(MOMENTUM, __nv_bfloat16)
MAKE_optimizer32bit(RMSPROP, half)
MAKE_optimizer32bit(RMSPROP, float)
MAKE_optimizer32bit(RMSPROP, __nv_bfloat16)
MAKE_optimizer32bit(LION, half)
MAKE_optimizer32bit(LION, float)
MAKE_optimizer32bit(LION, __nv_bfloat16)
MAKE_optimizer32bit(ADAGRAD, half)
MAKE_optimizer32bit(ADAGRAD, float)
MAKE_optimizer32bit(ADAGRAD, __nv_bfloat16)
MAKE_optimizer32bit(ADEMAMIX, half)
MAKE_optimizer32bit(ADEMAMIX, __nv_bfloat16)
MAKE_optimizer32bit(ADEMAMIX, float)
Expand Down Expand Up @@ -861,13 +864,16 @@ MAKE_optimizerStatic8bitBlockwise(float, ADAM);
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAM);
MAKE_optimizerStatic8bitBlockwise(half, MOMENTUM);
MAKE_optimizerStatic8bitBlockwise(float, MOMENTUM);
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, MOMENTUM);
MAKE_optimizerStatic8bitBlockwise(half, RMSPROP);
MAKE_optimizerStatic8bitBlockwise(float, RMSPROP);
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, RMSPROP);
MAKE_optimizerStatic8bitBlockwise(half, LION);
MAKE_optimizerStatic8bitBlockwise(float, LION);
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, LION);
MAKE_optimizerStatic8bitBlockwise(half, ADAGRAD);
MAKE_optimizerStatic8bitBlockwise(float, ADAGRAD);
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADAGRAD);
MAKE_optimizerStatic8bitBlockwise(half, ADEMAMIX);
MAKE_optimizerStatic8bitBlockwise(__nv_bfloat16, ADEMAMIX);
MAKE_optimizerStatic8bitBlockwise(float, ADEMAMIX);
Expand Down
14 changes: 10 additions & 4 deletions csrc/pythonInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,19 +103,22 @@ void fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \
{ optimizerStatic8bitBlockwise<gtype, optim_name>(p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\

MAKE_BLOCKWISE8(adam, ADAM, half, fp16)
MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
MAKE_BLOCKWISE8(adam, ADAM, float, fp32)
MAKE_BLOCKWISE8(momentum, MOMENTUM, half, fp16)
MAKE_BLOCKWISE8(momentum, MOMENTUM, __nv_bfloat16, bf16)
MAKE_BLOCKWISE8(momentum, MOMENTUM, float, fp32)
MAKE_BLOCKWISE8(rmsprop, RMSPROP, half, fp16)
MAKE_BLOCKWISE8(rmsprop, RMSPROP, __nv_bfloat16, bf16)
MAKE_BLOCKWISE8(rmsprop, RMSPROP, float, fp32)
MAKE_BLOCKWISE8(adagrad, ADAGRAD, half, fp16)
MAKE_BLOCKWISE8(adagrad, ADAGRAD, __nv_bfloat16, bf16)
MAKE_BLOCKWISE8(adagrad, ADAGRAD, float, fp32)
MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
MAKE_BLOCKWISE8(lion, LION, half, fp16)
MAKE_BLOCKWISE8(lion, LION, float, fp32)
MAKE_BLOCKWISE8(lion, LION, __nv_bfloat16, bf16)
MAKE_BLOCKWISE8(ademamix, ADEMAMIX, __nv_bfloat16, bf16)
MAKE_BLOCKWISE8(lion, LION, float, fp32)
MAKE_BLOCKWISE8(ademamix, ADEMAMIX, half, fp16)
MAKE_BLOCKWISE8(ademamix, ADEMAMIX, __nv_bfloat16, bf16)
MAKE_BLOCKWISE8(ademamix, ADEMAMIX, float, fp32)


Expand Down Expand Up @@ -283,13 +286,16 @@ extern "C"

MAKE_CBLOCKWISE8(adam, ADAM, half, fp16)
MAKE_CBLOCKWISE8(adam, ADAM, float, fp32)
MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(momentum, MOMENTUM, half, fp16)
MAKE_CBLOCKWISE8(momentum, MOMENTUM, float, fp32)
MAKE_CBLOCKWISE8(momentum, MOMENTUM, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, half, fp16)
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, float, fp32)
MAKE_CBLOCKWISE8(rmsprop, RMSPROP, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, half, fp16)
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, float, fp32)
MAKE_CBLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(adagrad, ADAGRAD, __nv_bfloat16, bf16)
MAKE_CBLOCKWISE8(lion, LION, half, fp16)
MAKE_CBLOCKWISE8(lion, LION, float, fp32)
MAKE_CBLOCKWISE8(lion, LION, __nv_bfloat16, bf16)
Expand Down
54 changes: 32 additions & 22 deletions tests/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,18 @@ def rm_path(path):
lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=k, t_beta3=k),
lambda pxx: bnb.optim.AdEMAMix(pxx, t_alpha=k, t_beta3=k),
)
str2optimizers["paged_ademamix_scheduled"] = (
lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=k, t_beta3=k),
lambda pxx: bnb.optim.PagedAdEMAMix(pxx, t_alpha=k, t_beta3=k),
)
str2optimizers["ademamix8bit_blockwise_scheduled"] = (
lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=100, t_beta3=100),
lambda pxx: bnb.optim.AdEMAMix8bit(pxx, t_alpha=100, t_beta3=100),
)
str2optimizers["paged_ademamix8bit_blockwise_scheduled"] = (
lambda pxx: bnb.optim.ademamix._ReferenceAdEMAMix(pxx, t_alpha=100, t_beta3=100),
lambda pxx: bnb.optim.PagedAdEMAMix8bit(pxx, t_alpha=100, t_beta3=100),
)

str2optimizers["lion"] = (Lion, bnb.optim.Lion)
str2optimizers["lion8bit"] = (Lion, lambda pxx: bnb.optim.Lion8bit(pxx, block_wise=False))
Expand Down Expand Up @@ -143,7 +151,7 @@ def rm_path(path):
str2statenames["paged_lion8bit_blockwise"] = [("exp_avg", "state1", "qmap1", "absmax1")]

str2statenames["ademamix"] = str2statenames["ademamix_scheduled"] = [("m1_m2", "state1"), ("nu", "state2")]
str2statenames["paged_ademamix"] = [("m1_m2", "state1"), ("nu", "state2")]
str2statenames["paged_ademamix"] = str2statenames["paged_ademamix_scheduled"] = [("m1_m2", "state1"), ("nu", "state2")]
str2statenames["ademamix8bit_blockwise"] = str2statenames["ademamix8bit_blockwise_scheduled"] = [
("m1_m2", "state1", "qmap1", "absmax1"),
("nu", "state2", "qmap2", "absmax2"),
Expand All @@ -164,6 +172,7 @@ def rm_path(path):
"ademamix",
"ademamix_scheduled",
"paged_ademamix",
"paged_ademamix_scheduled",
]


Expand Down Expand Up @@ -309,18 +318,15 @@ def test_global_config(dim1, dim2, gtype):
def test_optimizer8bit(dim1, dim2, gtype, optim_name):
torch.set_printoptions(precision=6)

if gtype == torch.bfloat16 and optim_name not in [
"adam8bit_blockwise",
"lion8bit_blockwise",
"ademamix8bit_blockwise",
]:
if gtype == torch.bfloat16 and "blockwise" not in optim_name:
pytest.skip()

if dim1 == 1 and dim2 == 1:
return
p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
p2 = p1.clone()
p1 = p1.float()
blocksize = 2048
blocksize = 256

torch_optimizer = str2optimizers[optim_name][0]([p1])
bnb_optimizer = str2optimizers[optim_name][1]([p2])
Expand All @@ -347,8 +353,7 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
torch_optimizer.step()

# since Lion can have pretty noisy updates where things lie at the boundary
# and AdEMAMix can diverge as well, allow up to 0.05% errors.
assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=int(p1.numel() * 5e-4))
assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=0)

dequant_states = []
for name1, name2, qmap, max_val in str2statenames[optim_name]:
Expand Down Expand Up @@ -392,11 +397,11 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):
err = torch.abs(p1 - p2)
relerr = err / (torch.abs(p1) + 1e-9)
if g.dtype == torch.bfloat16:
assert err.mean() < 0.00015
assert relerr.mean() < 0.0020 # 0.0016
assert err.mean() <= 0.00017
assert relerr.mean() <= 0.0016
else:
assert err.mean() < 0.00016 # 0.00012
assert relerr.mean() < 0.0016 # 0.0012
assert err.mean() < 0.00006
assert relerr.mean() < 0.0006

errors.append(err.mean().item())
relerrors.append(relerr.mean().item())
Expand Down Expand Up @@ -454,9 +459,9 @@ def test_optimizer8bit(dim1, dim2, gtype, optim_name):

num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0
assert num_not_close.sum().item() < 20
# since Lion can have pretty noisy updates where things lie at the boundary
# and AdEMAMix can also be noisy, allow up to 0.05%.
assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=int(p1.numel() * 5e-04))

# Lion can have pretty noisy updates where things lie at the boundary
assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=0)

# the parameters diverge quickly. Here we keep them close
# together so we can test against the Adam error
Expand Down Expand Up @@ -560,15 +565,19 @@ def test_adam_percentile_clipping(dim1, dim2, gtype, optim_bits):
optimizer_names_benchmark = [
"adam8bit_blockwise",
"paged_adam8bit_blockwise",
"paged_adamw8bit_blockwise",
"ademamix8bit_blockwise",
"paged_ademamix8bit_blockwise",
"ademamix8bit_blockwise_scheduled",
"paged_ademamix8bit_blockwise_scheduled",
"lion8bit_blockwise",
"paged_lion8bit_blockwise",
"paged_ademamix8bit_blockwise",
]


@pytest.mark.parametrize("dim1", [4096], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [4096], ids=id_formatter("dim2"))
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("gtype", [torch.float32, torch.bfloat16, torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("optim_name", optimizer_names_benchmark, ids=id_formatter("opt"))
@pytest.mark.benchmark
def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
Expand All @@ -580,8 +589,9 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):

g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
p1.grad = g
for i in range(k):
if i == k // 5:
total_steps = 500
for i in range(total_steps):
if i == total_steps // 5:
# 100 iterations for burn-in
torch.cuda.synchronize()
t0 = time.time()
Expand All @@ -591,8 +601,8 @@ def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
torch.cuda.synchronize()
s = time.time() - t0
print("")
params = (k - k // 5) * dim1 * dim2
print(optim_name, gtype, s / params)
params = (total_steps - total_steps // 5) * dim1 * dim2
print(optim_name, gtype, s, params, s / params)
# assert s < 3.9


Expand Down

0 comments on commit aa57bd8

Please sign in to comment.