diff --git a/src/ATen/native/xpu/TensorAdvancedIndexing.cpp b/src/ATen/native/xpu/TensorAdvancedIndexing.cpp index 988cac622..99d50acbf 100644 --- a/src/ATen/native/xpu/TensorAdvancedIndexing.cpp +++ b/src/ATen/native/xpu/TensorAdvancedIndexing.cpp @@ -340,6 +340,15 @@ Tensor& XPUNativeFunctions::index_fill_( "index_fill_(): Converting complex Scalar to non-complex type is not supported"); } + TORCH_CHECK( + self.device() == index.device(), + "index_fill_(): self and index value tensors ", + "should have same device type, but got self tensor device type ", + self.device(), + " and index value ", + "tensor device type ", + index.device()); + // Handle the case when `self` is 0-dim Tensor self_nonzero_dim = (self.dim() == 0) ? self.unsqueeze(-1) : self; dim = at::maybe_wrap_dim(dim, self_nonzero_dim);