Skip to content

Commit

Permalink
[ reverse mode differentiation ] done reverse mode
Browse files Browse the repository at this point in the history
  • Loading branch information
dandoh committed Jul 27, 2020
1 parent 25809b8 commit 62c9be7
Show file tree
Hide file tree
Showing 11 changed files with 159 additions and 125 deletions.
4 changes: 2 additions & 2 deletions HashedExpression.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ cabal-version: 1.12
--
-- see: https://github.com/sol/hpack
--
-- hash: 4519a110756abaa652902d0183e8fdaa36524efb3ae39bcd679502eba5d75a05
-- hash: 158d5ed7ffd77d8520dc6eedabb8d47f1240293abc2485d2358bc8fa6a262c85

name: HashedExpression
version: 0.1.0.0
Expand Down Expand Up @@ -39,9 +39,9 @@ library
HashedExpression.Codegen
HashedExpression.Codegen.CSIMD
HashedExpression.Codegen.CSimple
HashedExpression.Differentiation.Exterior
HashedExpression.Differentiation.Exterior.Collect
HashedExpression.Differentiation.Exterior.Derivative
HashedExpression.Differentiation.Exterior.Partial
HashedExpression.Differentiation.Reverse
HashedExpression.Differentiation.Reverse.State
HashedExpression.Embed.FFTW
Expand Down
40 changes: 40 additions & 0 deletions src/HashedExpression/Differentiation/Exterior.hs
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
-- |
-- Module : HashedExpression.Differentiation.Exterior.Derivative
-- Copyright : (c) OCA 2020
-- License : MIT (see the LICENSE file)
-- Maintainer : [email protected]
-- Stability : provisional
-- Portability : unportable
module HashedExpression.Differentiation.Exterior where

import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import Data.Maybe (mapMaybe)
import HashedExpression.Differentiation.Exterior.Collect
import HashedExpression.Differentiation.Exterior.Derivative
import HashedExpression.Internal
import HashedExpression.Internal.Expression
import HashedExpression.Internal.Node

partialDerivativesMapByExterior :: Expression Scalar R -> (ExpressionMap, Map String NodeID)
partialDerivativesMapByExterior exp =
let (mp, rootID) = unwrap . collectDifferentials . derivativeAllVars $ exp
in (mp, partialDerivativesMap (mp, rootID))

-- | Return a map from variable name to the corresponding partial derivative node id
-- Partial derivatives in Expression Scalar Covector should be collected before passing to this function
partialDerivativesMap :: (ExpressionMap, NodeID) -> Map String NodeID
partialDerivativesMap (dfMp, dfId) =
case retrieveOp dfId dfMp of
Sum ns | retrieveElementType dfId dfMp == Covector -> Map.fromList $ mapMaybe getPartial ns
_ -> Map.fromList $ mapMaybe getPartial [dfId]
where
getPartial :: NodeID -> Maybe (String, NodeID)
getPartial nId
| MulD partialId dId <- retrieveOp nId dfMp,
DVar name <- retrieveOp dId dfMp =
Just (name, partialId)
| InnerProdD partialId dId <- retrieveOp nId dfMp,
DVar name <- retrieveOp dId dfMp =
Just (name, partialId)
| otherwise = Nothing
34 changes: 0 additions & 34 deletions src/HashedExpression/Differentiation/Exterior/Partial.hs

This file was deleted.

65 changes: 28 additions & 37 deletions src/HashedExpression/Differentiation/Reverse.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,11 @@ import HashedExpression.Internal.OperationSpec
import HashedExpression.Internal.Structure
import Prelude hiding ((^))

