Skip to content

Commit

Permalink
Mock up Tom's Flattenable idea for Num over any TK
Browse files Browse the repository at this point in the history
  • Loading branch information
Mikolaj committed Dec 10, 2024
1 parent 58e33f0 commit b59b158
Show file tree
Hide file tree
Showing 4 changed files with 182 additions and 18 deletions.
3 changes: 2 additions & 1 deletion horde-ad.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,8 @@ common options
default-extensions: TypeAbstractions
other-extensions: UnboxedTuples, CPP, ViewPatterns, OverloadedLists,
DerivingStrategies, DeriveAnyClass, TupleSections,
UndecidableInstances, AllowAmbiguousTypes
UndecidableInstances, AllowAmbiguousTypes,
QuantifiedConstraints
ghc-options: -Wall -Wcompat -Wimplicit-prelude -Wmissing-home-modules -Widentities -Wredundant-constraints -Wmissing-export-lists -Wpartial-fields -Wunused-packages
ghc-options: -Wno-unticked-promoted-constructors -fprint-explicit-kinds
if impl(ghc >= 9.2)
Expand Down
160 changes: 156 additions & 4 deletions src/HordeAd/Core/HVectorOps.hs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE QuantifiedConstraints, UndecidableInstances #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-}
-- | A class containing array operations, with some extra algebraic operations
Expand Down Expand Up @@ -30,10 +30,21 @@ import GHC.TypeLits (KnownNat, SomeNat (..), sameNat, someNatVal, type (+))
import Type.Reflection (typeRep)
import Unsafe.Coerce (unsafeCoerce)

import Data.Array.Mixed.Shape (KnownShX (..), ssxFromShape)
import Data.Array.Mixed.Shape
(KnownShX (..), ssxAppend, ssxFromShape, ssxReplicate)
import Data.Array.Nested
(IShR, KnownShS (..), Rank, ShR (..), ShS (..), pattern (:$:), pattern ZSR)
import Data.Array.Nested.Internal.Shape (shrRank)
( IShR
, KnownShS (..)
, MapJust
, Rank
, Replicate
, ShR (..)
, ShS (..)
, pattern (:$:)
, pattern ZSR
, type (++)
)
import Data.Array.Nested.Internal.Shape (shCvtSX, shrRank, shsAppend)

import HordeAd.Core.TensorClass
import HordeAd.Core.TensorKind
Expand Down Expand Up @@ -121,6 +132,147 @@ addRepD a b = case (a, b) of
(DTKUntyped hv1, DTKUntyped hv2) ->
DTKUntyped $ V.zipWith addDynamic hv1 hv2


-- * Winding

type family UnWind tk where
UnWind (TKScalar r) =
TKScalar r
UnWind (TKR2 n (TKScalar r)) =
TKR2 n (TKScalar r)
UnWind (TKR2 n (TKR2 m x)) =
UnWind (TKR2 (n + m) x)
UnWind (TKR2 n (TKS2 sh x)) =
UnWind (TKX2 (Replicate n Nothing ++ MapJust sh) x)
UnWind (TKR2 n (TKX2 sh x)) =
UnWind (TKX2 (Replicate n Nothing ++ sh) x)
UnWind (TKR2 n (TKProduct y z)) =
TKProduct (UnWind (TKR2 n y)) (UnWind (TKR2 n z))
UnWind (TKS2 sh (TKScalar r)) =
TKS2 sh (TKScalar r)
UnWind (TKS2 sh (TKR2 m x)) =
UnWind (TKX2 (MapJust sh ++ Replicate m Nothing) x)
UnWind (TKS2 sh (TKS2 sh2 x)) =
UnWind (TKS2 (sh ++ sh2) x)
UnWind (TKS2 sh (TKX2 sh2 x)) =
UnWind (TKX2 (MapJust sh ++ sh2) x)
UnWind (TKS2 sh (TKProduct y z)) =
TKProduct (UnWind (TKS2 sh y)) (UnWind (TKS2 sh z))
UnWind (TKX2 sh (TKScalar r)) =
TKX2 sh (TKScalar r)
UnWind (TKX2 sh (TKR2 m x)) =
UnWind (TKX2 (sh ++ Replicate m Nothing) x)
UnWind (TKX2 sh (TKS2 sh2 x)) =
UnWind (TKX2 (sh ++ MapJust sh2) x)
UnWind (TKX2 sh (TKX2 sh2 x)) =
UnWind (TKX2 (sh ++ sh2) x)
UnWind (TKX2 sh (TKProduct y z)) =
TKProduct (UnWind (TKX2 sh y)) (UnWind (TKX2 sh z))
UnWind (TKProduct y z) =
TKProduct (UnWind y) (UnWind z)
UnWind TKUntyped =
TKUntyped

