summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authornalex <>2020-09-15 13:08:00 (GMT)
committerhdiff <hdiff@hdiff.luite.com>2020-09-15 13:08:00 (GMT)
commitc41b3a425c3945d8d50f4a9d57ccfc9b92ef9492 (patch)
treed53b48acc6c8f19394ac012bdc26cd17a537bf0a
parent4ae02a1b360162664f710c72e0467cdd4c7ee438 (diff)
version 0.2.1.0HEAD0.2.1.0master
-rwxr-xr-xCHANGELOG.md6
-rw-r--r--safe-tensor.cabal5
-rw-r--r--src/Math/Tensor.hs5
-rw-r--r--src/Math/Tensor/LinearAlgebra/Equations.hs22
-rw-r--r--src/Math/Tensor/LinearAlgebra/Scalar.hs13
-rw-r--r--src/Math/Tensor/Safe.hs49
-rw-r--r--src/Math/Tensor/Safe/TH.hs15
-rw-r--r--src/Math/Tensor/Safe/Vector.hs6
8 files changed, 90 insertions, 31 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index e26a88d..b88e74e 100755
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,11 @@
# Changelog
+## [0.2.1.0] - 2020-07-20
+ * removeZeros is optional: Tensor addition and tensor contraction will not remove zero values afterwards in order to improve performance.
+ Zero values have to be removed by explicitly applying removeZeros to a tensor.
+ * all data types are now instances of NFData
+ * added functionality to read solution from an rref matrix in a reversed manner
+
## [0.2.0.0] - 2020-07-08
* Minor API adjustments
* major documentation improvements
diff --git a/safe-tensor.cabal b/safe-tensor.cabal
index 68fb4a5..d970a3c 100644
--- a/safe-tensor.cabal
+++ b/safe-tensor.cabal
@@ -4,10 +4,10 @@ cabal-version: 1.12
--
-- see: https://github.com/sol/hpack
--
--- hash: 01b45dfa508544cf3d4c145d847b74cfe33a819cd020b449016b3c1d38cf3afc
+-- hash: e0fcc8fb2efd3f5b62b7dd2ad65bca4521755f5226831bfe4553c23ab94904f6
name: safe-tensor
-version: 0.2.0.0
+version: 0.2.1.0
synopsis: Dependently typed tensor algebra
description: For an introduction to the library, see "Math.Tensor.Safe". For more information, see the README on GitHub at <https://github.com/nilsalex/safe-tensor#readme>
category: Math
@@ -53,6 +53,7 @@ library
base >=4.7 && <5
, constraints >=0.10 && <0.13
, containers >=0.6 && <0.7
+ , deepseq >=1.4 && <1.5
, hmatrix >=0.20 && <0.21
, mtl >=2.2 && <2.3
, singletons >=2.5 && <2.8
diff --git a/src/Math/Tensor.hs b/src/Math/Tensor.hs
index 11d3c4b..42357b7 100644
--- a/src/Math/Tensor.hs
+++ b/src/Math/Tensor.hs
@@ -97,6 +97,8 @@ import Data.Bifunctor (first)
import Data.List.NonEmpty (NonEmpty((:|)),sort)
+import Control.DeepSeq (NFData(rnf))
+
import Control.Monad.Except (MonadError, throwError)
-- |@'T'@ wraps around @'Tensor'@ and exposes only the value type @v@.
@@ -105,6 +107,9 @@ data T :: Type -> Type where
deriving instance Show v => Show (T v)
+instance NFData v => NFData (T v) where
+ rnf (T t) = rnf t
+
instance Functor T where
fmap f (T t) = T $ fmap f t
diff --git a/src/Math/Tensor/LinearAlgebra/Equations.hs b/src/Math/Tensor/LinearAlgebra/Equations.hs
index 7d2de12..acca194 100644
--- a/src/Math/Tensor/LinearAlgebra/Equations.hs
+++ b/src/Math/Tensor/LinearAlgebra/Equations.hs
@@ -25,7 +25,9 @@ module Math.Tensor.LinearAlgebra.Equations
, systemRank
, Solution
, fromRref
+ , fromRrefRev
, fromRow
+ , fromRowRev
, applySolution
, solveTensor
, solveSystem
@@ -73,6 +75,7 @@ import qualified Data.IntMap.Strict as IM
, difference
, intersectionWith
, mapKeys
+ , empty
)
import Data.List (nub, sort)
import Data.Ratio (numerator, denominator)
@@ -84,7 +87,7 @@ type Equation a = IM.IntMap a
-- |Extract linear equations from tensor components.
-- The equations are normalized, sorted, and made unique.
tensorToEquations :: Integral a => T (Poly Rational) -> [Equation a]
-tensorToEquations = nub . sort . fmap (equationFromRational . normalize . snd) . toListT
+tensorToEquations = nub . sort . filter (not . IM.null) . fmap (equationFromRational . normalize . snd) . toListT
-- |Extract linear equation with integral coefficients from polynomial
-- tensor component with rational coefficients.
@@ -97,6 +100,8 @@ equationFromRational (Affine x (Lin lin))
fac :: a
fac = IM.foldl' (\acc v -> lcm (fromIntegral (denominator v)) acc) 1 lin
lin' = IM.map (\v -> fromIntegral (numerator v) * (fac `div` fromIntegral (denominator v))) lin
+equationFromRational (Const c)
+ | c == 0 = IM.empty
equationFromRational _ = error "equation can only be extracted from linear scalar!"
-- |Convert list of equations to sparse matrix representation of the
@@ -147,6 +152,12 @@ fromRref ref = IM.fromList assocs
rows = HM.toLists ref
assocs = mapMaybe fromRow rows
+fromRrefRev :: HM.Matrix HM.Z -> Solution
+fromRrefRev ref = IM.fromList assocs
+ where
+ rows = fmap reverse $ HM.toLists ref
+ assocs = mapMaybe fromRowRev rows
+
-- |Read single substitution rule from single
-- row of reduced row echelon form.
fromRow :: forall a.Integral a => [a] -> Maybe (Int, Poly Rational)
@@ -158,6 +169,15 @@ fromRow xs = case assocs of
where
assocs = filter ((/=0). snd) $ zip [(1::Int)..] xs
+fromRowRev :: forall a.Integral a => [a] -> Maybe (Int, Poly Rational)
+fromRowRev xs = case assocs of
+ [] -> Nothing
+ [(i,_)] -> Just (i, Const 0)
+ (i, v):assocs' -> let assocs'' = fmap (\(i',v') -> (i', - fromIntegral @a @Rational v' / fromIntegral @a @Rational v)) assocs'
+ in Just (i, Affine 0 (Lin (IM.fromList assocs'')))
+ where
+ assocs = reverse $ filter ((/=0). snd) $ zip [(1::Int)..] xs
+
-- |Apply substitution rules to tensor component.
applySolution :: Solution -> Poly Rational -> Poly Rational
applySolution s (Affine x (Lin lin))
diff --git a/src/Math/Tensor/LinearAlgebra/Scalar.hs b/src/Math/Tensor/LinearAlgebra/Scalar.hs
index 68fffa5..a5d2a52 100644
--- a/src/Math/Tensor/LinearAlgebra/Scalar.hs
+++ b/src/Math/Tensor/LinearAlgebra/Scalar.hs
@@ -1,5 +1,7 @@
{-# LANGUAGE Safe #-}
{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE DeriveAnyClass #-}
-----------------------------------------------------------------------------
{-|
@@ -34,16 +36,19 @@ import qualified Data.IntMap.Strict as IM
, findMin
)
+import GHC.Generics (Generic)
+import Control.DeepSeq (NFData)
+
-- |Linear combination represented as mapping from
-- variable number to prefactor.
-newtype Lin a = Lin (IM.IntMap a) deriving (Show, Ord, Eq)
+newtype Lin a = Lin (IM.IntMap a) deriving (Show, Ord, Eq, Generic, NFData)
-- |Polynomial: Can be constant, affine, or something of higher
-- rank which is not yet implemented.
data Poly a = Const !a -- ^ constant value
| Affine !a !(Lin a) -- ^ constant value plus linear term
| NotSupported -- ^ higher rank
- deriving (Show, Ord, Eq)
+ deriving (Show, Ord, Eq, Generic, NFData)
-- |Produces an affine value \(c + a\cdot x_i\)
singletonPoly :: a -- ^ constant
@@ -107,8 +112,8 @@ shiftVars s (Affine a (Lin lin)) =
-- \mathrm{normalize}(c) = 1 \\
-- \mathrm{normalize}(c + a_1\cdot x_1 + a_2\cdot x_2 + \dots + a_n\cdot x_n) = \frac{c}{a_1} + 1\cdot x_1 + \frac{a_2}{a_1}\cdot x_2 + \dots + \frac{a_n}{a_1}\cdot x_n
-- \]
-normalize :: Fractional a => Poly a -> Poly a
-normalize (Const _) = Const 1
+normalize :: (Fractional a, Eq a) => Poly a -> Poly a
+normalize (Const c) = Const $ if c == 0 then 0 else 1
normalize NotSupported = NotSupported
normalize (Affine a (Lin lin)) = Affine (a/v) $ Lin $ IM.map (/v) lin
where
diff --git a/src/Math/Tensor/Safe.hs b/src/Math/Tensor/Safe.hs
index d8fce5a..76fb663 100644
--- a/src/Math/Tensor/Safe.hs
+++ b/src/Math/Tensor/Safe.hs
@@ -238,10 +238,12 @@ import Data.Singletons.Decide
)
import Data.Singletons.TypeLits (Nat, Symbol)
-import Data.Maybe (catMaybes)
+import Data.Maybe (mapMaybe)
import Data.Bifunctor (first,second)
import Data.List (foldl',groupBy,sortBy)
+import Control.DeepSeq (NFData(rnf))
+
-- |The @'Tensor'@ type parameterized by its generalized rank @r@ and
-- arbitrary value type @v@.
data Tensor :: Rank -> Type -> Type where
@@ -254,6 +256,11 @@ data Tensor :: Rank -> Type -> Type where
deriving instance Eq v => Eq (Tensor r v)
deriving instance Show v => Show (Tensor r v)
+instance NFData v => NFData (Tensor r v) where
+ rnf ZeroTensor = ()
+ rnf (Scalar v) = rnf v
+ rnf (Tensor ts) = rnf ts
+
instance Functor (Tensor r) where
fmap _ ZeroTensor = ZeroTensor
fmap f (Scalar s) = Scalar $ f s
@@ -294,11 +301,8 @@ removeZeros (Tensor ms) =
Tensor r v -> Tensor r' v -> Tensor r v
(&+) ZeroTensor t = t
(&+) t ZeroTensor = t
-(&+) (Scalar s) (Scalar s') =
- if s'' == 0 then ZeroTensor else Scalar s''
- where
- s'' = s + s'
-(&+) (Tensor xs) (Tensor xs') = removeZeros $ Tensor xs''
+(&+) (Scalar s) (Scalar s') = Scalar (s + s')
+(&+) (Tensor xs) (Tensor xs') = Tensor xs''
where
xs'' = unionWith (&+) id id xs xs'
(&+) _ _ = error "Cannot add scalar and tensor! Should have been caught by the type system!"
@@ -414,33 +418,34 @@ contract'' sr (Tensor ms) =
in case sv %== sv' of
SFalse ->
case contractTailDiffVProof sr of
- Sub Dict -> removeZeros $ Tensor $ fmap (fmap (contract'' st)) ms
+ Sub Dict -> Tensor $ fmap (fmap (contract'' st)) ms
STrue -> case si of
SICon sa -> case si' of
SICov sb -> case sa %== sb of
STrue ->
- let ms' = fmap (\(i, v) -> case v of
- Tensor vs ->
- case filter (\(i', _) -> i == i') vs of
- [] -> Nothing
- [(_, v')] -> Just v'
- _ -> error "duplicate key in tensor assoc list") ms
- ms'' = catMaybes ms' :: [Tensor (TailR (TailR r)) v]
+ let ms' = mapMaybe (\(i, v) -> case v of
+ Tensor vs ->
+ case filter (\(i', _) -> i == i') vs of
+ [] -> Nothing
+ [(_, v')] -> Just v'
+ _ -> error "duplicate key in tensor assoc list"
+ ZeroTensor -> Nothing)
+ ms :: [Tensor (TailR (TailR r)) v]
in case saneTailRProof sr of
Sub Dict ->
case saneTailRProof st of
Sub Dict ->
case contractTailSameVSameIProof sr of
- Sub Dict -> contract' st' $ foldl' (&+) ZeroTensor ms''
+ Sub Dict -> contract' st' $ foldl' (&+) ZeroTensor ms'
SFalse ->
case contractTailSameVDiffIProof sr of
- Sub Dict -> removeZeros $ Tensor $ fmap (fmap (contract'' st)) ms
+ Sub Dict -> Tensor $ fmap (fmap (contract'' st)) ms
SICon _ ->
case contractTailSameVNoCovProof sr of
- Sub Dict -> removeZeros $ Tensor $ fmap (fmap (contract'' st)) ms
+ Sub Dict -> Tensor $ fmap (fmap (contract'' st)) ms
SICov _ ->
case contractTailSameVNoConProof sr of
- Sub Dict -> removeZeros $ Tensor $ fmap (fmap (contract'' st)) ms
+ Sub Dict -> Tensor $ fmap (fmap (contract'' st)) ms
-- |Tensor contraction. Contracting a tensor is the identity function on non-contractible tensors.
-- Otherwise, the result is the contracted tensor with the contracted labels removed from the
@@ -612,14 +617,18 @@ toList (Tensor ms) =
in case st of
SNil ->
case sn of
- SS SZ -> fmap (\(i, Scalar s) -> (VCons i VNil, s)) ms
+ SS SZ -> mapMaybe (\(i, x) -> case x of
+ ZeroTensor -> Nothing
+ Scalar s -> Just (VCons i VNil, s)) ms
_ ->
case sn of
SS sm' ->
withSingI sm' $
case sm %~ sm' of
Proved Refl ->
- concatMap (\(i, v) -> case v of Tensor _ -> fmap (first (VCons i)) (withSingI st $ toList v)) ms
+ concatMap (\(i, v) -> case v of
+ Tensor _ -> fmap (first (VCons i)) (withSingI st $ toList v)
+ ZeroTensor -> []) ms
-- |Construct @'Tensor'@ from assocs list. Keys are length-typed vectors of indices. Generalized
-- rank is passed explicitly as singleton.
diff --git a/src/Math/Tensor/Safe/TH.hs b/src/Math/Tensor/Safe/TH.hs
index 1202c01..bf7ac84 100644
--- a/src/Math/Tensor/Safe/TH.hs
+++ b/src/Math/Tensor/Safe/TH.hs
@@ -17,6 +17,8 @@
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE CPP #-}
#if MIN_VERSION_base(4,14,0)
@@ -46,6 +48,9 @@ import Data.Singletons.TypeLits
import Data.List.NonEmpty (NonEmpty((:|)),sort,sortBy,(<|))
+import GHC.Generics (Generic,Generic1)
+import Control.DeepSeq (NFData,NFData1)
+
$(singletons [d|
data N where
Z :: N
@@ -56,6 +61,8 @@ $(singletons [d|
True -> Z
False -> S $ fromNat (pred n)
+ deriving instance Generic N
+ deriving instance NFData N
deriving instance Show N
deriving instance Eq N
instance Ord N where
@@ -86,9 +93,9 @@ $(singletons [d|
data VSpace a b = VSpace { vId :: a,
vDim :: b }
- deriving (Show, Ord, Eq)
+ deriving (Show, Ord, Eq, Generic, NFData, Generic1, NFData1)
- data Ix a = ICon a | ICov a deriving (Show, Ord, Eq)
+ data Ix a = ICon a | ICov a deriving (Show, Ord, Eq, Generic, NFData, Generic1, NFData1)
ixCompare :: Ord a => Ix a -> Ix a -> Ordering
ixCompare (ICon a) (ICon b) = compare a b
@@ -105,7 +112,7 @@ $(singletons [d|
data IList a = ConCov (NonEmpty a) (NonEmpty a) |
Cov (NonEmpty a) |
Con (NonEmpty a)
- deriving (Show, Ord, Eq)
+ deriving (Show, Ord, Eq, Generic, NFData, Generic1, NFData1)
type GRank s n = [(VSpace s n, IList s)]
type Rank = GRank Symbol Nat
@@ -361,7 +368,7 @@ $(singletons [d|
data TransRule a = TransCon (NonEmpty a) (NonEmpty a) |
TransCov (NonEmpty a) (NonEmpty a)
- deriving (Show, Eq)
+ deriving (Show, Eq, Generic, NFData, Generic1, NFData1)
saneTransRule :: Ord a => TransRule a -> Bool
saneTransRule tl =
diff --git a/src/Math/Tensor/Safe/Vector.hs b/src/Math/Tensor/Safe/Vector.hs
index 877f8f2..80a70bd 100644
--- a/src/Math/Tensor/Safe/Vector.hs
+++ b/src/Math/Tensor/Safe/Vector.hs
@@ -25,12 +25,18 @@ import Math.Tensor.Safe.TH
import Data.Kind (Type)
import Data.Singletons (Sing)
+import Control.DeepSeq (NFData(rnf))
+
data Vec :: N -> Type -> Type where
VNil :: Vec 'Z a
VCons :: a -> Vec n a -> Vec ('S n) a
deriving instance Show a => Show (Vec n a)
+instance NFData a => NFData (Vec n a) where
+ rnf VNil = ()
+ rnf (VCons x xs) = rnf x `seq` rnf xs
+
instance Eq a => Eq (Vec n a) where
VNil == VNil = True
(x `VCons` xs) == (y `VCons` ys) = x == y && xs == ys