compute ::
-- |
partialDerivativesMapByReverse ::
Expression Scalar R ->
(ExpressionMap, Map String NodeID)
compute (Expression rootID mp) =
partialDerivativesMapByReverse (Expression rootID mp) =
let reverseTopoOrder = reverse $ topologicalSort (mp, rootID)
init = ComputeDState mp IM.empty Map.empty
-- Chain rule
Expand All @@ -46,6 +47,8 @@ compute (Expression rootID mp) =
Just ds -> perform (Nary specSum) ds
curMp <- gets contextMap
let (shape, et, op) = retrieveNode nID curMp
let one = introduceNode (shape, R, Const 1)
let zero = introduceNode (shape, R, Const 0)
case op of
Var name -> modifyPartialDerivativeMap (Map.insert name dN)
Const _ -> return ()
Expand All @@ -68,7 +71,7 @@ compute (Expression rootID mp) =
dX <- sNum (fromIntegral alpha) *. (from dN * (from x ^ (alpha -1)))
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
C -> do
dX <- sNum (fromIntegral alpha) *. conjugate (from x ^ (alpha - 1))
dX <- sNum (fromIntegral alpha) *. (from dN * conjugate (from x ^ (alpha - 1)))
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
Neg x -> do
dX <- negate $ from dN
Expand Down Expand Up @@ -126,38 +129,35 @@ compute (Expression rootID mp) =
dX <- from dN * sinh (from x)
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
Tanh x -> do
let one = introduceNode (shape, R, Const 1)
dX <- from dN * (one - tanh (from x) ^ 2)
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
Asin x -> do
dX <- error "TODO"
dX <- from dN * (one / sqrt (one - from x ^ 2))
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
Acos x -> do
dX <- error "TODO"
dX <- from dN * (- one / sqrt (one - from x ^ 2))
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
Atan x -> do
dX <- error "TODO"
dX <- from dN * (one / one + from x ^ 2)
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
Asinh x -> do
dX <- error "TODO"
dX <- from dN * (one / sqrt (one + from x ^ 2))
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
Acosh x -> do
dX <- error "TODO"
dX <- from dN * (one / sqrt (from x ^ 2 - one))
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
Atanh x -> do
dX <- error "TODO"
dX <- from dN * (one / sqrt (one - from x ^ 2))
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
RealImag re im -> do
dRe <- xRe $ from dN
modifyComputedPartsByParents (IM.insertWith (++) re [dRe])
dIm <- xIm $ from dN
modifyComputedPartsByParents (IM.insertWith (++) im [dIm])
RealPart reIm -> do
let zero = introduceNode (shape, R, Const 0)
dReIm <- from dN +: zero
modifyComputedPartsByParents (IM.insertWith (++) reIm [dReIm])
ImagPart reIm -> do
let zero = introduceNode (shape, R, Const 0)
dReIm <- zero +: from dN
modifyComputedPartsByParents (IM.insertWith (++) reIm [dReIm])
InnerProd x y -> do
Expand All @@ -172,7 +172,21 @@ compute (Expression rootID mp) =
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
dY <- conjugate (from dN) *. from x
modifyComputedPartsByParents (IM.insertWith (++) y [dY])
Piecewise {} -> undefined
Piecewise marks condition branches -> do
dCondition <- zero
modifyComputedPartsByParents (IM.insertWith (++) condition [dCondition])
let numBranches = length branches
forM_ (zip branches [0 ..]) $ \(branch, idx) -> case et of
R -> do
associate <- piecewise marks (from condition) (replicate idx zero ++ [one] ++ replicate (numBranches - idx - 1) zero)
dBranch <- from dN * from associate
modifyComputedPartsByParents (IM.insertWith (++) branch [dBranch])
C -> do
let zeroC = zero +: zero
let oneC = one +: zero
associate <- piecewise marks (from condition) (replicate idx zeroC ++ [oneC] ++ replicate (numBranches - idx - 1) zeroC)
dBranch <- from dN * conjugate (from associate)
modifyComputedPartsByParents (IM.insertWith (++) branch [dBranch])
Rotate amount x -> do
dX <- perform (Unary (specRotate (map negate amount))) [dN]
dX <- rotate (map negate amount) $ from dN
Expand All @@ -192,28 +206,5 @@ compute (Expression rootID mp) =
dX <- imFT (from dN) +: (- reFT (from dN))
modifyComputedPartsByParents (IM.insertWith (++) x [dX])
(_, res) = runState go init
in -- res = flip runStateT init $ do
-- forM_ reverseTopoOrder $ \nID -> do
-- undefined

