From 46bdc31eb393a87a9ea13eee5cba2825f7f420c4 Mon Sep 17 00:00:00 2001 From: "Nhan Thai (dandoh)" Date: Sat, 25 Jul 2020 23:45:18 +0700 Subject: [PATCH 1/4] [ refactor ] move computing differnetiations by exterioir derivative into one modules dir --- HashedExpression.cabal | 8 ++++---- TODO.md | 1 - app/Main.hs | 3 +-- src/HashedExpression.hs | 4 ++-- .../Exterior/Collect.hs} | 4 ++-- .../{ => Differentiation/Exterior}/Derivative.hs | 4 ++-- .../{Derivative => Differentiation/Exterior}/Partial.hs | 8 +++----- src/HashedExpression/Problem.hs | 4 ++-- test/CollectSpec.hs | 4 ++-- test/ProblemSpec.hs | 4 ++-- test/Spec.hs | 2 +- 11 files changed, 21 insertions(+), 25 deletions(-) rename src/HashedExpression/{Internal/CollectDifferential.hs => Differentiation/Exterior/Collect.hs} (97%) rename src/HashedExpression/{ => Differentiation/Exterior}/Derivative.hs (98%) rename src/HashedExpression/{Derivative => Differentiation/Exterior}/Partial.hs (87%) diff --git a/HashedExpression.cabal b/HashedExpression.cabal index 361062d8..444287e3 100644 --- a/HashedExpression.cabal +++ b/HashedExpression.cabal @@ -4,7 +4,7 @@ cabal-version: 1.12 -- -- see: https://github.com/sol/hpack -- --- hash: 53e3c8434f7f2a6206e7b49dcb5c186d6ea5ffb6931a3c138160d6bbae9c09e8 +-- hash: 38c67ea70677719cae6b279fc3bc55b2570d2fd4eedd6d4b3f50a830535d56a8 name: HashedExpression version: 0.1.0.0 @@ -39,11 +39,11 @@ library HashedExpression.Codegen HashedExpression.Codegen.CSIMD HashedExpression.Codegen.CSimple - HashedExpression.Derivative - HashedExpression.Derivative.Partial + HashedExpression.Differentiation.Exterior.Collect + HashedExpression.Differentiation.Exterior.Derivative + HashedExpression.Differentiation.Exterior.Partial HashedExpression.Embed.FFTW HashedExpression.Internal - HashedExpression.Internal.CollectDifferential HashedExpression.Internal.Collision HashedExpression.Internal.Expression HashedExpression.Internal.Hash diff --git a/TODO.md b/TODO.md index a6802a4a..512ec576 100644 --- a/TODO.md +++ b/TODO.md @@ -31,7 +31,6 @@ stack haddock --haddock-arguments "--odir=docs/" ### TODO add regression tests for examples ### TODO Better interface for Transformation's (in Inner.hs) needed? - Make relationship between Transformation/Modification/Change clearer? Put in it's own module? -- toRecursiveSimplification and toRecursiveCollecting should be reduced to one function? ### TODO Make sure we don't introduce bugs doing CodeGen for FT ### TODO Maybe use Numeric Prelude for better Num class and then better VectorSpace ### TODO add cVariable1D name = variable1D (name ++ "Re") +: variable1D (name ++ "Im"), cVariable2D = ... diff --git a/app/Main.hs b/app/Main.hs index ba28450c..20c8cbca 100644 --- a/app/Main.hs +++ b/app/Main.hs @@ -11,10 +11,9 @@ import Data.Map (empty, fromList, union) import Data.Maybe (fromJust) import Data.STRef.Strict import qualified Data.Set as Set -import HashedExpression.Derivative.Partial import Graphics.EasyPlot import HashedExpression -import HashedExpression.Derivative +import HashedExpression.Differentiation.Exterior.Derivative import HashedExpression.Interp import HashedExpression.Operation import qualified HashedExpression.Operation diff --git a/src/HashedExpression.hs b/src/HashedExpression.hs index eab56f20..eaba1f2a 100644 --- a/src/HashedExpression.hs +++ b/src/HashedExpression.hs @@ -57,8 +57,8 @@ module HashedExpression ) where -import HashedExpression.Derivative -import HashedExpression.Internal.CollectDifferential +import HashedExpression.Differentiation.Exterior.Collect +import HashedExpression.Differentiation.Exterior.Derivative import HashedExpression.Internal.Expression import HashedExpression.Internal.Normalize import HashedExpression.Interp diff --git a/src/HashedExpression/Internal/CollectDifferential.hs b/src/HashedExpression/Differentiation/Exterior/Collect.hs similarity index 97% rename from src/HashedExpression/Internal/CollectDifferential.hs rename to src/HashedExpression/Differentiation/Exterior/Collect.hs index 9386618a..b03b0949 100644 --- a/src/HashedExpression/Internal/CollectDifferential.hs +++ b/src/HashedExpression/Differentiation/Exterior/Collect.hs @@ -1,5 +1,5 @@ -- | --- Module : HashedExpression.Internal.CollectDifferential +-- Module : HashedExpression.Differentiation.Exterior.Collect -- Copyright : (c) OCA 2020 -- License : MIT (see the LICENSE file) -- Maintainer : anandc@mcmaster.ca @@ -8,7 +8,7 @@ -- -- This module exists solely to factor terms around their differentials. When properly factored, the term multiplying -- a differential (say dx) is it's corresponding parital derivative (i.e derivative w.r.t x) -module HashedExpression.Internal.CollectDifferential +module HashedExpression.Differentiation.Exterior.Collect ( collectDifferentials, ) where diff --git a/src/HashedExpression/Derivative.hs b/src/HashedExpression/Differentiation/Exterior/Derivative.hs similarity index 98% rename from src/HashedExpression/Derivative.hs rename to src/HashedExpression/Differentiation/Exterior/Derivative.hs index 96462aae..e5cbb291 100644 --- a/src/HashedExpression/Derivative.hs +++ b/src/HashedExpression/Differentiation/Exterior/Derivative.hs @@ -1,7 +1,7 @@ {-# LANGUAGE ScopedTypeVariables #-} -- | --- Module : HashedExpression.Derivative +-- Module : HashedExpression.Differentiation.Exterior.Derivative -- Copyright : (c) OCA 2020 -- License : MIT (see the LICENSE file) -- Maintainer : anandc@mcmaster.ca @@ -17,7 +17,7 @@ -- Computing an exterior derivative on an expression @Expression d R@ will result in a @Expression d Covector@, i.e a 'Covector' field -- (also known as 1-form). This will contain 'dVar' terms representing where implicit differentiation has occurred. See 'CollectDifferential' -- to factor like terms for producing partial derivatives -module HashedExpression.Derivative +module HashedExpression.Differentiation.Exterior.Derivative ( exteriorDerivative, derivativeAllVars, ) diff --git a/src/HashedExpression/Derivative/Partial.hs b/src/HashedExpression/Differentiation/Exterior/Partial.hs similarity index 87% rename from src/HashedExpression/Derivative/Partial.hs rename to src/HashedExpression/Differentiation/Exterior/Partial.hs index 44b18ceb..0d00abc1 100644 --- a/src/HashedExpression/Derivative/Partial.hs +++ b/src/HashedExpression/Differentiation/Exterior/Partial.hs @@ -1,17 +1,15 @@ -module HashedExpression.Derivative.Partial where +module HashedExpression.Differentiation.Exterior.Partial where import qualified Data.Set as Set import HashedExpression -import HashedExpression.Derivative +import HashedExpression.Differentiation.Exterior.Collect +import HashedExpression.Differentiation.Exterior.Derivative import HashedExpression.Internal -import HashedExpression.Internal.CollectDifferential import HashedExpression.Internal.Expression import HashedExpression.Internal.Node import HashedExpression.Internal.Utils import HashedExpression.Prettify --- TODO move to Derivative - -- | Compute partial derivative: ∂f / ∂x. -- Automatically performs 'exteriorDerivative' w.r.t a single variable, uses 'collectDifferentials' to -- factor terms and extracts the term corresponding to the partial derivative w.r.t the given variable, diff --git a/src/HashedExpression/Problem.hs b/src/HashedExpression/Problem.hs index 7170a358..4741470c 100644 --- a/src/HashedExpression/Problem.hs +++ b/src/HashedExpression/Problem.hs @@ -23,9 +23,9 @@ import qualified Data.Map as Map import Data.Maybe (fromJust, fromMaybe, mapMaybe) import qualified Data.Set as Set import Debug.Trace (traceShowId) -import HashedExpression.Derivative +import HashedExpression.Differentiation.Exterior.Collect +import HashedExpression.Differentiation.Exterior.Derivative import HashedExpression.Internal -import HashedExpression.Internal.CollectDifferential import HashedExpression.Internal.Expression import HashedExpression.Internal.Node import HashedExpression.Internal.Normalize diff --git a/test/CollectSpec.hs b/test/CollectSpec.hs index dab69315..ae0bd38d 100644 --- a/test/CollectSpec.hs +++ b/test/CollectSpec.hs @@ -9,9 +9,9 @@ import Data.Maybe (fromJust) import qualified Data.Set as Set import Data.Tuple.Extra (thd3) import Debug.Trace (traceShow) -import HashedExpression.Derivative +import HashedExpression.Differentiation.Exterior.Collect +import HashedExpression.Differentiation.Exterior.Derivative import HashedExpression.Internal (D_, ET_, unwrap) -import HashedExpression.Internal.CollectDifferential import HashedExpression.Internal.Expression import HashedExpression.Internal.Node import HashedExpression.Internal.Normalize diff --git a/test/ProblemSpec.hs b/test/ProblemSpec.hs index 114ce63a..b0d0c8d7 100644 --- a/test/ProblemSpec.hs +++ b/test/ProblemSpec.hs @@ -24,9 +24,9 @@ import qualified Data.Text as T import qualified Data.Text.IO as TIO import Debug.Trace (traceShowId) import GHC.IO.Exception (ExitCode (..)) -import HashedExpression.Derivative +import HashedExpression.Differentiation.Exterior.Collect +import HashedExpression.Differentiation.Exterior.Derivative import HashedExpression.Internal -import HashedExpression.Internal.CollectDifferential import HashedExpression.Internal.Expression import HashedExpression.Internal.Node import HashedExpression.Internal.Normalize (normalize) diff --git a/test/Spec.hs b/test/Spec.hs index b58d8d36..0eaa5c2e 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -10,7 +10,7 @@ import qualified Data.IntMap as IM import Data.Map (fromList, union) import Data.Maybe (fromJust) import qualified Data.Set as Set -import HashedExpression.Internal.CollectDifferential +import HashedExpression.Differentiation.Exterior.Collect import HashedExpression.Internal.Expression import HashedExpression.Internal.Normalize import HashedExpression.Internal.Utils From ffd64dd314cabe35fff097cbd055134d6d4843bc Mon Sep 17 00:00:00 2001 From: "Nhan Thai (dandoh)" Date: Mon, 27 Jul 2020 16:39:11 +0700 Subject: [PATCH 2/4] [ reverse mode differentiation ] init --- HashedExpression.cabal | 3 +- .../Differentiation/Reverse.hs | 218 ++++++++++++++++++ src/HashedExpression/Internal.hs | 24 +- src/HashedExpression/Internal/Expression.hs | 2 + src/HashedExpression/Internal/Hash.hs | 1 + src/HashedExpression/Internal/Node.hs | 21 +- .../Internal/OperationSpec.hs | 3 + src/HashedExpression/Internal/Pattern.hs | 2 +- src/HashedExpression/Internal/Structure.hs | 8 + 9 files changed, 259 insertions(+), 23 deletions(-) create mode 100644 src/HashedExpression/Differentiation/Reverse.hs diff --git a/HashedExpression.cabal b/HashedExpression.cabal index 444287e3..9703132d 100644 --- a/HashedExpression.cabal +++ b/HashedExpression.cabal @@ -4,7 +4,7 @@ cabal-version: 1.12 -- -- see: https://github.com/sol/hpack -- --- hash: 38c67ea70677719cae6b279fc3bc55b2570d2fd4eedd6d4b3f50a830535d56a8 +-- hash: 163558ccb287da47d260cf5e9b4f7cea2c489811279103812d69f235ac594245 name: HashedExpression version: 0.1.0.0 @@ -42,6 +42,7 @@ library HashedExpression.Differentiation.Exterior.Collect HashedExpression.Differentiation.Exterior.Derivative HashedExpression.Differentiation.Exterior.Partial + HashedExpression.Differentiation.Reverse HashedExpression.Embed.FFTW HashedExpression.Internal HashedExpression.Internal.Collision diff --git a/src/HashedExpression/Differentiation/Reverse.hs b/src/HashedExpression/Differentiation/Reverse.hs new file mode 100644 index 00000000..cfa33956 --- /dev/null +++ b/src/HashedExpression/Differentiation/Reverse.hs @@ -0,0 +1,218 @@ +-- | +-- Module : HashedExpression.Differentiation.Exterior.Collect +-- Copyright : (c) OCA 2020 +-- License : MIT (see the LICENSE file) +-- Maintainer : anandc@mcmaster.ca +-- Stability : provisional +-- Portability : unportable +-- +-- Compute differentiations using reverse accumulation method +-- https://en.wikipedia.org/wiki/Automatic_differentiation#Reverse_accumulation +module HashedExpression.Differentiation.Reverse where + +import Control.Monad.State.Strict +import qualified Data.IntMap.Strict as IM +import Data.List (foldl') +import Data.List.HT (removeEach) +import Data.Map.Strict (Map) +import qualified Data.Map.Strict as Map +import HashedExpression.Internal +import HashedExpression.Internal.Expression +import HashedExpression.Internal.Hash +import HashedExpression.Internal.Node +import HashedExpression.Internal.OperationSpec +import HashedExpression.Internal.Structure + +--child2ParentsMap :: (ExpressionMap, NodeID) -> IM.IntMap [NodeID] +--child2ParentsMap (mp, rootID) = +-- let parent2ChildEdges = expressionEdges (mp, rootID) +-- in foldl' (\acc (parent, child) -> IM.insertWith (++) child [parent] acc) IM.empty parent2ChildEdges + +data ComputeDState = ComputeDState + { contextMap :: ExpressionMap, + computedPartsByParents :: IM.IntMap [NodeID], + partialDerivativeMap :: Map String NodeID + } + +type ComputeReverseM a = State ComputeDState a + +modifyContextMap :: (ExpressionMap -> ExpressionMap) -> ComputeReverseM () +modifyContextMap f = modify' $ \s -> s {contextMap = f (contextMap s)} + +modifyComputedPartsByParents :: (IM.IntMap [NodeID] -> IM.IntMap [NodeID]) -> ComputeReverseM () +modifyComputedPartsByParents f = modify' $ \s -> s {computedPartsByParents = f (computedPartsByParents s)} + +modifyPartialDerivativeMap :: (Map String NodeID -> Map String NodeID) -> ComputeReverseM () +modifyPartialDerivativeMap f = modify' $ \s -> s {partialDerivativeMap = f (partialDerivativeMap s)} + +introduceNode :: Node -> ComputeReverseM NodeID +introduceNode node = do + mp <- gets contextMap + let nID = hashNode (checkHashFromMap mp) node + modify' $ \s -> s {contextMap = IM.insert nID node mp} + return nID + +perform :: OperationSpec -> [NodeID] -> ComputeReverseM NodeID +perform spec operandIDs = do + mp <- gets contextMap + let operands = map (\nID -> (nID, retrieveNode nID mp)) operandIDs + let (nID, node) = createEntry (checkHashFromMap mp) spec operands + modify' $ \s -> s {contextMap = IM.insert nID node mp} + return nID + +compute :: + Expression Scalar R -> + (ExpressionMap, Map String NodeID) +compute (Expression rootID mp) = + let reverseTopoOrder = reverse $ topologicalSort (mp, rootID) + init = ComputeDState mp IM.empty Map.empty + -- Chain rule + go :: ComputeReverseM () + go = forM_ reverseTopoOrder $ \nID -> do + --- NodeID of derivative w.r.t to current node: d(f) / d(nID) + dN <- + if nID == rootID + then introduceNode ([], R, Const 1) + else do + dPartsFromParent <- IM.lookup nID <$> gets computedPartsByParents + -- Sum all the derivative parts incurred by its parents + case dPartsFromParent of + Just [d] -> pure d + Just ds -> perform (Nary specSum) ds + curMp <- gets contextMap + let (shape, et, op) = retrieveNode nID curMp + case op of + Var name -> modifyPartialDerivativeMap (Map.insert name dN) + Const _ -> return () + Sum args -> do + forM_ args $ \x -> do + let dX = dN + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) + Mul args -> do + forM_ (removeEach args) $ \(x, rest) -> do + productRest <- perform (Nary specMul) rest + if et == R + then do + dX <- perform (Nary specMul) [dN, productRest] + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) + else do + conjugateRest <- perform (Unary specConjugate) [productRest] + dX <- perform (Nary specMul) [dN, conjugateRest] + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) + Power alpha x -> do + rest <- perform (Unary (specPower (alpha - 1))) [x] + conjugateRest <- perform (Unary specConjugate) [rest] + dX <- perform (Nary specMul) [dN, conjugateRest] + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) + Neg x -> do + dX <- perform (Unary specNeg) [dN] + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) + Scale scalar scalee -> do + case (retrieveElementType scalar curMp, retrieveElementType scalee curMp) of + (R, R) -> do + -- for scalar + dScalar <- perform (Binary specInnerProd) [dN, scalee] + modifyComputedPartsByParents (IM.insertWith (++) scalar [dScalar]) + -- for scalee + dScalee <- perform (Binary specScale) [scalar, dN] + modifyComputedPartsByParents (IM.insertWith (++) scalee [dScalee]) + (R, C) -> do + -- for scalar + reScalee <- perform (Unary specRealPart) [scalar] + imScalee <- perform (Unary specImagPart) [scalar] + reDN <- perform (Unary specRealPart) [dN] + imDN <- perform (Unary specImagPart) [dN] + sm1 <- perform (Binary specInnerProd) [reScalee, reDN] + sm2 <- perform (Binary specInnerProd) [imScalee, imDN] + dScalar <- perform (Nary specSum) [sm1, sm2] + modifyComputedPartsByParents (IM.insertWith (++) scalar [dScalar]) + -- for scalee + dScalee <- perform (Binary specScale) [scalar, dN] + modifyComputedPartsByParents (IM.insertWith (++) scalee [dScalee]) + (C, C) -> do + dScalar <- perform (Binary specInnerProd) [dN, scalee] + modifyComputedPartsByParents (IM.insertWith (++) scalar [dScalar]) + conjugateScalar <- perform (Unary specConjugate) [dScalar] + dScalee <- perform (Binary specScale) [conjugateScalar, dN] + modifyComputedPartsByParents (IM.insertWith (++) scalee [dScalee]) + Div x y -> do + dX <- perform (Binary specDiv) [dN, y] + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) + temp1 <- perform (Unary (specPower (-2))) [y] + temp2 <- perform (Unary specNeg) [temp1] + dY <- perform (Nary specMul) [dN, x, temp2] + modifyComputedPartsByParents (IM.insertWith (++) y [dY]) + Sqrt {} -> undefined + Sin {} -> undefined + Cos {} -> undefined + Tan {} -> undefined + Exp {} -> undefined + Log {} -> undefined + Sinh {} -> undefined + Cosh {} -> undefined + Tanh {} -> undefined + Asin {} -> undefined + Acos {} -> undefined + Atan {} -> undefined + Asinh {} -> undefined + Acosh {} -> undefined + Atanh {} -> undefined + -- + RealImag re im -> do + dRe <- perform (Unary specRealPart) [dN] + dIm <- perform (Unary specImagPart) [dN] + modifyComputedPartsByParents (IM.insertWith (++) re [dRe]) + modifyComputedPartsByParents (IM.insertWith (++) im [dIm]) + RealPart reIm -> do + zeroIm <- introduceNode (shape, R, Const 0) + dReIm <- perform (Binary specRealImag) [dN, zeroIm] + modifyComputedPartsByParents (IM.insertWith (++) reIm [dReIm]) + ImagPart reIm -> do + zeroRe <- introduceNode (shape, R, Const 0) + dReIm <- perform (Binary specRealImag) [zeroRe, dN] + modifyComputedPartsByParents (IM.insertWith (++) reIm [dReIm]) + InnerProd x y -> do + case et of + R -> do + dX <- perform (Binary specScale) [dN, y] + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) + dY <- perform (Binary specScale) [dN, x] + modifyComputedPartsByParents (IM.insertWith (++) y [dY]) + C -> do + dX <- perform (Binary specScale) [dN, y] + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) + conjugateDN <- perform (Unary specConjugate) [dN] + dY <- perform (Binary specScale) [conjugateDN, x] + modifyComputedPartsByParents (IM.insertWith (++) y [dY]) + Piecewise {} -> undefined + Rotate amount x -> do + dX <- perform (Unary (specRotate (map negate amount))) [dN] + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) + ReFT {} -> undefined + ImFT {} -> undefined + (_, res) = runState go init + in -- res = flip runStateT init $ do + -- forM_ reverseTopoOrder $ \nID -> do + -- undefined + + -- update :: (ExpressionMap, IM.IntMap [NodeID], IM.IntMap NodeID) -> NodeID -> (ExpressionMap, IM.IntMap [NodeID], IM.IntMap NodeID) + -- update (accMp, accsByID, dByID) nID = + -- let (shape, et, op) = retrieveNode nID accMp + -- derivativeCurrent + -- | nID == rootID = undefined + -- | otherwise = case IM.lookup nID accsByID of + -- Just [d] -> d + -- Just ds -> undefined + -- in undefined + (contextMap res, partialDerivativeMap res) + +--- +-- Re Im +--- (a + bi) <.> (c + di) +-- (a <.> c + b <.> d) + (b <.> c - a <.> d) +-- (Re *. c - Im *. d) (Re *. d + IM *. c) + +-- (Re *. a + Im *. b) + (Re *. b - Im *. a) +-- + +----------------------------------- diff --git a/src/HashedExpression/Internal.hs b/src/HashedExpression/Internal.hs index 37869ff7..ce2d8234 100644 --- a/src/HashedExpression/Internal.hs +++ b/src/HashedExpression/Internal.hs @@ -34,7 +34,6 @@ import HashedExpression.Internal.Expression import HashedExpression.Internal.Hash import HashedExpression.Internal.Node import HashedExpression.Internal.OperationSpec -import HashedExpression.Internal.Structure import HashedExpression.Internal.Utils import Prelude hiding ((^)) @@ -270,14 +269,21 @@ multipleTimes outK smp exp = go (outK - 1) exp (smp exp) -- a new root 'NodeID' (contained inside a 'ExpressionDiff') from a list of other 'Change'. By passing along the -- base 'ExpressionMap' in each 'Change', we can assure there's no overlap when generating new 'Node' -- --- ExpressionDiff in the MonadReader ExpressionMap +-- ExpressionDiff is MonadReader ExpressionMap type Change = ExpressionMap -> ExpressionDiff --- | The 'ExpressionDiff' when adding a constant is just the constant node (generate by 'aConst') +-- | The 'ExpressionDiff' when adding a constant const_ :: Shape -> Double -> Change -const_ shape val mp = ExpressionDiff mp n - where - (mp, n) = aConst shape val +const_ shape val mp = + let node = (shape, R, Const val) + nID = hashNode (checkHashFromMap mp) node + in case IM.lookup nID mp of + Just _ -> ExpressionDiff IM.empty nID + _ -> ExpressionDiff (IM.singleton nID node) nID + +--const_ shape val mp = ExpressionDiff mp n +-- where +-- (mp, n) = aConst shape val -- | The 'Change' created when adding a single 'Scalar' constant num_ :: Double -> Change @@ -375,12 +381,6 @@ data ExpressionDiff = ExpressionDiff } deriving (Eq, Ord, Show) --- | The 'ExpressionDiff' when adding a constant is just the constant node (generate by 'aConst') -diffConst :: Shape -> Double -> ExpressionDiff -diffConst shape val = ExpressionDiff mp n - where - (mp, n) = aConst shape val - dZeroWithShape :: Shape -> ExpressionDiff dZeroWithShape shape = ExpressionDiff mp n where diff --git a/src/HashedExpression/Internal/Expression.hs b/src/HashedExpression/Internal/Expression.hs index 7378abd1..0c15e163 100644 --- a/src/HashedExpression/Internal/Expression.hs +++ b/src/HashedExpression/Internal/Expression.hs @@ -170,6 +170,8 @@ data Op RealPart Arg | -- | extract imaginary from complex (transforms @Expression d C@ to @Expression d R@) ImagPart Arg + | -- | conjugate a complex expression + Conjugate Arg | -- | inner product operator, overload via 'InnerProductSpace' InnerProd Arg Arg | -- | piecewise function, overload via 'PiecewiseOp'. Evaluates 'ConditionArg' to select 'BranchArg' diff --git a/src/HashedExpression/Internal/Hash.hs b/src/HashedExpression/Internal/Hash.hs index 2805300f..4af90890 100644 --- a/src/HashedExpression/Internal/Hash.hs +++ b/src/HashedExpression/Internal/Hash.hs @@ -91,6 +91,7 @@ hash (shape, et, node) rehashNum = RealPart arg -> offsetHash 24 . hashString' $ show arg ImagPart arg -> offsetHash 25 . hashString' $ show arg RealImag arg1 arg2 -> offsetHash 26 . hashString' $ show arg1 ++ separator ++ show arg2 + Conjugate arg -> offsetHash 39 . hashString' $ show arg -- InnerProd arg1 arg2 -> offsetHash 27 . hashString' $ show arg1 ++ separator ++ show arg2 Piecewise marks arg branches -> diff --git a/src/HashedExpression/Internal/Node.hs b/src/HashedExpression/Internal/Node.hs index e06dc7fa..ea30e9cf 100644 --- a/src/HashedExpression/Internal/Node.hs +++ b/src/HashedExpression/Internal/Node.hs @@ -69,15 +69,16 @@ nodeTypeWeight node = RealPart {} -> 24 ImagPart {} -> 25 InnerProd {} -> 26 - Piecewise {} -> 27 - Rotate {} -> 28 - ReFT {} -> 29 - ImFT {} -> 30 - TwiceReFT {} -> 31 - TwiceImFT {} -> 32 - Scale {} -> 33 -- Right after RealImag - RealImag {} -> 34 -- At the end right after sum - Sum {} -> 35 -- Sum at the end + Conjugate {} -> 27 + Piecewise {} -> 28 + Rotate {} -> 29 + ReFT {} -> 30 + ImFT {} -> 31 + TwiceReFT {} -> 32 + TwiceImFT {} -> 33 + Scale {} -> 34 -- Right after RealImag + RealImag {} -> 35 -- At the end right after sum + Sum {} -> 36 -- Sum at the end ------------------------ DVar {} -> 101 DZero {} -> 102 @@ -121,6 +122,7 @@ opArgs node = RealImag arg1 arg2 -> [arg1, arg2] RealPart arg -> [arg] ImagPart arg -> [arg] + Conjugate arg -> [arg] InnerProd arg1 arg2 -> [arg1, arg2] Piecewise _ conditionArg branches -> conditionArg : branches Rotate _ arg -> [arg] @@ -164,6 +166,7 @@ mapOp f op = RealImag arg1 arg2 -> RealImag (f arg1) (f arg2) RealPart arg -> RealPart (f arg) ImagPart arg -> ImagPart (f arg) + Conjugate arg -> Conjugate (f arg) InnerProd arg1 arg2 -> InnerProd (f arg1) (f arg2) Piecewise marks conditionArg branches -> Piecewise marks (f conditionArg) (map f branches) Rotate am arg -> Rotate am (f arg) diff --git a/src/HashedExpression/Internal/OperationSpec.hs b/src/HashedExpression/Internal/OperationSpec.hs index c802658d..59d6d19d 100644 --- a/src/HashedExpression/Internal/OperationSpec.hs +++ b/src/HashedExpression/Internal/OperationSpec.hs @@ -176,6 +176,9 @@ specImagPart = decideET x | x == C = R | otherwise = error "Must be complex" + +specConjugate :: HasCallStack => UnarySpec +specConjugate = defaultUnary Conjugate [C] specInnerProd :: HasCallStack => BinarySpec specInnerProd = diff --git a/src/HashedExpression/Internal/Pattern.hs b/src/HashedExpression/Internal/Pattern.hs index 0900d6b5..eeadc625 100644 --- a/src/HashedExpression/Internal/Pattern.hs +++ b/src/HashedExpression/Internal/Pattern.hs @@ -790,7 +790,7 @@ buildFromPattern exp@(originalMp, originalN) match = buildFromPattern' (Just $ r error "Capture not in the Map Capture Int which should never happen" PHead pl -> head $ buildFromPatternList exp match pl PConst val -> case inferredShape of - Just shape -> diffConst shape val + Just shape -> const_ shape val originalMp _ -> error "Can't infer shape of the constant" PSumList ptl -> applyDiff' (Nary specSum) . buildFromPatternList exp match $ ptl diff --git a/src/HashedExpression/Internal/Structure.hs b/src/HashedExpression/Internal/Structure.hs index d8dc74af..f021e4c6 100644 --- a/src/HashedExpression/Internal/Structure.hs +++ b/src/HashedExpression/Internal/Structure.hs @@ -18,8 +18,16 @@ import qualified Data.Set as Set import Debug.Trace (traceShowId) import GHC.Exts (sortWith) import GHC.Stack (HasCallStack) +import HashedExpression.Internal import HashedExpression.Internal.Expression import HashedExpression.Internal.Hash import HashedExpression.Internal.Node import HashedExpression.Internal.OperationSpec import HashedExpression.Internal.Utils + +-- | Edges from parent to children +expressionEdges :: (ExpressionMap, NodeID) -> [(NodeID, NodeID)] +expressionEdges (mp, rootID) = + [ (nID, child) | nID <- topologicalSort (mp, rootID), child <- opArgs $ retrieveOp nID mp + ] + From 25809b8e0626ea4cdf30dcbe2f1571724fc8cf6d Mon Sep 17 00:00:00 2001 From: "Nhan Thai (dandoh)" Date: Mon, 27 Jul 2020 20:05:18 +0700 Subject: [PATCH 3/4] [ reverse mode differentiation ] implementation for each node --- HashedExpression.cabal | 4 +- src/HashedExpression/Codegen/CSimple.hs | 5 + .../Differentiation/Reverse.hs | 189 +++++++++--------- .../Differentiation/Reverse/State.hs | 186 +++++++++++++++++ src/HashedExpression/Internal.hs | 13 +- src/HashedExpression/Internal/Expression.hs | 3 + .../Internal/OperationSpec.hs | 4 +- src/HashedExpression/Internal/Pattern.hs | 3 + src/HashedExpression/Internal/Structure.hs | 1 - src/HashedExpression/Operation.hs | 2 + src/HashedExpression/Prettify.hs | 1 + test/ReverseDifferentiationSpec.hs | 43 ++++ 12 files changed, 350 insertions(+), 104 deletions(-) create mode 100644 src/HashedExpression/Differentiation/Reverse/State.hs create mode 100644 test/ReverseDifferentiationSpec.hs diff --git a/HashedExpression.cabal b/HashedExpression.cabal index 9703132d..ffa09b43 100644 --- a/HashedExpression.cabal +++ b/HashedExpression.cabal @@ -4,7 +4,7 @@ cabal-version: 1.12 -- -- see: https://github.com/sol/hpack -- --- hash: 163558ccb287da47d260cf5e9b4f7cea2c489811279103812d69f235ac594245 +-- hash: 4519a110756abaa652902d0183e8fdaa36524efb3ae39bcd679502eba5d75a05 name: HashedExpression version: 0.1.0.0 @@ -43,6 +43,7 @@ library HashedExpression.Differentiation.Exterior.Derivative HashedExpression.Differentiation.Exterior.Partial HashedExpression.Differentiation.Reverse + HashedExpression.Differentiation.Reverse.State HashedExpression.Embed.FFTW HashedExpression.Internal HashedExpression.Internal.Collision @@ -181,6 +182,7 @@ test-suite HashedExpression-test InterpSpec NormalizeSpec ProblemSpec + ReverseDifferentiationSpec StructureSpec Var Paths_HashedExpression diff --git a/src/HashedExpression/Codegen/CSimple.hs b/src/HashedExpression/Codegen/CSimple.hs index ddb2dcb0..46495028 100644 --- a/src/HashedExpression/Codegen/CSimple.hs +++ b/src/HashedExpression/Codegen/CSimple.hs @@ -219,6 +219,11 @@ evaluating CSimpleCodegen {..} rootIDs = ] RealPart arg -> for i (len n) [[I.i|#{n !! i} = #{arg `reAt` i};|]] ImagPart arg -> for i (len n) [[I.i|#{n !! i} = #{arg `imAt` i};|]] + Conjugate arg -> + for i (len n) $ + [ [I.i|#{n `reAt` i} = #{arg `reAt` i};|], + [I.i|#{n `imAt` i} = -#{arg `imAt` i};|] + ] InnerProd arg1 arg2 | et == R && null (shapeOf arg1) -> [[I.i|#{n !! nooffset} = #{arg1 !! nooffset} * #{arg2 !! nooffset};|]] | et == R -> diff --git a/src/HashedExpression/Differentiation/Reverse.hs b/src/HashedExpression/Differentiation/Reverse.hs index cfa33956..042aa27e 100644 --- a/src/HashedExpression/Differentiation/Reverse.hs +++ b/src/HashedExpression/Differentiation/Reverse.hs @@ -16,49 +16,14 @@ import Data.List (foldl') import Data.List.HT (removeEach) import Data.Map.Strict (Map) import qualified Data.Map.Strict as Map +import HashedExpression.Differentiation.Reverse.State import HashedExpression.Internal import HashedExpression.Internal.Expression import HashedExpression.Internal.Hash import HashedExpression.Internal.Node import HashedExpression.Internal.OperationSpec import HashedExpression.Internal.Structure - ---child2ParentsMap :: (ExpressionMap, NodeID) -> IM.IntMap [NodeID] ---child2ParentsMap (mp, rootID) = --- let parent2ChildEdges = expressionEdges (mp, rootID) --- in foldl' (\acc (parent, child) -> IM.insertWith (++) child [parent] acc) IM.empty parent2ChildEdges - -data ComputeDState = ComputeDState - { contextMap :: ExpressionMap, - computedPartsByParents :: IM.IntMap [NodeID], - partialDerivativeMap :: Map String NodeID - } - -type ComputeReverseM a = State ComputeDState a - -modifyContextMap :: (ExpressionMap -> ExpressionMap) -> ComputeReverseM () -modifyContextMap f = modify' $ \s -> s {contextMap = f (contextMap s)} - -modifyComputedPartsByParents :: (IM.IntMap [NodeID] -> IM.IntMap [NodeID]) -> ComputeReverseM () -modifyComputedPartsByParents f = modify' $ \s -> s {computedPartsByParents = f (computedPartsByParents s)} - -modifyPartialDerivativeMap :: (Map String NodeID -> Map String NodeID) -> ComputeReverseM () -modifyPartialDerivativeMap f = modify' $ \s -> s {partialDerivativeMap = f (partialDerivativeMap s)} - -introduceNode :: Node -> ComputeReverseM NodeID -introduceNode node = do - mp <- gets contextMap - let nID = hashNode (checkHashFromMap mp) node - modify' $ \s -> s {contextMap = IM.insert nID node mp} - return nID - -perform :: OperationSpec -> [NodeID] -> ComputeReverseM NodeID -perform spec operandIDs = do - mp <- gets contextMap - let operands = map (\nID -> (nID, retrieveNode nID mp)) operandIDs - let (nID, node) = createEntry (checkHashFromMap mp) spec operands - modify' $ \s -> s {contextMap = IM.insert nID node mp} - return nID +import Prelude hiding ((^)) compute :: Expression Scalar R -> @@ -72,12 +37,12 @@ compute (Expression rootID mp) = --- NodeID of derivative w.r.t to current node: d(f) / d(nID) dN <- if nID == rootID - then introduceNode ([], R, Const 1) + then sNum 1 else do dPartsFromParent <- IM.lookup nID <$> gets computedPartsByParents -- Sum all the derivative parts incurred by its parents case dPartsFromParent of - Just [d] -> pure d + Just [d] -> from d Just ds -> perform (Nary specSum) ds curMp <- gets contextMap let (shape, et, op) = retrieveNode nID curMp @@ -93,103 +58,139 @@ compute (Expression rootID mp) = productRest <- perform (Nary specMul) rest if et == R then do - dX <- perform (Nary specMul) [dN, productRest] + dX <- from dN * from productRest modifyComputedPartsByParents (IM.insertWith (++) x [dX]) else do - conjugateRest <- perform (Unary specConjugate) [productRest] - dX <- perform (Nary specMul) [dN, conjugateRest] + dX <- from dN * conjugate (from productRest) modifyComputedPartsByParents (IM.insertWith (++) x [dX]) - Power alpha x -> do - rest <- perform (Unary (specPower (alpha - 1))) [x] - conjugateRest <- perform (Unary specConjugate) [rest] - dX <- perform (Nary specMul) [dN, conjugateRest] - modifyComputedPartsByParents (IM.insertWith (++) x [dX]) + Power alpha x -> case et of + R -> do + dX <- sNum (fromIntegral alpha) *. (from dN * (from x ^ (alpha -1))) + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) + C -> do + dX <- sNum (fromIntegral alpha) *. conjugate (from x ^ (alpha - 1)) + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) Neg x -> do - dX <- perform (Unary specNeg) [dN] + dX <- negate $ from dN modifyComputedPartsByParents (IM.insertWith (++) x [dX]) Scale scalar scalee -> do case (retrieveElementType scalar curMp, retrieveElementType scalee curMp) of (R, R) -> do -- for scalar - dScalar <- perform (Binary specInnerProd) [dN, scalee] + dScalar <- from dN <.> from scalee modifyComputedPartsByParents (IM.insertWith (++) scalar [dScalar]) -- for scalee - dScalee <- perform (Binary specScale) [scalar, dN] + dScalee <- from scalar *. from dN modifyComputedPartsByParents (IM.insertWith (++) scalee [dScalee]) (R, C) -> do -- for scalar - reScalee <- perform (Unary specRealPart) [scalar] - imScalee <- perform (Unary specImagPart) [scalar] - reDN <- perform (Unary specRealPart) [dN] - imDN <- perform (Unary specImagPart) [dN] - sm1 <- perform (Binary specInnerProd) [reScalee, reDN] - sm2 <- perform (Binary specInnerProd) [imScalee, imDN] - dScalar <- perform (Nary specSum) [sm1, sm2] + dScalar <- xRe (from scalee) <.> xRe (from dN) + xIm (from scalee) <.> xIm (from dN) modifyComputedPartsByParents (IM.insertWith (++) scalar [dScalar]) -- for scalee - dScalee <- perform (Binary specScale) [scalar, dN] + dScalee <- from scalar *. from dN modifyComputedPartsByParents (IM.insertWith (++) scalee [dScalee]) (C, C) -> do - dScalar <- perform (Binary specInnerProd) [dN, scalee] + -- for scalar + dScalar <- from dN <.> from scalee modifyComputedPartsByParents (IM.insertWith (++) scalar [dScalar]) - conjugateScalar <- perform (Unary specConjugate) [dScalar] - dScalee <- perform (Binary specScale) [conjugateScalar, dN] + -- for scalee + dScalee <- conjugate (from scalar) *. from dN modifyComputedPartsByParents (IM.insertWith (++) scalee [dScalee]) Div x y -> do - dX <- perform (Binary specDiv) [dN, y] + dX <- from dN / from y modifyComputedPartsByParents (IM.insertWith (++) x [dX]) - temp1 <- perform (Unary (specPower (-2))) [y] - temp2 <- perform (Unary specNeg) [temp1] - dY <- perform (Nary specMul) [dN, x, temp2] + dY <- from dN * from x * (from y ^ (-2)) modifyComputedPartsByParents (IM.insertWith (++) y [dY]) - Sqrt {} -> undefined - Sin {} -> undefined - Cos {} -> undefined - Tan {} -> undefined - Exp {} -> undefined - Log {} -> undefined - Sinh {} -> undefined - Cosh {} -> undefined - Tanh {} -> undefined - Asin {} -> undefined - Acos {} -> undefined - Atan {} -> undefined - Asinh {} -> undefined - Acosh {} -> undefined - Atanh {} -> undefined - -- + Sqrt x -> do + dX <- sNum 0.5 *. (from dN / sqrt (from x)) + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) + Sin x -> do + dX <- from dN * cos (from x) + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) + Cos x -> do + dX <- from dN * (- sin (from x)) + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) + Tan x -> do + dX <- from dN * (cos (from x) ^ (-2)) + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) + Exp x -> do + dX <- from dN * exp (from x) + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) + Log x -> do + dX <- from dN * (from x ^ (-1)) + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) + Sinh x -> do + dX <- from dN * cosh (from x) + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) + Cosh x -> do + dX <- from dN * sinh (from x) + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) + Tanh x -> do + let one = introduceNode (shape, R, Const 1) + dX <- from dN * (one - tanh (from x) ^ 2) + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) + Asin x -> do + dX <- error "TODO" + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) + Acos x -> do + dX <- error "TODO" + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) + Atan x -> do + dX <- error "TODO" + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) + Asinh x -> do + dX <- error "TODO" + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) + Acosh x -> do + dX <- error "TODO" + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) + Atanh x -> do + dX <- error "TODO" + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) RealImag re im -> do - dRe <- perform (Unary specRealPart) [dN] - dIm <- perform (Unary specImagPart) [dN] + dRe <- xRe $ from dN modifyComputedPartsByParents (IM.insertWith (++) re [dRe]) + dIm <- xIm $ from dN modifyComputedPartsByParents (IM.insertWith (++) im [dIm]) RealPart reIm -> do - zeroIm <- introduceNode (shape, R, Const 0) - dReIm <- perform (Binary specRealImag) [dN, zeroIm] + let zero = introduceNode (shape, R, Const 0) + dReIm <- from dN +: zero modifyComputedPartsByParents (IM.insertWith (++) reIm [dReIm]) ImagPart reIm -> do - zeroRe <- introduceNode (shape, R, Const 0) - dReIm <- perform (Binary specRealImag) [zeroRe, dN] + let zero = introduceNode (shape, R, Const 0) + dReIm <- zero +: from dN modifyComputedPartsByParents (IM.insertWith (++) reIm [dReIm]) InnerProd x y -> do case et of R -> do - dX <- perform (Binary specScale) [dN, y] + dX <- from dN *. from y modifyComputedPartsByParents (IM.insertWith (++) x [dX]) - dY <- perform (Binary specScale) [dN, x] + dY <- from dN *. from x modifyComputedPartsByParents (IM.insertWith (++) y [dY]) C -> do - dX <- perform (Binary specScale) [dN, y] + dX <- from dN *. from y modifyComputedPartsByParents (IM.insertWith (++) x [dX]) - conjugateDN <- perform (Unary specConjugate) [dN] - dY <- perform (Binary specScale) [conjugateDN, x] + dY <- conjugate (from dN) *. from x modifyComputedPartsByParents (IM.insertWith (++) y [dY]) Piecewise {} -> undefined Rotate amount x -> do dX <- perform (Unary (specRotate (map negate amount))) [dN] + dX <- rotate (map negate amount) $ from dN modifyComputedPartsByParents (IM.insertWith (++) x [dX]) - ReFT {} -> undefined - ImFT {} -> undefined + ReFT x + | retrieveElementType x curMp == R -> do + dX <- reFT (from dN) + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) + | otherwise -> do + dX <- reFT (from dN) +: (- imFT (from dN)) + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) + ImFT x + | retrieveElementType x curMp == R -> do + dX <- imFT (from dN) + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) + | otherwise -> do + dX <- imFT (from dN) +: (- reFT (from dN)) + modifyComputedPartsByParents (IM.insertWith (++) x [dX]) (_, res) = runState go init in -- res = flip runStateT init $ do -- forM_ reverseTopoOrder $ \nID -> do diff --git a/src/HashedExpression/Differentiation/Reverse/State.hs b/src/HashedExpression/Differentiation/Reverse/State.hs new file mode 100644 index 00000000..d6edb4f6 --- /dev/null +++ b/src/HashedExpression/Differentiation/Reverse/State.hs @@ -0,0 +1,186 @@ +-- | +-- Module : HashedExpression.Differentiation.Exterior.Collect +-- Copyright : (c) OCA 2020 +-- License : MIT (see the LICENSE file) +-- Maintainer : anandc@mcmaster.ca +-- Stability : provisional +-- Portability : unportable +-- +-- Helper for reverse accumulation method +module HashedExpression.Differentiation.Reverse.State where + +import Control.Monad.State.Strict +import qualified Data.IntMap.Strict as IM +import Data.List (foldl') +import Data.List.HT (removeEach) +import Data.Map.Strict (Map) +import qualified Data.Map.Strict as Map +import HashedExpression.Internal +import HashedExpression.Internal.Expression +import HashedExpression.Internal.Hash +import HashedExpression.Internal.Node +import HashedExpression.Internal.OperationSpec +import HashedExpression.Internal.Structure +import Prelude hiding ((^)) + +data ComputeDState = ComputeDState + { contextMap :: ExpressionMap, + computedPartsByParents :: IM.IntMap [NodeID], + partialDerivativeMap :: Map String NodeID + } + +modifyContextMap :: (ExpressionMap -> ExpressionMap) -> ComputeReverseM () +modifyContextMap f = modify' $ \s -> s {contextMap = f (contextMap s)} + +modifyComputedPartsByParents :: (IM.IntMap [NodeID] -> IM.IntMap [NodeID]) -> ComputeReverseM () +modifyComputedPartsByParents f = modify' $ \s -> s {computedPartsByParents = f (computedPartsByParents s)} + +modifyPartialDerivativeMap :: (Map String NodeID -> Map String NodeID) -> ComputeReverseM () +modifyPartialDerivativeMap f = modify' $ \s -> s {partialDerivativeMap = f (partialDerivativeMap s)} + +from :: NodeID -> ComputeReverseM NodeID +from = return + +sNum :: Double -> ComputeReverseM NodeID +sNum val = introduceNode ([], R, Const val) + +introduceNode :: Node -> ComputeReverseM NodeID +introduceNode node = do + mp <- gets contextMap + let nID = hashNode (checkHashFromMap mp) node + modify' $ \s -> s {contextMap = IM.insert nID node mp} + return nID + +perform :: OperationSpec -> [NodeID] -> ComputeReverseM NodeID +perform spec operandIDs = do + mp <- gets contextMap + let operands = map (\nID -> (nID, retrieveNode nID mp)) operandIDs + let (nID, node) = createEntry (checkHashFromMap mp) spec operands + modify' $ \s -> s {contextMap = IM.insert nID node mp} + return nID + +type ComputeReverseM a = State ComputeDState a + +instance Num (ComputeReverseM NodeID) where + (+) operand1 operand2 = + do + x <- operand1 + y <- operand2 + perform (Nary specSum) [x, y] + negate operand = + do + x <- operand + perform (Unary specNeg) [x] + (*) operand1 operand2 = + do + x <- operand1 + y <- operand2 + perform (Nary specSum) [x, y] + +instance Fractional (ComputeReverseM NodeID) where + (/) operand1 operand2 = do + x <- operand1 + y <- operand2 + perform (Binary specDiv) [x, y] + + fromRational r = error "N/A" + +instance Floating (ComputeReverseM NodeID) where + sqrt operand = do + x <- operand + perform (Unary specSqrt) [x] + exp operand = do + x <- operand + perform (Unary specExp) [x] + log operand = do + x <- operand + perform (Unary specLog) [x] + sin operand = do + x <- operand + perform (Unary specSin) [x] + cos operand = do + x <- operand + perform (Unary specCos) [x] + tan operand = do + x <- operand + perform (Unary specTan) [x] + asin operand = do + x <- operand + perform (Unary specAsin) [x] + acos operand = do + x <- operand + perform (Unary specAcos) [x] + atan operand = do + x <- operand + perform (Unary specAtan) [x] + sinh operand = do + x <- operand + perform (Unary specSinh) [x] + cosh operand = do + x <- operand + perform (Unary specCosh) [x] + tanh operand = do + x <- operand + perform (Unary specTanh) [x] + asinh operand = do + x <- operand + perform (Unary specAsinh) [x] + acosh operand = do + x <- operand + perform (Unary specAcosh) [x] + atanh operand = do + x <- operand + perform (Unary specAtanh) [x] + +instance PowerOp (ComputeReverseM NodeID) Int where + (^) operand alpha = do + x <- operand + perform (Unary (specPower alpha)) [x] + +instance VectorSpaceOp (ComputeReverseM NodeID) (ComputeReverseM NodeID) where + scale operand1 operand2 = do + x <- operand1 + y <- operand2 + perform (Binary specScale) [x, y] + +instance ComplexRealOp (ComputeReverseM NodeID) (ComputeReverseM NodeID) where + (+:) operand1 operand2 = do + x <- operand1 + y <- operand2 + perform (Binary specRealImag) [x, y] + xRe operand1 = do + x <- operand1 + perform (Unary specRealPart) [x] + xIm operand1 = do + x <- operand1 + perform (Unary specImagPart) [x] + conjugate operand = do + x <- operand + perform (Unary specConjugate) [x] + +instance InnerProductSpaceOp (ComputeReverseM NodeID) (ComputeReverseM NodeID) (ComputeReverseM NodeID) where + (<.>) operand1 operand2 = do + x <- operand1 + y <- operand2 + perform (Binary specInnerProd) [x, y] + +instance RotateOp RotateAmount (ComputeReverseM NodeID) where + rotate ra operand = do + x <- operand + perform (Unary (specRotate ra)) [x] + +instance PiecewiseOp (ComputeReverseM NodeID) (ComputeReverseM NodeID) where + piecewise marks condition branches = do + conditionID <- condition + branchIDs <- sequence branches + perform (ConditionAry (specPiecewise marks)) $ conditionID : branchIDs + +reFT :: ComputeReverseM NodeID -> ComputeReverseM NodeID +reFT operand = do + x <- operand + perform (Unary specReFT) [x] + +imFT :: ComputeReverseM NodeID -> ComputeReverseM NodeID +imFT operand = do + x <- operand + perform (Unary specImFT) [x] diff --git a/src/HashedExpression/Internal.hs b/src/HashedExpression/Internal.hs index ce2d8234..3c8fa62d 100644 --- a/src/HashedExpression/Internal.hs +++ b/src/HashedExpression/Internal.hs @@ -277,9 +277,9 @@ const_ :: Shape -> Double -> Change const_ shape val mp = let node = (shape, R, Const val) nID = hashNode (checkHashFromMap mp) node - in case IM.lookup nID mp of - Just _ -> ExpressionDiff IM.empty nID - _ -> ExpressionDiff (IM.singleton nID node) nID + in case IM.lookup nID mp of + Just _ -> ExpressionDiff IM.empty nID + _ -> ExpressionDiff (IM.singleton nID node) nID --const_ shape val mp = ExpressionDiff mp n -- where @@ -305,9 +305,9 @@ instance Num Change where (+) change1 change2 mp = applyDiff mp (Nary specSum) [change1 mp, change2 mp] negate change mp = applyDiff mp (Unary specNeg) [change mp] (*) change1 change2 mp = applyDiff mp (Nary specMul) [change1 mp, change2 mp] - signum = error "The Change of signum is currently unimplemented" - abs = error "The Change of abs is currently unimplemented" - fromInteger = error "The change of fromInteger is currently unimplemented" + signum = error "signum change" + abs = error "abs change" + fromInteger = error "from integer" instance Fractional Change where (/) change1 change2 = change1 * (change2 ^ (-1)) @@ -341,6 +341,7 @@ instance ComplexRealOp Change Change where (+:) change1 change2 mp = applyDiff mp (Binary specRealImag) [change1 mp, change2 mp] xRe change1 mp = applyDiff mp (Unary specRealPart) [change1 mp] xIm change1 mp = applyDiff mp (Unary specImagPart) [change1 mp] + conjugate change mp = applyDiff mp (Unary specConjugate) [change mp] instance InnerProductSpaceOp Change Change Change where (<.>) change1 change2 mp = diff --git a/src/HashedExpression/Internal/Expression.hs b/src/HashedExpression/Internal/Expression.hs index 0c15e163..1997b49d 100644 --- a/src/HashedExpression/Internal/Expression.hs +++ b/src/HashedExpression/Internal/Expression.hs @@ -398,6 +398,9 @@ class ComplexRealOp r c | r -> c, c -> r where -- | extract imaginary part from complex data xIm :: c -> r + -- conjugate + conjugate :: c -> c + -- | Interface for Inner Product combinator for constructing 'Expression' types. Can be overloaded -- to support different functionality performed on 'Expresion' (such as evaluation, pattern matching, code generation) class InnerProductSpaceOp a b c | a b -> c where diff --git a/src/HashedExpression/Internal/OperationSpec.hs b/src/HashedExpression/Internal/OperationSpec.hs index 59d6d19d..640e3203 100644 --- a/src/HashedExpression/Internal/OperationSpec.hs +++ b/src/HashedExpression/Internal/OperationSpec.hs @@ -176,8 +176,8 @@ specImagPart = decideET x | x == C = R | otherwise = error "Must be complex" - -specConjugate :: HasCallStack => UnarySpec + +specConjugate :: HasCallStack => UnarySpec specConjugate = defaultUnary Conjugate [C] specInnerProd :: HasCallStack => BinarySpec diff --git a/src/HashedExpression/Internal/Pattern.hs b/src/HashedExpression/Internal/Pattern.hs index eeadc625..f520de97 100644 --- a/src/HashedExpression/Internal/Pattern.hs +++ b/src/HashedExpression/Internal/Pattern.hs @@ -229,6 +229,8 @@ data Pattern PRealPart Pattern | -- | pattern that has a imaginary part extraction operator applied to it PImagPart Pattern + | -- | + PConjugate Pattern | -- | pattern that has a inner product operator applied to it PInnerProd Pattern Pattern | -- | pattern that has a piecewise @@ -363,6 +365,7 @@ instance ComplexRealOp Pattern Pattern where (+:) = PRealImag xRe = PRealPart xIm = PImagPart + conjugate = PConjugate instance InnerProductSpaceOp Pattern Pattern Pattern where (<.>) = PInnerProd diff --git a/src/HashedExpression/Internal/Structure.hs b/src/HashedExpression/Internal/Structure.hs index f021e4c6..a43c741d 100644 --- a/src/HashedExpression/Internal/Structure.hs +++ b/src/HashedExpression/Internal/Structure.hs @@ -30,4 +30,3 @@ expressionEdges :: (ExpressionMap, NodeID) -> [(NodeID, NodeID)] expressionEdges (mp, rootID) = [ (nID, child) | nID <- topologicalSort (mp, rootID), child <- opArgs $ retrieveOp nID mp ] - diff --git a/src/HashedExpression/Operation.hs b/src/HashedExpression/Operation.hs index c31ff55a..d15401e8 100644 --- a/src/HashedExpression/Operation.hs +++ b/src/HashedExpression/Operation.hs @@ -138,6 +138,8 @@ instance (DimensionType d) => ComplexRealOp (Expression d R) (Expression d C) wh xRe = applyUnary specRealPart xIm :: Expression d C -> Expression d R xIm = applyUnary specImagPart + conjugate :: Expression d C -> Expression d C + conjugate = applyUnary specConjugate instance (InnerProductSpace d s) => diff --git a/src/HashedExpression/Prettify.hs b/src/HashedExpression/Prettify.hs index 75b231b4..c935205d 100644 --- a/src/HashedExpression/Prettify.hs +++ b/src/HashedExpression/Prettify.hs @@ -164,6 +164,7 @@ hiddenPrettify pastable (mp, n) = RealImag arg1 arg2 -> T.concat [innerPrettify arg1, "+:", innerPrettify arg2] RealPart arg -> T.concat ["Re", wrapParentheses $ innerPrettify arg] ImagPart arg -> T.concat ["Im", wrapParentheses $ innerPrettify arg] + Conjugate arg -> T.concat ["conjugate", wrapParentheses $ innerPrettify arg] InnerProd arg1 arg2 -> T.concat [innerPrettify arg1, "<.>", innerPrettify arg2] Piecewise marks conditionArg branches -> let printBranches = T.intercalate ", " . map innerPrettify $ branches diff --git a/test/ReverseDifferentiationSpec.hs b/test/ReverseDifferentiationSpec.hs new file mode 100644 index 00000000..d98f66f1 --- /dev/null +++ b/test/ReverseDifferentiationSpec.hs @@ -0,0 +1,43 @@ +module ReverseDifferentiationSpec where + +import Commons +import Control.Applicative (liftA2) +import Control.Monad +import Control.Monad (replicateM_, unless) +import qualified Data.IntMap.Strict as IM +import Data.List (group, sort) +import qualified Data.Map.Strict as Map +import Data.Maybe (fromJust) +import qualified Data.Set as Set +import Data.Tuple.Extra (thd3) +import Debug.Trace (traceShow) +import HashedExpression.Differentiation.Exterior.Collect +import HashedExpression.Differentiation.Exterior.Derivative +import HashedExpression.Differentiation.Reverse +import HashedExpression.Internal (D_, ET_, unwrap) +import HashedExpression.Internal.Expression +import HashedExpression.Internal.Node +import HashedExpression.Internal.Normalize +import HashedExpression.Internal.Structure +import HashedExpression.Internal.Utils +import HashedExpression.Interp +import HashedExpression.Operation hiding (product, sum) +import qualified HashedExpression.Operation +import HashedExpression.Prettify +import Test.HUnit (assertBool) +import Test.Hspec +import Test.QuickCheck +import Var +import Prelude hiding ((^)) +import qualified Prelude + +spec :: Spec +spec = + describe "Reverse differentiation spec" $ do + specify "Unit tests" $ do + let f = xRe ((x1 +: y1) <.> (y1 +: z1)) + showExp f + showExp $ collectDifferentials . derivativeAllVars $ f + let (mp, pd) = compute f + forM_ (Map.toList pd) $ \(name, pID) -> do + print $ name ++ ": " ++ debugPrint (mp, pID) From 62c9be7978a9ec7d65fac6e3eb161afd7e7b4734 Mon Sep 17 00:00:00 2001 From: "Nhan Thai (dandoh)" Date: Mon, 27 Jul 2020 22:44:37 +0700 Subject: [PATCH 4/4] [ reverse mode differentiation ] done reverse mode --- HashedExpression.cabal | 4 +- .../Differentiation/Exterior.hs | 40 ++++++++++++ .../Differentiation/Exterior/Partial.hs | 34 ---------- .../Differentiation/Reverse.hs | 65 ++++++++----------- .../Differentiation/Reverse/State.hs | 12 +++- .../Internal/OperationSpec.hs | 2 +- src/HashedExpression/Interp.hs | 13 ++++ src/HashedExpression/Problem.hs | 25 +------ test/Commons.hs | 43 ++++++------ test/ReverseDifferentiationSpec.hs | 44 ++++++++++--- test/Spec.hs | 2 + 11 files changed, 159 insertions(+), 125 deletions(-) create mode 100644 src/HashedExpression/Differentiation/Exterior.hs delete mode 100644 src/HashedExpression/Differentiation/Exterior/Partial.hs diff --git a/HashedExpression.cabal b/HashedExpression.cabal index ffa09b43..7ff74415 100644 --- a/HashedExpression.cabal +++ b/HashedExpression.cabal @@ -4,7 +4,7 @@ cabal-version: 1.12 -- -- see: https://github.com/sol/hpack -- --- hash: 4519a110756abaa652902d0183e8fdaa36524efb3ae39bcd679502eba5d75a05 +-- hash: 158d5ed7ffd77d8520dc6eedabb8d47f1240293abc2485d2358bc8fa6a262c85 name: HashedExpression version: 0.1.0.0 @@ -39,9 +39,9 @@ library HashedExpression.Codegen HashedExpression.Codegen.CSIMD HashedExpression.Codegen.CSimple + HashedExpression.Differentiation.Exterior HashedExpression.Differentiation.Exterior.Collect HashedExpression.Differentiation.Exterior.Derivative - HashedExpression.Differentiation.Exterior.Partial HashedExpression.Differentiation.Reverse HashedExpression.Differentiation.Reverse.State HashedExpression.Embed.FFTW diff --git a/src/HashedExpression/Differentiation/Exterior.hs b/src/HashedExpression/Differentiation/Exterior.hs new file mode 100644 index 00000000..f377d954 --- /dev/null +++ b/src/HashedExpression/Differentiation/Exterior.hs @@ -0,0 +1,40 @@ +-- | +-- Module : HashedExpression.Differentiation.Exterior.Derivative +-- Copyright : (c) OCA 2020 +-- License : MIT (see the LICENSE file) +-- Maintainer : anandc@mcmaster.ca +-- Stability : provisional +-- Portability : unportable +module HashedExpression.Differentiation.Exterior where + +import Data.Map.Strict (Map) +import qualified Data.Map.Strict as Map +import Data.Maybe (mapMaybe) +import HashedExpression.Differentiation.Exterior.Collect +import HashedExpression.Differentiation.Exterior.Derivative +import HashedExpression.Internal +import HashedExpression.Internal.Expression +import HashedExpression.Internal.Node + +partialDerivativesMapByExterior :: Expression Scalar R -> (ExpressionMap, Map String NodeID) +partialDerivativesMapByExterior exp = + let (mp, rootID) = unwrap . collectDifferentials . derivativeAllVars $ exp + in (mp, partialDerivativesMap (mp, rootID)) + +-- | Return a map from variable name to the corresponding partial derivative node id +-- Partial derivatives in Expression Scalar Covector should be collected before passing to this function +partialDerivativesMap :: (ExpressionMap, NodeID) -> Map String NodeID +partialDerivativesMap (dfMp, dfId) = + case retrieveOp dfId dfMp of + Sum ns | retrieveElementType dfId dfMp == Covector -> Map.fromList $ mapMaybe getPartial ns + _ -> Map.fromList $ mapMaybe getPartial [dfId] + where + getPartial :: NodeID -> Maybe (String, NodeID) + getPartial nId + | MulD partialId dId <- retrieveOp nId dfMp, + DVar name <- retrieveOp dId dfMp = + Just (name, partialId) + | InnerProdD partialId dId <- retrieveOp nId dfMp, + DVar name <- retrieveOp dId dfMp = + Just (name, partialId) + | otherwise = Nothing diff --git a/src/HashedExpression/Differentiation/Exterior/Partial.hs b/src/HashedExpression/Differentiation/Exterior/Partial.hs deleted file mode 100644 index 0d00abc1..00000000 --- a/src/HashedExpression/Differentiation/Exterior/Partial.hs +++ /dev/null @@ -1,34 +0,0 @@ -module HashedExpression.Differentiation.Exterior.Partial where - -import qualified Data.Set as Set -import HashedExpression -import HashedExpression.Differentiation.Exterior.Collect -import HashedExpression.Differentiation.Exterior.Derivative -import HashedExpression.Internal -import HashedExpression.Internal.Expression -import HashedExpression.Internal.Node -import HashedExpression.Internal.Utils -import HashedExpression.Prettify - --- | Compute partial derivative: ∂f / ∂x. --- Automatically performs 'exteriorDerivative' w.r.t a single variable, uses 'collectDifferentials' to --- factor terms and extracts the term corresponding to the partial derivative w.r.t the given variable, --- returning that term alone as a 'Expression' -partialDerivative :: - DimensionType d => - -- | base Expression - Expression Scalar R -> - -- | variable to take partial w.r.t - Expression d R -> - -- | term corresponding to partial - Expression d R -partialDerivative f mx = case maybeVariable mx of - Just (x, shape) -> - let df = exteriorDerivative (Set.fromList [x]) f - Expression nID mp = collectDifferentials df - in case retrieveOp nID mp of - DZero -> constWithShape shape 0 - MulD partialID _ -> wrap (mp, partialID) - InnerProdD partialID _ -> wrap (mp, partialID) - node -> error $ "This should not happen: " ++ show node - Nothing -> error "2nd argument is not a variable" diff --git a/src/HashedExpression/Differentiation/Reverse.hs b/src/HashedExpression/Differentiation/Reverse.hs index 042aa27e..8696895d 100644 --- a/src/HashedExpression/Differentiation/Reverse.hs +++ b/src/HashedExpression/Differentiation/Reverse.hs @@ -25,10 +25,11 @@ import HashedExpression.Internal.OperationSpec import HashedExpression.Internal.Structure import Prelude hiding ((^)) -compute :: +-- | +partialDerivativesMapByReverse :: Expression Scalar R -> (ExpressionMap, Map String NodeID) -compute (Expression rootID mp) = +partialDerivativesMapByReverse (Expression rootID mp) = let reverseTopoOrder = reverse $ topologicalSort (mp, rootID) init = ComputeDState mp IM.empty Map.empty -- Chain rule @@ -46,6 +47,8 @@ compute (Expression rootID mp) = Just ds -> perform (Nary specSum) ds curMp <- gets contextMap let (shape, et, op) = retrieveNode nID curMp + let one = introduceNode (shape, R, Const 1) + let zero = introduceNode (shape, R, Const 0) case op of Var name -> modifyPartialDerivativeMap (Map.insert name dN) Const _ -> return () @@ -68,7 +71,7 @@ compute (Expression rootID mp) = dX <- sNum (fromIntegral alpha) *. (from dN * (from x ^ (alpha -1))) modifyComputedPartsByParents (IM.insertWith (++) x [dX]) C -> do - dX <- sNum (fromIntegral alpha) *. conjugate (from x ^ (alpha - 1)) + dX <- sNum (fromIntegral alpha) *. (from dN * conjugate (from x ^ (alpha - 1))) modifyComputedPartsByParents (IM.insertWith (++) x [dX]) Neg x -> do dX <- negate $ from dN @@ -126,26 +129,25 @@ compute (Expression rootID mp) = dX <- from dN * sinh (from x) modifyComputedPartsByParents (IM.insertWith (++) x [dX]) Tanh x -> do - let one = introduceNode (shape, R, Const 1) dX <- from dN * (one - tanh (from x) ^ 2) modifyComputedPartsByParents (IM.insertWith (++) x [dX]) Asin x -> do - dX <- error "TODO" + dX <- from dN * (one / sqrt (one - from x ^ 2)) modifyComputedPartsByParents (IM.insertWith (++) x [dX]) Acos x -> do - dX <- error "TODO" + dX <- from dN * (- one / sqrt (one - from x ^ 2)) modifyComputedPartsByParents (IM.insertWith (++) x [dX]) Atan x -> do - dX <- error "TODO" + dX <- from dN * (one / one + from x ^ 2) modifyComputedPartsByParents (IM.insertWith (++) x [dX]) Asinh x -> do - dX <- error "TODO" + dX <- from dN * (one / sqrt (one + from x ^ 2)) modifyComputedPartsByParents (IM.insertWith (++) x [dX]) Acosh x -> do - dX <- error "TODO" + dX <- from dN * (one / sqrt (from x ^ 2 - one)) modifyComputedPartsByParents (IM.insertWith (++) x [dX]) Atanh x -> do - dX <- error "TODO" + dX <- from dN * (one / sqrt (one - from x ^ 2)) modifyComputedPartsByParents (IM.insertWith (++) x [dX]) RealImag re im -> do dRe <- xRe $ from dN @@ -153,11 +155,9 @@ compute (Expression rootID mp) = dIm <- xIm $ from dN modifyComputedPartsByParents (IM.insertWith (++) im [dIm]) RealPart reIm -> do - let zero = introduceNode (shape, R, Const 0) dReIm <- from dN +: zero modifyComputedPartsByParents (IM.insertWith (++) reIm [dReIm]) ImagPart reIm -> do - let zero = introduceNode (shape, R, Const 0) dReIm <- zero +: from dN modifyComputedPartsByParents (IM.insertWith (++) reIm [dReIm]) InnerProd x y -> do @@ -172,7 +172,21 @@ compute (Expression rootID mp) = modifyComputedPartsByParents (IM.insertWith (++) x [dX]) dY <- conjugate (from dN) *. from x modifyComputedPartsByParents (IM.insertWith (++) y [dY]) - Piecewise {} -> undefined + Piecewise marks condition branches -> do + dCondition <- zero + modifyComputedPartsByParents (IM.insertWith (++) condition [dCondition]) + let numBranches = length branches + forM_ (zip branches [0 ..]) $ \(branch, idx) -> case et of + R -> do + associate <- piecewise marks (from condition) (replicate idx zero ++ [one] ++ replicate (numBranches - idx - 1) zero) + dBranch <- from dN * from associate + modifyComputedPartsByParents (IM.insertWith (++) branch [dBranch]) + C -> do + let zeroC = zero +: zero + let oneC = one +: zero + associate <- piecewise marks (from condition) (replicate idx zeroC ++ [oneC] ++ replicate (numBranches - idx - 1) zeroC) + dBranch <- from dN * conjugate (from associate) + modifyComputedPartsByParents (IM.insertWith (++) branch [dBranch]) Rotate amount x -> do dX <- perform (Unary (specRotate (map negate amount))) [dN] dX <- rotate (map negate amount) $ from dN @@ -192,28 +206,5 @@ compute (Expression rootID mp) = dX <- imFT (from dN) +: (- reFT (from dN)) modifyComputedPartsByParents (IM.insertWith (++) x [dX]) (_, res) = runState go init - in -- res = flip runStateT init $ do - -- forM_ reverseTopoOrder $ \nID -> do - -- undefined - - -- update :: (ExpressionMap, IM.IntMap [NodeID], IM.IntMap NodeID) -> NodeID -> (ExpressionMap, IM.IntMap [NodeID], IM.IntMap NodeID) - -- update (accMp, accsByID, dByID) nID = - -- let (shape, et, op) = retrieveNode nID accMp - -- derivativeCurrent - -- | nID == rootID = undefined - -- | otherwise = case IM.lookup nID accsByID of - -- Just [d] -> d - -- Just ds -> undefined - -- in undefined - (contextMap res, partialDerivativeMap res) - ---- --- Re Im ---- (a + bi) <.> (c + di) --- (a <.> c + b <.> d) + (b <.> c - a <.> d) --- (Re *. c - Im *. d) (Re *. d + IM *. c) - --- (Re *. a + Im *. b) + (Re *. b - Im *. a) --- + in (contextMap res, partialDerivativeMap res) ------------------------------------ diff --git a/src/HashedExpression/Differentiation/Reverse/State.hs b/src/HashedExpression/Differentiation/Reverse/State.hs index d6edb4f6..a89c82df 100644 --- a/src/HashedExpression/Differentiation/Reverse/State.hs +++ b/src/HashedExpression/Differentiation/Reverse/State.hs @@ -15,6 +15,7 @@ import Data.List (foldl') import Data.List.HT (removeEach) import Data.Map.Strict (Map) import qualified Data.Map.Strict as Map +import GHC.Stack (HasCallStack) import HashedExpression.Internal import HashedExpression.Internal.Expression import HashedExpression.Internal.Hash @@ -29,21 +30,27 @@ data ComputeDState = ComputeDState partialDerivativeMap :: Map String NodeID } +-- | modifyContextMap :: (ExpressionMap -> ExpressionMap) -> ComputeReverseM () modifyContextMap f = modify' $ \s -> s {contextMap = f (contextMap s)} +-- | modifyComputedPartsByParents :: (IM.IntMap [NodeID] -> IM.IntMap [NodeID]) -> ComputeReverseM () modifyComputedPartsByParents f = modify' $ \s -> s {computedPartsByParents = f (computedPartsByParents s)} +-- | modifyPartialDerivativeMap :: (Map String NodeID -> Map String NodeID) -> ComputeReverseM () modifyPartialDerivativeMap f = modify' $ \s -> s {partialDerivativeMap = f (partialDerivativeMap s)} +-- | from :: NodeID -> ComputeReverseM NodeID from = return +-- | sNum :: Double -> ComputeReverseM NodeID sNum val = introduceNode ([], R, Const val) +-- | introduceNode :: Node -> ComputeReverseM NodeID introduceNode node = do mp <- gets contextMap @@ -51,6 +58,7 @@ introduceNode node = do modify' $ \s -> s {contextMap = IM.insert nID node mp} return nID +-- | perform :: OperationSpec -> [NodeID] -> ComputeReverseM NodeID perform spec operandIDs = do mp <- gets contextMap @@ -59,6 +67,7 @@ perform spec operandIDs = do modify' $ \s -> s {contextMap = IM.insert nID node mp} return nID +-- | type ComputeReverseM a = State ComputeDState a instance Num (ComputeReverseM NodeID) where @@ -71,11 +80,12 @@ instance Num (ComputeReverseM NodeID) where do x <- operand perform (Unary specNeg) [x] + (*) :: HasCallStack => ComputeReverseM NodeID -> ComputeReverseM NodeID -> ComputeReverseM NodeID (*) operand1 operand2 = do x <- operand1 y <- operand2 - perform (Nary specSum) [x, y] + perform (Nary specMul) [x, y] instance Fractional (ComputeReverseM NodeID) where (/) operand1 operand2 = do diff --git a/src/HashedExpression/Internal/OperationSpec.hs b/src/HashedExpression/Internal/OperationSpec.hs index 640e3203..6238ac85 100644 --- a/src/HashedExpression/Internal/OperationSpec.hs +++ b/src/HashedExpression/Internal/OperationSpec.hs @@ -159,7 +159,7 @@ specRealImag = decideShape x y = assertSame [x, y] x decideET x y | x == R && y == R = C - | otherwise = error "2 operands must be real" + | otherwise = error $ "2 operands must be real" ++ show [x, y] specRealPart :: HasCallStack => UnarySpec specRealPart = diff --git a/src/HashedExpression/Interp.hs b/src/HashedExpression/Interp.hs index 3bb61024..0ae2a250 100644 --- a/src/HashedExpression/Interp.hs +++ b/src/HashedExpression/Interp.hs @@ -10,6 +10,12 @@ module HashedExpression.Interp ( Evaluable (..), Approximable (..), + evaluate1DReal, + evaluate1DComplex, + evaluate2DReal, + evaluate2DComplex, + evaluate3DReal, + evaluate3DComplex, ) where @@ -278,6 +284,7 @@ instance Evaluable Scalar C (Complex Double) where -- show the real and imaginary part of complex as x + i y eval valMap (expZeroR mp arg1) :+ eval valMap (expZeroR mp arg2) + Conjugate arg -> conjugate $ eval valMap (expZeroC mp arg) InnerProd arg1 arg2 -> -- evaluate the inner product in C case retrieveShape arg1 mp of @@ -472,6 +479,9 @@ evaluate1DComplex valMap (mp, n) (:+) (evaluate1DReal valMap $ (mp, arg1)) (evaluate1DReal valMap $ (mp, arg2)) + Conjugate arg -> + let res = evaluate1DComplex valMap (mp, arg) + in fmap conjugate res Piecewise marks conditionArg branchArgs -> let cdt = evaluate1DReal valMap $ (mp, conditionArg) branches = @@ -643,6 +653,9 @@ evaluate2DComplex valMap (mp, n) (:+) (evaluate2DReal valMap $ (mp, arg1)) (evaluate2DReal valMap $ (mp, arg2)) + Conjugate arg -> + let res = evaluate2DComplex valMap (mp, arg) + in fmap conjugate res Piecewise marks conditionArg branchArgs -> let cdt = evaluate2DReal valMap $ (mp, conditionArg) branches = diff --git a/src/HashedExpression/Problem.hs b/src/HashedExpression/Problem.hs index 4741470c..80212bb1 100644 --- a/src/HashedExpression/Problem.hs +++ b/src/HashedExpression/Problem.hs @@ -25,6 +25,7 @@ import qualified Data.Set as Set import Debug.Trace (traceShowId) import HashedExpression.Differentiation.Exterior.Collect import HashedExpression.Differentiation.Exterior.Derivative +import HashedExpression.Differentiation.Exterior import HashedExpression.Internal import HashedExpression.Internal.Expression import HashedExpression.Internal.Node @@ -128,26 +129,6 @@ inf = 1 / 0 ------------------------------------------------------------------------------- --- | Return a map from variable name to the corresponding partial derivative node id --- Partial derivatives in Expression Scalar Covector should be collected before passing to this function -partialDerivativeMaps :: (ExpressionMap, NodeID) -> Map String NodeID -partialDerivativeMaps (dfMp, dfId) = - case retrieveOp dfId dfMp of - Sum ns | retrieveElementType dfId dfMp == Covector -> Map.fromList $ mapMaybe getPartial ns - _ -> Map.fromList $ mapMaybe getPartial [dfId] - where - getPartial :: NodeID -> Maybe (String, NodeID) - getPartial nId - | MulD partialId dId <- retrieveOp nId dfMp, - DVar name <- retrieveOp dId dfMp = - Just (name, partialId) - | InnerProdD partialId dId <- retrieveOp nId dfMp, - DVar name <- retrieveOp dId dfMp = - Just (name, partialId) - | otherwise = Nothing - -------------------------------------------------------------------------------- - -- | The statement of a constraint, including an 'ExpressionMap' subexpressions, the root 'NodeID' and its value data ConstraintStatement = -- | A lower bound constraint @@ -342,7 +323,7 @@ constructProblemHelper obj names (Constraint constraints) = do curMp <- get let finalRelevantVars = filter (\(name, _) -> Set.member name varsSet) $ varNodesWithId curMp let name2PartialDerivativeID :: Map String NodeID - name2PartialDerivativeID = partialDerivativeMaps (curMp, dfID) + name2PartialDerivativeID = partialDerivativesMap (curMp, dfID) variables = map ( \(name, varNodeID) -> @@ -353,7 +334,7 @@ constructProblemHelper obj names (Constraint constraints) = do let scalarConstraints = map ( \(gID, dgID, (lb, ub)) -> - let name2PartialDerivativeID = partialDerivativeMaps (curMp, dgID) + let name2PartialDerivativeID = partialDerivativesMap (curMp, dgID) in ScalarConstraint { constraintValueId = gID, constraintPartialDerivatives = map (\(name, _) -> fromJust $ Map.lookup name name2PartialDerivativeID) finalRelevantVars, diff --git a/test/Commons.hs b/test/Commons.hs index a8196545..4e3c8406 100644 --- a/test/Commons.hs +++ b/test/Commons.hs @@ -35,6 +35,9 @@ import Test.QuickCheck import Var import Prelude hiding ((^)) +sizeReduceFactor :: Int +sizeReduceFactor = 20 + -- | -- -- | Remove duplicate but also sort @@ -179,10 +182,10 @@ genScalarR :: genScalarR size | size == 0 = primitiveScalarR | otherwise = - let sub = genScalarR @default1D @default2D1 @default2D2 (size `div` 20) - subC = genScalarC @default1D @default2D1 @default2D2 (size `div` 20) - sub1D = gen1DR @default1D @default2D1 @default2D2 (size `div` 20) - sub2D = gen2DR @default1D @default2D1 @default2D2 (size `div` 20) + let sub = genScalarR @default1D @default2D1 @default2D2 (size `div` sizeReduceFactor) + subC = genScalarC @default1D @default2D1 @default2D2 (size `div` sizeReduceFactor) + sub1D = gen1DR @default1D @default2D1 @default2D2 (size `div` sizeReduceFactor) + sub2D = gen2DR @default1D @default2D1 @default2D2 (size `div` sizeReduceFactor) fromPiecewise = do numBranches <- elements [2, 3] branches <- vectorOf numBranches sub @@ -217,10 +220,10 @@ genScalarC :: genScalarC size | size == 0 = primitiveScalarC | otherwise = - let sub = genScalarC @default1D @default2D1 @default2D2 (size `div` 20) - subR = genScalarR @default1D @default2D1 @default2D2 (size `div` 20) - sub1D = gen1DC @default1D @default2D1 @default2D2 (size `div` 20) - sub2D = gen2DC @default1D @default2D1 @default2D2 (size `div` 20) + let sub = genScalarC @default1D @default2D1 @default2D2 (size `div` sizeReduceFactor) + subR = genScalarR @default1D @default2D1 @default2D2 (size `div` sizeReduceFactor) + sub1D = gen1DC @default1D @default2D1 @default2D2 (size `div` sizeReduceFactor) + sub2D = gen2DC @default1D @default2D1 @default2D2 (size `div` sizeReduceFactor) fromPiecewise = do numBranches <- elements [2, 3] branches <- vectorOf numBranches sub @@ -254,9 +257,9 @@ gen1DR :: gen1DR size | size == 0 = primitive1DR | otherwise = - let sub = gen1DR @n @default2D1 @default2D2 (size `div` 20) - subC = gen1DC @n @default2D1 @default2D2 (size `div` 20) - subScalar = genScalarR @n @default2D1 @default2D2 (size `div` 20) + let sub = gen1DR @n @default2D1 @default2D2 (size `div` sizeReduceFactor) + subC = gen1DC @n @default2D1 @default2D2 (size `div` sizeReduceFactor) + subScalar = genScalarR @n @default2D1 @default2D2 (size `div` sizeReduceFactor) fromPiecewise = do numBranches <- elements [2, 3] branches <- vectorOf numBranches sub @@ -294,9 +297,9 @@ gen1DC :: gen1DC size | size == 0 = primitive1DC | otherwise = - let sub = gen1DC @n @default2D1 @default2D2 (size `div` 20) - subR = gen1DR @n @default2D1 @default2D2 (size `div` 20) - subScalar = genScalarC @n @default2D1 @default2D2 (size `div` 20) + let sub = gen1DC @n @default2D1 @default2D2 (size `div` sizeReduceFactor) + subR = gen1DR @n @default2D1 @default2D2 (size `div` sizeReduceFactor) + subScalar = genScalarC @n @default2D1 @default2D2 (size `div` sizeReduceFactor) fromPiecewise = do numBranches <- elements [2, 3] branches <- vectorOf numBranches sub @@ -332,9 +335,9 @@ gen2DR :: gen2DR size | size == 0 = primitive2DR | otherwise = - let sub = gen2DR @default1D @m @n (size `div` 20) - subC = gen2DC @default1D @m @n (size `div` 20) - subScalar = genScalarR @default1D @m @n (size `div` 20) + let sub = gen2DR @default1D @m @n (size `div` sizeReduceFactor) + subC = gen2DC @default1D @m @n (size `div` sizeReduceFactor) + subScalar = genScalarR @default1D @m @n (size `div` sizeReduceFactor) fromPiecewise = do numBranches <- elements [2, 3] branches <- vectorOf numBranches sub @@ -373,9 +376,9 @@ gen2DC :: gen2DC size | size == 0 = primitive2DC | otherwise = - let sub = gen2DC @default1D @m @n (size `div` 20) - subR = gen2DR @default1D @m @n (size `div` 20) - subScalar = genScalarC @default1D @m @n (size `div` 20) + let sub = gen2DC @default1D @m @n (size `div` sizeReduceFactor) + subR = gen2DR @default1D @m @n (size `div` sizeReduceFactor) + subScalar = genScalarC @default1D @m @n (size `div` sizeReduceFactor) fromPiecewise = do numBranches <- elements [2, 3] branches <- vectorOf numBranches sub diff --git a/test/ReverseDifferentiationSpec.hs b/test/ReverseDifferentiationSpec.hs index d98f66f1..6322d3af 100644 --- a/test/ReverseDifferentiationSpec.hs +++ b/test/ReverseDifferentiationSpec.hs @@ -11,7 +11,7 @@ import Data.Maybe (fromJust) import qualified Data.Set as Set import Data.Tuple.Extra (thd3) import Debug.Trace (traceShow) -import HashedExpression.Differentiation.Exterior.Collect +import HashedExpression.Differentiation.Exterior import HashedExpression.Differentiation.Exterior.Derivative import HashedExpression.Differentiation.Reverse import HashedExpression.Internal (D_, ET_, unwrap) @@ -31,13 +31,41 @@ import Var import Prelude hiding ((^)) import qualified Prelude +prop_reverseMethodAndExteriorShouldBeSameValue :: SuiteScalarR -> Expectation +prop_reverseMethodAndExteriorShouldBeSameValue (Suite exp valMap) = do + -- print "---------------------" + -- showExp exp + let (eMP, eMap) = partialDerivativesMapByExterior exp + let (rMP, rMap) = partialDerivativesMapByReverse exp + forM_ (Map.toList $ zipMp eMap rMap) $ \(name, (eID, rID)) -> do + -- putStrLn $ "for: " ++ name + -- putStrLn $ debugPrint (eMP, eID) + -- putStrLn $ debugPrint (rMP, rID) + retrieveShape eID eMP `shouldBe` retrieveShape rID rMP + let shape = retrieveShape eID eMP + case shape of + [] -> do + let valE = eval valMap (Expression @Scalar @R eID eMP) + let valR = eval valMap (Expression @Scalar @R rID rMP) + valE `shouldApprox` valR + [sz] -> do + let valE = evaluate1DReal valMap (eMP, eID) + let valR = evaluate1DReal valMap (rMP, rID) + valE `shouldApprox` valR + [sz1, sz2] -> do + let valE = evaluate2DReal valMap (eMP, eID) + let valR = evaluate2DReal valMap (rMP, rID) + valE `shouldApprox` valR + spec :: Spec spec = describe "Reverse differentiation spec" $ do - specify "Unit tests" $ do - let f = xRe ((x1 +: y1) <.> (y1 +: z1)) - showExp f - showExp $ collectDifferentials . derivativeAllVars $ f - let (mp, pd) = compute f - forM_ (Map.toList pd) $ \(name, pID) -> do - print $ name ++ ": " ++ debugPrint (mp, pID) + specify "should be the same as exterior method" $ do + property prop_reverseMethodAndExteriorShouldBeSameValue + +-- let f = xRe ((x1 +: y1) <.> (y1 +: z1)) +-- showExp f +-- showExp $ collectDifferentials . derivativeAllVars $ f +-- let (mp, pd) = compute f +-- forM_ (Map.toList pd) $ \(name, pID) -> do +-- print $ name ++ ": " ++ debugPrint (mp, pID) diff --git a/test/Spec.hs b/test/Spec.hs index 0eaa5c2e..516561f5 100644 --- a/test/Spec.hs +++ b/test/Spec.hs @@ -21,6 +21,7 @@ import HashedExpression.Prettify import qualified InterpSpec import qualified NormalizeSpec import qualified ProblemSpec +import qualified ReverseDifferentiationSpec import qualified StructureSpec import Test.Hspec import Test.Hspec.Runner @@ -43,5 +44,6 @@ main = do describe "HashedInterpSpec" InterpSpec.spec describe "HashedCollectSpec" CollectSpec.spec describe "StructureSpec" StructureSpec.spec + describe "ReverseDifferentiationSpec" ReverseDifferentiationSpec.spec hspecWith defaultConfig {configQuickCheckMaxSuccess = Just 20} $ do describe "CSimpleSpec" CSimpleSpec.spec