diff --git a/torchrec/distributed/benchmark/benchmark_inference.py b/torchrec/distributed/benchmark/benchmark_inference.py index 54774f5ef..9f28394a7 100644 --- a/torchrec/distributed/benchmark/benchmark_inference.py +++ b/torchrec/distributed/benchmark/benchmark_inference.py @@ -439,11 +439,10 @@ def main() -> None: # Place all outputs under the datetime folder os.mkdir(output_dir) - # TODO: ROW_WISE and COLUMN_WISE are not supported yet BENCH_SHARDING_TYPES = [ ShardingType.TABLE_WISE, - # ShardingType.ROW_WISE, - # ShardingType.COLUMN_WISE, + ShardingType.ROW_WISE, + ShardingType.COLUMN_WISE, ] table_sizes = [