Skip to content

Commit f7b9745

Browse files
Fix a bug with hoisting 'IRVar' insts that are used outside the loop (#6446)
* Fix a bug with hoisting 'IRVar' insts that are used outside the loop - We introduce a 'CheckpointObject' inst and use that to split loop state insts into two pieces (one for within-loop uses and one for outside-loop uses. - This allows the two kinds of uses to be handled separately by the hoisting mechanism - CheckpointObject is then lowered to a no-op after hoisting is complete. * Update slang-ir-autodiff-primal-hoist.cpp * Update slang-ir-autodiff-primal-hoist.cpp
1 parent a9f2f8a commit f7b9745

7 files changed

+336
-57
lines changed

source/slang/slang-ir-autodiff-primal-hoist.cpp

+181-30
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,142 @@ static Dictionary<IRBlock*, IRBlock*> createPrimalRecomputeBlocks(
281281
return recomputeBlockMap;
282282
}
283283

284+
// Checks if list A is a subset of list B by comparing their primal count parameters.
285+
//
286+
// Parameters:
287+
// indicesA - First list of IndexTrackingInfo to compare
288+
// indicesB - Second list of IndexTrackingInfo to compare
289+
//
290+
// Returns:
291+
// true if all indices in indicesA are present in indicesB, false otherwise
292+
//
293+
bool areIndicesSubsetOf(List<IndexTrackingInfo>& indicesA, List<IndexTrackingInfo>& indicesB)
294+
{
295+
if (indicesA.getCount() > indicesB.getCount())
296+
return false;
297+
298+
for (Index ii = 0; ii < indicesA.getCount(); ii++)
299+
{
300+
if (indicesA[ii].primalCountParam != indicesB[ii].primalCountParam)
301+
return false;
302+
}
303+
304+
return true;
305+
}
306+
307+
bool canInstBeStored(IRInst* inst)
308+
{
309+
// Cannot store insts whose value is a type or a witness table, or a function.
310+
// These insts get lowered to target-specific logic, and cannot be
311+
// stored into variables or context structs as normal values.
312+
//
313+
if (as<IRTypeType>(inst->getDataType()) || as<IRWitnessTableType>(inst->getDataType()) ||
314+
as<IRTypeKind>(inst->getDataType()) || as<IRFuncType>(inst->getDataType()) ||
315+
!inst->getDataType())
316+
return false;
317+
318+
return true;
319+
}
320+
321+
// This is a helper that converts insts in a loop condition block into two if necessary,
322+
// then replaces all uses 'outside' the loop region with the new insts. This is because
323+
// insts in loop condition blocks can be used in two distinct regions (the loop body, and
324+
// after the loop).
325+
//
326+
// We'll use CheckpointObject for the splitting, which is allowed on any value-typed inst.
327+
//
328+
void splitLoopConditionBlockInsts(
329+
IRGlobalValueWithCode* func,
330+
Dictionary<IRBlock*, List<IndexTrackingInfo>>& indexedBlockInfo)
331+
{
332+
// RefPtr<IRDominatorTree> domTree = computeDominatorTree(func);
333+
334+
// Collect primal loop condition blocks, and map differential blocks to their primal blocks.
335+
List<IRBlock*> loopConditionBlocks;
336+
Dictionary<IRBlock*, IRBlock*> diffBlockMap;
337+
for (auto block : func->getBlocks())
338+
{
339+
if (auto loop = as<IRLoop>(block->getTerminator()))
340+
{
341+
auto loopConditionBlock = getLoopConditionBlock(loop);
342+
if (isDifferentialBlock(loopConditionBlock))
343+
{
344+
auto diffDecor = loopConditionBlock->findDecoration<IRDifferentialInstDecoration>();
345+
diffBlockMap[cast<IRBlock>(diffDecor->getPrimalInst())] = loopConditionBlock;
346+
}
347+
else
348+
loopConditionBlocks.add(loopConditionBlock);
349+
}
350+
}
351+
352+
// For each loop condition block, split the insts that are used in both the loop body and
353+
// after the loop.
354+
// Use the dominator tree to find uses of insts outside the loop body
355+
//
356+
// Essentially we want to split the uses dominated by the true block and the false block of the
357+
// condition.
358+
//
359+
IRBuilder builder(func->getModule());
360+
361+
362+
List<IRUse*> loopUses;
363+
List<IRUse*> afterLoopUses;
364+
365+
for (auto condBlock : loopConditionBlocks)
366+
{
367+
// For each inst in the primal condition block, check if it has uses inside the loop body
368+
// as well as outside of it. (Use the indexedBlockInfo to perform the teets)
369+
//
370+
for (auto inst = condBlock->getFirstInst(); inst; inst = inst->getNextInst())
371+
{
372+
// Skip terminators and insts that can't be stored
373+
if (as<IRTerminatorInst>(inst) || !canInstBeStored(inst))
374+
continue;
375+
// Shouldn't see any vars.
376+
SLANG_ASSERT(!as<IRVar>(inst));
377+
378+
// Get the indices for the condition block
379+
auto& condBlockIndices = indexedBlockInfo[condBlock];
380+
381+
loopUses.clear();
382+
afterLoopUses.clear();
383+
384+
// Check all uses of this inst
385+
for (auto use = inst->firstUse; use; use = use->nextUse)
386+
{
387+
auto userBlock = getBlock(use->getUser());
388+
auto& userBlockIndices = indexedBlockInfo[userBlock];
389+
390+
// If all of the condBlock's indices are a subset of the userBlock's indices,
391+
// then the userBlock is inside the loop.
392+
//
393+
bool isInLoop = areIndicesSubsetOf(condBlockIndices, userBlockIndices);
394+
395+
if (isInLoop)
396+
loopUses.add(use);
397+
else
398+
afterLoopUses.add(use);
399+
}
400+
401+
// If inst has uses both inside and after the loop, create a copy for after-loop uses
402+
if (loopUses.getCount() > 0 && afterLoopUses.getCount() > 0)
403+
{
404+
setInsertAfterOrdinaryInst(&builder, inst);
405+
auto copy = builder.emitCheckpointObject(inst);
406+
407+
// Copy source location so that checkpoint reporting is accurate
408+
copy->sourceLoc = inst->sourceLoc;
409+
410+
// Replace after-loop uses with the copy
411+
for (auto use : afterLoopUses)
412+
{
413+
builder.replaceOperand(use, copy);
414+
}
415+
}
416+
}
417+
}
418+
}
419+
284420
RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc(
285421
IRGlobalValueWithCode* func,
286422
Dictionary<IRBlock*, IRBlock*>& mapDiffBlockToRecomputeBlock,
@@ -1297,20 +1433,6 @@ bool areIndicesEqual(
12971433
return true;
12981434
}
12991435

1300-
bool areIndicesSubsetOf(List<IndexTrackingInfo>& indicesA, List<IndexTrackingInfo>& indicesB)
1301-
{
1302-
if (indicesA.getCount() > indicesB.getCount())
1303-
return false;
1304-
1305-
for (Index ii = 0; ii < indicesA.getCount(); ii++)
1306-
{
1307-
if (indicesA[ii].primalCountParam != indicesB[ii].primalCountParam)
1308-
return false;
1309-
}
1310-
1311-
return true;
1312-
}
1313-
13141436
static int getInstRegionNestLevel(
13151437
Dictionary<IRBlock*, List<IndexTrackingInfo>>& indexedBlockInfo,
13161438
IRBlock* defBlock,
@@ -1510,21 +1632,6 @@ static List<IndexTrackingInfo> maybeTrimIndices(
15101632
return result;
15111633
}
15121634

1513-
bool canInstBeStored(IRInst* inst)
1514-
{
1515-
// Cannot store insts whose value is a type or a witness table, or a function.
1516-
// These insts get lowered to target-specific logic, and cannot be
1517-
// stored into variables or context structs as normal values.
1518-
//
1519-
if (as<IRTypeType>(inst->getDataType()) || as<IRWitnessTableType>(inst->getDataType()) ||
1520-
as<IRTypeKind>(inst->getDataType()) || as<IRFuncType>(inst->getDataType()) ||
1521-
!inst->getDataType())
1522-
return false;
1523-
1524-
return true;
1525-
}
1526-
1527-
15281635
/// Legalizes all accesses to primal insts from recompute and diff blocks.
15291636
///
15301637
RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
@@ -2104,6 +2211,39 @@ void buildIndexedBlocks(
21042211
}
21052212
}
21062213

2214+
// This function simply turns all CheckpointObject insts into a 'no-op'.
2215+
// i.e. simply replaces all uses of CheckpointObject with the original value.
2216+
//
2217+
// This operation is 'correct' because if CheckpointObject's operand is visible
2218+
// in a block, then it is visible in all dominated blocks.
2219+
//
2220+
void lowerCheckpointObjectInsts(IRGlobalValueWithCode* func)
2221+
{
2222+
// For each block in the function
2223+
for (auto block : func->getBlocks())
2224+
{
2225+
// For each instruction in the block
2226+
for (auto inst = block->getFirstInst(); inst;)
2227+
{
2228+
// Get next inst before potentially removing current one
2229+
auto nextInst = inst->getNextInst();
2230+
2231+
// Check if this is a CheckpointObject instruction
2232+
if (auto copyInst = as<IRCheckpointObject>(inst))
2233+
{
2234+
// Replace all uses of the copy with the original value
2235+
auto originalVal = copyInst->getVal();
2236+
copyInst->replaceUsesWith(originalVal);
2237+
2238+
// Remove the now unused copy instruction
2239+
inst->removeAndDeallocate();
2240+
}
2241+
2242+
inst = nextInst;
2243+
}
2244+
}
2245+
}
2246+
21072247
// For each primal inst that is used in reverse blocks, decide if we should recompute or store
21082248
// its value, then make them accessible in reverse blocks based the decision.
21092249
//
@@ -2117,6 +2257,9 @@ RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(IRGlobalValueWithCode* func)
21172257
Dictionary<IRBlock*, List<IndexTrackingInfo>> indexedBlockInfo;
21182258
buildIndexedBlocks(indexedBlockInfo, func);
21192259

2260+
// Split loop condition insts into two if necessary.
2261+
splitLoopConditionBlockInsts(func, indexedBlockInfo);
2262+
21202263
// Create recompute blocks for each region following the same control flow structure
21212264
// as in primal code.
21222265
//
@@ -2136,7 +2279,12 @@ RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(IRGlobalValueWithCode* func)
21362279
// Legalize the primal inst accesses by introducing local variables / arrays and emitting
21372280
// necessary load/store logic.
21382281
//
2139-
return ensurePrimalAvailability(primalsInfo, func, indexedBlockInfo);
2282+
auto hoistedPrimalsInfo = ensurePrimalAvailability(primalsInfo, func, indexedBlockInfo);
2283+
2284+
// Lower CheckpointObject insts to a no-op.
2285+
lowerCheckpointObjectInsts(func);
2286+
2287+
return hoistedPrimalsInfo;
21402288
}
21412289

21422290
void DefaultCheckpointPolicy::preparePolicy(IRGlobalValueWithCode* func)
@@ -2312,6 +2460,9 @@ static bool shouldStoreInst(IRInst* inst)
23122460

23132461
break;
23142462
}
2463+
case kIROp_CheckpointObject:
2464+
// Special inst for when a value must be stored.
2465+
return true;
23152466
default:
23162467
break;
23172468
}