-- update :: (ExpressionMap, IM.IntMap [NodeID], IM.IntMap NodeID) -> NodeID -> (ExpressionMap, IM.IntMap [NodeID], IM.IntMap NodeID)
-- update (accMp, accsByID, dByID) nID =
-- let (shape, et, op) = retrieveNode nID accMp
-- derivativeCurrent
-- | nID == rootID = undefined
-- | otherwise = case IM.lookup nID accsByID of
-- Just [d] -> d
-- Just ds -> undefined
-- in undefined
(contextMap res, partialDerivativeMap res)

---
-- Re Im
--- (a + bi) <.> (c + di)
-- (a <.> c + b <.> d) + (b <.> c - a <.> d)
-- (Re *. c - Im *. d) (Re *. d + IM *. c)

-- (Re *. a + Im *. b) + (Re *. b - Im *. a)
--
in (contextMap res, partialDerivativeMap res)

-----------------------------------
12 changes: 11 additions & 1 deletion src/HashedExpression/Differentiation/Reverse/State.hs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import Data.List (foldl')
import Data.List.HT (removeEach)
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as Map
import GHC.Stack (HasCallStack)
import HashedExpression.Internal
import HashedExpression.Internal.Expression
import HashedExpression.Internal.Hash
Expand All @@ -29,28 +30,35 @@ data ComputeDState = ComputeDState
partialDerivativeMap :: Map String NodeID
}

-- |
modifyContextMap :: (ExpressionMap -> ExpressionMap) -> ComputeReverseM ()
modifyContextMap f = modify' $ \s -> s {contextMap = f (contextMap s)}

-- |
modifyComputedPartsByParents :: (IM.IntMap [NodeID] -> IM.IntMap [NodeID]) -> ComputeReverseM ()
modifyComputedPartsByParents f = modify' $ \s -> s {computedPartsByParents = f (computedPartsByParents s)}

-- |
modifyPartialDerivativeMap :: (Map String NodeID -> Map String NodeID) -> ComputeReverseM ()
modifyPartialDerivativeMap f = modify' $ \s -> s {partialDerivativeMap = f (partialDerivativeMap s)}

-- |
from :: NodeID -> ComputeReverseM NodeID
from = return

-- |
sNum :: Double -> ComputeReverseM NodeID
sNum val = introduceNode ([], R, Const val)

-- |
introduceNode :: Node -> ComputeReverseM NodeID
introduceNode node = do
mp <- gets contextMap
let nID = hashNode (checkHashFromMap mp) node
modify' $ \s -> s {contextMap = IM.insert nID node mp}
return nID

-- |
perform :: OperationSpec -> [NodeID] -> ComputeReverseM NodeID
perform spec operandIDs = do
mp <- gets contextMap
Expand All @@ -59,6 +67,7 @@ perform spec operandIDs = do
modify' $ \s -> s {contextMap = IM.insert nID node mp}
return nID

-- |
type ComputeReverseM a = State ComputeDState a

instance Num (ComputeReverseM NodeID) where
Expand All @@ -71,11 +80,12 @@ instance Num (ComputeReverseM NodeID) where
do
x <- operand
perform (Unary specNeg) [x]
(*) :: HasCallStack => ComputeReverseM NodeID -> ComputeReverseM NodeID -> ComputeReverseM NodeID
(*) operand1 operand2 =
do
x <- operand1
y <- operand2
perform (Nary specSum) [x, y]
perform (Nary specMul) [x, y]

instance Fractional (ComputeReverseM NodeID) where
(/) operand1 operand2 = do
Expand Down
2 changes: 1 addition & 1 deletion src/HashedExpression/Internal/OperationSpec.hs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ specRealImag =
decideShape x y = assertSame [x, y] x
decideET x y
| x == R && y == R = C
| otherwise = error "2 operands must be real"
| otherwise = error $ "2 operands must be real" ++ show [x, y]

specRealPart :: HasCallStack => UnarySpec
specRealPart =
Expand Down
13 changes: 13 additions & 0 deletions src/HashedExpression/Interp.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@
module HashedExpression.Interp
( Evaluable (..),
Approximable (..),
evaluate1DReal,
evaluate1DComplex,
evaluate2DReal,
evaluate2DComplex,
evaluate3DReal,
evaluate3DComplex,
)
where

