Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add aten::index_reduce operator #1156

Merged
merged 27 commits into from
Jan 9, 2025
Merged
Changes from 1 commit
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
92738af
First reduced version of index_reduce.
cfgfung Nov 25, 2024
cbd7b03
Implemented reduce_prod.
cfgfung Nov 26, 2024
55daf53
Removed unnecessary if cases.
cfgfung Dec 2, 2024
eb1ad53
Added two reduce operators - amin and amax.
cfgfung Dec 3, 2024
f77e325
Add reduce_mean op.
cfgfung Dec 4, 2024
9066b5d
Skip 3 test cases. These are due to precision errors and the differen…
cfgfung Dec 19, 2024
fe08ee4
Merge branch 'main' into reduce_index_v2
xytintel Jan 3, 2025
bae2544
Merge branch 'main' into reduce_index_v2
xytintel Jan 6, 2025
c439c15
Fix pointer bug & refine code
xytintel Jan 6, 2025
4ddc3af
Update XPUFallback.template
xytintel Jan 6, 2025
41abd1b
Update xpu_test_utils.py
xytintel Jan 6, 2025
68eb7ae
Update TensorAdvancedIndexing.cpp
xytintel Jan 6, 2025
7435995
Update native_functions.yaml
xytintel Jan 6, 2025
e7c5c16
Update native_functions.yaml
xytintel Jan 6, 2025
a4ffaee
Update skip_list_common.py
xytintel Jan 6, 2025
a6f3fc6
Update Indexing.cpp
xytintel Jan 6, 2025
3e0724a
Update TensorInfo.h
xytintel Jan 7, 2025
1e31b0b
Add syclMaxNumSubGroups
xytintel Jan 8, 2025
01549d1
Update Indexing.cpp
xytintel Jan 8, 2025
60381af
Merge branch 'main' into reduce_index_v2
xytintel Jan 8, 2025
743851f
Update Indexing.cpp
xytintel Jan 8, 2025
6cbb613
Update ScatterGatherKernels.cpp
xytintel Jan 8, 2025
a6c36d7
Update skip_list_common.py
xytintel Jan 9, 2025
be54630
Update skip_list_common.py
xytintel Jan 9, 2025
f186ec8
Merge branch 'main' into reduce_index_v2
xytintel Jan 9, 2025
63acc5c
Merge branch 'main' into reduce_index_v2
xytintel Jan 9, 2025
dfab49c
Update apply_torch_pr.py
xytintel Jan 9, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Implemented reduce_prod.
cfgfung committed Nov 26, 2024

Unverified

This commit is not signed, but one or more authors requires that any commit attributed to them is signed.
commit cbd7b03d35a898b08a4d51aa3f4416543b349dd8
72 changes: 50 additions & 22 deletions src/ATen/native/xpu/TensorAdvancedIndexing.cpp
Original file line number Diff line number Diff line change
@@ -150,20 +150,18 @@ TORCH_IMPL_FUNC(index_reduce_xpu_out)
c10::impl::check_and_update_common_device(
common_device, source, "xpu::index_reduce_out", "source");
dim = maybe_wrap_dim(dim, self.dim());
int reduce_type = 0;
reduce == "prod"? reduce_type = 1 : reduce_type = 0;
reduce == "mean"? reduce_type = 2 : reduce_type = 0;
reduce == "amax"? reduce_type = 3 : reduce_type = 0;
reduce == "amin"? reduce_type = 4 : reduce_type = 0;
switch(reduce_type){
case 0: //invalid
TORCH_CHECK(false, "reduce argument must be one of the following choices: prod, mean, amax or amin. The choice was ", reduce, ".");
break;
case 1: //prod
//index_reduce_kernel(self, dim, index, source, include_self, ReductionType::PROD, reduce_multiply, result);
xpu::index_reduce_kernel(self, dim, index, source, include_self, ReductionType::PROD, result);
break;
case 2: //mean
int reduce_type = 0; //hard code to test reduce index

// if (reduce == "prod") {reduce_type = 1;}
// if (reduce == "mean") {reduce_type = 2;}
// if (reduce == "amax") {reduce_type = 3;}
// if (reduce == "amin") {reduce_type = 4;}


if (reduce == "prod") {
xpu::index_reduce_kernel(self, dim, index, source, include_self, ReductionType::PROD, result);
}
else if (reduce == "mean") {
xpu::index_reduce_kernel(self, dim, index, source, include_self, ReductionType::MEAN, result);
auto counts = include_self ? ones_like(result) : zeros_like(result);
counts.index_add_(dim, index, ones_like(source));
@@ -173,15 +171,45 @@ TORCH_IMPL_FUNC(index_reduce_xpu_out)
}
else {
result.div_(counts, "floor");
}
break;
// case 3: //amax
// xpu::index_reduce_kernel(self, dim, index, source, include_self, ReductionType::MAX, result);
// break;
// case 4: //amin
// xpu::index_reduce_kernel(self, dim, index, source, include_self, ReductionType::MIN, result);
// break;
}
}
else if (reduce == "amax") {
xpu::index_reduce_kernel(self, dim, index, source, include_self, ReductionType::MAX, result);
}
else if (reduce == "amin") {
xpu::index_reduce_kernel(self, dim, index, source, include_self, ReductionType::MIN, result);
} else {
TORCH_CHECK(false, "Only support prod, mean, amax or amin reduce operator. Input was ", reduce, ".");
}


// switch(reduce_type){
// case 0: //invalid
// TORCH_CHECK(false, "reduce argument must be one of the following choices: prod, mean, amax or amin. The choice was ", reduce, ".");
// break;
// case 1: //prod
// //index_reduce_kernel(self, dim, index, source, include_self, ReductionType::PROD, reduce_multiply, result);
// xpu::index_reduce_kernel(self, dim, index, source, include_self, ReductionType::PROD, result);
// break;
// case 2: //mean
// xpu::index_reduce_kernel(self, dim, index, source, include_self, ReductionType::MEAN, result);
// auto counts = include_self ? ones_like(result) : zeros_like(result);
// counts.index_add_(dim, index, ones_like(source));
// counts.masked_fill_(counts == 0, 1);
// if (result.is_floating_point() || result.is_complex()) {
// result.div_(counts);
// }
// else {
// result.div_(counts, "floor");
// }
// break;
// case 3: //amax
// xpu::index_reduce_kernel(self, dim, index, source, include_self, ReductionType::MAX, result);
// break;
// case 4: //amin
// xpu::index_reduce_kernel(self, dim, index, source, include_self, ReductionType::MIN, result);
// break;
// }
}

} // namespace native