summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--LICENSE29
-rw-r--r--README.md1
-rw-r--r--mxnet-dataiter.cabal65
-rw-r--r--src/MXNet/Core/IO/DataIter/Conduit.hs59
-rw-r--r--src/MXNet/Core/IO/DataIter/Streaming.hs62
-rw-r--r--src/MXNet/Core/IO/Internal.hs6
-rw-r--r--src/MXNet/Core/IO/Internal/TH.hs103
-rw-r--r--test/conduit.hs38
-rw-r--r--test/streaming.hs39
9 files changed, 402 insertions, 0 deletions
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..26b7e73
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,29 @@
+BSD 3-Clause License
+
+Copyright (c) 2018, Jiasen Wu
+All rights reserved.
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+* Neither the name of the copyright holder nor the names of its
+ contributors may be used to endorse or promote products derived from
+ this software without specific prior written permission.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
+FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
+DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
+SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
+CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
+OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..820cd35
--- /dev/null
+++ b/README.md
@@ -0,0 +1 @@
+# mxnet-dataiter
diff --git a/mxnet-dataiter.cabal b/mxnet-dataiter.cabal
new file mode 100644
index 0000000..61f7d07
--- /dev/null
+++ b/mxnet-dataiter.cabal
@@ -0,0 +1,65 @@
+-- This file has been generated from package.yaml by hpack version 0.20.0.
+--
+-- see: https://github.com/sol/hpack
+--
+-- hash: b16997d0d4bb4611f467a01d45d22fb2a96f5c0ebfb49dc0bdc8b482a563fe7d
+
+name: mxnet-dataiter
+version: 0.1.0.0
+synopsis: mxnet dataiters
+description: Providing the mxnet dataiters as Stream or Conduit
+homepage: https://github.com/pierric/mxnet-dataiter#readme
+bug-reports: https://github.com/pierric/mxnet-dataiter/issues
+author: Jiasen Wu
+maintainer: jiasenwu@hotmail.com
+copyright: 2018 Jiasen Wu
+license: BSD3
+license-file: LICENSE
+category: Machine Learning, AI
+build-type: Simple
+cabal-version: >= 1.10
+
+extra-source-files:
+ README.md
+
+source-repository head
+ type: git
+ location: https://github.com/pierric/mxnet-dataiter
+
+library
+ hs-source-dirs:
+ src
+ build-depends:
+ base >=4.7 && <5,
+ template-haskell >= 2.10.0.0,
+ streaming >= 0.1.4.5,
+ conduit >= 1.2 && < 1.3,
+ conduit-combinators >= 1.1.2 && < 1.3,
+ mxnet == 0.2.0.0,
+ mxnet-nn >= 0.0.1.3 && < 0.0.2
+ exposed-modules:
+ MXNet.Core.IO.Internal,
+ MXNet.Core.IO.DataIter.Streaming,
+ MXNet.Core.IO.DataIter.Conduit
+ other-modules:
+ Paths_mxnet_dataiter
+ MXNet.Core.IO.Internal.TH
+ default-extensions:
+ FlexibleContexts, DataKinds, TypeOperators
+ default-language: Haskell2010
+
+test-suite streaming
+ type: exitcode-stdio-1.0
+ main-is: streaming.hs
+ hs-source-dirs: test
+ build-depends: base>=4.7 && <5, hspec==2.*, streaming >= 0.1.4.5, mxnet == 0.2.0.0, mxnet-dataiter
+ default-language: Haskell2010
+ default-extensions: FlexibleContexts, DataKinds, TypeApplications
+
+test-suite conduit
+ type: exitcode-stdio-1.0
+ main-is: conduit.hs
+ hs-source-dirs: test
+ build-depends: base>=4.7 && <5, hspec==2.*, mxnet == 0.2.0.0, mxnet-dataiter
+ default-language: Haskell2010
+ default-extensions: FlexibleContexts, DataKinds, TypeApplications \ No newline at end of file
diff --git a/src/MXNet/Core/IO/DataIter/Conduit.hs b/src/MXNet/Core/IO/DataIter/Conduit.hs
new file mode 100644
index 0000000..09b2ed0
--- /dev/null
+++ b/src/MXNet/Core/IO/DataIter/Conduit.hs
@@ -0,0 +1,59 @@
+{-# Language TypeFamilies #-}
+{-# LANGUAGE FlexibleInstances #-}
+module MXNet.Core.IO.DataIter.Conduit (
+ ConduitData,
+ Dataset(..),
+ imageRecordIter, mnistIter, csvIter, libSVMIter
+) where
+
+import Data.IORef
+import Data.Conduit
+import qualified Data.Conduit.Combinators as C
+import qualified Data.Conduit.List as CL
+import Control.Monad.IO.Class
+import MXNet.Core.Base
+import MXNet.Core.Base.NDArray (NDArray(..))
+import MXNet.Core.Base.Internal
+import qualified MXNet.Core.IO.Internal as I
+
+import MXNet.NN.Types (TrainM)
+import MXNet.NN.DataIter.Class
+
+newtype ConduitData m a = ConduitData { getConduit :: ConduitM () a m () }
+
+imageRecordIter :: (MatchKVList kvs I.ImageRecordIter_Args, ShowKV kvs, DType a, MonadIO m) =>
+ HMap kvs -> ConduitData m (NDArray a, NDArray a)
+imageRecordIter = makeIter I.imageRecordIter
+
+mnistIter :: (MatchKVList kvs I.MNISTIter_Args, ShowKV kvs, DType a, MonadIO m) =>
+ HMap kvs -> ConduitData m (NDArray a, NDArray a)
+mnistIter = makeIter I.mNISTIter
+
+csvIter :: (MatchKVList kvs I.CSVIter_Args, ShowKV kvs, DType a, MonadIO m) =>
+ HMap kvs -> ConduitData m (NDArray a, NDArray a)
+csvIter = makeIter I.cSVIter
+
+libSVMIter :: (MatchKVList kvs I.LibSVMIter_Args, ShowKV kvs, DType a, MonadIO m) =>
+ HMap kvs -> ConduitData m (NDArray a, NDArray a)
+libSVMIter = makeIter I.libSVMIter
+
+makeIter creator args = ConduitData $ do
+ iter <- liftIO (creator args)
+ let loop = do valid <- liftIO $ checked $ mxDataIterNext iter
+ if valid == 0
+ then liftIO (checked $ mxDataIterFree iter)
+ else do
+ yieldM $ liftIO $ do
+ dat <- checked $ mxDataIterGetData iter
+ lbl <- checked $ mxDataIterGetLabel iter
+ return (NDArray dat, NDArray lbl)
+ loop
+ loop
+
+type instance DatasetConstraint (ConduitData m1) m2 = m1 ~ m2
+
+instance Monad m => Dataset (ConduitData m) where
+ fromListD = ConduitData . CL.sourceList
+ zipD (ConduitData d1) (ConduitData d2) = ConduitData $ getZipSource $ (,) <$> ZipSource d1 <*> ZipSource d2
+ sizeD (ConduitData dat) = runConduit (dat .| C.length)
+ forEachD (ConduitData dat) proc = sourceToList $ dat .| CL.mapM proc
diff --git a/src/MXNet/Core/IO/DataIter/Streaming.hs b/src/MXNet/Core/IO/DataIter/Streaming.hs
new file mode 100644
index 0000000..9ad57af
--- /dev/null
+++ b/src/MXNet/Core/IO/DataIter/Streaming.hs
@@ -0,0 +1,62 @@
+{-# Language TypeFamilies #-}
+{-# Language FlexibleInstances #-}
+module MXNet.Core.IO.DataIter.Streaming (
+ StreamData,
+ Dataset(..),
+ imageRecordIter, mnistIter, csvIter, libSVMIter
+) where
+
+import Data.IORef
+import Streaming
+import Streaming.Prelude (Of(..), yield, length_, toList_)
+import qualified Streaming.Prelude as S
+import MXNet.Core.Base
+import MXNet.Core.Base.NDArray (NDArray(..))
+import MXNet.Core.Base.Internal
+import qualified MXNet.Core.IO.Internal as I
+
+import MXNet.NN.Types (TrainM)
+import MXNet.NN.DataIter.Class
+
+newtype StreamData m a = StreamData { getStream :: Stream (Of a) m ()}
+
+imageRecordIter :: (MatchKVList kvs I.ImageRecordIter_Args, ShowKV kvs, DType a, MonadIO m) =>
+ HMap kvs -> StreamData m (NDArray a, NDArray a)
+imageRecordIter = makeIter I.imageRecordIter
+
+mnistIter :: (MatchKVList kvs I.MNISTIter_Args, ShowKV kvs, DType a, MonadIO m) =>
+ HMap kvs -> StreamData m (NDArray a, NDArray a)
+mnistIter = makeIter I.mNISTIter
+
+csvIter :: (MatchKVList kvs I.CSVIter_Args, ShowKV kvs, DType a, MonadIO m) =>
+ HMap kvs -> StreamData m (NDArray a, NDArray a)
+csvIter = makeIter I.cSVIter
+
+libSVMIter :: (MatchKVList kvs I.LibSVMIter_Args, ShowKV kvs, DType a, MonadIO m) =>
+ HMap kvs -> StreamData m (NDArray a, NDArray a)
+libSVMIter = makeIter I.libSVMIter
+
+makeIter creator args = StreamData $ do
+ cnt <- liftIO (newIORef 0)
+ iter <- liftIO (creator args)
+ let loop = do valid <- liftIO $ do
+ modifyIORef cnt (+1)
+ checked $ mxDataIterNext iter
+ if valid == 0
+ then liftIO (checked $ mxDataIterFree iter)
+ else do
+ item <- liftIO $ do
+ dat <- checked $ mxDataIterGetData iter
+ lbl <- checked $ mxDataIterGetLabel iter
+ return (NDArray dat, NDArray lbl)
+ yield item
+ loop
+ loop
+
+type instance DatasetConstraint (StreamData m1) m2 = m1 ~ m2
+
+instance Monad m => Dataset (StreamData m) where
+ fromListD = StreamData . S.each
+ zipD s1 s2 = StreamData $ S.zip (getStream s1) (getStream s2)
+ sizeD = length_ . getStream
+ forEachD dat proc = toList_ $ void $ S.mapM proc (getStream dat)
diff --git a/src/MXNet/Core/IO/Internal.hs b/src/MXNet/Core/IO/Internal.hs
new file mode 100644
index 0000000..83921ec
--- /dev/null
+++ b/src/MXNet/Core/IO/Internal.hs
@@ -0,0 +1,6 @@
+{-# Language TemplateHaskell #-}
+module MXNet.Core.IO.Internal where
+
+import MXNet.Core.IO.Internal.TH
+
+$(registerDataIters) \ No newline at end of file
diff --git a/src/MXNet/Core/IO/Internal/TH.hs b/src/MXNet/Core/IO/Internal/TH.hs
new file mode 100644
index 0000000..884e87b
--- /dev/null
+++ b/src/MXNet/Core/IO/Internal/TH.hs
@@ -0,0 +1,103 @@
+{-# Language TemplateHaskell #-}
+module MXNet.Core.IO.Internal.TH where
+
+import Data.List
+import Data.Char
+import Data.Bifunctor
+import Text.ParserCombinators.ReadP
+import Language.Haskell.TH
+
+import MXNet.Core.Base
+import MXNet.Core.Base.Internal
+
+diInfoName (n,_,_,_,_,_) = n
+diInfoDesc (_,n,_,_,_,_) = n
+diInfoArgc (_,_,n,_,_,_) = n
+diInfoArgN (_,_,_,n,_,_) = n
+diInfoArgT (_,_,_,_,n,_) = n
+diInfoArgD (_,_,_,_,_,n) = n
+
+registerDataIters :: Q [Dec]
+registerDataIters = do
+ dataiterInfo <- runIO (mxListDataIters >>= mapM info . zip [0..])
+ concat <$> mapM (uncurry makeDataIter) dataiterInfo
+ where
+ info (idx, creator) = do
+ info <- mxDataIterGetIterInfo creator
+ let name = diInfoName info
+ argn = diInfoArgN info
+ argt = diInfoArgT info
+ args = nub $ zip argn argt
+ return ((idx, name), args)
+
+makeDataIter :: (Integer, String) -> [(String, String)] -> Q [Dec]
+makeDataIter (index, name) args = do
+ let args' = map (second parseArgDesc) args
+ dname = mkName (deCap name)
+ let kvs = mkName "kvs"
+ cstName = mkName $ name ++ "_Args"
+ args = foldr add promotedNilT args'
+ typWithArgs = if null args' then [t| IO DataIterHandle |] else [t| HMap $(varT kvs) -> IO DataIterHandle |]
+ cst <- tySynD cstName [] args
+ sig <- sigD dname [t| (MatchKVList $(varT kvs) $(conT cstName), ShowKV $(varT kvs)) => $(typWithArgs) |]
+ let allargs = mkName "allargs"
+ fun <- funD dname [clause [varP allargs] (normalB [e| do{
+ args <- return (dump $(varE allargs));
+ len <- return (fromIntegral $ length args);
+ (keys, vals) <- return (unzip args);
+ crts <- mxListDataIters;
+ checked $ mxDataIterCreateIter (crts !! $(litE $ integerL index)) len keys vals;
+ } |]) []]
+ return [cst, sig, fun]
+ where
+ deCap (x:xs) = (toLower x):xs
+ toTyp ArgString = [t| String |]
+ toTyp ArgInt = [t| Int |]
+ toTyp ArgLong = [t| Integer |]
+ toTyp ArgFloat = [t| Float |]
+ toTyp ArgBool = [t| Bool |]
+ toTyp ArgShape = [t| [Int] |]
+ toTyp (ArgEnum v) = [t| String |]
+ toTyp (ArgTuple t) = [t| [$(toTyp t)] |]
+ app t1 t2 = [t| $(toTyp t1) -> $(t2) |]
+ add (nm,(at,_)) lst = let item = [t| $(litT (strTyLit nm)) ':= $(toTyp at) |]
+ in appT (appT promotedConsT item) lst
+
+data ArgType = ArgString | ArgInt | ArgLong | ArgFloat | ArgBool | ArgShape | ArgEnum [String] | ArgTuple ArgType
+ deriving (Eq, Show)
+data ArgOccr = Required | Optional
+ deriving (Eq, Show)
+
+parseArgDesc :: String -> (ArgType, ArgOccr)
+parseArgDesc str = case readP_to_S desc str of
+ (r, _):_ -> r
+ _ -> error ("cannot parse arg desc: " ++ str)
+
+alphaNum = many1 (satisfy isAlphaNum)
+quoted = between (char '\'') (char '\'') (many $ satisfy isAlphaNum +++ choice (map char "/_-."))
+boxed = between (char '[') (char ']') (quoted +++ number +++ alphaNum)
+number = optional (char '-') >> many1 (satisfy isDigit)
+comma = skipSpaces >> char ',' >> skipSpaces
+enum = between (char '{') (char '}') (sepBy1 (alphaNum +++ quoted) comma)
+typ = choice [ string "string" >> return ArgString
+ , string "int" >> return ArgInt
+ , string "int (non-negative)" >> return ArgInt
+ , string "long" >> return ArgLong
+ , string "long (non-negative)" >> return ArgLong
+ , string "boolean" >> return ArgBool
+ , string "float" >> return ArgFloat
+ , string "Shape(tuple)" >> return ArgShape
+ , string "tuple of" >> skipSpaces >> (between (char '<') (char '>') typ >>= return . ArgTuple)
+ , enum >>= (return . ArgEnum) ]
+occ = choice [ string "required" >>
+ return Required
+ , string "optional" >> comma >>
+ string "default=" >> (quoted +++ boxed +++ alphaNum +++ number) >>
+ return Optional]
+
+desc :: ReadP (ArgType, ArgOccr)
+desc = do
+ t <- typ
+ comma
+ o <- occ
+ return (t, o)
diff --git a/test/conduit.hs b/test/conduit.hs
new file mode 100644
index 0000000..e6b9e1c
--- /dev/null
+++ b/test/conduit.hs
@@ -0,0 +1,38 @@
+module Main where
+
+import Test.Hspec
+import MXNet.Core.Base
+
+import MXNet.Core.IO.DataIter.Conduit
+
+type DS = ConduitData IO (NDArray Float, NDArray Float)
+
+main :: IO ()
+main = hspec $ do
+ describe "MNISTIter" $ do
+ it "batch-size = 1" $ do
+ let sr = mnistIter (add @"image" "test/data/train-images-idx3-ubyte" $
+ add @"label" "test/data/train-labels-idx1-ubyte" $
+ add @"batch_size" 1 nil) :: DS
+ sizeD sr `shouldReturn` 60000
+ it "batch-size = 32" $ do
+ let sr = mnistIter (add @"image" "test/data/train-images-idx3-ubyte" $
+ add @"label" "test/data/train-labels-idx1-ubyte" $
+ add @"batch_size" 32 nil) :: DS
+ sizeD sr `shouldReturn` 1875
+ it "batch-size = 128" $ do
+ let sr = mnistIter (add @"image" "test/data/train-images-idx3-ubyte" $
+ add @"label" "test/data/train-labels-idx1-ubyte" $
+ add @"batch_size" 128 nil) :: DS
+ sizeD sr `shouldReturn` 468
+ describe "ImageRecordIter" $ do
+ it "batch-size = 32" $ do
+ let sr = imageRecordIter (add @"path_imgrec" "test/data/cifar10_val.rec" $
+ add @"data_shape" [3,28,28] $
+ add @"batch_size" 32 nil) :: DS
+ sizeD sr `shouldReturn` 313
+ it "batch-size = 128" $ do
+ let sr = imageRecordIter (add @"path_imgrec" "test/data/cifar10_val.rec" $
+ add @"data_shape" [3,28,28] $
+ add @"batch_size" 128 nil) :: DS
+ sizeD sr `shouldReturn` 79
diff --git a/test/streaming.hs b/test/streaming.hs
new file mode 100644
index 0000000..5d70e98
--- /dev/null
+++ b/test/streaming.hs
@@ -0,0 +1,39 @@
+module Main where
+
+import Test.Hspec
+import Streaming.Prelude
+import MXNet.Core.Base
+
+import MXNet.Core.IO.DataIter.Streaming
+
+type DS = StreamData IO (NDArray Float, NDArray Float)
+
+main :: IO ()
+main = hspec $ do
+ describe "MNISTIter" $ do
+ it "batch-size = 1" $ do
+ let sr = mnistIter (add @"image" "test/data/train-images-idx3-ubyte" $
+ add @"label" "test/data/train-labels-idx1-ubyte" $
+ add @"batch_size" 1 nil) :: DS
+ sizeD sr `shouldReturn` 60000
+ it "batch-size = 32" $ do
+ let sr = mnistIter (add @"image" "test/data/train-images-idx3-ubyte" $
+ add @"label" "test/data/train-labels-idx1-ubyte" $
+ add @"batch_size" 32 nil) :: DS
+ sizeD sr `shouldReturn` 1875
+ it "batch-size = 128" $ do
+ let sr = mnistIter (add @"image" "test/data/train-images-idx3-ubyte" $
+ add @"label" "test/data/train-labels-idx1-ubyte" $
+ add @"batch_size" 128 nil) :: DS
+ sizeD sr `shouldReturn` 468
+ describe "ImageRecordIter" $ do
+ it "batch-size = 32" $ do
+ let sr = imageRecordIter (add @"path_imgrec" "test/data/cifar10_val.rec" $
+ add @"data_shape" [3,28,28] $
+ add @"batch_size" 32 nil) :: DS
+ sizeD sr `shouldReturn` 313
+ it "batch-size = 128" $ do
+ let sr = imageRecordIter (add @"path_imgrec" "test/data/cifar10_val.rec" $
+ add @"data_shape" [3,28,28] $
+ add @"batch_size" 128 nil) :: DS
+ sizeD sr `shouldReturn` 79