Skip to content

Commit

Permalink
Merge pull request #17 from McMasterU/differentiation_backward
Browse files Browse the repository at this point in the history
Differentiation reverse method
  • Loading branch information
dandoh authored Jul 27, 2020
2 parents e471d50 + 62c9be7 commit 0788eea
Show file tree
Hide file tree
Showing 27 changed files with 636 additions and 124 deletions.
11 changes: 7 additions & 4 deletions HashedExpression.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ cabal-version: 1.12
--
-- see: https://github.com/sol/hpack
--
-- hash: 53e3c8434f7f2a6206e7b49dcb5c186d6ea5ffb6931a3c138160d6bbae9c09e8
-- hash: 158d5ed7ffd77d8520dc6eedabb8d47f1240293abc2485d2358bc8fa6a262c85

name: HashedExpression
version: 0.1.0.0
Expand Down Expand Up @@ -39,11 +39,13 @@ library
HashedExpression.Codegen
HashedExpression.Codegen.CSIMD
HashedExpression.Codegen.CSimple
HashedExpression.Derivative
HashedExpression.Derivative.Partial
HashedExpression.Differentiation.Exterior
HashedExpression.Differentiation.Exterior.Collect
HashedExpression.Differentiation.Exterior.Derivative
HashedExpression.Differentiation.Reverse
HashedExpression.Differentiation.Reverse.State
HashedExpression.Embed.FFTW
HashedExpression.Internal
HashedExpression.Internal.CollectDifferential
HashedExpression.Internal.Collision
HashedExpression.Internal.Expression
HashedExpression.Internal.Hash
Expand Down Expand Up @@ -180,6 +182,7 @@ test-suite HashedExpression-test
InterpSpec
NormalizeSpec
ProblemSpec
ReverseDifferentiationSpec
StructureSpec
Var
Paths_HashedExpression
Expand Down
1 change: 0 additions & 1 deletion TODO.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ stack haddock --haddock-arguments "--odir=docs/"
### TODO add regression tests for examples
### TODO Better interface for Transformation's (in Inner.hs) needed?
- Make relationship between Transformation/Modification/Change clearer? Put in it's own module?
- toRecursiveSimplification and toRecursiveCollecting should be reduced to one function?
### TODO Make sure we don't introduce bugs doing CodeGen for FT
### TODO Maybe use Numeric Prelude for better Num class and then better VectorSpace
### TODO add cVariable1D name = variable1D (name ++ "Re") +: variable1D (name ++ "Im"), cVariable2D = ...
Expand Down
3 changes: 1 addition & 2 deletions app/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@ import Data.Map (empty, fromList, union)
import Data.Maybe (fromJust)
import Data.STRef.Strict
import qualified Data.Set as Set
import HashedExpression.Derivative.Partial
import Graphics.EasyPlot
import HashedExpression
import HashedExpression.Derivative
import HashedExpression.Differentiation.Exterior.Derivative
import HashedExpression.Interp
import HashedExpression.Operation
import qualified HashedExpression.Operation
Expand Down
4 changes: 2 additions & 2 deletions src/HashedExpression.hs
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,8 @@ module HashedExpression
)
where

