diff --git a/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py b/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py index f625a044d4..91cea3b158 100644 --- a/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py +++ b/fbgemm_gpu/experimental/gen_ai/bench/quantize_bench.py @@ -31,29 +31,61 @@ def set_amd_env_vars() -> None: os.environ["PYTORCH_TUNABLEOP_MAX_WARMUP_DURATION_MS"] = "30" -def get_llama_shapes() -> List[Tuple[int, int, int]]: +def get_llama_shapes() -> List[Tuple[int, int, int, int]]: # Helper function that returns a list of shapes relevant to llama. llama_shapes = [] for M in [1, 16, 32, 64, 96, 128, 16384]: # Add shapes for llama 70B llama_shapes += [ - (M, 1280, 8192), - (M, 8192, 1024), - (M, 7168, 8192), - (M, 8192, 3584), + (1, M, 1280, 8192), + (1, M, 8192, 1024), + (1, M, 7168, 8192), + (1, M, 8192, 3584), ] # Add shapes for llama 405B llama_shapes += [ - (M, 13312, 6656), - (M, 13312, 16384), - (M, 16384, 6656), - (M, 16384, 16384), + (1, M, 13312, 6656), + (1, M, 13312, 16384), + (1, M, 16384, 6656), + (1, M, 16384, 16384), ] return llama_shapes +def get_ldm_shapes() -> List[Tuple[int, int, int, int]]: + # Helper function that returns a list of shapes relevant to ldm. + return [ + (1, 1536, 3584, 3584), + (1, 8192, 9728, 3584), + (1, 8192, 3584, 9728), + (1, 8192, 3584, 3584), + (1, 4096, 3584, 3584), + (1, 768, 3584, 3584), + (1, 4096, 9728, 3584), + (1, 4096, 3584, 9728), + (1, 7200, 3584, 3584), + (1, 7200, 9728, 3584), + (1, 7200, 3584, 9728), + (1, 3600, 3584, 3584), + (1, 3600, 9728, 3584), + (1, 3600, 3584, 9728), + (1, 1536, 4096, 4096), + (1, 3600, 4096, 4096), + (1, 3600, 11008, 4096), + (1, 3600, 4096, 11008), + (1, 4096, 4096, 4096), + (1, 4096, 11008, 4096), + (1, 4096, 4096, 11008), + (1, 32768, 128, 8192), + (1, 32768, 8192, 1024), + (1, 32768, 8192, 3072), + (1, 32768, 3072, 8192), + (1, 32768, 1024, 8192), + ] + + def benchmark_grouped( quantize_ops: List[QuantizeOpBase], b: List[int], @@ -297,6 +329,8 @@ def main(args: Any): B = [int(b) for b in args.B.strip().split(",")] if args.use_llama_shapes: MNK = get_llama_shapes() + elif args.use_ldm_shapes: + MNK = get_ldm_shapes() else: if args.M is None: M = [1, 4, 8, 16, 32, 64, 128, 2048, 4096, 8192, 16384] @@ -419,6 +453,12 @@ def invoke_main() -> None: action="store_true", help="If set, benchmark using fixed shapes relevant to llama workloads.", ) + parser.add_argument( + "--use_ldm_shapes", + default=False, + action="store_true", + help="If set, benchmark using fixed shapes relevant to ldm workloads.", + ) args = parser.parse_args() main(args)