Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rework ChainRules for DynamicExpressions #56

Open
avik-pal opened this issue Apr 29, 2024 · 1 comment
Open

Rework ChainRules for DynamicExpressions #56

avik-pal opened this issue Apr 29, 2024 · 1 comment
Labels
enhancement New feature or request good first issue Good for newcomers

Comments

@avik-pal
Copy link
Member

DynamicExpressions supports ChainRules starting v0.17 SymbolicML/DynamicExpressions.jl#71. We can remove parts of our code with CRC.rrule_via_ad. We still need to define a rule because we do an in-place node update. Additionally we need to extract the node parameters in the final parameter gradient.

@avik-pal avik-pal added enhancement New feature or request good first issue Good for newcomers labels Apr 29, 2024
@avik-pal
Copy link
Member Author

avik-pal commented May 8, 2024

Needs some investigation, I wasn't able to unthunk the Tangent coming from SymbolicML/DynamicExpressions.jl#71

This would need some further thought.

function Lux.__apply_dynamic_expression_rrule(
        de::Lux.DynamicExpressionsLayer, expr, operator_enum, x, ps)
    Lux.__update_expression_constants!(expr, ps)
    @static if pkgversion(DynamicExpressions) < v"0.17"
        error("`DynamicExpressions` v0.17 or later is required for reverse mode to work.")
    end
    (y, _), pb_f = CRC.rrule(eval_tree_array, expr, x, operator_enum; de.turbo, de.bumper)
    __∇apply_dynamic_expression = @closure Δ -> begin
        _, ∂expr, ∂x, ∂operator_enum = pb_f((Δ, nothing))
        ∂ps = CRC.unthunk(∂expr).gradient
        return NoTangent(), NoTangent(), NoTangent(), ∂operator_enum, ∂x, ∂ps, NoTangent()
    end
    return y, __∇apply_dynamic_expression
end

This works but we hit a clear regression on mixed-precision. Maybe once that is handled upstream we can use the rrule directly

@avik-pal avik-pal mentioned this issue May 9, 2024
2 tasks
@avik-pal avik-pal transferred this issue from LuxDL/Lux.jl Sep 5, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request good first issue Good for newcomers
Projects
None yet
Development

No branches or pull requests

1 participant