From 9dcbe3b281c98ba9a3bd192495bef42570095336 Mon Sep 17 00:00:00 2001 From: Jianhua Zheng Date: Mon, 28 Oct 2024 10:31:16 +0800 Subject: [PATCH] using tensor device in eval distributed (#557) * using tensor device in eval distributed * tensor_to_rank0 device --- libai/evaluation/evaluator.py | 4 ++-- libai/utils/distributed.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/libai/evaluation/evaluator.py b/libai/evaluation/evaluator.py index 1414cdaa0..39c699c7a 100644 --- a/libai/evaluation/evaluator.py +++ b/libai/evaluation/evaluator.py @@ -203,12 +203,12 @@ def inference_on_dataset( # get valid sample valid_data = { - key: dist.tensor_to_rank0(value, to_local=True)[:valid_sample] + key: dist.tensor_to_rank0(value, device=value.placement.type, to_local=True)[:valid_sample] for key, value in data.items() } valid_outputs = {} for key, value in outputs.items(): - value = dist.tensor_to_rank0(value, to_local=True) + value = dist.tensor_to_rank0(value, device=value.placement.type, to_local=True) if value.ndim > 1: valid_outputs[key] = value[:valid_sample] # Slice if it's batched output else: diff --git a/libai/utils/distributed.py b/libai/utils/distributed.py index f64479210..a76372956 100644 --- a/libai/utils/distributed.py +++ b/libai/utils/distributed.py @@ -471,7 +471,6 @@ def tton(tensor, local_only=False, ranks=None): def tensor_to_rank0(tensor, device="cuda", to_local=False): """Global tensor to rank0.""" - # assert device in ["cpu", "cuda"], f"not supported for device:{device}" if tensor.is_global: # Consider if it's 2d mesh, ranks should be [[0]] instead of [0] placement = flow.placement(device, ranks=[0] if tensor.placement.ranks.ndim == 1 else [[0]])