Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Provide implicit reparameterization gradients to parameters of Scalar…
…FunctionWithInferredInverse. This originally arose as an issue in #1270 where we wanted to define the bijector that interpolates between a sigmoid and the identity function: y = fn(x, a) = a * x + (1 - a) * sigmoid(x) (for a in [0, 1]) with a trainable coefficient `a`, but the checked-in version of this bijector doesn't provide a gradient to `a`. This version only accepts scalar parameters, because 1. This was already complicated enough. 2. I *think* that supporting parameters of arbitrary rank would require the caller to tell us the rank of the parameter (so we know how much of it to treat as 'batch shape'), which seems like a mess. 3. Going vector-only instead of scalar-only would work too, and allows faster math if there are lots of parameters, but it's awkward to deal with boxing and unboxing vectors in the common case where the parameter really is semantically a scalar. 4. In dire straits, you could simulate a vector of fixed size by passing multiple scalars. I'm not certain this current API will ultimately be exactly the right thing, but that's what experimental is for. :-) PiperOrigin-RevId: 369308653
- Loading branch information