Expand Down Expand Up @@ -278,6 +284,7 @@ instance Evaluable Scalar C (Complex Double) where
-- show the real and imaginary part of complex as x + i y
eval valMap (expZeroR mp arg1)
:+ eval valMap (expZeroR mp arg2)
Conjugate arg -> conjugate $ eval valMap (expZeroC mp arg)
InnerProd arg1 arg2 ->
-- evaluate the inner product in C
case retrieveShape arg1 mp of
Expand Down Expand Up @@ -472,6 +479,9 @@ evaluate1DComplex valMap (mp, n)
(:+)
(evaluate1DReal valMap $ (mp, arg1))
(evaluate1DReal valMap $ (mp, arg2))
Conjugate arg ->
let res = evaluate1DComplex valMap (mp, arg)
in fmap conjugate res
Piecewise marks conditionArg branchArgs ->
let cdt = evaluate1DReal valMap $ (mp, conditionArg)
branches =
Expand Down Expand Up @@ -643,6 +653,9 @@ evaluate2DComplex valMap (mp, n)
(:+)
(evaluate2DReal valMap $ (mp, arg1))
(evaluate2DReal valMap $ (mp, arg2))
Conjugate arg ->
let res = evaluate2DComplex valMap (mp, arg)
in fmap conjugate res
Piecewise marks conditionArg branchArgs ->
let cdt = evaluate2DReal valMap $ (mp, conditionArg)
branches =
Expand Down
25 changes: 3 additions & 22 deletions src/HashedExpression/Problem.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import qualified Data.Set as Set
import Debug.Trace (traceShowId)
import HashedExpression.Differentiation.Exterior.Collect
import HashedExpression.Differentiation.Exterior.Derivative
import HashedExpression.Differentiation.Exterior
import HashedExpression.Internal
import HashedExpression.Internal.Expression
import HashedExpression.Internal.Node
Expand Down Expand Up @@ -128,26 +129,6 @@ inf = 1 / 0

-------------------------------------------------------------------------------

-- | Return a map from variable name to the corresponding partial derivative node id
-- Partial derivatives in Expression Scalar Covector should be collected before passing to this function
partialDerivativeMaps :: (ExpressionMap, NodeID) -> Map String NodeID
partialDerivativeMaps (dfMp, dfId) =
case retrieveOp dfId dfMp of
Sum ns | retrieveElementType dfId dfMp == Covector -> Map.fromList $ mapMaybe getPartial ns
_ -> Map.fromList $ mapMaybe getPartial [dfId]
where
getPartial :: NodeID -> Maybe (String, NodeID)
getPartial nId
| MulD partialId dId <- retrieveOp nId dfMp,
DVar name <- retrieveOp dId dfMp =
Just (name, partialId)
| InnerProdD partialId dId <- retrieveOp nId dfMp,
DVar name <- retrieveOp dId dfMp =
Just (name, partialId)
| otherwise = Nothing

-------------------------------------------------------------------------------

-- | The statement of a constraint, including an 'ExpressionMap' subexpressions, the root 'NodeID' and its value
data ConstraintStatement
= -- | A lower bound constraint
Expand Down Expand Up @@ -342,7 +323,7 @@ constructProblemHelper obj names (Constraint constraints) = do
curMp <- get
let finalRelevantVars = filter (\(name, _) -> Set.member name varsSet) $ varNodesWithId curMp
let name2PartialDerivativeID :: Map String NodeID
name2PartialDerivativeID = partialDerivativeMaps (curMp, dfID)
name2PartialDerivativeID = partialDerivativesMap (curMp, dfID)
variables =
map
( \(name, varNodeID) ->
Expand All @@ -353,7 +334,7 @@ constructProblemHelper obj names (Constraint constraints) = do
let scalarConstraints =
map
( \(gID, dgID, (lb, ub)) ->
let name2PartialDerivativeID = partialDerivativeMaps (curMp, dgID)
let name2PartialDerivativeID = partialDerivativesMap (curMp, dgID)
in ScalarConstraint
{ constraintValueId = gID,
constraintPartialDerivatives = map (\(name, _) -> fromJust $ Map.lookup name name2PartialDerivativeID) finalRelevantVars,
Expand Down
Loading

0 comments on commit 62c9be7

Please sign in to comment.