summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTrevorMcDonell <>2017-11-15 08:40:00 (GMT)
committerhdiff <hdiff@hdiff.luite.com>2017-11-15 08:40:00 (GMT)
commit095330aaf94b0d030cf25f1710ab3ef6ebba4b98 (patch)
tree11002454f87f98782604fb9c4fba830cbe84292e
parent45303a1f5ef94736490a1f59dc134a507f5b96b0 (diff)
version 0.4.0.00.4.0.0
-rw-r--r--CHANGELOG.md16
-rw-r--r--Foreign/CUDA/BLAS/Context.chs53
-rw-r--r--Foreign/CUDA/BLAS/Error.chs8
-rw-r--r--Foreign/CUDA/BLAS/Internal/C2HS.hs15
-rw-r--r--Foreign/CUDA/BLAS/Internal/Types.chs22
-rw-r--r--cbits/stubs.h59
-rw-r--r--cublas.cabal5
7 files changed, 161 insertions, 17 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index d5a29ff..94f7574 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -2,9 +2,23 @@
Notable changes to the project will be documented in this file.
-The format is based on [Keep a Changelog](http://keepachangelog.com/).
+The format is based on [Keep a Changelog](http://keepachangelog.com/) and the
+project adheres to the [Haskell Package Versioning
+Policy (PVP)](https://pvp.haskell.org)
+
+
+## [0.4.0.0] - 2017-11-15
+### Added
+ * `getMathMode`
+ * `setMathMode`
+
+### Fixed
+ * Build fix for CUDA-9
## 0.3.0.0 - 2017-08-24
* First version; replaces [bmsherman/cublas](https://github.com/bmsherman/cublas). Released on an unsuspecting world.
+
+[0.4.0.0]: https://github.com/tmcdonell/cublas/compare/release/0.3.0.0...0.4.0.0
+
diff --git a/Foreign/CUDA/BLAS/Context.chs b/Foreign/CUDA/BLAS/Context.chs
index b8dbf70..b258209 100644
--- a/Foreign/CUDA/BLAS/Context.chs
+++ b/Foreign/CUDA/BLAS/Context.chs
@@ -1,5 +1,6 @@
{-# LANGUAGE CPP #-}
{-# LANGUAGE ForeignFunctionInterface #-}
+{-# LANGUAGE TemplateHaskell #-}
-- |
-- Module : Foreign.CUDA.BLAS.Context
-- Copyright : [2014..2017] Trevor L. McDonell
@@ -17,11 +18,13 @@ module Foreign.CUDA.BLAS.Context (
create, destroy,
-- ** Utilities
- PointerMode(..), AtomicsMode(..),
+ PointerMode(..), AtomicsMode(..), MathMode(..),
setPointerMode,
getPointerMode,
setAtomicsMode,
getAtomicsMode,
+ setMathMode,
+ getMathMode,
) where
@@ -89,11 +92,9 @@ import Control.Monad ( liftM )
{-# INLINEABLE getPointerMode #-}
{# fun unsafe cublasGetPointerMode_v2 as getPointerMode
{ useHandle `Handle'
- , alloca- `PointerMode' peekPM*
+ , alloca- `PointerMode' peekEnum*
}
-> `()' checkStatus*- #}
- where
- peekPM = liftM cToEnum . peek
-- | Set whether cuBLAS library functions are allowed to use atomic functions,
@@ -118,9 +119,47 @@ import Control.Monad ( liftM )
{-# INLINEABLE getAtomicsMode #-}
{# fun unsafe cublasGetAtomicsMode as getAtomicsMode
{ useHandle `Handle'
- , alloca- `AtomicsMode' peekAM*
+ , alloca- `AtomicsMode' peekEnum*
}
-> `()' checkStatus*- #}
- where
- peekAM = liftM cToEnum . peek
+
+
+-- | Set whether cuBLAS library functions are allowed to use Tensor Core
+-- operations where available.
+--
+-- <http://docs.nvidia.com/cuda/cublas/index.html#cublassetmathmode>
+--
+-- @since 0.4.0.0@
+--
+{-# INLINEABLE setMathMode #-}
+#if CUDA_VERSION < 9000
+setMathMode :: Handle -> MathMode -> IO ()
+setMathMode _ _ = requireSDK 'setMathMode 9.0
+#else
+{# fun unsafe cublasSetMathMode as setMathMode
+ { useHandle `Handle'
+ , cFromEnum `MathMode'
+ }
+ -> `()' checkStatus*- #}
+#endif
+
+
+-- | Determine whether cuBLAS library functions are allowed to use Tensor Core
+-- operations where available.
+--
+-- <http://docs.nvidia.com/cuda/cublas/index.html#cublasgetmathmode>
+--
+-- @since 0.4.0.0@
+--
+{-# INLINEABLE getMathMode #-}
+#if CUDA_VERSION < 9000
+getMathMode :: Handle -> IO MathMode
+getMathMode _ = requireSDK 'getMathMode 9.0
+#else
+{# fun unsafe cublasGetMathMode as getMathMode
+ { useHandle `Handle'
+ , alloca- `MathMode' peekEnum*
+ }
+ -> `()' checkStatus*- #}
+#endif
diff --git a/Foreign/CUDA/BLAS/Error.chs b/Foreign/CUDA/BLAS/Error.chs
index 84b2f42..797ac26 100644
--- a/Foreign/CUDA/BLAS/Error.chs
+++ b/Foreign/CUDA/BLAS/Error.chs
@@ -21,6 +21,8 @@ import Foreign.CUDA.BLAS.Internal.C2HS
import Control.Exception
import Data.Typeable
import Foreign.C.Types
+import Language.Haskell.TH
+import Text.Printf
#include "cbits/stubs.h"
{# context lib="cublas" #}
@@ -70,6 +72,12 @@ instance Show CUBLASException where
cublasError :: String -> IO a
cublasError s = throwIO (UserError s)
+-- |
+-- A specially formatted error message
+--
+requireSDK :: Name -> Double -> IO a
+requireSDK n v = cublasError $ printf "'%s' requires at least cuda-%3.1f\n" (show n) v
+
-- | Return the results of a function on successful execution, otherwise throw
-- an exception with an error string associated with the return code
diff --git a/Foreign/CUDA/BLAS/Internal/C2HS.hs b/Foreign/CUDA/BLAS/Internal/C2HS.hs
index c6e784c..f811ab8 100644
--- a/Foreign/CUDA/BLAS/Internal/C2HS.hs
+++ b/Foreign/CUDA/BLAS/Internal/C2HS.hs
@@ -8,16 +8,13 @@
-- Portability : non-portable (GHC extensions)
--
-module Foreign.CUDA.BLAS.Internal.C2HS (
-
- -- * Conversion between C and Haskell types
- cIntConv, cFloatConv, cToBool, cFromBool, cToEnum, cFromEnum,
-
-) where
+module Foreign.CUDA.BLAS.Internal.C2HS
+ where
-- system
import Foreign
import Foreign.C
+import Control.Monad ( liftM )
-- Conversions -----------------------------------------------------------------
@@ -68,3 +65,9 @@ cToEnum = toEnum . cIntConv
cFromEnum :: (Enum e, Integral i) => e -> i
cFromEnum = cIntConv . fromEnum
+-- | Peek a C value into a Haskell enumeration
+--
+{-# INLINE peekEnum #-}
+peekEnum :: (Enum a, Integral b, Storable b) => Ptr b -> IO a
+peekEnum = liftM cToEnum . peek
+
diff --git a/Foreign/CUDA/BLAS/Internal/Types.chs b/Foreign/CUDA/BLAS/Internal/Types.chs
index 22890ce..d92385c 100644
--- a/Foreign/CUDA/BLAS/Internal/Types.chs
+++ b/Foreign/CUDA/BLAS/Internal/Types.chs
@@ -120,11 +120,31 @@ data Type
--
#if CUDA_VERSION < 8000
data GemmAlgorithm
-#else
+#elif CUDA_VERSION < 9000
{# enum cublasGemmAlgo_t as GemmAlgorithm
{ underscoreToCase
, CUBLAS_GEMM_DFALT as GemmDefault
}
with prefix="CUBLAS" deriving (Eq, Show) #}
+#else
+{# enum cublasGemmAlgo_t as GemmAlgorithm
+ { underscoreToCase
+ , CUBLAS_GEMM_DFALT as CUBLAS_GEMM_DFALT
+ , CUBLAS_GEMM_DEFAULT as GemmDefault
+ }
+ with prefix="CUBLAS" deriving (Eq, Show) #}
+#endif
+
+
+-- | Enum for default math mode / tensor math mode
+--
+#if CUDA_VERSION < 9000
+data MathMode
+#else
+{# enum cublasMath_t as MathMode
+ { CUBLAS_DEFAULT_MATH as DefaultMath
+ , CUBLAS_TENSOR_OP_MATH as TensorMath
+ }
+ with prefix="CUBLAS" deriving (Eq, Show) #}
#endif
diff --git a/cbits/stubs.h b/cbits/stubs.h
index 2a1026e..65787cc 100644
--- a/cbits/stubs.h
+++ b/cbits/stubs.h
@@ -14,5 +14,64 @@
#include <cuda.h>
#include <cublas_v2.h>
+/*
+ * We need to redeclare these functions for CUDA-9, as they are now hidden
+ * behind a #if defined(__cplusplus) guard.
+ */
+#if CUDA_VERSION >= 9000
+typedef struct __align__(2) {
+ unsigned short x;
+} __half;
+
+CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHgemm (cublasHandle_t handle,
+ cublasOperation_t transa,
+ cublasOperation_t transb,
+ int m,
+ int n,
+ int k,
+ const __half *alpha, /* host or device pointer */
+ const __half *A,
+ int lda,
+ const __half *B,
+ int ldb,
+ const __half *beta, /* host or device pointer */
+ __half *C,
+ int ldc);
+
+CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHgemmBatched (cublasHandle_t handle,
+ cublasOperation_t transa,
+ cublasOperation_t transb,
+ int m,
+ int n,
+ int k,
+ const __half *alpha, /* host or device pointer */
+ const __half *Aarray[],
+ int lda,
+ const __half *Barray[],
+ int ldb,
+ const __half *beta, /* host or device pointer */
+ __half *Carray[],
+ int ldc,
+ int batchCount);
+
+CUBLASAPI cublasStatus_t CUBLASWINAPI cublasHgemmStridedBatched (cublasHandle_t handle,
+ cublasOperation_t transa,
+ cublasOperation_t transb,
+ int m,
+ int n,
+ int k,
+ const __half *alpha, /* host or device pointer */
+ const __half *A,
+ int lda,
+ long long int strideA, /* purposely signed */
+ const __half *B,
+ int ldb,
+ long long int strideB,
+ const __half *beta, /* host or device pointer */
+ __half *C,
+ int ldc,
+ long long int strideC,
+ int batchCount);
+#endif /* CUDA_VERSION */
#endif /* C_STUBS_H */
diff --git a/cublas.cabal b/cublas.cabal
index b5a3b59..e7db438 100644
--- a/cublas.cabal
+++ b/cublas.cabal
@@ -1,5 +1,5 @@
name: cublas
-version: 0.3.0.0
+version: 0.4.0.0
synopsis: FFI bindings to the CUDA BLAS library
description:
The cuBLAS library is an implementation of BLAS (Basic Linear Algebra
@@ -57,6 +57,7 @@ library
, cuda >= 0.8
, half >= 0.1
, storable-complex >= 0.2
+ , template-haskell
build-tools:
c2hs >= 0.16
@@ -75,6 +76,6 @@ source-repository head
source-repository this
type: git
location: https://github.com/tmcdonell/cublas
- tag: 0.3.0.0
+ tag: 0.4.0.0
-- vim: nospell