@@ -416,13 +416,18 @@ void __requireGLSLExtension(constexpr String preludeText);
416
416
__intrinsic_op($(kIROp_StaticAssert))
417
417
void static_assert(constexpr bool condition, NativeString errorMessage);
418
418
419
- /// Interface to denote types as differentiable.
420
- /// Allows for user-specified differential types as
421
- /// well as automatic generation, for when the associated type
422
- /// hasn't been declared explicitly.
423
- /// Note that the requirements must currently be defined in this exact order
424
- /// since the auto-diff pass relies on the order to grab the struct keys.
419
+ /// Represents a type that is differentiable for the purposes of automatic differentiation.
425
420
///
421
+ /// Implemented by builtin floating-point scalar types (`float`, `half`, `double`)
422
+ ///
423
+ /// vector<T, N>, matrix<T, N, M> and Array<T, N> automatically conform to
424
+ /// `IDifferentiable` if `T` conforms to `IDifferentiable`.
425
+ ///
426
+ /// @remarks Types that implement `IDifferentiable` can be used with the automatic differentiation
427
+ /// primitives `bwd_diff` and `fwd_diff` to load and store gradients of parameters.
428
+ /// @remarks This interface supports automatic synthesis of requirements. A struct that conforms to `IDifferentiable`
429
+ /// will have its `Differential`, `dzero()` and `dadd()` methods automatically synthesized based on its fields, if
430
+ /// they are not already defined.
426
431
__magic_type(DifferentiableType)
427
432
interface IDifferentiable
428
433
{
@@ -446,9 +451,13 @@ interface IDifferentiable
446
451
static Differential dmul(T, Differential);
447
452
};
448
453
449
- /// Represents a type that supports differentiation operations for pointer types.
450
- /// This interface is used to define operations that are specific to pointer types
451
- /// in the context of automatic differentiation.
454
+ /// @experimental
455
+ ///
456
+ /// Represents a type that supports differentiation operations for pointers, buffers and
457
+ /// any other types
458
+ ///
459
+ /// @remarks Support for this interface is still experimental and subject to change.
460
+ ///
452
461
__magic_type(DifferentiablePtrType)
453
462
interface IDifferentiablePtrType
454
463
{
@@ -458,8 +467,9 @@ interface IDifferentiablePtrType
458
467
459
468
460
469
/// Pair type that serves to wrap the primal and
461
- /// differential types of an arbitrary type T.
462
-
470
+ /// differential types of a differentiable value type
471
+ /// T that conforms to `IDifferentiable`.
472
+ ///
463
473
__generic<T : IDifferentiable>
464
474
__magic_type(DifferentialPairType)
465
475
__intrinsic_type($(kIROp_DifferentialPairUserCodeType))
@@ -528,6 +538,10 @@ struct DifferentialPair : IDifferentiable
528
538
}
529
539
};
530
540
541
+ /// Pair type that serves to wrap the primal and
542
+ /// differential types of a differentiable pointer type
543
+ /// T that conforms to `IDifferentiablePtrType`.
544
+ ///
531
545
__generic<T : IDifferentiablePtrType>
532
546
__magic_type(DifferentialPtrPairType)
533
547
__intrinsic_type($(kIROp_DifferentialPtrPairType))
0 commit comments