Skip to content

Commit a9f2f8a

Browse files
Fix UseGraph::replace (#6395)
* Fix `UseGraph::isTrivial()` test. * Fix. * Fix. * Refactor `UseGraph` and `UseChain` * Update slang-ir-autodiff-primal-hoist.cpp * Update all auto-diff locations that handle pointers to treat user pointers as regular values * Update test to use direct-SPIRV only --------- Co-authored-by: Yong He <[email protected]>
1 parent 1908392 commit a9f2f8a

12 files changed

+117
-74
lines changed

source/slang/slang-ir-autodiff-fwd.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -1777,7 +1777,7 @@ void insertTempVarForMutableParams(IRModule* module, IRFunc* func)
17771777

17781778
for (auto param : params)
17791779
{
1780-
auto ptrType = as<IRPtrTypeBase>(param->getDataType());
1780+
auto ptrType = asRelevantPtrType(param->getDataType());
17811781
auto tempVar = builder.emitVar(ptrType->getValueType());
17821782
param->replaceUsesWith(tempVar);
17831783
mapParamToTempVar[param] = tempVar;
@@ -2245,7 +2245,7 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam(
22452245
builder->emitDifferentialPairGetPrimal(diffPairParam),
22462246
builder->emitDifferentialPairGetDifferential(diffType, diffPairParam));
22472247
}
2248-
else if (auto pairPtrType = as<IRPtrTypeBase>(diffPairType))
2248+
else if (auto pairPtrType = asRelevantPtrType(diffPairType))
22492249
{
22502250
auto ptrInnerPairType = as<IRDifferentialPairTypeBase>(pairPtrType->getValueType());
22512251
// Make a local copy of the parameter for primal and diff parts.

source/slang/slang-ir-autodiff-primal-hoist.cpp

+38-50
Original file line numberDiff line numberDiff line change
@@ -1174,7 +1174,7 @@ IRVar* emitIndexedLocalVar(
11741174
SourceLoc location)
11751175
{
11761176
// Cannot store pointers. Case should have been handled by now.
1177-
SLANG_RELEASE_ASSERT(!as<IRPtrTypeBase>(baseType));
1177+
SLANG_RELEASE_ASSERT(!asRelevantPtrType(baseType));
11781178

11791179
// Cannot store types. Case should have been handled by now.
11801180
SLANG_RELEASE_ASSERT(!as<IRTypeType>(baseType));
@@ -1326,7 +1326,11 @@ static int getInstRegionNestLevel(
13261326

13271327
struct UseChain
13281328
{
1329+
// The chain of uses from the base use to the relevant use.
1330+
// However, this is stored in reverse order (so that the last use is the 'base use')
1331+
//
13291332
List<IRUse*> chain;
1333+
13301334
static List<UseChain> from(
13311335
IRUse* baseUse,
13321336
Func<bool, IRUse*> isRelevantUse,
@@ -1366,41 +1370,20 @@ struct UseChain
13661370
return result;
13671371
}
13681372

1369-
void replace(IROutOfOrderCloneContext* ctx, IRBuilder* builder, IRInst* inst)
1373+
// This function only replaces the inner links, not the base use.
1374+
void replaceInnerLinks(IROutOfOrderCloneContext* ctx, IRBuilder* builder)
13701375
{
13711376
SLANG_ASSERT(chain.getCount() > 0);
13721377

1373-
// Simple case: if there is only one use, then we can just replace it.
1374-
if (chain.getCount() == 1)
1375-
{
1376-
builder->replaceOperand(chain.getLast(), inst);
1377-
chain.clear();
1378-
return;
1379-
}
1380-
1381-
// Pop the last use, which is the base use that needs to be replaced.
1382-
auto baseUse = chain.getLast();
1383-
chain.removeLast();
1378+
const UIndex count = chain.getCount();
13841379

1385-
// Ensure that replacement inst is set as mapping for the baseUse.
1386-
ctx->cloneEnv.mapOldValToNew[baseUse->get()] = inst;
1387-
1388-
IRBuilder chainBuilder(builder->getModule());
1389-
setInsertAfterOrdinaryInst(&chainBuilder, inst);
1390-
1391-
chain.reverse();
1392-
chain.removeLast();
1393-
1394-
// Clone the rest of the chain.
1395-
for (auto& use : chain)
1380+
// Process the chain in reverse order (excluding the first and last elements).
1381+
// That is, iterate from count - 2 down to 1 (inclusive).
1382+
for (int i = ((int)count) - 2; i >= 1; i--)
13961383
{
1397-
ctx->cloneInstOutOfOrder(&chainBuilder, use->get());
1384+
IRUse* use = chain[i];
1385+
ctx->cloneInstOutOfOrder(builder, use->get());
13981386
}
1399-
1400-
// We won't actually replace the final use, because if there are multiple chains
1401-
// it can cause problems. The parent UseGraph will handle that.
1402-
1403-
chain.clear();
14041387
}
14051388

14061389
IRInst* getUser() const
@@ -1417,6 +1400,14 @@ struct UseGraph
14171400
//
14181401
OrderedDictionary<IRUse*, List<UseChain>> chainSets;
14191402

1403+
// Create a UseGraph from a base inst.
1404+
//
1405+
// `isRelevantUse` is a predicate that determines if a use is relevant. Traversal will stop at
1406+
// this use, and all chains to this use will be grouped together.
1407+
//
1408+
// `passthroughInst` is a predicate that determines if an inst should be looked through
1409+
// for uses.
1410+
//
14201411
static UseGraph from(
14211412
IRInst* baseInst,
14221413
Func<bool, IRUse*> isRelevantUse,
@@ -1445,36 +1436,33 @@ struct UseGraph
14451436
return result;
14461437
}
14471438

1448-
void replace(IRBuilder* builder, IRUse* use, IRInst* inst)
1439+
void replace(IRBuilder* builder, IRUse* relevantUse, IRInst* inst)
14491440
{
14501441
// Since we may have common nodes, we will use an out-of-order cloning context
14511442
// that can retroactively correct the uses as needed.
14521443
//
14531444
IROutOfOrderCloneContext ctx;
1454-
List<UseChain> chains = chainSets[use];
1455-
for (auto chain : chains)
1456-
{
1457-
chain.replace(&ctx, builder, inst);
1458-
}
1445+
List<UseChain> chains = chainSets[relevantUse];
14591446

1460-
if (!isTrivial())
1447+
// Link the first use of each chain to inst.
1448+
for (auto& chain : chains)
1449+
ctx.cloneEnv.mapOldValToNew[chain.chain.getLast()->get()] = inst;
1450+
1451+
// Process the inner links of each chain using the replacement.
1452+
for (auto& chain : chains)
14611453
{
1462-
builder->setInsertBefore(use->getUser());
1463-
auto lastInstInChain = ctx.cloneInstOutOfOrder(builder, use->get());
1454+
IRBuilder chainBuilder(builder->getModule());
1455+
setInsertAfterOrdinaryInst(&chainBuilder, inst);
14641456

1465-
// Replace the base use.
1466-
builder->replaceOperand(use, lastInstInChain);
1457+
chain.replaceInnerLinks(&ctx, builder);
14671458
}
1468-
}
14691459

1470-
bool isTrivial()
1471-
{
1472-
// We're trivial if there's only one chain, and it has only one use.
1473-
if (chainSets.getCount() != 1)
1474-
return false;
1460+
// Finally, replace the relevant use (i.e, "final use") with the new replacement inst.
1461+
builder->setInsertBefore(relevantUse->getUser());
1462+
auto lastInstInChain = ctx.cloneInstOutOfOrder(builder, relevantUse->get());
14751463

1476-
auto& chain = chainSets.getFirst().value;
1477-
return chain.getCount() == 1;
1464+
// Replace the base use.
1465+
builder->replaceOperand(relevantUse, lastInstInChain);
14781466
}
14791467

14801468
List<IRUse*> getUniqueUses() const
@@ -1668,7 +1656,7 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability(
16681656
return true;
16691657
}
16701658
else if (
1671-
as<IRPtrTypeBase>(instToStore->getDataType()) &&
1659+
asRelevantPtrType(instToStore->getDataType()) &&
16721660
!isDifferentialOrRecomputeBlock(defBlock))
16731661
{
16741662
return true;

source/slang/slang-ir-autodiff-rev.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ IRType* BackwardDiffTranscriberBase::transcribeParamTypeForPropagateFunc(
370370
auto diffPairType = tryGetDiffPairType(builder, paramType);
371371
if (diffPairType)
372372
{
373-
if (!as<IRPtrTypeBase>(diffPairType) && !as<IRDifferentialPtrPairType>(diffPairType))
373+
if (!asRelevantPtrType(diffPairType) && !as<IRDifferentialPtrPairType>(diffPairType))
374374
return builder->getInOutType(diffPairType);
375375
return diffPairType;
376376
}
@@ -514,7 +514,7 @@ InstPair BackwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRF
514514
{
515515
// As long as the primal parameter is not an out or constref type,
516516
// we need to fetch the primal value from the parameter.
517-
if (as<IRPtrTypeBase>(propagateParamType))
517+
if (asRelevantPtrType(propagateParamType))
518518
{
519519
primalArg = builder.emitLoad(param);
520520
}
@@ -544,7 +544,7 @@ InstPair BackwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRF
544544
}
545545
else
546546
{
547-
auto primalPtrType = as<IRPtrTypeBase>(primalParamType);
547+
auto primalPtrType = asRelevantPtrType(primalParamType);
548548
SLANG_RELEASE_ASSERT(primalPtrType);
549549
auto primalValueType = primalPtrType->getValueType();
550550
auto var = builder.emitVar(primalValueType);

source/slang/slang-ir-autodiff-transcriber-base.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -291,7 +291,7 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy
291291
if (isNoDiffType(origType))
292292
return nullptr;
293293

294-
if (auto ptrType = as<IRPtrTypeBase>(origType))
294+
if (auto ptrType = asRelevantPtrType(origType))
295295
return builder->getPtrType(
296296
origType->getOp(),
297297
differentiateType(builder, ptrType->getValueType()));
@@ -556,7 +556,7 @@ IRType* AutoDiffTranscriberBase::tryGetDiffPairType(IRBuilder* builder, IRType*
556556
if (isNoDiffType(originalType))
557557
return nullptr;
558558

559-
if (auto origPtrType = as<IRPtrTypeBase>(originalType))
559+
if (auto origPtrType = asRelevantPtrType(originalType))
560560
{
561561
if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType()))
562562
return builder->getPtrType(originalType->getOp(), diffPairValueType);

source/slang/slang-ir-autodiff-transpose.h

+3-4
Original file line numberDiff line numberDiff line change
@@ -619,7 +619,7 @@ struct DiffTransposePass
619619
if (isDifferentialInst(varInst) && tryGetPrimalTypeFromDiffInst(varInst))
620620
{
621621
if (auto ptrPrimalType =
622-
as<IRPtrTypeBase>(tryGetPrimalTypeFromDiffInst(varInst)))
622+
asRelevantPtrType(tryGetPrimalTypeFromDiffInst(varInst)))
623623
{
624624
varInst->insertAtEnd(firstRevDiffBlock);
625625

@@ -1119,7 +1119,7 @@ struct DiffTransposePass
11191119

11201120
auto getDiffPairType = [](IRType* type)
11211121
{
1122-
if (auto ptrType = as<IRPtrTypeBase>(type))
1122+
if (auto ptrType = asRelevantPtrType(type))
11231123
type = ptrType->getValueType();
11241124
return as<IRDifferentialPairType>(type);
11251125
};
@@ -1168,7 +1168,7 @@ struct DiffTransposePass
11681168
argRequiresLoad.add(false);
11691169
writebacks.add(DiffValWriteBack{instPair->getDiff(), tempVar});
11701170
}
1171-
else if (!as<IRPtrTypeBase>(arg->getDataType()) && getDiffPairType(arg->getDataType()))
1171+
else if (!asRelevantPtrType(arg->getDataType()) && getDiffPairType(arg->getDataType()))
11721172
{
11731173
// Normal differentiable input parameter will become an inout DiffPair parameter
11741174
// in the propagate func. The split logic has already prepared the initial value
@@ -1241,7 +1241,6 @@ struct DiffTransposePass
12411241
argRequiresLoad.add(false);
12421242
}
12431243

1244-
12451244
auto revFnType =
12461245
this->autodiffContext->transcriberSet.propagateTranscriber->differentiateFunctionType(
12471246
builder,

source/slang/slang-ir-autodiff-unzip.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,8 @@ bool isIntermediateContextType(IRInst* type)
332332
case kIROp_Specialize:
333333
return isIntermediateContextType(as<IRSpecialize>(type)->getBase());
334334
default:
335-
if (as<IRPtrTypeBase>(type))
336-
return isIntermediateContextType(as<IRPtrTypeBase>(type)->getValueType());
335+
if (auto ptrType = asRelevantPtrType(type))
336+
return isIntermediateContextType(ptrType->getValueType());
337337
return false;
338338
}
339339
}

source/slang/slang-ir-autodiff-unzip.h

+3-4
Original file line numberDiff line numberDiff line change
@@ -75,16 +75,15 @@ struct DiffUnzipPass
7575
primalParam = primalParam->getNextParam())
7676
{
7777
auto type = primalParam->getFullType();
78-
if (auto ptrType = as<IRPtrTypeBase>(type))
78+
if (auto ptrType = asRelevantPtrType(type))
7979
{
8080
type = ptrType->getValueType();
8181
}
8282
if (auto pairType = as<IRDifferentialPairType>(type))
8383
{
8484
IRInst* diffType = diffTypeContext.getDiffTypeFromPairType(builder, pairType);
85-
if (as<IRPtrTypeBase>(primalParam->getFullType()))
86-
diffType =
87-
builder->getPtrType(primalParam->getFullType()->getOp(), (IRType*)diffType);
85+
if (auto ptrType = asRelevantPtrType(primalParam->getFullType()))
86+
diffType = builder->getPtrType(ptrType->getOp(), (IRType*)diffType);
8887
auto primalRef = builder->emitPrimalParamRef(primalParam);
8988
auto diffRef = builder->emitDiffParamRef((IRType*)diffType, primalParam);
9089
builder->markInstAsDifferential(diffRef, pairType->getValueType());

source/slang/slang-ir-autodiff.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ bool isNoDiffType(IRType* paramType)
135135

136136
paramType = attrType->getBaseType();
137137
}
138-
else if (auto ptrType = as<IRPtrTypeBase>(paramType))
138+
else if (auto ptrType = asRelevantPtrType(paramType))
139139
{
140140
paramType = ptrType->getValueType();
141141
}
@@ -184,7 +184,7 @@ IRInst* DifferentialPairTypeBuilder::emitFieldAccessor(
184184
IRStructKey* key)
185185
{
186186
IRInst* pairType = nullptr;
187-
if (auto basePtrType = as<IRPtrTypeBase>(baseInst->getDataType()))
187+
if (auto basePtrType = asRelevantPtrType(baseInst->getDataType()))
188188
{
189189
auto loweredType = lowerDiffPairType(builder, basePtrType->getValueType());
190190

@@ -203,7 +203,7 @@ IRInst* DifferentialPairTypeBuilder::emitFieldAccessor(
203203
baseInst,
204204
key));
205205
}
206-
else if (auto ptrType = as<IRPtrTypeBase>(pairType))
206+
else if (auto ptrType = asRelevantPtrType(pairType))
207207
{
208208
if (auto ptrInnerSpecializedType = as<IRSpecialize>(ptrType->getValueType()))
209209
{
@@ -240,7 +240,7 @@ IRInst* DifferentialPairTypeBuilder::emitFieldAccessor(
240240
baseInst,
241241
key));
242242
}
243-
else if (auto genericPtrType = as<IRPtrTypeBase>(genericType))
243+
else if (auto genericPtrType = asRelevantPtrType(genericType))
244244
{
245245
if (auto genericPairStructType = as<IRStructType>(genericPtrType->getValueType()))
246246
{
@@ -1646,7 +1646,7 @@ IRType* DifferentiableTypeConformanceContext::differentiateType(
16461646
IRBuilder* builder,
16471647
IRInst* primalType)
16481648
{
1649-
if (auto ptrType = as<IRPtrTypeBase>(primalType))
1649+
if (auto ptrType = asRelevantPtrType(primalType))
16501650
return builder->getPtrType(
16511651
primalType->getOp(),
16521652
differentiateType(builder, ptrType->getValueType()));

source/slang/slang-ir-autodiff.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -604,7 +604,7 @@ inline bool isRelevantDifferentialPair(IRType* type)
604604
{
605605
return true;
606606
}
607-
else if (auto argPtrType = as<IRPtrTypeBase>(type))
607+
else if (auto argPtrType = asRelevantPtrType(type))
608608
{
609609
if (as<IRDifferentialPairType>(argPtrType->getValueType()))
610610
{

source/slang/slang-ir-util.cpp

+11-1
Original file line numberDiff line numberDiff line change
@@ -1528,14 +1528,24 @@ bool isOne(IRInst* inst)
15281528
}
15291529
}
15301530

1531+
IRPtrTypeBase* asRelevantPtrType(IRInst* inst)
1532+
{
1533+
if (auto ptrType = as<IRPtrTypeBase>(inst))
1534+
{
1535+
if (ptrType->getAddressSpace() != AddressSpace::UserPointer)
1536+
return ptrType;
1537+
}
1538+
return nullptr;
1539+
}
1540+
15311541
IRPtrTypeBase* isMutablePointerType(IRInst* inst)
15321542
{
15331543
switch (inst->getOp())
15341544
{
15351545
case kIROp_ConstRefType:
15361546
return nullptr;
15371547
default:
1538-
return as<IRPtrTypeBase>(inst);
1548+
return asRelevantPtrType(inst);
15391549
}
15401550
}
15411551

source/slang/slang-ir-util.h

+4
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,10 @@ bool isZero(IRInst* inst);
271271

272272
bool isOne(IRInst* inst);
273273

274+
// Casts inst to IRPtrTypeBase, excluding UserPointer address space.
275+
IRPtrTypeBase* asRelevantPtrType(IRInst* inst);
276+
277+
// Returns the pointer type if it is pointer type that is not a const ref or a user pointer.
274278
IRPtrTypeBase* isMutablePointerType(IRInst* inst);
275279

276280
void initializeScratchData(IRInst* inst);

0 commit comments

Comments
 (0)