unWindSTK :: STensorKindType y -> STensorKindType (UnWind y)
unWindSTK = \case
stk@STKScalar{} -> stk
stk@(STKR _ STKScalar{}) -> stk
STKR (SNat @n) (STKR (SNat @m) x) ->
unWindSTK $ STKR (SNat @(n + m)) x
STKR n (STKS sh x) ->
unWindSTK $ STKX (ssxReplicate n `ssxAppend` ssxFromShape (shCvtSX sh)) x
STKR n (STKX sh x) ->
unWindSTK $ STKX (ssxReplicate n `ssxAppend` sh) x
STKR n (STKProduct y z) ->
unWindSTK $ STKProduct (STKR n y) (STKR n z)
stk@(STKS _ STKScalar{}) -> stk
STKS sh (STKR m x) ->
unWindSTK
$ STKX (ssxFromShape (shCvtSX sh) `ssxAppend` ssxReplicate m) x
STKS sh (STKS sh2 x) ->
unWindSTK $ STKS (shsAppend sh sh2) x
STKS sh (STKX sh2 x) ->
unWindSTK $ STKX (ssxFromShape (shCvtSX sh) `ssxAppend` sh2) x
STKS sh (STKProduct y z) ->
unWindSTK $ STKProduct (STKS sh y) (STKS sh z)
stk@(STKX _ STKScalar{}) -> stk
STKX sh (STKR m x) ->
unWindSTK $ STKX (sh `ssxAppend` ssxReplicate m) x
STKX sh (STKS sh2 x) ->
unWindSTK $ STKX (sh `ssxAppend` ssxFromShape (shCvtSX sh2)) x
STKX sh (STKX sh2 x) ->
unWindSTK $ STKX (sh `ssxAppend` sh2) x
STKX sh (STKProduct y z) ->
unWindSTK $ STKProduct (STKX sh y) (STKX sh z)
STKProduct y z | (Dict, Dict) <- lemTensorKind1OfSTK (unWindSTK y)
, (Dict, Dict) <- lemTensorKind1OfSTK (unWindSTK z) ->
STKProduct (unWindSTK y) (unWindSTK z)
stk@STKUntyped -> stk

-- Alternatively the codomain could be RepD, which clearly indicates
-- what the normal form of UnWind is.
unWindShare :: (BaseTensor target, ShareTensor target)
=> STensorKindType y -> target y -> target (UnWind y)
unWindShare stk t = case stk of
STKScalar{} -> t
STKR _ STKScalar{} -> t
STKS _ STKScalar{} -> t
STKS sh (STKS sh2 x) | Dict <- lemTensorKindOfSTK x
, Dict <- lemTensorKindOfSTK (unWindSTK x) ->
withKnownShS sh $ withKnownShS sh2 $ withKnownShS (shsAppend sh sh2)
$ unWindShare (STKS (shsAppend sh sh2) x) (sunNest t)
STKS sh (STKProduct stk1 stk2) | Dict <- lemTensorKindOfSTK stk1
, Dict <- lemTensorKindOfSTK stk2
, Dict <- lemTensorKindOfSTK (unWindSTK stk1)
, Dict <- lemTensorKindOfSTK (unWindSTK stk2) ->
unWindShare (STKProduct (STKS sh stk1) (STKS sh stk2)) (sunzip t)
STKX _ STKScalar{} -> t
STKProduct stk1 stk2 | Dict <- lemTensorKindOfSTK stk1
, Dict <- lemTensorKindOfSTK stk2
, (Dict, Dict) <- lemTensorKind1OfSTK (unWindSTK stk1)
, (Dict, Dict) <- lemTensorKind1OfSTK (unWindSTK stk2) ->
let (t1, t2) = tunpair t
in tpair (unWindShare stk1 t1) (unWindShare stk2 t2)
STKUntyped -> t
_ -> error "TODO"

wind :: BaseTensor target
=> STensorKindType y -> target (UnWind y) -> target y
wind stk t = undefined

addWindShare ::
(ADReadyNoLet target, ShareTensor target)
=> STensorKindType y
-> target y -> target y -> target y
addWindShare stk a b = case stk of
STKScalar{} -> a + b
STKR SNat STKScalar{} -> a + b
STKS sh STKScalar{} -> withKnownShS sh $ a + b
STKX sh STKScalar{} -> withKnownShX sh $ a + b
STKProduct stk1 stk2 | Dict <- lemTensorKindOfSTK stk1
, Dict <- lemTensorKindOfSTK stk2 ->
let (a1, a2) = tunpair a
(b1, b2) = tunpair b
in tpair (addWindShare stk1 a1 b1) (addWindShare stk2 a2 b2)
STKUntyped ->
let va = tunvector a
vb = tunvector b
in dmkHVector $ V.zipWith addDynamic va vb
_ -> error "addWindShare: impossible normal form of UnWind"

