{-# language FlexibleInstances, DeriveFunctor #-}
{-# language ScopedTypeVariables #-}
{-# language RankNTypes #-}
{-# language ViewPatterns #-}
{-# language FlexibleContexts #-}
{-# language BangPatterns #-}
{-# language TypeApplications #-}
{-# language MultiWayIf #-}
module Algorithm.SRTree.AD
( reverseModeArr
, reverseModeGraph
, forwardModeUniqueJac
) where
import Control.Monad (forM_, foldM, when)
import Control.Monad.ST ( runST )
import Data.Bifunctor (bimap, first, second)
import qualified Data.DList as DL
import Data.Massiv.Array hiding (forM_, map, replicate, zipWith)
import qualified Data.Massiv.Array as M
import qualified Data.Massiv.Array.Unsafe as UMA
import Data.Massiv.Core.Operations (unsafeLiftArray)
import Data.SRTree.Derivative ( derivative )
import Data.SRTree.Eval
( SRVector, evalFun, evalOp, SRMatrix, PVector, replicateAs )
import Data.SRTree.Internal
import Data.SRTree.Print (showExpr)
import Data.SRTree.Recursion ( cataM, cata, accu )
import qualified Data.Vector as V
import Debug.Trace (trace, traceShow)
import GHC.IO (unsafePerformIO)
import qualified Data.IntMap.Strict as IntMap
import Data.List ( foldl' )
import qualified Data.Vector.Storable as VS
import Control.Scheduler
import Data.Maybe ( fromJust )
import Control.Monad.State.Strict
import qualified Data.Map.Strict as Map
reverseModeGraph :: SRMatrix -> PVector -> Maybe PVector -> VS.Vector Double -> Fix SRTree -> (Array D Ix1 Double, VS.Vector Double)
reverseModeGraph :: SRMatrix
-> Array S Ix1 Double
-> Maybe (Array S Ix1 Double)
-> Vector Double
-> Fix SRTree
-> (Array D Ix1 Double, Vector Double)
reverseModeGraph SRMatrix
xss Array S Ix1 Double
ys Maybe (Array S Ix1 Double)
mYErr Vector Double
theta Fix SRTree
tree = (Array S Ix1 Double -> Array D Ix1 Double
forall ix r e.
(Index ix, Source r e) =>
Array r ix e -> Array D ix e
delay (Array S Ix1 Double -> Array D Ix1 Double)
-> Array S Ix1 Double -> Array D Ix1 Double
forall a b. (a -> b) -> a -> b
$ IntMap (Array S Ix1 Double)
cachedVal IntMap (Array S Ix1 Double) -> Ix1 -> Array S Ix1 Double
forall a. IntMap a -> Ix1 -> a
IntMap.! Ix1
root
, [Double] -> Vector Double
forall a. Storable a => [a] -> Vector a
VS.fromList [Array S Ix1 Double -> Double
forall ix r e. (Index ix, Source r e, Num e) => Array r ix e -> e
M.sum (Array S Ix1 Double -> Double) -> Array S Ix1 Double -> Double
forall a b. (a -> b) -> a -> b
$ Map (SRTree Ix1) (Array S Ix1 Double)
cachedGrad Map (SRTree Ix1) (Array S Ix1 Double)
-> SRTree Ix1 -> Array S Ix1 Double
forall k a. Ord k => Map k a -> k -> a
Map.! (Ix1 -> SRTree Ix1
forall val. Ix1 -> SRTree val
Param Ix1
ix) | Ix1
ix <- [Ix1
0..Ix1
pIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1]])
where
yErr :: Array S Ix1 Double
yErr = Maybe (Array S Ix1 Double) -> Array S Ix1 Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe (Array S Ix1 Double)
mYErr
m :: Sz Ix1
m = Array S Ix1 Double -> Sz Ix1
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size Array S Ix1 Double
ys
p :: Ix1
p = Vector Double -> Ix1
forall a. Storable a => Vector a -> Ix1
VS.length Vector Double
theta
comp :: Comp
comp = SRMatrix -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
M.getComp SRMatrix
xss
one :: Array S Ix1 Double
one :: Array S Ix1 Double
one = Comp -> Sz Ix1 -> Double -> Array S Ix1 Double
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
comp Sz Ix1
m Double
1
(Map (SRTree Ix1) Ix1
key2int, IntMap (SRTree Ix1)
int2key, IntMap (Array S Ix1 Double)
cachedVal, (Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
subtract Ix1
1) -> Ix1
root) = (forall x.
SRTree
(StateT
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
Identity
x)
-> StateT
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
Identity
(SRTree x))
-> (SRTree Ix1
-> StateT
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
Identity
Ix1)
-> Fix SRTree
-> StateT
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
Identity
Ix1
forall (f :: * -> *) (m :: * -> *) a.
(Functor f, Monad m) =>
(forall x. f (m x) -> m (f x)) -> (f a -> m a) -> Fix f -> m a
cataM SRTree
(StateT
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
Identity
x)
-> StateT
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
Identity
(SRTree x)
forall x.
SRTree
(StateT
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
Identity
x)
-> StateT
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
Identity
(SRTree x)
forall {f :: * -> *} {a}.
Applicative f =>
SRTree (f a) -> f (SRTree a)
leftToRight SRTree Ix1
-> StateT
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
Identity
Ix1
forall {m :: * -> *}.
MonadState
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
m =>
SRTree Ix1 -> m Ix1
alg Fix SRTree
tree StateT
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
Identity
Ix1
-> (Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
-> (Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
forall s a. State s a -> s -> s
`execState` (Map (SRTree Ix1) Ix1
forall k a. Map k a
Map.empty, IntMap (SRTree Ix1)
forall a. IntMap a
IntMap.empty, IntMap (Array S Ix1 Double)
forall a. IntMap a
IntMap.empty, Ix1
0)
(Map (SRTree Ix1) Ix1
key2int', IntMap (SRTree Ix1)
int2key', IntMap (Array S Ix1 Double)
cachedVal', Map (SRTree Ix1) (Array S Ix1 Double)
cachedGrad) = Ix1
-> Array S Ix1 Double
-> State
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
()
calcGrad Ix1
root Array S Ix1 Double
one State
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
()
-> (Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
-> (Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
forall s a. State s a -> s -> s
`execState` (Map (SRTree Ix1) Ix1
key2int, IntMap (SRTree Ix1)
int2key, IntMap (Array S Ix1 Double)
cachedVal, Map (SRTree Ix1) (Array S Ix1 Double)
forall k a. Map k a
Map.empty)
calcGrad :: Int -> Array S Ix1 Double -> State (Map.Map (SRTree Int) Int, IntMap.IntMap (SRTree Int), IntMap.IntMap (Array S Ix1 Double), Map.Map (SRTree Int) (Array S Ix1 Double)) ()
calcGrad :: Ix1
-> Array S Ix1 Double
-> State
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
()
calcGrad Ix1
key Array S Ix1 Double
v = do node <- ((Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
-> SRTree Ix1)
-> StateT
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
Identity
(SRTree Ix1)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((IntMap (SRTree Ix1) -> Ix1 -> SRTree Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! Ix1
key) (IntMap (SRTree Ix1) -> SRTree Ix1)
-> ((Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
-> IntMap (SRTree Ix1))
-> (Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
-> SRTree Ix1
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
-> IntMap (SRTree Ix1)
forall {a} {b} {c} {d}. (a, b, c, d) -> b
_int2key)
case node of
Bin Op
op Ix1
l Ix1
r -> do xl <- ((Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
-> Array S Ix1 Double)
-> StateT
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
Identity
(Array S Ix1 Double)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (Ix1
-> (Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
-> Array S Ix1 Double
forall {a} {b} {a} {d}. Ix1 -> (a, b, IntMap a, d) -> a
getVal Ix1
l)
xr <- gets (getVal r)
(dl, dr) <- diff op v xl xr l r
calcGrad l dl
calcGrad r dr
Uni Function
f Ix1
t -> do x <- ((Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
-> Array S Ix1 Double)
-> StateT
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
Identity
(Array S Ix1 Double)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (Ix1
-> (Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
-> Array S Ix1 Double
forall {a} {b} {a} {d}. Ix1 -> (a, b, IntMap a, d) -> a
getVal Ix1
t)
calcGrad t (M.computeAs S $ M.zipWith (*) v (M.map (derivative f) x))
Param Ix1
ix -> ((Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
-> (Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double),
Map (SRTree Ix1) (Array S Ix1 Double)))
-> State
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify' (Array S Ix1 Double
-> SRTree Ix1
-> (Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
-> (Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
forall {e} {ix} {k} {a} {b} {c}.
(Storable e, Index ix, Num e, Ord k) =>
Array S ix e
-> k
-> (a, b, c, Map k (Array S ix e))
-> (a, b, c, Map k (Array S ix e))
insertGrad Array S Ix1 Double
v (Ix1 -> SRTree Ix1
forall val. Ix1 -> SRTree val
Param Ix1
ix))
SRTree Ix1
_ -> ()
-> State
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
()
forall a.
a
-> StateT
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Map (SRTree Ix1) (Array S Ix1 Double))
Identity
a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
where
_int2key :: (a, b, c, d) -> b
_int2key (a
_, b
b, c
_, d
_) = b
b
insertGrad :: Array S ix e
-> k
-> (a, b, c, Map k (Array S ix e))
-> (a, b, c, Map k (Array S ix e))
insertGrad Array S ix e
v k
k (a
a, b
b, c
c, Map k (Array S ix e)
g) = (a
a, b
b, c
c, (Array S ix e -> Array S ix e -> Array S ix e)
-> k
-> Array S ix e
-> Map k (Array S ix e)
-> Map k (Array S ix e)
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
Map.insertWith (\Array S ix e
v1 Array S ix e
v2 -> S -> Array D ix e -> Array S ix e
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D ix e -> Array S ix e) -> Array D ix e -> Array S ix e
forall a b. (a -> b) -> a -> b
$ (e -> e -> e) -> Array S ix e -> Array S ix e -> Array D ix e
forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
M.zipWith e -> e -> e
forall a. Num a => a -> a -> a
(+) Array S ix e
v1 Array S ix e
v2) k
k Array S ix e
v Map k (Array S ix e)
g)
graph :: (a, b, c, d) -> a
graph (a
a, b
_, c
_, d
_) = a
a
insKey :: a
-> a
-> (Map a Ix1, IntMap a, IntMap a, Ix1)
-> (Map a Ix1, IntMap a, IntMap a, Ix1)
insKey a
key a
ev (Map a Ix1
a, IntMap a
b, IntMap a
c, Ix1
d) = (a -> Ix1 -> Map a Ix1 -> Map a Ix1
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert a
key Ix1
d Map a Ix1
a, Ix1 -> a -> IntMap a -> IntMap a
forall a. Ix1 -> a -> IntMap a -> IntMap a
IntMap.insert Ix1
d a
key IntMap a
b, Ix1 -> a -> IntMap a -> IntMap a
forall a. Ix1 -> a -> IntMap a -> IntMap a
IntMap.insert Ix1
d a
ev IntMap a
c, Ix1
dIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+Ix1
1)
getVal :: Ix1 -> (a, b, IntMap a, d) -> a
getVal Ix1
key (a
a, b
b, IntMap a
c, d
d) = IntMap a
c IntMap a -> Ix1 -> a
forall a. IntMap a -> Ix1 -> a
IntMap.! Ix1
key
getKey :: k -> (Map k a, b, c, d) -> a
getKey k
key (Map k a
a, b
b, c
c, d
d) = Map k a
a Map k a -> k -> a
forall k a. Ord k => Map k a -> k -> a
Map.! k
key
leftToRight :: SRTree (f a) -> f (SRTree a)
leftToRight (Uni Function
f f a
mt) = Function -> a -> SRTree a
forall val. Function -> val -> SRTree val
Uni Function
f (a -> SRTree a) -> f a -> f (SRTree a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f a
mt;
leftToRight (Bin Op
f f a
ml f a
mr) = Op -> a -> a -> SRTree a
forall val. Op -> val -> val -> SRTree val
Bin Op
f (a -> a -> SRTree a) -> f a -> f (a -> SRTree a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f a
ml f (a -> SRTree a) -> f a -> f (SRTree a)
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> f a
mr
leftToRight (Var Ix1
ix) = SRTree a -> f (SRTree a)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Ix1 -> SRTree a
forall val. Ix1 -> SRTree val
Var Ix1
ix)
leftToRight (Param Ix1
ix) = SRTree a -> f (SRTree a)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Ix1 -> SRTree a
forall val. Ix1 -> SRTree val
Param Ix1
ix)
leftToRight (Const Double
c) = SRTree a -> f (SRTree a)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double -> SRTree a
forall val. Double -> SRTree val
Const Double
c)
evalKey :: SRTree Ix1 -> f (Array S Ix1 Double)
evalKey (Var Ix1
ix) = Array S Ix1 Double -> f (Array S Ix1 Double)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Array S Ix1 Double -> f (Array S Ix1 Double))
-> Array S Ix1 Double -> f (Array S Ix1 Double)
forall a b. (a -> b) -> a -> b
$ if Ix1
ix Ix1 -> Ix1 -> Bool
forall a. Eq a => a -> a -> Bool
== -Ix1
1
then Array S Ix1 Double
ys
else if Ix1
ix Ix1 -> Ix1 -> Bool
forall a. Eq a => a -> a -> Bool
== -Ix1
2
then Array S Ix1 Double
yErr
else S -> Array D Ix1 Double -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D Ix1 Double -> Array S Ix1 Double)
-> Array D Ix1 Double -> Array S Ix1 Double
forall a b. (a -> b) -> a -> b
$ SRMatrix
xss SRMatrix -> Ix1 -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Ix1 -> Array D (Lower ix) e
<! Ix1
ix
evalKey (Const Double
v) = Array S Ix1 Double -> f (Array S Ix1 Double)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Array S Ix1 Double -> f (Array S Ix1 Double))
-> Array S Ix1 Double -> f (Array S Ix1 Double)
forall a b. (a -> b) -> a -> b
$ Comp -> Sz Ix1 -> Double -> Array S Ix1 Double
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
comp Sz Ix1
m Double
v
evalKey (Param Ix1
ix) = Array S Ix1 Double -> f (Array S Ix1 Double)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Array S Ix1 Double -> f (Array S Ix1 Double))
-> Array S Ix1 Double -> f (Array S Ix1 Double)
forall a b. (a -> b) -> a -> b
$ Comp -> Sz Ix1 -> Double -> Array S Ix1 Double
forall r ix e. Load r ix e => Comp -> Sz ix -> e -> Array r ix e
M.replicate Comp
comp Sz Ix1
m (Vector Double
theta Vector Double -> Ix1 -> Double
forall a. Storable a => Vector a -> Ix1 -> a
VS.! Ix1
ix)
evalKey (Uni Function
f Ix1
t) = S -> Array D Ix1 Double -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D Ix1 Double -> Array S Ix1 Double)
-> (Array r2 Ix1 Double -> Array D Ix1 Double)
-> Array r2 Ix1 Double
-> Array S Ix1 Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> Double) -> Array r2 Ix1 Double -> Array D Ix1 Double
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map (Function -> Double -> Double
forall a. Floating a => Function -> a -> a
evalFun Function
f) (Array r2 Ix1 Double -> Array S Ix1 Double)
-> f (Array r2 Ix1 Double) -> f (Array S Ix1 Double)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((a, b, IntMap (Array r2 Ix1 Double), d) -> Array r2 Ix1 Double)
-> f (Array r2 Ix1 Double)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (Ix1
-> (a, b, IntMap (Array r2 Ix1 Double), d) -> Array r2 Ix1 Double
forall {a} {b} {a} {d}. Ix1 -> (a, b, IntMap a, d) -> a
getVal Ix1
t)
evalKey (Bin Op
op Ix1
l Ix1
r) = S -> Array D Ix1 Double -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D Ix1 Double -> Array S Ix1 Double)
-> f (Array D Ix1 Double) -> f (Array S Ix1 Double)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Double -> Double -> Double)
-> Array r2 Ix1 Double -> Array r2 Ix1 Double -> Array D Ix1 Double
forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
M.zipWith (Op -> Double -> Double -> Double
forall a. Floating a => Op -> a -> a -> a
evalOp Op
op) (Array r2 Ix1 Double -> Array r2 Ix1 Double -> Array D Ix1 Double)
-> f (Array r2 Ix1 Double)
-> f (Array r2 Ix1 Double -> Array D Ix1 Double)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((a, b, IntMap (Array r2 Ix1 Double), d) -> Array r2 Ix1 Double)
-> f (Array r2 Ix1 Double)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (Ix1
-> (a, b, IntMap (Array r2 Ix1 Double), d) -> Array r2 Ix1 Double
forall {a} {b} {a} {d}. Ix1 -> (a, b, IntMap a, d) -> a
getVal Ix1
l) f (Array r2 Ix1 Double -> Array D Ix1 Double)
-> f (Array r2 Ix1 Double) -> f (Array D Ix1 Double)
forall a b. f (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ((a, b, IntMap (Array r2 Ix1 Double), d) -> Array r2 Ix1 Double)
-> f (Array r2 Ix1 Double)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (Ix1
-> (a, b, IntMap (Array r2 Ix1 Double), d) -> Array r2 Ix1 Double
forall {a} {b} {a} {d}. Ix1 -> (a, b, IntMap a, d) -> a
getVal Ix1
r))
alg :: SRTree Ix1 -> m Ix1
alg (Var Ix1
ix) = SRTree Ix1 -> m Ix1
forall {m :: * -> *}.
MonadState
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
m =>
SRTree Ix1 -> m Ix1
insertKey (Ix1 -> SRTree Ix1
forall val. Ix1 -> SRTree val
Var Ix1
ix)
alg (Param Ix1
ix) = SRTree Ix1 -> m Ix1
forall {m :: * -> *}.
MonadState
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
m =>
SRTree Ix1 -> m Ix1
insertKey (Ix1 -> SRTree Ix1
forall val. Ix1 -> SRTree val
Param Ix1
ix)
alg (Const Double
v) = SRTree Ix1 -> m Ix1
forall {m :: * -> *}.
MonadState
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
m =>
SRTree Ix1 -> m Ix1
insertKey (Double -> SRTree Ix1
forall val. Double -> SRTree val
Const Double
v)
alg (Uni Function
f Ix1
t) = SRTree Ix1 -> m Ix1
forall {m :: * -> *}.
MonadState
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
m =>
SRTree Ix1 -> m Ix1
insertKey (Function -> Ix1 -> SRTree Ix1
forall val. Function -> val -> SRTree val
Uni Function
f Ix1
t)
alg (Bin Op
op Ix1
l Ix1
r) = SRTree Ix1 -> m Ix1
forall {m :: * -> *}.
MonadState
(Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
m =>
SRTree Ix1 -> m Ix1
insertKey (Op -> Ix1 -> Ix1 -> SRTree Ix1
forall val. Op -> val -> val -> SRTree val
Bin Op
op Ix1
l Ix1
r)
diff :: Op
-> Array S ix e
-> Array r2 ix e
-> Array r3 ix e
-> val
-> val
-> f (Array S ix e, Array S ix e)
diff Op
Add Array S ix e
dx Array r2 ix e
fx Array r3 ix e
gy val
l val
r = (Array S ix e, Array S ix e) -> f (Array S ix e, Array S ix e)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Array S ix e
dx, Array S ix e
dx)
diff Op
Sub Array S ix e
dx Array r2 ix e
fx Array r3 ix e
gy val
l val
r = (Array S ix e, Array S ix e) -> f (Array S ix e, Array S ix e)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Array S ix e
dx, S -> Array D ix e -> Array S ix e
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D ix e -> Array S ix e) -> Array D ix e -> Array S ix e
forall a b. (a -> b) -> a -> b
$ (e -> e) -> Array S ix e -> Array D ix e
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map e -> e
forall a. Num a => a -> a
negate Array S ix e
dx)
diff Op
Mul Array S ix e
dx Array r2 ix e
fx Array r3 ix e
gy val
l val
r = (Array S ix e, Array S ix e) -> f (Array S ix e, Array S ix e)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (S -> Array D ix e -> Array S ix e
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D ix e -> Array S ix e) -> Array D ix e -> Array S ix e
forall a b. (a -> b) -> a -> b
$ (e -> e -> e) -> Array S ix e -> Array r3 ix e -> Array D ix e
forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
M.zipWith e -> e -> e
forall a. Num a => a -> a -> a
(*) Array S ix e
dx Array r3 ix e
gy, S -> Array D ix e -> Array S ix e
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D ix e -> Array S ix e) -> Array D ix e -> Array S ix e
forall a b. (a -> b) -> a -> b
$ (e -> e -> e) -> Array S ix e -> Array r2 ix e -> Array D ix e
forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
M.zipWith e -> e -> e
forall a. Num a => a -> a -> a
(*) Array S ix e
dx Array r2 ix e
fx)
diff Op
Div Array S ix e
dx Array r2 ix e
fx Array r3 ix e
gy val
l val
r = do
k <- ((Map (SRTree val) Ix1, b, IntMap (Array r3 ix e), d) -> Ix1)
-> f Ix1
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (SRTree val
-> (Map (SRTree val) Ix1, b, IntMap (Array r3 ix e), d) -> Ix1
forall {k} {a} {b} {c} {d}. Ord k => k -> (Map k a, b, c, d) -> a
getKey (Op -> val -> val -> SRTree val
forall val. Op -> val -> val -> SRTree val
Bin Op
Div val
l val
r))
v <- gets (getVal k)
pure (M.computeAs S $ M.zipWith (/) dx gy
, M.computeAs S $ M.zipWith (*) dx (M.zipWith (\e
l e
r -> e -> e
forall a. Num a => a -> a
negate e
le -> e -> e
forall a. Fractional a => a -> a -> a
/e
r) v gy))
diff Op
Power Array S ix e
dx Array r2 ix e
fx Array r3 ix e
gy val
l val
r = do
k <- ((Map (SRTree val) Ix1, b, IntMap (Array r3 ix e), d) -> Ix1)
-> f Ix1
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (SRTree val
-> (Map (SRTree val) Ix1, b, IntMap (Array r3 ix e), d) -> Ix1
forall {k} {a} {b} {c} {d}. Ord k => k -> (Map k a, b, c, d) -> a
getKey (Op -> val -> val -> SRTree val
forall val. Op -> val -> val -> SRTree val
Bin Op
Power val
l val
r))
v <- gets (getVal k)
pure ( M.computeAs S $ M.zipWith4 (\e
d e
f e
g e
vi -> e -> e
forall {a}. RealFloat a => a -> a
fixNaN (e -> e) -> e -> e
forall a b. (a -> b) -> a -> b
$ e
d e -> e -> e
forall a. Num a => a -> a -> a
* e
g e -> e -> e
forall a. Num a => a -> a -> a
* e
vi e -> e -> e
forall a. Fractional a => a -> a -> a
/ e
f) dx fx gy v
, M.computeAs S $ M.zipWith3 (\e
d e
f e
vi -> e -> e
forall {a}. RealFloat a => a -> a
fixNaN (e -> e) -> e -> e
forall a b. (a -> b) -> a -> b
$ e
d e -> e -> e
forall a. Num a => a -> a -> a
* e
vi e -> e -> e
forall a. Num a => a -> a -> a
* e -> e
forall a. Floating a => a -> a
log e
f) dx fx v)
diff Op
PowerAbs Array S ix e
dx Array r2 ix e
fx Array r3 ix e
gy val
l val
r = do
k <- ((Map (SRTree val) Ix1, b, IntMap (Array r3 ix e), d) -> Ix1)
-> f Ix1
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (SRTree val
-> (Map (SRTree val) Ix1, b, IntMap (Array r3 ix e), d) -> Ix1
forall {k} {a} {b} {c} {d}. Ord k => k -> (Map k a, b, c, d) -> a
getKey (Op -> val -> val -> SRTree val
forall val. Op -> val -> val -> SRTree val
Bin Op
PowerAbs val
l val
r))
v <- gets (getVal k)
let v2 = (e -> e) -> Array r2 ix e -> Array D ix e
forall ix r e' e.
(Index ix, Source r e') =>
(e' -> e) -> Array r ix e' -> Array D ix e
M.map e -> e
forall a. Num a => a -> a
abs Array r2 ix e
fx
v3 = S -> Array D ix e -> Array S ix e
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D ix e -> Array S ix e) -> Array D ix e -> Array S ix e
forall a b. (a -> b) -> a -> b
$ (e -> e -> e) -> Array r2 ix e -> Array r3 ix e -> Array D ix e
forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
M.zipWith e -> e -> e
forall a. Num a => a -> a -> a
(*) Array r2 ix e
fx Array r3 ix e
gy
pure ( M.computeAs S $ M.zipWith4 (\e
d e
v3i e
vi e
v2i -> e -> e
forall {a}. RealFloat a => a -> a
fixNaN (e -> e) -> e -> e
forall a b. (a -> b) -> a -> b
$ e
d e -> e -> e
forall a. Num a => a -> a -> a
* e
v3i e -> e -> e
forall a. Num a => a -> a -> a
* e
vi e -> e -> e
forall a. Fractional a => a -> a -> a
/ (e
v2ie -> Integer -> e
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2)) dx v3 v v2
, M.computeAs S $ M.zipWith3 (\e
d e
f e
vi -> e -> e
forall {a}. RealFloat a => a -> a
fixNaN (e -> e) -> e -> e
forall a b. (a -> b) -> a -> b
$ e
d e -> e -> e
forall a. Num a => a -> a -> a
* e
vi e -> e -> e
forall a. Num a => a -> a -> a
* e -> e
forall a. Floating a => a -> a
log e
f) dx v2 v)
diff Op
AQ Array S ix e
dx Array r2 ix e
fx Array r3 ix e
gy val
l val
r = let dxl :: Array D ix e
dxl = (e -> e -> e) -> Array r3 ix e -> Array S ix e -> Array D ix e
forall ix r1 e1 r2 e2 e.
(Index ix, Source r1 e1, Source r2 e2) =>
(e1 -> e2 -> e) -> Array r1 ix e1 -> Array r2 ix e2 -> Array D ix e
M.zipWith (\e
g e
d -> e
d e -> e -> e
forall a. Num a => a -> a -> a
* (e -> e
forall a. Fractional a => a -> a
recip (e -> e) -> (e -> e) -> e -> e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> e
forall a. Floating a => a -> a
sqrt (e -> e) -> (e -> e) -> e -> e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (e -> e -> e
forall a. Num a => a -> a -> a
+e
1) (e -> e) -> (e -> e) -> e -> e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (e -> Integer -> e
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
2)) e
g) Array r3 ix e
gy Array S ix e
dx
dxy :: Array D ix e
dxy = (e -> e -> e -> e)
-> Array r2 ix e -> Array r3 ix e -> Array D ix e -> Array D ix e
forall ix r1 e1 r2 e2 r3 e3 e.
(Index ix, Source r1 e1, Source r2 e2, Source r3 e3) =>
(e1 -> e2 -> e3 -> e)
-> Array r1 ix e1
-> Array r2 ix e2
-> Array r3 ix e3
-> Array D ix e
M.zipWith3 (\e
f e
g e
dl -> e
f e -> e -> e
forall a. Num a => a -> a -> a
* e
g e -> e -> e
forall a. Num a => a -> a -> a
* e
dle -> Integer -> e
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
3) Array r2 ix e
fx Array r3 ix e
gy Array D ix e
dxl
in (Array S ix e, Array S ix e) -> f (Array S ix e, Array S ix e)
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (S -> Array D ix e -> Array S ix e
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D ix e -> Array S ix e) -> Array D ix e -> Array S ix e
forall a b. (a -> b) -> a -> b
$ Array D ix e
dxl, S -> Array D ix e -> Array S ix e
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D ix e -> Array S ix e) -> Array D ix e -> Array S ix e
forall a b. (a -> b) -> a -> b
$ Array D ix e
dxy)
fixNaN :: a -> a
fixNaN a
x = if a -> Bool
forall a. RealFloat a => a -> Bool
isNaN a
x then a
0 else a
x
insertKey :: SRTree Ix1 -> m Ix1
insertKey SRTree Ix1
key = do
isCached <- ((Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
-> Bool)
-> m Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((SRTree Ix1
key SRTree Ix1 -> Map (SRTree Ix1) Ix1 -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`Map.member`) (Map (SRTree Ix1) Ix1 -> Bool)
-> ((Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
-> Map (SRTree Ix1) Ix1)
-> (Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Map (SRTree Ix1) Ix1, IntMap (SRTree Ix1),
IntMap (Array S Ix1 Double), Ix1)
-> Map (SRTree Ix1) Ix1
forall {a} {b} {c} {d}. (a, b, c, d) -> a
graph)
when (not isCached) $ do
ev <- evalKey key
modify' (insKey key ev)
gets (getKey key)
reverseModeArr :: SRMatrix
-> PVector
-> Maybe PVector
-> VS.Vector Double
-> [(Int, (Int, Int, Int, Double))]
-> IntMap.IntMap Int
-> (Array D Ix1 Double, Array S Ix1 Double)
reverseModeArr :: SRMatrix
-> Array S Ix1 Double
-> Maybe (Array S Ix1 Double)
-> Vector Double
-> [(Ix1, (Ix1, Ix1, Ix1, Double))]
-> IntMap Ix1
-> (Array D Ix1 Double, Array S Ix1 Double)
reverseModeArr SRMatrix
xss Array S Ix1 Double
ys Maybe (Array S Ix1 Double)
mYErr Vector Double
theta [(Ix1, (Ix1, Ix1, Ix1, Double))]
t IntMap Ix1
j2ix =
IO (Array D Ix1 Double, Array S Ix1 Double)
-> (Array D Ix1 Double, Array S Ix1 Double)
forall a. IO a -> a
unsafePerformIO (IO (Array D Ix1 Double, Array S Ix1 Double)
-> (Array D Ix1 Double, Array S Ix1 Double))
-> IO (Array D Ix1 Double, Array S Ix1 Double)
-> (Array D Ix1 Double, Array S Ix1 Double)
forall a b. (a -> b) -> a -> b
$ do
fwd <- Sz Ix2 -> Double -> IO (MArray (PrimState IO) S Ix2 Double)
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
Sz ix -> e -> m (MArray (PrimState m) r ix e)
forall ix (m :: * -> *).
(Index ix, PrimMonad m) =>
Sz ix -> Double -> m (MArray (PrimState m) S ix Double)
M.newMArray (Ix1 -> Ix1 -> Sz Ix2
Sz2 Ix1
n Ix1
m) Double
0
partial <- M.newMArray (Sz2 n m) 0
jacob <- M.newMArray (Sz p) 0
val <- M.newMArray (Sz m) 0
let
stps = Integer
2
(a, b) = (0, m)
forward (a, b) fwd
calculateYHat (a, b) fwd val
reverseMode (a, b) fwd partial
combine (a, b) partial jacob
j <- UMA.unsafeFreeze (getComp xss) jacob
v <- UMA.unsafeFreeze (getComp xss) val
pure (delay v, j)
where
(Sz2 Ix1
m Ix1
_) = SRMatrix -> Sz Ix2
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size SRMatrix
xss
p :: Ix1
p = Vector Double -> Ix1
forall a. Storable a => Vector a -> Ix1
VS.length Vector Double
theta
n :: Ix1
n = [(Ix1, (Ix1, Ix1, Ix1, Double))] -> Ix1
forall a. [a] -> Ix1
forall (t :: * -> *) a. Foldable t => t a -> Ix1
length [(Ix1, (Ix1, Ix1, Ix1, Double))]
t
toLin :: Ix1 -> Ix1 -> Ix1
toLin Ix1
i Ix1
j = Ix1
iIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
*Ix1
m Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+ Ix1
j
yErr :: Array S Ix1 Double
yErr = Maybe (Array S Ix1 Double) -> Array S Ix1 Double
forall a. HasCallStack => Maybe a -> a
fromJust Maybe (Array S Ix1 Double)
mYErr
eps :: Double
eps = Double
1e-8
myForM_ :: [t] -> (t -> f a) -> f ()
myForM_ [] t -> f a
_ = () -> f ()
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
myForM_ (!t
x:[t]
xs) t -> f a
f = do t -> f a
f t
x
[t] -> (t -> f a) -> f ()
myForM_ [t]
xs t -> f a
f
{-# INLINE myForM_ #-}
calculateYHat :: (Int, Int) -> MArray (PrimState IO) S Ix2 Double -> MArray (PrimState IO) S Ix1 Double -> IO ()
calculateYHat :: (Ix1, Ix1)
-> MArray (PrimState IO) S Ix2 Double
-> MArray (PrimState IO) S Ix1 Double
-> IO ()
calculateYHat (Ix1
a, Ix1
b) MArray (PrimState IO) S Ix2 Double
fwd MArray (PrimState IO) S Ix1 Double
yhat = [Ix1] -> (Ix1 -> IO ()) -> IO ()
forall {f :: * -> *} {t} {a}. Monad f => [t] -> (t -> f a) -> f ()
myForM_ [Ix1
a..Ix1
bIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1] ((Ix1 -> IO ()) -> IO ()) -> (Ix1 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ix1
i -> do
vi <- MArray (PrimState IO) S Ix2 Double -> Ix2 -> IO Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState IO) S Ix2 Double
fwd (Ix1
0 Ix1 -> Ix1 -> Ix2
:. Ix1
i)
UMA.unsafeWrite yhat i vi
{-# INLINE calculateYHat #-}
forward :: (Int, Int) -> MArray (PrimState IO) S Ix2 Double -> IO ()
forward :: (Ix1, Ix1) -> MArray (PrimState IO) S Ix2 Double -> IO ()
forward (Ix1
a, Ix1
b) MArray (PrimState IO) S Ix2 Double
fwd = do
let t' :: [(Ix1, (Ix1, Ix1, Ix1, Double))]
t' = [(Ix1, (Ix1, Ix1, Ix1, Double))]
-> [(Ix1, (Ix1, Ix1, Ix1, Double))]
forall a. [a] -> [a]
Prelude.reverse [(Ix1, (Ix1, Ix1, Ix1, Double))]
t
[(Ix1, (Ix1, Ix1, Ix1, Double))]
-> ((Ix1, (Ix1, Ix1, Ix1, Double)) -> IO ()) -> IO ()
forall {f :: * -> *} {t} {a}. Monad f => [t] -> (t -> f a) -> f ()
myForM_ [(Ix1, (Ix1, Ix1, Ix1, Double))]
t' (Ix1, (Ix1, Ix1, Ix1, Double)) -> IO ()
forall {f :: * -> *} {a}.
(PrimState f ~ RealWorld, Eq a, Num a, PrimMonad f) =>
(Ix1, (a, Ix1, Ix1, Double)) -> f ()
makeFwd
where
makeFwd :: (Ix1, (a, Ix1, Ix1, Double)) -> f ()
makeFwd (Ix1
j, (a
0, Ix1
0, Ix1
ix, Double
_)) =
do let j' :: Ix1
j' = IntMap Ix1
j2ix IntMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! Ix1
j
[Ix1] -> (Ix1 -> f ()) -> f ()
forall {f :: * -> *} {t} {a}. Monad f => [t] -> (t -> f a) -> f ()
myForM_ [Ix1
a..Ix1
bIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1] ((Ix1 -> f ()) -> f ()) -> (Ix1 -> f ()) -> f ()
forall a b. (a -> b) -> a -> b
$ \Ix1
i -> do
MArray (PrimState f) S Ix2 Double -> Ix2 -> Double -> f ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray (PrimState f) S Ix2 Double
MArray (PrimState IO) S Ix2 Double
fwd (Ix1
j' Ix1 -> Ix1 -> Ix2
:. Ix1
i) (Double -> f ()) -> Double -> f ()
forall a b. (a -> b) -> a -> b
$ case Ix1
ix of
(-1) -> Array S Ix1 Double
ys Array S Ix1 Double -> Ix1 -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
M.! Ix1
i
(-2) -> Array S Ix1 Double
yErr Array S Ix1 Double -> Ix1 -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
M.! Ix1
i
Ix1
_ -> SRMatrix
xss SRMatrix -> Ix2 -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
M.! (Ix1
i Ix1 -> Ix1 -> Ix2
:. Ix1
ix)
makeFwd (Ix1
j, (a
0, Ix1
1, Ix1
ix, Double
_)) = do let j' :: Ix1
j' = IntMap Ix1
j2ix IntMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! Ix1
j
v :: Double
v = Vector Double
theta Vector Double -> Ix1 -> Double
forall a. Storable a => Vector a -> Ix1 -> a
VS.! Ix1
ix
[Ix1] -> (Ix1 -> f ()) -> f ()
forall {f :: * -> *} {t} {a}. Monad f => [t] -> (t -> f a) -> f ()
myForM_ [Ix1
a..Ix1
bIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1] ((Ix1 -> f ()) -> f ()) -> (Ix1 -> f ()) -> f ()
forall a b. (a -> b) -> a -> b
$ \Ix1
i -> do
MArray (PrimState f) S Ix2 Double -> Ix2 -> Double -> f ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray (PrimState f) S Ix2 Double
MArray (PrimState IO) S Ix2 Double
fwd (Ix1
j' Ix1 -> Ix1 -> Ix2
:. Ix1
i) Double
v
makeFwd (Ix1
j, (a
0, Ix1
2, Ix1
_, Double
x)) = do let j' :: Ix1
j' = IntMap Ix1
j2ix IntMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! Ix1
j
[Ix1] -> (Ix1 -> f ()) -> f ()
forall {f :: * -> *} {t} {a}. Monad f => [t] -> (t -> f a) -> f ()
myForM_ [Ix1
a..Ix1
bIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1] ((Ix1 -> f ()) -> f ()) -> (Ix1 -> f ()) -> f ()
forall a b. (a -> b) -> a -> b
$ \Ix1
i -> do
MArray (PrimState f) S Ix2 Double -> Ix2 -> Double -> f ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray (PrimState f) S Ix2 Double
MArray (PrimState IO) S Ix2 Double
fwd (Ix1
j' Ix1 -> Ix1 -> Ix2
:. Ix1
i) Double
x
makeFwd (Ix1
j, (a
1, Ix1
f, Ix1
_, Double
_)) = do let j' :: Ix1
j' = IntMap Ix1
j2ix IntMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! Ix1
j
j2 :: Ix1
j2 = IntMap Ix1
j2ix IntMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! (Ix1
2Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
*Ix1
j Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+ Ix1
1)
[Ix1] -> (Ix1 -> f ()) -> f ()
forall {f :: * -> *} {t} {a}. Monad f => [t] -> (t -> f a) -> f ()
myForM_ [Ix1
a..Ix1
bIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1] ((Ix1 -> f ()) -> f ()) -> (Ix1 -> f ()) -> f ()
forall a b. (a -> b) -> a -> b
$ \Ix1
i -> do
v <- MArray (PrimState f) S Ix2 Double -> Ix2 -> f Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState f) S Ix2 Double
MArray (PrimState IO) S Ix2 Double
fwd (Ix1
j2 Ix1 -> Ix1 -> Ix2
:. Ix1
i)
UMA.unsafeWrite fwd (j' :. i) (evalFun (toEnum f) v)
makeFwd (Ix1
j, (a
2, Ix1
op, Ix1
_, Double
_)) = do let j' :: Ix1
j' = IntMap Ix1
j2ix IntMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! Ix1
j
j2 :: Ix1
j2 = IntMap Ix1
j2ix IntMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! (Ix1
2Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
*Ix1
j Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+ Ix1
1)
j3 :: Ix1
j3 = IntMap Ix1
j2ix IntMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! (Ix1
2Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
*Ix1
j Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+ Ix1
2)
[Ix1] -> (Ix1 -> f ()) -> f ()
forall {f :: * -> *} {t} {a}. Monad f => [t] -> (t -> f a) -> f ()
myForM_ [Ix1
a..Ix1
bIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1] ((Ix1 -> f ()) -> f ()) -> (Ix1 -> f ()) -> f ()
forall a b. (a -> b) -> a -> b
$ \Ix1
i -> do
l <- MArray (PrimState f) S Ix2 Double -> Ix2 -> f Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState f) S Ix2 Double
MArray (PrimState IO) S Ix2 Double
fwd (Ix1
j2 Ix1 -> Ix1 -> Ix2
:. Ix1
i)
r <- UMA.unsafeRead fwd (j3 :. i)
UMA.unsafeWrite fwd (j' :. i) (evalOp (toEnum op) l r)
makeFwd (Ix1, (a, Ix1, Ix1, Double))
_ = () -> f ()
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
{-# INLINE makeFwd #-}
{-# INLINE forward #-}
reverseMode :: (Int, Int) -> MArray (PrimState IO) S Ix2 Double -> MArray (PrimState IO) S Ix2 Double -> IO ()
reverseMode :: (Ix1, Ix1)
-> MArray (PrimState IO) S Ix2 Double
-> MArray (PrimState IO) S Ix2 Double
-> IO ()
reverseMode (Ix1
a, Ix1
b) MArray (PrimState IO) S Ix2 Double
fwd MArray (PrimState IO) S Ix2 Double
partial =
do [Ix1] -> (Ix1 -> IO ()) -> IO ()
forall {f :: * -> *} {t} {a}. Monad f => [t] -> (t -> f a) -> f ()
myForM_ [Ix1
a..Ix1
bIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1] ((Ix1 -> IO ()) -> IO ()) -> (Ix1 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ix1
i -> MArray (PrimState IO) S Ix2 Double -> Ix2 -> Double -> IO ()
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> e -> m ()
UMA.unsafeWrite MArray (PrimState IO) S Ix2 Double
partial (Ix1
0 Ix1 -> Ix1 -> Ix2
:. Ix1
i) Double
1
[(Ix1, (Ix1, Ix1, Ix1, Double))]
-> ((Ix1, (Ix1, Ix1, Ix1, Double)) -> IO ()) -> IO ()
forall {f :: * -> *} {t} {a}. Monad f => [t] -> (t -> f a) -> f ()
myForM_ [(Ix1, (Ix1, Ix1, Ix1, Double))]
t (Ix1, (Ix1, Ix1, Ix1, Double)) -> IO ()
forall {f :: * -> *} {a} {c} {d}.
(PrimState f ~ RealWorld, Eq a, Num a, PrimMonad f) =>
(Ix1, (a, Ix1, c, d)) -> f ()
makeRev
where
makeRev :: (Ix1, (a, Ix1, c, d)) -> f ()
makeRev (Ix1
j, (a
1, Ix1
f, c
_, d
_)) = do let dxj :: Ix1
dxj = IntMap Ix1
j2ix IntMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! Ix1
j
vj :: Ix1
vj = IntMap Ix1
j2ix IntMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! (Ix1
2Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
*Ix1
j Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+ Ix1
1)
[Ix1] -> (Ix1 -> f ()) -> f ()
forall {f :: * -> *} {t} {a}. Monad f => [t] -> (t -> f a) -> f ()
myForM_ [Ix1
a..Ix1
bIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1] ((Ix1 -> f ()) -> f ()) -> (Ix1 -> f ()) -> f ()
forall a b. (a -> b) -> a -> b
$ \Ix1
i -> do
v <- MArray (PrimState f) S Ix2 Double -> Ix2 -> f Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState f) S Ix2 Double
MArray (PrimState IO) S Ix2 Double
fwd (Ix1
vj Ix1 -> Ix1 -> Ix2
:. Ix1
i)
dx <- UMA.unsafeRead partial (dxj :. i)
UMA.unsafeWrite partial (vj :. i) (dx * derivative (toEnum f) v)
makeRev (Ix1
j, (a
2, Ix1
op, c
_, d
_)) = do let dxj :: Ix1
dxj = IntMap Ix1
j2ix IntMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! Ix1
j
lj :: Ix1
lj = IntMap Ix1
j2ix IntMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! (Ix1
2Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
*Ix1
j Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+ Ix1
1)
rj :: Ix1
rj = IntMap Ix1
j2ix IntMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! (Ix1
2Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
*Ix1
j Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+ Ix1
2)
[Ix1] -> (Ix1 -> f ()) -> f ()
forall {f :: * -> *} {t} {a}. Monad f => [t] -> (t -> f a) -> f ()
myForM_ [Ix1
a..Ix1
bIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1] ((Ix1 -> f ()) -> f ()) -> (Ix1 -> f ()) -> f ()
forall a b. (a -> b) -> a -> b
$ \Ix1
i -> do
l <- MArray (PrimState f) S Ix2 Double -> Ix2 -> f Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState f) S Ix2 Double
MArray (PrimState IO) S Ix2 Double
fwd (Ix1
lj Ix1 -> Ix1 -> Ix2
:. Ix1
i)
r <- UMA.unsafeRead fwd (rj :. i)
dx <- UMA.unsafeRead partial (dxj :. i)
let (dxl, dxr) = diff (toEnum op) dx l r
UMA.unsafeWrite partial (lj :. i) dxl
UMA.unsafeWrite partial (rj :. i) dxr
makeRev (Ix1, (a, Ix1, c, d))
_ = () -> f ()
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
{-# INLINE makeRev #-}
{-# INLINE reverseMode #-}
fixNaN :: a -> a
fixNaN a
x | a -> Bool
forall a. RealFloat a => a -> Bool
isNaN a
x = a
0
| Bool
otherwise = a
x
diff :: Op -> Double -> Double -> Double -> (Double, Double)
diff :: Op -> Double -> Double -> Double -> (Double, Double)
diff Op
Add Double
dx Double
fx Double
gy = (Double
dx, Double
dx)
diff Op
Sub Double
dx Double
fx Double
gy = (Double
dx, Double -> Double
forall a. Num a => a -> a
negate Double
dx)
diff Op
Mul Double
dx Double
fx Double
gy = (Double
dx Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
gy, Double
dx Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
fx)
diff Op
Div Double
dx Double
fx Double
gy = (Double
dx Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
gy, Double
dx Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double -> Double
forall a. Num a => a -> a
negate Double
fx Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (Double
gy Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
gy)))
diff Op
Power Double
0 Double
_ Double
_ = (Double
0, Double
0)
diff Op
Power Double
dx Double
0 Double
0 = (Double
0, Double
0)
diff Op
Power Double
dx Double
fx Double
0 = (Double
0, Double -> Double
forall {a}. RealFloat a => a -> a
fixNaN (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double
dx Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
log Double
fx)
diff Op
Power Double
dx Double
0 Double
gy = (Double -> Double
forall {a}. RealFloat a => a -> a
fixNaN (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double
dx Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
gy Double -> Double -> Double
forall a. Num a => a -> a -> a
* if Double
gy Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
1 then Double
eps Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
gy Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1) else Double
0
, Double
0)
diff Op
Power Double
dx Double
fx Double
gy = (Double -> Double
forall {a}. RealFloat a => a -> a
fixNaN (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double
dx Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
gy Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
fx Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
gy Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1), Double -> Double
forall {a}. RealFloat a => a -> a
fixNaN (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double
dx Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
fx Double -> Double -> Double
forall a. Floating a => a -> a -> a
** Double
gy Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
log Double
fx)
diff Op
PowerAbs Double
0 Double
fx Double
gy = (Double
0, Double
0)
diff Op
PowerAbs Double
0 Double
0 Double
0 = (Double
0, Double
0)
diff Op
PowerAbs Double
dx Double
fx Double
0 = (Double
0, Double -> Double
forall {a}. RealFloat a => a -> a
fixNaN (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double
dx Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
log (Double -> Double
forall a. Num a => a -> a
abs Double
fx))
diff Op
PowerAbs Double
dx Double
0 Double
gy = (Double
0, Double -> Double
forall {a}. RealFloat a => a -> a
fixNaN (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double
dx Double -> Double -> Double
forall a. Num a => a -> a -> a
* if Double
gy Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
0 then Double
eps Double -> Double -> Double
forall a. Floating a => a -> a -> a
** Double
gy else Double
0)
diff Op
PowerAbs Double
dx Double
fx Double
gy = (Double -> Double
forall {a}. RealFloat a => a -> a
fixNaN (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double
dx Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
gy Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
fx Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Num a => a -> a
abs Double
fx Double -> Double -> Double
forall a. Floating a => a -> a -> a
** (Double
gy Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
2), Double -> Double
forall {a}. RealFloat a => a -> a
fixNaN (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Double
dx Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Num a => a -> a
abs Double
fx Double -> Double -> Double
forall a. Floating a => a -> a -> a
** Double
gy Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double -> Double
forall a. Floating a => a -> a
log (Double -> Double
forall a. Num a => a -> a
abs Double
fx))
diff Op
AQ Double
dx Double
fx Double
gy = let dxl :: Double
dxl = Double -> Double
forall a. Fractional a => a -> a
recip ((Double -> Double
forall a. Floating a => a -> a
sqrt (Double -> Double) -> (Double -> Double) -> Double -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> Double -> Double
forall a. Num a => a -> a -> a
+Double
1)) (Double
gy Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
gy))
dxy :: Double
dxy = Double
fx Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
gy Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double
dxlDouble -> Integer -> Double
forall a b. (Num a, Integral b) => a -> b -> a
^Integer
3)
in (Double
dxl Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
dx, Double
dxy Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
dx)
{-# INLINE diff #-}
combine :: (Int, Int) -> MArray (PrimState IO) S Ix2 Double -> MArray (PrimState IO) S Ix1 Double -> IO ()
combine :: (Ix1, Ix1)
-> MArray (PrimState IO) S Ix2 Double
-> MArray (PrimState IO) S Ix1 Double
-> IO ()
combine (Ix1
lo, Ix1
hi) MArray (PrimState IO) S Ix2 Double
partial MArray (PrimState IO) S Ix1 Double
jacob = [(Ix1, (Ix1, Ix1, Ix1, Double))]
-> ((Ix1, (Ix1, Ix1, Ix1, Double)) -> IO ()) -> IO ()
forall {f :: * -> *} {t} {a}. Monad f => [t] -> (t -> f a) -> f ()
myForM_ [(Ix1, (Ix1, Ix1, Ix1, Double))]
t (Ix1, (Ix1, Ix1, Ix1, Double)) -> IO ()
forall {m :: * -> *} {a} {a} {d}.
(PrimState m ~ RealWorld, Eq a, Eq a, Num a, Num a, PrimMonad m) =>
(Ix1, (a, a, Ix1, d)) -> m ()
makeJacob
where
makeJacob :: (Ix1, (a, a, Ix1, d)) -> m ()
makeJacob (Ix1
j, (a
0, a
1, Ix1
ix, d
_)) = do val <- MArray (PrimState m) S Ix1 Double -> Ix1 -> m Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState m) S Ix1 Double
MArray (PrimState IO) S Ix1 Double
jacob Ix1
ix
let j' = IntMap Ix1
j2ix IntMap Ix1 -> Ix1 -> Ix1
forall a. IntMap a -> Ix1 -> a
IntMap.! Ix1
j
addI Ix1
a Ix1
b Double
acc = do v2 <- MArray (PrimState m) S Ix2 Double -> Ix2 -> m Double
forall r e ix (m :: * -> *).
(Manifest r e, Index ix, PrimMonad m) =>
MArray (PrimState m) r ix e -> ix -> m e
UMA.unsafeRead MArray (PrimState m) S Ix2 Double
MArray (PrimState IO) S Ix2 Double
partial (Ix1
b Ix1 -> Ix1 -> Ix2
:. Ix1
a)
pure (v2 + acc)
acc <- foldM (\Double
a Ix1
i -> Ix1 -> Ix1 -> Double -> m Double
forall {m :: * -> *}.
(PrimState m ~ RealWorld, PrimMonad m) =>
Ix1 -> Ix1 -> Double -> m Double
addI Ix1
i Ix1
j' Double
a) val [lo..hi-1]
UMA.unsafeWrite jacob ix acc
makeJacob (Ix1, (a, a, Ix1, d))
_ = () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
{-# INLINE combine #-}
forwardModeUniqueJac :: SRMatrix -> PVector -> Fix SRTree -> [PVector]
forwardModeUniqueJac :: SRMatrix
-> Array S Ix1 Double -> Fix SRTree -> [Array S Ix1 Double]
forwardModeUniqueJac SRMatrix
xss Array S Ix1 Double
theta = (Array D Ix1 Double, [Array S Ix1 Double]) -> [Array S Ix1 Double]
forall a b. (a, b) -> b
snd ((Array D Ix1 Double, [Array S Ix1 Double])
-> [Array S Ix1 Double])
-> (Fix SRTree -> (Array D Ix1 Double, [Array S Ix1 Double]))
-> Fix SRTree
-> [Array S Ix1 Double]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (DList (Array D Ix1 Double) -> [Array S Ix1 Double])
-> (Array D Ix1 Double, DList (Array D Ix1 Double))
-> (Array D Ix1 Double, [Array S Ix1 Double])
forall b c a. (b -> c) -> (a, b) -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second ((Array D Ix1 Double -> Array S Ix1 Double)
-> [Array D Ix1 Double] -> [Array S Ix1 Double]
forall a b. (a -> b) -> [a] -> [b]
map (S -> Array D Ix1 Double -> Array S Ix1 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
M.S) ([Array D Ix1 Double] -> [Array S Ix1 Double])
-> (DList (Array D Ix1 Double) -> [Array D Ix1 Double])
-> DList (Array D Ix1 Double)
-> [Array S Ix1 Double]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DList (Array D Ix1 Double) -> [Array D Ix1 Double]
forall a. DList a -> [a]
DL.toList) ((Array D Ix1 Double, DList (Array D Ix1 Double))
-> (Array D Ix1 Double, [Array S Ix1 Double]))
-> (Fix SRTree -> (Array D Ix1 Double, DList (Array D Ix1 Double)))
-> Fix SRTree
-> (Array D Ix1 Double, [Array S Ix1 Double])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SRTree (Array D Ix1 Double, DList (Array D Ix1 Double))
-> (Array D Ix1 Double, DList (Array D Ix1 Double)))
-> Fix SRTree -> (Array D Ix1 Double, DList (Array D Ix1 Double))
forall (f :: * -> *) a. Functor f => (f a -> a) -> Fix f -> a
cata SRTree (Array D Ix1 Double, DList (Array D Ix1 Double))
-> (Array D Ix1 Double, DList (Array D Ix1 Double))
alg
where
(Sz Ix1
n) = Array S Ix1 Double -> Sz Ix1
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
M.size Array S Ix1 Double
theta
one :: Array D Ix1 Double
one = SRMatrix -> Double -> Array D Ix1 Double
replicateAs SRMatrix
xss Double
1
alg :: SRTree (Array D Ix1 Double, DList (Array D Ix1 Double))
-> (Array D Ix1 Double, DList (Array D Ix1 Double))
alg (Var Ix1
ix) = (SRMatrix
xss SRMatrix -> Ix1 -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Ix1 -> Array D (Lower ix) e
<! Ix1
ix, DList (Array D Ix1 Double)
forall a. DList a
DL.empty)
alg (Param Ix1
ix) = (SRMatrix -> Double -> Array D Ix1 Double
replicateAs SRMatrix
xss (Double -> Array D Ix1 Double) -> Double -> Array D Ix1 Double
forall a b. (a -> b) -> a -> b
$ Array S Ix1 Double
theta Array S Ix1 Double -> Ix1 -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
! Ix1
ix, Array D Ix1 Double -> DList (Array D Ix1 Double)
forall a. a -> DList a
DL.singleton Array D Ix1 Double
one)
alg (Const Double
c) = (SRMatrix -> Double -> Array D Ix1 Double
replicateAs SRMatrix
xss Double
c, DList (Array D Ix1 Double)
forall a. DList a
DL.empty)
alg (Uni Function
f (Array D Ix1 Double
v, DList (Array D Ix1 Double)
gs)) = let v' :: Array D Ix1 Double
v' = Function -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Floating a => Function -> a -> a
evalFun Function
f Array D Ix1 Double
v
dv :: Array D Ix1 Double
dv = Function -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Floating a => Function -> a -> a
derivative Function
f Array D Ix1 Double
v
in (Array D Ix1 Double
v', (Array D Ix1 Double -> Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a b. (a -> b) -> DList a -> DList b
DL.map (Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
*Array D Ix1 Double
dv) DList (Array D Ix1 Double)
gs)
alg (Bin Op
Add (Array D Ix1 Double
v1, DList (Array D Ix1 Double)
l) (Array D Ix1 Double
v2, DList (Array D Ix1 Double)
r)) = (Array D Ix1 Double
v1Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
+Array D Ix1 Double
v2, DList (Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a. DList a -> DList a -> DList a
DL.append DList (Array D Ix1 Double)
l DList (Array D Ix1 Double)
r)
alg (Bin Op
Sub (Array D Ix1 Double
v1, DList (Array D Ix1 Double)
l) (Array D Ix1 Double
v2, DList (Array D Ix1 Double)
r)) = (Array D Ix1 Double
v1Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
-Array D Ix1 Double
v2, DList (Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a. DList a -> DList a -> DList a
DL.append DList (Array D Ix1 Double)
l ((Array D Ix1 Double -> Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a b. (a -> b) -> DList a -> DList b
DL.map Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a
negate DList (Array D Ix1 Double)
r))
alg (Bin Op
Mul (Array D Ix1 Double
v1, DList (Array D Ix1 Double)
l) (Array D Ix1 Double
v2, DList (Array D Ix1 Double)
r)) = (Array D Ix1 Double
v1Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
*Array D Ix1 Double
v2, DList (Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a. DList a -> DList a -> DList a
DL.append ((Array D Ix1 Double -> Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a b. (a -> b) -> DList a -> DList b
DL.map (Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
*Array D Ix1 Double
v2) DList (Array D Ix1 Double)
l) ((Array D Ix1 Double -> Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a b. (a -> b) -> DList a -> DList b
DL.map (Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
*Array D Ix1 Double
v1) DList (Array D Ix1 Double)
r))
alg (Bin Op
Div (Array D Ix1 Double
v1, DList (Array D Ix1 Double)
l) (Array D Ix1 Double
v2, DList (Array D Ix1 Double)
r)) = let dv :: Array D Ix1 Double
dv = ((-Array D Ix1 Double
v1)Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Fractional a => a -> a -> a
/(Array D Ix1 Double
v2Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
*Array D Ix1 Double
v2))
in (Array D Ix1 Double
v1Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Fractional a => a -> a -> a
/Array D Ix1 Double
v2, DList (Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a. DList a -> DList a -> DList a
DL.append ((Array D Ix1 Double -> Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a b. (a -> b) -> DList a -> DList b
DL.map (Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Fractional a => a -> a -> a
/Array D Ix1 Double
v2) DList (Array D Ix1 Double)
l) ((Array D Ix1 Double -> Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a b. (a -> b) -> DList a -> DList b
DL.map (Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
*Array D Ix1 Double
dv) DList (Array D Ix1 Double)
r))
alg (Bin Op
Power (Array D Ix1 Double
v1, DList (Array D Ix1 Double)
l) (Array D Ix1 Double
v2, DList (Array D Ix1 Double)
r)) = let dv1 :: Array D Ix1 Double
dv1 = Array D Ix1 Double
v1 Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Floating a => a -> a -> a
** (Array D Ix1 Double
v2 Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
- Array D Ix1 Double
one)
dv2 :: Array D Ix1 Double
dv2 = Array D Ix1 Double
v1 Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
* Array D Ix1 Double -> Array D Ix1 Double
forall a. Floating a => a -> a
log Array D Ix1 Double
v1
in (Array D Ix1 Double
v1 Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Floating a => a -> a -> a
** Array D Ix1 Double
v2, (Array D Ix1 Double -> Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a b. (a -> b) -> DList a -> DList b
DL.map (Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
*Array D Ix1 Double
dv1) (DList (Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a. DList a -> DList a -> DList a
DL.append ((Array D Ix1 Double -> Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a b. (a -> b) -> DList a -> DList b
DL.map (Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
*Array D Ix1 Double
v2) DList (Array D Ix1 Double)
l) ((Array D Ix1 Double -> Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a b. (a -> b) -> DList a -> DList b
DL.map (Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
*Array D Ix1 Double
dv2) DList (Array D Ix1 Double)
r)))
alg (Bin Op
PowerAbs (Array D Ix1 Double
v1, DList (Array D Ix1 Double)
l) (Array D Ix1 Double
v2, DList (Array D Ix1 Double)
r)) = let dv1 :: Array D Ix1 Double
dv1 = Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a
abs Array D Ix1 Double
v1 Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Floating a => a -> a -> a
** Array D Ix1 Double
v2
dv2 :: DList (Array D Ix1 Double)
dv2 = (Array D Ix1 Double -> Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a b. (a -> b) -> DList a -> DList b
DL.map (Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
* (Array D Ix1 Double -> Array D Ix1 Double
forall a. Floating a => a -> a
log (Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a
abs Array D Ix1 Double
v1))) DList (Array D Ix1 Double)
r
dv3 :: DList (Array D Ix1 Double)
dv3 = (Array D Ix1 Double -> Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a b. (a -> b) -> DList a -> DList b
DL.map (Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
*(Array D Ix1 Double
v2 Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Fractional a => a -> a -> a
/ Array D Ix1 Double
v1)) DList (Array D Ix1 Double)
l
in (Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a
abs Array D Ix1 Double
v1 Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Floating a => a -> a -> a
** Array D Ix1 Double
v2, (Array D Ix1 Double -> Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a b. (a -> b) -> DList a -> DList b
DL.map (Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
*Array D Ix1 Double
dv1) (DList (Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a. DList a -> DList a -> DList a
DL.append DList (Array D Ix1 Double)
dv2 DList (Array D Ix1 Double)
dv3))
alg (Bin Op
AQ (Array D Ix1 Double
v1, DList (Array D Ix1 Double)
l) (Array D Ix1 Double
v2, DList (Array D Ix1 Double)
r)) = let dv1 :: DList (Array D Ix1 Double)
dv1 = (Array D Ix1 Double -> Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a b. (a -> b) -> DList a -> DList b
DL.map (Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
*(Array D Ix1 Double
1 Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
+ Array D Ix1 Double
v2Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
*Array D Ix1 Double
v2)) DList (Array D Ix1 Double)
l
dv2 :: DList (Array D Ix1 Double)
dv2 = (Array D Ix1 Double -> Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a b. (a -> b) -> DList a -> DList b
DL.map (Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
*(-Array D Ix1 Double
v1Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
*Array D Ix1 Double
v2)) DList (Array D Ix1 Double)
r
in (Array D Ix1 Double
v1Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Fractional a => a -> a -> a
/Array D Ix1 Double -> Array D Ix1 Double
forall a. Floating a => a -> a
sqrt(Array D Ix1 Double
1 Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
+ Array D Ix1 Double
v2Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
*Array D Ix1 Double
v2), (Array D Ix1 Double -> Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a b. (a -> b) -> DList a -> DList b
DL.map (Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Fractional a => a -> a -> a
/(Array D Ix1 Double
1 Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
+ Array D Ix1 Double
v2Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Num a => a -> a -> a
*Array D Ix1 Double
v2)Array D Ix1 Double -> Array D Ix1 Double -> Array D Ix1 Double
forall a. Floating a => a -> a -> a
**Array D Ix1 Double
1.5) (DList (Array D Ix1 Double) -> DList (Array D Ix1 Double))
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a b. (a -> b) -> a -> b
$ DList (Array D Ix1 Double)
-> DList (Array D Ix1 Double) -> DList (Array D Ix1 Double)
forall a. DList a -> DList a -> DList a
DL.append DList (Array D Ix1 Double)
dv1 DList (Array D Ix1 Double)
dv2)