diff --git a/raf_native_functions.yaml b/raf_native_functions.yaml index 778e5e7..90808f0 100644 --- a/raf_native_functions.yaml +++ b/raf_native_functions.yaml @@ -74,8 +74,6 @@ supported: - _index_put_impl_ - inverse - isnan - - kl_div - - kl_div_backward - kthvalue - log - log10 diff --git a/ratex/csrc/aten_raf_type.cpp b/ratex/csrc/aten_raf_type.cpp index e867051..e7c6677 100644 --- a/ratex/csrc/aten_raf_type.cpp +++ b/ratex/csrc/aten_raf_type.cpp @@ -1437,21 +1437,6 @@ at::Tensor LazyNativeFunctions::isnan(const at::Tensor& self) { return bridge::AtenFromLtcTensor(LazyTensor::isnan(bridge::raf_backend::GetLtcTensor(self))); } -at::Tensor LazyNativeFunctions::kl_div(const at::Tensor& self, const at::Tensor& target, - int64_t reduction, bool log_target) { - LTC_FN_COUNTER("raf::"); - return at::native::kl_div(self, target, reduction, log_target); -} - -at::Tensor LazyNativeFunctions::kl_div_backward(const at::Tensor& grad_output, - const at::Tensor& self, const at::Tensor& target, - int64_t reduction, bool log_target) { - LTC_FN_COUNTER("raf::"); - return bridge::AtenFromLtcTensor(LazyTensor::kl_div_backward( - bridge::raf_backend::GetLtcTensor(grad_output), bridge::raf_backend::GetLtcTensor(self), - bridge::raf_backend::GetLtcTensor(target), reduction, log_target)); -} - std::tuple LazyNativeFunctions::kthvalue(const at::Tensor& self, int64_t k, int64_t dim, bool keepdim) { LTC_FN_COUNTER("raf::"); @@ -1460,7 +1445,6 @@ std::tuple LazyNativeFunctions::kthvalue(const at::Tenso bridge::AtenFromLtcTensor(std::get<1>(results))); } - at::Tensor LazyNativeFunctions::le(const at::Tensor& self, const at::Scalar& other) { LTC_FN_COUNTER("raf::"); return bridge::AtenFromLtcTensor(LazyTensor::le(bridge::raf_backend::GetLtcTensor(self), other)); @@ -1715,7 +1699,6 @@ at::Tensor LazyNativeFunctions::max_unpool2d(const at::Tensor& self, const at::T lazy_tensors::util::ToVector(output_size))); } - at::Tensor LazyNativeFunctions::max_unpool3d(const at::Tensor& self, const at::Tensor& indices, at::IntArrayRef output_size, at::IntArrayRef stride, at::IntArrayRef padding) { @@ -1725,7 +1708,6 @@ at::Tensor LazyNativeFunctions::max_unpool3d(const at::Tensor& self, const at::T lazy_tensors::util::ToVector(output_size))); } - at::Tensor LazyNativeFunctions::mean(const at::Tensor& self, c10::optional dtype) { LTC_FN_COUNTER("raf::"); LazyTensor self_tensor = bridge::raf_backend::GetLtcTensor(self); @@ -1734,12 +1716,12 @@ at::Tensor LazyNativeFunctions::mean(const at::Tensor& self, c10::optional dtype) { +at::Tensor LazyNativeFunctions::mean(const at::Tensor& self, at::OptionalIntArrayRef dim, + bool keepdim, c10::optional dtype) { LTC_FN_COUNTER("raf::"); - return bridge::AtenFromLtcTensor(LazyTensor::mean(bridge::raf_backend::GetLtcTensor(self), - lazy_tensors::util::ToVector(dim), - /*keep_reduced_dimensions=*/keepdim, dtype)); + return bridge::AtenFromLtcTensor(LazyTensor::mean( + bridge::raf_backend::GetLtcTensor(self), lazy_tensors::util::ToVector(dim.value()), + /*keep_reduced_dimensions=*/keepdim, dtype)); } at::Tensor LazyNativeFunctions::min(const at::Tensor& self) { @@ -2547,12 +2529,12 @@ at::Tensor LazyNativeFunctions::sum(const at::Tensor& self, c10::optional dtype) { +at::Tensor LazyNativeFunctions::sum(const at::Tensor& self, at::OptionalIntArrayRef dim, + bool keepdim, c10::optional dtype) { LTC_FN_COUNTER("raf::"); - return bridge::AtenFromLtcTensor(LazyTensor::sum(bridge::raf_backend::GetLtcTensor(self), - lazy_tensors::util::ToVector(dim), - keepdim, dtype)); + return bridge::AtenFromLtcTensor( + LazyTensor::sum(bridge::raf_backend::GetLtcTensor(self), + lazy_tensors::util::ToVector(dim.value()), keepdim, dtype)); } std::tuple LazyNativeFunctions::svd(const at::Tensor& self, @@ -2762,8 +2744,8 @@ at::Tensor LazyNativeFunctions::upsample_nearest2d( } at::Tensor LazyNativeFunctions::upsample_nearest2d_backward( - const at::Tensor& grad_output, at::OptionalIntArrayRef output_size, - at::IntArrayRef input_size, c10::optional> scale_factors) { + const at::Tensor& grad_output, at::OptionalIntArrayRef output_size, at::IntArrayRef input_size, + c10::optional> scale_factors) { LTC_FN_COUNTER("raf::"); LazyTensor grad_output_tensor = bridge::raf_backend::GetLtcTensor(grad_output); if (grad_output_tensor.GetDevice().hw_type != DeviceType::TPU) { diff --git a/ratex/jit/script.py b/ratex/jit/script.py index 96e996d..26e53c3 100644 --- a/ratex/jit/script.py +++ b/ratex/jit/script.py @@ -260,7 +260,8 @@ def wrapper(*args, **kwargs): # TODO: use torch.jit.script assert len(args) == 1, f"Only support single input for now, but got {len(args)}" assert not kwargs, "Do not support kwargs yet" - shape_n_dtype = (list(args[0].shape), str(args[0].dtype).rsplit(".", maxsplit=1)[-1]) + arg0_meta = args[0].to("meta") + shape_n_dtype = (list(arg0_meta.shape), str(arg0_meta.dtype).rsplit(".", maxsplit=1)[-1]) cache_key = (hash_torch_module(module), str(shape_n_dtype)) if cache_key in JIT_CACHE: # Cache hit. diff --git a/ratex/lazy_tensors/computation_client/debug_macros.h b/ratex/lazy_tensors/computation_client/debug_macros.h index 258b226..f43376a 100644 --- a/ratex/lazy_tensors/computation_client/debug_macros.h +++ b/ratex/lazy_tensors/computation_client/debug_macros.h @@ -15,13 +15,13 @@ #define LTC_ERROR() LOG(ERROR) #define LTC_CHECK(c) CHECK(c) -#define LTC_CHECK_OK(c) CHECK(c.ok()) -#define LTC_CHECK_EQ(a, b) CHECK_EQ(a, b) -#define LTC_CHECK_NE(a, b) CHECK_NE(a, b) -#define LTC_CHECK_LE(a, b) CHECK_LE(a, b) -#define LTC_CHECK_GE(a, b) CHECK_GE(a, b) -#define LTC_CHECK_LT(a, b) CHECK_LT(a, b) -#define LTC_CHECK_GT(a, b) CHECK_GT(a, b) +#define LTC_CHECK_OK(c) TORCH_CHECK(c.ok()) +#define LTC_CHECK_EQ(a, b) TORCH_CHECK_EQ(a, b) +#define LTC_CHECK_NE(a, b) TORCH_CHECK_NE(a, b) +#define LTC_CHECK_LE(a, b) TORCH_CHECK_LE(a, b) +#define LTC_CHECK_GE(a, b) TORCH_CHECK_GE(a, b) +#define LTC_CHECK_LT(a, b) TORCH_CHECK_LT(a, b) +#define LTC_CHECK_GT(a, b) TORCH_CHECK_GT(a, b) template T ConsumeValue(lazy_tensors::StatusOr&& status) { diff --git a/scripts/pinned_torch_nightly.txt b/scripts/pinned_torch_nightly.txt index 8c2ae1e..7e1d1b6 100644 --- a/scripts/pinned_torch_nightly.txt +++ b/scripts/pinned_torch_nightly.txt @@ -1 +1 @@ -1.13.0.dev20220629 +1.13.0.dev20220801