addShare ::
(ADReadyNoLet target, ShareTensor target)
=> STensorKindType y
-> target y -> target y -> target y
addShare stk a b =
let stk2 = unWindSTK stk
a2 = unWindShare stk a
b2 = unWindShare stk b
in wind stk $ addWindShare stk2 a2 b2


-- * Dynamic

addDynamic :: forall target.
(BaseTensor target, (forall y. TensorKind y => Show (target y)))
=> DynamicTensor target -> DynamicTensor target
Expand Down
4 changes: 4 additions & 0 deletions src/HordeAd/Core/TensorClass.hs
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,10 @@ class ( Num (IntOf target)
sunNest :: forall sh1 sh2 r.
(TensorKind1 r, KnownShS sh1, KnownShS sh2, KnownShS (sh1 ++ sh2))
=> target (TKS2 sh1 (TKS2 sh2 r)) -> target (TKS2 (sh1 ++ sh2) r)
szip :: target (TKProduct (TKS2 sh y) (TKS2 sh z))
-> target (TKS2 sh (TKProduct y z))
sunzip :: target (TKS2 sh (TKProduct y z))
-> target (TKProduct (TKS2 sh y) (TKS2 sh z))
sfromR :: (GoodScalar r, KnownShS sh, KnownNat (Rank sh))
=> target (TKR (Rank sh) r) -> target (TKS sh r)
sfromX :: ( KnownShS sh, KnownShX sh', Rank sh ~ Rank sh'
Expand Down
33 changes: 20 additions & 13 deletions src/HordeAd/Core/TensorKind.hs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
module HordeAd.Core.TensorKind
( -- * Singletons
STensorKindType(..), TensorKind(..)
, lemTensorKindOfSTK, sameTensorKind, sameSTK
, lemTensorKindOfSTK, lemTensorKind1OfSTK, sameTensorKind, sameSTK
, lemTensorKindOfBuild, lemTensorKind1OfBuild
, lemTensorKindOfAD, lemTensorKind1OfAD, lemBuildOfAD
, FullTensorKind(..), lemTensorKindOfFTK, buildFTK
Expand Down Expand Up @@ -108,17 +108,23 @@ instance TensorKind TKUntyped where
stensorKind = STKUntyped

lemTensorKindOfSTK :: STensorKindType y -> Dict TensorKind y
lemTensorKindOfSTK = \case
STKScalar _ -> Dict
STKR SNat x -> case lemTensorKindOfSTK x of
Dict -> Dict
STKS sh x -> case lemTensorKindOfSTK x of
Dict -> withKnownShS sh Dict
STKX sh x -> case lemTensorKindOfSTK x of
Dict -> withKnownShX sh Dict
STKProduct stk1 stk2 | Dict <- lemTensorKindOfSTK stk1
, Dict <- lemTensorKindOfSTK stk2 -> Dict
STKUntyped -> Dict
lemTensorKindOfSTK = fst . lemTensorKind1OfSTK

lemTensorKind1OfSTK :: STensorKindType y
-> ( Dict TensorKind y
, Dict Nested.Elt (RepORArray y) )
lemTensorKind1OfSTK = \case
STKScalar _ -> (Dict, Dict)
STKR SNat x -> case lemTensorKind1OfSTK x of
(Dict, Dict) -> (Dict, Dict)
STKS sh x -> case lemTensorKind1OfSTK x of
(Dict, Dict) -> withKnownShS sh (Dict, Dict)
STKX sh x -> case lemTensorKind1OfSTK x of
(Dict, Dict) -> withKnownShX sh (Dict, Dict)
STKProduct stk1 stk2 | (Dict, Dict) <- lemTensorKind1OfSTK stk1
, (Dict, Dict) <- lemTensorKind1OfSTK stk2 -> (Dict, Dict)
STKUntyped ->
(Dict, unsafeCoerce (Dict @Nested.Elt @Double)) -- never nested in arrays

sameTensorKind :: forall y1 y2. (TensorKind y1, TensorKind y2) => Maybe (y1 :~: y2)
sameTensorKind = sameSTK (stensorKind @y1) (stensorKind @y2)
Expand Down Expand Up @@ -307,7 +313,8 @@ type role RepN nominal
newtype RepN y = RepN {unRepN :: RepORArray y}

type GoodTKConstraint y =
( Default (RepORArray y), Show (RepORArray y), Nested.KnownElt (RepORArray y)
( Default (RepORArray y) -- TODO: remove
, Show (RepORArray y), Nested.KnownElt (RepORArray y)
, Num (RepORArray (ADTensorKind y)) )

-- A class so that the constraint can be represented by a single Dict.
Expand Down

0 comments on commit b59b158

Please sign in to comment.