source/slang/slang-ir-inst-defs.h

+2
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,8 @@ INST(BitNot, bitnot, 1, 0)
716716

717717
INST(Select, select, 3, 0)
718718

719+
INST(CheckpointObject, checkpointObj, 1, 0)
720+
719721
INST(GetStringHash, getStringHash, 1, 0)
720722

721723
INST(WaveGetActiveMask, waveGetActiveMask, 0, 0)

source/slang/slang-ir-insts.h

+18
Original file line numberDiff line numberDiff line change
@@ -2664,6 +2664,22 @@ struct IRDiscard : IRTerminatorInst
26642664
{
26652665
};
26662666

2667+
// Used for representing a distinct copy of an object.
2668+
// This will get lowered into a no-op in the backend,
2669+
// but is useful for IR transformations that need to consider
2670+
// different uses of an inst separately.
2671+
//
2672+
// For example, when we hoist primal insts out of a loop,
2673+
// we need to make distinct copies of the inst for its uses
2674+
// within the loop body and outside of it.
2675+
//
2676+
struct IRCheckpointObject : IRInst
2677+
{
2678+
IR_LEAF_ISA(CheckpointObject);
2679+
2680+
IRInst* getVal() { return getOperand(0); }
2681+
};
2682+
26672683
// Signals that this point in the code should be unreachable.
26682684
// We can/should emit a dataflow error if we can ever determine
26692685
// that a block ending in one of these can actually be
@@ -4408,6 +4424,8 @@ struct IRBuilder
44084424

44094425
IRInst* emitDiscard();
44104426

4427+
IRInst* emitCheckpointObject(IRInst* value);
4428+
44114429
IRInst* emitUnreachable();
44124430
IRInst* emitMissingReturn();
44134431

0 commit comments

Comments
 (0)