summaryrefslogtreecommitdiff
path: root/Foreign/CUDA/BLAS/Level3.chs
diff options
context:
space:
mode:
Diffstat (limited to 'Foreign/CUDA/BLAS/Level3.chs')
-rw-r--r--Foreign/CUDA/BLAS/Level3.chs17
1 files changed, 17 insertions, 0 deletions
diff --git a/Foreign/CUDA/BLAS/Level3.chs b/Foreign/CUDA/BLAS/Level3.chs
index b2d3130..474c45c 100644
--- a/Foreign/CUDA/BLAS/Level3.chs
+++ b/Foreign/CUDA/BLAS/Level3.chs
@@ -133,6 +133,8 @@ module Foreign.CUDA.BLAS.Level3 (
dotEx,
dotcEx,
scalEx,
+ gemmBatchedEx,
+ gemmStridedBatchedEx,
) where
@@ -556,3 +558,18 @@ dotcEx _ _ _ _ _ _ _ _ _ _ _ = cublasError "'dotcEx' requires at least cuda-8.0"
scalEx :: Handle -> Int -> Ptr () -> Type -> DevicePtr () -> Type -> Int -> Type -> IO ()
scalEx _ _ _ _ _ _ _ _ = cublasError "'scalEx' requires at least cuda-8.0"
#endif
+#if CUDA_VERSION >= 9100
+
+{-# INLINEABLE gemmBatchedEx #-}
+{# fun unsafe cublasGemmBatchedEx as gemmBatchedEx { useHandle `Handle', cFromEnum `Operation', cFromEnum `Operation', `Int', `Int', `Int', castPtr `Ptr ()', useDevP `DevicePtr (DevicePtr ())', cFromEnum `Type', `Int', useDevP `DevicePtr ()', cFromEnum `Type', `Int', castPtr `Ptr ()', useDevP `DevicePtr ()', cFromEnum `Type', `Int', `Int', cFromEnum `Type', cFromEnum `GemmAlgorithm' } -> `()' checkStatus* #}
+
+{-# INLINEABLE gemmStridedBatchedEx #-}
+{# fun unsafe cublasGemmStridedBatchedEx as gemmStridedBatchedEx { useHandle `Handle', cFromEnum `Operation', cFromEnum `Operation', `Int', `Int', `Int', castPtr `Ptr ()', useDevP `DevicePtr ()', cFromEnum `Type', `Int', `Int64', useDevP `DevicePtr ()', cFromEnum `Type', `Int', `Int64', castPtr `Ptr ()', useDevP `DevicePtr ()', cFromEnum `Type', `Int', `Int64', `Int', cFromEnum `Type', cFromEnum `GemmAlgorithm' } -> `()' checkStatus* #}
+#else
+
+gemmBatchedEx :: Handle -> Operation -> Operation -> Int -> Int -> Int -> Ptr () -> DevicePtr (DevicePtr ()) -> Type -> Int -> DevicePtr () -> Type -> Int -> Ptr () -> DevicePtr () -> Type -> Int -> Int -> Type -> GemmAlgorithm -> IO ()
+gemmBatchedEx _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ = cublasError "'gemmBatchedEx' requires at least cuda-9.1"
+
+gemmStridedBatchedEx :: Handle -> Operation -> Operation -> Int -> Int -> Int -> Ptr () -> DevicePtr () -> Type -> Int -> Int64 -> DevicePtr () -> Type -> Int -> Int64 -> Ptr () -> DevicePtr () -> Type -> Int -> Int64 -> Int -> Type -> GemmAlgorithm -> IO ()
+gemmStridedBatchedEx _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ = cublasError "'gemmStridedBatchedEx' requires at least cuda-9.1"
+#endif