diff --git a/tests/benchmarks/bench_transformer.py b/tests/benchmarks/bench_transformer.py index 0e2d5340e..4839c03e9 100644 --- a/tests/benchmarks/bench_transformer.py +++ b/tests/benchmarks/bench_transformer.py @@ -59,6 +59,11 @@ def bench_causal_lm(model_name, bs, genlen, dtype, backend, mode): inputs = tokenizer(input_string_batch, return_tensors='pt')['input_ids'].cuda() with torch.no_grad(), torch.autocast("cuda"): + # Temporary workaround for gpt-j + # gpt-j initializes tensors during the first forwasd pass + # which causes recompilation during the second forward pass + if model_name == 'EleutherAI/gpt-j-6B': + model(inputs) model = comp_backend.compile(model) latency = bench_gen_model(model, tokenizer, inputs, bs=bs, genlen=genlen) del model