Skip to content

Commit e232281

Browse files
committed
[skip ci] WIP on index_fill batch rule
1 parent 1af1ae2 commit e232281

File tree

3 files changed

+107
-5
lines changed

3 files changed

+107
-5
lines changed

functorch/csrc/BatchRulesScatterOps.cpp

+106
Original file line numberDiff line numberDiff line change
@@ -541,6 +541,110 @@ std::tuple<Tensor,optional<int64_t>> index_add_batch_rule(
541541
return std::make_tuple(at::stack(results), 0);
542542
}
543543

544+
std::tuple<Tensor,optional<int64_t>> index_fill_int_scalar_batch_rule(
545+
const Tensor& self, optional<int64_t> self_bdim,
546+
int64_t dim,
547+
const Tensor& index, optional<int64_t> index_bdim,
548+
const Scalar& value) {
549+
550+
// std::cout << "index_fill_int_scalar_batch_rule:" << std::endl;
551+
if (!index_bdim) {
552+
// Handle scalar tensors... self, other can be scalar tensors
553+
const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
554+
auto self_ = moveBatchDimToFront(self, self_bdim);
555+
if (self_logical_rank == 0) {
556+
self_ = self_.unsqueeze(-1);
557+
}
558+
dim = maybe_wrap_dim(dim, self_logical_rank);
559+
560+
optional<int64_t> out_bdim = nullopt;
561+
if (self_bdim) {
562+
const auto batch_size = self.size(*self_bdim);
563+
self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
564+
dim = dim + 1;
565+
out_bdim = 0;
566+
}
567+
568+
// std::cout << "1 index_fill, self_: " << self_.sizes() << " index: " << index.sizes() << std::endl;
569+
auto result = self_.index_fill(dim, index, value);
570+
if (self_logical_rank == 0) {
571+
result = result.squeeze(-1);
572+
}
573+
return std::make_tuple(result, out_bdim);
574+
}
575+
576+
// SAME AS FOR index_add
577+
// Index is batched. For-loop and stack is the best thing I can come up with
578+
// right now. We really want generalized index_fill kernel in PyTorch
579+
auto batch_size = get_bdim_size2(self, self_bdim, index, index_bdim);
580+
std::vector<Tensor> results;
581+
results.reserve(batch_size);
582+
// std::cout << "2 index_fill loop: " << std::endl;
583+
for (const auto i : c10::irange(0, batch_size)) {
584+
const auto& self_slice = self_bdim.has_value() ?
585+
self.select(*self_bdim, i) : self;
586+
const auto& index_slice = index_bdim.has_value() ?
587+
index.select(*index_bdim, i) : index;
588+
// std::cout << i << " self_: " << self_slice.sizes() << " index: " << index_slice.sizes() << std::endl;
589+
results.push_back(at::index_fill(self_slice, dim, index_slice, value));
590+
}
591+
return std::make_tuple(at::stack(results), 0);
592+
}
593+
594+
std::tuple<Tensor,optional<int64_t>> index_fill_int_tensor_batch_rule(
595+
const Tensor& self, optional<int64_t> self_bdim,
596+
int64_t dim,
597+
const Tensor& index, optional<int64_t> index_bdim,
598+
const Tensor& value, optional<int64_t> value_bdim) {
599+
600+
// std::cout << "index_fill_int_tensor_batch_rule: "
601+
// << ((index_bdim) ? "true" : "false") << " "
602+
// << ((value_bdim) ? "true" : "false") << " "
603+
// << std::endl;
604+
if (!index_bdim && !value_bdim) {
605+
// Handle scalar tensors... self, other can be scalar tensors
606+
const auto self_logical_rank = rankWithoutBatchDim(self, self_bdim);
607+
auto self_ = moveBatchDimToFront(self, self_bdim);
608+
if (self_logical_rank == 0) {
609+
self_ = self_.unsqueeze(-1);
610+
}
611+
dim = maybe_wrap_dim(dim, self_logical_rank);
612+
613+
optional<int64_t> out_bdim = nullopt;
614+
if (self_bdim) {
615+
const auto batch_size = self.size(*self_bdim);
616+
self_ = ensure_has_bdim(self_, self_bdim.has_value(), batch_size);
617+
dim = dim + 1;
618+
out_bdim = 0;
619+
}
620+
// std::cout << "1 index_fill, self_: " << self_.sizes() << " index: " << index.sizes() << std::endl;
621+
auto result = self_.index_fill(dim, index, value);
622+
if (self_logical_rank == 0) {
623+
result = result.squeeze(-1);
624+
}
625+
return std::make_tuple(result, out_bdim);
626+
}
627+
628+
// SAME AS FOR index_add
629+
// Index is batched. For-loop and stack is the best thing I can come up with
630+
// right now. We really want generalized index_fill kernel in PyTorch
631+
auto batch_size = get_bdim_size3(self, self_bdim, index, index_bdim, value, value_bdim);
632+
std::vector<Tensor> results;
633+
results.reserve(batch_size);
634+
// std::cout << "2 index_fill loop: " << std::endl;
635+
for (const auto i : c10::irange(0, batch_size)) {
636+
const auto& self_slice = self_bdim.has_value() ?
637+
self.select(*self_bdim, i) : self;
638+
const auto& index_slice = index_bdim.has_value() ?
639+
index.select(*index_bdim, i) : index;
640+
const auto& value_slice = value_bdim.has_value() ?
641+
value.select(*value_bdim, i) : value;
642+
// std::cout << i << " self_: " << self_slice.sizes() << " index: " << index_slice.sizes() << " value: " << value_slice.sizes() << std::endl;
643+
results.push_back(at::index_fill(self_slice, dim, index_slice, value_slice));
644+
}
645+
return std::make_tuple(at::stack(results), 0);
646+
}
647+
544648
TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
545649
m.impl("index.Tensor", index_plumbing);
546650
m.impl("index_put_", index_put__plumbing);
@@ -550,6 +654,8 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
550654
m.impl("index_copy", index_copy_decomp);
551655
m.impl("index_select", index_select_decomp);
552656
VMAP_SUPPORT("index_add", index_add_batch_rule);
657+
VMAP_SUPPORT("index_fill.int_Scalar", index_fill_int_scalar_batch_rule);
658+
VMAP_SUPPORT("index_fill.int_Tensor", index_fill_int_tensor_batch_rule);
553659
VMAP_SUPPORT("diagonal_scatter", diagonal_scatter_batch_rule);
554660
VMAP_SUPPORT("gather", gather_batch_rule);
555661
VMAP_SUPPORT("gather_backward", gather_backward_batch_rule);