import HashedExpression.Derivative
import HashedExpression.Internal.CollectDifferential
import HashedExpression.Differentiation.Exterior.Collect
import HashedExpression.Differentiation.Exterior.Derivative
import HashedExpression.Internal.Expression
import HashedExpression.Internal.Normalize
import HashedExpression.Interp
Expand Down
5 changes: 5 additions & 0 deletions src/HashedExpression/Codegen/CSimple.hs
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,11 @@ evaluating CSimpleCodegen {..} rootIDs =
]
RealPart arg -> for i (len n) [[I.i|#{n !! i} = #{arg `reAt` i};|]]
ImagPart arg -> for i (len n) [[I.i|#{n !! i} = #{arg `imAt` i};|]]
Conjugate arg ->
for i (len n) $
[ [I.i|#{n `reAt` i} = #{arg `reAt` i};|],
[I.i|#{n `imAt` i} = -#{arg `imAt` i};|]
]
InnerProd arg1 arg2
| et == R && null (shapeOf arg1) -> [[I.i|#{n !! nooffset} = #{arg1 !! nooffset} * #{arg2 !! nooffset};|]]
| et == R ->
Expand Down
36 changes: 0 additions & 36 deletions src/HashedExpression/Derivative/Partial.hs

This file was deleted.

40 changes: 40 additions & 0 deletions src/HashedExpression/Differentiation/Exterior.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
-- |
-- Module : HashedExpression.Differentiation.Exterior.Derivative
-- Copyright : (c) OCA 2020
-- License : MIT (see the LICENSE file)
-- Maintainer : [email protected]
-- Stability : provisional
-- Portability : unportable
module HashedExpression.Differentiation.Exterior where

import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import Data.Maybe (mapMaybe)
import HashedExpression.Differentiation.Exterior.Collect
import HashedExpression.Differentiation.Exterior.Derivative
import HashedExpression.Internal
import HashedExpression.Internal.Expression
import HashedExpression.Internal.Node

partialDerivativesMapByExterior :: Expression Scalar R -> (ExpressionMap, Map String NodeID)
partialDerivativesMapByExterior exp =
let (mp, rootID) = unwrap . collectDifferentials . derivativeAllVars $ exp
in (mp, partialDerivativesMap (mp, rootID))

-- | Return a map from variable name to the corresponding partial derivative node id
-- Partial derivatives in Expression Scalar Covector should be collected before passing to this function
partialDerivativesMap :: (ExpressionMap, NodeID) -> Map String NodeID
partialDerivativesMap (dfMp, dfId) =
case retrieveOp dfId dfMp of
Sum ns | retrieveElementType dfId dfMp == Covector -> Map.fromList $ mapMaybe getPartial ns
_ -> Map.fromList $ mapMaybe getPartial [dfId]
where
getPartial :: NodeID -> Maybe (String, NodeID)
getPartial nId
| MulD partialId dId <- retrieveOp nId dfMp,
DVar name <- retrieveOp dId dfMp =
Just (name, partialId)
| InnerProdD partialId dId <- retrieveOp nId dfMp,
DVar name <- retrieveOp dId dfMp =
Just (name, partialId)
| otherwise = Nothing
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
-- |
-- Module : HashedExpression.Internal.CollectDifferential
-- Module : HashedExpression.Differentiation.Exterior.Collect
-- Copyright : (c) OCA 2020
-- License : MIT (see the LICENSE file)
-- Maintainer : [email protected]
Expand All @@ -8,7 +8,7 @@
--
-- This module exists solely to factor terms around their differentials. When properly factored, the term multiplying
-- a differential (say dx) is it's corresponding parital derivative (i.e derivative w.r.t x)
module HashedExpression.Internal.CollectDifferential
module HashedExpression.Differentiation.Exterior.Collect
( collectDifferentials,
)
where
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{-# LANGUAGE ScopedTypeVariables #-}

-- |
-- Module : HashedExpression.Derivative
-- Module : HashedExpression.Differentiation.Exterior.Derivative
-- Copyright : (c) OCA 2020
-- License : MIT (see the LICENSE file)
-- Maintainer : [email protected]
Expand All @@ -17,7 +17,7 @@
-- Computing an exterior derivative on an expression @Expression d R@ will result in a @Expression d Covector@, i.e a 'Covector' field
-- (also known as 1-form). This will contain 'dVar' terms representing where implicit differentiation has occurred. See 'CollectDifferential'
-- to factor like terms for producing partial derivatives
module HashedExpression.Derivative
module HashedExpression.Differentiation.Exterior.Derivative
( exteriorDerivative,
derivativeAllVars,
)
Expand Down
210 changes: 210 additions & 0 deletions src/HashedExpression/Differentiation/Reverse.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
-- |
-- Module : HashedExpression.Differentiation.Exterior.Collect
-- Copyright : (c) OCA 2020
-- License : MIT (see the LICENSE file)
-- Maintainer : [email protected]
-- Stability : provisional
-- Portability : unportable
--
-- Compute differentiations using reverse accumulation method
-- https://en.wikipedia.org/wiki/Automatic_differentiation#Reverse_accumulation
module HashedExpression.Differentiation.Reverse where

import Control.Monad.State.Strict
import qualified Data.IntMap.Strict as IM
import Data.List (foldl')
import Data.List.HT (removeEach)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import HashedExpression.Differentiation.Reverse.State
import HashedExpression.Internal
import HashedExpression.Internal.Expression
import HashedExpression.Internal.Hash
import HashedExpression.Internal.Node
import HashedExpression.Internal.OperationSpec
import HashedExpression.Internal.Structure
import Prelude hiding ((^))

-- |
partialDerivativesMapByReverse ::
Expression Scalar R ->
(ExpressionMap, Map String NodeID)
partialDerivativesMapByReverse (Expression rootID mp) =
let reverseTopoOrder = reverse $ topologicalSort (mp, rootID)
init = ComputeDState mp IM.empty Map.empty
-- Chain rule
go :: ComputeReverseM ()
go = forM_ reverseTopoOrder $ \nID -> do
--- NodeID of derivative w.r.t to current node: d(f) / d(nID)
dN <-
if nID == rootID
then sNum 1
else do
dPartsFromParent <- IM.lookup nID <$> gets computedPartsByParents
-- Sum all the derivative parts incurred by its parents
case dPartsFromParent of
Just [d] -> from d
Just ds -> perform (Nary specSum) ds
curMp <- gets contextMap
let (shape, et, op) = retrieveNode nID curMp
let one = introduceNode (shape, R, Const 1)
let zero = introduceNode (shape, R, Const 0)
case op of
Var name -> modifyPartialDerivativeMap (Map.insert name dN)
Const _ -> return ()
Sum args -> do
forM_ args $ \x -> do
let dX = dN
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
Mul args -> do
forM_ (removeEach args) $ \(x, rest) -> do
productRest <- perform (Nary specMul) rest
if et == R
then do
dX <- from dN * from productRest
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
else do
dX <- from dN * conjugate (from productRest)
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
Power alpha x -> case et of
R -> do
dX <- sNum (fromIntegral alpha) *. (from dN * (from x ^ (alpha -1)))
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
C -> do
dX <- sNum (fromIntegral alpha) *. (from dN * conjugate (from x ^ (alpha - 1)))
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
Neg x -> do
dX <- negate $ from dN
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
Scale scalar scalee -> do
case (retrieveElementType scalar curMp, retrieveElementType scalee curMp) of
(R, R) -> do
-- for scalar
dScalar <- from dN <.> from scalee
modifyComputedPartsByParents (IM.insertWith (++) scalar [dScalar])
-- for scalee
dScalee <- from scalar *. from dN
modifyComputedPartsByParents (IM.insertWith (++) scalee [dScalee])
(R, C) -> do
-- for scalar
dScalar <- xRe (from scalee) <.> xRe (from dN) + xIm (from scalee) <.> xIm (from dN)
modifyComputedPartsByParents (IM.insertWith (++) scalar [dScalar])
-- for scalee
dScalee <- from scalar *. from dN
modifyComputedPartsByParents (IM.insertWith (++) scalee [dScalee])
(C, C) -> do
-- for scalar
dScalar <- from dN <.> from scalee
modifyComputedPartsByParents (IM.insertWith (++) scalar [dScalar])
-- for scalee
dScalee <- conjugate (from scalar) *. from dN
modifyComputedPartsByParents (IM.insertWith (++) scalee [dScalee])
Div x y -> do
dX <- from dN / from y
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
dY <- from dN * from x * (from y ^ (-2))
modifyComputedPartsByParents (IM.insertWith (++) y [dY])
Sqrt x -> do
dX <- sNum 0.5 *. (from dN / sqrt (from x))
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
Sin x -> do
dX <- from dN * cos (from x)
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
Cos x -> do
dX <- from dN * (- sin (from x))
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
Tan x -> do
dX <- from dN * (cos (from x) ^ (-2))
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
Exp x -> do
dX <- from dN * exp (from x)
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
Log x -> do
dX <- from dN * (from x ^ (-1))
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
Sinh x -> do
dX <- from dN * cosh (from x)
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
Cosh x -> do
dX <- from dN * sinh (from x)
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
Tanh x -> do
dX <- from dN * (one - tanh (from x) ^ 2)
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
Asin x -> do
dX <- from dN * (one / sqrt (one - from x ^ 2))
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
Acos x -> do
dX <- from dN * (- one / sqrt (one - from x ^ 2))
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
Atan x -> do
dX <- from dN * (one / one + from x ^ 2)
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
Asinh x -> do
dX <- from dN * (one / sqrt (one + from x ^ 2))
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
Acosh x -> do
dX <- from dN * (one / sqrt (from x ^ 2 - one))
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
Atanh x -> do
dX <- from dN * (one / sqrt (one - from x ^ 2))
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
RealImag re im -> do
dRe <- xRe $ from dN
modifyComputedPartsByParents (IM.insertWith (++) re [dRe])
dIm <- xIm $ from dN
modifyComputedPartsByParents (IM.insertWith (++) im [dIm])
RealPart reIm -> do
dReIm <- from dN +: zero
modifyComputedPartsByParents (IM.insertWith (++) reIm [dReIm])
ImagPart reIm -> do
dReIm <- zero +: from dN
modifyComputedPartsByParents (IM.insertWith (++) reIm [dReIm])
InnerProd x y -> do
case et of
R -> do
dX <- from dN *. from y
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
dY <- from dN *. from x
modifyComputedPartsByParents (IM.insertWith (++) y [dY])
C -> do
dX <- from dN *. from y
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
dY <- conjugate (from dN) *. from x
modifyComputedPartsByParents (IM.insertWith (++) y [dY])
Piecewise marks condition branches -> do
dCondition <- zero
modifyComputedPartsByParents (IM.insertWith (++) condition [dCondition])
let numBranches = length branches
forM_ (zip branches [0 ..]) $ \(branch, idx) -> case et of
R -> do
associate <- piecewise marks (from condition) (replicate idx zero ++ [one] ++ replicate (numBranches - idx - 1) zero)
dBranch <- from dN * from associate
modifyComputedPartsByParents (IM.insertWith (++) branch [dBranch])
C -> do
let zeroC = zero +: zero
let oneC = one +: zero
associate <- piecewise marks (from condition) (replicate idx zeroC ++ [oneC] ++ replicate (numBranches - idx - 1) zeroC)
dBranch <- from dN * conjugate (from associate)
modifyComputedPartsByParents (IM.insertWith (++) branch [dBranch])
Rotate amount x -> do
dX <- perform (Unary (specRotate (map negate amount))) [dN]
dX <- rotate (map negate amount) $ from dN
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
ReFT x
| retrieveElementType x curMp == R -> do
dX <- reFT (from dN)
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
| otherwise -> do
dX <- reFT (from dN) +: (- imFT (from dN))
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
ImFT x
| retrieveElementType x curMp == R -> do
dX <- imFT (from dN)
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
| otherwise -> do
dX <- imFT (from dN) +: (- reFT (from dN))
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
(_, res) = runState go init
in (contextMap res, partialDerivativeMap res)

Loading

0 comments on commit 0788eea

Please sign in to comment.