@@ -281,6 +281,142 @@ static Dictionary<IRBlock*, IRBlock*> createPrimalRecomputeBlocks(
281
281
return recomputeBlockMap;
282
282
}
283
283
284
+ // Checks if list A is a subset of list B by comparing their primal count parameters.
285
+ //
286
+ // Parameters:
287
+ // indicesA - First list of IndexTrackingInfo to compare
288
+ // indicesB - Second list of IndexTrackingInfo to compare
289
+ //
290
+ // Returns:
291
+ // true if all indices in indicesA are present in indicesB, false otherwise
292
+ //
293
+ bool areIndicesSubsetOf (List<IndexTrackingInfo>& indicesA, List<IndexTrackingInfo>& indicesB)
294
+ {
295
+ if (indicesA.getCount () > indicesB.getCount ())
296
+ return false ;
297
+
298
+ for (Index ii = 0 ; ii < indicesA.getCount (); ii++)
299
+ {
300
+ if (indicesA[ii].primalCountParam != indicesB[ii].primalCountParam )
301
+ return false ;
302
+ }
303
+
304
+ return true ;
305
+ }
306
+
307
+ bool canInstBeStored (IRInst* inst)
308
+ {
309
+ // Cannot store insts whose value is a type or a witness table, or a function.
310
+ // These insts get lowered to target-specific logic, and cannot be
311
+ // stored into variables or context structs as normal values.
312
+ //
313
+ if (as<IRTypeType>(inst->getDataType ()) || as<IRWitnessTableType>(inst->getDataType ()) ||
314
+ as<IRTypeKind>(inst->getDataType ()) || as<IRFuncType>(inst->getDataType ()) ||
315
+ !inst->getDataType ())
316
+ return false ;
317
+
318
+ return true ;
319
+ }
320
+
321
+ // This is a helper that converts insts in a loop condition block into two if necessary,
322
+ // then replaces all uses 'outside' the loop region with the new insts. This is because
323
+ // insts in loop condition blocks can be used in two distinct regions (the loop body, and
324
+ // after the loop).
325
+ //
326
+ // We'll use CheckpointObject for the splitting, which is allowed on any value-typed inst.
327
+ //
328
+ void splitLoopConditionBlockInsts (
329
+ IRGlobalValueWithCode* func,
330
+ Dictionary<IRBlock*, List<IndexTrackingInfo>>& indexedBlockInfo)
331
+ {
332
+ // RefPtr<IRDominatorTree> domTree = computeDominatorTree(func);
333
+
334
+ // Collect primal loop condition blocks, and map differential blocks to their primal blocks.
335
+ List<IRBlock*> loopConditionBlocks;
336
+ Dictionary<IRBlock*, IRBlock*> diffBlockMap;
337
+ for (auto block : func->getBlocks ())
338
+ {
339
+ if (auto loop = as<IRLoop>(block->getTerminator ()))
340
+ {
341
+ auto loopConditionBlock = getLoopConditionBlock (loop);
342
+ if (isDifferentialBlock (loopConditionBlock))
343
+ {
344
+ auto diffDecor = loopConditionBlock->findDecoration <IRDifferentialInstDecoration>();
345
+ diffBlockMap[cast<IRBlock>(diffDecor->getPrimalInst ())] = loopConditionBlock;
346
+ }
347
+ else
348
+ loopConditionBlocks.add (loopConditionBlock);
349
+ }
350
+ }
351
+
352
+ // For each loop condition block, split the insts that are used in both the loop body and
353
+ // after the loop.
354
+ // Use the dominator tree to find uses of insts outside the loop body
355
+ //
356
+ // Essentially we want to split the uses dominated by the true block and the false block of the
357
+ // condition.
358
+ //
359
+ IRBuilder builder (func->getModule ());
360
+
361
+
362
+ List<IRUse*> loopUses;
363
+ List<IRUse*> afterLoopUses;
364
+
365
+ for (auto condBlock : loopConditionBlocks)
366
+ {
367
+ // For each inst in the primal condition block, check if it has uses inside the loop body
368
+ // as well as outside of it. (Use the indexedBlockInfo to perform the teets)
369
+ //
370
+ for (auto inst = condBlock->getFirstInst (); inst; inst = inst->getNextInst ())
371
+ {
372
+ // Skip terminators and insts that can't be stored
373
+ if (as<IRTerminatorInst>(inst) || !canInstBeStored (inst))
374
+ continue ;
375
+ // Shouldn't see any vars.
376
+ SLANG_ASSERT (!as<IRVar>(inst));
377
+
378
+ // Get the indices for the condition block
379
+ auto & condBlockIndices = indexedBlockInfo[condBlock];
380
+
381
+ loopUses.clear ();
382
+ afterLoopUses.clear ();
383
+
384
+ // Check all uses of this inst
385
+ for (auto use = inst->firstUse ; use; use = use->nextUse )
386
+ {
387
+ auto userBlock = getBlock (use->getUser ());
388
+ auto & userBlockIndices = indexedBlockInfo[userBlock];
389
+
390
+ // If all of the condBlock's indices are a subset of the userBlock's indices,
391
+ // then the userBlock is inside the loop.
392
+ //
393
+ bool isInLoop = areIndicesSubsetOf (condBlockIndices, userBlockIndices);
394
+
395
+ if (isInLoop)
396
+ loopUses.add (use);
397
+ else
398
+ afterLoopUses.add (use);
399
+ }
400
+
401
+ // If inst has uses both inside and after the loop, create a copy for after-loop uses
402
+ if (loopUses.getCount () > 0 && afterLoopUses.getCount () > 0 )
403
+ {
404
+ setInsertAfterOrdinaryInst (&builder, inst);
405
+ auto copy = builder.emitCheckpointObject (inst);
406
+
407
+ // Copy source location so that checkpoint reporting is accurate
408
+ copy->sourceLoc = inst->sourceLoc ;
409
+
410
+ // Replace after-loop uses with the copy
411
+ for (auto use : afterLoopUses)
412
+ {
413
+ builder.replaceOperand (use, copy);
414
+ }
415
+ }
416
+ }
417
+ }
418
+ }
419
+
284
420
RefPtr<HoistedPrimalsInfo> AutodiffCheckpointPolicyBase::processFunc (
285
421
IRGlobalValueWithCode* func,
286
422
Dictionary<IRBlock*, IRBlock*>& mapDiffBlockToRecomputeBlock,
@@ -1297,20 +1433,6 @@ bool areIndicesEqual(
1297
1433
return true ;
1298
1434
}
1299
1435
1300
- bool areIndicesSubsetOf (List<IndexTrackingInfo>& indicesA, List<IndexTrackingInfo>& indicesB)
1301
- {
1302
- if (indicesA.getCount () > indicesB.getCount ())
1303
- return false ;
1304
-
1305
- for (Index ii = 0 ; ii < indicesA.getCount (); ii++)
1306
- {
1307
- if (indicesA[ii].primalCountParam != indicesB[ii].primalCountParam )
1308
- return false ;
1309
- }
1310
-
1311
- return true ;
1312
- }
1313
-
1314
1436
static int getInstRegionNestLevel (
1315
1437
Dictionary<IRBlock*, List<IndexTrackingInfo>>& indexedBlockInfo,
1316
1438
IRBlock* defBlock,
@@ -1510,21 +1632,6 @@ static List<IndexTrackingInfo> maybeTrimIndices(
1510
1632
return result;
1511
1633
}
1512
1634
1513
- bool canInstBeStored (IRInst* inst)
1514
- {
1515
- // Cannot store insts whose value is a type or a witness table, or a function.
1516
- // These insts get lowered to target-specific logic, and cannot be
1517
- // stored into variables or context structs as normal values.
1518
- //
1519
- if (as<IRTypeType>(inst->getDataType ()) || as<IRWitnessTableType>(inst->getDataType ()) ||
1520
- as<IRTypeKind>(inst->getDataType ()) || as<IRFuncType>(inst->getDataType ()) ||
1521
- !inst->getDataType ())
1522
- return false ;
1523
-
1524
- return true ;
1525
- }
1526
-
1527
-
1528
1635
// / Legalizes all accesses to primal insts from recompute and diff blocks.
1529
1636
// /
1530
1637
RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability (
@@ -2104,6 +2211,39 @@ void buildIndexedBlocks(
2104
2211
}
2105
2212
}
2106
2213
2214
+ // This function simply turns all CheckpointObject insts into a 'no-op'.
2215
+ // i.e. simply replaces all uses of CheckpointObject with the original value.
2216
+ //
2217
+ // This operation is 'correct' because if CheckpointObject's operand is visible
2218
+ // in a block, then it is visible in all dominated blocks.
2219
+ //
2220
+ void lowerCheckpointObjectInsts (IRGlobalValueWithCode* func)
2221
+ {
2222
+ // For each block in the function
2223
+ for (auto block : func->getBlocks ())
2224
+ {
2225
+ // For each instruction in the block
2226
+ for (auto inst = block->getFirstInst (); inst;)
2227
+ {
2228
+ // Get next inst before potentially removing current one
2229
+ auto nextInst = inst->getNextInst ();
2230
+
2231
+ // Check if this is a CheckpointObject instruction
2232
+ if (auto copyInst = as<IRCheckpointObject>(inst))
2233
+ {
2234
+ // Replace all uses of the copy with the original value
2235
+ auto originalVal = copyInst->getVal ();
2236
+ copyInst->replaceUsesWith (originalVal);
2237
+
2238
+ // Remove the now unused copy instruction
2239
+ inst->removeAndDeallocate ();
2240
+ }
2241
+
2242
+ inst = nextInst;
2243
+ }
2244
+ }
2245
+ }
2246
+
2107
2247
// For each primal inst that is used in reverse blocks, decide if we should recompute or store
2108
2248
// its value, then make them accessible in reverse blocks based the decision.
2109
2249
//
@@ -2117,6 +2257,9 @@ RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(IRGlobalValueWithCode* func)
2117
2257
Dictionary<IRBlock*, List<IndexTrackingInfo>> indexedBlockInfo;
2118
2258
buildIndexedBlocks (indexedBlockInfo, func);
2119
2259
2260
+ // Split loop condition insts into two if necessary.
2261
+ splitLoopConditionBlockInsts (func, indexedBlockInfo);
2262
+
2120
2263
// Create recompute blocks for each region following the same control flow structure
2121
2264
// as in primal code.
2122
2265
//
@@ -2136,7 +2279,12 @@ RefPtr<HoistedPrimalsInfo> applyCheckpointPolicy(IRGlobalValueWithCode* func)
2136
2279
// Legalize the primal inst accesses by introducing local variables / arrays and emitting
2137
2280
// necessary load/store logic.
2138
2281
//
2139
- return ensurePrimalAvailability (primalsInfo, func, indexedBlockInfo);
2282
+ auto hoistedPrimalsInfo = ensurePrimalAvailability (primalsInfo, func, indexedBlockInfo);
2283
+
2284
+ // Lower CheckpointObject insts to a no-op.
2285
+ lowerCheckpointObjectInsts (func);
2286
+
2287
+ return hoistedPrimalsInfo;
2140
2288
}
2141
2289
2142
2290
void DefaultCheckpointPolicy::preparePolicy (IRGlobalValueWithCode* func)
@@ -2312,6 +2460,9 @@ static bool shouldStoreInst(IRInst* inst)
2312
2460
2313
2461
break ;
2314
2462
}
2463
+ case kIROp_CheckpointObject :
2464
+ // Special inst for when a value must be stored.
2465
+ return true ;
2315
2466
default :
2316
2467
break ;
2317
2468
}
0 commit comments