-
Notifications
You must be signed in to change notification settings - Fork 699
/
Copy pathIROptimizer.cpp
1691 lines (1508 loc) · 59 KB
/
IROptimizer.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/**
* Copyright (c) Glow Contributors. See CONTRIBUTORS file.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "glow/Optimizer/IROptimizer/IROptimizer.h"
#include "glow/Optimizer/IROptimizer/IRFunctionPassManager.h"
#include "glow/Optimizer/IROptimizer/IRFunctionPasses.h"
#include "glow/Backend/Backend.h"
#include "glow/Graph/Graph.h"
#include "glow/IR/IR.h"
#include "glow/IR/IRBuilder.h"
#include "glow/IR/IRUtils.h"
#include "glow/IR/Instrs.h"
#include "glow/Support/Debug.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include <unordered_map>
#include <unordered_set>
#define DEBUG_TYPE "ir-optimizer"
using namespace glow;
using llvm::cast;
using llvm::dyn_cast;
using llvm::isa;
namespace {
/// Live interval of a memory buffer.
/// It represents a sequence of instructions [begin, end) where this buffer
/// holds a value.
struct Interval {
/// Index of the interval begin. Typically this is the index of the
/// instruction, which overwrites the buffer.
size_t begin_;
/// Index of the interval end. Typically, it is the last use of the current
/// value held in the buffer.
size_t end_;
/// True if the value does not change between begin and end, e.g.
/// due to @inout use. In most cases, the value does not change for the
/// duration of a single live interval.
bool sameValue_{true};
Interval(size_t begin, size_t end, bool sameValue = true)
: begin_(begin), end_(end), sameValue_(sameValue) {}
bool operator==(const Interval &other) const {
return begin_ == other.begin_ && end_ == other.end_ &&
sameValue_ == other.sameValue_;
}
std::string str() const {
std::string s;
llvm::raw_string_ostream sb{s};
sb << "[" << begin_ << ", " << end_ << ", " << sameValue_ << ")";
return sb.str();
}
};
/// A helper class used for instructions numbering used by live intervals.
/// It follows the LLVM's linear scan register allocator approach and assigns
/// different numbers to read and write slots of the same instruction, which
/// allows for an easy construction of a very precise set of live intervals.
class LiveIntervalsInstructionNumbering {
using NumberedInstructionMap = std::vector<Instruction *>;
using InstructionNumbersMap = std::unordered_map<const Instruction *, size_t>;
/// Maps the number to an instruction.
NumberedInstructionMap numToInstr_;
/// Maps an instruction to its number.
InstructionNumbersMap instrToNum_;
public:
/// Virtual slot number to be used for instructions numbering. It helps to
/// distinguish reads from writes and makes comparision of live intervals
/// easier. LLVM used a similar approach for the linear scan register
/// allocator.
///
/// For an instruction with number N, its @in operands would be considered
/// to be at (N+READ_SLOT), its @out operands would be at (N+WRITE_SLOT).
enum SLOTS {
READ_SLOT = 0,
WRITE_SLOT = 2,
MAX_SLOT = 4,
};
LiveIntervalsInstructionNumbering(IRFunction &M) {
auto &instrs = M.getInstrs();
size_t instIdx = 0;
numToInstr_.reserve(instrs.size());
for (auto &I : instrs) {
numToInstr_.push_back(&I);
instrToNum_[&I] = instIdx;
instIdx += MAX_SLOT;
}
}
/// \returns the base number of the instruction.
/// It is the same for all slots of a given instruction.
static int64_t getInstrBaseNumber(int64_t idx) {
return idx / MAX_SLOT * MAX_SLOT;
}
/// \returns true if \p idx is the instruction number of the read slot of the
/// instruction.
static bool isReadSlotNumber(int64_t idx) {
return idx % MAX_SLOT == READ_SLOT;
}
/// \returns true if \p idx is the instruction number of a write slot of the
/// instruction.
static bool isWriteSlotNumber(int64_t idx) {
return idx % MAX_SLOT == WRITE_SLOT;
}
/// \returns the instruction number of a read slot of instruction with number
/// \p idx.
static int64_t getInstrReadSlotNumber(int64_t idx) {
return getInstrBaseNumber(idx) + READ_SLOT;
}
/// \returns the instruction number of a write slot of instruction with number
/// \p idx.
static int64_t getInstrWriteSlotNumber(int64_t idx) {
return getInstrBaseNumber(idx) + WRITE_SLOT;
}
/// \returns the number of the instruction, or -1 if the instruction is not
/// numbered.
int64_t getInstrNumber(const Instruction *I) const {
auto result = instrToNum_.find(I);
if (result == instrToNum_.end()) {
return -1;
}
return (int64_t)result->second;
}
/// \returns the instruction with a given number.
Instruction *getInstr(size_t instrNumber) const {
assert(instrNumber / MAX_SLOT < numToInstr_.size());
return numToInstr_[instrNumber / MAX_SLOT];
}
};
} // namespace
llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Interval &I) {
os << I.str();
return os;
}
/// Set of intervals for a single memory buffer. If there is only one write into
/// a memory buffer, it would contain a single interval. If there are multiple
/// writes, it would contain multiple live intervals, one per write.
using Intervals = llvm::SmallVector<Interval, 4>;
/// Maping from a memory buffer to its live intervals.
using LiveIntervalsMap = std::unordered_map<const Value *, Intervals>;
/// Set of instructions.
using InstructionPtrSet = std::unordered_set<Instruction *>;
/// Hoists Dealloc instructions right after their last use.
static bool hoistDealloc(IRFunction &M) {
bool changed = false;
// Maps activation instructions to their last non-dealloc user.
std::unordered_map<Value *, Instruction *> lastUser;
// Dealloc instructions in the current function.
llvm::SetVector<Instruction *> deallocs;
auto &instrs = M.getInstrs();
// Record the last use of each dealloc.
for (auto &I : instrs) {
if (isa<DeallocActivationInst>(&I)) {
// Collect dealloc instructions.
deallocs.insert(&I);
changed = true;
continue;
}
if (auto alloc = dyn_cast<AllocActivationInst>(&I)) {
lastUser[alloc] = &I;
continue;
}
for (int i = 0, e = I.getNumOperands(); i < e; i++) {
auto op = I.getOperand(i).first;
// Consider any use of a tensor_view to be also a use
// of its source tensor. This is required to make
// sure that a lifetime of a tensor_view is always
// enclosed inside the lifetime of its source tensor.
if (auto *alloc = getAllocationOrigin(op)) {
lastUser[alloc] = &I;
continue;
}
}
}
// Now that we've found the last user of each allocated buffer, we can hoist
// the dealloc instructions.
for (auto it = deallocs.begin(), e = deallocs.end(); it != e;
/* increment below */) {
auto *curr = *it;
++it;
auto *da = dyn_cast<DeallocActivationInst>(&*curr);
if (!da) {
continue;
}
auto *alloc = cast<AllocActivationInst>(getOrigin(da->getSrc()));
auto *where = lastUser[alloc];
if (std::next(where->getIterator()) == curr->getIterator()) {
// No need to move the instruction, because the last use was
// right before the deallocation.
continue;
}
// Get the instruction after where or at the end.
if (std::next(where->getIterator()) != instrs.end()) {
where = &*std::next(where->getIterator());
M.moveInstruction(where, curr);
} else {
// Append at the end.
M.removeInstruction(curr);
M.pushInstr(curr);
}
changed = true;
}
return changed;
}
/// Sink Alloc instructions right before their first use.
static bool sinkAllocas(IRFunction &M) {
bool changed = false;
/// A list of allocas to reschedule.
InstructionPtrSet allocs;
auto &instrs = M.getInstrs();
// Remove all of the allocas.
for (auto it = instrs.begin(), e = instrs.end(); it != e;) {
auto *I = &*it;
++it;
auto *aa = dyn_cast<AllocActivationInst>(I);
if (!aa) {
continue;
}
allocs.insert(aa);
M.removeInstruction(I);
changed = true;
}
// Place all of the allocas in the right place:
for (auto &I : instrs) {
for (int i = 0, e = I.getNumOperands(); i < e; i++) {
auto op = I.getOperand(i).first;
auto aa = dyn_cast<AllocActivationInst>(getOrigin(op));
if (!aa) {
continue;
}
auto A = allocs.find(aa);
if (A == allocs.end()) {
continue;
}
allocs.erase(A);
M.insertInstruction(&I, aa);
changed = true;
if (allocs.empty()) {
return changed;
}
}
}
assert(allocs.empty() && "Forgot to insert some allocas!");
return changed;
}
/// Sink tensorview instructions right before their first use.
static bool sinkTensorViews(IRFunction &M) {
bool changed = false;
// A set of tensorviews to reschedule.
std::unordered_set<TensorViewInst *> tensorviews;
auto &instrs = M.getInstrs();
// Remove all of the tensorviews.
for (auto it = instrs.begin(), e = instrs.end(); it != e;) {
auto *I = &*it;
++it;
auto *tv = dyn_cast<TensorViewInst>(I);
if (!tv) {
continue;
}
// Ignore tensorviews that are unused.
if (!tv->hasUsers()) {
continue;
}
tensorviews.insert(tv);
M.removeInstruction(I);
changed = true;
}
// Place all of the tensorviews in the right place:
for (auto it = instrs.begin(), e = instrs.end(); it != e;) {
// Holds the next value for the iterator.
auto nextIt = instrs.end();
auto *I = &*it;
for (int i = 0, f = I->getNumOperands(); i < f; i++) {
auto op = I->getOperand(i).first;
auto tv = dyn_cast<TensorViewInst>(op);
if (!tv) {
continue;
}
auto TV = tensorviews.find(tv);
if (TV == tensorviews.end()) {
continue;
}
auto inserted = M.insertInstruction(I, tv);
changed = true;
tensorviews.erase(TV);
if (tensorviews.empty()) {
return changed;
}
if (nextIt == instrs.end()) {
// Remember and re-scan the first inserted instruction as it may use
// another tensor_view.
nextIt = inserted;
}
}
// If no insertions were made, move to the next instruction.
if (nextIt == instrs.end()) {
nextIt = ++it;
}
it = nextIt;
}
assert(tensorviews.empty() && "Forgot to insert some tensorviews!");
return changed;
}
/// Delete alloc instructions that have no readers or writers.
static bool deleteDeadAllocs(IRFunction &M) {
bool changed = false;
auto &instrs = M.getInstrs();
// Remove all unused tensor views tracking back their dependencies, which are
// in a topological order.
// Note that this should precede to remove dependencies on allocs.
for (auto it = instrs.rbegin(); it != instrs.rend();) {
// Remember the current instruction and advance the iterator.
auto *I = &*it++;
if (isa<TensorViewInst>(I) && I->getNumUsers() == 0) {
// Remove a tensor view. It may make other tensor views preceding it
// eligible for a removal as well.
M.eraseInstruction(I);
changed = true;
}
}
// Remove all of unused allocs and their corresponding deallocs.
// Iterate instructions in a reverse order to erase deallocs before
// their respective allocs, otherwise `DeallocActivationInst::getAlloc()` will
// return erased allocs.
for (auto it = instrs.rbegin(); it != instrs.rend();) {
auto *I = &*it++;
const auto *DA = dyn_cast<const DeallocActivationInst>(I);
if (DA && DA->getAlloc()->getNumUsers() < 2) {
M.eraseInstruction(I);
changed = true;
continue;
}
if (isa<AllocActivationInst>(I) && !I->hasUsers()) {
M.eraseInstruction(I);
changed = true;
}
}
return changed;
}
// Replace all users of some value with another value, but don't touch the
// dealloc instruction, because we need to preserve the well formedness of the
// IR.
static void replaceAllNonDeallocUsersWith(Value *val, Value *with) {
assert(val != with && "Replacing value with self");
auto &users = val->getUsers();
// We use a vector here because changing the operands of the user changes the
// uselist, and this invalidates the iterator.
llvm::SmallVector<Use, 6> usersVec(users.begin(), users.end());
for (auto &U : usersVec) {
auto *I = U.get();
// Ignore the instruction itself (e.g. when creating a view and then
// replacing all uses of the original with the view).
if (I == with) {
continue;
}
// Ignore dealloc instrs.
if (isa<DeallocActivationInst>(I)) {
continue;
}
assert(U.getOperand().first->getType() == with->getType() &&
"Operand type does not match replacement type.");
U.setOperand(with);
}
}
/// \returns true if Value \p V has more than one writer, ignoring any
/// instructions in \p ignoredInstructions.
static bool hasMultipleWriters(const Value *V,
InstructionPtrSet ignoredInstructions) {
bool foundWriter = false;
for (const auto &U : ValueUses(V)) {
Instruction *user = U.get();
// Ignore deallocs.
if (isa<DeallocActivationInst>(user)) {
continue;
}
// Ignore readers.
if (U.getOperand().second == OperandKind::In) {
continue;
}
// Ignore others provided.
if (ignoredInstructions.find(user) != ignoredInstructions.end()) {
continue;
}
// Already found another writer.
if (foundWriter) {
return true;
}
foundWriter = true;
}
return false;
}
/// \returns the pointer to the single writer that writes into this value \p V,
/// or nullptr if the number of writers is not exactly one.
static Instruction *getSingleWriter(const Value *V) {
Instruction *singleUser = nullptr;
for (const auto &U : ValueUses(V)) {
Instruction *user = U.get();
// Ignore deallocs.
if (isa<DeallocActivationInst>(user)) {
continue;
}
auto op = U.getOperand();
// Ignore the readers.
if (op.second == OperandKind::In) {
continue;
}
// Multiple users.
if (singleUser) {
return nullptr;
}
singleUser = user;
}
return singleUser;
}
/// Marks non-mutable weights as constants.
bool makeWeightsConst(IRFunction &M) {
bool changed = false;
// For each weight:
for (auto *W : M.getWeights()) {
if (!W->isConstant()) {
continue;
}
bool readOnly = true;
// For each instruction that uses the weight:
for (const auto &U : ValueUses(W)) {
auto kind = U.getOperand().second;
// Check if all of the users are read-only.
if (kind != OperandKind::In) {
readOnly = false;
break;
}
}
// Mark the constant as read only.
if (readOnly) {
W->setMutability(WeightVar::MutabilityKind::Constant);
changed = true;
} else {
assert(W->getMutability() != WeightVar::MutabilityKind::Constant &&
"Const cannot be written into.");
}
}
return changed;
}
#ifndef NDEBUG
/// Dump a live intervals map.
static void LLVM_ATTRIBUTE_UNUSED dump(IRFunction &M,
LiveIntervalsMap &intervalsMap) {
llvm::outs() << "\nDumping live intervals map:\n";
for (const auto &I : intervalsMap) {
llvm::outs() << "\nValue " << I.first->getName();
llvm::outs() << "\n";
for (const auto &Interval : I.second) {
llvm::outs() << Interval << " ";
}
llvm::outs() << "\n";
}
}
#endif
/// Compute live intervals for each mutable location represented by
/// Value which is either an AllocActivationInst or a WeightVar.
/// Each such value is mapped to a list of intervals where it is alive.
/// Each interval starts at the point of definition and ends at last use
/// of the current value, which is assigned at the beginning of the current
/// interval. If there are multiple writes to the same mutable memory
/// location, then each such assignment would result in a new interval.
static void calculateLiveIntervals(const IRFunction &M,
LiveIntervalsMap &liveness) {
assert(liveness.empty() &&
"This function should be called with empty liveness map");
auto const &instrs = M.getInstrs();
unsigned instIdx = 0;
// Compute the [start..end) intervals for each alloc activation in our basic
// block. Notice that we ignore Dealloc instructions in our analysis.
for (auto it = instrs.begin(), e = instrs.end(); it != e;
++it, instIdx += LiveIntervalsInstructionNumbering::MAX_SLOT) {
auto *I = &*it;
// Ignore deallocations in our liveness calculation.
if (isa<DeallocActivationInst>(I)) {
continue;
}
// Ignore tensorview instructions, because they are just aliases
// and do not represent a read or write, even though formally they
// are reads due to the @in src parameter.
if (isa<TensorViewInst>(I)) {
continue;
}
auto instOperands = I->getOperands();
llvm::SmallVector<Instruction::Operand, 8> sortedOperands(
instOperands.begin(), instOperands.end());
// Sort operands so that:
// - all operands referencing the same Value are grouped together.
// - operands related to the same Value are always in the following
// order: In, InOut, Out.
//
// This ordering ensures that we process reads before writes.
std::sort(sortedOperands.begin(), sortedOperands.end());
for (int i = 0, f = sortedOperands.size(); i < f; i++) {
auto op = sortedOperands[i].first;
auto opKind = sortedOperands[i].second;
// Look through tensorviews. As a result, all operations
// on tensorviews are accounted as operations on their
// origins.
auto opOrigin = getOrigin(op);
Value *loc = dyn_cast<AllocActivationInst>(opOrigin);
if (!loc) {
loc = dyn_cast<WeightVar>(opOrigin);
}
// Bail if the operand is not an AllocActivationInst or a WeightVar.
if (!loc) {
continue;
}
// Determine if this is a write to a subview of the tensor, i.e. a write
// to a tensorview with any non-zero offsets. We treat such partial writes
// in the same way as InOut: they're not the end of an interval, but they
// also obviously modify (part of) the value.
const bool isPartialWrite =
(opKind == OperandKind::Out) &&
(op->getType()->size() < loc->getType()->size());
unsigned opIdx;
if (opKind == OperandKind::Out && !isPartialWrite) {
opIdx =
LiveIntervalsInstructionNumbering::getInstrWriteSlotNumber(instIdx);
} else {
opIdx =
LiveIntervalsInstructionNumbering::getInstrReadSlotNumber(instIdx);
}
auto found = liveness.find(loc);
if (found == liveness.end()) {
// Create a new interval.
liveness[loc].push_back(Interval(opIdx, opIdx + 1));
// If it is a first use, it should be either an input variable or
// a write.
// FIXME: Remove InOut!
assert((isa<TensorViewInst>(I) || isa<WeightVar>(opOrigin) ||
opKind == OperandKind::Out || opKind == OperandKind::InOut) &&
"First reference inside a live interval should be either an "
"input variable or a write");
continue;
}
auto &intervals = found->second;
// Extend the interval but only if current use is not a write or
// if it is a write, but we have seen a read before.
if (opKind != OperandKind::Out) {
intervals.back().end_ = opIdx + 1;
}
// How @inout operands should be handled?
// They cannot be treated as an end of an interval and a beginning of a
// new one, because this would imply that this new interval completely
// overwrites the buffer, which is not true in general.
// @inout operands cannot be considered to be simple reads, because it
// would mean that the value does not change for the duration of the
// interval, which is not the case. To handle this, @inout operands are
// considered to be a part of the existing interval, but the sameValue_
// flag is set to false to indicate that the value is not guaranteed to be
// the same inside the interval. Note: partial writes have similar
// properties and so are treated in the same way.
if (opKind == OperandKind::InOut || isPartialWrite) {
intervals.back().sameValue_ = false;
}
// No need to create a new interval if it is not a write, or if it is a
// partial write.
if (opKind != OperandKind::Out || isPartialWrite)
continue;
opIdx =
LiveIntervalsInstructionNumbering::getInstrWriteSlotNumber(instIdx);
// This instruction modifies the memory location.
// Therefore, end the current active live interval
// for this memory location and begin a new one.
intervals.push_back(Interval(opIdx, opIdx + 1));
}
}
for (auto &Entry : liveness) {
auto *ML = Entry.first;
auto &IL = Entry.second;
if (isa<WeightVar>(ML)) {
assert(!IL.empty() && "Live interval list cannot be empty");
// Extend the last interval till the end of the program
// to express that all mutable weights are used outside.
IL.back().end_ = instIdx;
}
}
}
/// Provided a set of intervals, return the interval covering
/// a given instruction.
static Intervals::iterator getEnclosingInterval(Intervals &liveIntervals,
size_t instIdx) {
for (auto I = liveIntervals.begin(), E = liveIntervals.end(); I != E; ++I) {
if (I->begin_ <= instIdx && instIdx < I->end_) {
return I;
}
}
return liveIntervals.end();
}
/// Returns true if RHS is enclosed inside LHS.
static bool isEnclosedInside(Interval &lhs, Interval &rhs) {
return lhs.begin_ <= rhs.begin_ && rhs.end_ <= lhs.end_;
}
/// \returns true of any intervals from \p Ints overlap with interval \p I.
static bool hasOverlappingIntervals(Intervals &intervals, Interval I) {
for (const auto &curI : intervals) {
if (std::max(curI.begin_, I.begin_) < std::min(curI.end_, I.end_)) {
return true;
}
}
return false;
}
/// Helper function to get a compatible value to replace \p val with \p with.
/// This function casts \p with if necessary. When a cast needs to be created
/// it will be inserted before \p Before. Moreover, when that happens, the
/// second element of the returned pair is true, false otherwise.
///
/// \returns A pair with the first element being the Value matching val's type,
/// using \p with's content and the second element being the status of
/// whether a cast was inserted or not.
static std::pair<Value *, bool>
getCompatibleValueForReplacement(IRBuilder &B, Instruction *Before,
const Value &val, Value &with) {
if (val.getType() == with.getType()) {
return std::make_pair(&with, false);
}
Value *replacement = getOrigin(&with);
if (val.getType() == replacement->getType()) {
return std::make_pair(replacement, false);
}
// Perform a cast to make the types match.
std::vector<dim_t> offsets(replacement->dims().size(), 0);
auto *tv = B.createTensorViewInst(with.getName(), replacement, val.getType(),
offsets);
assert(tv->getType()->size() == with.getType()->size() &&
"Replacement must have same number of elements as original.");
B.getIRFunction().moveInstruction(Before, tv);
replacement = tv;
return std::make_pair(replacement, true);
}
/// Moves an interval from one interval list to another.
static void moveInterval(Intervals &from, Intervals &to, Interval &interval) {
auto fromIt = std::find(from.begin(), from.end(), interval);
assert(fromIt != from.end() && "Interval should exist in the from list");
// Nothing to do if interval is enclosed into one of to intervals.
// Add to the to list.
bool isEnclosed = false;
for (auto &I : to) {
if (isEnclosedInside(I, interval)) {
isEnclosed = true;
break;
}
}
if (!isEnclosed) {
to.push_back(interval);
// Let sort find a right position for it.
// std::sort(to.begin(), to.end());
}
// Delete from the from list.
from.erase(fromIt);
}
/// Replace all uses of \p val by \p with inside interval \p liveInterval.
/// While replacing the uses if we don't find a definition before
/// the first use and \p fixUpFirstUseIfNoDef is true, this method will
/// create a proper definition for \p with.
///
/// \p fixUpFirstUseIfNoDef must only be used when we are extending destination
/// live-ranges upward. Also, \p fixUpFirstUseIfNoDef must only be used if
/// we extend the live-range of \p with toward the live-range of a WeightVar.
/// If fixUpFirstUseIfNoDef is required in other situations, that means the
/// input IR is wrong and that we have a bug somewhere else.
static void replaceAllUsesInsideIntervalWith(
IRBuilder &B, Value *val, Value *with, const Interval &liveInterval,
IRFunction &M, const LiveIntervalsInstructionNumbering &instrNumbering,
bool fixUpFirstUseIfNoDef) {
auto &instrs = M.getInstrs();
auto valOrigin = getOrigin(val);
unsigned instIdx = 0;
bool sawDefinitionBeforeFirstUse = false;
Instruction *firstUse = nullptr;
for (auto it = instrNumbering.getInstr(liveInterval.begin_)->getIterator(),
e = instrs.end();
it != e && instIdx <= liveInterval.end_; ++it) {
auto *I = &*it;
if (isa<DeallocActivationInst>(I)) {
continue;
}
// Ignore any new instructions which were not present as the instruction
// numbering was performed.
auto instNum = instrNumbering.getInstrNumber(I);
if (instNum >= 0) {
instIdx = instNum;
}
if (instNum < 0) {
continue;
}
bool sawDefinition = false;
// This is an instruction inside the interval.
// Iterate over all operands and perform replacements.
for (int i = 0, f = I->getNumOperands(); i < f; i++) {
auto op = I->getOperand(i).first;
auto opOrigin = getOrigin(op);
auto opKind = I->getOperand(i).second;
// Is the operand the value we are looking for?
if (opOrigin != valOrigin) {
continue;
}
size_t opIdx = static_cast<size_t>(
(opKind == OperandKind::In)
? LiveIntervalsInstructionNumbering::getInstrReadSlotNumber(
instIdx)
: LiveIntervalsInstructionNumbering::getInstrWriteSlotNumber(
instIdx));
// Skip operands outside of the interval.
if (opIdx < liveInterval.begin_ || opIdx >= liveInterval.end_) {
continue;
}
std::pair<Value *, bool> replacementAndHasCreated =
getCompatibleValueForReplacement(B, I, *op, *with);
auto *replacement = replacementAndHasCreated.first;
// If we inserted a cast of with and didn't see any use
// of with yet, this is our first use.
if (replacementAndHasCreated.second && !firstUse) {
assert(llvm::isa<Instruction>(replacement) &&
"Replacement status should not be \"hasCreated\"");
firstUse = llvm::cast<Instruction>(replacement);
}
DEBUG_GLOW(llvm::dbgs()
<< "Replacing inside instruction " << opIdx << "\n";
llvm::dbgs() << "before: "; I->dump(llvm::dbgs());
llvm::dbgs() << "\n");
// Don't account for InOut definitions, because the In part of that
// definition is going to be undefined if we didn't see any
// definition yet.
sawDefinition |= opKind == OperandKind::Out;
if (!firstUse &&
(opKind == OperandKind::In || opKind == OperandKind::InOut)) {
firstUse = I;
}
// Replace the old value by the new value.
I->setOperand(i, replacement);
DEBUG_GLOW(llvm::dbgs() << "after: "; I->dump(llvm::dbgs());
llvm::dbgs() << "\n");
}
sawDefinitionBeforeFirstUse |= (sawDefinition && !firstUse);
}
// We found a use without a definition first and have been asked to
// fix those situations.
// Insert a copy to initialize "with" with val.
if (firstUse && !sawDefinitionBeforeFirstUse && fixUpFirstUseIfNoDef) {
std::pair<Value *, bool> replacementAndHasCreated =
getCompatibleValueForReplacement(B, firstUse, *with, *val);
auto *fixupInit = B.createCopyInst(firstUse->getName().str() + ".fixup",
with, replacementAndHasCreated.first);
M.moveInstruction(firstUse, fixupInit);
}
}
/// Erase all instructions from the \p ErasedInstructions set.
/// If \p forceErase is true, no additional checks are performed.
/// Otherwise, copies into weight variables cannot be erased.
static bool eraseInstructions(IRFunction &M,
InstructionPtrSet &erasedInstructions) {
bool changed = false;
for (auto it : erasedInstructions) {
DEBUG_GLOW(llvm::dbgs() << "Deleting instruction :"; it->dump(llvm::dbgs());
llvm::dbgs() << "\n");
M.eraseInstruction(it);
changed = true;
}
return changed;
}
/// \returns true if writes into this memory location are observable from
/// outside.
static bool isObservable(Value *V) { return isa<WeightVar>(getOrigin(V)); }
namespace {
/// A helper class for performing a sharing of buffers used by a given
/// instruction.
class BufferSharingOptimizer {
/// Current function.
IRFunction &M_;
/// Current IRBuilder
IRBuilder &builder_;
/// The instruction numbering to be used.
const LiveIntervalsInstructionNumbering &instrNumbering_;
/// Current instruction.
Instruction *instr_;
/// The number of the current instruction.
size_t instrIdx_;
/// The source argument.
Value *src_;
/// The destination argument.
Value *dest_;
/// The origin of the source argument.
Value *srcOrigin_;
/// The origin of the destination argument.
Value *destOrigin_;
/// List of live intervals for the source buffer.
Intervals &srcIntervals_;
/// List of live intervals for the destination buffer.
Intervals &destIntervals_;
/// The live interval of the source buffer, which covers the current
/// instruction.
Interval *srcInterval_;
/// The live interval of the destination buffer, which covers the current
/// instruction.
Interval *destInterval_;
/// Check if instr_ is a copy propagation.
/// That is, instr_ is a copy and both the source and destination
/// are not redefined on the related intervals.
/// Intervals are split at each definition, expect for inout.
/// Thus redefinitions could only happen when a value is redefined by
/// inout operands or some partial write.
/// This is tracked by the sameValue_ field on each interval.
bool isCopyPropagation() const {
return isa<CopyInst>(instr_) && srcInterval_->sameValue_ &&
destInterval_->sameValue_;
}
/// Pick the buffer that can be reused. To make a decision, check
/// which intervals intersect with each other. In most cases, the buffers
/// can be combined if their live intervals do not overlap.
///
/// \returns the buffer that can be
/// reused, or nullptr if none of the buffers can be reused.
Value *getBufferToBeReused() {
// Do not try to combine observables.
if (isObservable(destOrigin_) && isObservable(srcOrigin_)) {
return nullptr;
}
// Check if dest or src live interval is the last live interval of
// an observable memory location.
bool isDestLastIntervalOfObservable =
isObservable(destOrigin_) && *destInterval_ == destIntervals_.back();
bool isSrcLastIntervalOfObservable =
isObservable(srcOrigin_) && *srcInterval_ == srcIntervals_.back();
// A value X cannot reuse the buffer of another value Y,
// if the live interval of X overlaps with any live intervals of Y.
// The only exception is the copy instruction, where the live interval
// of the destination may be merged into a live interval of the source
// if they have the same value.
// If dest interval overlaps with any srcIntervals, it cannot be replaced.
bool destIntvalCannotBeReplaced =
!isCopyPropagation() &&
hasOverlappingIntervals(srcIntervals_, *destInterval_);
// If src interval overlaps with any dest Intervals, it cannot be replaced.
bool srcIntervalCannotBeReplaced =
hasOverlappingIntervals(destIntervals_, *srcInterval_);
if (!isDestLastIntervalOfObservable && !isSrcLastIntervalOfObservable &&
!destIntvalCannotBeReplaced && !srcIntervalCannotBeReplaced) {
// There are no restrictions and intervals can be combined on any
// order. Try to use a heuristic to pick the right way to combine
// them.
// Try to reuse the interval of an observable memory location, because
// it does not increase the memory usage.
// TODO: If it would introduce a last write into an observable, do not
// do it.
if (isObservable(srcOrigin_) && !isObservable(destOrigin_)) {
// Use src buffer for dest.
return srcOrigin_;
}
if (isObservable(destOrigin_) && !isObservable(srcOrigin_)) {
// Use dest buffer for src.
return destOrigin_;
}
// Avoid sharing a buffer if there is a single
// live interval in the interval list. After replacement
// this whole buffer can be eliminated.
if (srcIntervals_.size() == 1 && destIntervals_.size() != 1) {
// Use dest buffer for src.
return destOrigin_;
}
if (destIntervals_.size() == 1 && srcIntervals_.size() != 1) {
// Use src buffer for dest.
return srcOrigin_;
}
// Just use src buffer for dest by default.
return srcOrigin_;
}
// Try to check if buffers can be shared by using
// src instead of dest inside the live interval of dest.
// This is possible if src is not live after the current instruction and
// until the end of the current Dest's live interval.
if (isDestLastIntervalOfObservable || destIntvalCannotBeReplaced) {
// Dest cannot be replaced by src because src is being mutated while
// dest is alive or because dest contains the last write into
// an observable memory location.
// Try to replace src by dest in the live interval of src.
// This is possible if Src is not live anywhere inside the current
// Dest's live interval ending at the current instruction.
// Bail, because src cannot be replaced by dest because dest is being
// mutated while src is alive or because src contains the last write