test/test_ops.py

+1-4
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,6 @@ def vjp_of_vjp(*args_and_cotangents):
513513
xfail('fmax'),
514514
xfail('fmin'),
515515
xfail('index_copy'),
516-
xfail('index_fill'),
517516
xfail('linalg.det', ''),
518517
xfail('linalg.eigh'),
519518
xfail('linalg.householder_product'),
@@ -595,7 +594,6 @@ def test_vmapvjp(self, device, dtype, op):
595594
xfail('block_diag'), # TODO: We expect this to fail in core, but it doesn't
596595
xfail('index_copy'),
597596
xfail('index_put'),
598-
xfail('index_fill'),
599597
xfail('masked_fill'),
600598
xfail('masked_scatter'),
601599
@@ -701,7 +699,6 @@ def test_vmapjvp(self, device, dtype, op):
701699
xfail('max', 'binary'),
702700
xfail('nn.functional.gaussian_nll_loss'),
703701
xfail('min', 'binary'),
704-
xfail('index_fill'),
705702
xfail('index_put'),
706703
xfail('std_mean'),
707704
xfail('double', 'channels_last'),
@@ -760,7 +757,7 @@ def test_vmapjvpall(self, device, dtype, op):
760757
xfail('fmax'),
761758
xfail('fmin'),
762759
xfail('index_copy'),
763-
xfail('index_fill'),
760+
xfail('index_fill'), # RuntimeError: aten::_unique hit the vmap fallback which is currently disabled
764761
xfail('linalg.cholesky'),
765762
xfail('linalg.cholesky_ex'),
766763
xfail('linalg.det'),

test/test_vmap.py

-1
Original file line numberDiff line numberDiff line change
@@ -3181,7 +3181,6 @@ def test_vmap_exhaustive(self, device, dtype, op):
31813181
xfail('gradient'),
31823182
xfail('histogram'),
31833183
xfail('hsplit'),
3184-
xfail('index_fill'),
31853184
xfail('index_put'),
31863185
xfail('isin'),
31873186
xfail('linalg.cholesky'),

0 commit comments

Comments
 (0)