From 5e71f5491169d5a2d156491fd5cc537520b5446a Mon Sep 17 00:00:00 2001 From: AIREMetaBot Date: Mon, 1 Aug 2022 13:19:26 +0000 Subject: [PATCH 1/4] [Compatible] Update Pinned PyTorch Nightly 1.13.0.dev20220801 --- scripts/pinned_torch_nightly.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/pinned_torch_nightly.txt b/scripts/pinned_torch_nightly.txt index 8c2ae1e..7e1d1b6 100644 --- a/scripts/pinned_torch_nightly.txt +++ b/scripts/pinned_torch_nightly.txt @@ -1 +1 @@ -1.13.0.dev20220629 +1.13.0.dev20220801 From 73a973559d984beb9b4590d3d6a70c42a6288b24 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 1 Aug 2022 20:21:46 +0000 Subject: [PATCH 2/4] fix --- raf_native_functions.yaml | 2 - ratex/csrc/aten_raf_type.cpp | 42 ++++++------------- .../computation_client/debug_macros.h | 14 +++---- 3 files changed, 19 insertions(+), 39 deletions(-) diff --git a/raf_native_functions.yaml b/raf_native_functions.yaml index 778e5e7..90808f0 100644 --- a/raf_native_functions.yaml +++ b/raf_native_functions.yaml @@ -74,8 +74,6 @@ supported: - _index_put_impl_ - inverse - isnan - - kl_div - - kl_div_backward - kthvalue - log - log10 diff --git a/ratex/csrc/aten_raf_type.cpp b/ratex/csrc/aten_raf_type.cpp index e867051..e7c6677 100644 --- a/ratex/csrc/aten_raf_type.cpp +++ b/ratex/csrc/aten_raf_type.cpp @@ -1437,21 +1437,6 @@ at::Tensor LazyNativeFunctions::isnan(const at::Tensor& self) { return bridge::AtenFromLtcTensor(LazyTensor::isnan(bridge::raf_backend::GetLtcTensor(self))); } -at::Tensor LazyNativeFunctions::kl_div(const at::Tensor& self, const at::Tensor& target, - int64_t reduction, bool log_target) { - LTC_FN_COUNTER("raf::"); - return at::native::kl_div(self, target, reduction, log_target); -} - -at::Tensor LazyNativeFunctions::kl_div_backward(const at::Tensor& grad_output, - const at::Tensor& self, const at::Tensor& target, - int64_t reduction, bool log_target) { - LTC_FN_COUNTER("raf::"); - return bridge::AtenFromLtcTensor(LazyTensor::kl_div_backward( - bridge::raf_backend::GetLtcTensor(grad_output), bridge::raf_backend::GetLtcTensor(self), - bridge::raf_backend::GetLtcTensor(target), reduction, log_target)); -} - std::tuple LazyNativeFunctions::kthvalue(const at::Tensor& self, int64_t k, int64_t dim, bool keepdim) { LTC_FN_COUNTER("raf::"); @@ -1460,7 +1445,6 @@ std::tuple LazyNativeFunctions::kthvalue(const at::Tenso bridge::AtenFromLtcTensor(std::get<1>(results))); } - at::Tensor LazyNativeFunctions::le(const at::Tensor& self, const at::Scalar& other) { LTC_FN_COUNTER("raf::"); return bridge::AtenFromLtcTensor(LazyTensor::le(bridge::raf_backend::GetLtcTensor(self), other)); @@ -1715,7 +1699,6 @@ at::Tensor LazyNativeFunctions::max_unpool2d(const at::Tensor& self, const at::T lazy_tensors::util::ToVector(output_size))); } - at::Tensor LazyNativeFunctions::max_unpool3d(const at::Tensor& self, const at::Tensor& indices, at::IntArrayRef output_size, at::IntArrayRef stride, at::IntArrayRef padding) { @@ -1725,7 +1708,6 @@ at::Tensor LazyNativeFunctions::max_unpool3d(const at::Tensor& self, const at::T lazy_tensors::util::ToVector(output_size))); } - at::Tensor LazyNativeFunctions::mean(const at::Tensor& self, c10::optional dtype) { LTC_FN_COUNTER("raf::"); LazyTensor self_tensor = bridge::raf_backend::GetLtcTensor(self); @@ -1734,12 +1716,12 @@ at::Tensor LazyNativeFunctions::mean(const at::Tensor& self, c10::optional dtype) { +at::Tensor LazyNativeFunctions::mean(const at::Tensor& self, at::OptionalIntArrayRef dim, + bool keepdim, c10::optional dtype) { LTC_FN_COUNTER("raf::"); - return bridge::AtenFromLtcTensor(LazyTensor::mean(bridge::raf_backend::GetLtcTensor(self), - lazy_tensors::util::ToVector(dim), - /*keep_reduced_dimensions=*/keepdim, dtype)); + return bridge::AtenFromLtcTensor(LazyTensor::mean( + bridge::raf_backend::GetLtcTensor(self), lazy_tensors::util::ToVector(dim.value()), + /*keep_reduced_dimensions=*/keepdim, dtype)); } at::Tensor LazyNativeFunctions::min(const at::Tensor& self) { @@ -2547,12 +2529,12 @@ at::Tensor LazyNativeFunctions::sum(const at::Tensor& self, c10::optional dtype) { +at::Tensor LazyNativeFunctions::sum(const at::Tensor& self, at::OptionalIntArrayRef dim, + bool keepdim, c10::optional dtype) { LTC_FN_COUNTER("raf::"); - return bridge::AtenFromLtcTensor(LazyTensor::sum(bridge::raf_backend::GetLtcTensor(self), - lazy_tensors::util::ToVector(dim), - keepdim, dtype)); + return bridge::AtenFromLtcTensor( + LazyTensor::sum(bridge::raf_backend::GetLtcTensor(self), + lazy_tensors::util::ToVector(dim.value()), keepdim, dtype)); } std::tuple LazyNativeFunctions::svd(const at::Tensor& self, @@ -2762,8 +2744,8 @@ at::Tensor LazyNativeFunctions::upsample_nearest2d( } at::Tensor LazyNativeFunctions::upsample_nearest2d_backward( - const at::Tensor& grad_output, at::OptionalIntArrayRef output_size, - at::IntArrayRef input_size, c10::optional> scale_factors) { + const at::Tensor& grad_output, at::OptionalIntArrayRef output_size, at::IntArrayRef input_size, + c10::optional> scale_factors) { LTC_FN_COUNTER("raf::"); LazyTensor grad_output_tensor = bridge::raf_backend::GetLtcTensor(grad_output); if (grad_output_tensor.GetDevice().hw_type != DeviceType::TPU) { diff --git a/ratex/lazy_tensors/computation_client/debug_macros.h b/ratex/lazy_tensors/computation_client/debug_macros.h index 258b226..f43376a 100644 --- a/ratex/lazy_tensors/computation_client/debug_macros.h +++ b/ratex/lazy_tensors/computation_client/debug_macros.h @@ -15,13 +15,13 @@ #define LTC_ERROR() LOG(ERROR) #define LTC_CHECK(c) CHECK(c) -#define LTC_CHECK_OK(c) CHECK(c.ok()) -#define LTC_CHECK_EQ(a, b) CHECK_EQ(a, b) -#define LTC_CHECK_NE(a, b) CHECK_NE(a, b) -#define LTC_CHECK_LE(a, b) CHECK_LE(a, b) -#define LTC_CHECK_GE(a, b) CHECK_GE(a, b) -#define LTC_CHECK_LT(a, b) CHECK_LT(a, b) -#define LTC_CHECK_GT(a, b) CHECK_GT(a, b) +#define LTC_CHECK_OK(c) TORCH_CHECK(c.ok()) +#define LTC_CHECK_EQ(a, b) TORCH_CHECK_EQ(a, b) +#define LTC_CHECK_NE(a, b) TORCH_CHECK_NE(a, b) +#define LTC_CHECK_LE(a, b) TORCH_CHECK_LE(a, b) +#define LTC_CHECK_GE(a, b) TORCH_CHECK_GE(a, b) +#define LTC_CHECK_LT(a, b) TORCH_CHECK_LT(a, b) +#define LTC_CHECK_GT(a, b) TORCH_CHECK_GT(a, b) template T ConsumeValue(lazy_tensors::StatusOr&& status) { From e1c36e7113125235bcda7344be852085628057bf Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 1 Aug 2022 22:29:41 +0000 Subject: [PATCH 3/4] fizx --- ratex/jit/script.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/ratex/jit/script.py b/ratex/jit/script.py index 96e996d..609b2ad 100644 --- a/ratex/jit/script.py +++ b/ratex/jit/script.py @@ -187,7 +187,7 @@ def saver(value): @ltc_timed("RAFTraceConvertModuleToRAF") @persist_cache_fn -def convert_module_to_raf(module, shape_n_dtype, args): +def convert_module_to_raf(module, shape_n_dtype, arg_np): """Convert the PyTorch module to RAF and apply necessary transformations. Parameters ---------- @@ -195,8 +195,8 @@ def convert_module_to_raf(module, shape_n_dtype, args): The PyTorch module to be converted. shape_n_dtype : List[Tuple[int, torch.dtype]] The shape and dtype of the input tensor. - args : List[torch.Tensor] - The input tensors. + arg_np : np.ndarray + The input tensors in numpy array. Note that we do not support multiple arguments for now. Returns ------- @@ -213,7 +213,7 @@ def convert_module_to_raf(module, shape_n_dtype, args): # Must use *.clone(), otherwise the tensor will be removed from live tensors graph # because asnumpy() calls *.cpu() - record = model._internal(raf.array(asnumpy(args[0].clone()))) + record = model._internal(raf.array(arg_np)) mod = record.mod mod = AutoDiff([])(InferType()(mod)) mod = DeadCodeElimination()(mod) @@ -260,7 +260,8 @@ def wrapper(*args, **kwargs): # TODO: use torch.jit.script assert len(args) == 1, f"Only support single input for now, but got {len(args)}" assert not kwargs, "Do not support kwargs yet" - shape_n_dtype = (list(args[0].shape), str(args[0].dtype).rsplit(".", maxsplit=1)[-1]) + arg0_np = asnumpy(args[0].clone()) + shape_n_dtype = (list(arg0_np.shape), str(arg0_np.dtype).rsplit(".", maxsplit=1)[-1]) cache_key = (hash_torch_module(module), str(shape_n_dtype)) if cache_key in JIT_CACHE: # Cache hit. @@ -273,7 +274,7 @@ def wrapper(*args, **kwargs): inplace_update_map, raf_params_shape, raf_params_dtype, - ) = convert_module_to_raf(module, shape_n_dtype, args) + ) = convert_module_to_raf(module, shape_n_dtype, arg0_np) # Convert missing args params_keys = [to_raf_name(k) for k in params.keys()] for name in param_names: From f6fa249f84ab3a59b895c8f58ef432292086d6d2 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Mon, 1 Aug 2022 22:34:47 +0000 Subject: [PATCH 4/4] use meta --- ratex/jit/script.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ratex/jit/script.py b/ratex/jit/script.py index 609b2ad..26e53c3 100644 --- a/ratex/jit/script.py +++ b/ratex/jit/script.py @@ -187,7 +187,7 @@ def saver(value): @ltc_timed("RAFTraceConvertModuleToRAF") @persist_cache_fn -def convert_module_to_raf(module, shape_n_dtype, arg_np): +def convert_module_to_raf(module, shape_n_dtype, args): """Convert the PyTorch module to RAF and apply necessary transformations. Parameters ---------- @@ -195,8 +195,8 @@ def convert_module_to_raf(module, shape_n_dtype, arg_np): The PyTorch module to be converted. shape_n_dtype : List[Tuple[int, torch.dtype]] The shape and dtype of the input tensor. - arg_np : np.ndarray - The input tensors in numpy array. Note that we do not support multiple arguments for now. + args : List[torch.Tensor] + The input tensors. Returns ------- @@ -213,7 +213,7 @@ def convert_module_to_raf(module, shape_n_dtype, arg_np): # Must use *.clone(), otherwise the tensor will be removed from live tensors graph # because asnumpy() calls *.cpu() - record = model._internal(raf.array(arg_np)) + record = model._internal(raf.array(asnumpy(args[0].clone()))) mod = record.mod mod = AutoDiff([])(InferType()(mod)) mod = DeadCodeElimination()(mod) @@ -260,8 +260,8 @@ def wrapper(*args, **kwargs): # TODO: use torch.jit.script assert len(args) == 1, f"Only support single input for now, but got {len(args)}" assert not kwargs, "Do not support kwargs yet" - arg0_np = asnumpy(args[0].clone()) - shape_n_dtype = (list(arg0_np.shape), str(arg0_np.dtype).rsplit(".", maxsplit=1)[-1]) + arg0_meta = args[0].to("meta") + shape_n_dtype = (list(arg0_meta.shape), str(arg0_meta.dtype).rsplit(".", maxsplit=1)[-1]) cache_key = (hash_torch_module(module), str(shape_n_dtype)) if cache_key in JIT_CACHE: # Cache hit. @@ -274,7 +274,7 @@ def wrapper(*args, **kwargs): inplace_update_map, raf_params_shape, raf_params_dtype, - ) = convert_module_to_raf(module, shape_n_dtype, arg0_np) + ) = convert_module_to_raf(module, shape_n_dtype, args) # Convert missing args params_keys = [to_raf_name(k) for k in params.keys()] for name in param_names: