@@ -541,6 +541,110 @@ std::tuple<Tensor,optional<int64_t>> index_add_batch_rule(
541
541
return std::make_tuple (at::stack (results), 0 );
542
542
}
543
543
544
+ std::tuple<Tensor,optional<int64_t >> index_fill_int_scalar_batch_rule (
545
+ const Tensor& self, optional<int64_t > self_bdim,
546
+ int64_t dim,
547
+ const Tensor& index, optional<int64_t > index_bdim,
548
+ const Scalar& value) {
549
+
550
+ // std::cout << "index_fill_int_scalar_batch_rule:" << std::endl;
551
+ if (!index_bdim) {
552
+ // Handle scalar tensors... self, other can be scalar tensors
553
+ const auto self_logical_rank = rankWithoutBatchDim (self, self_bdim);
554
+ auto self_ = moveBatchDimToFront (self, self_bdim);
555
+ if (self_logical_rank == 0 ) {
556
+ self_ = self_.unsqueeze (-1 );
557
+ }
558
+ dim = maybe_wrap_dim (dim, self_logical_rank);
559
+
560
+ optional<int64_t > out_bdim = nullopt;
561
+ if (self_bdim) {
562
+ const auto batch_size = self.size (*self_bdim);
563
+ self_ = ensure_has_bdim (self_, self_bdim.has_value (), batch_size);
564
+ dim = dim + 1 ;
565
+ out_bdim = 0 ;
566
+ }
567
+
568
+ // std::cout << "1 index_fill, self_: " << self_.sizes() << " index: " << index.sizes() << std::endl;
569
+ auto result = self_.index_fill (dim, index , value);
570
+ if (self_logical_rank == 0 ) {
571
+ result = result.squeeze (-1 );
572
+ }
573
+ return std::make_tuple (result, out_bdim);
574
+ }
575
+
576
+ // SAME AS FOR index_add
577
+ // Index is batched. For-loop and stack is the best thing I can come up with
578
+ // right now. We really want generalized index_fill kernel in PyTorch
579
+ auto batch_size = get_bdim_size2 (self, self_bdim, index , index_bdim);
580
+ std::vector<Tensor> results;
581
+ results.reserve (batch_size);
582
+ // std::cout << "2 index_fill loop: " << std::endl;
583
+ for (const auto i : c10::irange (0 , batch_size)) {
584
+ const auto & self_slice = self_bdim.has_value () ?
585
+ self.select (*self_bdim, i) : self;
586
+ const auto & index_slice = index_bdim.has_value () ?
587
+ index .select (*index_bdim, i) : index ;
588
+ // std::cout << i << " self_: " << self_slice.sizes() << " index: " << index_slice.sizes() << std::endl;
589
+ results.push_back (at::index_fill (self_slice, dim, index_slice, value));
590
+ }
591
+ return std::make_tuple (at::stack (results), 0 );
592
+ }
593
+
594
+ std::tuple<Tensor,optional<int64_t >> index_fill_int_tensor_batch_rule (
595
+ const Tensor& self, optional<int64_t > self_bdim,
596
+ int64_t dim,
597
+ const Tensor& index, optional<int64_t > index_bdim,
598
+ const Tensor& value, optional<int64_t > value_bdim) {
599
+
600
+ // std::cout << "index_fill_int_tensor_batch_rule: "
601
+ // << ((index_bdim) ? "true" : "false") << " "
602
+ // << ((value_bdim) ? "true" : "false") << " "
603
+ // << std::endl;
604
+ if (!index_bdim && !value_bdim) {
605
+ // Handle scalar tensors... self, other can be scalar tensors
606
+ const auto self_logical_rank = rankWithoutBatchDim (self, self_bdim);
607
+ auto self_ = moveBatchDimToFront (self, self_bdim);
608
+ if (self_logical_rank == 0 ) {
609
+ self_ = self_.unsqueeze (-1 );
610
+ }
611
+ dim = maybe_wrap_dim (dim, self_logical_rank);
612
+
613
+ optional<int64_t > out_bdim = nullopt;
614
+ if (self_bdim) {
615
+ const auto batch_size = self.size (*self_bdim);
616
+ self_ = ensure_has_bdim (self_, self_bdim.has_value (), batch_size);
617
+ dim = dim + 1 ;
618
+ out_bdim = 0 ;
619
+ }
620
+ // std::cout << "1 index_fill, self_: " << self_.sizes() << " index: " << index.sizes() << std::endl;
621
+ auto result = self_.index_fill (dim, index , value);
622
+ if (self_logical_rank == 0 ) {
623
+ result = result.squeeze (-1 );
624
+ }
625
+ return std::make_tuple (result, out_bdim);
626
+ }
627
+
628
+ // SAME AS FOR index_add
629
+ // Index is batched. For-loop and stack is the best thing I can come up with
630
+ // right now. We really want generalized index_fill kernel in PyTorch
631
+ auto batch_size = get_bdim_size3 (self, self_bdim, index , index_bdim, value, value_bdim);
632
+ std::vector<Tensor> results;
633
+ results.reserve (batch_size);
634
+ // std::cout << "2 index_fill loop: " << std::endl;
635
+ for (const auto i : c10::irange (0 , batch_size)) {
636
+ const auto & self_slice = self_bdim.has_value () ?
637
+ self.select (*self_bdim, i) : self;
638
+ const auto & index_slice = index_bdim.has_value () ?
639
+ index .select (*index_bdim, i) : index ;
640
+ const auto & value_slice = value_bdim.has_value () ?
641
+ value.select (*value_bdim, i) : value;
642
+ // std::cout << i << " self_: " << self_slice.sizes() << " index: " << index_slice.sizes() << " value: " << value_slice.sizes() << std::endl;
643
+ results.push_back (at::index_fill (self_slice, dim, index_slice, value_slice));
644
+ }
645
+ return std::make_tuple (at::stack (results), 0 );
646
+ }
647
+
544
648
TORCH_LIBRARY_IMPL (aten, FT_BATCHED_KEY, m) {
545
649
m.impl (" index.Tensor" , index_plumbing);
546
650
m.impl (" index_put_" , index_put__plumbing);
@@ -550,6 +654,8 @@ TORCH_LIBRARY_IMPL(aten, FT_BATCHED_KEY, m) {
550
654
m.impl (" index_copy" , index_copy_decomp);
551
655
m.impl (" index_select" , index_select_decomp);
552
656
VMAP_SUPPORT (" index_add" , index_add_batch_rule);
657
+ VMAP_SUPPORT (" index_fill.int_Scalar" , index_fill_int_scalar_batch_rule);
658
+ VMAP_SUPPORT (" index_fill.int_Tensor" , index_fill_int_tensor_batch_rule);
553
659
VMAP_SUPPORT (" diagonal_scatter" , diagonal_scatter_batch_rule);
554
660
VMAP_SUPPORT (" gather" , gather_batch_rule);
555
661
VMAP_SUPPORT (" gather_backward" , gather_backward_batch_rule);
0 commit comments