summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTroelsHenriksen <>2019-03-13 12:25:00 (GMT)
committerhdiff <hdiff@hdiff.luite.com>2019-03-13 12:25:00 (GMT)
commitdc312109d1c96e44afee330d5618c126d15093e1 (patch)
tree8e2bf9efca10f453ad6ada3d1509f70bee750fe8
parent1c411af78f73eac393062a920c8ddf21f97a51c3 (diff)
version 0.9.1HEAD0.9.1master
-rw-r--r--futhark.cabal249
-rw-r--r--rts/c/cuda.h535
-rw-r--r--rts/c/free_list.h110
-rw-r--r--rts/c/opencl.h248
-rw-r--r--rts/csharp/opencl.cs12
-rw-r--r--rts/python/__init__.py0
-rw-r--r--rts/python/opencl.py5
-rw-r--r--src/Futhark/Analysis/DataDependencies.hs1
-rw-r--r--src/Futhark/Analysis/HORepresentation/MapNest.hs1
-rw-r--r--src/Futhark/Analysis/HORepresentation/SOAC.hs4
-rw-r--r--src/Futhark/Analysis/Metrics.hs4
-rw-r--r--src/Futhark/Analysis/PrimExp.hs74
-rw-r--r--src/Futhark/Analysis/PrimExp/Convert.hs2
-rw-r--r--src/Futhark/Analysis/Range.hs1
-rw-r--r--src/Futhark/Analysis/SymbolTable.hs5
-rw-r--r--src/Futhark/Analysis/Usage.hs1
-rw-r--r--src/Futhark/Analysis/UsageTable.hs8
-rw-r--r--src/Futhark/CLI/Bench.hs (renamed from src/futhark-bench.hs)128
-rw-r--r--src/Futhark/CLI/C.hs (renamed from src/futhark-c.hs)4
-rw-r--r--src/Futhark/CLI/CSOpenCL.hs (renamed from src/futhark-csopencl.hs)4
-rw-r--r--src/Futhark/CLI/CSharp.hs (renamed from src/futhark-cs.hs)4
-rw-r--r--src/Futhark/CLI/CUDA.hs43
-rw-r--r--src/Futhark/CLI/Datacmp.hs29
-rw-r--r--src/Futhark/CLI/Dataset.hs (renamed from src/futhark-dataset.hs)4
-rw-r--r--src/Futhark/CLI/Dev.hs398
-rw-r--r--src/Futhark/CLI/Doc.hs (renamed from src/futhark-doc.hs)5
-rw-r--r--src/Futhark/CLI/Misc.hs31
-rw-r--r--src/Futhark/CLI/OpenCL.hs (renamed from src/futhark-opencl.hs)4
-rw-r--r--src/Futhark/CLI/Pkg.hs (renamed from src/futhark-pkg.hs)62
-rw-r--r--src/Futhark/CLI/PyOpenCL.hs (renamed from src/futhark-pyopencl.hs)4
-rw-r--r--src/Futhark/CLI/Python.hs (renamed from src/futhark-py.hs)4
-rw-r--r--src/Futhark/CLI/REPL.hs419
-rw-r--r--src/Futhark/CLI/Run.hs143
-rw-r--r--src/Futhark/CLI/Test.hs (renamed from src/futhark-test.hs)160
-rw-r--r--src/Futhark/CodeGen/Backends/CCUDA.hs277
-rw-r--r--src/Futhark/CodeGen/Backends/CCUDA/Boilerplate.hs256
-rw-r--r--src/Futhark/CodeGen/Backends/COpenCL.hs25
-rw-r--r--src/Futhark/CodeGen/Backends/COpenCL/Boilerplate.hs47
-rw-r--r--src/Futhark/CodeGen/Backends/CSOpenCL.hs25
-rw-r--r--src/Futhark/CodeGen/Backends/CSOpenCL/Boilerplate.hs21
-rw-r--r--src/Futhark/CodeGen/Backends/GenericC.hs7
-rw-r--r--src/Futhark/CodeGen/Backends/GenericCSharp.hs2806
-rw-r--r--src/Futhark/CodeGen/Backends/PyOpenCL.hs23
-rw-r--r--src/Futhark/CodeGen/Backends/PyOpenCL/Boilerplate.hs6
-rw-r--r--src/Futhark/CodeGen/ImpCode.hs13
-rw-r--r--src/Futhark/CodeGen/ImpCode/Kernels.hs66
-rw-r--r--src/Futhark/CodeGen/ImpCode/OpenCL.hs12
-rw-r--r--src/Futhark/CodeGen/ImpGen.hs78
-rw-r--r--src/Futhark/CodeGen/ImpGen/CUDA.hs14
-rw-r--r--src/Futhark/CodeGen/ImpGen/Kernels.hs1067
-rw-r--r--src/Futhark/CodeGen/ImpGen/Kernels/Base.hs960
-rw-r--r--src/Futhark/CodeGen/ImpGen/Kernels/SegRed.hs601
-rw-r--r--src/Futhark/CodeGen/ImpGen/Kernels/ToOpenCL.hs324
-rw-r--r--src/Futhark/CodeGen/ImpGen/Kernels/Transpose.hs5
-rw-r--r--src/Futhark/CodeGen/ImpGen/Sequential.hs2
-rw-r--r--src/Futhark/CodeGen/OpenCL/Kernels.hs2
-rw-r--r--src/Futhark/Compiler.hs21
-rw-r--r--src/Futhark/Compiler/CLI.hs9
-rw-r--r--src/Futhark/Compiler/Program.hs1
-rw-r--r--src/Futhark/Doc/Generator.hs22
-rw-r--r--src/Futhark/Doc/Html.hs2
-rw-r--r--src/Futhark/FreshNames.hs4
-rw-r--r--src/Futhark/Internalise.hs62
-rw-r--r--src/Futhark/Internalise/Defunctionalise.hs84
-rw-r--r--src/Futhark/Internalise/Defunctorise.hs7
-rw-r--r--src/Futhark/Internalise/Monad.hs3
-rw-r--r--src/Futhark/Internalise/Monomorphise.hs28
-rw-r--r--src/Futhark/Internalise/TypesValues.hs11
-rw-r--r--src/Futhark/Optimise/CSE.hs1
-rw-r--r--src/Futhark/Optimise/Fusion.hs5
-rw-r--r--src/Futhark/Optimise/Fusion/Composing.hs1
-rw-r--r--src/Futhark/Optimise/Fusion/LoopKernel.hs1
-rw-r--r--src/Futhark/Optimise/InPlaceLowering.hs4
-rw-r--r--src/Futhark/Optimise/InPlaceLowering/SubstituteIndices.hs1
-rw-r--r--src/Futhark/Optimise/MemoryBlockMerging/Types.hs4
-rw-r--r--src/Futhark/Optimise/Simplify.hs2
-rw-r--r--src/Futhark/Optimise/Simplify/ClosedForm.hs1
-rw-r--r--src/Futhark/Optimise/Simplify/Lore.hs1
-rw-r--r--src/Futhark/Optimise/Simplify/Rule.hs8
-rw-r--r--src/Futhark/Optimise/Simplify/Rules.hs1
-rw-r--r--src/Futhark/Optimise/TileLoops.hs5
-rw-r--r--src/Futhark/Optimise/TileLoops/RegTiling3D.hs3
-rw-r--r--src/Futhark/Pass/ExpandAllocations.hs118
-rw-r--r--src/Futhark/Pass/ExplicitAllocations.hs90
-rw-r--r--src/Futhark/Pass/ExtractKernels.hs213
-rw-r--r--src/Futhark/Pass/ExtractKernels/BlockedKernel.hs66
-rw-r--r--src/Futhark/Pass/ExtractKernels/ISRWIM.hs1
-rw-r--r--src/Futhark/Pass/ExtractKernels/Intragroup.hs12
-rw-r--r--src/Futhark/Pass/ExtractKernels/Kernelise.hs1
-rw-r--r--src/Futhark/Pass/ExtractKernels/Segmented.hs814
-rw-r--r--src/Futhark/Pass/ExtractKernels/Split.hs41
-rw-r--r--src/Futhark/Pass/KernelBabysitting.hs1
-rw-r--r--src/Futhark/Pkg/Info.hs4
-rw-r--r--src/Futhark/Pkg/Types.hs5
-rw-r--r--src/Futhark/Representation/AST/Attributes/TypeOf.hs1
-rw-r--r--src/Futhark/Representation/AST/Syntax.hs14
-rw-r--r--src/Futhark/Representation/AST/Syntax/Core.hs10
-rw-r--r--src/Futhark/Representation/Aliases.hs4
-rw-r--r--src/Futhark/Representation/ExplicitMemory/Simplify.hs1
-rw-r--r--src/Futhark/Representation/Kernels/Kernel.hs98
-rw-r--r--src/Futhark/Representation/Kernels/Simplify.hs32
-rw-r--r--src/Futhark/Representation/Kernels/Sizes.hs4
-rw-r--r--src/Futhark/Representation/SOACS/Simplify.hs1
-rw-r--r--src/Futhark/Test.hs144
-rw-r--r--src/Futhark/Test/Values.hs38
-rw-r--r--src/Futhark/Tools.hs1
-rw-r--r--src/Futhark/Transform/FirstOrderTransform.hs1
-rw-r--r--src/Futhark/Transform/Rename.hs1
-rw-r--r--src/Futhark/TypeCheck.hs17
-rw-r--r--src/Futhark/Util.hs20
-rw-r--r--src/Futhark/Util/Log.hs4
-rw-r--r--src/Futhark/Util/Options.hs22
-rw-r--r--src/Language/Futhark.hs9
-rw-r--r--src/Language/Futhark/Attributes.hs297
-rw-r--r--src/Language/Futhark/Core.hs28
-rw-r--r--src/Language/Futhark/Interpreter.hs80
-rw-r--r--src/Language/Futhark/Parser.hs7
-rw-r--r--src/Language/Futhark/Parser/Parser.y2
-rw-r--r--src/Language/Futhark/Pretty.hs19
-rw-r--r--src/Language/Futhark/Semantic.hs35
-rw-r--r--src/Language/Futhark/Syntax.hs122
-rw-r--r--src/Language/Futhark/Traversals.hs52
-rw-r--r--src/Language/Futhark/TypeChecker.hs83
-rw-r--r--src/Language/Futhark/TypeChecker/Monad.hs7
-rw-r--r--src/Language/Futhark/TypeChecker/Terms.hs326
-rw-r--r--src/Language/Futhark/TypeChecker/Types.hs147
-rw-r--r--src/Language/Futhark/Warnings.hs4
-rw-r--r--src/futhark.hs475
-rw-r--r--src/futharki.hs471
-rw-r--r--src/wrapper.hs29
130 files changed, 7875 insertions, 5732 deletions
diff --git a/futhark.cabal b/futhark.cabal
index a14b78f..d85de0f 100644
--- a/futhark.cabal
+++ b/futhark.cabal
@@ -2,10 +2,10 @@
--
-- see: https://github.com/sol/hpack
--
--- hash: a75ffbf819d567c2108977d336c20ff7f11963b071dc0d9e581faa2f792dba99
+-- hash: d926568e0952a1cecfa1efa7cf687ddc7cd27d1c642a929d85ca5b0e8465b408
name: futhark
-version: 0.8.1
+version: 0.9.1
synopsis: An optimising compiler for a functional, array-oriented language.
description: See the website at https://futhark-lang.org
category: Language
@@ -23,6 +23,8 @@ extra-source-files:
futlib/prelude.fut
futlib/soacs.fut
futlib/zip.fut
+ rts/c/cuda.h
+ rts/c/free_list.h
rts/c/lock.h
rts/c/opencl.h
rts/c/panic.h
@@ -37,7 +39,6 @@ extra-source-files:
rts/csharp/reader.cs
rts/csharp/scalar.cs
rts/futhark-doc/style.css
- rts/python/__init__.py
rts/python/memory.py
rts/python/opencl.py
rts/python/panic.py
@@ -49,6 +50,63 @@ source-repository head
location: https://github.com/diku-dk/futhark
library
+ hs-source-dirs:
+ src
+ ghc-options: -Wall -Wcompat -Wredundant-constraints -Wincomplete-record-updates -Wmissing-export-lists
+ build-depends:
+ aeson
+ , ansi-terminal >=0.6.3.1
+ , array >=0.4
+ , base >=4 && <5
+ , bifunctors >=5.4.2
+ , binary >=0.8.3
+ , blaze-html >=0.9.0.1
+ , bytestring >=0.10.8
+ , containers >=0.5
+ , data-binary-ieee754 >=0.1
+ , directory >=1.3.0.0
+ , directory-tree >=0.12.1
+ , dlist >=0.6.0.1
+ , extra >=1.5.3
+ , file-embed >=0.0.9
+ , filepath >=1.4.1.1
+ , free >=4.12.4
+ , gitrev >=1.2.0
+ , haskeline
+ , http-client >=0.5.7.0
+ , http-client-tls >=0.3.5.1
+ , http-conduit >=2.2.4
+ , language-c-quote >=0.12
+ , mainland-pretty >=0.6.1
+ , markdown >=0.1.16
+ , megaparsec >=7.0.1
+ , mtl >=2.2.1
+ , neat-interpolation >=0.3
+ , parallel >=3.2.1.0
+ , parser-combinators >=1.0.0
+ , process >=1.4.3.0
+ , process-extras >=0.7.2
+ , random
+ , raw-strings-qq >=1.1
+ , regex-tdfa >=1.2
+ , srcloc >=0.4
+ , template-haskell >=2.11.1
+ , temporary
+ , text >=1.2.2.2
+ , th-lift-instances >=0.1.11
+ , time >=1.6.0.1
+ , transformers >=0.3
+ , vector >=0.12
+ , vector-binary-instances >=0.2.2.0
+ , versions >=3.3.1
+ , zip-archive >=0.3.1.1
+ , zlib >=0.6.1.2
+ build-tools:
+ alex
+ , happy
+ if !impl(ghc >= 8.0)
+ build-depends:
+ semigroups ==0.18.*
exposed-modules:
Futhark.Actions
Futhark.Analysis.AlgSimplify
@@ -69,6 +127,25 @@ library
Futhark.Analysis.UsageTable
Futhark.Binder
Futhark.Binder.Class
+ Futhark.CLI.Bench
+ Futhark.CLI.C
+ Futhark.CLI.CSharp
+ Futhark.CLI.CSOpenCL
+ Futhark.CLI.CUDA
+ Futhark.CLI.Datacmp
+ Futhark.CLI.Dataset
+ Futhark.CLI.Dev
+ Futhark.CLI.Doc
+ Futhark.CLI.Misc
+ Futhark.CLI.OpenCL
+ Futhark.CLI.Pkg
+ Futhark.CLI.PyOpenCL
+ Futhark.CLI.Python
+ Futhark.CLI.REPL
+ Futhark.CLI.Run
+ Futhark.CLI.Test
+ Futhark.CodeGen.Backends.CCUDA
+ Futhark.CodeGen.Backends.CCUDA.Boilerplate
Futhark.CodeGen.Backends.COpenCL
Futhark.CodeGen.Backends.COpenCL.Boilerplate
Futhark.CodeGen.Backends.CSOpenCL
@@ -94,7 +171,10 @@ library
Futhark.CodeGen.ImpCode.OpenCL
Futhark.CodeGen.ImpCode.Sequential
Futhark.CodeGen.ImpGen
+ Futhark.CodeGen.ImpGen.CUDA
Futhark.CodeGen.ImpGen.Kernels
+ Futhark.CodeGen.ImpGen.Kernels.Base
+ Futhark.CodeGen.ImpGen.Kernels.SegRed
Futhark.CodeGen.ImpGen.Kernels.ToOpenCL
Futhark.CodeGen.ImpGen.Kernels.Transpose
Futhark.CodeGen.ImpGen.OpenCL
@@ -177,6 +257,7 @@ library
Futhark.Pass.ExtractKernels.ISRWIM
Futhark.Pass.ExtractKernels.Kernelise
Futhark.Pass.ExtractKernels.Segmented
+ Futhark.Pass.ExtractKernels.Split
Futhark.Pass.FirstOrderTransform
Futhark.Pass.KernelBabysitting
Futhark.Pass.ResolveAssertions
@@ -254,65 +335,10 @@ library
Language.Futhark.Parser.Parser
Language.Futhark.Parser.Lexer
Paths_futhark
- hs-source-dirs:
- src
- ghc-options: -Wall -Wcompat -Wredundant-constraints -Wincomplete-record-updates -Wmissing-export-lists
- build-depends:
- ansi-terminal >=0.6.3.1
- , array >=0.4
- , base >=4 && <5
- , bifunctors >=5.4.2
- , binary >=0.8.3
- , blaze-html >=0.9.0.1
- , bytestring >=0.10.8
- , containers >=0.5
- , data-binary-ieee754 >=0.1
- , directory >=1.3.0.0
- , directory-tree >=0.12.1
- , dlist >=0.6.0.1
- , extra >=1.5.3
- , file-embed >=0.0.9
- , filepath >=1.4.1.1
- , free >=4.12.4
- , gitrev >=1.2.0
- , http-client >=0.5.7.0
- , http-client-tls >=0.3.5.1
- , http-conduit >=2.2.4
- , language-c-quote >=0.12
- , mainland-pretty >=0.6.1
- , markdown >=0.1.16
- , megaparsec >=7.0.1
- , mtl >=2.2.1
- , neat-interpolation >=0.3
- , parallel >=3.2.1.0
- , parser-combinators >=1.0.0
- , process >=1.4.3.0
- , process-extras >=0.7.2
- , raw-strings-qq >=1.1
- , regex-tdfa >=1.2
- , srcloc >=0.4
- , template-haskell >=2.11.1
- , text >=1.2.2.2
- , th-lift-instances >=0.1.11
- , time >=1.6.0.1
- , transformers >=0.3
- , vector >=0.12
- , vector-binary-instances >=0.2.2.0
- , versions >=3.3.1
- , zip-archive >=0.3.1.1
- , zlib >=0.6.1.2
- build-tools:
- alex
- , happy
- if !impl(ghc >= 8.0)
- build-depends:
- semigroups ==0.18.*
default-language: Haskell2010
executable futhark
main-is: src/futhark.hs
- other-modules:
- Paths_futhark
ghc-options: -Wall -Wcompat -Wredundant-constraints -Wincomplete-record-updates -Wmissing-export-lists -threaded -rtsopts "-with-rtsopts=-N -qg"
build-depends:
aeson
@@ -334,6 +360,7 @@ executable futhark
, free >=4.12.4
, futhark
, gitrev >=1.2.0
+ , haskeline
, http-client >=0.5.7.0
, http-client-tls >=0.3.5.1
, http-conduit >=2.2.4
@@ -365,13 +392,15 @@ executable futhark
if !impl(ghc >= 8.0)
build-depends:
semigroups ==0.18.*
+ other-modules:
+ Paths_futhark
default-language: Haskell2010
executable futhark-bench
- main-is: src/futhark-bench.hs
+ main-is: src/wrapper.hs
other-modules:
Paths_futhark
- ghc-options: -Wall -Wcompat -Wredundant-constraints -Wincomplete-record-updates -Wmissing-export-lists -threaded -rtsopts "-with-rtsopts=-N -qg"
+ ghc-options: -Wall -Wcompat -Wredundant-constraints -Wincomplete-record-updates -Wmissing-export-lists
build-depends:
aeson
, ansi-terminal >=0.6.3.1
@@ -390,8 +419,8 @@ executable futhark-bench
, file-embed >=0.0.9
, filepath >=1.4.1.1
, free >=4.12.4
- , futhark
, gitrev >=1.2.0
+ , haskeline
, http-client >=0.5.7.0
, http-client-tls >=0.3.5.1
, http-conduit >=2.2.4
@@ -426,10 +455,10 @@ executable futhark-bench
default-language: Haskell2010
executable futhark-c
- main-is: src/futhark-c.hs
+ main-is: src/wrapper.hs
other-modules:
Paths_futhark
- ghc-options: -Wall -Wcompat -Wredundant-constraints -Wincomplete-record-updates -Wmissing-export-lists -threaded -rtsopts "-with-rtsopts=-N -qg"
+ ghc-options: -Wall -Wcompat -Wredundant-constraints -Wincomplete-record-updates -Wmissing-export-lists
build-depends:
aeson
, ansi-terminal >=0.6.3.1
@@ -448,8 +477,8 @@ executable futhark-c
, file-embed >=0.0.9
, filepath >=1.4.1.1
, free >=4.12.4
- , futhark
, gitrev >=1.2.0
+ , haskeline
, http-client >=0.5.7.0
, http-client-tls >=0.3.5.1
, http-conduit >=2.2.4
@@ -484,10 +513,10 @@ executable futhark-c
default-language: Haskell2010
executable futhark-cs
- main-is: src/futhark-cs.hs
+ main-is: src/wrapper.hs
other-modules:
Paths_futhark
- ghc-options: -Wall -Wcompat -Wredundant-constraints -Wincomplete-record-updates -Wmissing-export-lists -threaded -rtsopts "-with-rtsopts=-N -qg"
+ ghc-options: -Wall -Wcompat -Wredundant-constraints -Wincomplete-record-updates -Wmissing-export-lists
build-depends:
aeson
, ansi-terminal >=0.6.3.1
@@ -506,8 +535,8 @@ executable futhark-cs
, file-embed >=0.0.9
, filepath >=1.4.1.1
, free >=4.12.4
- , futhark
, gitrev >=1.2.0
+ , haskeline
, http-client >=0.5.7.0
, http-client-tls >=0.3.5.1
, http-conduit >=2.2.4
@@ -542,10 +571,10 @@ executable futhark-cs
default-language: Haskell2010
executable futhark-csopencl
- main-is: src/futhark-csopencl.hs
+ main-is: src/wrapper.hs
other-modules:
Paths_futhark
- ghc-options: -Wall -Wcompat -Wredundant-constraints -Wincomplete-record-updates -Wmissing-export-lists -threaded -rtsopts "-with-rtsopts=-N -qg"
+ ghc-options: -Wall -Wcompat -Wredundant-constraints -Wincomplete-record-updates -Wmissing-export-lists
build-depends:
aeson
, ansi-terminal >=0.6.3.1
@@ -564,8 +593,8 @@ executable futhark-csopencl
, file-embed >=0.0.9
, filepath >=1.4.1.1
, free >=4.12.4
- , futhark
, gitrev >=1.2.0
+ , haskeline
, http-client >=0.5.7.0
, http-client-tls >=0.3.5.1
, http-conduit >=2.2.4
@@ -600,10 +629,10 @@ executable futhark-csopencl
default-language: Haskell2010
executable futhark-dataset
- main-is: src/futhark-dataset.hs
+ main-is: src/wrapper.hs
other-modules:
Paths_futhark
- ghc-options: -Wall -Wcompat -Wredundant-constraints -Wincomplete-record-updates -Wmissing-export-lists -threaded -rtsopts "-with-rtsopts=-N -qg"
+ ghc-options: -Wall -Wcompat -Wredundant-constraints -Wincomplete-record-updates -Wmissing-export-lists
build-depends:
aeson
, ansi-terminal >=0.6.3.1
@@ -622,8 +651,8 @@ executable futhark-dataset
, file-embed >=0.0.9
, filepath >=1.4.1.1
, free >=4.12.4
- , futhark
, gitrev >=1.2.0
+ , haskeline
, http-client >=0.5.7.0
, http-client-tls >=0.3.5.1
, http-conduit >=2.2.4
@@ -658,10 +687,10 @@ executable futhark-dataset
default-language: Haskell2010
executable futhark-doc
- main-is: src/futhark-doc.hs
+ main-is: src/wrapper.hs
other-modules:
Paths_futhark
- ghc-options: -Wall -Wcompat -Wredundant-constraints -Wincomplete-record-updates -Wmissing-export-lists -threaded -rtsopts "-with-rtsopts=-N -qg"
+ ghc-options: -Wall -Wcompat -Wredundant-constraints -Wincomplete-record-updates -Wmissing-export-lists
build-depends:
aeson
, ansi-terminal >=0.6.3.1
@@ -680,8 +709,8 @@ executable futhark-doc
, file-embed >=0.0.9
, filepath >=1.4.1.1
, free >=4.12.4
- , futhark
, gitrev >=1.2.0
+ , haskeline
, http-client >=0.5.7.0
, http-client-tls >=0.3.5.1
, http-conduit >=2.2.4
@@ -716,10 +745,10 @@ executable futhark-doc
default-language: Haskell2010
executable futhark-opencl
- main-is: src/futhark-opencl.hs
+ main-is: src/wrapper.hs
other-modules:
Paths_futhark
- ghc-options: -Wall -Wcompat -Wredundant-constraints -Wincomplete-record-updates -Wmissing-export-lists -threaded -rtsopts "-with-rtsopts=-N -qg"
+ ghc-options: -Wall -Wcompat -Wredundant-constraints -Wincomplete-record-updates -Wmissing-export-lists
build-depends:
aeson
, ansi-terminal >=0.6.3.1
@@ -738,8 +767,8 @@ executable futhark-opencl
, file-embed >=0.0.9
, filepath >=1.4.1.1
, free >=4.12.4
- , futhark
, gitrev >=1.2.0
+ , haskeline
, http-client >=0.5.7.0
, http-client-tls >=0.3.5.1
, http-conduit >=2.2.4
@@ -774,10 +803,10 @@ executable futhark-opencl
default-language: Haskell2010
executable futhark-pkg
- main-is: src/futhark-pkg.hs
+ main-is: src/wrapper.hs
other-modules:
Paths_futhark
- ghc-options: -Wall -Wcompat -Wredundant-constraints -Wincomplete-record-updates -Wmissing-export-lists -threaded -rtsopts "-with-rtsopts=-N -qg"
+ ghc-options: -Wall -Wcompat -Wredundant-constraints -Wincomplete-record-updates -Wmissing-export-lists
build-depends:
aeson
, ansi-terminal >=0.6.3.1
@@ -796,8 +825,8 @@ executable futhark-pkg
, file-embed >=0.0.9
, filepath >=1.4.1.1
, free >=4.12.4
- , futhark
, gitrev >=1.2.0
+ , haskeline
, http-client >=0.5.7.0
, http-client-tls >=0.3.5.1
, http-conduit >=2.2.4
@@ -832,10 +861,10 @@ executable futhark-pkg
default-language: Haskell2010
executable futhark-py
- main-is: src/futhark-py.hs
+ main-is: src/wrapper.hs
other-modules:
Paths_futhark
- ghc-options: -Wall -Wcompat -Wredundant-constraints -Wincomplete-record-updates -Wmissing-export-lists -threaded -rtsopts "-with-rtsopts=-N -qg"
+ ghc-options: -Wall -Wcompat -Wredundant-constraints -Wincomplete-record-updates -Wmissing-export-lists
build-depends:
aeson
, ansi-terminal >=0.6.3.1
@@ -854,8 +883,8 @@ executable futhark-py
, file-embed >=0.0.9
, filepath >=1.4.1.1
, free >=4.12.4
- , futhark
, gitrev >=1.2.0
+ , haskeline
, http-client >=0.5.7.0
, http-client-tls >=0.3.5.1
, http-conduit >=2.2.4
@@ -890,10 +919,10 @@ executable futhark-py
default-language: Haskell2010
executable futhark-pyopencl
- main-is: src/futhark-pyopencl.hs
+ main-is: src/wrapper.hs
other-modules:
Paths_futhark
- ghc-options: -Wall -Wcompat -Wredundant-constraints -Wincomplete-record-updates -Wmissing-export-lists -threaded -rtsopts "-with-rtsopts=-N -qg"
+ ghc-options: -Wall -Wcompat -Wredundant-constraints -Wincomplete-record-updates -Wmissing-export-lists
build-depends:
aeson
, ansi-terminal >=0.6.3.1
@@ -912,8 +941,8 @@ executable futhark-pyopencl
, file-embed >=0.0.9
, filepath >=1.4.1.1
, free >=4.12.4
- , futhark
, gitrev >=1.2.0
+ , haskeline
, http-client >=0.5.7.0
, http-client-tls >=0.3.5.1
, http-conduit >=2.2.4
@@ -948,10 +977,10 @@ executable futhark-pyopencl
default-language: Haskell2010
executable futhark-test
- main-is: src/futhark-test.hs
+ main-is: src/wrapper.hs
other-modules:
Paths_futhark
- ghc-options: -Wall -Wcompat -Wredundant-constraints -Wincomplete-record-updates -Wmissing-export-lists -threaded -rtsopts "-with-rtsopts=-N -qg"
+ ghc-options: -Wall -Wcompat -Wredundant-constraints -Wincomplete-record-updates -Wmissing-export-lists
build-depends:
aeson
, ansi-terminal >=0.6.3.1
@@ -970,8 +999,8 @@ executable futhark-test
, file-embed >=0.0.9
, filepath >=1.4.1.1
, free >=4.12.4
- , futhark
, gitrev >=1.2.0
+ , haskeline
, http-client >=0.5.7.0
, http-client-tls >=0.3.5.1
, http-conduit >=2.2.4
@@ -1009,9 +1038,10 @@ executable futharki
main-is: src/futharki.hs
other-modules:
Paths_futhark
- ghc-options: -Wall -Wcompat -Wredundant-constraints -Wincomplete-record-updates -Wmissing-export-lists -threaded -rtsopts "-with-rtsopts=-N -qg"
+ ghc-options: -Wall -Wcompat -Wredundant-constraints -Wincomplete-record-updates -Wmissing-export-lists
build-depends:
- ansi-terminal >=0.6.3.1
+ aeson
+ , ansi-terminal >=0.6.3.1
, array >=0.4
, base >=4 && <5
, bifunctors >=5.4.2
@@ -1027,7 +1057,6 @@ executable futharki
, file-embed >=0.0.9
, filepath >=1.4.1.1
, free >=4.12.4
- , futhark
, gitrev >=1.2.0
, haskeline
, http-client >=0.5.7.0
@@ -1043,10 +1072,12 @@ executable futharki
, parser-combinators >=1.0.0
, process >=1.4.3.0
, process-extras >=0.7.2
+ , random
, raw-strings-qq >=1.1
, regex-tdfa >=1.2
, srcloc >=0.4
, template-haskell >=2.11.1
+ , temporary
, text >=1.2.2.2
, th-lift-instances >=0.1.11
, time >=1.6.0.1
@@ -1064,25 +1095,13 @@ executable futharki
test-suite unit
type: exitcode-stdio-1.0
main-is: futhark_tests.hs
- other-modules:
- Futhark.Analysis.ScalExpTests
- Futhark.Optimise.AlgSimplifyTests
- Futhark.Pkg.SolveTests
- Futhark.Representation.AST.Attributes.RearrangeTests
- Futhark.Representation.AST.Attributes.ReshapeTests
- Futhark.Representation.AST.AttributesTests
- Futhark.Representation.AST.Syntax.CoreTests
- Futhark.Representation.AST.SyntaxTests
- Futhark.Representation.PrimitiveTests
- Language.Futhark.CoreTests
- Language.Futhark.SyntaxTests
- Paths_futhark
hs-source-dirs:
unittests
ghc-options: -Wall -Wcompat -Wredundant-constraints -Wincomplete-record-updates -Wmissing-export-lists
build-depends:
HUnit
, QuickCheck >=2.8
+ , aeson
, ansi-terminal >=0.6.3.1
, array >=0.4
, base >=4 && <5
@@ -1101,6 +1120,7 @@ test-suite unit
, free >=4.12.4
, futhark
, gitrev >=1.2.0
+ , haskeline
, http-client >=0.5.7.0
, http-client-tls >=0.3.5.1
, http-conduit >=2.2.4
@@ -1114,6 +1134,7 @@ test-suite unit
, parser-combinators >=1.0.0
, process >=1.4.3.0
, process-extras >=0.7.2
+ , random
, raw-strings-qq >=1.1
, regex-tdfa >=1.2
, srcloc >=0.4
@@ -1121,6 +1142,7 @@ test-suite unit
, tasty-hunit
, tasty-quickcheck
, template-haskell >=2.11.1
+ , temporary
, text >=1.2.2.2
, th-lift-instances >=0.1.11
, time >=1.6.0.1
@@ -1133,4 +1155,17 @@ test-suite unit
if !impl(ghc >= 8.0)
build-depends:
semigroups ==0.18.*
+ other-modules:
+ Futhark.Analysis.ScalExpTests
+ Futhark.Optimise.AlgSimplifyTests
+ Futhark.Pkg.SolveTests
+ Futhark.Representation.AST.Attributes.RearrangeTests
+ Futhark.Representation.AST.Attributes.ReshapeTests
+ Futhark.Representation.AST.AttributesTests
+ Futhark.Representation.AST.Syntax.CoreTests
+ Futhark.Representation.AST.SyntaxTests
+ Futhark.Representation.PrimitiveTests
+ Language.Futhark.CoreTests
+ Language.Futhark.SyntaxTests
+ Paths_futhark
default-language: Haskell2010
diff --git a/rts/c/cuda.h b/rts/c/cuda.h
new file mode 100644
index 0000000..c7c2f24
--- /dev/null
+++ b/rts/c/cuda.h
@@ -0,0 +1,535 @@
+/* Simple CUDA runtime framework */
+
+#define CUDA_SUCCEED(x) cuda_api_succeed(x, #x, __FILE__, __LINE__)
+#define NVRTC_SUCCEED(x) nvrtc_api_succeed(x, #x, __FILE__, __LINE__)
+
+static inline void cuda_api_succeed(CUresult res, const char *call,
+ const char *file, int line)
+{
+ if (res != CUDA_SUCCESS) {
+ const char *err_str;
+ cuGetErrorString(res, &err_str);
+ if (err_str == NULL) { err_str = "Unknown"; }
+ panic(-1, "%s:%d: CUDA call\n %s\nfailed with error code %d (%s)\n",
+ file, line, call, res, err_str);
+ }
+}
+
+static inline void nvrtc_api_succeed(nvrtcResult res, const char *call,
+ const char *file, int line)
+{
+ if (res != NVRTC_SUCCESS) {
+ const char *err_str = nvrtcGetErrorString(res);
+ panic(-1, "%s:%d: NVRTC call\n %s\nfailed with error code %d (%s)\n",
+ file, line, call, res, err_str);
+ }
+}
+
+struct cuda_config {
+ int debugging;
+ int logging;
+ const char *preferred_device;
+
+ const char *dump_program_to;
+ const char *load_program_from;
+
+ const char *dump_ptx_to;
+ const char *load_ptx_from;
+
+ size_t default_block_size;
+ size_t default_grid_size;
+ size_t default_tile_size;
+ size_t default_threshold;
+
+ int default_block_size_changed;
+ int default_grid_size_changed;
+ int default_tile_size_changed;
+
+ int num_sizes;
+ const char **size_names;
+ const char **size_vars;
+ size_t *size_values;
+ const char **size_classes;
+};
+
+void cuda_config_init(struct cuda_config *cfg,
+ int num_sizes,
+ const char *size_names[],
+ const char *size_vars[],
+ size_t *size_values,
+ const char *size_classes[])
+{
+ cfg->debugging = 0;
+ cfg->logging = 0;
+ cfg->preferred_device = "";
+
+ cfg->dump_program_to = NULL;
+ cfg->load_program_from = NULL;
+
+ cfg->dump_ptx_to = NULL;
+ cfg->load_ptx_from = NULL;
+
+ cfg->default_block_size = 256;
+ cfg->default_grid_size = 128;
+ cfg->default_tile_size = 32;
+ cfg->default_threshold = 32*1024;
+
+ cfg->default_block_size_changed = 0;
+ cfg->default_grid_size_changed = 0;
+ cfg->default_tile_size_changed = 0;
+
+ cfg->num_sizes = num_sizes;
+ cfg->size_names = size_names;
+ cfg->size_vars = size_vars;
+ cfg->size_values = size_values;
+ cfg->size_classes = size_classes;
+}
+
+struct cuda_context {
+ CUdevice dev;
+ CUcontext cu_ctx;
+ CUmodule module;
+
+ struct cuda_config cfg;
+
+ struct free_list free_list;
+
+ size_t max_block_size;
+ size_t max_grid_size;
+ size_t max_tile_size;
+ size_t max_threshold;
+
+ size_t lockstep_width;
+};
+
+#define CU_DEV_ATTR(x) (CU_DEVICE_ATTRIBUTE_##x)
+#define device_query(dev,attrib) _device_query(dev, CU_DEV_ATTR(attrib))
+static int _device_query(CUdevice dev, CUdevice_attribute attrib)
+{
+ int val;
+ CUDA_SUCCEED(cuDeviceGetAttribute(&val, attrib, dev));
+ return val;
+}
+
+#define CU_FUN_ATTR(x) (CU_FUNC_ATTRIBUTE_##x)
+#define function_query(fn,attrib) _function_query(dev, CU_FUN_ATTR(attrib))
+static int _function_query(CUfunction dev, CUfunction_attribute attrib)
+{
+ int val;
+ CUDA_SUCCEED(cuFuncGetAttribute(&val, attrib, dev));
+ return val;
+}
+
+void set_preferred_device(struct cuda_config *cfg, const char *s)
+{
+ cfg->preferred_device = s;
+}
+
+static int cuda_device_setup(struct cuda_context *ctx)
+{
+ char name[256];
+ int count, chosen = -1, best_cc = -1;
+ int cc_major_best, cc_minor_best;
+ int cc_major, cc_minor;
+ CUdevice dev;
+
+ CUDA_SUCCEED(cuDeviceGetCount(&count));
+ if (count == 0) { return 1; }
+
+ // XXX: Current device selection policy is to choose the device with the
+ // highest compute capability (if no preferred device is set).
+ // This should maybe be changed, since greater compute capability is not
+ // necessarily an indicator of better performance.
+ for (int i = 0; i < count; i++) {
+ CUDA_SUCCEED(cuDeviceGet(&dev, i));
+
+ cc_major = device_query(dev, COMPUTE_CAPABILITY_MAJOR);
+ cc_minor = device_query(dev, COMPUTE_CAPABILITY_MINOR);
+
+ CUDA_SUCCEED(cuDeviceGetName(name, sizeof(name)/sizeof(name[0]) - 1, dev));
+ name[sizeof(name)/sizeof(name[0])] = 0;
+
+ if (ctx->cfg.debugging) {
+ fprintf(stderr, "Device #%d: name=\"%s\", compute capability=%d.%d\n",
+ i, name, cc_major, cc_minor);
+ }
+
+ if (device_query(dev, COMPUTE_MODE) == CU_COMPUTEMODE_PROHIBITED) {
+ if (ctx->cfg.debugging) {
+ fprintf(stderr, "Device #%d is compute-prohibited, ignoring\n", i);
+ }
+ continue;
+ }
+
+ if (best_cc == -1 || cc_major > cc_major_best ||
+ (cc_major == cc_major_best && cc_minor > cc_minor_best)) {
+ best_cc = i;
+ cc_major_best = cc_major;
+ cc_minor_best = cc_minor;
+ }
+
+ if (chosen == -1 && strstr(name, ctx->cfg.preferred_device) == name) {
+ chosen = i;
+ }
+ }
+
+ if (chosen == -1) { chosen = best_cc; }
+ if (chosen == -1) { return 1; }
+
+ if (ctx->cfg.debugging) {
+ fprintf(stderr, "Using device #%d\n", chosen);
+ }
+
+ CUDA_SUCCEED(cuDeviceGet(&ctx->dev, chosen));
+ return 0;
+}
+
+static char *concat_fragments(const char *src_fragments[])
+{
+ size_t src_len = 0;
+ const char **p;
+
+ for (p = src_fragments; *p; p++) {
+ src_len += strlen(*p);
+ }
+
+ char *src = malloc(src_len + 1);
+ size_t n = 0;
+ for (p = src_fragments; *p; p++) {
+ strcpy(src + n, *p);
+ n += strlen(*p);
+ }
+
+ return src;
+}
+
+static const char *cuda_nvrtc_get_arch(CUdevice dev)
+{
+ struct {
+ int major;
+ int minor;
+ const char *arch_str;
+ } static const x[] = {
+ { 3, 0, "compute_30" },
+ { 3, 2, "compute_32" },
+ { 3, 5, "compute_35" },
+ { 3, 7, "compute_37" },
+ { 5, 0, "compute_50" },
+ { 5, 2, "compute_52" },
+ { 5, 3, "compute_53" },
+ { 6, 0, "compute_60" },
+ { 6, 1, "compute_61" },
+ { 6, 2, "compute_62" },
+ { 7, 0, "compute_70" },
+ { 7, 2, "compute_72" }
+ };
+
+ int major = device_query(dev, COMPUTE_CAPABILITY_MAJOR);
+ int minor = device_query(dev, COMPUTE_CAPABILITY_MINOR);
+
+ int chosen = -1;
+ for (int i = 0; i < sizeof(x)/sizeof(x[0]); i++) {
+ if (x[i].major < major || (x[i].major == major && x[i].minor <= minor)) {
+ chosen = i;
+ } else {
+ break;
+ }
+ }
+
+ if (chosen == -1) {
+ panic(-1, "Unsupported compute capability %d.%d\n", major, minor);
+ }
+ return x[chosen].arch_str;
+}
+
+static char *cuda_nvrtc_build(struct cuda_context *ctx, const char *src)
+{
+ nvrtcProgram prog;
+ NVRTC_SUCCEED(nvrtcCreateProgram(&prog, src, "futhark-cuda", 0, NULL, NULL));
+
+ size_t n_opts, i = 0, i_dyn, n_opts_alloc = 20 + ctx->cfg.num_sizes;
+ const char **opts = malloc(n_opts_alloc * sizeof(const char *));
+ opts[i++] = "-arch";
+ opts[i++] = cuda_nvrtc_get_arch(ctx->dev);
+ opts[i++] = "-default-device";
+ if (ctx->cfg.debugging) {
+ opts[i++] = "-G";
+ opts[i++] = "-lineinfo";
+ } else {
+ opts[i++] = "--disable-warnings";
+ }
+ i_dyn = i;
+ for (size_t j = 0; j < ctx->cfg.num_sizes; j++) {
+ opts[i++] = msgprintf("-D%s=%zu", ctx->cfg.size_vars[j],
+ ctx->cfg.size_values[j]);
+ }
+ opts[i++] = msgprintf("-DLOCKSTEP_WIDTH=%zu", ctx->lockstep_width);
+ opts[i++] = msgprintf("-DMAX_THREADS_PER_BLOCK=%zu", ctx->max_block_size);
+ n_opts = i;
+
+ if (ctx->cfg.debugging) {
+ fprintf(stderr, "NVRTC compile options:\n");
+ for (size_t j = 0; j < n_opts; j++) {
+ fprintf(stderr, "\t%s\n", opts[j]);
+ }
+ fprintf(stderr, "\n");
+ }
+
+ nvrtcResult res = nvrtcCompileProgram(prog, n_opts, opts);
+ if (res != NVRTC_SUCCESS) {
+ size_t log_size;
+ if (nvrtcGetProgramLogSize(prog, &log_size) == NVRTC_SUCCESS) {
+ char *log = malloc(log_size);
+ if (nvrtcGetProgramLog(prog, log) == NVRTC_SUCCESS) {
+ fprintf(stderr,"Compilation log:\n%s\n", log);
+ }
+ free(log);
+ }
+ NVRTC_SUCCEED(res);
+ }
+
+ for (i = i_dyn; i < n_opts; i++) { free((char *)opts[i]); }
+ free(opts);
+
+ char *ptx;
+ size_t ptx_size;
+ NVRTC_SUCCEED(nvrtcGetPTXSize(prog, &ptx_size));
+ ptx = malloc(ptx_size);
+ NVRTC_SUCCEED(nvrtcGetPTX(prog, ptx));
+
+ NVRTC_SUCCEED(nvrtcDestroyProgram(&prog));
+
+ return ptx;
+}
+
+static void cuda_size_setup(struct cuda_context *ctx)
+{
+ if (ctx->cfg.default_block_size > ctx->max_block_size) {
+ if (ctx->cfg.default_block_size_changed) {
+ fprintf(stderr,
+ "Note: Device limits default block size to %zu (down from %zu).\n",
+ ctx->max_block_size, ctx->cfg.default_block_size);
+ }
+ ctx->cfg.default_block_size = ctx->max_block_size;
+ }
+ if (ctx->cfg.default_grid_size > ctx->max_grid_size) {
+ if (ctx->cfg.default_grid_size_changed) {
+ fprintf(stderr,
+ "Note: Device limits default grid size to %zu (down from %zu).\n",
+ ctx->max_grid_size, ctx->cfg.default_grid_size);
+ }
+ ctx->cfg.default_grid_size = ctx->max_grid_size;
+ }
+ if (ctx->cfg.default_tile_size > ctx->max_tile_size) {
+ if (ctx->cfg.default_tile_size_changed) {
+ fprintf(stderr,
+ "Note: Device limits default tile size to %zu (down from %zu).\n",
+ ctx->max_tile_size, ctx->cfg.default_tile_size);
+ }
+ ctx->cfg.default_tile_size = ctx->max_tile_size;
+ }
+
+ for (int i = 0; i < ctx->cfg.num_sizes; i++) {
+ const char *size_class, *size_name;
+ size_t *size_value, max_value, default_value;
+
+ size_class = ctx->cfg.size_classes[i];
+ size_value = &ctx->cfg.size_values[i];
+ size_name = ctx->cfg.size_names[i];
+
+ if (strstr(size_class, "group_size") == size_class) {
+ max_value = ctx->max_block_size;
+ default_value = ctx->cfg.default_block_size;
+ } else if (strstr(size_class, "num_groups") == size_class) {
+ max_value = ctx->max_grid_size;
+ default_value = ctx->cfg.default_grid_size;
+ } else if (strstr(size_class, "tile_size") == size_class) {
+ max_value = ctx->max_tile_size;
+ default_value = ctx->cfg.default_tile_size;
+ } else if (strstr(size_class, "threshold") == size_class) {
+ max_value = ctx->max_threshold;
+ default_value = ctx->cfg.default_threshold;
+ } else {
+ panic(1, "Unknown size class for size '%s': %s\n", size_name, size_class);
+ }
+
+ if (*size_value == 0) {
+ *size_value = default_value;
+ } else if (max_value > 0 && *size_value > max_value) {
+ fprintf(stderr, "Note: Device limits %zu to %zu (down from %zu)\n",
+ size_name, max_value, *size_value);
+ *size_value = max_value;
+ }
+ }
+}
+
+static void dump_string_to_file(const char *file, const char *buf)
+{
+ FILE *f = fopen(file, "w");
+ assert(f != NULL);
+ assert(fputs(buf, f) != EOF);
+ assert(fclose(f) == 0);
+}
+
+static void load_string_from_file(const char *file, char **obuf, size_t *olen)
+{
+ char *buf;
+ size_t len;
+ FILE *f = fopen(file, "r");
+
+ assert(f != NULL);
+ assert(fseek(f, 0, SEEK_END) == 0);
+ len = ftell(f);
+ assert(fseek(f, 0, SEEK_SET) == 0);
+
+ buf = malloc(len + 1);
+ assert(fread(buf, 1, len, f) == len);
+ buf[len] = 0;
+ *obuf = buf;
+ if (olen != NULL) {
+ *olen = len;
+ }
+
+ assert(fclose(f) == 0);
+}
+
+static void cuda_module_setup(struct cuda_context *ctx,
+ const char *src_fragments[])
+{
+ char *ptx = NULL, *src = NULL;
+
+ if (ctx->cfg.load_ptx_from == NULL && ctx->cfg.load_program_from == NULL) {
+ src = concat_fragments(src_fragments);
+ ptx = cuda_nvrtc_build(ctx, src);
+ } else if (ctx->cfg.load_ptx_from == NULL) {
+ load_string_from_file(ctx->cfg.load_program_from, &src, NULL);
+ ptx = cuda_nvrtc_build(ctx, src);
+ } else {
+ if (ctx->cfg.load_program_from != NULL) {
+ fprintf(stderr,
+ "WARNING: Loading PTX from %s instead of C code from %s\n",
+ ctx->cfg.load_ptx_from, ctx->cfg.load_program_from);
+ }
+
+ load_string_from_file(ctx->cfg.load_ptx_from, &ptx, NULL);
+ }
+
+ if (ctx->cfg.dump_program_to != NULL) {
+ if (src == NULL) {
+ src = concat_fragments(src_fragments);
+ }
+ dump_string_to_file(ctx->cfg.dump_program_to, src);
+ }
+ if (ctx->cfg.dump_ptx_to != NULL) {
+ dump_string_to_file(ctx->cfg.dump_ptx_to, ptx);
+ }
+
+ CUDA_SUCCEED(cuModuleLoadData(&ctx->module, ptx));
+
+ free(ptx);
+ if (src != NULL) {
+ free(src);
+ }
+}
+
+void cuda_setup(struct cuda_context *ctx, const char *src_fragments[])
+{
+ CUDA_SUCCEED(cuInit(0));
+
+ if (cuda_device_setup(ctx) != 0) {
+ panic(-1, "No suitable CUDA device found.\n");
+ }
+ CUDA_SUCCEED(cuCtxCreate(&ctx->cu_ctx, 0, ctx->dev));
+
+ free_list_init(&ctx->free_list);
+
+ ctx->max_block_size = device_query(ctx->dev, MAX_THREADS_PER_BLOCK);
+ ctx->max_grid_size = device_query(ctx->dev, MAX_GRID_DIM_X);
+ ctx->max_tile_size = sqrt(ctx->max_block_size);
+ ctx->max_threshold = 0;
+ ctx->lockstep_width = device_query(ctx->dev, WARP_SIZE);
+
+ cuda_size_setup(ctx);
+ cuda_module_setup(ctx, src_fragments);
+}
+
+CUresult cuda_free_all(struct cuda_context *ctx);
+
+void cuda_cleanup(struct cuda_context *ctx)
+{
+ CUDA_SUCCEED(cuda_free_all(ctx));
+ CUDA_SUCCEED(cuModuleUnload(ctx->module));
+ CUDA_SUCCEED(cuCtxDestroy(ctx->cu_ctx));
+}
+
+CUresult cuda_alloc(struct cuda_context *ctx, size_t min_size,
+ const char *tag, CUdeviceptr *mem_out)
+{
+ if (min_size < sizeof(int)) {
+ min_size = sizeof(int);
+ }
+
+ size_t size;
+ if (free_list_find(&ctx->free_list, tag, &size, mem_out) == 0) {
+ if (size >= min_size) {
+ return CUDA_SUCCESS;
+ } else {
+ CUresult res = cuMemFree(*mem_out);
+ if (res != CUDA_SUCCESS) {
+ return res;
+ }
+ }
+ }
+
+ CUresult res = cuMemAlloc(mem_out, min_size);
+ while (res == CUDA_ERROR_OUT_OF_MEMORY) {
+ CUdeviceptr mem;
+ if (free_list_first(&ctx->free_list, &mem) == 0) {
+ res = cuMemFree(mem);
+ if (res != CUDA_SUCCESS) {
+ return res;
+ }
+ } else {
+ break;
+ }
+ res = cuMemAlloc(mem_out, min_size);
+ }
+
+ return res;
+}
+
+CUresult cuda_free(struct cuda_context *ctx, CUdeviceptr mem,
+ const char *tag)
+{
+ size_t size;
+ CUdeviceptr existing_mem;
+
+ // If there is already a block with this tag, then remove it.
+ if (free_list_find(&ctx->free_list, tag, &size, &existing_mem) == 0) {
+ CUresult res = cuMemFree(existing_mem);
+ if (res != CUDA_SUCCESS) {
+ return res;
+ }
+ }
+
+ CUresult res = cuMemGetAddressRange(NULL, &size, mem);
+ if (res == CUDA_SUCCESS) {
+ free_list_insert(&ctx->free_list, size, mem, tag);
+ }
+
+ return res;
+}
+
+CUresult cuda_free_all(struct cuda_context *ctx) {
+ CUdeviceptr mem;
+ free_list_pack(&ctx->free_list);
+ while (free_list_first(&ctx->free_list, &mem) == 0) {
+ CUresult res = cuMemFree(mem);
+ if (res != CUDA_SUCCESS) {
+ return res;
+ }
+ }
+
+ return CUDA_SUCCESS;
+}
+
diff --git a/rts/c/free_list.h b/rts/c/free_list.h
new file mode 100644
index 0000000..e59758a
--- /dev/null
+++ b/rts/c/free_list.h
@@ -0,0 +1,110 @@
+/* Free list management */
+
+/* An entry in the free list. May be invalid, to avoid having to
+ deallocate entries as soon as they are removed. There is also a
+ tag, to help with memory reuse. */
+struct free_list_entry {
+ size_t size;
+ fl_mem_t mem;
+ const char *tag;
+ unsigned char valid;
+};
+
+struct free_list {
+ struct free_list_entry *entries; // Pointer to entries.
+ int capacity; // Number of entries.
+ int used; // Number of valid entries.
+};
+
+void free_list_init(struct free_list *l) {
+ l->capacity = 30; // Picked arbitrarily.
+ l->used = 0;
+ l->entries = malloc(sizeof(struct free_list_entry) * l->capacity);
+ for (int i = 0; i < l->capacity; i++) {
+ l->entries[i].valid = 0;
+ }
+}
+
+/* Remove invalid entries from the free list. */
+void free_list_pack(struct free_list *l) {
+ int p = 0;
+ for (int i = 0; i < l->capacity; i++) {
+ if (l->entries[i].valid) {
+ l->entries[p] = l->entries[i];
+ p++;
+ }
+ }
+ // Now p == l->used.
+ l->entries = realloc(l->entries, l->used * sizeof(struct free_list_entry));
+ l->capacity = l->used;
+}
+
+void free_list_destroy(struct free_list *l) {
+ assert(l->used == 0);
+ free(l->entries);
+}
+
+int free_list_find_invalid(struct free_list *l) {
+ int i;
+ for (i = 0; i < l->capacity; i++) {
+ if (!l->entries[i].valid) {
+ break;
+ }
+ }
+ return i;
+}
+
+void free_list_insert(struct free_list *l, size_t size, fl_mem_t mem, const char *tag) {
+ int i = free_list_find_invalid(l);
+
+ if (i == l->capacity) {
+ // List is full; so we have to grow it.
+ int new_capacity = l->capacity * 2 * sizeof(struct free_list_entry);
+ l->entries = realloc(l->entries, new_capacity);
+ for (int j = 0; j < l->capacity; j++) {
+ l->entries[j+l->capacity].valid = 0;
+ }
+ l->capacity *= 2;
+ }
+
+ // Now 'i' points to the first invalid entry.
+ l->entries[i].valid = 1;
+ l->entries[i].size = size;
+ l->entries[i].mem = mem;
+ l->entries[i].tag = tag;
+
+ l->used++;
+}
+
+/* Find and remove a memory block of at least the desired size and
+ tag. Returns 0 on success. */
+int free_list_find(struct free_list *l, const char *tag, size_t *size_out, fl_mem_t *mem_out) {
+ int i;
+ for (i = 0; i < l->capacity; i++) {
+ if (l->entries[i].valid && l->entries[i].tag == tag) {
+ l->entries[i].valid = 0;
+ *size_out = l->entries[i].size;
+ *mem_out = l->entries[i].mem;
+ l->used--;
+ return 0;
+ }
+ }
+
+ return 1;
+}
+
+/* Remove the first block in the free list. Returns 0 if a block was
+ removed, and nonzero if the free list was already empty. */
+int free_list_first(struct free_list *l, fl_mem_t *mem_out) {
+ for (int i = 0; i < l->capacity; i++) {
+ if (l->entries[i].valid) {
+ l->entries[i].valid = 0;
+ *mem_out = l->entries[i].mem;
+ l->used--;
+ return 0;
+ }
+ }
+
+ return 1;
+}
+
diff --git a/rts/c/opencl.h b/rts/c/opencl.h
index 6225593..1142d96 100644
--- a/rts/c/opencl.h
+++ b/rts/c/opencl.h
@@ -1,15 +1,5 @@
/* The simple OpenCL runtime framework used by Futhark. */
-#define CL_USE_DEPRECATED_OPENCL_1_2_APIS
-
-#define CL_SILENCE_DEPRECATION // For macOS.
-
-#ifdef __APPLE__
- #include <OpenCL/cl.h>
-#else
- #include <CL/cl.h>
-#endif
-
#define OPENCL_SUCCEED_FATAL(e) opencl_succeed_fatal(e, #e, __FILE__, __LINE__)
#define OPENCL_SUCCEED_NONFATAL(e) opencl_succeed_nonfatal(e, #e, __FILE__, __LINE__)
// Take care not to override an existing error.
@@ -40,6 +30,8 @@ struct opencl_config {
const char* dump_program_to;
const char* load_program_from;
+ const char* dump_binary_to;
+ const char* load_binary_from;
size_t default_group_size;
size_t default_num_groups;
@@ -51,17 +43,17 @@ struct opencl_config {
int num_sizes;
const char **size_names;
+ const char **size_vars;
size_t *size_values;
const char **size_classes;
- const char **size_entry_points;
};
void opencl_config_init(struct opencl_config *cfg,
int num_sizes,
const char *size_names[],
+ const char *size_vars[],
size_t *size_values,
- const char *size_classes[],
- const char *size_entry_points[]) {
+ const char *size_classes[]) {
cfg->debugging = 0;
cfg->logging = 0;
cfg->preferred_device_num = 0;
@@ -69,10 +61,16 @@ void opencl_config_init(struct opencl_config *cfg,
cfg->preferred_device = "";
cfg->dump_program_to = NULL;
cfg->load_program_from = NULL;
+ cfg->dump_binary_to = NULL;
+ cfg->load_binary_from = NULL;
+
+ // The following are dummy sizes that mean the concrete defaults
+ // will be set during initialisation via hardware-inspection-based
+ // heuristics.
+ cfg->default_group_size = 0;
+ cfg->default_num_groups = 0;
+ cfg->default_tile_size = 0;
- cfg->default_group_size = 256;
- cfg->default_num_groups = 128;
- cfg->default_tile_size = 32;
cfg->default_threshold = 32*1024;
cfg->default_group_size_changed = 0;
@@ -80,117 +78,9 @@ void opencl_config_init(struct opencl_config *cfg,
cfg->num_sizes = num_sizes;
cfg->size_names = size_names;
+ cfg->size_vars = size_vars;
cfg->size_values = size_values;
cfg->size_classes = size_classes;
- cfg->size_entry_points = size_entry_points;
-}
-
-/* An entry in the free list. May be invalid, to avoid having to
- deallocate entries as soon as they are removed. There is also a
- tag, to help with memory reuse. */
-struct opencl_free_list_entry {
- size_t size;
- cl_mem mem;
- const char *tag;
- unsigned char valid;
-};
-
-struct opencl_free_list {
- struct opencl_free_list_entry *entries; // Pointer to entries.
- int capacity; // Number of entries.
- int used; // Number of valid entries.
-};
-
-void free_list_init(struct opencl_free_list *l) {
- l->capacity = 30; // Picked arbitrarily.
- l->used = 0;
- l->entries = malloc(sizeof(struct opencl_free_list_entry) * l->capacity);
- for (int i = 0; i < l->capacity; i++) {
- l->entries[i].valid = 0;
- }
-}
-
-/* Remove invalid entries from the free list. */
-void free_list_pack(struct opencl_free_list *l) {
- int p = 0;
- for (int i = 0; i < l->capacity; i++) {
- if (l->entries[i].valid) {
- l->entries[p] = l->entries[i];
- p++;
- }
- }
- // Now p == l->used.
- l->entries = realloc(l->entries, l->used * sizeof(struct opencl_free_list_entry));
- l->capacity = l->used;
-}
-
-void free_list_destroy(struct opencl_free_list *l) {
- assert(l->used == 0);
- free(l->entries);
-}
-
-int free_list_find_invalid(struct opencl_free_list *l) {
- int i;
- for (i = 0; i < l->capacity; i++) {
- if (!l->entries[i].valid) {
- break;
- }
- }
- return i;
-}
-
-void free_list_insert(struct opencl_free_list *l, size_t size, cl_mem mem, const char *tag) {
- int i = free_list_find_invalid(l);
-
- if (i == l->capacity) {
- // List is full; so we have to grow it.
- int new_capacity = l->capacity * 2 * sizeof(struct opencl_free_list_entry);
- l->entries = realloc(l->entries, new_capacity);
- for (int j = 0; j < l->capacity; j++) {
- l->entries[j+l->capacity].valid = 0;
- }
- l->capacity *= 2;
- }
-
- // Now 'i' points to the first invalid entry.
- l->entries[i].valid = 1;
- l->entries[i].size = size;
- l->entries[i].mem = mem;
- l->entries[i].tag = tag;
-
- l->used++;
-}
-
-/* Find and remove a memory block of at least the desired size and
- tag. Returns 0 on success. */
-int free_list_find(struct opencl_free_list *l, const char *tag, size_t *size_out, cl_mem *mem_out) {
- int i;
- for (i = 0; i < l->capacity; i++) {
- if (l->entries[i].valid && l->entries[i].tag == tag) {
- l->entries[i].valid = 0;
- *size_out = l->entries[i].size;
- *mem_out = l->entries[i].mem;
- l->used--;
- return 0;
- }
- }
-
- return 1;
-}
-
-/* Remove the first block in the free list. Returns 0 if a block was
- removed, and nonzero if the free list was already empty. */
-int free_list_first(struct opencl_free_list *l, cl_mem *mem_out) {
- for (int i = 0; i < l->capacity; i++) {
- if (l->entries[i].valid) {
- l->entries[i].valid = 0;
- *mem_out = l->entries[i].mem;
- l->used--;
- return 0;
- }
- }
-
- return 1;
}
struct opencl_context {
@@ -200,7 +90,7 @@ struct opencl_context {
struct opencl_config cfg;
- struct opencl_free_list free_list;
+ struct free_list free_list;
size_t max_group_size;
size_t max_num_groups;
@@ -235,6 +125,30 @@ static char *strclone(const char *str) {
return copy;
}
+// Read a file into a NUL-terminated string; returns NULL on error.
+static char* slurp_file(const char *filename, size_t *size) {
+ char *s;
+ FILE *f = fopen(filename, "rb"); // To avoid Windows messing with linebreaks.
+ if (f == NULL) return NULL;
+ fseek(f, 0, SEEK_END);
+ size_t src_size = ftell(f);
+ fseek(f, 0, SEEK_SET);
+ s = (char*) malloc(src_size + 1);
+ if (fread(s, 1, src_size, f) != src_size) {
+ free(s);
+ s = NULL;
+ } else {
+ s[src_size] = '\0';
+ }
+ fclose(f);
+
+ if (size) {
+ *size = src_size;
+ }
+
+ return s;
+}
+
static const char* opencl_error_string(unsigned int err)
{
switch (err) {
@@ -575,6 +489,9 @@ static cl_program setup_opencl_with_command_queue(struct opencl_context *ctx,
size_t max_tile_size = sqrt(max_group_size);
+ // Make sure this function is defined.
+ post_opencl_setup(ctx, &device_option);
+
if (max_group_size < ctx->cfg.default_group_size) {
if (ctx->cfg.default_group_size_changed) {
fprintf(stderr, "Note: Device limits default group size to %zu (down from %zu).\n",
@@ -626,9 +543,6 @@ static cl_program setup_opencl_with_command_queue(struct opencl_context *ctx,
}
}
- // Make sure this function is defined.
- post_opencl_setup(ctx, &device_option);
-
if (ctx->lockstep_width == 0) {
ctx->lockstep_width = 1;
}
@@ -644,14 +558,8 @@ static cl_program setup_opencl_with_command_queue(struct opencl_context *ctx,
// Maybe we have to read OpenCL source from somewhere else (used for debugging).
if (ctx->cfg.load_program_from != NULL) {
- FILE *f = fopen(ctx->cfg.load_program_from, "r");
- assert(f != NULL);
- fseek(f, 0, SEEK_END);
- src_size = ftell(f);
- fseek(f, 0, SEEK_SET);
- fut_opencl_src = malloc(src_size);
- assert(fread(fut_opencl_src, 1, src_size, f) == src_size);
- fclose(f);
+ fut_opencl_src = slurp_file(ctx->cfg.load_program_from, NULL);
+ assert(fut_opencl_src != NULL);
} else {
// Build the OpenCL program. First we have to concatenate all the fragments.
for (const char **src = srcs; src && *src; src++) {
@@ -680,29 +588,63 @@ static cl_program setup_opencl_with_command_queue(struct opencl_context *ctx,
fclose(f);
}
- prog = clCreateProgramWithSource(ctx->ctx, 1, src_ptr, &src_size, &error);
- assert(error == 0);
+ if (ctx->cfg.load_binary_from == NULL) {
+ prog = clCreateProgramWithSource(ctx->ctx, 1, src_ptr, &src_size, &error);
+ assert(error == 0);
- int compile_opts_size = 1024;
- for (int i = 0; i < ctx->cfg.num_sizes; i++) {
- compile_opts_size += strlen(ctx->cfg.size_names[i]) + 20;
- }
- char *compile_opts = malloc(compile_opts_size);
+ int compile_opts_size = 1024;
+ for (int i = 0; i < ctx->cfg.num_sizes; i++) {
+ compile_opts_size += strlen(ctx->cfg.size_names[i]) + 20;
+ }
+ char *compile_opts = malloc(compile_opts_size);
- int w = snprintf(compile_opts, compile_opts_size,
- "-DLOCKSTEP_WIDTH=%d ",
- (int)ctx->lockstep_width);
+ int w = snprintf(compile_opts, compile_opts_size,
+ "-DLOCKSTEP_WIDTH=%d ",
+ (int)ctx->lockstep_width);
- for (int i = 0; i < ctx->cfg.num_sizes; i++) {
- w += snprintf(compile_opts+w, compile_opts_size-w,
- "-D%s=%d ", ctx->cfg.size_names[i],
- (int)ctx->cfg.size_values[i]);
+ for (int i = 0; i < ctx->cfg.num_sizes; i++) {
+ w += snprintf(compile_opts+w, compile_opts_size-w,
+ "-D%s=%d ",
+ ctx->cfg.size_vars[i],
+ (int)ctx->cfg.size_values[i]);
+ }
+
+ OPENCL_SUCCEED_FATAL(build_opencl_program(prog, device_option.device, compile_opts));
+
+ free(compile_opts);
+ } else {
+ size_t binary_size;
+ unsigned char *fut_opencl_bin =
+ (unsigned char*) slurp_file(ctx->cfg.load_binary_from, &binary_size);
+ assert(fut_opencl_src != NULL);
+ const unsigned char *binaries[1] = { fut_opencl_bin };
+ cl_int status = 0;
+
+ prog = clCreateProgramWithBinary(ctx->ctx, 1, &device_option.device,
+ &binary_size, binaries,
+ &status, &error);
+
+ OPENCL_SUCCEED_FATAL(status);
+ OPENCL_SUCCEED_FATAL(error);
}
- OPENCL_SUCCEED_FATAL(build_opencl_program(prog, device_option.device, compile_opts));
- free(compile_opts);
free(fut_opencl_src);
+ if (ctx->cfg.dump_binary_to != NULL) {
+ size_t binary_size;
+ OPENCL_SUCCEED_FATAL(clGetProgramInfo(prog, CL_PROGRAM_BINARY_SIZES,
+ sizeof(size_t), &binary_size, NULL));
+ unsigned char *binary = malloc(binary_size);
+ unsigned char *binaries[1] = { binary };
+ OPENCL_SUCCEED_FATAL(clGetProgramInfo(prog, CL_PROGRAM_BINARIES,
+ sizeof(unsigned char*), binaries, NULL));
+
+ FILE *f = fopen(ctx->cfg.dump_binary_to, "w");
+ assert(f != NULL);
+ fwrite(binary, sizeof(char), binary_size, f);
+ fclose(f);
+ }
+
return prog;
}
diff --git a/rts/csharp/opencl.cs b/rts/csharp/opencl.cs
index dced81a..e90d5b2 100644
--- a/rts/csharp/opencl.cs
+++ b/rts/csharp/opencl.cs
@@ -38,6 +38,7 @@ public struct OpenCLConfig
public int NumSizes;
public string[] SizeNames;
+ public string[] SizeVars;
public int[] SizeValues;
public string[] SizeClasses;
}
@@ -362,6 +363,7 @@ private OpenCLFreeList OpenCLFreeListInit()
private void OpenCLConfigInit(out OpenCLConfig cfg,
int num_sizes,
string[] size_names,
+ string[] size_vars,
int[] size_values,
string[] size_classes)
{
@@ -379,6 +381,7 @@ private void OpenCLConfigInit(out OpenCLConfig cfg,
cfg.NumSizes = num_sizes;
cfg.SizeNames = size_names;
+ cfg.SizeVars = size_vars;
cfg.SizeValues = size_values;
cfg.SizeClasses = size_classes;
}
@@ -873,7 +876,7 @@ private CLProgramHandle SetupOpenCL(ref FutharkContext ctx,
for (int i = 0; i < ctx.OpenCL.Cfg.NumSizes; i++) {
compile_opts += String.Format("-D{0}={1} ",
- ctx.OpenCL.Cfg.SizeNames[i],
+ ctx.OpenCL.Cfg.SizeVars[i],
ctx.OpenCL.Cfg.SizeValues[i]);
}
@@ -897,11 +900,8 @@ private void FutharkConfigPrintSizes()
int n = FutharkGetNumSizes();
for (int i = 0; i < n; i++)
{
- if (FutharkGetSizeEntry(i) == EntryPoint)
- {
- Console.WriteLine("{0} ({1})", FutharkGetSizeName(i),
- FutharkGetSizeClass(i));
- }
+ Console.WriteLine("{0} ({1})", FutharkGetSizeName(i),
+ FutharkGetSizeClass(i));
}
Environment.Exit(0);
}
diff --git a/rts/python/__init__.py b/rts/python/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/rts/python/__init__.py
+++ /dev/null
diff --git a/rts/python/opencl.py b/rts/python/opencl.py
index 1594c27..46f6278 100644
--- a/rts/python/opencl.py
+++ b/rts/python/opencl.py
@@ -164,10 +164,13 @@ def initialise_opencl_object(self,
else:
self.sizes[k] = v['value']
+ # XXX: we perform only a subset of z-encoding here. Really, the
+ # compiler should provide us with the variables to which
+ # parameters are mapped.
if (len(program_src) >= 0):
return cl.Program(self.ctx, program_src).build(
["-DLOCKSTEP_WIDTH={}".format(lockstep_width)]
- + ["-D{}={}".format(s,v) for (s,v) in self.sizes.items()])
+ + ["-D{}={}".format(s.replace('z', 'zz').replace('.', 'zi'),v) for (s,v) in self.sizes.items()])
def opencl_alloc(self, min_size, tag):
min_size = 1 if min_size == 0 else min_size
diff --git a/src/Futhark/Analysis/DataDependencies.hs b/src/Futhark/Analysis/DataDependencies.hs
index 97094b5..b6bf5ac 100644
--- a/src/Futhark/Analysis/DataDependencies.hs
+++ b/src/Futhark/Analysis/DataDependencies.hs
@@ -7,7 +7,6 @@ module Futhark.Analysis.DataDependencies
)
where
-import Data.Semigroup ((<>))
import qualified Data.Map.Strict as M
import qualified Data.Set as S
diff --git a/src/Futhark/Analysis/HORepresentation/MapNest.hs b/src/Futhark/Analysis/HORepresentation/MapNest.hs
index f507697..8d98c78 100644
--- a/src/Futhark/Analysis/HORepresentation/MapNest.hs
+++ b/src/Futhark/Analysis/HORepresentation/MapNest.hs
@@ -16,7 +16,6 @@ where
import Control.Monad
import Data.List
import Data.Maybe
-import Data.Semigroup ((<>))
import qualified Data.Map.Strict as M
import qualified Data.Set as S
diff --git a/src/Futhark/Analysis/HORepresentation/SOAC.hs b/src/Futhark/Analysis/HORepresentation/SOAC.hs
index a8cd9bb..164fd14 100644
--- a/src/Futhark/Analysis/HORepresentation/SOAC.hs
+++ b/src/Futhark/Analysis/HORepresentation/SOAC.hs
@@ -72,7 +72,6 @@ import Data.Foldable as Foldable
import Data.Maybe
import Data.Monoid ((<>))
import qualified Data.Sequence as Seq
-import qualified Data.Semigroup as Sem
import qualified Futhark.Representation.AST as Futhark
import Futhark.Representation.SOACS.SOAC
@@ -126,14 +125,13 @@ instance Substitute ArrayTransform where
newtype ArrayTransforms = ArrayTransforms (Seq.Seq ArrayTransform)
deriving (Eq, Ord, Show)
-instance Sem.Semigroup ArrayTransforms where
+instance Semigroup ArrayTransforms where
ts1 <> ts2 = case viewf ts2 of
t :< ts2' -> (ts1 |> t) <> ts2'
EmptyF -> ts1
instance Monoid ArrayTransforms where
mempty = noTransforms
- mappend = (Sem.<>)
instance Substitute ArrayTransforms where
substituteNames substs (ArrayTransforms ts) =
diff --git a/src/Futhark/Analysis/Metrics.hs b/src/Futhark/Analysis/Metrics.hs
index 32f0793..3201815 100644
--- a/src/Futhark/Analysis/Metrics.hs
+++ b/src/Futhark/Analysis/Metrics.hs
@@ -22,7 +22,6 @@ import qualified Data.Text as T
import Data.String
import Data.List
import qualified Data.Map.Strict as M
-import qualified Data.Semigroup as Sem
import Futhark.Representation.AST
@@ -48,12 +47,11 @@ instance OpMetrics () where
newtype CountMetrics = CountMetrics [([Text], Text)]
-instance Sem.Semigroup CountMetrics where
+instance Semigroup CountMetrics where
CountMetrics x <> CountMetrics y = CountMetrics $ x <> y
instance Monoid CountMetrics where
mempty = CountMetrics mempty
- mappend = (Sem.<>)
actualMetrics :: CountMetrics -> AstMetrics
actualMetrics (CountMetrics metrics) =
diff --git a/src/Futhark/Analysis/PrimExp.hs b/src/Futhark/Analysis/PrimExp.hs
index 5835237..b48d73e 100644
--- a/src/Futhark/Analysis/PrimExp.hs
+++ b/src/Futhark/Analysis/PrimExp.hs
@@ -8,6 +8,7 @@ module Futhark.Analysis.PrimExp
, coerceIntPrimExp
, true
, false
+ , constFoldPrimExp
, module Futhark.Representation.Primitive
, (.&&.), (.||.), (.<.), (.<=.), (.>.), (.>=.), (.==.), (.&.), (.|.), (.^.)
@@ -22,7 +23,9 @@ import Futhark.Representation.Primitive
import Futhark.Util.IntegralExp
import Futhark.Util.Pretty
--- | A primitive expression parametrised over the representation of free variables.
+-- | A primitive expression parametrised over the representation of
+-- free variables. Note that the 'Functor', 'Traversable', and 'Num'
+-- instances perform automatic (but simple) constant folding.
data PrimExp v = LeafExp v PrimType
| ValueExp PrimValue
| BinOpExp BinOp (PrimExp v) (PrimExp v)
@@ -66,7 +69,7 @@ instance Traversable PrimExp where
traverse _ (ValueExp v) =
pure $ ValueExp v
traverse f (BinOpExp op x y) =
- BinOpExp op <$> traverse f x <*> traverse f y
+ constFoldPrimExp <$> (BinOpExp op <$> traverse f x <*> traverse f y)
traverse f (CmpOpExp op x y) =
CmpOpExp op <$> traverse f x <*> traverse f y
traverse f (ConvOpExp op x) =
@@ -79,6 +82,29 @@ instance Traversable PrimExp where
instance FreeIn v => FreeIn (PrimExp v) where
freeIn = foldMap freeIn
+-- | Perform quick and dirty constant folding on the top level of a
+-- PrimExp. This is necessary because we want to consider
+-- e.g. equality modulo constant folding.
+constFoldPrimExp :: PrimExp v -> PrimExp v
+constFoldPrimExp (BinOpExp Add{} x y)
+ | zeroIshExp x = y
+ | zeroIshExp y = x
+constFoldPrimExp (BinOpExp Sub{} x y)
+ | zeroIshExp y = x
+constFoldPrimExp (BinOpExp Mul{} x y)
+ | oneIshExp x = y
+ | oneIshExp y = x
+constFoldPrimExp (BinOpExp SDiv{} x y)
+ | oneIshExp y = x
+constFoldPrimExp (BinOpExp SQuot{} x y)
+ | oneIshExp y = x
+constFoldPrimExp (BinOpExp UDiv{} x y)
+ | oneIshExp y = x
+constFoldPrimExp (BinOpExp bop (ValueExp x) (ValueExp y))
+ | Just z <- doBinOp bop x y =
+ ValueExp z
+constFoldPrimExp e = e
+
-- The Num instance performs a little bit of magic: whenever an
-- expression and a constant is combined with a binary operator, the
-- type of the constant may be changed to be the type of the
@@ -93,32 +119,13 @@ instance FreeIn v => FreeIn (PrimExp v) where
-- expressions to constants so that the above works. However, it is
-- still a bit of a hack.
instance Pretty v => Num (PrimExp v) where
- x + y | zeroIshExp x = y
- | zeroIshExp y = x
- | IntType t <- primExpType x,
- Just z <- constFold (doBinOp $ Add t) x y = z
- | FloatType t <- primExpType x,
- Just z <- constFold (doBinOp $ FAdd t) x y = z
- | Just z <- msum [asIntOp Add x y, asFloatOp FAdd x y] = z
+ x + y | Just z <- msum [asIntOp Add x y, asFloatOp FAdd x y] = constFoldPrimExp z
| otherwise = numBad "+" (x,y)
- x - y | zeroIshExp y = x
- | IntType t <- primExpType x,
- Just z <- constFold (doBinOp $ Sub t) x y = z
- | FloatType t <- primExpType x,
- Just z <- constFold (doBinOp $ FSub t) x y = z
- | Just z <- msum [asIntOp Sub x y, asFloatOp FSub x y] = z
+ x - y | Just z <- msum [asIntOp Sub x y, asFloatOp FSub x y] = constFoldPrimExp z
| otherwise = numBad "-" (x,y)
- x * y | zeroIshExp x = x
- | zeroIshExp y = y
- | oneIshExp x = y
- | oneIshExp y = x
- | IntType t <- primExpType x,
- Just z <- constFold (doBinOp $ Mul t) x y = z
- | FloatType t <- primExpType x,
- Just z <- constFold (doBinOp $ FMul t) x y = z
- | Just z <- msum [asIntOp Mul x y, asFloatOp FMul x y] = z
+ x * y | Just z <- msum [asIntOp Mul x y, asFloatOp FMul x y] = constFoldPrimExp z
| otherwise = numBad "*" (x,y)
abs x | IntType t <- primExpType x = UnOpExp (Abs t) x
@@ -131,18 +138,17 @@ instance Pretty v => Num (PrimExp v) where
fromInteger = fromInt32 . fromInteger
instance Pretty v => IntegralExp (PrimExp v) where
- x `div` y | oneIshExp y = x
- | Just z <- msum [asIntOp SDiv x y, asFloatOp FDiv x y] = z
+ x `div` y | Just z <- msum [asIntOp SDiv x y, asFloatOp FDiv x y] = constFoldPrimExp z
| otherwise = numBad "div" (x,y)
x `mod` y | Just z <- msum [asIntOp SMod x y] = z
| otherwise = numBad "mod" (x,y)
x `quot` y | oneIshExp y = x
- | Just z <- msum [asIntOp SQuot x y] = z
+ | Just z <- msum [asIntOp SQuot x y] = constFoldPrimExp z
| otherwise = numBad "quot" (x,y)
- x `rem` y | Just z <- msum [asIntOp SRem x y] = z
+ x `rem` y | Just z <- msum [asIntOp SRem x y] = constFoldPrimExp z
| otherwise = numBad "rem" (x,y)
sgn (ValueExp (IntValue i)) = Just $ signum $ valueIntegral i
@@ -220,13 +226,6 @@ asFloatExp t (ValueExp (IntValue v)) =
asFloatExp _ _ =
Nothing
-constFold :: (PrimValue -> PrimValue -> Maybe PrimValue)
- -> PrimExp v -> PrimExp v
- -> Maybe (PrimExp v)
-constFold f x y = do x' <- valueExp x
- y' <- valueExp y
- ValueExp <$> f x' y'
-
numBad :: Pretty a => String -> a -> b
numBad s x =
error $ "Invalid argument to PrimExp method " ++ s ++ ": " ++ pretty x
@@ -280,11 +279,6 @@ oneIshExp :: PrimExp v -> Bool
oneIshExp (ValueExp v) = oneIsh v
oneIshExp _ = False
--- | Is the expression a constant value?
-valueExp :: PrimExp v -> Maybe PrimValue
-valueExp (ValueExp v) = Just v
-valueExp _ = Nothing
-
-- | If the given 'PrimExp' is a constant of the wrong integer type,
-- coerce it to the given integer type. This is a workaround for an
-- issue in the 'Num' instance.
diff --git a/src/Futhark/Analysis/PrimExp/Convert.hs b/src/Futhark/Analysis/PrimExp/Convert.hs
index 3c3058a..afe2b2a 100644
--- a/src/Futhark/Analysis/PrimExp/Convert.hs
+++ b/src/Futhark/Analysis/PrimExp/Convert.hs
@@ -92,7 +92,7 @@ replaceInPrimExp f (LeafExp v pt) =
replaceInPrimExp _ (ValueExp v) =
ValueExp v
replaceInPrimExp f (BinOpExp bop pe1 pe2) =
- BinOpExp bop (replaceInPrimExp f pe1) (replaceInPrimExp f pe2)
+ constFoldPrimExp $ BinOpExp bop (replaceInPrimExp f pe1) (replaceInPrimExp f pe2)
replaceInPrimExp f (CmpOpExp cop pe1 pe2) =
CmpOpExp cop (replaceInPrimExp f pe1) (replaceInPrimExp f pe2)
replaceInPrimExp f (UnOpExp uop pe) =
diff --git a/src/Futhark/Analysis/Range.hs b/src/Futhark/Analysis/Range.hs
index d3741a0..15eee8a 100644
--- a/src/Futhark/Analysis/Range.hs
+++ b/src/Futhark/Analysis/Range.hs
@@ -12,7 +12,6 @@ module Futhark.Analysis.Range
import qualified Data.Map.Strict as M
import Control.Monad.Reader
-import Data.Semigroup ((<>))
import Data.List
import qualified Futhark.Analysis.ScalExp as SE
diff --git a/src/Futhark/Analysis/SymbolTable.hs b/src/Futhark/Analysis/SymbolTable.hs
index 571cc8f..0f7265c 100644
--- a/src/Futhark/Analysis/SymbolTable.hs
+++ b/src/Futhark/Analysis/SymbolTable.hs
@@ -59,12 +59,10 @@ import Control.Monad
import Control.Monad.Reader
import Data.Ord
import Data.Maybe
-import Data.Semigroup ((<>))
import Data.List hiding (elem, lookup)
import qualified Data.List as L
import qualified Data.Set as S
import qualified Data.Map.Strict as M
-import qualified Data.Semigroup as Sem
import Prelude hiding (elem, lookup)
@@ -88,7 +86,7 @@ data SymbolTable lore = SymbolTable {
-- loop?
}
-instance Sem.Semigroup (SymbolTable lore) where
+instance Semigroup (SymbolTable lore) where
table1 <> table2 =
SymbolTable { loopDepth = max (loopDepth table1) (loopDepth table2)
, bindings = bindings table1 <> bindings table2
@@ -98,7 +96,6 @@ instance Sem.Semigroup (SymbolTable lore) where
instance Monoid (SymbolTable lore) where
mempty = empty
- mappend = (Sem.<>)
empty :: SymbolTable lore
empty = SymbolTable 0 M.empty mempty
diff --git a/src/Futhark/Analysis/Usage.hs b/src/Futhark/Analysis/Usage.hs
index 8f0355b..f8afdc7 100644
--- a/src/Futhark/Analysis/Usage.hs
+++ b/src/Futhark/Analysis/Usage.hs
@@ -8,7 +8,6 @@ module Futhark.Analysis.Usage
)
where
-import Data.Semigroup ((<>))
import Data.Foldable
import qualified Data.Set as S
diff --git a/src/Futhark/Analysis/UsageTable.hs b/src/Futhark/Analysis/UsageTable.hs
index 00cac1d..e8efc05 100644
--- a/src/Futhark/Analysis/UsageTable.hs
+++ b/src/Futhark/Analysis/UsageTable.hs
@@ -27,10 +27,8 @@ import Control.Arrow (first)
import Data.Bits
import qualified Data.Foldable as Foldable
import Data.List (foldl')
-import Data.Semigroup ((<>))
import qualified Data.Map.Strict as M
import qualified Data.Set as S
-import qualified Data.Semigroup as Sem
import Prelude hiding (lookup)
@@ -40,13 +38,12 @@ import Futhark.Representation.AST
newtype UsageTable = UsageTable (M.Map VName Usages)
deriving (Eq, Show)
-instance Sem.Semigroup UsageTable where
+instance Semigroup UsageTable where
UsageTable table1 <> UsageTable table2 =
UsageTable $ M.unionWith (<>) table1 table2
instance Monoid UsageTable where
mempty = empty
- mappend = (Sem.<>)
instance Substitute UsageTable where
substituteNames subst (UsageTable table)
@@ -115,12 +112,11 @@ inResultUsage name = UsageTable $ M.singleton name inResultU
newtype Usages = Usages Int
deriving (Eq, Ord, Show)
-instance Sem.Semigroup Usages where
+instance Semigroup Usages where
Usages x <> Usages y = Usages $ x .|. y
instance Monoid Usages where
mempty = Usages 0
- mappend = (Sem.<>)
consumedU, inResultU, presentU :: Usages
consumedU = Usages 1
diff --git a/src/futhark-bench.hs b/src/Futhark/CLI/Bench.hs
index d4a8d8e..5845a81 100644
--- a/src/futhark-bench.hs
+++ b/src/Futhark/CLI/Bench.hs
@@ -3,9 +3,8 @@
{-# LANGUAGE FlexibleContexts #-}
-- | Simple tool for benchmarking Futhark programs. Use the @--json@
-- flag for machine-readable output.
-module Main (main) where
+module Futhark.CLI.Bench (main) where
-import Control.Concurrent
import Control.Monad
import Control.Monad.Except
import qualified Data.ByteString.Char8 as SBS
@@ -13,7 +12,6 @@ import qualified Data.ByteString.Lazy.Char8 as LBS
import qualified Data.Map as M
import Data.Either
import Data.Maybe
-import Data.Semigroup ((<>))
import Data.List
import qualified Data.Text as T
import qualified Data.Text.IO as T
@@ -32,10 +30,12 @@ import Text.Printf
import Text.Regex.TDFA
import Futhark.Test
+import Futhark.Util (pmapIO)
import Futhark.Util.Options
data BenchOptions = BenchOptions
- { optCompiler :: String
+ { optBackend :: String
+ , optFuthark :: String
, optRunner :: String
, optRuns :: Int
, optExtraOptions :: [String]
@@ -44,11 +44,12 @@ data BenchOptions = BenchOptions
, optSkipCompilation :: Bool
, optExcludeCase :: [String]
, optIgnoreFiles :: [Regex]
+ , optEntryPoint :: Maybe String
}
initialBenchOptions :: BenchOptions
-initialBenchOptions = BenchOptions "futhark-c" "" 10 [] Nothing (-1) False
- ["nobench", "disable"] []
+initialBenchOptions = BenchOptions "c" "futhark" "" 10 [] Nothing (-1) False
+ ["nobench", "disable"] [] Nothing
-- | The name we use for compiled programs.
binaryName :: FilePath -> FilePath
@@ -81,23 +82,6 @@ encodeBenchResults rs =
BenchResult prog r <- rs
return $ T.pack prog JSON..= M.singleton ("datasets" :: T.Text) (DataResults r)
-fork :: (a -> IO b) -> a -> IO (MVar b)
-fork f x = do cell <- newEmptyMVar
- void $ forkIO $ do result <- f x
- putMVar cell result
- return cell
-
-pmapIO :: (a -> IO b) -> [a] -> IO [b]
-pmapIO f elems = go elems []
- where
- go [] res = return res
- go xs res = do
- numThreads <- getNumCapabilities
- let (e,es) = splitAt numThreads xs
- mvars <- mapM (fork f) e
- result <- mapM takeMVar mvars
- go es (result ++ res)
-
runBenchmarks :: BenchOptions -> [FilePath] -> IO ()
runBenchmarks opts paths = do
-- We force line buffering to ensure that we produce running output.
@@ -126,9 +110,9 @@ anyFailed = any failedBenchResult
failedResult _ = False
anyFailedToCompile :: [SkipReason] -> Bool
-anyFailedToCompile = elem FailedToCompile
+anyFailedToCompile = not . all (==Skipped)
-data SkipReason = Skipped | FailedToCompile
+data SkipReason = Skipped | FailedToCompile | ReferenceFailed
deriving (Eq)
compileBenchmark :: BenchOptions -> (FilePath, ProgramTest)
@@ -147,24 +131,32 @@ compileBenchmark opts (program, spec) =
return $ Left FailedToCompile
else do
putStr $ "Compiling " ++ program ++ "...\n"
- (futcode, _, futerr) <- liftIO $ readProcessWithExitCode compiler
- [program, "-o", binaryName program] ""
-
- case futcode of
- ExitSuccess -> return $ Right (program, cases)
- ExitFailure 127 -> do putStrLn $ "Failed:\n" ++ progNotFound compiler
- return $ Left FailedToCompile
- ExitFailure _ -> do putStrLn "Failed:\n"
- SBS.putStrLn futerr
- return $ Left FailedToCompile
+
+ ref_res <- runExceptT $ ensureReferenceOutput futhark "c" program cases
+ case ref_res of
+ Left err -> do
+ putStrLn "Reference output generation failed:\n"
+ print err
+ return $ Left ReferenceFailed
+
+ Right () -> do
+ (futcode, _, futerr) <- liftIO $ readProcessWithExitCode futhark
+ [optBackend opts, program, "-o", binaryName program] ""
+
+ case futcode of
+ ExitSuccess -> return $ Right (program, cases)
+ ExitFailure 127 -> do putStrLn $ "Failed:\n" ++ progNotFound futhark
+ return $ Left FailedToCompile
+ ExitFailure _ -> do putStrLn "Failed:\n"
+ SBS.putStrLn futerr
+ return $ Left FailedToCompile
_ ->
return $ Left Skipped
- where compiler = optCompiler opts
-
- hasRuns (InputOutputs _ runs) = not $ null runs
+ where hasRuns (InputOutputs _ runs) = not $ null runs
+ futhark = optFuthark opts
runBenchmark :: BenchOptions -> (FilePath, [InputOutputs]) -> IO [BenchResult]
-runBenchmark opts (program, cases) = mapM forInputOutputs cases
+runBenchmark opts (program, cases) = mapM forInputOutputs $ filter relevant cases
where forInputOutputs (InputOutputs entry_name runs) = do
putStr $ "Results for " ++ program' ++ ":\n"
BenchResult program' . catMaybes <$>
@@ -173,6 +165,8 @@ runBenchmark opts (program, cases) = mapM forInputOutputs cases
then program
else program ++ ":" ++ T.unpack entry_name
+ relevant = maybe (const True) (==) (optEntryPoint opts) . T.unpack . iosEntryPoint
+
pad_to = foldl max 0 $ concatMap (map (length . runDescription) . iosTestRuns) cases
reportResult :: [RunResult] -> IO ()
@@ -203,15 +197,18 @@ runBenchmarkCase _ _ _ _ (TestRun _ _ RunTimeFailure{} _ _) =
runBenchmarkCase opts _ _ _ (TestRun tags _ _ _ _)
| any (`elem` tags) $ optExcludeCase opts =
return Nothing
-runBenchmarkCase opts program entry pad_to (TestRun _ input_spec (Succeeds expected_spec) _ dataset_desc) =
+runBenchmarkCase opts program entry pad_to tr@(TestRun _ input_spec (Succeeds expected_spec) _ dataset_desc) =
-- We store the runtime in a temporary file.
withSystemTempFile "futhark-bench" $ \tmpfile h -> do
hClose h -- We will be writing and reading this ourselves.
input <- getValuesBS dir input_spec
- let getValuesAndBS vs = do
+ let getValuesAndBS (SuccessValues vs) = do
vs' <- getValues dir vs
bs <- getValuesBS dir vs
return (LBS.toStrict bs, vs')
+ getValuesAndBS SuccessGenerateValues =
+ getValuesAndBS $ SuccessValues $ InFile $
+ testRunReferenceOutput program entry tr
maybe_expected <- maybe (return Nothing) (fmap Just . getValuesAndBS) expected_spec
let options = optExtraOptions opts ++ ["-e", T.unpack entry,
"-t", tmpfile,
@@ -236,22 +233,21 @@ runBenchmarkCase opts program entry pad_to (TestRun _ input_spec (Succeeds expec
readProcessWithExitCode to_run to_run_args $
LBS.toStrict input
- fmap (Just . DataResult dataset_desc) $ runBenchM $ case run_res of
- Just (progCode, output, progerr) ->
- do
- case maybe_expected of
- Nothing ->
- didNotFail program progCode $ T.decodeUtf8 progerr
- Just expected ->
- compareResult program expected =<<
- runResult program progCode output progerr
- runtime_result <- io $ T.readFile tmpfile
- runtimes <- case mapM readRuntime $ T.lines runtime_result of
- Just runtimes -> return $ map RunResult runtimes
- Nothing -> itWentWrong $ "Runtime file has invalid contents:\n" <> runtime_result
-
- io $ reportResult runtimes
- return (runtimes, T.decodeUtf8 progerr)
+ fmap (Just . DataResult dataset_desc) $ runBenchM $ case run_res of
+ Just (progCode, output, progerr) -> do
+ case maybe_expected of
+ Nothing ->
+ didNotFail program progCode $ T.decodeUtf8 progerr
+ Just expected ->
+ compareResult program expected =<<
+ runResult program progCode output progerr
+ runtime_result <- io $ T.readFile tmpfile
+ runtimes <- case mapM readRuntime $ T.lines runtime_result of
+ Just runtimes -> return $ map RunResult runtimes
+ Nothing -> itWentWrong $ "Runtime file has invalid contents:\n" <> runtime_result
+
+ io $ reportResult runtimes
+ return (runtimes, T.decodeUtf8 progerr)
Nothing ->
itWentWrong $ T.pack $ "Execution exceeded " ++ show (optTimeout opts) ++ " seconds."
@@ -297,7 +293,7 @@ compareResult :: (MonadError T.Text m, MonadIO m) =>
FilePath -> (SBS.ByteString, [Value]) -> (SBS.ByteString, [Value])
-> m ()
compareResult program (expected_bs, expected_vs) (actual_bs, actual_vs) =
- case compareValues actual_vs expected_vs of
+ case compareValues1 actual_vs expected_vs of
Just mismatch -> do
let actualf = program `replaceExtension` "actual"
expectedf = program `replaceExtension` "expected"
@@ -321,11 +317,14 @@ commandLineOptions = [
Left $ error $ "'" ++ n ++ "' is not a non-negative integer.")
"RUNS")
"Run each test case this many times."
- , Option [] ["compiler"]
- (ReqArg (\prog ->
- Right $ \config -> config { optCompiler = prog })
+ , Option [] ["backend"]
+ (ReqArg (\backend -> Right $ \config -> config { optBackend = backend })
"PROGRAM")
"The compiler used (defaults to 'futhark-c')."
+ , Option [] ["futhark"]
+ (ReqArg (\prog -> Right $ \config -> config { optFuthark = prog })
+ "PROGRAM")
+ "The binary used for operations (defaults to 'futhark')."
, Option [] ["runner"]
(ReqArg (\prog -> Right $ \config -> config { optRunner = prog }) "PROGRAM")
"The program used to run the Futhark-generated programs (defaults to nothing)."
@@ -364,11 +363,16 @@ commandLineOptions = [
config { optIgnoreFiles = makeRegex s : optIgnoreFiles config })
"REGEX")
"Ignore files matching this regular expression."
+ , Option "e" ["entry-point"]
+ (ReqArg (\s -> Right $ \config ->
+ config { optEntryPoint = Just s })
+ "NAME")
+ "Only run this entry point."
]
where max_timeout :: Int
max_timeout = maxBound `div` 1000000
-main :: IO ()
+main :: String -> [String] -> IO ()
main = mainWithOptions initialBenchOptions commandLineOptions "options... programs..." $ \progs config ->
Just $ runBenchmarks config progs
diff --git a/src/futhark-c.hs b/src/Futhark/CLI/C.hs
index 5a251cf..9b4f28c 100644
--- a/src/futhark-c.hs
+++ b/src/Futhark/CLI/C.hs
@@ -1,5 +1,5 @@
{-# LANGUAGE FlexibleContexts #-}
-module Main (main) where
+module Futhark.CLI.C (main) where
import Control.Monad.IO.Class
import System.FilePath
@@ -12,7 +12,7 @@ import Futhark.Util.Pretty (prettyText)
import Futhark.Compiler.CLI
import Futhark.Util
-main :: IO ()
+main :: String -> [String] -> IO ()
main = compilerMain () []
"Compile sequential C" "Generate sequential C code from optimised Futhark program."
sequentialCpuPipeline $ \() mode outpath prog -> do
diff --git a/src/futhark-csopencl.hs b/src/Futhark/CLI/CSOpenCL.hs
index e2df75e..a3862b4 100644
--- a/src/futhark-csopencl.hs
+++ b/src/Futhark/CLI/CSOpenCL.hs
@@ -1,5 +1,5 @@
{-# LANGUAGE FlexibleContexts #-}
-module Main (main) where
+module Futhark.CLI.CSOpenCL (main) where
import Control.Monad.IO.Class
import Data.Maybe (fromMaybe)
@@ -15,7 +15,7 @@ import Futhark.Util.Pretty (prettyText)
import Futhark.Compiler.CLI
import Futhark.Util
-main :: IO ()
+main :: String -> [String] -> IO ()
main = compilerMain () []
"Compile OpenCL C#" "Generate OpenCL C# code from optimised Futhark program."
gpuPipeline $ \() mode outpath prog -> do
diff --git a/src/futhark-cs.hs b/src/Futhark/CLI/CSharp.hs
index 96cc81d..2aa5c0e 100644
--- a/src/futhark-cs.hs
+++ b/src/Futhark/CLI/CSharp.hs
@@ -1,5 +1,5 @@
{-# LANGUAGE FlexibleContexts #-}
-module Main (main) where
+module Futhark.CLI.CSharp (main) where
import Control.Monad.IO.Class
import Data.Maybe (fromMaybe)
@@ -15,7 +15,7 @@ import Futhark.Util.Pretty (prettyText)
import Futhark.Compiler.CLI
import Futhark.Util
-main :: IO ()
+main :: String -> [String] -> IO ()
main = compilerMain () []
"Compile sequential C#" "Generate sequential C# code from optimised Futhark program."
sequentialCpuPipeline $ \() mode outpath prog -> do
diff --git a/src/Futhark/CLI/CUDA.hs b/src/Futhark/CLI/CUDA.hs
new file mode 100644
index 0000000..fe51dca
--- /dev/null
+++ b/src/Futhark/CLI/CUDA.hs
@@ -0,0 +1,43 @@
+{-# LANGUAGE FlexibleContexts #-}
+module Futhark.CLI.CUDA (main) where
+
+import Control.Monad.IO.Class
+import System.FilePath
+import System.Exit
+
+import Futhark.Pipeline
+import Futhark.Passes
+import qualified Futhark.CodeGen.Backends.CCUDA as CCUDA
+import Futhark.Util
+import Futhark.Util.Pretty (prettyText)
+import Futhark.Compiler.CLI
+
+main :: String -> [String] -> IO ()
+main = compilerMain () []
+ "Compile CUDA" "Generate CUDA/C code from optimised Futhark program."
+ gpuPipeline $ \() mode outpath prog -> do
+ cprog <- either (`internalError` prettyText prog) return =<<
+ CCUDA.compileProg prog
+ let cpath = outpath `addExtension` "c"
+ hpath = outpath `addExtension` "h"
+ extra_options = [ "-lcuda"
+ , "-lnvrtc"
+ ]
+ case mode of
+ ToLibrary -> do
+ let (header, impl) = CCUDA.asLibrary cprog
+ liftIO $ writeFile hpath header
+ liftIO $ writeFile cpath impl
+ ToExecutable -> do
+ liftIO $ writeFile cpath $ CCUDA.asExecutable cprog
+ let args = [cpath, "-O3", "-std=c99", "-lm", "-o", outpath]
+ ++ extra_options
+ ret <- liftIO $ runProgramWithExitCode "gcc" args ""
+ case ret of
+ Left err ->
+ externalErrorS $ "Failed to run gcc: " ++ show err
+ Right (ExitFailure code, _, gccerr) ->
+ externalErrorS $ "gcc failed with code " ++
+ show code ++ ":\n" ++ gccerr
+ Right (ExitSuccess, _, _) ->
+ return ()
diff --git a/src/Futhark/CLI/Datacmp.hs b/src/Futhark/CLI/Datacmp.hs
new file mode 100644
index 0000000..45bcdde
--- /dev/null
+++ b/src/Futhark/CLI/Datacmp.hs
@@ -0,0 +1,29 @@
+{-# LANGUAGE OverloadedStrings #-}
+module Futhark.CLI.Datacmp (main) where
+
+import qualified Data.ByteString.Lazy.Char8 as BS
+import System.Exit
+import System.IO
+
+import Futhark.Test.Values
+import Futhark.Util.Options
+
+main :: String -> [String] -> IO ()
+main = mainWithOptions () [] "<file> <file>" f
+ where f [file_a, file_b] () = Just $ do
+ vs_a_maybe <- readValues <$> BS.readFile file_a
+ vs_b_maybe <- readValues <$> BS.readFile file_b
+ case (vs_a_maybe, vs_b_maybe) of
+ (Nothing, _) ->
+ error $ "Error reading values from " ++ file_a
+ (_, Nothing) ->
+ error $ "Error reading values from " ++ file_b
+ (Just vs_a, Just vs_b) ->
+ case compareValues vs_a vs_b of
+ [] -> return ()
+ es -> do
+ mapM_ (hPrint stderr) es
+ exitWith $ ExitFailure 2
+
+ f _ _ =
+ Nothing
diff --git a/src/futhark-dataset.hs b/src/Futhark/CLI/Dataset.hs
index 584e0fc..6ed15fa 100644
--- a/src/futhark-dataset.hs
+++ b/src/Futhark/CLI/Dataset.hs
@@ -1,7 +1,7 @@
{-# LANGUAGE OverloadedStrings #-}
-- | Randomly generate Futhark input files containing values of a
-- specified type and shape.
-module Main (main) where
+module Futhark.CLI.Dataset (main) where
import Control.Monad
import Control.Monad.ST
@@ -25,7 +25,7 @@ import Language.Futhark.Pretty ()
import Futhark.Test.Values
import Futhark.Util.Options
-main :: IO ()
+main :: String -> [String] -> IO ()
main = mainWithOptions initialDataOptions commandLineOptions "options..." f
where f [] config
| null $ optOrders config = Just $ do
diff --git a/src/Futhark/CLI/Dev.hs b/src/Futhark/CLI/Dev.hs
new file mode 100644
index 0000000..9094fec
--- /dev/null
+++ b/src/Futhark/CLI/Dev.hs
@@ -0,0 +1,398 @@
+{-# LANGUAGE RankNTypes #-}
+-- | Futhark Compiler Driver
+module Futhark.CLI.Dev (main) where
+
+import Data.Maybe
+import Control.Category (id)
+import Control.Monad
+import Control.Monad.State
+import qualified Data.Text.IO as T
+import System.IO
+import System.Exit
+import System.Console.GetOpt
+
+import Prelude hiding (id)
+
+import Futhark.Pass
+import Futhark.Actions
+import Futhark.Compiler
+import Language.Futhark.Parser (parseFuthark)
+import Futhark.Util.Options
+import Futhark.Pipeline
+import qualified Futhark.Representation.SOACS as SOACS
+import Futhark.Representation.SOACS (SOACS)
+import qualified Futhark.Representation.Kernels as Kernels
+import Futhark.Representation.Kernels (Kernels)
+import qualified Futhark.Representation.ExplicitMemory as ExplicitMemory
+import Futhark.Representation.ExplicitMemory (ExplicitMemory)
+import Futhark.Representation.AST (Prog, pretty)
+import Futhark.TypeCheck (Checkable)
+import qualified Futhark.Util.Pretty as PP
+
+import Futhark.Internalise.Defunctorise as Defunctorise
+import Futhark.Internalise.Monomorphise as Monomorphise
+import Futhark.Internalise.Defunctionalise as Defunctionalise
+import Futhark.Optimise.InliningDeadFun
+import Futhark.Optimise.CSE
+import Futhark.Optimise.Fusion
+import Futhark.Pass.FirstOrderTransform
+import Futhark.Pass.Simplify
+import Futhark.Optimise.InPlaceLowering
+import Futhark.Optimise.DoubleBuffer
+import Futhark.Optimise.TileLoops
+import Futhark.Optimise.Unstream
+import Futhark.Pass.KernelBabysitting
+import Futhark.Pass.ExtractKernels
+import Futhark.Pass.ExpandAllocations
+import Futhark.Pass.ExplicitAllocations
+import Futhark.Passes
+
+-- | What to do with the program after it has been read.
+data FutharkPipeline = PrettyPrint
+ -- ^ Just print it.
+ | TypeCheck
+ -- ^ Run the type checker; print type errors.
+ | Pipeline [UntypedPass]
+ -- ^ Run this pipeline.
+ | Defunctorise
+ -- ^ Partially evaluate away the module language.
+ | Monomorphise
+ -- ^ Defunctorise and monomorphise.
+ | Defunctionalise
+ -- ^ Defunctorise, monomorphise, and defunctionalise.
+
+data Config = Config { futharkConfig :: FutharkConfig
+ , futharkPipeline :: FutharkPipeline
+ -- ^ Nothing is distinct from a empty pipeline -
+ -- it means we don't even run the internaliser.
+ , futharkAction :: UntypedAction
+ }
+
+
+-- | Get a Futhark pipeline from the configuration - an empty one if
+-- none exists.
+getFutharkPipeline :: Config -> [UntypedPass]
+getFutharkPipeline = toPipeline . futharkPipeline
+ where toPipeline (Pipeline p) = p
+ toPipeline _ = []
+
+data UntypedPassState = SOACS (Prog SOACS.SOACS)
+ | Kernels (Prog Kernels.Kernels)
+ | ExplicitMemory (Prog ExplicitMemory.ExplicitMemory)
+
+getSOACSProg :: UntypedPassState -> Maybe (Prog SOACS.SOACS)
+getSOACSProg (SOACS prog) = Just prog
+getSOACSProg _ = Nothing
+
+class Representation s where
+ -- | A human-readable description of the representation expected or
+ -- contained, usable for error messages.
+ representation :: s -> String
+
+instance Representation UntypedPassState where
+ representation (SOACS _) = "SOACS"
+ representation (Kernels _) = "Kernels"
+ representation (ExplicitMemory _) = "ExplicitMemory"
+
+instance PP.Pretty UntypedPassState where
+ ppr (SOACS prog) = PP.ppr prog
+ ppr (Kernels prog) = PP.ppr prog
+ ppr (ExplicitMemory prog) = PP.ppr prog
+
+newtype UntypedPass = UntypedPass (UntypedPassState
+ -> PipelineConfig
+ -> FutharkM UntypedPassState)
+
+data UntypedAction = SOACSAction (Action SOACS)
+ | KernelsAction (Action Kernels)
+ | ExplicitMemoryAction (Action ExplicitMemory)
+ | PolyAction (Action SOACS) (Action Kernels) (Action ExplicitMemory)
+
+untypedActionName :: UntypedAction -> String
+untypedActionName (SOACSAction a) = actionName a
+untypedActionName (KernelsAction a) = actionName a
+untypedActionName (ExplicitMemoryAction a) = actionName a
+untypedActionName (PolyAction a _ _) = actionName a
+
+instance Representation UntypedAction where
+ representation (SOACSAction _) = "SOACS"
+ representation (KernelsAction _) = "Kernels"
+ representation (ExplicitMemoryAction _) = "ExplicitMemory"
+ representation PolyAction{} = "<any>"
+
+newConfig :: Config
+newConfig = Config newFutharkConfig (Pipeline []) $ PolyAction printAction printAction printAction
+
+changeFutharkConfig :: (FutharkConfig -> FutharkConfig)
+ -> Config -> Config
+changeFutharkConfig f cfg = cfg { futharkConfig = f $ futharkConfig cfg }
+
+type FutharkOption = FunOptDescr Config
+
+passOption :: String -> UntypedPass -> String -> [String] -> FutharkOption
+passOption desc pass short long =
+ Option short long
+ (NoArg $ Right $ \cfg ->
+ cfg { futharkPipeline = Pipeline $ getFutharkPipeline cfg ++ [pass] })
+ desc
+
+explicitMemoryProg :: String -> UntypedPassState -> FutharkM (Prog ExplicitMemory.ExplicitMemory)
+explicitMemoryProg _ (ExplicitMemory prog) =
+ return prog
+explicitMemoryProg name rep =
+ externalErrorS $ "Pass " ++ name ++
+ " expects ExplicitMemory representation, but got " ++ representation rep
+
+soacsProg :: String -> UntypedPassState -> FutharkM (Prog SOACS.SOACS)
+soacsProg _ (SOACS prog) =
+ return prog
+soacsProg name rep =
+ externalErrorS $ "Pass " ++ name ++
+ " expects SOACS representation, but got " ++ representation rep
+
+kernelsProg :: String -> UntypedPassState -> FutharkM (Prog Kernels.Kernels)
+kernelsProg _ (Kernels prog) =
+ return prog
+kernelsProg name rep =
+ externalErrorS $
+ "Pass " ++ name ++" expects Kernels representation, but got " ++ representation rep
+
+typedPassOption :: (Checkable fromlore, Checkable tolore) =>
+ (String -> UntypedPassState -> FutharkM (Prog fromlore))
+ -> (Prog tolore -> UntypedPassState)
+ -> Pass fromlore tolore
+ -> String
+ -> FutharkOption
+typedPassOption getProg putProg pass short =
+ passOption (passDescription pass) (UntypedPass perform) short long
+ where perform s config = do
+ prog <- getProg (passName pass) s
+ putProg <$> runPasses (onePass pass) config prog
+
+ long = [passLongOption pass]
+
+soacsPassOption :: Pass SOACS SOACS -> String -> FutharkOption
+soacsPassOption =
+ typedPassOption soacsProg SOACS
+
+kernelsPassOption :: Pass Kernels Kernels -> String -> FutharkOption
+kernelsPassOption =
+ typedPassOption kernelsProg Kernels
+
+explicitMemoryPassOption :: Pass ExplicitMemory ExplicitMemory -> String -> FutharkOption
+explicitMemoryPassOption =
+ typedPassOption explicitMemoryProg ExplicitMemory
+
+simplifyOption :: String -> FutharkOption
+simplifyOption short =
+ passOption (passDescription pass) (UntypedPass perform) short long
+ where perform (SOACS prog) config =
+ SOACS <$> runPasses (onePass simplifySOACS) config prog
+ perform (Kernels prog) config =
+ Kernels <$> runPasses (onePass simplifyKernels) config prog
+ perform (ExplicitMemory prog) config =
+ ExplicitMemory <$> runPasses (onePass simplifyExplicitMemory) config prog
+
+ long = [passLongOption pass]
+ pass = simplifySOACS
+
+cseOption :: String -> FutharkOption
+cseOption short =
+ passOption (passDescription pass) (UntypedPass perform) short long
+ where perform (SOACS prog) config =
+ SOACS <$> runPasses (onePass $ performCSE True) config prog
+ perform (Kernels prog) config =
+ Kernels <$> runPasses (onePass $ performCSE True) config prog
+ perform (ExplicitMemory prog) config =
+ ExplicitMemory <$> runPasses (onePass $ performCSE False) config prog
+
+ long = [passLongOption pass]
+ pass = performCSE True :: Pass SOACS SOACS
+
+pipelineOption :: (UntypedPassState -> Maybe (Prog fromlore))
+ -> String
+ -> (Prog tolore -> UntypedPassState)
+ -> String
+ -> Pipeline fromlore tolore
+ -> String
+ -> [String]
+ -> FutharkOption
+pipelineOption getprog repdesc repf desc pipeline =
+ passOption desc $ UntypedPass pipelinePass
+ where pipelinePass rep config =
+ case getprog rep of
+ Just prog ->
+ repf <$> runPasses pipeline config prog
+ Nothing ->
+ externalErrorS $ "Expected " ++ repdesc ++ " representation, but got " ++
+ representation rep
+
+soacsPipelineOption :: String -> Pipeline SOACS SOACS -> String -> [String]
+ -> FutharkOption
+soacsPipelineOption = pipelineOption getSOACSProg "SOACS" SOACS
+
+kernelsPipelineOption :: String -> Pipeline SOACS Kernels -> String -> [String]
+ -> FutharkOption
+kernelsPipelineOption = pipelineOption getSOACSProg "Kernels" Kernels
+
+explicitMemoryPipelineOption :: String -> Pipeline SOACS ExplicitMemory -> String -> [String]
+ -> FutharkOption
+explicitMemoryPipelineOption = pipelineOption getSOACSProg "ExplicitMemory" ExplicitMemory
+
+commandLineOptions :: [FutharkOption]
+commandLineOptions =
+ [ Option "v" ["verbose"]
+ (OptArg (Right . changeFutharkConfig . incVerbosity) "FILE")
+ "Print verbose output on standard error; wrong program to FILE."
+ , Option [] ["Werror"]
+ (NoArg $ Right $ changeFutharkConfig $ \opts -> opts { futharkWerror = True })
+ "Treat warnings as errors."
+
+ , Option "t" ["type-check"]
+ (NoArg $ Right $ \opts ->
+ opts { futharkPipeline = TypeCheck })
+ "Type-check the program and print errors on standard error."
+
+ , Option [] ["pretty-print"]
+ (NoArg $ Right $ \opts ->
+ opts { futharkPipeline = PrettyPrint })
+ "Parse and pretty-print the AST of the given program."
+
+ , Option [] ["compile-imperative"]
+ (NoArg $ Right $ \opts ->
+ opts { futharkAction = ExplicitMemoryAction impCodeGenAction })
+ "Translate program into the imperative IL and write it on standard output."
+ , Option [] ["compile-imperative-kernels"]
+ (NoArg $ Right $ \opts ->
+ opts { futharkAction = ExplicitMemoryAction kernelImpCodeGenAction })
+ "Translate program into the imperative IL with kernels and write it on standard output."
+ , Option [] ["range-analysis"]
+ (NoArg $ Right $ \opts -> opts { futharkAction = PolyAction rangeAction rangeAction rangeAction })
+ "Print the program with range annotations added."
+ , Option "p" ["print"]
+ (NoArg $ Right $ \opts -> opts { futharkAction = PolyAction printAction printAction printAction })
+ "Prettyprint the resulting internal representation on standard output (default action)."
+ , Option "m" ["metrics"]
+ (NoArg $ Right $ \opts -> opts { futharkAction = PolyAction metricsAction metricsAction metricsAction })
+ "Print AST metrics of the resulting internal representation on standard output."
+ , Option [] ["defunctorise"]
+ (NoArg $ Right $ \opts -> opts { futharkPipeline = Defunctorise })
+ "Partially evaluate all module constructs and print the residual program."
+ , Option [] ["monomorphise"]
+ (NoArg $ Right $ \opts -> opts { futharkPipeline = Monomorphise })
+ "Monomorphise the program."
+ , Option [] ["defunctionalise"]
+ (NoArg $ Right $ \opts -> opts { futharkPipeline = Defunctionalise })
+ "Defunctionalise the program."
+ , typedPassOption soacsProg Kernels firstOrderTransform "f"
+ , soacsPassOption fuseSOACs "o"
+ , soacsPassOption inlineAndRemoveDeadFunctions []
+ , kernelsPassOption inPlaceLowering []
+ , kernelsPassOption babysitKernels []
+ , kernelsPassOption tileLoops []
+ , kernelsPassOption unstream []
+ , typedPassOption soacsProg Kernels extractKernels []
+
+ , typedPassOption kernelsProg ExplicitMemory explicitAllocations "a"
+
+ , explicitMemoryPassOption doubleBuffer []
+ , explicitMemoryPassOption expandAllocations []
+
+ , cseOption []
+ , simplifyOption "e"
+
+ , soacsPipelineOption "Run the default optimised pipeline"
+ standardPipeline "s" ["standard"]
+ , kernelsPipelineOption "Run the default optimised kernels pipeline"
+ kernelsPipeline [] ["kernels"]
+ , explicitMemoryPipelineOption "Run the full GPU compilation pipeline"
+ gpuPipeline [] ["gpu"]
+ , explicitMemoryPipelineOption "Run the sequential CPU compilation pipeline"
+ sequentialCpuPipeline [] ["cpu"]
+ ]
+
+incVerbosity :: Maybe FilePath -> FutharkConfig -> FutharkConfig
+incVerbosity file cfg =
+ cfg { futharkVerbose = (v, file `mplus` snd (futharkVerbose cfg)) }
+ where v = case fst $ futharkVerbose cfg of
+ NotVerbose -> Verbose
+ Verbose -> VeryVerbose
+ VeryVerbose -> VeryVerbose
+
+-- | Entry point. Non-interactive, except when reading interpreter
+-- input from standard input.
+main :: String -> [String] -> IO ()
+main = mainWithOptions newConfig commandLineOptions "options... program" compile
+ where compile [file] config =
+ Just $ do
+ res <- runFutharkM (m file config) $
+ fst $ futharkVerbose $ futharkConfig config
+ case res of
+ Left err -> do
+ dumpError (futharkConfig config) err
+ exitWith $ ExitFailure 2
+ Right () -> return ()
+ compile _ _ =
+ Nothing
+ m file config =
+ case futharkPipeline config of
+ TypeCheck -> do
+ -- No pipeline; just read the program and type check
+ (warnings, _, _) <- readProgram file
+ liftIO $ hPutStr stderr $ show warnings
+ PrettyPrint -> liftIO $ do
+ maybe_prog <- parseFuthark file <$> T.readFile file
+ case maybe_prog of
+ Left err -> fail $ show err
+ Right prog -> putStrLn $ pretty prog
+ Defunctorise -> do
+ (_, imports, src) <- readProgram file
+ liftIO $ mapM_ (putStrLn . pretty) $
+ evalState (Defunctorise.transformProg imports) src
+ Monomorphise -> do
+ (_, imports, src) <- readProgram file
+ liftIO $ mapM_ (putStrLn . pretty) $ flip evalState src $
+ Defunctorise.transformProg imports
+ >>= Monomorphise.transformProg
+ Defunctionalise -> do
+ (_, imports, src) <- readProgram file
+ liftIO $ mapM_ (putStrLn . pretty) $ flip evalState src $
+ Defunctorise.transformProg imports
+ >>= Monomorphise.transformProg
+ >>= Defunctionalise.transformProg
+ Pipeline{} -> do
+ prog <- runPipelineOnProgram (futharkConfig config) id file
+ runPolyPasses config prog
+
+runPolyPasses :: Config -> SOACS.Prog -> FutharkM ()
+runPolyPasses config prog = do
+ prog' <- foldM (runPolyPass pipeline_config) (SOACS prog) (getFutharkPipeline config)
+ case (prog', futharkAction config) of
+ (SOACS soacs_prog, SOACSAction action) ->
+ actionProcedure action soacs_prog
+ (Kernels kernels_prog, KernelsAction action) ->
+ actionProcedure action kernels_prog
+ (ExplicitMemory mem_prog, ExplicitMemoryAction action) ->
+ actionProcedure action mem_prog
+
+ (SOACS soacs_prog, PolyAction soacs_action _ _) ->
+ actionProcedure soacs_action soacs_prog
+ (Kernels kernels_prog, PolyAction _ kernels_action _) ->
+ actionProcedure kernels_action kernels_prog
+ (ExplicitMemory mem_prog, PolyAction _ _ mem_action) ->
+ actionProcedure mem_action mem_prog
+
+ (_, action) ->
+ externalErrorS $ "Action " <>
+ untypedActionName action <>
+ " expects " ++ representation action ++ " representation, but got " ++
+ representation prog' ++ "."
+ where pipeline_config =
+ PipelineConfig { pipelineVerbose = fst (futharkVerbose $ futharkConfig config) > NotVerbose
+ , pipelineValidate = True
+ }
+
+runPolyPass :: PipelineConfig
+ -> UntypedPassState -> UntypedPass -> FutharkM UntypedPassState
+runPolyPass pipeline_config s (UntypedPass f) =
+ f s pipeline_config
diff --git a/src/futhark-doc.hs b/src/Futhark/CLI/Doc.hs
index 4e40a39..5a89aab 100644
--- a/src/futhark-doc.hs
+++ b/src/Futhark/CLI/Doc.hs
@@ -2,13 +2,12 @@
{-# LANGUAGE QuasiQuotes #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE OverloadedStrings #-}
-module Main (main) where
+module Futhark.CLI.Doc (main) where
import Control.Monad.IO.Class (liftIO)
import Control.Monad.State
import Data.FileEmbed
import Data.List
-import Data.Semigroup ((<>))
import System.FilePath
import System.Directory (createDirectoryIfMissing)
import System.Console.GetOpt
@@ -25,7 +24,7 @@ import Language.Futhark.Syntax (progDoc, DocComment(..))
import Futhark.Util.Options
import Futhark.Util (directoryContents, trim)
-main :: IO ()
+main :: String -> [String] -> IO ()
main = mainWithOptions initialDocConfig commandLineOptions "options... -o outdir programs..." f
where f [dir] config = Just $ do
res <- runFutharkM (m config dir) Verbose
diff --git a/src/Futhark/CLI/Misc.hs b/src/Futhark/CLI/Misc.hs
new file mode 100644
index 0000000..891aac0
--- /dev/null
+++ b/src/Futhark/CLI/Misc.hs
@@ -0,0 +1,31 @@
+{-# LANGUAGE FlexibleContexts #-}
+-- Various small subcommands that are too simple to deserve their own file.
+module Futhark.CLI.Misc
+ ( mainCheck
+ )
+where
+
+import Control.Monad.State
+import System.IO
+import System.Exit
+
+import Futhark.Compiler
+import Futhark.Util.Options
+import Futhark.Pipeline
+
+runFutharkM' :: FutharkM () -> IO ()
+runFutharkM' m = do
+ res <- runFutharkM m NotVerbose
+ case res of
+ Left err -> do
+ dumpError newFutharkConfig err
+ exitWith $ ExitFailure 2
+ Right () -> return ()
+
+mainCheck :: String -> [String] -> IO ()
+mainCheck = mainWithOptions () [] "program" $ \args () ->
+ case args of
+ [file] -> Just $ runFutharkM' $ check file
+ _ -> Nothing
+ where check file = do (warnings, _, _) <- readProgram file
+ liftIO $ hPutStr stderr $ show warnings
diff --git a/src/futhark-opencl.hs b/src/Futhark/CLI/OpenCL.hs
index 8f01899..618b4ec 100644
--- a/src/futhark-opencl.hs
+++ b/src/Futhark/CLI/OpenCL.hs
@@ -1,5 +1,5 @@
{-# LANGUAGE FlexibleContexts #-}
-module Main (main) where
+module Futhark.CLI.OpenCL (main) where
import Control.Monad.IO.Class
import System.FilePath
@@ -13,7 +13,7 @@ import Futhark.Util
import Futhark.Util.Pretty (prettyText)
import Futhark.Compiler.CLI
-main :: IO ()
+main :: String -> [String] -> IO ()
main = compilerMain () []
"Compile OpenCL" "Generate OpenCL/C code from optimised Futhark program."
gpuPipeline $ \() mode outpath prog -> do
diff --git a/src/futhark-pkg.hs b/src/Futhark/CLI/Pkg.hs
index c22d670..16f2237 100644
--- a/src/futhark-pkg.hs
+++ b/src/Futhark/CLI/Pkg.hs
@@ -1,6 +1,6 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
-module Main (main) where
+module Futhark.CLI.Pkg (main) where
import Control.Monad.IO.Class
import Control.Monad.State
@@ -51,7 +51,7 @@ installInDir (BuildList bl) dir = do
-- claims to encode filepaths with '/' directory seperators no
-- matter the host OS.
when (".." `elem` Posix.splitPath (Zip.eRelativePath entry)) $
- fail $ "Zip archive for " <> pdir <> " contains suspicuous path: " <>
+ fail $ "Zip archive for " <> pdir <> " contains suspicious path: " <>
Zip.eRelativePath entry
let f = pdir </> makeRelative from_dir (Zip.eRelativePath entry)
createDirectoryIfMissing True $ takeDirectory f
@@ -189,21 +189,22 @@ instance MonadLogger PkgM where
runPkgM :: PkgConfig -> PkgM a -> IO a
runPkgM cfg (PkgM m) = evalStateT (runReaderT m cfg) mempty
-cmdMain :: String -> ([String] -> PkgConfig -> Maybe (IO ())) -> IO ()
+cmdMain :: String -> ([String] -> PkgConfig -> Maybe (IO ()))
+ -> String -> [String] -> IO ()
cmdMain = mainWithOptions (PkgConfig False) options
where options = [ Option "v" ["verbose"]
(NoArg $ Right $ \cfg -> cfg { pkgVerbose = True })
"Write running diagnostics to stderr."]
-doFmt :: IO ()
-doFmt = mainWithOptions () [] "fmt" $ \args () ->
+doFmt :: String -> [String] -> IO ()
+doFmt = mainWithOptions () [] "" $ \args () ->
case args of
[] -> Just $ do
m <- parsePkgManifestFromFile futharkPkg
T.writeFile futharkPkg $ prettyPkgManifest m
_ -> Nothing
-doCheck :: IO ()
+doCheck :: String -> [String] -> IO ()
doCheck = cmdMain "check" $ \args cfg ->
case args of
[] -> Just $ runPkgM cfg $ do
@@ -231,8 +232,8 @@ doCheck = cmdMain "check" $ \args cfg ->
exitFailure
_ -> Nothing
-doSync :: IO ()
-doSync = cmdMain "sync" $ \args cfg ->
+doSync :: String -> [String] -> IO ()
+doSync = cmdMain "" $ \args cfg ->
case args of
[] -> Just $ runPkgM cfg $ do
m <- getPkgManifest
@@ -240,8 +241,8 @@ doSync = cmdMain "sync" $ \args cfg ->
installBuildList (commented $ manifestPkgPath m) bl
_ -> Nothing
-doAdd :: IO ()
-doAdd = cmdMain "add PKGPATH" $ \args cfg ->
+doAdd :: String -> [String] -> IO ()
+doAdd = cmdMain "PKGPATH" $ \args cfg ->
case args of
[p, v] | Right v' <- parseVersion $ T.pack v -> Just $ runPkgM cfg $ doAdd' (T.pack p) v'
[p] -> Just $ runPkgM cfg $
@@ -284,8 +285,8 @@ doAdd = cmdMain "add PKGPATH" $ \args cfg ->
putPkgManifest m'
liftIO $ T.putStrLn "Remember to run 'futhark-pkg sync'."
-doRemove :: IO ()
-doRemove = cmdMain "remove PKGPATH" $ \args cfg ->
+doRemove :: String -> [String] -> IO ()
+doRemove = cmdMain "PKGPATH" $ \args cfg ->
case args of
[p] -> Just $ runPkgM cfg $ doRemove' $ T.pack p
_ -> Nothing
@@ -300,8 +301,8 @@ doRemove = cmdMain "remove PKGPATH" $ \args cfg ->
putPkgManifest m'
liftIO $ T.putStrLn $ "Removed " <> p <> " " <> prettySemVer (requiredPkgRev r) <> "."
-doInit :: IO ()
-doInit = cmdMain "create PKGPATH" $ \args cfg ->
+doInit :: String -> [String] -> IO ()
+doInit = cmdMain "PKGPATH" $ \args cfg ->
case args of
[p] -> Just $ runPkgM cfg $ doCreate' $ T.pack p
_ -> Nothing
@@ -318,8 +319,8 @@ doInit = cmdMain "create PKGPATH" $ \args cfg ->
putPkgManifest $ newPkgManifest $ Just p
liftIO $ T.putStrLn $ "Wrote " <> T.pack futharkPkg <> "."
-doUpgrade :: IO ()
-doUpgrade = cmdMain "upgrade" $ \args cfg ->
+doUpgrade :: String -> [String] -> IO ()
+doUpgrade = cmdMain "" $ \args cfg ->
case args of
[] -> Just $ runPkgM cfg $ do
m <- getPkgManifest
@@ -337,8 +338,8 @@ doUpgrade = cmdMain "upgrade" $ \args cfg ->
return req { requiredPkgRev = v
, requiredHash = Just h }
-doVersions :: IO ()
-doVersions = cmdMain "versions PKGPATH" $ \args cfg ->
+doVersions :: String -> [String] -> IO ()
+doVersions = cmdMain "PKGPATH" $ \args cfg ->
case args of
[p] -> Just $ runPkgM cfg $ doVersions' $ T.pack p
_ -> Nothing
@@ -346,15 +347,14 @@ doVersions = cmdMain "versions PKGPATH" $ \args cfg ->
mapM_ (liftIO . T.putStrLn . prettySemVer) . M.keys . pkgVersions
<=< lookupPackage
-main :: IO ()
-main = do
+main :: String -> [String] -> IO ()
+main prog args = do
-- Ensure that we can make HTTPS requests.
setGlobalManager =<< newManager tlsManagerSettings
-- Avoid Git asking for credentials. We prefer failure.
liftIO $ setEnv "GIT_TERMINAL_PROMPT" "0"
- args <- getArgs
let commands = [ ("add",
(doAdd, "Add another required package to futhark.pkg."))
, ("check",
@@ -374,14 +374,18 @@ main = do
]
usage = "options... <" <> intercalate "|" (map fst commands) <> ">"
case args of
- cmd : args' | Just (m, _) <- lookup cmd commands -> withArgs args' m
- _ -> mainWithOptions () [] usage $ \_ () -> Just $ do
- let k = maximum (map (length . fst) commands) + 3
- usageMsg $ T.unlines $
- ["<command> ...:", "", "Commands:"] ++
- [ " " <> T.pack cmd <> T.pack (replicate (k - length cmd) ' ') <> desc
- | (cmd, (_, desc)) <- commands ]
+ cmd : args' | Just (m, _) <- lookup cmd commands ->
+ m (unwords [prog, cmd]) args'
+ _ -> do
+ let bad _ () = Just $ do
+ let k = maximum (map (length . fst) commands) + 3
+ usageMsg $ T.unlines $
+ ["<command> ...:", "", "Commands:"] ++
+ [ " " <> T.pack cmd <> T.pack (replicate (k - length cmd) ' ') <> desc
+ | (cmd, (_, desc)) <- commands ]
+
+ mainWithOptions () [] usage bad prog args
where usageMsg s = do
- T.putStrLn $ "Usage: futhark-pkg [--version] [--help] " <> s
+ T.putStrLn $ "Usage: " <> T.pack prog <> " [--version] [--help] " <> s
exitFailure
diff --git a/src/futhark-pyopencl.hs b/src/Futhark/CLI/PyOpenCL.hs
index cdadd21..38258e9 100644
--- a/src/futhark-pyopencl.hs
+++ b/src/Futhark/CLI/PyOpenCL.hs
@@ -1,5 +1,5 @@
{-# LANGUAGE FlexibleContexts #-}
-module Main (main) where
+module Futhark.CLI.PyOpenCL (main) where
import Control.Monad.IO.Class
import System.FilePath
@@ -11,7 +11,7 @@ import qualified Futhark.CodeGen.Backends.PyOpenCL as PyOpenCL
import Futhark.Util.Pretty (prettyText)
import Futhark.Compiler.CLI
-main :: IO ()
+main :: String -> [String] -> IO ()
main = compilerMain () []
"Compile PyOpenCL" "Generate Python + OpenCL code from optimised Futhark program."
gpuPipeline $ \() mode outpath prog -> do
diff --git a/src/futhark-py.hs b/src/Futhark/CLI/Python.hs
index 727f628..7a25ae8 100644
--- a/src/futhark-py.hs
+++ b/src/Futhark/CLI/Python.hs
@@ -1,5 +1,5 @@
{-# LANGUAGE FlexibleContexts #-}
-module Main (main) where
+module Futhark.CLI.Python (main) where
import Control.Monad.IO.Class
import System.FilePath
@@ -11,7 +11,7 @@ import qualified Futhark.CodeGen.Backends.SequentialPython as SequentialPy
import Futhark.Util.Pretty (prettyText)
import Futhark.Compiler.CLI
-main :: IO ()
+main :: String -> [String] -> IO ()
main = compilerMain () []
"Compile sequential Python" "Generate sequential Python code from optimised Futhark program."
sequentialCpuPipeline $ \() mode outpath prog -> do
diff --git a/src/Futhark/CLI/REPL.hs b/src/Futhark/CLI/REPL.hs
new file mode 100644
index 0000000..18788ed
--- /dev/null
+++ b/src/Futhark/CLI/REPL.hs
@@ -0,0 +1,419 @@
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE QuasiQuotes #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+module Futhark.CLI.REPL (main) where
+
+import Control.Monad.Free.Church
+import Control.Exception
+import Data.Char
+import Data.List
+import Data.Loc
+import Data.Maybe
+import Data.Version
+import Control.Monad
+import Control.Monad.IO.Class
+import Control.Monad.State
+import Control.Monad.Except
+import qualified Data.Text as T
+import qualified Data.Text.IO as T
+import NeatInterpolation (text)
+import System.Directory
+import System.FilePath
+import System.Console.GetOpt
+import System.IO
+import qualified System.Console.Haskeline as Haskeline
+
+import Language.Futhark
+import Language.Futhark.Parser hiding (EOF)
+import qualified Language.Futhark.TypeChecker as T
+import qualified Language.Futhark.Semantic as T
+import Futhark.MonadFreshNames
+import Futhark.Version
+import Futhark.Compiler
+import Futhark.Pipeline
+import Futhark.Util.Options
+import Futhark.Util (toPOSIX, maybeHead)
+
+import qualified Language.Futhark.Interpreter as I
+
+banner :: String
+banner = unlines [
+ "|// |\\ | |\\ |\\ /",
+ "|/ | \\ |\\ |\\ |/ /",
+ "| | \\ |/ | |\\ \\",
+ "| | \\ | | | \\ \\"
+ ]
+
+main :: String -> [String] -> IO ()
+main = mainWithOptions interpreterConfig options "options..." run
+ where run [] _ = Just repl
+ run _ _ = Nothing
+
+data StopReason = EOF | Stop | Exit | Load FilePath
+
+repl :: IO ()
+repl = do
+ putStr banner
+ putStrLn $ "Version " ++ showVersion version ++ "."
+ putStrLn "Copyright (C) DIKU, University of Copenhagen, released under the ISC license."
+ putStrLn ""
+ putStrLn "Run :help for a list of commands."
+ putStrLn ""
+
+ let toploop s = do
+ (stop, s') <- runStateT (runExceptT $ runFutharkiM $ forever readEvalPrint) s
+ case stop of
+ Left Stop -> finish s'
+ Left EOF -> finish s'
+ Left Exit -> finish s'
+ Left (Load file) -> do
+ liftIO $ T.putStrLn $ "Loading " <> T.pack file
+ maybe_new_state <-
+ liftIO $ newFutharkiState (futharkiCount s) $ Just file
+ case maybe_new_state of
+ Right new_state -> toploop new_state
+ Left err -> do liftIO $ putStrLn err
+ toploop s'
+ Right _ -> return ()
+
+ finish s = do
+ quit <- confirmQuit
+ if quit then return () else toploop s
+
+ maybe_init_state <- liftIO $ newFutharkiState 0 Nothing
+ case maybe_init_state of
+ Left err -> error $ "Failed to initialise intepreter state: " ++ err
+ Right init_state -> Haskeline.runInputT Haskeline.defaultSettings $ toploop init_state
+
+ putStrLn "Leaving futharki."
+
+confirmQuit :: Haskeline.InputT IO Bool
+confirmQuit = do
+ c <- Haskeline.getInputChar "Quit futharki? (y/n) "
+ case c of
+ Nothing -> return True -- EOF
+ Just 'y' -> return True
+ Just 'n' -> return False
+ _ -> confirmQuit
+
+newtype InterpreterConfig = InterpreterConfig { interpreterEntryPoint :: Name }
+
+interpreterConfig :: InterpreterConfig
+interpreterConfig = InterpreterConfig defaultEntryPoint
+
+options :: [FunOptDescr InterpreterConfig]
+options = [ Option "e" ["entry-point"]
+ (ReqArg (\entry -> Right $ \config ->
+ config { interpreterEntryPoint = nameFromString entry })
+ "NAME")
+ "The entry point to execute."
+ ]
+
+data FutharkiState =
+ FutharkiState { futharkiImports :: Imports
+ , futharkiNameSource :: VNameSource
+ , futharkiCount :: Int
+ , futharkiEnv :: (T.Env, I.Ctx)
+ , futharkiBreaking :: Maybe Loc
+ -- ^ Are we currently stopped at a breakpoint?
+ , futharkiSkipBreaks :: [Loc]
+ -- ^ Skip breakpoints at these locations.
+ , futharkiLoaded :: Maybe FilePath
+ -- ^ The currently loaded file.
+ }
+
+newFutharkiState :: Int -> Maybe FilePath -> IO (Either String FutharkiState)
+newFutharkiState count maybe_file = runExceptT $ do
+ (imports, src, tenv, ienv) <- case maybe_file of
+
+ Nothing -> do
+ -- Load the builtins through the type checker.
+ (_, imports, src) <- badOnLeft =<< runExceptT (readLibrary [])
+ -- Then into the interpreter.
+ ienv <- foldM (\ctx -> badOnLeft <=< runInterpreter' . I.interpretImport ctx)
+ I.initialCtx $ map (fmap fileProg) imports
+
+ -- Then make the prelude available in the type checker.
+ (tenv, d, src') <- badOnLeft $ T.checkDec imports src T.initialEnv
+ (T.mkInitialImport ".") $ mkOpen "/futlib/prelude"
+ -- Then in the interpreter.
+ ienv' <- badOnLeft =<< runInterpreter' (I.interpretDec ienv d)
+ return (imports, src', tenv, ienv')
+
+ Just file -> do
+ (ws, imports, src) <-
+ badOnLeft =<< liftIO (runExceptT (readProgram file)
+ `Haskeline.catch` \(err::IOException) ->
+ return (Left (ExternalError (T.pack $ show err))))
+ liftIO $ hPrint stderr ws
+
+ let imp = T.mkInitialImport "."
+ ienv1 <- foldM (\ctx -> badOnLeft <=< runInterpreter' . I.interpretImport ctx) I.initialCtx $
+ map (fmap fileProg) imports
+ (tenv1, d1, src') <- badOnLeft $ T.checkDec imports src T.initialEnv imp $
+ mkOpen "/futlib/prelude"
+ (tenv2, d2, src'') <- badOnLeft $ T.checkDec imports src' tenv1 imp $
+ mkOpen $ toPOSIX $ dropExtension file
+ ienv2 <- badOnLeft =<< runInterpreter' (I.interpretDec ienv1 d1)
+ ienv3 <- badOnLeft =<< runInterpreter' (I.interpretDec ienv2 d2)
+ return (imports, src'', tenv2, ienv3)
+
+ return FutharkiState { futharkiImports = imports
+ , futharkiNameSource = src
+ , futharkiCount = count
+ , futharkiEnv = (tenv, ienv)
+ , futharkiBreaking = Nothing
+ , futharkiSkipBreaks = mempty
+ , futharkiLoaded = maybe_file
+ }
+ where badOnLeft :: Show err => Either err a -> ExceptT String IO a
+ badOnLeft (Right x) = return x
+ badOnLeft (Left err) = throwError $ show err
+
+getPrompt :: FutharkiM String
+getPrompt = do
+ i <- gets futharkiCount
+ return $ "[" ++ show i ++ "]> "
+
+mkOpen :: FilePath -> UncheckedDec
+mkOpen f = OpenDec (ModImport f NoInfo noLoc) noLoc
+
+-- The ExceptT part is more of a continuation, really.
+newtype FutharkiM a =
+ FutharkiM { runFutharkiM :: ExceptT StopReason (StateT FutharkiState (Haskeline.InputT IO)) a }
+ deriving (Functor, Applicative, Monad,
+ MonadState FutharkiState, MonadIO, MonadError StopReason)
+
+readEvalPrint :: FutharkiM ()
+readEvalPrint = do
+ prompt <- getPrompt
+ line <- inputLine prompt
+ breaking <- gets futharkiBreaking
+ case T.uncons line of
+ Nothing
+ | isJust breaking -> throwError Stop
+ | otherwise -> return ()
+
+ Just (':', command) -> do
+ let (cmdname, rest) = T.break isSpace command
+ arg = T.dropWhileEnd isSpace $ T.dropWhile isSpace rest
+ case filter ((cmdname `T.isPrefixOf`) . fst) commands of
+ [] -> liftIO $ T.putStrLn $ "Unknown command '" <> cmdname <> "'"
+ [(_, (cmdf, _))] -> cmdf arg
+ matches -> liftIO $ T.putStrLn $ "Ambiguous command; could be one of " <>
+ mconcat (intersperse ", " (map fst matches))
+
+ _ -> do
+ -- Read a declaration or expression.
+ maybe_dec_or_e <- parseDecOrExpIncrM (inputLine " ") prompt line
+
+ case maybe_dec_or_e of
+ Left err -> liftIO $ print err
+ Right (Left d) -> onDec d
+ Right (Right e) -> onExp e
+ modify $ \s -> s { futharkiCount = futharkiCount s + 1 }
+ where inputLine prompt = do
+ inp <- FutharkiM $ lift $ lift $ Haskeline.getInputLine prompt
+ case inp of
+ Just s -> return $ T.pack s
+ Nothing -> throwError EOF
+
+getIt :: FutharkiM (Imports, VNameSource, T.Env, I.Ctx)
+getIt = do
+ imports <- gets futharkiImports
+ src <- gets futharkiNameSource
+ (tenv, ienv) <- gets futharkiEnv
+ return (imports, src, tenv, ienv)
+
+onDec :: UncheckedDec -> FutharkiM ()
+onDec d = do
+ (imports, src, tenv, ienv) <- getIt
+ cur_import <- T.mkInitialImport . fromMaybe "." <$> gets futharkiLoaded
+
+ -- Most of the complexity here concerns the dealing with the fact
+ -- that 'import "foo"' is a declaration. We have to involve a lot
+ -- of machinery to load this external code before executing the
+ -- declaration itself.
+ let basis = Basis imports src ["/futlib/prelude"]
+ mkImport = uncurry $ T.mkImportFrom cur_import
+ imp_r <- runExceptT $ readImports basis (map mkImport $ decImports d)
+
+ case imp_r of
+ Left e -> liftIO $ print e
+ Right (_, imports', src') ->
+ case T.checkDec imports' src' tenv cur_import d of
+ Left e -> liftIO $ print e
+ Right (tenv', d', src'') -> do
+ let new_imports = filter ((`notElem` map fst imports) . fst) imports'
+ int_r <- runInterpreter $ do
+ let onImport ienv' (s, imp) =
+ I.interpretImport ienv' (s, T.fileProg imp)
+ ienv' <- foldM onImport ienv new_imports
+ I.interpretDec ienv' d'
+ case int_r of
+ Left err -> liftIO $ print err
+ Right ienv' -> modify $ \s -> s { futharkiEnv = (tenv', ienv')
+ , futharkiImports = imports'
+ , futharkiNameSource = src''
+ }
+
+onExp :: UncheckedExp -> FutharkiM ()
+onExp e = do
+ (imports, src, tenv, ienv) <- getIt
+ case showErr (T.checkExp imports src tenv e) of
+ Left err -> liftIO $ putStrLn err
+ Right (_, e') -> do
+ r <- runInterpreter $ I.interpretExp ienv e'
+ case r of
+ Left err -> liftIO $ print err
+ Right v -> liftIO $ putStrLn $ pretty v
+ where showErr :: Show a => Either a b -> Either String b
+ showErr = either (Left . show) Right
+
+runInterpreter :: F I.ExtOp a -> FutharkiM (Either I.InterpreterError a)
+runInterpreter m = runF m (return . Right) intOp
+ where
+ intOp (I.ExtOpError err) =
+ return $ Left err
+ intOp (I.ExtOpTrace w v c) = do
+ liftIO $ putStrLn $ "Trace at " ++ locStr w ++ ": " ++ v
+ c
+ intOp (I.ExtOpBreak w ctx tenv c) = do
+ s <- get
+
+ -- Are we supposed to skip this breakpoint?
+ let loc = maybe noLoc locOf $ maybeHead w
+
+ -- We do not want recursive breakpoints. It could work fine
+ -- technically, but is probably too confusing to be useful.
+ unless (isJust (futharkiBreaking s) || loc `elem` futharkiSkipBreaks s) $ do
+ liftIO $ putStrLn $ "Breaking at " ++ intercalate " -> " (map locStr w) ++ "."
+ liftIO $ putStrLn "<Enter> to continue."
+
+ -- Note the cleverness to preserve the Haskeline session (for
+ -- line history and such).
+ (stop, s') <-
+ FutharkiM $ lift $ lift $
+ runStateT (runExceptT $ runFutharkiM $ forever readEvalPrint)
+ s { futharkiEnv = (tenv, ctx)
+ , futharkiCount = futharkiCount s + 1
+ , futharkiBreaking = Just loc }
+
+ case stop of
+ Left (Load file) -> throwError $ Load file
+ _ -> do liftIO $ putStrLn "Continuing..."
+ put s { futharkiCount = futharkiCount s'
+ , futharkiSkipBreaks = futharkiSkipBreaks s' <> futharkiSkipBreaks s }
+
+ c
+
+runInterpreter' :: MonadIO m => F I.ExtOp a -> m (Either I.InterpreterError a)
+runInterpreter' m = runF m (return . Right) intOp
+ where intOp (I.ExtOpError err) = return $ Left err
+ intOp (I.ExtOpTrace w v c) = do
+ liftIO $ putStrLn $ "Trace at " ++ locStr w ++ ": " ++ v
+ c
+ intOp (I.ExtOpBreak _ _ _ c) = c
+
+type Command = T.Text -> FutharkiM ()
+
+loadCommand :: Command
+loadCommand file = do
+ loaded <- gets futharkiLoaded
+ case (T.null file, loaded) of
+ (True, Just loaded') -> throwError $ Load loaded'
+ (True, Nothing) -> liftIO $ T.putStrLn "No file specified and no file previously loaded."
+ (False, _) -> throwError $ Load $ T.unpack file
+
+genTypeCommand :: (Show err1, Show err2) =>
+ (String -> T.Text -> Either err1 a)
+ -> (Imports -> VNameSource -> T.Env -> a -> Either err2 b)
+ -> (b -> String)
+ -> Command
+genTypeCommand f g h e = do
+ prompt <- getPrompt
+ case f prompt e of
+ Left err -> liftIO $ print err
+ Right e' -> do
+ imports <- gets futharkiImports
+ src <- gets futharkiNameSource
+ (tenv, _) <- gets futharkiEnv
+ case g imports src tenv e' of
+ Left err -> liftIO $ print err
+ Right x -> liftIO $ putStrLn $ h x
+
+typeCommand :: Command
+typeCommand = genTypeCommand parseExp T.checkExp $ \(ps, e) ->
+ pretty e <> concatMap ((" "<>) . pretty) ps <>
+ " : " <> pretty (typeOf e)
+
+mtypeCommand :: Command
+mtypeCommand = genTypeCommand parseModExp T.checkModExp $ pretty . fst
+
+unbreakCommand :: Command
+unbreakCommand _ = do
+ breaking <- gets futharkiBreaking
+ case breaking of
+ Nothing -> liftIO $ putStrLn "Not currently stopped at a breakpoint."
+ Just loc -> do modify $ \s -> s { futharkiSkipBreaks = loc : futharkiSkipBreaks s }
+ throwError Stop
+
+pwdCommand :: Command
+pwdCommand _ = liftIO $ putStrLn =<< getCurrentDirectory
+
+cdCommand :: Command
+cdCommand dir
+ | T.null dir = liftIO $ putStrLn "Usage: ':cd <dir>'."
+ | otherwise =
+ liftIO $ setCurrentDirectory (T.unpack dir)
+ `Haskeline.catch` \(err::IOException) -> print err
+
+helpCommand :: Command
+helpCommand _ = liftIO $ forM_ commands $ \(cmd, (_, desc)) -> do
+ T.putStrLn $ ":" <> cmd
+ T.putStrLn $ T.replicate (1+T.length cmd) "-"
+ T.putStr desc
+ T.putStrLn ""
+ T.putStrLn ""
+
+quitCommand :: Command
+quitCommand _ = throwError Exit
+
+commands :: [(T.Text, (Command, T.Text))]
+commands = [("load", (loadCommand, [text|
+Load a Futhark source file. Usage:
+
+ > :load foo.fut
+
+If the loading succeeds, any subsequentialy entered expressions entered
+subsequently will have access to the definition (such as function definitions)
+in the source file.
+
+Only one source file can be loaded at a time. Using the :load command a
+second time will replace the previously loaded file. It will also replace
+any declarations entered at the REPL.
+
+|])),
+ ("type", (typeCommand, [text|
+Show the type of an expression, which must fit on a single line.
+|])),
+ ("mtype", (mtypeCommand, [text|
+Show the type of a module expression, which must fit on a single line.
+|])),
+ ("unbreak", (unbreakCommand, [text|
+Skip all future occurences of the current breakpoint.
+|])),
+ ("pwd", (pwdCommand, [text|
+Print the current working directory.
+|])),
+ ("cd", (cdCommand, [text|
+Change the current working directory.
+|])),
+ ("help", (helpCommand, [text|
+Print a list of commands and a description of their behaviour.
+|])),
+ ("quit", (quitCommand, [text|
+Quit futharki.
+|]))]
diff --git a/src/Futhark/CLI/Run.hs b/src/Futhark/CLI/Run.hs
new file mode 100644
index 0000000..1674307
--- /dev/null
+++ b/src/Futhark/CLI/Run.hs
@@ -0,0 +1,143 @@
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+module Futhark.CLI.Run (main) where
+
+import Control.Monad.Free.Church
+import Control.Exception
+import Data.Array
+import Data.List
+import Data.Loc
+import Data.Maybe
+import qualified Data.Map as M
+import Control.Monad
+import Control.Monad.IO.Class
+import Control.Monad.Except
+import qualified Data.Text as T
+import qualified Data.Text.IO as T
+import System.FilePath
+import System.Exit
+import System.Console.GetOpt
+import System.IO
+import qualified System.Console.Haskeline as Haskeline
+
+import Prelude
+
+import Language.Futhark
+import Language.Futhark.Parser hiding (EOF)
+import qualified Language.Futhark.TypeChecker as T
+import qualified Language.Futhark.Semantic as T
+import Futhark.Compiler
+import Futhark.Pipeline
+import Futhark.Util.Options
+import Futhark.Util (toPOSIX)
+
+import qualified Language.Futhark.Interpreter as I
+
+main :: String -> [String] -> IO ()
+main = mainWithOptions interpreterConfig options "options... program" run
+ where run [prog] config = Just $ interpret config prog
+ run _ _ = Nothing
+
+interpret :: InterpreterConfig -> FilePath -> IO ()
+interpret config fp = do
+ pr <- newFutharkiState config fp
+ (tenv, ienv) <- case pr of Left err -> do hPutStrLn stderr err
+ exitFailure
+ Right env -> return env
+
+ let entry = interpreterEntryPoint config
+ vr <- parseValues "stdin" <$> T.getContents
+
+ inps <-
+ case vr of
+ Left err -> do
+ hPutStrLn stderr $ "Error when reading input: " ++ show err
+ exitFailure
+ Right vs
+ | Just vs' <- mapM convertValue vs ->
+ return vs'
+ | otherwise -> do
+ hPutStrLn stderr "Error when reading input: irregular array."
+ exitFailure
+
+ (fname, ret) <-
+ case M.lookup (T.Term, entry) $ T.envNameMap tenv of
+ Just fname
+ | Just (T.BoundV _ t) <- M.lookup (qualLeaf fname) $ T.envVtable tenv ->
+ return (fname, toStructural $ snd $ unfoldFunType t)
+ _ -> do hPutStrLn stderr $ "Invalid entry point: " ++ pretty entry
+ exitFailure
+
+ r <- runInterpreter' $ I.interpretFunction ienv (qualLeaf fname) inps
+ case r of
+ Left err -> do hPrint stderr err
+ exitFailure
+ Right res ->
+ case (I.fromTuple res, isTupleRecord ret) of
+ (Just vs, Just ts) -> zipWithM_ putValue vs ts
+ _ -> putValue res ret
+
+putValue :: I.Value -> TypeBase () () -> IO ()
+putValue v t
+ | I.isEmptyArray v =
+ putStrLn $ "empty(" ++ pretty (stripArray 1 t) ++ ")"
+ | otherwise = putStrLn $ pretty v
+
+convertValue :: Value -> Maybe I.Value
+convertValue (PrimValue p) = Just $ I.ValuePrim p
+convertValue (ArrayValue arr _) = I.mkArray =<< mapM convertValue (elems arr)
+
+data InterpreterConfig =
+ InterpreterConfig { interpreterEntryPoint :: Name
+ , interpreterPrintWarnings :: Bool
+ }
+
+interpreterConfig :: InterpreterConfig
+interpreterConfig = InterpreterConfig defaultEntryPoint True
+
+options :: [FunOptDescr InterpreterConfig]
+options = [ Option "e" ["entry-point"]
+ (ReqArg (\entry -> Right $ \config ->
+ config { interpreterEntryPoint = nameFromString entry })
+ "NAME")
+ "The entry point to execute."
+ , Option "w" ["no-warnings"]
+ (NoArg $ Right $ \config -> config { interpreterPrintWarnings = False })
+ "Do not print warnings."
+ ]
+
+newFutharkiState :: InterpreterConfig -> FilePath
+ -> IO (Either String (T.Env, I.Ctx))
+newFutharkiState cfg file = runExceptT $ do
+ (ws, imports, src) <-
+ badOnLeft =<< liftIO (runExceptT (readProgram file)
+ `Haskeline.catch` \(err::IOException) ->
+ return (Left (ExternalError (T.pack $ show err))))
+ when (interpreterPrintWarnings cfg) $
+ liftIO $ hPrint stderr ws
+
+ let imp = T.mkInitialImport "."
+ ienv1 <- foldM (\ctx -> badOnLeft <=< runInterpreter' . I.interpretImport ctx) I.initialCtx $
+ map (fmap fileProg) imports
+ (tenv1, d1, src') <- badOnLeft $ T.checkDec imports src T.initialEnv imp $
+ mkOpen "/futlib/prelude"
+ (tenv2, d2, _) <- badOnLeft $ T.checkDec imports src' tenv1 imp $
+ mkOpen $ toPOSIX $ dropExtension file
+ ienv2 <- badOnLeft =<< runInterpreter' (I.interpretDec ienv1 d1)
+ ienv3 <- badOnLeft =<< runInterpreter' (I.interpretDec ienv2 d2)
+ return (tenv2, ienv3)
+ where badOnLeft :: Show err => Either err a -> ExceptT String IO a
+ badOnLeft (Right x) = return x
+ badOnLeft (Left err) = throwError $ show err
+
+mkOpen :: FilePath -> UncheckedDec
+mkOpen f = OpenDec (ModImport f NoInfo noLoc) noLoc
+
+runInterpreter' :: MonadIO m => F I.ExtOp a -> m (Either I.InterpreterError a)
+runInterpreter' m = runF m (return . Right) intOp
+ where intOp (I.ExtOpError err) = return $ Left err
+ intOp (I.ExtOpTrace w v c) = do
+ liftIO $ putStrLn $ "Trace at " ++ locStr w ++ ": " ++ v
+ c
+ intOp (I.ExtOpBreak _ _ _ c) = c
diff --git a/src/futhark-test.hs b/src/Futhark/CLI/Test.hs
index 9444221..b5883d4 100644
--- a/src/futhark-test.hs
+++ b/src/Futhark/CLI/Test.hs
@@ -1,7 +1,7 @@
{-# LANGUAGE OverloadedStrings, FlexibleContexts, LambdaCase #-}
-- | This program is a convenience utility for running the Futhark
-- test suite, and its test programs.
-module Main (main) where
+module Futhark.CLI.Test (main) where
import Control.Applicative.Lift (runErrors, failure, Errors, Lift(..))
import Control.Concurrent
@@ -13,7 +13,6 @@ import qualified Data.ByteString as SBS
import qualified Data.ByteString.Lazy as LBS
import Data.List
-import Data.Semigroup ((<>))
import qualified Data.Map.Strict as M
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
@@ -99,13 +98,13 @@ optimisedProgramMetrics programs pipeline program =
check "--gpu"
where check opt = do
(code, output, err) <-
- io $ readProcessWithExitCode (configTypeChecker programs) [opt, "--metrics", program] ""
+ io $ readProcessWithExitCode (configFuthark programs) ["dev", opt, "--metrics", program] ""
let output' = T.decodeUtf8 output
case code of
ExitSuccess
| [(m, [])] <- reads $ T.unpack output' -> return m
| otherwise -> throwError $ "Could not read metrics output:\n" <> output'
- ExitFailure 127 -> throwError $ progNotFound $ T.pack $ configTypeChecker programs
+ ExitFailure 127 -> throwError $ progNotFound $ T.pack $ configFuthark programs
ExitFailure _ -> throwError $ T.decodeUtf8 err
testMetrics :: ProgConfig -> FilePath -> StructureTest -> TestM ()
@@ -137,59 +136,62 @@ runTestCase :: TestCase -> TestM ()
runTestCase (TestCase mode program testcase progs) =
case testAction testcase of
- CompileTimeFailure expected_error -> do
- let typeChecker = configTypeChecker progs
- context ("Type-checking with " <> T.pack typeChecker) $ do
+ CompileTimeFailure expected_error ->
+ context (mconcat ["Type-checking with '", T.pack futhark,
+ " check ", T.pack program, "'"]) $ do
(code, _, err) <-
- io $ readProcessWithExitCode typeChecker ["-t", program] ""
+ io $ readProcessWithExitCode futhark ["check", program] ""
case code of
ExitSuccess -> throwError "Expected failure\n"
- ExitFailure 127 -> throwError $ progNotFound $ T.pack typeChecker
+ ExitFailure 127 -> throwError $ progNotFound $ T.pack futhark
ExitFailure 1 -> throwError $ T.decodeUtf8 err
ExitFailure _ -> checkError expected_error err
RunCases _ _ warnings | mode == TypeCheck -> do
- let typeChecker = configTypeChecker progs
- options = ["-t", program] ++ configExtraCompilerOptions progs
- context ("Type-checking with " <> T.pack typeChecker) $ do
- (code, _, err) <- io $ readProcessWithExitCode typeChecker options ""
+ let options = ["check", program] ++ configExtraCompilerOptions progs
+ context (mconcat ["Type-checking with '", T.pack futhark,
+ " check ", T.pack program, "'"]) $ do
+ (code, _, err) <- io $ readProcessWithExitCode futhark options ""
testWarnings warnings err
case code of
ExitSuccess -> return ()
- ExitFailure 127 -> throwError $ progNotFound $ T.pack typeChecker
+ ExitFailure 127 -> throwError $ progNotFound $ T.pack futhark
ExitFailure _ -> throwError $ T.decodeUtf8 err
RunCases ios structures warnings -> do
-- Compile up-front and reuse same executable for several entry points.
- let compiler = configCompiler progs
- interpreter = configInterpreter progs
+ let backend = configBackend progs
extra_options = configExtraCompilerOptions progs
+ unless (mode == Compile) $
+ context "Generating reference outputs" $
+ ensureReferenceOutput futhark "c" program ios
unless (mode == Interpreted) $
- context ("Compiling with " <> T.pack compiler) $ do
- compileTestProgram extra_options compiler program warnings
+ context ("Compiling with --backend=" <> T.pack backend) $ do
+ compileTestProgram extra_options futhark backend program warnings
mapM_ (testMetrics progs program) structures
unless (mode == Compile) $
context "Running compiled program" $
accErrors_ $ map (runCompiledEntry program progs) ios
unless (mode == Compile || mode == Compiled) $
- context ("Interpreting with " <> T.pack interpreter) $
- accErrors_ $ map (runInterpretedEntry interpreter program) ios
+ context "Interpreting" $
+ accErrors_ $ map (runInterpretedEntry futhark program) ios
+ where futhark = configFuthark progs
runInterpretedEntry :: String -> FilePath -> InputOutputs -> TestM()
-runInterpretedEntry futharki program (InputOutputs entry run_cases) =
+runInterpretedEntry futhark program (InputOutputs entry run_cases) =
let dir = takeDirectory program
- runInterpretedCase run@(TestRun _ inputValues expectedResult index _) =
+ runInterpretedCase run@(TestRun _ inputValues _ index _) =
unless ("compiled" `elem` runTags run) $
context ("Entry point: " <> entry
<> "; dataset: " <> T.pack (runDescription run)) $ do
input <- T.unlines . map prettyText <$> getValues dir inputValues
- expectedResult' <- getExpectedResult dir expectedResult
+ expectedResult' <- getExpectedResult program entry run
(code, output, err) <-
- io $ readProcessWithExitCode futharki ["-e", T.unpack entry, program] $
+ io $ readProcessWithExitCode futhark ["run", "-e", T.unpack entry, program] $
T.encodeUtf8 input
case code of
- ExitFailure 127 -> throwError $ progNotFound $ T.pack futharki
+ ExitFailure 127 -> throwError $ progNotFound $ T.pack futhark
_ -> compareResult entry index program expectedResult'
=<< runResult program code output err
@@ -198,33 +200,27 @@ runInterpretedEntry futharki program (InputOutputs entry run_cases) =
runCompiledEntry :: FilePath -> ProgConfig -> InputOutputs -> TestM ()
runCompiledEntry program progs (InputOutputs entry run_cases) =
- -- Explicitly prefixing the current directory is necessary for
- -- readProcessWithExitCode to find the binary when binOutputf has
- -- no path component.
+ -- Explicitly prefixing the current directory is necessary for
+ -- readProcessWithExitCode to find the binary when binOutputf has
+ -- no path component.
let binOutputf = dropExtension program
- dir = takeDirectory program
binpath = "." </> binOutputf
entry_options = ["-e", T.unpack entry]
runner = configRunner progs
extra_options = configExtraOptions progs
- (to_run, to_run_args)
- | null runner = (binpath, entry_options ++ extra_options)
- | otherwise = (runner, binpath : entry_options ++ extra_options)
- runCompiledCase run@(TestRun _ inputValues expectedResult index _) =
+ runCompiledCase run@(TestRun _ inputValues _ index _) =
context ("Entry point: " <> entry
<> "; dataset: " <> T.pack (runDescription run)) $ do
-
- input <- getValuesBS dir inputValues
- expectedResult' <- getExpectedResult dir expectedResult
+ expected <- getExpectedResult program entry run
(progCode, output, progerr) <-
- io $ readProcessWithExitCode to_run to_run_args $ LBS.toStrict input
- compareResult entry index program expectedResult'
+ runProgram runner extra_options program entry inputValues
+ compareResult entry index program expected
=<< runResult program progCode output progerr
in context ("Running " <> T.pack (unwords $ binpath : entry_options ++ extra_options)) $
- accErrors_ $ map runCompiledCase run_cases
+ accErrors_ $ map runCompiledCase run_cases
checkError :: ExpectedError -> SBS.ByteString -> TestM ()
checkError (ThisError regex_s regex) err
@@ -245,41 +241,26 @@ runResult program ExitSuccess stdout_s _ =
runResult _ (ExitFailure code) _ stderr_s =
return $ ErrorResult code stderr_s
-getExpectedResult :: MonadIO m =>
- FilePath -> ExpectedResult Values
- -> m (ExpectedResult [Value])
-getExpectedResult dir (Succeeds (Just vals)) = Succeeds . Just <$> getValues dir vals
-getExpectedResult _ (Succeeds Nothing) = return $ Succeeds Nothing
-getExpectedResult _ (RunTimeFailure err) = return $ RunTimeFailure err
-
-compileTestProgram :: [String] -> String -> FilePath -> [WarningTest] -> TestM ()
-compileTestProgram extra_options futharkc program warnings = do
- (futcode, _, futerr) <- io $ readProcessWithExitCode futharkc options ""
+compileTestProgram :: [String] -> FilePath -> String -> FilePath -> [WarningTest] -> TestM ()
+compileTestProgram extra_options futhark backend program warnings = do
+ (_, futerr) <- compileProgram extra_options futhark backend program
testWarnings warnings futerr
- case futcode of
- ExitFailure 127 -> throwError $ progNotFound $ T.pack futharkc
- ExitFailure _ -> throwError $ T.decodeUtf8 futerr
- ExitSuccess -> return ()
- where binOutputf = dropExtension program
- options = [program, "-o", binOutputf] ++ extra_options
compareResult :: T.Text -> Int -> FilePath -> ExpectedResult [Value] -> RunResult
-> TestM ()
compareResult _ _ _ (Succeeds Nothing) SuccessResult{} =
return ()
compareResult entry index program (Succeeds (Just expectedResult)) (SuccessResult actualResult) =
- case compareValues actualResult expectedResult of
- Just mismatches ->
- let reportMismatch mismatch = do
- let actualf = program <.> T.unpack entry <.> show index <.> "actual"
- expectedf = program <.> T.unpack entry <.> show index <.> "expected"
- io $ SBS.writeFile actualf $
- T.encodeUtf8 $ T.unlines $ map prettyText actualResult
- io $ SBS.writeFile expectedf $
- T.encodeUtf8 $ T.unlines $ map prettyText expectedResult
- throwError $ T.pack actualf <> " and " <> T.pack expectedf <>
- " do not match:\n" <> T.pack (show mismatch) <> "\n"
- in mapM_ reportMismatch mismatches
+ case compareValues1 actualResult expectedResult of
+ Just mismatch -> do
+ let actualf = program <.> T.unpack entry <.> show index <.> "actual"
+ expectedf = program <.> T.unpack entry <.> show index <.> "expected"
+ io $ SBS.writeFile actualf $
+ T.encodeUtf8 $ T.unlines $ map prettyText actualResult
+ io $ SBS.writeFile expectedf $
+ T.encodeUtf8 $ T.unlines $ map prettyText expectedResult
+ throwError $ T.pack actualf <> " and " <> T.pack expectedf <>
+ " do not match:\n" <> T.pack (show mismatch) <> "\n"
Nothing ->
return ()
compareResult _ _ _ (RunTimeFailure expectedError) (ErrorResult _ actualError) =
@@ -435,7 +416,7 @@ runTests config paths = do
Failure s -> do
when isTTY moveCursorToTableTop
clear
- T.putStrLn $ (T.pack (inRed $ testCaseProgram test) <> ":\n") <> T.concat s
+ T.putStr $ (T.pack (inRed $ testCaseProgram test) <> ":\n") <> T.unlines s
when isTTY spaceTable
getResults $ ts' { testStatusFail = testStatusFail ts' + 1
, testStatusRunPass = testStatusRunPass ts'
@@ -487,9 +468,8 @@ defaultConfig = TestConfig { configTestMode = Everything
, configExclude = [ "disable" ]
, configPrograms =
ProgConfig
- { configCompiler = "futhark-c"
- , configInterpreter = "futharki"
- , configTypeChecker = "futhark"
+ { configBackend = "c"
+ , configFuthark = "futhark"
, configRunner = ""
, configExtraOptions = []
, configExtraCompilerOptions = []
@@ -498,9 +478,8 @@ defaultConfig = TestConfig { configTestMode = Everything
}
data ProgConfig = ProgConfig
- { configCompiler :: FilePath
- , configInterpreter :: FilePath
- , configTypeChecker :: FilePath
+ { configBackend :: String
+ , configFuthark :: FilePath
, configRunner :: FilePath
, configExtraCompilerOptions :: [String]
, configExtraOptions :: [String]
@@ -511,17 +490,13 @@ data ProgConfig = ProgConfig
changeProgConfig :: (ProgConfig -> ProgConfig) -> TestConfig -> TestConfig
changeProgConfig f config = config { configPrograms = f $ configPrograms config }
-setCompiler :: FilePath -> ProgConfig -> ProgConfig
-setCompiler compiler config =
- config { configCompiler = compiler }
-
-setInterpreter :: FilePath -> ProgConfig -> ProgConfig
-setInterpreter interpreter config =
- config { configInterpreter = interpreter }
+setBackend :: FilePath -> ProgConfig -> ProgConfig
+setBackend backend config =
+ config { configBackend = backend }
-setTypeChecker :: FilePath -> ProgConfig -> ProgConfig
-setTypeChecker typeChecker config =
- config { configTypeChecker = typeChecker }
+setFuthark :: FilePath -> ProgConfig -> ProgConfig
+setFuthark futhark config =
+ config { configFuthark = futhark }
setRunner :: FilePath -> ProgConfig -> ProgConfig
setRunner runner config =
@@ -559,15 +534,12 @@ commandLineOptions = [
, Option [] ["no-terminal", "notty"]
(NoArg $ Right $ \config -> config { configLineOutput = True })
"Provide simpler line-based output."
- , Option [] ["typechecker"]
- (ReqArg (Right . changeProgConfig . setTypeChecker) "PROGRAM")
- "What to run for type-checking (defaults to 'futhark')."
- , Option [] ["compiler"]
- (ReqArg (Right . changeProgConfig . setCompiler) "PROGRAM")
- "What to run for code generation (defaults to 'futhark-c')."
- , Option [] ["interpreter"]
- (ReqArg (Right . changeProgConfig . setInterpreter) "PROGRAM")
- "What to run for interpretation (defaults to 'futharki')."
+ , Option [] ["backend"]
+ (ReqArg (Right . changeProgConfig . setBackend) "BACKEND")
+ "Backend used for compilation (defaults to 'c')."
+ , Option [] ["futhark"]
+ (ReqArg (Right . changeProgConfig . setFuthark) "PROGRAM")
+ "Program to run for subcommands (defaults to 'futhark')."
, Option [] ["runner"]
(ReqArg (Right . changeProgConfig . setRunner) "PROGRAM")
"The program used to run the Futhark-generated programs (defaults to nothing)."
@@ -585,6 +557,6 @@ commandLineOptions = [
"Pass this option to the compiler (or typechecker if in -t mode)."
]
-main :: IO ()
+main :: String -> [String] -> IO ()
main = mainWithOptions defaultConfig commandLineOptions "options... programs..." $ \progs config ->
Just $ runTests config progs
diff --git a/src/Futhark/CodeGen/Backends/CCUDA.hs b/src/Futhark/CodeGen/Backends/CCUDA.hs
new file mode 100644
index 0000000..0d42139
--- /dev/null
+++ b/src/Futhark/CodeGen/Backends/CCUDA.hs
@@ -0,0 +1,277 @@
+{-# LANGUAGE QuasiQuotes #-}
+{-# LANGUAGE TupleSections #-}
+module Futhark.CodeGen.Backends.CCUDA
+ ( compileProg
+ , GC.CParts(..)
+ , GC.asLibrary
+ , GC.asExecutable
+ ) where
+
+import qualified Language.C.Quote.OpenCL as C
+import Data.List
+
+import qualified Futhark.CodeGen.Backends.GenericC as GC
+import qualified Futhark.CodeGen.ImpGen.CUDA as ImpGen
+import Futhark.Error
+import Futhark.Representation.ExplicitMemory hiding (GetSize, CmpSizeLe, GetSizeMax)
+import Futhark.MonadFreshNames
+import Futhark.CodeGen.ImpCode.OpenCL
+import Futhark.CodeGen.Backends.CCUDA.Boilerplate
+import Futhark.CodeGen.Backends.GenericC.Options
+
+import Data.Maybe (catMaybes)
+
+compileProg :: MonadFreshNames m => Prog ExplicitMemory -> m (Either InternalError GC.CParts)
+compileProg prog = do
+ res <- ImpGen.compileProg prog
+ case res of
+ Left err -> return $ Left err
+ Right (Program cuda_code cuda_prelude kernel_names _ sizes prog') ->
+ let extra = generateBoilerplate cuda_code cuda_prelude
+ kernel_names sizes
+ in Right <$> GC.compileProg operations extra cuda_includes
+ [Space "device", Space "local", DefaultSpace] cliOptions prog'
+ where
+ operations :: GC.Operations OpenCL ()
+ operations = GC.Operations
+ { GC.opsWriteScalar = writeCUDAScalar
+ , GC.opsReadScalar = readCUDAScalar
+ , GC.opsAllocate = allocateCUDABuffer
+ , GC.opsDeallocate = deallocateCUDABuffer
+ , GC.opsCopy = copyCUDAMemory
+ , GC.opsStaticArray = staticCUDAArray
+ , GC.opsMemoryType = cudaMemoryType
+ , GC.opsCompiler = callKernel
+ , GC.opsFatMemory = True
+ }
+ cuda_includes = unlines [ "#include <cuda.h>"
+ , "#include <nvrtc.h>"
+ ]
+
+cliOptions :: [Option]
+cliOptions = [ Option { optionLongName = "dump-cuda"
+ , optionShortName = Nothing
+ , optionArgument = RequiredArgument
+ , optionAction = [C.cstm|futhark_context_config_dump_program_to(cfg, optarg);|]
+ }
+ , Option { optionLongName = "load-cuda"
+ , optionShortName = Nothing
+ , optionArgument = RequiredArgument
+ , optionAction = [C.cstm|futhark_context_config_load_program_from(cfg, optarg);|]
+ }
+ , Option { optionLongName = "dump-ptx"
+ , optionShortName = Nothing
+ , optionArgument = RequiredArgument
+ , optionAction = [C.cstm|futhark_context_config_dump_ptx_to(cfg, optarg);|]
+ }
+ , Option { optionLongName = "load-ptx"
+ , optionShortName = Nothing
+ , optionArgument = RequiredArgument
+ , optionAction = [C.cstm|futhark_context_config_load_ptx_from(cfg, optarg);|]
+ }
+ , Option { optionLongName = "print-sizes"
+ , optionShortName = Nothing
+ , optionArgument = NoArgument
+ , optionAction = [C.cstm|{
+ int n = futhark_get_num_sizes();
+ for (int i = 0; i < n; i++) {
+ printf("%s (%s)\n", futhark_get_size_name(i),
+ futhark_get_size_class(i));
+ }
+ exit(0);
+ }|]
+ }
+ ]
+
+writeCUDAScalar :: GC.WriteScalar OpenCL ()
+writeCUDAScalar mem idx t "device" _ val = do
+ val' <- newVName "write_tmp"
+ GC.stm [C.cstm|{$ty:t $id:val' = $exp:val;
+ CUDA_SUCCEED(
+ cuMemcpyHtoD($exp:mem + $exp:idx,
+ &$id:val',
+ sizeof($ty:t)));
+ }|]
+writeCUDAScalar _ _ _ space _ _ =
+ fail $ "Cannot write to '" ++ space ++ "' memory space."
+
+readCUDAScalar :: GC.ReadScalar OpenCL ()
+readCUDAScalar mem idx t "device" _ = do
+ val <- newVName "read_res"
+ GC.decl [C.cdecl|$ty:t $id:val;|]
+ GC.stm [C.cstm|CUDA_SUCCEED(
+ cuMemcpyDtoH(&$id:val,
+ $exp:mem + $exp:idx,
+ sizeof($ty:t)));
+ |]
+ return [C.cexp|$id:val|]
+readCUDAScalar _ _ _ space _ =
+ fail $ "Cannot write to '" ++ space ++ "' memory space."
+
+allocateCUDABuffer :: GC.Allocate OpenCL ()
+allocateCUDABuffer mem size tag "device" =
+ GC.stm [C.cstm|CUDA_SUCCEED(cuda_alloc(&ctx->cuda, $exp:size, $exp:tag, &$exp:mem));|]
+allocateCUDABuffer _ _ _ "local" = return ()
+allocateCUDABuffer _ _ _ space =
+ fail $ "Cannot allocate in '" ++ space ++ "' memory space."
+
+deallocateCUDABuffer :: GC.Deallocate OpenCL ()
+deallocateCUDABuffer mem tag "device" =
+ GC.stm [C.cstm|CUDA_SUCCEED(cuda_free(&ctx->cuda, $exp:mem, $exp:tag));|]
+deallocateCUDABuffer _ _ "local" = return ()
+deallocateCUDABuffer _ _ space =
+ fail $ "Cannot deallocate in '" ++ space ++ "' memory space."
+
+copyCUDAMemory :: GC.Copy OpenCL ()
+copyCUDAMemory dstmem dstidx dstSpace srcmem srcidx srcSpace nbytes = do
+ fn <- memcpyFun dstSpace srcSpace
+ GC.stm [C.cstm|CUDA_SUCCEED(
+ $id:fn($exp:dstmem + $exp:dstidx,
+ $exp:srcmem + $exp:srcidx,
+ $exp:nbytes));
+ |]
+ where
+ memcpyFun DefaultSpace (Space "device") = return "cuMemcpyDtoH"
+ memcpyFun (Space "device") DefaultSpace = return "cuMemcpyHtoD"
+ memcpyFun (Space "device") (Space "device") = return "cuMemcpy"
+ memcpyFun _ _ = fail $ "Cannot copy to '" ++ show dstSpace
+ ++ "' from '" ++ show srcSpace ++ "'."
+
+staticCUDAArray :: GC.StaticArray OpenCL ()
+staticCUDAArray name "device" t vals = do
+ let ct = GC.primTypeToCType t
+ vals' = [[C.cinit|$exp:v|] | v <- map GC.compilePrimValue vals]
+ num_elems = length vals
+ name_realtype <- newVName $ baseString name ++ "_realtype"
+ GC.libDecl [C.cedecl|static $ty:ct $id:name_realtype[$int:num_elems] = {$inits:vals'};|]
+ -- Fake a memory block.
+ GC.contextField (pretty name) [C.cty|struct memblock_device|] Nothing
+ -- During startup, copy the data to where we need it.
+ GC.atInit [C.cstm|{
+ ctx->$id:name.references = NULL;
+ ctx->$id:name.size = 0;
+ CUDA_SUCCEED(cuMemAlloc(&ctx->$id:name.mem,
+ ($int:num_elems > 0 ? $int:num_elems : 1)*sizeof($ty:ct)));
+ if ($int:num_elems > 0) {
+ CUDA_SUCCEED(cuMemcpyHtoD(ctx->$id:name.mem, $id:name_realtype,
+ $int:num_elems*sizeof($ty:ct)));
+ }
+ }|]
+ GC.item [C.citem|struct memblock_device $id:name = ctx->$id:name;|]
+staticCUDAArray _ space _ _ =
+ fail $ "CUDA backend cannot create static array in '" ++ space
+ ++ "' memory space"
+
+cudaMemoryType :: GC.MemoryType OpenCL ()
+cudaMemoryType "device" = return [C.cty|typename CUdeviceptr|]
+cudaMemoryType "local" = pure [C.cty|unsigned char|] -- dummy type
+cudaMemoryType space =
+ fail $ "CUDA backend does not support '" ++ space ++ "' memory space."
+
+callKernel :: GC.OpCompiler OpenCL ()
+callKernel (HostCode c) = GC.compileCode c
+callKernel (GetSize v key) =
+ GC.stm [C.cstm|$id:v = ctx->sizes.$id:key;|]
+callKernel (CmpSizeLe v key x) = do
+ x' <- GC.compileExp x
+ GC.stm [C.cstm|$id:v = ctx->sizes.$id:key <= $exp:x';|]
+callKernel (GetSizeMax v size_class) =
+ let field = "max_" ++ cudaSizeClass size_class
+ in GC.stm [C.cstm|$id:v = ctx->cuda.$id:field;|]
+ where
+ cudaSizeClass (SizeThreshold _) = "threshold"
+ cudaSizeClass SizeGroup = "block_size"
+ cudaSizeClass SizeNumGroups = "grid_size"
+ cudaSizeClass SizeTile = "tile_size"
+callKernel (LaunchKernel name args num_blocks block_size) = do
+ args_arr <- newVName "kernel_args"
+ time_start <- newVName "time_start"
+ time_end <- newVName "time_end"
+ (args', shared_vars) <- unzip <$> mapM mkArgs args
+ let (shared_sizes, shared_offsets) = unzip $ catMaybes shared_vars
+ shared_offsets_sc = mkOffsets shared_sizes
+ shared_args = zip shared_offsets shared_offsets_sc
+ shared_tot = last shared_offsets_sc
+ mapM_ (\(arg,offset) ->
+ GC.decl [C.cdecl|unsigned int $id:arg = $exp:offset;|]
+ ) shared_args
+
+ (grid_x, grid_y, grid_z) <- mkDims <$> mapM GC.compileExp num_blocks
+ (block_x, block_y, block_z) <- mkDims <$> mapM GC.compileExp block_size
+ let perm_args
+ | length num_blocks == 3 = [ [C.cinit|&perm[0]|], [C.cinit|&perm[1]|], [C.cinit|&perm[2]|] ]
+ | otherwise = []
+ let args'' = perm_args ++ [ [C.cinit|&$id:a|] | a <- args' ]
+ sizes_nonzero = expsNotZero [grid_x, grid_y, grid_z,
+ block_x, block_y, block_z]
+ GC.stm [C.cstm|
+ if ($exp:sizes_nonzero) {
+ int perm[3] = { 0, 1, 2 };
+
+ if ($exp:grid_y > (1<<16)) {
+ perm[1] = perm[0];
+ perm[0] = 1;
+ }
+
+ if ($exp:grid_z > (1<<16)) {
+ perm[2] = perm[0];
+ perm[0] = 2;
+ }
+
+ size_t grid[3];
+ grid[perm[0]] = $exp:grid_x;
+ grid[perm[1]] = $exp:grid_y;
+ grid[perm[2]] = $exp:grid_z;
+
+ void *$id:args_arr[] = { $inits:args'' };
+ typename int64_t $id:time_start = 0, $id:time_end = 0;
+ if (ctx->debugging) {
+ fprintf(stderr, "Launching %s with grid size (", $string:name);
+ $stms:(printSizes [grid_x, grid_y, grid_z])
+ fprintf(stderr, ") and block size (");
+ $stms:(printSizes [block_x, block_y, block_z])
+ fprintf(stderr, ").\n");
+ $id:time_start = get_wall_time();
+ }
+ CUDA_SUCCEED(
+ cuLaunchKernel(ctx->$id:name,
+ grid[0], grid[1], grid[2],
+ $exp:block_x, $exp:block_y, $exp:block_z,
+ $exp:shared_tot, NULL,
+ $id:args_arr, NULL));
+ if (ctx->debugging) {
+ CUDA_SUCCEED(cuCtxSynchronize());
+ $id:time_end = get_wall_time();
+ fprintf(stderr, "Kernel %s runtime: %ldus\n",
+ $string:name, $id:time_end - $id:time_start);
+ }
+ }|]
+ where
+ mkDims [] = ([C.cexp|0|] , [C.cexp|0|], [C.cexp|0|])
+ mkDims [x] = (x, [C.cexp|1|], [C.cexp|1|])
+ mkDims [x,y] = (x, y, [C.cexp|1|])
+ mkDims (x:y:z:_) = (x, y, z)
+ addExp x y = [C.cexp|$exp:x + $exp:y|]
+ alignExp e = [C.cexp|$exp:e + ((8 - ($exp:e % 8)) % 8)|]
+ mkOffsets = scanl (\a b -> a `addExp` alignExp b) [C.cexp|0|]
+ expNotZero e = [C.cexp|$exp:e != 0|]
+ expAnd a b = [C.cexp|$exp:a && $exp:b|]
+ expsNotZero = foldl expAnd [C.cexp|1|] . map expNotZero
+ mkArgs (ValueKArg e t) =
+ (,Nothing) <$> GC.compileExpToName "kernel_arg" t e
+ mkArgs (MemKArg v) = do
+ v' <- GC.rawMem v
+ arg <- newVName "kernel_arg"
+ GC.decl [C.cdecl|typename CUdeviceptr $id:arg = $exp:v';|]
+ return (arg, Nothing)
+ mkArgs (SharedMemoryKArg (Count c)) = do
+ num_bytes <- GC.compileExp c
+ size <- newVName "shared_size"
+ offset <- newVName "shared_offset"
+ GC.decl [C.cdecl|unsigned int $id:size = $exp:num_bytes;|]
+ return (offset, Just (size, offset))
+
+ printSizes =
+ intercalate [[C.cstm|fprintf(stderr, ", ");|]] . map printSize
+ printSize e =
+ [[C.cstm|fprintf(stderr, "%zu", $exp:e);|]]
diff --git a/src/Futhark/CodeGen/Backends/CCUDA/Boilerplate.hs b/src/Futhark/CodeGen/Backends/CCUDA/Boilerplate.hs
new file mode 100644
index 0000000..67af1f7
--- /dev/null
+++ b/src/Futhark/CodeGen/Backends/CCUDA/Boilerplate.hs
@@ -0,0 +1,256 @@
+{-# LANGUAGE QuasiQuotes #-}
+{-# LANGUAGE TemplateHaskell #-}
+module Futhark.CodeGen.Backends.CCUDA.Boilerplate
+ (
+ generateBoilerplate
+ ) where
+
+import qualified Language.C.Quote.OpenCL as C
+
+import qualified Futhark.CodeGen.Backends.GenericC as GC
+import Futhark.Representation.ExplicitMemory hiding (GetSize, CmpSizeLe, GetSizeMax)
+import Futhark.CodeGen.ImpCode.OpenCL
+import Futhark.Util (chunk, zEncodeString)
+
+import qualified Data.Map as M
+import Data.FileEmbed (embedStringFile)
+
+
+
+generateBoilerplate :: String -> String -> [String]
+ -> M.Map Name SizeClass
+ -> GC.CompilerM OpenCL () ()
+generateBoilerplate cuda_program cuda_prelude kernel_names sizes = do
+ GC.earlyDecls [C.cunit|
+ $esc:("#include <cuda.h>")
+ $esc:("#include <nvrtc.h>")
+ $esc:("typedef CUdeviceptr fl_mem_t;")
+ $esc:free_list_h
+ $esc:cuda_h
+ const char *cuda_program[] = {$inits:fragments, NULL};
+ |]
+
+ generateSizeFuns sizes
+ cfg <- generateConfigFuns sizes
+ generateContextFuns cfg kernel_names sizes
+ where
+ cuda_h = $(embedStringFile "rts/c/cuda.h")
+ free_list_h = $(embedStringFile "rts/c/free_list.h")
+ fragments = map (\s -> [C.cinit|$string:s|])
+ $ chunk 2000 (cuda_prelude ++ cuda_program)
+
+generateSizeFuns :: M.Map Name SizeClass -> GC.CompilerM OpenCL () ()
+generateSizeFuns sizes = do
+ let size_name_inits = map (\k -> [C.cinit|$string:(pretty k)|]) $ M.keys sizes
+ size_var_inits = map (\k -> [C.cinit|$string:(zEncodeString (pretty k))|]) $ M.keys sizes
+ size_class_inits = map (\c -> [C.cinit|$string:(pretty c)|]) $ M.elems sizes
+ num_sizes = M.size sizes
+
+ GC.libDecl [C.cedecl|static const char *size_names[] = { $inits:size_name_inits };|]
+ GC.libDecl [C.cedecl|static const char *size_vars[] = { $inits:size_var_inits };|]
+ GC.libDecl [C.cedecl|static const char *size_classes[] = { $inits:size_class_inits };|]
+
+ GC.publicDef_ "get_num_sizes" GC.InitDecl $ \s ->
+ ([C.cedecl|int $id:s(void);|],
+ [C.cedecl|int $id:s(void) {
+ return $int:num_sizes;
+ }|])
+
+ GC.publicDef_ "get_size_name" GC.InitDecl $ \s ->
+ ([C.cedecl|const char* $id:s(int);|],
+ [C.cedecl|const char* $id:s(int i) {
+ return size_names[i];
+ }|])
+
+ GC.publicDef_ "get_size_class" GC.InitDecl $ \s ->
+ ([C.cedecl|const char* $id:s(int);|],
+ [C.cedecl|const char* $id:s(int i) {
+ return size_classes[i];
+ }|])
+
+generateConfigFuns :: M.Map Name SizeClass -> GC.CompilerM OpenCL () String
+generateConfigFuns sizes = do
+ let size_decls = map (\k -> [C.csdecl|size_t $id:k;|]) $ M.keys sizes
+ num_sizes = M.size sizes
+ GC.libDecl [C.cedecl|struct sizes { $sdecls:size_decls };|]
+ cfg <- GC.publicDef "context_config" GC.InitDecl $ \s ->
+ ([C.cedecl|struct $id:s;|],
+ [C.cedecl|struct $id:s { struct cuda_config cu_cfg;
+ size_t sizes[$int:num_sizes];
+ };|])
+
+ let size_value_inits = map (\i -> [C.cstm|cfg->sizes[$int:i] = 0;|])
+ [0..M.size sizes-1]
+ GC.publicDef_ "context_config_new" GC.InitDecl $ \s ->
+ ([C.cedecl|struct $id:cfg* $id:s(void);|],
+ [C.cedecl|struct $id:cfg* $id:s(void) {
+ struct $id:cfg *cfg = malloc(sizeof(struct $id:cfg));
+ if (cfg == NULL) {
+ return NULL;
+ }
+
+ $stms:size_value_inits
+ cuda_config_init(&cfg->cu_cfg, $int:num_sizes,
+ size_names, size_vars,
+ cfg->sizes, size_classes);
+ return cfg;
+ }|])
+
+ GC.publicDef_ "context_config_free" GC.InitDecl $ \s ->
+ ([C.cedecl|void $id:s(struct $id:cfg* cfg);|],
+ [C.cedecl|void $id:s(struct $id:cfg* cfg) {
+ free(cfg);
+ }|])
+
+ GC.publicDef_ "context_config_set_debugging" GC.InitDecl $ \s ->
+ ([C.cedecl|void $id:s(struct $id:cfg* cfg, int flag);|],
+ [C.cedecl|void $id:s(struct $id:cfg* cfg, int flag) {
+ cfg->cu_cfg.logging = cfg->cu_cfg.debugging = flag;
+ }|])
+
+ GC.publicDef_ "context_config_set_logging" GC.InitDecl $ \s ->
+ ([C.cedecl|void $id:s(struct $id:cfg* cfg, int flag);|],
+ [C.cedecl|void $id:s(struct $id:cfg* cfg, int flag) {
+ cfg->cu_cfg.logging = flag;
+ }|])
+
+ GC.publicDef_ "context_config_set_device" GC.InitDecl $ \s ->
+ ([C.cedecl|void $id:s(struct $id:cfg* cfg, const char *s);|],
+ [C.cedecl|void $id:s(struct $id:cfg* cfg, const char *s) {
+ set_preferred_device(&cfg->cu_cfg, s);
+ }|])
+
+ GC.publicDef_ "context_config_dump_program_to" GC.InitDecl $ \s ->
+ ([C.cedecl|void $id:s(struct $id:cfg* cfg, const char *path);|],
+ [C.cedecl|void $id:s(struct $id:cfg* cfg, const char *path) {
+ cfg->cu_cfg.dump_program_to = path;
+ }|])
+
+ GC.publicDef_ "context_config_load_program_from" GC.InitDecl $ \s ->
+ ([C.cedecl|void $id:s(struct $id:cfg* cfg, const char *path);|],
+ [C.cedecl|void $id:s(struct $id:cfg* cfg, const char *path) {
+ cfg->cu_cfg.load_program_from = path;
+ }|])
+
+ GC.publicDef_ "context_config_dump_ptx_to" GC.InitDecl $ \s ->
+ ([C.cedecl|void $id:s(struct $id:cfg* cfg, const char *path);|],
+ [C.cedecl|void $id:s(struct $id:cfg* cfg, const char *path) {
+ cfg->cu_cfg.dump_ptx_to = path;
+ }|])
+
+ GC.publicDef_ "context_config_load_ptx_from" GC.InitDecl $ \s ->
+ ([C.cedecl|void $id:s(struct $id:cfg* cfg, const char *path);|],
+ [C.cedecl|void $id:s(struct $id:cfg* cfg, const char *path) {
+ cfg->cu_cfg.load_ptx_from = path;
+ }|])
+
+ GC.publicDef_ "context_config_set_default_block_size" GC.InitDecl $ \s ->
+ ([C.cedecl|void $id:s(struct $id:cfg* cfg, int size);|],
+ [C.cedecl|void $id:s(struct $id:cfg* cfg, int size) {
+ cfg->cu_cfg.default_block_size = size;
+ cfg->cu_cfg.default_block_size_changed = 1;
+ }|])
+
+ GC.publicDef_ "context_config_set_default_grid_size" GC.InitDecl $ \s ->
+ ([C.cedecl|void $id:s(struct $id:cfg* cfg, int num);|],
+ [C.cedecl|void $id:s(struct $id:cfg* cfg, int num) {
+ cfg->cu_cfg.default_grid_size = num;
+ }|])
+
+ GC.publicDef_ "context_config_set_default_tile_size" GC.InitDecl $ \s ->
+ ([C.cedecl|void $id:s(struct $id:cfg* cfg, int num);|],
+ [C.cedecl|void $id:s(struct $id:cfg* cfg, int size) {
+ cfg->cu_cfg.default_tile_size = size;
+ cfg->cu_cfg.default_tile_size_changed = 1;
+ }|])
+
+ GC.publicDef_ "context_config_set_default_threshold" GC.InitDecl $ \s ->
+ ([C.cedecl|void $id:s(struct $id:cfg* cfg, int num);|],
+ [C.cedecl|void $id:s(struct $id:cfg* cfg, int size) {
+ cfg->cu_cfg.default_threshold = size;
+ }|])
+
+ GC.publicDef_ "context_config_set_size" GC.InitDecl $ \s ->
+ ([C.cedecl|int $id:s(struct $id:cfg* cfg, const char *size_name, size_t size_value);|],
+ [C.cedecl|int $id:s(struct $id:cfg* cfg, const char *size_name, size_t size_value) {
+
+ for (int i = 0; i < $int:num_sizes; i++) {
+ if (strcmp(size_name, size_names[i]) == 0) {
+ cfg->sizes[i] = size_value;
+ return 0;
+ }
+ }
+ return 1;
+ }|])
+ return cfg
+
+generateContextFuns :: String -> [String]
+ -> M.Map Name SizeClass
+ -> GC.CompilerM OpenCL () ()
+generateContextFuns cfg kernel_names sizes = do
+ final_inits <- GC.contextFinalInits
+ (fields, init_fields) <- GC.contextContents
+ let kernel_fields = map (\k -> [C.csdecl|typename CUfunction $id:k;|])
+ kernel_names
+
+ ctx <- GC.publicDef "context" GC.InitDecl $ \s ->
+ ([C.cedecl|struct $id:s;|],
+ [C.cedecl|struct $id:s {
+ int detail_memory;
+ int debugging;
+ typename lock_t lock;
+ char *error;
+ $sdecls:fields
+ $sdecls:kernel_fields
+ struct cuda_context cuda;
+ struct sizes sizes;
+ };|])
+
+ let set_sizes = zipWith (\i k -> [C.cstm|ctx->sizes.$id:k = cfg->sizes[$int:i];|])
+ [(0::Int)..] $ M.keys sizes
+
+ GC.publicDef_ "context_new" GC.InitDecl $ \s ->
+ ([C.cedecl|struct $id:ctx* $id:s(struct $id:cfg* cfg);|],
+ [C.cedecl|struct $id:ctx* $id:s(struct $id:cfg* cfg) {
+ struct $id:ctx* ctx = malloc(sizeof(struct $id:ctx));
+ if (ctx == NULL) {
+ return NULL;
+ }
+ ctx->debugging = ctx->detail_memory = cfg->cu_cfg.debugging;
+
+ ctx->cuda.cfg = cfg->cu_cfg;
+ create_lock(&ctx->lock);
+ $stms:init_fields
+
+ cuda_setup(&ctx->cuda, cuda_program);
+ $stms:(map (loadKernelByName) kernel_names)
+
+ $stms:final_inits
+ $stms:set_sizes
+ return ctx;
+ }|])
+
+ GC.publicDef_ "context_free" GC.InitDecl $ \s ->
+ ([C.cedecl|void $id:s(struct $id:ctx* ctx);|],
+ [C.cedecl|void $id:s(struct $id:ctx* ctx) {
+ cuda_cleanup(&ctx->cuda);
+ free_lock(&ctx->lock);
+ free(ctx);
+ }|])
+
+ GC.publicDef_ "context_sync" GC.InitDecl $ \s ->
+ ([C.cedecl|int $id:s(struct $id:ctx* ctx);|],
+ [C.cedecl|int $id:s(struct $id:ctx* ctx) {
+ CUDA_SUCCEED(cuCtxSynchronize());
+ return 0;
+ }|])
+
+ GC.publicDef_ "context_get_error" GC.InitDecl $ \s ->
+ ([C.cedecl|char* $id:s(struct $id:ctx* ctx);|],
+ [C.cedecl|char* $id:s(struct $id:ctx* ctx) {
+ return ctx->error;
+ }|])
+ where
+ loadKernelByName name =
+ [C.cstm|CUDA_SUCCEED(cuModuleGetFunction(&ctx->$id:name,
+ ctx->cuda.module, $string:name));|]
diff --git a/src/Futhark/CodeGen/Backends/COpenCL.hs b/src/Futhark/CodeGen/Backends/COpenCL.hs
index a6df4cb..bea6af5 100644
--- a/src/Futhark/CodeGen/Backends/COpenCL.hs
+++ b/src/Futhark/CodeGen/Backends/COpenCL.hs
@@ -91,16 +91,24 @@ cliOptions = [ Option { optionLongName = "platform"
, optionArgument = RequiredArgument
, optionAction = [C.cstm|futhark_context_config_load_program_from(cfg, optarg);|]
}
+ , Option { optionLongName = "dump-opencl-binary"
+ , optionShortName = Nothing
+ , optionArgument = RequiredArgument
+ , optionAction = [C.cstm|futhark_context_config_dump_binary_to(cfg, optarg);|]
+ }
+ , Option { optionLongName = "load-opencl-binary"
+ , optionShortName = Nothing
+ , optionArgument = RequiredArgument
+ , optionAction = [C.cstm|futhark_context_config_load_binary_from(cfg, optarg);|]
+ }
, Option { optionLongName = "print-sizes"
, optionShortName = Nothing
, optionArgument = NoArgument
, optionAction = [C.cstm|{
int n = futhark_get_num_sizes();
for (int i = 0; i < n; i++) {
- if (strcmp(futhark_get_size_entry(i), entry_point) == 0) {
- printf("%s (%s)\n", futhark_get_size_name(i),
- futhark_get_size_class(i));
- }
+ printf("%s (%s)\n", futhark_get_size_name(i),
+ futhark_get_size_class(i));
}
exit(0);
}|]
@@ -276,11 +284,11 @@ callKernel (GetSizeMax v size_class) =
callKernel (HostCode c) =
GC.compileCode c
-callKernel (LaunchKernel name args kernel_size workgroup_size) = do
+callKernel (LaunchKernel name args num_workgroups workgroup_size) = do
zipWithM_ setKernelArg [(0::Int)..] args
- kernel_size' <- mapM GC.compileExp kernel_size
+ num_workgroups' <- mapM GC.compileExp num_workgroups
workgroup_size' <- mapM GC.compileExp workgroup_size
- launchKernel name kernel_size' workgroup_size'
+ launchKernel name num_workgroups' workgroup_size'
where setKernelArg i (ValueKArg e bt) = do
v <- GC.compileExpToName "kernel_arg" bt e
GC.stm [C.cstm|
@@ -301,7 +309,7 @@ callKernel (LaunchKernel name args kernel_size workgroup_size) = do
launchKernel :: C.ToExp a =>
String -> [a] -> [a] -> GC.CompilerM op s ()
-launchKernel kernel_name kernel_dims workgroup_dims = do
+launchKernel kernel_name num_workgroups workgroup_dims = do
global_work_size <- newVName "global_work_size"
time_start <- newVName "time_start"
time_end <- newVName "time_end"
@@ -336,6 +344,7 @@ launchKernel kernel_name kernel_dims workgroup_dims = do
}
}|]
where kernel_rank = length kernel_dims
+ kernel_dims = zipWith multExp num_workgroups workgroup_dims
kernel_dims' = map toInit kernel_dims
workgroup_dims' = map toInit workgroup_dims
total_elements = foldl multExp [C.cexp|1|] kernel_dims
diff --git a/src/Futhark/CodeGen/Backends/COpenCL/Boilerplate.hs b/src/Futhark/CodeGen/Backends/COpenCL/Boilerplate.hs
index 2478e7e..e2df848 100644
--- a/src/Futhark/CodeGen/Backends/COpenCL/Boilerplate.hs
+++ b/src/Futhark/CodeGen/Backends/COpenCL/Boilerplate.hs
@@ -15,10 +15,10 @@ import qualified Language.C.Quote.OpenCL as C
import Futhark.CodeGen.ImpCode.OpenCL
import qualified Futhark.CodeGen.Backends.GenericC as GC
import Futhark.CodeGen.OpenCL.Kernels
-import Futhark.Util (chunk)
+import Futhark.Util (chunk, zEncodeString)
generateBoilerplate :: String -> String -> [String] -> [PrimType]
- -> M.Map VName (SizeClass, Name)
+ -> M.Map Name SizeClass
-> GC.CompilerM OpenCL () ()
generateBoilerplate opencl_code opencl_prelude kernel_names types sizes = do
final_inits <- GC.contextFinalInits
@@ -29,13 +29,13 @@ generateBoilerplate opencl_code opencl_prelude kernel_names types sizes = do
GC.earlyDecls top_decls
let size_name_inits = map (\k -> [C.cinit|$string:(pretty k)|]) $ M.keys sizes
- size_class_inits = map (\(c,_) -> [C.cinit|$string:(pretty c)|]) $ M.elems sizes
- size_entry_points_inits = map (\(_,e) -> [C.cinit|$string:(pretty e)|]) $ M.elems sizes
+ size_var_inits = map (\k -> [C.cinit|$string:(zEncodeString (pretty k))|]) $ M.keys sizes
+ size_class_inits = map (\c -> [C.cinit|$string:(pretty c)|]) $ M.elems sizes
num_sizes = M.size sizes
GC.libDecl [C.cedecl|static const char *size_names[] = { $inits:size_name_inits };|]
+ GC.libDecl [C.cedecl|static const char *size_vars[] = { $inits:size_var_inits };|]
GC.libDecl [C.cedecl|static const char *size_classes[] = { $inits:size_class_inits };|]
- GC.libDecl [C.cedecl|static const char *size_entry_points[] = { $inits:size_entry_points_inits };|]
GC.publicDef_ "get_num_sizes" GC.InitDecl $ \s ->
([C.cedecl|int $id:s(void);|],
@@ -55,12 +55,6 @@ generateBoilerplate opencl_code opencl_prelude kernel_names types sizes = do
return size_classes[i];
}|])
- GC.publicDef_ "get_size_entry" GC.InitDecl $ \s ->
- ([C.cedecl|const char* $id:s(int);|],
- [C.cedecl|const char* $id:s(int i) {
- return size_entry_points[i];
- }|])
-
let size_decls = map (\k -> [C.csdecl|size_t $id:k;|]) $ M.keys sizes
GC.libDecl [C.cedecl|struct sizes { $sdecls:size_decls };|]
cfg <- GC.publicDef "context_config" GC.InitDecl $ \s ->
@@ -80,7 +74,8 @@ generateBoilerplate opencl_code opencl_prelude kernel_names types sizes = do
$stms:size_value_inits
opencl_config_init(&cfg->opencl, $int:num_sizes,
- size_names, cfg->sizes, size_classes, size_entry_points);
+ size_names, size_vars,
+ cfg->sizes, size_classes);
return cfg;
}|])
@@ -126,6 +121,19 @@ generateBoilerplate opencl_code opencl_prelude kernel_names types sizes = do
cfg->opencl.load_program_from = path;
}|])
+
+ GC.publicDef_ "context_config_dump_binary_to" GC.InitDecl $ \s ->
+ ([C.cedecl|void $id:s(struct $id:cfg* cfg, const char *path);|],
+ [C.cedecl|void $id:s(struct $id:cfg* cfg, const char *path) {
+ cfg->opencl.dump_binary_to = path;
+ }|])
+
+ GC.publicDef_ "context_config_load_binary_from" GC.InitDecl $ \s ->
+ ([C.cedecl|void $id:s(struct $id:cfg* cfg, const char *path);|],
+ [C.cedecl|void $id:s(struct $id:cfg* cfg, const char *path) {
+ cfg->opencl.load_binary_from = path;
+ }|])
+
GC.publicDef_ "context_config_set_default_group_size" GC.InitDecl $ \s ->
([C.cedecl|void $id:s(struct $id:cfg* cfg, int size);|],
[C.cedecl|void $id:s(struct $id:cfg* cfg, int size) {
@@ -317,17 +325,28 @@ void post_opencl_setup(struct opencl_context *ctx, struct opencl_device_option *
$stms:(map sizeHeuristicsCode sizeHeuristicsTable)
}|]]
+ free_list_h = $(embedStringFile "rts/c/free_list.h")
openCL_h = $(embedStringFile "rts/c/opencl.h")
program_fragments = opencl_program_fragments ++ [[C.cinit|NULL|]]
openCL_boilerplate = [C.cunit|
+
+ $esc:("#define CL_USE_DEPRECATED_OPENCL_1_2_APIS")
+ $esc:("#define CL_SILENCE_DEPRECATION // For macOS.")
+ $esc:("#ifdef __APPLE__")
+ $esc:(" #include <OpenCL/cl.h>")
+ $esc:("#else")
+ $esc:(" #include <CL/cl.h>")
+ $esc:("#endif")
+ $esc:("typedef cl_mem fl_mem_t;")
+ $esc:free_list_h
$esc:openCL_h
const char *opencl_program[] = {$inits:program_fragments};|]
loadKernelByName :: String -> C.Stm
loadKernelByName name = [C.cstm|{
ctx->$id:name = clCreateKernel(prog, $string:name, &error);
- assert(error == 0);
+ OPENCL_SUCCEED_FATAL(error);
if (ctx->debugging) {
fprintf(stderr, "Created kernel %s.\n", $string:name);
}
@@ -373,7 +392,7 @@ sizeHeuristicsCode (SizeHeuristic platform_name device_type which what) =
[C.cstm|
if ($exp:which' == 0 &&
strstr(option->platform_name, $string:platform_name) != NULL &&
- option->device_type == $exp:(clDeviceType device_type)) {
+ (option->device_type & $exp:(clDeviceType device_type)) == $exp:(clDeviceType device_type)) {
$stm:get_size
}|]
where clDeviceType DeviceGPU = [C.cexp|CL_DEVICE_TYPE_GPU|]
diff --git a/src/Futhark/CodeGen/Backends/CSOpenCL.hs b/src/Futhark/CodeGen/Backends/CSOpenCL.hs
index e5cb894..401111c 100644
--- a/src/Futhark/CodeGen/Backends/CSOpenCL.hs
+++ b/src/Futhark/CodeGen/Backends/CSOpenCL.hs
@@ -17,6 +17,7 @@ import Futhark.CodeGen.Backends.GenericCSharp.AST
import Futhark.CodeGen.Backends.GenericCSharp.Options
import Futhark.CodeGen.Backends.GenericCSharp.Definitions
import Futhark.Util.Pretty(pretty)
+import Futhark.Util (zEncodeString)
import Futhark.MonadFreshNames hiding (newVName')
@@ -119,24 +120,32 @@ cliOptions = [ Option { optionLongName = "platform"
callKernel :: CS.OpCompiler Imp.OpenCL ()
callKernel (Imp.GetSize v key) =
CS.stm $ Reassign (Var (CS.compileName v)) $
- Field (Var "Ctx.Sizes") $ pretty key
+ Field (Var "Ctx.Sizes") $ zEncodeString $ pretty key
callKernel (Imp.GetSizeMax v size_class) =
CS.stm $ Reassign (Var (CS.compileName v)) $
- Var $ "max_" ++ pretty size_class
+ Field (Var "Ctx.OpenCL") $
+ case size_class of Imp.SizeGroup -> "MaxGroupSize"
+ Imp.SizeNumGroups -> "MaxNumGroups"
+ Imp.SizeTile -> "MaxTileSize"
+ Imp.SizeThreshold{} -> "MaxThreshold"
callKernel (Imp.HostCode c) = CS.compileCode c
-callKernel (Imp.LaunchKernel name args kernel_size workgroup_size) = do
- kernel_size' <- mapM CS.compileExp kernel_size
- let total_elements = foldl mult_exp (Integer 1) kernel_size'
- let cond = BinOp "!=" total_elements (Integer 0)
+callKernel (Imp.LaunchKernel name args num_workgroups workgroup_size) = do
+ num_workgroups' <- mapM CS.compileExp num_workgroups
workgroup_size' <- mapM CS.compileExp workgroup_size
- body <- CS.collect $ launchKernel name kernel_size' workgroup_size' args
+ let kernel_size = zipWith mult_exp num_workgroups' workgroup_size'
+ total_elements = foldl mult_exp (Integer 1) kernel_size
+ cond = BinOp "!=" total_elements (Integer 0)
+ body <- CS.collect $ launchKernel name kernel_size workgroup_size' args
CS.stm $ If cond body []
where mult_exp = BinOp "*"
-callKernel _ = undefined
+callKernel (Imp.CmpSizeLe v key x) = do
+ x' <- CS.compileExp x
+ CS.stm $ Reassign (Var (CS.compileName v)) $
+ BinOp "<=" (Field (Var "Ctx.Sizes") (zEncodeString $ pretty key)) x'
launchKernel :: String -> [CSExp] -> [CSExp] -> [Imp.KernelArg] -> CS.CompilerM op s ()
launchKernel kernel_name kernel_dims workgroup_dims args = do
diff --git a/src/Futhark/CodeGen/Backends/CSOpenCL/Boilerplate.hs b/src/Futhark/CodeGen/Backends/CSOpenCL/Boilerplate.hs
index 9787174..64f73e3 100644
--- a/src/Futhark/CodeGen/Backends/CSOpenCL/Boilerplate.hs
+++ b/src/Futhark/CodeGen/Backends/CSOpenCL/Boilerplate.hs
@@ -11,7 +11,7 @@ import Futhark.CodeGen.ImpCode.OpenCL hiding (Index, If)
import Futhark.CodeGen.Backends.GenericCSharp as CS
import Futhark.CodeGen.Backends.GenericCSharp.AST as AST
import Futhark.CodeGen.OpenCL.Kernels
-
+import Futhark.Util (zEncodeString)
intT, longT, stringT, intArrayT, stringArrayT :: CSType
intT = Primitive $ CSInt Int32T
@@ -21,7 +21,7 @@ intArrayT = Composite $ ArrayT intT
stringArrayT = Composite $ ArrayT stringT
generateBoilerplate :: String -> String -> [String] -> [PrimType]
- -> M.Map VName (SizeClass, Name)
+ -> M.Map Name SizeClass
-> CS.CompilerM OpenCL () ()
generateBoilerplate opencl_code opencl_prelude kernel_names types sizes = do
final_inits <- CS.contextFinalInits
@@ -34,17 +34,16 @@ generateBoilerplate opencl_code opencl_prelude kernel_names types sizes = do
CS.stm $ AssignTyped stringArrayT (Var "SizeNames")
(Just $ Collection "string[]" (map (String . pretty) $ M.keys sizes))
- CS.stm $ AssignTyped stringArrayT (Var "SizeClasses")
- (Just $ Collection "string[]" (map (String . pretty . fst) $ M.elems sizes))
+ CS.stm $ AssignTyped stringArrayT (Var "SizeVars")
+ (Just $ Collection "string[]" (map (String . zEncodeString . pretty) $ M.keys sizes))
- CS.stm $ AssignTyped stringArrayT (Var "SizeEntryPoints")
- (Just $ Collection "string[]" (map (String . pretty . snd) $ M.elems sizes))
+ CS.stm $ AssignTyped stringArrayT (Var "SizeClasses")
+ (Just $ Collection "string[]" (map (String . pretty) $ M.elems sizes))
let get_num_sizes = CS.publicName "GetNumSizes"
let get_size_name = CS.publicName "GetSizeName"
let get_size_class = CS.publicName "GetSizeClass"
- let get_size_entry = CS.publicName "GetSizeEntry"
CS.stm $ CS.privateFunDef get_num_sizes intT []
@@ -53,8 +52,6 @@ generateBoilerplate opencl_code opencl_prelude kernel_names types sizes = do
[ Return $ Index (Var "SizeNames") (IdxExp $ Var "i") ]
CS.stm $ CS.privateFunDef get_size_class (Primitive StringT) [(intT, "i")]
[ Return $ Index (Var "SizeClasses") (IdxExp $ Var "i") ]
- CS.stm $ CS.privateFunDef get_size_entry (Primitive StringT) [(intT, "i")]
- [ Return $ Index (Var "SizeEntryPoints") (IdxExp $ Var "i") ]
let cfg = CS.publicName "ContextConfig"
let new_cfg = CS.publicName "ContextConfigNew"
@@ -69,7 +66,7 @@ generateBoilerplate opencl_code opencl_prelude kernel_names types sizes = do
let cfg_set_default_threshold = CS.publicName "ContextConfigSetDefaultThreshold"
let cfg_set_size = CS.publicName "ContextConfigSetSize"
- CS.stm $ StructDef "Sizes" (map (\k -> (intT, pretty k)) $ M.keys sizes)
+ CS.stm $ StructDef "Sizes" (map (\k -> (intT, zEncodeString $ pretty k)) $ M.keys sizes)
CS.stm $ StructDef cfg [ (CustomT "OpenCLConfig", "OpenCL")
, (intArrayT, "Sizes")]
@@ -78,7 +75,7 @@ generateBoilerplate opencl_code opencl_prelude kernel_names types sizes = do
[ Assign tmp_cfg $ CS.simpleInitClass cfg []
, Reassign (Field tmp_cfg "Sizes") (Collection "int[]" (replicate (M.size sizes) (Integer 0)))
, Exp $ CS.simpleCall "OpenCLConfigInit" [ Out $ Field tmp_cfg "OpenCL", (Integer . toInteger) $ M.size sizes
- , Var "SizeNames", Field tmp_cfg "Sizes", Var "SizeClasses" ]
+ , Var "SizeNames", Var "SizeVars", Field tmp_cfg "Sizes", Var "SizeClasses" ]
, Return tmp_cfg
]
@@ -152,7 +149,7 @@ generateBoilerplate opencl_code opencl_prelude kernel_names types sizes = do
let set_required_types = [Reassign (Var "RequiredTypes") (AST.Bool True)
| FloatType Float64 `elem` types]
- set_sizes = zipWith (\i k -> Reassign (Field (Var "Ctx.Sizes") (pretty k))
+ set_sizes = zipWith (\i k -> Reassign (Field (Var "Ctx.Sizes") (zEncodeString $ pretty k))
(Index (Var "Cfg.Sizes") (IdxExp $ (Integer . toInteger) i)))
[(0::Int)..] $ M.keys sizes
diff --git a/src/Futhark/CodeGen/Backends/GenericC.hs b/src/Futhark/CodeGen/Backends/GenericC.hs
index c94d9e5..b6d8c8c 100644
--- a/src/Futhark/CodeGen/Backends/GenericC.hs
+++ b/src/Futhark/CodeGen/Backends/GenericC.hs
@@ -77,7 +77,6 @@ import Data.List
import Data.Loc
import Data.Maybe
import Data.FileEmbed
-import qualified Data.Semigroup as Sem
import Text.Printf
import qualified Language.C.Syntax as C
@@ -228,13 +227,12 @@ newtype CompilerAcc op s = CompilerAcc {
accItems :: DL.DList C.BlockItem
}
-instance Sem.Semigroup (CompilerAcc op s) where
+instance Semigroup (CompilerAcc op s) where
CompilerAcc items1 <> CompilerAcc items2 =
CompilerAcc (items1<>items2)
instance Monoid (CompilerAcc op s) where
mempty = CompilerAcc mempty
- mappend = (Sem.<>)
envOpCompiler :: CompilerEnv op s -> OpCompiler op s
envOpCompiler = opsCompiler . envOperations
@@ -349,6 +347,9 @@ collect' m = pass $ do
item :: C.BlockItem -> CompilerM op s ()
item x = tell $ mempty { accItems = DL.singleton x }
+instance C.ToIdent Name where
+ toIdent = C.toIdent . zEncodeString . nameToString
+
instance C.ToIdent VName where
toIdent = C.toIdent . zEncodeString . pretty
diff --git a/src/Futhark/CodeGen/Backends/GenericCSharp.hs b/src/Futhark/CodeGen/Backends/GenericCSharp.hs
index dbaba8c..72bf940 100644
--- a/src/Futhark/CodeGen/Backends/GenericCSharp.hs
+++ b/src/Futhark/CodeGen/Backends/GenericCSharp.hs
@@ -1,1404 +1,1402 @@
-{-# LANGUAGE OverloadedStrings, GeneralizedNewtypeDeriving, LambdaCase #-}
-{-# LANGUAGE TupleSections #-}
--- | A generic C# code generator which is polymorphic in the type
--- of the operations. Concretely, we use this to handle both
--- sequential and OpenCL C# code.
-module Futhark.CodeGen.Backends.GenericCSharp
- ( compileProg
- , Constructor (..)
- , emptyConstructor
-
- , assignScalarPointer
- , toIntPtr
- , compileName
- , compileDim
- , compileExp
- , compileCode
- , compilePrimValue
- , compilePrimType
- , compilePrimTypeExt
- , compilePrimTypeToAST
- , compilePrimTypeToASText
- , contextFinalInits
- , debugReport
-
- , Operations (..)
- , defaultOperations
-
- , unpackDim
-
- , CompilerM (..)
- , OpCompiler
- , WriteScalar
- , ReadScalar
- , Allocate
- , Copy
- , StaticArray
- , EntryOutput
- , EntryInput
-
- , CompilerEnv(..)
- , CompilerState(..)
- , stm
- , stms
- , atInit
- , staticMemDecl
- , staticMemAlloc
- , addMemberDecl
- , beforeParse
- , collect'
- , collect
- , simpleCall
- , callMethod
- , simpleInitClass
- , parametrizedCall
-
- , copyMemoryDefaultSpace
- , consoleErrorWrite
- , consoleErrorWriteLine
- , consoleWrite
- , consoleWriteLine
-
- , publicName
- , sizeOf
- , privateFunDef
- , publicFunDef
- , getDefaultDecl
- ) where
-
-import Control.Monad.Identity
-import Control.Monad.State
-import Control.Monad.Reader
-import Control.Monad.Writer
-import Control.Monad.RWS
-import Control.Arrow((&&&))
-import Data.Maybe
-import Data.List
-import qualified Data.Map.Strict as M
-import qualified Data.Semigroup as Sem
-
-import Futhark.Representation.Primitive hiding (Bool)
-import Futhark.MonadFreshNames
-import Futhark.Representation.AST.Syntax (Space(..))
-import qualified Futhark.CodeGen.ImpCode as Imp
-import Futhark.CodeGen.Backends.GenericCSharp.AST
-import Futhark.CodeGen.Backends.GenericCSharp.Options
-import Futhark.CodeGen.Backends.GenericCSharp.Definitions
-import Futhark.Util.Pretty(pretty)
-import Futhark.Util (zEncodeString)
-import Futhark.Representation.AST.Attributes (builtInFunctions)
-import Text.Printf (printf)
-
--- | A substitute expression compiler, tried before the main
--- compilation function.
-type OpCompiler op s = op -> CompilerM op s ()
-
--- | Write a scalar to the given memory block with the given index and
--- in the given memory space.
-type WriteScalar op s = VName -> CSExp -> PrimType -> Imp.SpaceId -> CSExp
- -> CompilerM op s ()
-
--- | Read a scalar from the given memory block with the given index and
--- in the given memory space.
-type ReadScalar op s = VName -> CSExp -> PrimType -> Imp.SpaceId
- -> CompilerM op s CSExp
-
--- | Allocate a memory block of the given size in the given memory
--- space, saving a reference in the given variable name.
-type Allocate op s = VName -> CSExp -> Imp.SpaceId
- -> CompilerM op s ()
-
--- | Copy from one memory block to another.
-type Copy op s = VName -> CSExp -> Imp.Space ->
- VName -> CSExp -> Imp.Space ->
- CSExp -> PrimType ->
- CompilerM op s ()
-
--- | Create a static array of values - initialised at load time.
-type StaticArray op s = VName -> Imp.SpaceId -> PrimType -> [PrimValue] -> CompilerM op s ()
-
--- | Construct the C# array being returned from an entry point.
-type EntryOutput op s = VName -> Imp.SpaceId ->
- PrimType -> Imp.Signedness ->
- [Imp.DimSize] ->
- CompilerM op s CSExp
-
--- | Unpack the array being passed to an entry point.
-type EntryInput op s = VName -> Imp.MemSize -> Imp.SpaceId ->
- PrimType -> Imp.Signedness ->
- [Imp.DimSize] ->
- CSExp ->
- CompilerM op s ()
-
-data Operations op s = Operations { opsWriteScalar :: WriteScalar op s
- , opsReadScalar :: ReadScalar op s
- , opsAllocate :: Allocate op s
- , opsCopy :: Copy op s
- , opsStaticArray :: StaticArray op s
- , opsCompiler :: OpCompiler op s
- , opsEntryOutput :: EntryOutput op s
- , opsEntryInput :: EntryInput op s
- , opsSyncRun :: CSStmt
- }
-
--- | A set of operations that fail for every operation involving
--- non-default memory spaces. Uses plain pointers and @malloc@ for
--- memory management.
-defaultOperations :: Operations op s
-defaultOperations = Operations { opsWriteScalar = defWriteScalar
- , opsReadScalar = defReadScalar
- , opsAllocate = defAllocate
- , opsCopy = defCopy
- , opsStaticArray = defStaticArray
- , opsCompiler = defCompiler
- , opsEntryOutput = defEntryOutput
- , opsEntryInput = defEntryInput
- , opsSyncRun = defSyncRun
- }
- where defWriteScalar _ _ _ _ _ =
- fail "Cannot write to non-default memory space because I am dumb"
- defReadScalar _ _ _ _ =
- fail "Cannot read from non-default memory space"
- defAllocate _ _ _ =
- fail "Cannot allocate in non-default memory space"
- defCopy _ _ _ _ _ _ _ _ =
- fail "Cannot copy to or from non-default memory space"
- defStaticArray _ _ _ _ =
- fail "Cannot create static array in non-default memory space"
- defCompiler _ =
- fail "The default compiler cannot compile extended operations"
- defEntryOutput _ _ _ _ =
- fail "Cannot return array not in default memory space"
- defEntryInput _ _ _ _ =
- fail "Cannot accept array not in default memory space"
- defSyncRun =
- Pass
-
-data CompilerEnv op s = CompilerEnv {
- envOperations :: Operations op s
- , envFtable :: M.Map Name [Imp.Type]
-}
-
-data CompilerAcc op s = CompilerAcc {
- accItems :: [CSStmt]
- , accFreedMem :: [VName]
- }
-
-instance Sem.Semigroup (CompilerAcc op s) where
- CompilerAcc items1 freed1 <> CompilerAcc items2 freed2 =
- CompilerAcc (items1<>items2) (freed1<>freed2)
-
-instance Monoid (CompilerAcc op s) where
- mempty = CompilerAcc mempty mempty
- mappend = (Sem.<>)
-
-envOpCompiler :: CompilerEnv op s -> OpCompiler op s
-envOpCompiler = opsCompiler . envOperations
-
-envReadScalar :: CompilerEnv op s -> ReadScalar op s
-envReadScalar = opsReadScalar . envOperations
-
-envWriteScalar :: CompilerEnv op s -> WriteScalar op s
-envWriteScalar = opsWriteScalar . envOperations
-
-envAllocate :: CompilerEnv op s -> Allocate op s
-envAllocate = opsAllocate . envOperations
-
-envCopy :: CompilerEnv op s -> Copy op s
-envCopy = opsCopy . envOperations
-
-envStaticArray :: CompilerEnv op s -> StaticArray op s
-envStaticArray = opsStaticArray . envOperations
-
-envEntryOutput :: CompilerEnv op s -> EntryOutput op s
-envEntryOutput = opsEntryOutput . envOperations
-
-envEntryInput :: CompilerEnv op s -> EntryInput op s
-envEntryInput = opsEntryInput . envOperations
-
-envSyncFun :: CompilerEnv op s -> CSStmt
-envSyncFun = opsSyncRun . envOperations
-
-newCompilerEnv :: Imp.Functions op -> Operations op s -> CompilerEnv op s
-newCompilerEnv (Imp.Functions funs) ops =
- CompilerEnv { envOperations = ops
- , envFtable = ftable <> builtinFtable
- }
- where ftable = M.fromList $ map funReturn funs
- funReturn (name, Imp.Function _ outparams _ _ _ _) = (name, paramsTypes outparams)
- builtinFtable = M.map (map Imp.Scalar . snd) builtInFunctions
-
-data CompilerState s = CompilerState {
- compNameSrc :: VNameSource
- , compBeforeParse :: [CSStmt]
- , compInit :: [CSStmt]
- , compStaticMemDecls :: [CSStmt]
- , compStaticMemAllocs :: [CSStmt]
- , compDebugItems :: [CSStmt]
- , compUserState :: s
- , compMemberDecls :: [CSStmt]
- , compAssignedVars :: [VName]
- , compDeclaredMem :: [(VName, Space)]
-}
-
-newCompilerState :: VNameSource -> s -> CompilerState s
-newCompilerState src s = CompilerState { compNameSrc = src
- , compBeforeParse = []
- , compInit = []
- , compStaticMemDecls = []
- , compStaticMemAllocs = []
- , compDebugItems = []
- , compMemberDecls = []
- , compUserState = s
- , compAssignedVars = []
- , compDeclaredMem = []
- }
-
-newtype CompilerM op s a = CompilerM (RWS (CompilerEnv op s) (CompilerAcc op s) (CompilerState s) a)
- deriving (Functor, Applicative, Monad,
- MonadState (CompilerState s),
- MonadReader (CompilerEnv op s),
- MonadWriter (CompilerAcc op s))
-
-instance MonadFreshNames (CompilerM op s) where
- getNameSource = gets compNameSrc
- putNameSource src = modify $ \s -> s { compNameSrc = src }
-
-collect :: CompilerM op s () -> CompilerM op s [CSStmt]
-collect m = pass $ do
- ((), w) <- listen m
- return (accItems w,
- const w { accItems = mempty} )
-
-collect' :: CompilerM op s a -> CompilerM op s (a, [CSStmt])
-collect' m = pass $ do
- (x, w) <- listen m
- return ((x, accItems w),
- const w { accItems = mempty})
-
-beforeParse :: CSStmt -> CompilerM op s ()
-beforeParse x = modify $ \s ->
- s { compBeforeParse = compBeforeParse s ++ [x] }
-
-atInit :: CSStmt -> CompilerM op s ()
-atInit x = modify $ \s ->
- s { compInit = compInit s ++ [x] }
-
-staticMemDecl :: CSStmt -> CompilerM op s ()
-staticMemDecl x = modify $ \s ->
- s { compStaticMemDecls = compStaticMemDecls s ++ [x] }
-
-staticMemAlloc :: CSStmt -> CompilerM op s ()
-staticMemAlloc x = modify $ \s ->
- s { compStaticMemAllocs = compStaticMemAllocs s ++ [x] }
-
-addMemberDecl :: CSStmt -> CompilerM op s ()
-addMemberDecl x = modify $ \s ->
- s { compMemberDecls = compMemberDecls s ++ [x] }
-
-contextFinalInits :: CompilerM op s [CSStmt]
-contextFinalInits = gets compInit
-
-item :: CSStmt -> CompilerM op s ()
-item x = tell $ mempty { accItems = [x] }
-
-stm :: CSStmt -> CompilerM op s ()
-stm = item
-
-stms :: [CSStmt] -> CompilerM op s ()
-stms = mapM_ stm
-
-debugReport :: CSStmt -> CompilerM op s ()
-debugReport x = modify $ \s ->
- s { compDebugItems = compDebugItems s ++ [x] }
-
-getVarAssigned :: VName -> CompilerM op s Bool
-getVarAssigned vname =
- elem vname <$> gets compAssignedVars
-
-setVarAssigned :: VName -> CompilerM op s ()
-setVarAssigned vname = modify $ \s ->
- s { compAssignedVars = vname : compAssignedVars s}
-
-futharkFun :: String -> String
-futharkFun s = "futhark_" ++ zEncodeString s
-
-paramsTypes :: [Imp.Param] -> [Imp.Type]
-paramsTypes = map paramType
-
-paramType :: Imp.Param -> Imp.Type
-paramType (Imp.MemParam _ space) = Imp.Mem (Imp.ConstSize 0) space
-paramType (Imp.ScalarParam _ t) = Imp.Scalar t
-
-compileOutput :: Imp.Param -> (CSExp, CSType)
-compileOutput = nameFun &&& typeFun
- where nameFun = Var . compileName . Imp.paramName
- typeFun = compileType . paramType
-
-getDefaultDecl :: Imp.Param -> CSStmt
-getDefaultDecl (Imp.MemParam v DefaultSpace) =
- Assign (Var $ compileName v) $ simpleCall "allocateMem" [Integer 0]
-getDefaultDecl (Imp.MemParam v _) =
- AssignTyped (CustomT "OpenCLMemblock") (Var $ compileName v) (Just $ simpleCall "EmptyMemblock" [Var "Ctx.EMPTY_MEM_HANDLE"])
-getDefaultDecl (Imp.ScalarParam v Cert) =
- Assign (Var $ compileName v) $ Bool True
-getDefaultDecl (Imp.ScalarParam v t) =
- Assign (Var $ compileName v) $ simpleInitClass (compilePrimType t) []
-
-
-runCompilerM :: Imp.Functions op -> Operations op s
- -> VNameSource
- -> s
- -> CompilerM op s a
- -> a
-runCompilerM prog ops src userstate (CompilerM m) =
- fst $ evalRWS m (newCompilerEnv prog ops) (newCompilerState src userstate)
-
-standardOptions :: [Option]
-standardOptions = [
- Option { optionLongName = "write-runtime-to"
- , optionShortName = Just 't'
- , optionArgument = RequiredArgument
- , optionAction =
- [
- If (BinOp "!=" (Var "RuntimeFile") Null)
- [Exp $ simpleCall "RuntimeFile.Close" []] []
- , Reassign (Var "RuntimeFile") $
- simpleInitClass "FileStream" [Var "optarg", Var "FileMode.Create"]
- , Reassign (Var "RuntimeFileWriter") $
- simpleInitClass "StreamWriter" [Var "RuntimeFile"]
- ]
- },
- Option { optionLongName = "runs"
- , optionShortName = Just 'r'
- , optionArgument = RequiredArgument
- , optionAction =
- [ Reassign (Var "NumRuns") $ simpleCall "Convert.ToInt32" [Var "optarg"]
- , Reassign (Var "DoWarmupRun") $ Bool True
- ]
- },
- Option { optionLongName = "entry-point"
- , optionShortName = Just 'e'
- , optionArgument = RequiredArgument
- , optionAction =
- [ Reassign (Var "EntryPoint") $ Var "optarg" ]
- }
- ]
-
--- | The class generated by the code generator must have a
--- constructor, although it can be vacuous.
-data Constructor = Constructor [CSFunDefArg] [CSStmt]
-
--- | A constructor that takes no arguments and does nothing.
-emptyConstructor :: Constructor
-emptyConstructor = Constructor [(Composite $ ArrayT $ Primitive StringT, "args")] []
-
-constructorToConstructorDef :: Constructor -> String -> [CSStmt] -> CSStmt
-constructorToConstructorDef (Constructor params body) name at_init =
- ConstructorDef $ ClassConstructor name params $ body <> at_init
-
-
-compileProg :: MonadFreshNames m =>
- Maybe String
- -> Constructor
- -> [CSStmt]
- -> [CSStmt]
- -> Operations op s
- -> s
- -> CompilerM op s ()
- -> [CSStmt]
- -> [Space]
- -> [Option]
- -> Imp.Functions op
- -> m String
-compileProg module_name constructor imports defines ops userstate boilerplate pre_timing _ options prog@(Imp.Functions funs) = do
- src <- getNameSource
- let prog' = runCompilerM prog ops src userstate compileProg'
- let imports' = [ Using Nothing "System"
- , Using Nothing "System.Diagnostics"
- , Using Nothing "System.Collections"
- , Using Nothing "System.Collections.Generic"
- , Using Nothing "System.IO"
- , Using Nothing "System.Linq"
- , Using Nothing "System.Runtime.InteropServices"
- , Using Nothing "static System.ValueTuple"
- , Using Nothing "static System.Convert"
- , Using Nothing "static System.Math"
- , Using Nothing "System.Numerics"
- , Using Nothing "Mono.Options" ] ++ imports
-
- return $ pretty (CSProg $ imports' ++ prog')
- where compileProg' = do
- definitions <- mapM compileFunc funs
- opencl_boilerplate <- collect boilerplate
- compBeforeParses <- gets compBeforeParse
- compInits <- gets compInit
- staticDecls <- gets compStaticMemDecls
- staticAllocs <- gets compStaticMemAllocs
- extraMemberDecls <- gets compMemberDecls
- let member_decls' = member_decls ++ extraMemberDecls ++ staticDecls
- let at_inits' = at_inits ++ compBeforeParses ++ parse_options ++ compInits ++ staticAllocs
-
-
- case module_name of
- Just name -> do
- entry_points <- mapM (compileEntryFun pre_timing) $ filter (Imp.functionEntry . snd) funs
- let constructor' = constructorToConstructorDef constructor name at_inits'
- return [ Namespace name [ClassDef $ PublicClass name $ member_decls' ++
- constructor' : defines' ++ opencl_boilerplate ++
- map PrivateFunDef definitions ++
- map PublicFunDef entry_points ]]
-
-
- Nothing -> do
- let name = "FutharkInternal"
- let constructor' = constructorToConstructorDef constructor name at_inits'
- (entry_point_defs, entry_point_names, entry_points) <-
- unzip3 <$> mapM (callEntryFun pre_timing)
- (filter (Imp.functionEntry . snd) funs)
-
- debug_ending <- gets compDebugItems
- return [Namespace name ((ClassDef $
- PublicClass name $
- member_decls' ++
- constructor' : defines' ++
- opencl_boilerplate ++
- map PrivateFunDef (definitions ++ entry_point_defs) ++
- [PublicFunDef $ Def "InternalEntry" VoidT [] $ selectEntryPoint entry_point_names entry_points ++ debug_ending
- ]
- ) :
- [ClassDef $ PublicClass "Program"
- [StaticFunDef $ Def "Main" VoidT [(string_arrayT,"args")] main_entry]])
- ]
-
-
-
- string_arrayT = Composite $ ArrayT $ Primitive StringT
- main_entry :: [CSStmt]
- main_entry = [ Assign (Var "internalInstance") (simpleInitClass "FutharkInternal" [Var "args"])
- , Exp $ simpleCall "internalInstance.InternalEntry" []
- ]
-
- member_decls =
- [ AssignTyped (CustomT "FileStream") (Var "RuntimeFile") Nothing
- , AssignTyped (CustomT "StreamWriter") (Var "RuntimeFileWriter") Nothing
- , AssignTyped (Primitive BoolT) (Var "DoWarmupRun") Nothing
- , AssignTyped (Primitive $ CSInt Int32T) (Var "NumRuns") Nothing
- , AssignTyped (Primitive StringT) (Var "EntryPoint") Nothing
- ]
-
- at_inits = [ Reassign (Var "DoWarmupRun") (Bool False)
- , Reassign (Var "NumRuns") (Integer 1)
- , Reassign (Var "EntryPoint") (String "main")
- , Exp $ simpleCall "ValueReader" []
- ]
-
- defines' = [ Escape csScalar
- , Escape csMemory
- , Escape csPanic
- , Escape csExceptions
- , Escape csReader] ++ defines
-
- parse_options =
- generateOptionParser (standardOptions ++ options)
-
- selectEntryPoint entry_point_names entry_points =
- [ Assign (Var "EntryPoints") $
- Collection "Dictionary<string, Action>" $ zipWith Pair (map String entry_point_names) entry_points,
- If (simpleCall "!EntryPoints.ContainsKey" [Var "EntryPoint"])
- [ Exp $ simpleCall "Console.Error.WriteLine"
- [simpleCall "string.Format"
- [ String "No entry point '{0}'. Select another with --entry point. Options are:\n{1}"
- , Var "EntryPoint"
- , simpleCall "string.Join"
- [ String "\n"
- , Field (Var "EntryPoints") "Keys" ]]]
- , Exp $ simpleCall "Environment.Exit" [Integer 1]]
- [ Assign (Var "entryPointFun") $
- Index (Var "EntryPoints") (IdxExp $ Var "EntryPoint")
- , Exp $ simpleCall "entryPointFun.Invoke" []]
- ]
-
-
-compileFunc :: (Name, Imp.Function op) -> CompilerM op s CSFunDef
-compileFunc (fname, Imp.Function _ outputs inputs body _ _) = do
- body' <- blockScope $ compileCode body
- let inputs' = map compileTypedInput inputs
- let outputs' = map compileOutput outputs
- let outputDecls = map getDefaultDecl outputs
- let (ret, retType) = unzip outputs'
- let retType' = tupleOrSingleT retType
- let ret' = [Return $ tupleOrSingle ret]
-
- case outputs of
- [] -> return $ Def (futharkFun . nameToString $ fname) VoidT inputs' (outputDecls++body')
- _ -> return $ Def (futharkFun . nameToString $ fname) retType' inputs' (outputDecls++body'++ret')
-
-
-compileTypedInput :: Imp.Param -> (CSType, String)
-compileTypedInput input = (typeFun input, nameFun input)
- where nameFun = compileName . Imp.paramName
- typeFun = compileType . paramType
-
-tupleOrSingleEntryT :: [CSType] -> CSType
-tupleOrSingleEntryT [e] = e
-tupleOrSingleEntryT es = Composite $ SystemTupleT es
-
-tupleOrSingleEntry :: [CSExp] -> CSExp
-tupleOrSingleEntry [e] = e
-tupleOrSingleEntry es = CreateSystemTuple es
-
-tupleOrSingleT :: [CSType] -> CSType
-tupleOrSingleT [e] = e
-tupleOrSingleT es = Composite $ TupleT es
-
-tupleOrSingle :: [CSExp] -> CSExp
-tupleOrSingle [e] = e
-tupleOrSingle es = Tuple es
-
-assignScalarPointer :: CSExp -> CSExp -> CSStmt
-assignScalarPointer e ptr =
- AssignTyped (PointerT VoidT) ptr (Just $ Addr e)
-
--- | A 'Call' where the function is a variable and every argument is a
--- simple 'Arg'.
-simpleCall :: String -> [CSExp] -> CSExp
-simpleCall fname = Call (Var fname) . map simpleArg
-
--- | A 'Call' where the function is a variable and every argument is a
--- simple 'Arg'.
-parametrizedCall :: String -> String -> [CSExp] -> CSExp
-parametrizedCall fname primtype = Call (Var fname') . map simpleArg
- where fname' = concat [fname, "<", primtype, ">"]
-
-simpleArg :: CSExp -> CSArg
-simpleArg = Arg Nothing
-
--- | A CallMethod
-callMethod :: CSExp -> String -> [CSExp] -> CSExp
-callMethod object method = CallMethod object (Var method) . map simpleArg
-
-simpleInitClass :: String -> [CSExp] -> CSExp
-simpleInitClass fname =CreateObject (Var fname) . map simpleArg
-
-compileName :: VName -> String
-compileName = zEncodeString . pretty
-
-compileType :: Imp.Type -> CSType
-compileType (Imp.Scalar p) = compilePrimTypeToAST p
-compileType (Imp.Mem _ space) = rawMemCSType space
-
-compilePrimTypeToAST :: PrimType -> CSType
-compilePrimTypeToAST (IntType Int8) = Primitive $ CSInt Int8T
-compilePrimTypeToAST (IntType Int16) = Primitive $ CSInt Int16T
-compilePrimTypeToAST (IntType Int32) = Primitive $ CSInt Int32T
-compilePrimTypeToAST (IntType Int64) = Primitive $ CSInt Int64T
-compilePrimTypeToAST (FloatType Float32) = Primitive $ CSFloat FloatT
-compilePrimTypeToAST (FloatType Float64) = Primitive $ CSFloat DoubleT
-compilePrimTypeToAST Imp.Bool = Primitive BoolT
-compilePrimTypeToAST Imp.Cert = Primitive BoolT
-
-compilePrimTypeToASText :: PrimType -> Imp.Signedness -> CSType
-compilePrimTypeToASText (IntType Int8) Imp.TypeUnsigned = Primitive $ CSUInt UInt8T
-compilePrimTypeToASText (IntType Int16) Imp.TypeUnsigned = Primitive $ CSUInt UInt16T
-compilePrimTypeToASText (IntType Int32) Imp.TypeUnsigned = Primitive $ CSUInt UInt32T
-compilePrimTypeToASText (IntType Int64) Imp.TypeUnsigned = Primitive $ CSUInt UInt64T
-compilePrimTypeToASText (IntType Int8) _ = Primitive $ CSInt Int8T
-compilePrimTypeToASText (IntType Int16) _ = Primitive $ CSInt Int16T
-compilePrimTypeToASText (IntType Int32) _ = Primitive $ CSInt Int32T
-compilePrimTypeToASText (IntType Int64) _ = Primitive $ CSInt Int64T
-compilePrimTypeToASText (FloatType Float32) _ = Primitive $ CSFloat FloatT
-compilePrimTypeToASText (FloatType Float64) _ = Primitive $ CSFloat DoubleT
-compilePrimTypeToASText Imp.Bool _ = Primitive BoolT
-compilePrimTypeToASText Imp.Cert _ = Primitive BoolT
-
-compileDim :: Imp.DimSize -> CSExp
-compileDim (Imp.ConstSize i) = Integer $ toInteger i
-compileDim (Imp.VarSize v) = Var $ compileName v
-
-unpackDim :: CSExp -> Imp.DimSize -> Int32 -> CompilerM op s ()
-unpackDim arr_name (Imp.ConstSize c) i = do
- let shape_name = Field arr_name "Item2" -- array tuples are currently (data array * dimension array) currently
- let constant_c = Integer $ toInteger c
- let constant_i = Integer $ toInteger i
- stm $ Assert (BinOp "==" constant_c (Index shape_name $ IdxExp constant_i)) [String "constant dimension wrong"]
-
-unpackDim arr_name (Imp.VarSize var) i = do
- let shape_name = Field arr_name "Item2"
- let src = Index shape_name $ IdxExp $ Integer $ toInteger i
- let dest = Var $ compileName var
- isAssigned <- getVarAssigned var
- if isAssigned
- then
- stm $ Reassign dest $ Cast (Primitive $ CSInt Int32T) src
- else do
- stm $ Assign dest $ Cast (Primitive $ CSInt Int32T) src
- setVarAssigned var
-
-entryPointOutput :: Imp.ExternalValue -> CompilerM op s CSExp
-entryPointOutput (Imp.OpaqueValue _ vs) =
- CreateSystemTuple <$> mapM (entryPointOutput . Imp.TransparentValue) vs
-
-entryPointOutput (Imp.TransparentValue (Imp.ScalarValue bt ept name)) =
- return $ cast $ Var $ compileName name
- where cast = compileTypecastExt bt ept
-
-entryPointOutput (Imp.TransparentValue (Imp.ArrayValue mem _ Imp.DefaultSpace bt ept dims)) = do
- let src = Var $ compileName mem
- let createTuple = "createTuple_" ++ compilePrimTypeExt bt ept
- return $ simpleCall createTuple [src, CreateArray (Primitive $ CSInt Int64T) $ map compileDim dims]
-
-entryPointOutput (Imp.TransparentValue (Imp.ArrayValue mem _ (Imp.Space sid) bt ept dims)) = do
- unRefMem mem (Imp.Space sid)
- pack_output <- asks envEntryOutput
- pack_output mem sid bt ept dims
-
-entryPointInput :: (Int, Imp.ExternalValue, CSExp) -> CompilerM op s ()
-entryPointInput (i, Imp.OpaqueValue _ vs, e) =
- mapM_ entryPointInput $ zip3 (repeat i) (map Imp.TransparentValue vs) $
- map (\idx -> Field e $ "Item" ++ show (idx :: Int)) [1..]
-
-entryPointInput (_, Imp.TransparentValue (Imp.ScalarValue bt _ name), e) = do
- let vname' = Var $ compileName name
- cast = compileTypecast bt
- stm $ Assign vname' (cast e)
-
-entryPointInput (_, Imp.TransparentValue (Imp.ArrayValue mem memsize Imp.DefaultSpace bt _ dims), e) = do
- zipWithM_ (unpackDim e) dims [0..]
- let arrayData = Field e "Item1"
- let dest = Var $ compileName mem
- unwrap_call = simpleCall "unwrapArray" [arrayData, sizeOf $ compilePrimTypeToAST bt]
- case memsize of
- Imp.VarSize sizevar ->
- stm $ Assign (Var $ compileName sizevar) $ Field e "Item2.Length"
- Imp.ConstSize _ ->
- return ()
- stm $ Assign dest unwrap_call
-
-entryPointInput (_, Imp.TransparentValue (Imp.ArrayValue mem memsize (Imp.Space sid) bt ept dims), e) = do
- unpack_input <- asks envEntryInput
- unpack <- collect $ unpack_input mem memsize sid bt ept dims e
- stms unpack
-
-extValueDescName :: Imp.ExternalValue -> String
-extValueDescName (Imp.TransparentValue v) = extName $ valueDescName v
-extValueDescName (Imp.OpaqueValue desc []) = extName $ zEncodeString desc
-extValueDescName (Imp.OpaqueValue desc (v:_)) =
- extName $ zEncodeString desc ++ "_" ++ pretty (baseTag (valueDescVName v))
-
-extName :: String -> String
-extName = (++"_ext")
-
-sizeOf :: CSType -> CSExp
-sizeOf t = simpleCall "sizeof" [(Var . pretty) t]
-
-publicFunDef :: String -> CSType -> [(CSType, String)] -> [CSStmt] -> CSStmt
-publicFunDef s t args stmts = PublicFunDef $ Def s t args stmts
-
-privateFunDef :: String -> CSType -> [(CSType, String)] -> [CSStmt] -> CSStmt
-privateFunDef s t args stmts = PrivateFunDef $ Def s t args stmts
-
-valueDescName :: Imp.ValueDesc -> String
-valueDescName = compileName . valueDescVName
-
-valueDescVName :: Imp.ValueDesc -> VName
-valueDescVName (Imp.ScalarValue _ _ vname) = vname
-valueDescVName (Imp.ArrayValue vname _ _ _ _ _) = vname
-
-consoleWrite :: String -> [CSExp] -> CSExp
-consoleWrite str exps = simpleCall "Console.Write" $ String str:exps
-
-consoleWriteLine :: String -> [CSExp] -> CSExp
-consoleWriteLine str exps = simpleCall "Console.WriteLine" $ String str:exps
-
-consoleErrorWrite :: String -> [CSExp] -> CSExp
-consoleErrorWrite str exps = simpleCall "Console.Error.Write" $ String str:exps
-
-consoleErrorWriteLine :: String -> [CSExp] -> CSExp
-consoleErrorWriteLine str exps = simpleCall "Console.Error.WriteLine" $ String str:exps
-
-readFun :: PrimType -> Imp.Signedness -> String
-readFun (FloatType Float32) _ = "ReadF32"
-readFun (FloatType Float64) _ = "ReadF64"
-readFun (IntType Int8) Imp.TypeUnsigned = "ReadU8"
-readFun (IntType Int16) Imp.TypeUnsigned = "ReadU16"
-readFun (IntType Int32) Imp.TypeUnsigned = "ReadU32"
-readFun (IntType Int64) Imp.TypeUnsigned = "ReadU64"
-readFun (IntType Int8) Imp.TypeDirect = "ReadI8"
-readFun (IntType Int16) Imp.TypeDirect = "ReadI16"
-readFun (IntType Int32) Imp.TypeDirect = "ReadI32"
-readFun (IntType Int64) Imp.TypeDirect = "ReadI64"
-readFun Imp.Bool _ = "ReadBool"
-readFun Cert _ = error "readFun: cert"
-
-readBinFun :: PrimType -> Imp.Signedness -> String
-readBinFun (FloatType Float32) _bin_ = "ReadBinF32"
-readBinFun (FloatType Float64) _bin_ = "ReadBinF64"
-readBinFun (IntType Int8) Imp.TypeUnsigned = "ReadBinU8"
-readBinFun (IntType Int16) Imp.TypeUnsigned = "ReadBinU16"
-readBinFun (IntType Int32) Imp.TypeUnsigned = "ReadBinU32"
-readBinFun (IntType Int64) Imp.TypeUnsigned = "ReadBinU64"
-readBinFun (IntType Int8) Imp.TypeDirect = "ReadBinI8"
-readBinFun (IntType Int16) Imp.TypeDirect = "ReadBinI16"
-readBinFun (IntType Int32) Imp.TypeDirect = "ReadBinI32"
-readBinFun (IntType Int64) Imp.TypeDirect = "ReadBinI64"
-readBinFun Imp.Bool _ = "ReadBinBool"
-readBinFun Cert _ = error "readFun: cert"
-
--- The value returned will be used when reading binary arrays, to indicate what
--- the expected type is
--- Key into the FUTHARK_PRIMTYPES dict.
-readTypeEnum :: PrimType -> Imp.Signedness -> String
-readTypeEnum (IntType Int8) Imp.TypeUnsigned = "u8"
-readTypeEnum (IntType Int16) Imp.TypeUnsigned = "u16"
-readTypeEnum (IntType Int32) Imp.TypeUnsigned = "u32"
-readTypeEnum (IntType Int64) Imp.TypeUnsigned = "u64"
-readTypeEnum (IntType Int8) Imp.TypeDirect = "i8"
-readTypeEnum (IntType Int16) Imp.TypeDirect = "i16"
-readTypeEnum (IntType Int32) Imp.TypeDirect = "i32"
-readTypeEnum (IntType Int64) Imp.TypeDirect = "i64"
-readTypeEnum (FloatType Float32) _ = "f32"
-readTypeEnum (FloatType Float64) _ = "f64"
-readTypeEnum Imp.Bool _ = "bool"
-readTypeEnum Cert _ = error "readTypeEnum: cert"
-
-readInput :: Imp.ExternalValue -> CSStmt
-readInput (Imp.OpaqueValue desc _) =
- Throw $ simpleInitClass "Exception" [String $ "Cannot read argument of type " ++ desc ++ "."]
-
-readInput decl@(Imp.TransparentValue (Imp.ScalarValue bt ept _)) =
- let read_func = Var $ readFun bt ept
- read_bin_func = Var $ readBinFun bt ept
- type_enum = String $ readTypeEnum bt ept
- bt' = compilePrimTypeExt bt ept
- readScalar = initializeGenericFunction "ReadScalar" bt'
- in Assign (Var $ extValueDescName decl) $ simpleCall readScalar [type_enum, read_func, read_bin_func]
-
--- TODO: If the type identifier of 'Float32' is changed, currently the error
--- messages for reading binary input will not use this new name. This is also a
--- problem for the C runtime system.
-readInput decl@(Imp.TransparentValue (Imp.ArrayValue _ _ _ bt ept dims)) =
- let rank' = Var $ show $ length dims
- type_enum = String $ readTypeEnum bt ept
- bt' = compilePrimTypeExt bt ept
- read_func = Var $ readFun bt ept
- readArray = initializeGenericFunction "ReadArray" bt'
- in Assign (Var $ extValueDescName decl) $ simpleCall readArray [rank', type_enum, read_func]
-
-initializeGenericFunction :: String -> String -> String
-initializeGenericFunction fun tp = fun ++ "<" ++ tp ++ ">"
-
-
-printPrimStm :: CSExp -> CSStmt
-printPrimStm val = Exp $ simpleCall "WriteValue" [val]
-
-formatString :: String -> [CSExp] -> CSExp
-formatString fmt contents =
- simpleCall "String.Format" $ String fmt : contents
-
-printStm :: Imp.ValueDesc -> CSExp -> CSExp -> CompilerM op s CSStmt
-printStm Imp.ScalarValue{} _ e =
- return $ printPrimStm e
-printStm (Imp.ArrayValue _ _ _ _ _ []) ind e = do
- let e' = Index e (IdxExp (PostUnOp "++" ind))
- return $ printPrimStm e'
-
-printStm (Imp.ArrayValue mem memsize space bt ept (outer:shape)) ind e = do
- ptr <- newVName "shapePtr"
- first <- newVName "printFirst"
- let size = callMethod (CreateArray (Primitive $ CSInt Int32T) $ map compileDim $ outer:shape)
- "Aggregate" [ Integer 1
- , Lambda (Tuple [Var "acc", Var "val"])
- [Exp $ BinOp "*" (Var "acc") (Var "val")]
- ]
- emptystr = "empty(" ++ ppArrayType bt (length shape) ++ ")"
-
- printelem <- printStm (Imp.ArrayValue mem memsize space bt ept shape) ind e
- return $
- If (BinOp "==" size (Integer 0))
- [puts emptystr]
- [ Assign (Var $ pretty first) $ Var "true"
- , puts "["
- , For (pretty ptr) (compileDim outer)
- [ If (simpleCall "!" [Var $ pretty first]) [puts ", "] []
- , printelem
- , Reassign (Var $ pretty first) $ Var "false"
- ]
- , puts "]"
- ]
-
- where ppArrayType :: PrimType -> Int -> String
- ppArrayType t 0 = prettyPrimType ept t
- ppArrayType t n = "[]" ++ ppArrayType t (n-1)
-
- prettyPrimType Imp.TypeUnsigned (IntType Int8) = "u8"
- prettyPrimType Imp.TypeUnsigned (IntType Int16) = "u16"
- prettyPrimType Imp.TypeUnsigned (IntType Int32) = "u32"
- prettyPrimType Imp.TypeUnsigned (IntType Int64) = "u64"
- prettyPrimType _ t = pretty t
-
- puts s = Exp $ simpleCall "Console.Write" [String s]
-
-printValue :: [(Imp.ExternalValue, CSExp)] -> CompilerM op s [CSStmt]
-printValue = fmap concat . mapM (uncurry printValue')
- -- We copy non-host arrays to the host before printing. This is
- -- done in a hacky way - we assume the value has a .get()-method
- -- that returns an equivalent Numpy array. This works for CSOpenCL,
- -- but we will probably need yet another plugin mechanism here in
- -- the future.
- where printValue' (Imp.OpaqueValue desc _) _ =
- return [Exp $ simpleCall "Console.Write"
- [String $ "#<opaque " ++ desc ++ ">"]]
- printValue' (Imp.TransparentValue r@Imp.ScalarValue{}) e = do
- p <- printStm r (Integer 0) e
- return [p, Exp $ simpleCall "Console.Write" [String "\n"]]
- printValue' (Imp.TransparentValue r@Imp.ArrayValue{}) e = do
- tuple <- newVName "resultArr"
- i <- newVName "arrInd"
- let i' = Var $ compileName i
- p <- printStm r i' (Var $ compileName tuple)
- let e' = Var $ pretty e
- return [ Assign (Var $ compileName tuple) (Field e' "Item1")
- , Assign i' (Integer 0)
- , p
- , Exp $ simpleCall "Console.Write" [String "\n"]]
-
-prepareEntry :: (Name, Imp.Function op) -> CompilerM op s
- (String, [(CSType, String)], CSType, [CSStmt], [CSStmt], [CSStmt], [CSStmt],
- [(Imp.ExternalValue, CSExp)], [CSStmt])
-prepareEntry (fname, Imp.Function _ outputs inputs _ results args) = do
- let (output_types, output_paramNames) = unzip $ map compileTypedInput outputs
- funTuple = tupleOrSingle $ fmap Var output_paramNames
-
-
- (_, sizeDecls) <- collect' $ forM args declsfunction
-
- (argexps_mem_copies, prepare_run) <- collect' $ forM inputs $ \case
- Imp.MemParam name space -> do
- -- A program might write to its input parameters, so create a new memory
- -- block and copy the source there. This way the program can be run more
- -- than once.
- name' <- newVName $ baseString name <> "_copy"
- copy <- asks envCopy
- allocate <- asks envAllocate
-
- let size = Var (compileName name ++ "_nbytes")
- dest = name'
- src = name
- offset = Integer 0
- case space of
- DefaultSpace ->
- stm $ Reassign (Var (compileName name'))
- (simpleCall "allocateMem" [size]) -- FIXME
- Space sid ->
- allocate name' size sid
- copy dest offset space src offset space size (IntType Int64) -- FIXME
- return $ Just (compileName name')
- _ -> return Nothing
-
- prepareIn <- collect $ mapM_ entryPointInput $ zip3 [0..] args $
- map (Var . extValueDescName) args
- (res, prepareOut) <- collect' $ mapM entryPointOutput results
-
- let mem_copies = mapMaybe liftMaybe $ zip argexps_mem_copies inputs
- mem_copy_inits = map initCopy mem_copies
-
- argexps_lib = map (compileName . Imp.paramName) inputs
- argexps_bin = zipWith fromMaybe argexps_lib argexps_mem_copies
- fname' = futharkFun (nameToString fname)
- arg_types = map (fst . compileTypedInput) inputs
- inputs' = zip arg_types (map extValueDescName args)
- output_type = tupleOrSingleEntryT output_types
- call_lib = [Reassign funTuple $ simpleCall fname' (fmap Var argexps_lib)]
- call_bin = [Reassign funTuple $ simpleCall fname' (fmap Var argexps_bin)]
- prepareIn' = prepareIn ++ mem_copy_inits ++ sizeDecls
-
- return (nameToString fname, inputs', output_type,
- prepareIn', call_lib, call_bin, prepareOut,
- zip results res, prepare_run)
-
- where liftMaybe (Just a, b) = Just (a,b)
- liftMaybe _ = Nothing
-
- initCopy (varName, Imp.MemParam _ space) = declMem' varName space
- initCopy _ = Pass
-
- valueDescFun (Imp.ArrayValue mem _ Imp.DefaultSpace _ _ _) =
- stm $ Assign (Var $ compileName mem ++ "_nbytes") (Var $ compileName mem ++ ".Length")
- valueDescFun (Imp.ArrayValue mem _ (Imp.Space _) bt _ dims) =
- stm $ Assign (Var $ compileName mem ++ "_nbytes") $ foldr (BinOp "*" . compileDim) (sizeOf $ compilePrimTypeToAST bt) dims
- valueDescFun _ = stm Pass
-
- declsfunction (Imp.TransparentValue v) = valueDescFun v
- declsfunction (Imp.OpaqueValue _ vs) = mapM_ valueDescFun vs
-
-copyMemoryDefaultSpace :: VName -> CSExp -> VName -> CSExp -> CSExp ->
- CompilerM op s ()
-copyMemoryDefaultSpace destmem destidx srcmem srcidx nbytes =
- stm $ Exp $ simpleCall "Buffer.BlockCopy" [ Var (compileName srcmem), srcidx
- , Var (compileName destmem), destidx,
- nbytes]
-
-compileEntryFun :: [CSStmt] -> (Name, Imp.Function op)
- -> CompilerM op s CSFunDef
-compileEntryFun pre_timing entry@(_,Imp.Function _ outputs _ _ results args) = do
- let params = map (getType &&& extValueDescName) args
- let outputType = tupleOrSingleEntryT $ map getType results
-
- (fname', _, _, prepareIn, body_lib, _, prepareOut, res, _) <- prepareEntry entry
- let ret = Return $ tupleOrSingleEntry $ map snd res
- let outputDecls = map getDefaultDecl outputs
- do_run = body_lib ++ pre_timing
- (do_run_with_timing, close_runtime_file) <- addTiming do_run
-
- let do_warmup_run = If (Var "DoWarmupRun") do_run []
- do_num_runs = For "i" (Var "NumRuns") do_run_with_timing
-
- return $ Def fname' outputType params $
- prepareIn ++ outputDecls ++ [do_warmup_run, do_num_runs, close_runtime_file] ++ prepareOut ++ [ret]
-
- where getType :: Imp.ExternalValue -> CSType
- getType (Imp.OpaqueValue _ valueDescs) =
- let valueDescs' = map getType' valueDescs
- in Composite $ SystemTupleT valueDescs'
- getType (Imp.TransparentValue valueDesc) =
- getType' valueDesc
-
- getType' :: Imp.ValueDesc -> CSType
- getType' (Imp.ScalarValue primtype signedness _) =
- compilePrimTypeToASText primtype signedness
- getType' (Imp.ArrayValue _ _ _ primtype signedness _) =
- let t = compilePrimTypeToASText primtype signedness
- in Composite $ SystemTupleT [Composite $ ArrayT t, Composite $ ArrayT $ Primitive $ CSInt Int64T]
-
-
-callEntryFun :: [CSStmt] -> (Name, Imp.Function op)
- -> CompilerM op s (CSFunDef, String, CSExp)
-callEntryFun pre_timing entry@(fname, Imp.Function _ outputs _ _ _ decl_args) =
- if any isOpaque decl_args then
- return (Def fname' VoidT [] [exitException], nameToString fname, Var fname')
- else do
- (_, _, _, prepareIn, _, body_bin, prepare_out, res, prepare_run) <- prepareEntry entry
- let str_input = map readInput decl_args
-
- let outputDecls = map getDefaultDecl outputs
- exitcall = [
- Exp $ simpleCall "Console.Error.WriteLine" [formatString "Assertion.{0} failed" [Var "e"]]
- , Exp $ simpleCall "Environment.Exit" [Integer 1]
- ]
- except' = Catch (Var "Exception") exitcall
- do_run = body_bin ++ pre_timing
- (do_run_with_timing, close_runtime_file) <- addTiming do_run
-
- -- We ignore overflow errors and the like for executable entry
- -- points. These are (somewhat) well-defined in Futhark.
-
- let maybe_free =
- [If (BinOp "<" (Var "i") (BinOp "-" (Var "NumRuns") (Integer 1)))
- prepare_out []]
-
- do_warmup_run =
- If (Var "DoWarmupRun") (prepare_run ++ do_run ++ prepare_out) []
-
- do_num_runs =
- For "i" (Var "NumRuns") (prepare_run ++ do_run_with_timing ++ maybe_free)
-
- str_output <- printValue res
-
- return (Def fname' VoidT [] $
- str_input ++ prepareIn ++ outputDecls ++
- [Try [do_warmup_run, do_num_runs] [except']] ++
- [close_runtime_file] ++
- str_output,
-
- nameToString fname,
-
- Var fname')
-
- where fname' = "entry_" ++ nameToString fname
- isOpaque Imp.TransparentValue{} = False
- isOpaque _ = True
-
- exitException = Throw $ simpleInitClass "Exception" [String $ "The function " ++ nameToString fname ++ " is not available as an entry function."]
-
-addTiming :: [CSStmt] -> CompilerM s op ([CSStmt], CSStmt)
-addTiming statements = do
- syncFun <- asks envSyncFun
-
- return ([ Assign (Var "StopWatch") $ simpleInitClass "Stopwatch" []
- , syncFun
- , Exp $ simpleCall "StopWatch.Start" [] ] ++
- statements ++
- [ syncFun
- , Exp $ simpleCall "StopWatch.Stop" []
- , Assign (Var "timeElapsed") $ asMicroseconds (Var "StopWatch")
- , If (not_null (Var "RuntimeFile")) [print_runtime] []
- ]
- , If (not_null (Var "RuntimeFile")) [
- Exp $ simpleCall "RuntimeFileWriter.Close" [] ,
- Exp $ simpleCall "RuntimeFile.Close" []
- ] []
- )
-
- where print_runtime = Exp $ simpleCall "RuntimeFileWriter.WriteLine" [ callMethod (Var "timeElapsed") "ToString" [] ]
- not_null var = BinOp "!=" var Null
- asMicroseconds watch =
- BinOp "/" (Field watch "ElapsedTicks")
- (BinOp "/" (Field (Var "TimeSpan") "TicksPerMillisecond") (Integer 1000))
-
-compileUnOp :: Imp.UnOp -> String
-compileUnOp op =
- case op of
- Not -> "!"
- Complement{} -> "~"
- Abs{} -> "Math.Abs" -- actually write these helpers
- FAbs{} -> "Math.Abs"
- SSignum{} -> "ssignum"
- USignum{} -> "usignum"
-
-compileBinOpLike :: Monad m =>
- Imp.Exp -> Imp.Exp
- -> CompilerM op s (CSExp, CSExp, String -> m CSExp)
-compileBinOpLike x y = do
- x' <- compileExp x
- y' <- compileExp y
- let simple s = return $ BinOp s x' y'
- return (x', y', simple)
-
--- | The ctypes type corresponding to a 'PrimType'.
-compilePrimType :: PrimType -> String
-compilePrimType t =
- case t of
- IntType Int8 -> "sbyte"
- IntType Int16 -> "short"
- IntType Int32 -> "int"
- IntType Int64 -> "long"
- FloatType Float32 -> "float"
- FloatType Float64 -> "double"
- Imp.Bool -> "bool"
- Cert -> "bool"
-
--- | The ctypes type corresponding to a 'PrimType', taking sign into account.
-compilePrimTypeExt :: PrimType -> Imp.Signedness -> String
-compilePrimTypeExt t ept =
- case (t, ept) of
- (IntType Int8, Imp.TypeUnsigned) -> "byte"
- (IntType Int16, Imp.TypeUnsigned) -> "ushort"
- (IntType Int32, Imp.TypeUnsigned) -> "uint"
- (IntType Int64, Imp.TypeUnsigned) -> "ulong"
- (IntType Int8, _) -> "sbyte"
- (IntType Int16, _) -> "short"
- (IntType Int32, _) -> "int"
- (IntType Int64, _) -> "long"
- (FloatType Float32, _) -> "float"
- (FloatType Float64, _) -> "double"
- (Imp.Bool, _) -> "bool"
- (Cert, _) -> "byte"
-
--- | Select function to retrieve bytes from byte array as specific data type
--- | The ctypes type corresponding to a 'PrimType'.
-compileTypecastExt :: PrimType -> Imp.Signedness -> (CSExp -> CSExp)
-compileTypecastExt t ept =
- let t' = case (t, ept) of
- (IntType Int8 , Imp.TypeUnsigned)-> Primitive $ CSUInt UInt8T
- (IntType Int16 , Imp.TypeUnsigned)-> Primitive $ CSUInt UInt16T
- (IntType Int32 , Imp.TypeUnsigned)-> Primitive $ CSUInt UInt32T
- (IntType Int64 , Imp.TypeUnsigned)-> Primitive $ CSUInt UInt64T
- (IntType Int8 , _)-> Primitive $ CSInt Int8T
- (IntType Int16 , _)-> Primitive $ CSInt Int16T
- (IntType Int32 , _)-> Primitive $ CSInt Int32T
- (IntType Int64 , _)-> Primitive $ CSInt Int64T
- (FloatType Float32, _)-> Primitive $ CSFloat FloatT
- (FloatType Float64, _)-> Primitive $ CSFloat DoubleT
- (Imp.Bool , _)-> Primitive BoolT
- (Cert, _)-> Primitive $ CSInt Int8T
- in Cast t'
-
--- | The ctypes type corresponding to a 'PrimType'.
-compileTypecast :: PrimType -> (CSExp -> CSExp)
-compileTypecast t =
- let t' = case t of
- IntType Int8 -> Primitive $ CSInt Int8T
- IntType Int16 -> Primitive $ CSInt Int16T
- IntType Int32 -> Primitive $ CSInt Int32T
- IntType Int64 -> Primitive $ CSInt Int64T
- FloatType Float32 -> Primitive $ CSFloat FloatT
- FloatType Float64 -> Primitive $ CSFloat DoubleT
- Imp.Bool -> Primitive BoolT
- Cert -> Primitive $ CSInt Int8T
- in Cast t'
-
--- | The ctypes type corresponding to a 'PrimType'.
-compilePrimValue :: Imp.PrimValue -> CSExp
-compilePrimValue (IntValue (Int8Value v)) =
- Cast (Primitive $ CSInt Int8T) $ Integer $ toInteger v
-compilePrimValue (IntValue (Int16Value v)) =
- Cast (Primitive $ CSInt Int16T) $ Integer $ toInteger v
-compilePrimValue (IntValue (Int32Value v)) =
- Cast (Primitive $ CSInt Int32T) $ Integer $ toInteger v
-compilePrimValue (IntValue (Int64Value v)) =
- Cast (Primitive $ CSInt Int64T) $ Integer $ toInteger v
-compilePrimValue (FloatValue (Float32Value v))
- | isInfinite v =
- if v > 0 then Var "Single.PositiveInfinity" else Var "Single.NegativeInfinity"
- | isNaN v =
- Var "Single.NaN"
- | otherwise = Cast (Primitive $ CSFloat FloatT) (Float $ fromRational $ toRational v)
-compilePrimValue (FloatValue (Float64Value v))
- | isInfinite v =
- if v > 0 then Var "Double.PositiveInfinity" else Var "Double.NegativeInfinity"
- | isNaN v =
- Var "Double.NaN"
- | otherwise = Cast (Primitive $ CSFloat DoubleT) (Float $ fromRational $ toRational v)
-compilePrimValue (BoolValue v) = Bool v
-compilePrimValue Checked = Bool True
-
-compileExp :: Imp.Exp -> CompilerM op s CSExp
-
-compileExp (Imp.ValueExp v) = return $ compilePrimValue v
-
-compileExp (Imp.LeafExp (Imp.ScalarVar vname) _) =
- return $ Var $ compileName vname
-
-compileExp (Imp.LeafExp (Imp.SizeOf t) _) =
- return $ (compileTypecast $ IntType Int32) (Integer $ primByteSize t)
-
-compileExp (Imp.LeafExp (Imp.Index src (Imp.Count iexp) (IntType Int8) DefaultSpace _) _) = do
- let src' = compileName src
- iexp' <- compileExp iexp
- return $ Cast (Primitive $ CSInt Int8T) (Index (Var src') (IdxExp iexp'))
-
-compileExp (Imp.LeafExp (Imp.Index src (Imp.Count iexp) bt DefaultSpace _) _) = do
- iexp' <- compileExp iexp
- let bt' = compilePrimType bt
- return $ simpleCall ("indexArray_" ++ bt') [Var $ compileName src, iexp']
-
-compileExp (Imp.LeafExp (Imp.Index src (Imp.Count iexp) restype (Imp.Space space) _) _) =
- join $ asks envReadScalar
- <*> pure src <*> compileExp iexp
- <*> pure restype <*> pure space
-
-compileExp (Imp.BinOpExp op x y) = do
- (x', y', simple) <- compileBinOpLike x y
- case op of
- FAdd{} -> simple "+"
- FSub{} -> simple "-"
- FMul{} -> simple "*"
- FDiv{} -> simple "/"
- LogAnd{} -> simple "&&"
- LogOr{} -> simple "||"
- _ -> return $ simpleCall (pretty op) [x', y']
-
-compileExp (Imp.ConvOpExp conv x) = do
- x' <- compileExp x
- return $ simpleCall (pretty conv) [x']
-
-compileExp (Imp.CmpOpExp cmp x y) = do
- (x', y', simple) <- compileBinOpLike x y
- case cmp of
- CmpEq{} -> simple "=="
- FCmpLt{} -> simple "<"
- FCmpLe{} -> simple "<="
- _ -> return $ simpleCall (pretty cmp) [x', y']
-
-compileExp (Imp.UnOpExp op exp1) =
- PreUnOp (compileUnOp op) <$> compileExp exp1
-
-compileExp (Imp.FunExp h args _) =
- simpleCall (futharkFun (pretty h)) <$> mapM compileExp args
-
-compileCode :: Imp.Code op -> CompilerM op s ()
-
-compileCode Imp.DebugPrint{} =
- return ()
-
-compileCode (Imp.Op op) =
- join $ asks envOpCompiler <*> pure op
-
-compileCode (Imp.If cond tb fb) = do
- cond' <- compileExp cond
- tb' <- blockScope $ compileCode tb
- fb' <- blockScope $ compileCode fb
- stm $ If cond' tb' fb'
-
-compileCode (c1 Imp.:>>: c2) = do
- compileCode c1
- compileCode c2
-
-compileCode (Imp.While cond body) = do
- cond' <- compileExp cond
- body' <- blockScope $ compileCode body
- stm $ While cond' body'
-
-compileCode (Imp.For i it bound body) = do
- bound' <- compileExp bound
- let i' = compileName i
- body' <- blockScope $ compileCode body
- counter <- pretty <$> newVName "counter"
- one <- pretty <$> newVName "one"
- stm $ Assign (Var i') $ compileTypecast (IntType it) (Integer 0)
- stm $ Assign (Var one) $ compileTypecast (IntType it) (Integer 1)
- stm $ For counter bound' $ body' ++
- [AssignOp "+" (Var i') (Var one)]
-
-
-compileCode (Imp.SetScalar vname exp1) = do
- let name' = Var $ compileName vname
- exp1' <- compileExp exp1
- stm $ Reassign name' exp1'
-
-compileCode (Imp.DeclareMem v space) = declMem v space
-
-compileCode (Imp.DeclareScalar v Cert) =
- stm $ Assign (Var $ compileName v) $ Bool True
-compileCode (Imp.DeclareScalar v t) =
- stm $ AssignTyped t' (Var $ compileName v) Nothing
- where t' = compilePrimTypeToAST t
-
-compileCode (Imp.DeclareArray name DefaultSpace t vs) =
- stms [Assign (Var $ "init_"++name') $
- simpleCall "unwrapArray"
- [
- CreateArray (compilePrimTypeToAST t) (map compilePrimValue vs)
- , simpleCall "sizeof" [Var $ compilePrimType t]
- ]
- , Assign (Var name') $ Var ("init_"++name')
- ]
- where name' = compileName name
-
-
-compileCode (Imp.DeclareArray name (Space space) t vs) =
- join $ asks envStaticArray <*>
- pure name <*> pure space <*> pure t <*> pure vs
-
-compileCode (Imp.Comment s code) = do
- code' <- blockScope $ compileCode code
- stm $ Comment s code'
-
-compileCode (Imp.Assert e (Imp.ErrorMsg parts) (loc,locs)) = do
- e' <- compileExp e
- let onPart (i, Imp.ErrorString s) = return (printFormatArg i, String s)
- onPart (i, Imp.ErrorInt32 x) = (printFormatArg i,) <$> compileExp x
- (formatstrs, formatargs) <- unzip <$> mapM onPart (zip ([1..] :: [Integer]) parts)
- stm $ Assert e' $ (String $ "Error at {0}:\n" <> concat formatstrs) : (String stacktrace : formatargs)
- where stacktrace = intercalate " -> " (reverse $ map locStr $ loc:locs)
- printFormatArg = printf "{%d}"
-
-compileCode (Imp.Call dests fname args) = do
- args' <- mapM compileArg args
- let dests' = tupleOrSingle $ fmap Var (map compileName dests)
- fname' = futharkFun (pretty fname)
- call' = simpleCall fname' args'
- -- If the function returns nothing (is called only for side
- -- effects), take care not to assign to an empty tuple.
- stm $ if null dests
- then Exp call'
- else Reassign dests' call'
- where compileArg (Imp.MemArg m) = return $ Var $ compileName m
- compileArg (Imp.ExpArg e) = compileExp e
-
-compileCode (Imp.SetMem dest src DefaultSpace) = do
- let src' = Var (compileName src)
- let dest' = Var (compileName dest)
- stm $ Reassign dest' src'
-
-compileCode (Imp.SetMem dest src _) = do
- let src' = Var (compileName src)
- let dest' = Var (compileName dest)
- stm $ Exp $ simpleCall "MemblockSetDevice" [Ref $ Var "Ctx", Ref dest', Ref src', String (compileName src)]
-
-compileCode (Imp.Allocate name (Imp.Count e) DefaultSpace) = do
- e' <- compileExp e
- let allocate' = simpleCall "allocateMem" [e']
- let name' = Var (compileName name)
- stm $ Reassign name' allocate'
-
-compileCode (Imp.Allocate name (Imp.Count e) (Imp.Space space)) =
- join $ asks envAllocate
- <*> pure name
- <*> compileExp e
- <*> pure space
-
-compileCode (Imp.Free name space) = do
- unRefMem name space
- tell $ mempty { accFreedMem = [name] }
-
-compileCode (Imp.Copy dest (Imp.Count destoffset) DefaultSpace src (Imp.Count srcoffset) DefaultSpace (Imp.Count size)) = do
- destoffset' <- compileExp destoffset
- srcoffset' <- compileExp srcoffset
- let dest' = Var (compileName dest)
- let src' = Var (compileName src)
- size' <- compileExp size
- stm $ Exp $ simpleCall "Buffer.BlockCopy" [src', srcoffset', dest', destoffset', size']
-
-compileCode (Imp.Copy dest (Imp.Count destoffset) destspace src (Imp.Count srcoffset) srcspace (Imp.Count size)) = do
- copy <- asks envCopy
- join $ copy
- <$> pure dest <*> compileExp destoffset <*> pure destspace
- <*> pure src <*> compileExp srcoffset <*> pure srcspace
- <*> compileExp size <*> pure (IntType Int64) -- FIXME
-
-compileCode (Imp.Write dest (Imp.Count idx) elemtype DefaultSpace _ elemexp) = do
- idx' <- compileExp idx
- elemexp' <- compileExp elemexp
- let dest' = Var $ compileName dest
- let elemtype' = compileTypecast elemtype
- let ctype = elemtype' elemexp'
- stm $ Exp $ simpleCall "writeScalarArray" [dest', idx', ctype]
-
-compileCode (Imp.Write dest (Imp.Count idx) elemtype (Imp.Space space) _ elemexp) =
- join $ asks envWriteScalar
- <*> pure dest
- <*> compileExp idx
- <*> pure elemtype
- <*> pure space
- <*> compileExp elemexp
-
-compileCode Imp.Skip = return ()
-
-blockScope :: CompilerM op s () -> CompilerM op s [CSStmt]
-blockScope = fmap snd . blockScope'
-
-blockScope' :: CompilerM op s a -> CompilerM op s (a, [CSStmt])
-blockScope' m = do
- old_allocs <- gets compDeclaredMem
- (x, items) <- pass $ do
- (x, w) <- listen m
- let items = accItems w
- return ((x, items), const mempty)
- new_allocs <- gets $ filter (`notElem` old_allocs) . compDeclaredMem
- modify $ \s -> s { compDeclaredMem = old_allocs }
- releases <- collect $ mapM_ (uncurry unRefMem) new_allocs
- return (x, items <> releases)
-
-unRefMem :: VName -> Space -> CompilerM op s ()
-unRefMem mem (Space "device") =
- (stm . Exp) $ simpleCall "MemblockUnrefDevice" [ Ref $ Var "Ctx"
- , (Ref . Var . compileName) mem
- , (String . compileName) mem]
-unRefMem _ DefaultSpace = stm Pass
-unRefMem _ (Space "local") = stm Pass
-unRefMem _ (Space _) = fail "The default compiler cannot compile unRefMem for other spaces"
-
-
--- | Public names must have a consistent prefix.
-publicName :: String -> String
-publicName s = "Futhark" ++ s
-
-declMem :: VName -> Space -> CompilerM op s ()
-declMem name space = do
- modify $ \s -> s { compDeclaredMem = (name, space) : compDeclaredMem s}
- stm $ declMem' (compileName name) space
-
-declMem' :: String -> Space -> CSStmt
-declMem' name DefaultSpace =
- AssignTyped (Composite $ ArrayT $ Primitive ByteT) (Var name) Nothing
-declMem' name (Space _) =
- AssignTyped (CustomT "OpenCLMemblock") (Var name) (Just $ simpleCall "EmptyMemblock" [Var "Ctx.EMPTY_MEM_HANDLE"])
-
-rawMemCSType :: Space -> CSType
-rawMemCSType DefaultSpace = Composite $ ArrayT $ Primitive ByteT
-rawMemCSType (Space _) = CustomT "OpenCLMemblock"
-
-toIntPtr :: CSExp -> CSExp
-toIntPtr e = simpleInitClass "IntPtr" [e]
+{-# LANGUAGE OverloadedStrings, GeneralizedNewtypeDeriving, LambdaCase #-}
+{-# LANGUAGE TupleSections #-}
+-- | A generic C# code generator which is polymorphic in the type
+-- of the operations. Concretely, we use this to handle both
+-- sequential and OpenCL C# code.
+module Futhark.CodeGen.Backends.GenericCSharp
+ ( compileProg
+ , Constructor (..)
+ , emptyConstructor
+
+ , assignScalarPointer
+ , toIntPtr
+ , compileName
+ , compileDim
+ , compileExp
+ , compileCode
+ , compilePrimValue
+ , compilePrimType
+ , compilePrimTypeExt
+ , compilePrimTypeToAST
+ , compilePrimTypeToASText
+ , contextFinalInits
+ , debugReport
+
+ , Operations (..)
+ , defaultOperations
+
+ , unpackDim
+
+ , CompilerM (..)
+ , OpCompiler
+ , WriteScalar
+ , ReadScalar
+ , Allocate
+ , Copy
+ , StaticArray
+ , EntryOutput
+ , EntryInput
+
+ , CompilerEnv(..)
+ , CompilerState(..)
+ , stm
+ , stms
+ , atInit
+ , staticMemDecl
+ , staticMemAlloc
+ , addMemberDecl
+ , beforeParse
+ , collect'
+ , collect
+ , simpleCall
+ , callMethod
+ , simpleInitClass
+ , parametrizedCall
+
+ , copyMemoryDefaultSpace
+ , consoleErrorWrite
+ , consoleErrorWriteLine
+ , consoleWrite
+ , consoleWriteLine
+
+ , publicName
+ , sizeOf
+ , privateFunDef
+ , publicFunDef
+ , getDefaultDecl
+ ) where
+
+import Control.Monad.Identity
+import Control.Monad.State
+import Control.Monad.Reader
+import Control.Monad.Writer
+import Control.Monad.RWS
+import Control.Arrow((&&&))
+import Data.Maybe
+import Data.List
+import qualified Data.Map.Strict as M
+
+import Futhark.Representation.Primitive hiding (Bool)
+import Futhark.MonadFreshNames
+import Futhark.Representation.AST.Syntax (Space(..))
+import qualified Futhark.CodeGen.ImpCode as Imp
+import Futhark.CodeGen.Backends.GenericCSharp.AST
+import Futhark.CodeGen.Backends.GenericCSharp.Options
+import Futhark.CodeGen.Backends.GenericCSharp.Definitions
+import Futhark.Util.Pretty(pretty)
+import Futhark.Util (zEncodeString)
+import Futhark.Representation.AST.Attributes (builtInFunctions)
+import Text.Printf (printf)
+
+-- | A substitute expression compiler, tried before the main
+-- compilation function.
+type OpCompiler op s = op -> CompilerM op s ()
+
+-- | Write a scalar to the given memory block with the given index and
+-- in the given memory space.
+type WriteScalar op s = VName -> CSExp -> PrimType -> Imp.SpaceId -> CSExp
+ -> CompilerM op s ()
+
+-- | Read a scalar from the given memory block with the given index and
+-- in the given memory space.
+type ReadScalar op s = VName -> CSExp -> PrimType -> Imp.SpaceId
+ -> CompilerM op s CSExp
+
+-- | Allocate a memory block of the given size in the given memory
+-- space, saving a reference in the given variable name.
+type Allocate op s = VName -> CSExp -> Imp.SpaceId
+ -> CompilerM op s ()
+
+-- | Copy from one memory block to another.
+type Copy op s = VName -> CSExp -> Imp.Space ->
+ VName -> CSExp -> Imp.Space ->
+ CSExp -> PrimType ->
+ CompilerM op s ()
+
+-- | Create a static array of values - initialised at load time.
+type StaticArray op s = VName -> Imp.SpaceId -> PrimType -> [PrimValue] -> CompilerM op s ()
+
+-- | Construct the C# array being returned from an entry point.
+type EntryOutput op s = VName -> Imp.SpaceId ->
+ PrimType -> Imp.Signedness ->
+ [Imp.DimSize] ->
+ CompilerM op s CSExp
+
+-- | Unpack the array being passed to an entry point.
+type EntryInput op s = VName -> Imp.MemSize -> Imp.SpaceId ->
+ PrimType -> Imp.Signedness ->
+ [Imp.DimSize] ->
+ CSExp ->
+ CompilerM op s ()
+
+data Operations op s = Operations { opsWriteScalar :: WriteScalar op s
+ , opsReadScalar :: ReadScalar op s
+ , opsAllocate :: Allocate op s
+ , opsCopy :: Copy op s
+ , opsStaticArray :: StaticArray op s
+ , opsCompiler :: OpCompiler op s
+ , opsEntryOutput :: EntryOutput op s
+ , opsEntryInput :: EntryInput op s
+ , opsSyncRun :: CSStmt
+ }
+
+-- | A set of operations that fail for every operation involving
+-- non-default memory spaces. Uses plain pointers and @malloc@ for
+-- memory management.
+defaultOperations :: Operations op s
+defaultOperations = Operations { opsWriteScalar = defWriteScalar
+ , opsReadScalar = defReadScalar
+ , opsAllocate = defAllocate
+ , opsCopy = defCopy
+ , opsStaticArray = defStaticArray
+ , opsCompiler = defCompiler
+ , opsEntryOutput = defEntryOutput
+ , opsEntryInput = defEntryInput
+ , opsSyncRun = defSyncRun
+ }
+ where defWriteScalar _ _ _ _ _ =
+ fail "Cannot write to non-default memory space because I am dumb"
+ defReadScalar _ _ _ _ =
+ fail "Cannot read from non-default memory space"
+ defAllocate _ _ _ =
+ fail "Cannot allocate in non-default memory space"
+ defCopy _ _ _ _ _ _ _ _ =
+ fail "Cannot copy to or from non-default memory space"
+ defStaticArray _ _ _ _ =
+ fail "Cannot create static array in non-default memory space"
+ defCompiler _ =
+ fail "The default compiler cannot compile extended operations"
+ defEntryOutput _ _ _ _ =
+ fail "Cannot return array not in default memory space"
+ defEntryInput _ _ _ _ =
+ fail "Cannot accept array not in default memory space"
+ defSyncRun =
+ Pass
+
+data CompilerEnv op s = CompilerEnv {
+ envOperations :: Operations op s
+ , envFtable :: M.Map Name [Imp.Type]
+}
+
+data CompilerAcc op s = CompilerAcc {
+ accItems :: [CSStmt]
+ , accFreedMem :: [VName]
+ }
+
+instance Semigroup (CompilerAcc op s) where
+ CompilerAcc items1 freed1 <> CompilerAcc items2 freed2 =
+ CompilerAcc (items1<>items2) (freed1<>freed2)
+
+instance Monoid (CompilerAcc op s) where
+ mempty = CompilerAcc mempty mempty
+
+envOpCompiler :: CompilerEnv op s -> OpCompiler op s
+envOpCompiler = opsCompiler . envOperations
+
+envReadScalar :: CompilerEnv op s -> ReadScalar op s
+envReadScalar = opsReadScalar . envOperations
+
+envWriteScalar :: CompilerEnv op s -> WriteScalar op s
+envWriteScalar = opsWriteScalar . envOperations
+
+envAllocate :: CompilerEnv op s -> Allocate op s
+envAllocate = opsAllocate . envOperations
+
+envCopy :: CompilerEnv op s -> Copy op s
+envCopy = opsCopy . envOperations
+
+envStaticArray :: CompilerEnv op s -> StaticArray op s
+envStaticArray = opsStaticArray . envOperations
+
+envEntryOutput :: CompilerEnv op s -> EntryOutput op s
+envEntryOutput = opsEntryOutput . envOperations
+
+envEntryInput :: CompilerEnv op s -> EntryInput op s
+envEntryInput = opsEntryInput . envOperations
+
+envSyncFun :: CompilerEnv op s -> CSStmt
+envSyncFun = opsSyncRun . envOperations
+
+newCompilerEnv :: Imp.Functions op -> Operations op s -> CompilerEnv op s
+newCompilerEnv (Imp.Functions funs) ops =
+ CompilerEnv { envOperations = ops
+ , envFtable = ftable <> builtinFtable
+ }
+ where ftable = M.fromList $ map funReturn funs
+ funReturn (name, Imp.Function _ outparams _ _ _ _) = (name, paramsTypes outparams)
+ builtinFtable = M.map (map Imp.Scalar . snd) builtInFunctions
+
+data CompilerState s = CompilerState {
+ compNameSrc :: VNameSource
+ , compBeforeParse :: [CSStmt]
+ , compInit :: [CSStmt]
+ , compStaticMemDecls :: [CSStmt]
+ , compStaticMemAllocs :: [CSStmt]
+ , compDebugItems :: [CSStmt]
+ , compUserState :: s
+ , compMemberDecls :: [CSStmt]
+ , compAssignedVars :: [VName]
+ , compDeclaredMem :: [(VName, Space)]
+}
+
+newCompilerState :: VNameSource -> s -> CompilerState s
+newCompilerState src s = CompilerState { compNameSrc = src
+ , compBeforeParse = []
+ , compInit = []
+ , compStaticMemDecls = []
+ , compStaticMemAllocs = []
+ , compDebugItems = []
+ , compMemberDecls = []
+ , compUserState = s
+ , compAssignedVars = []
+ , compDeclaredMem = []
+ }
+
+newtype CompilerM op s a = CompilerM (RWS (CompilerEnv op s) (CompilerAcc op s) (CompilerState s) a)
+ deriving (Functor, Applicative, Monad,
+ MonadState (CompilerState s),
+ MonadReader (CompilerEnv op s),
+ MonadWriter (CompilerAcc op s))
+
+instance MonadFreshNames (CompilerM op s) where
+ getNameSource = gets compNameSrc
+ putNameSource src = modify $ \s -> s { compNameSrc = src }
+
+collect :: CompilerM op s () -> CompilerM op s [CSStmt]
+collect m = pass $ do
+ ((), w) <- listen m
+ return (accItems w,
+ const w { accItems = mempty} )
+
+collect' :: CompilerM op s a -> CompilerM op s (a, [CSStmt])
+collect' m = pass $ do
+ (x, w) <- listen m
+ return ((x, accItems w),
+ const w { accItems = mempty})
+
+beforeParse :: CSStmt -> CompilerM op s ()
+beforeParse x = modify $ \s ->
+ s { compBeforeParse = compBeforeParse s ++ [x] }
+
+atInit :: CSStmt -> CompilerM op s ()
+atInit x = modify $ \s ->
+ s { compInit = compInit s ++ [x] }
+
+staticMemDecl :: CSStmt -> CompilerM op s ()
+staticMemDecl x = modify $ \s ->
+ s { compStaticMemDecls = compStaticMemDecls s ++ [x] }
+
+staticMemAlloc :: CSStmt -> CompilerM op s ()
+staticMemAlloc x = modify $ \s ->
+ s { compStaticMemAllocs = compStaticMemAllocs s ++ [x] }
+
+addMemberDecl :: CSStmt -> CompilerM op s ()
+addMemberDecl x = modify $ \s ->
+ s { compMemberDecls = compMemberDecls s ++ [x] }
+
+contextFinalInits :: CompilerM op s [CSStmt]
+contextFinalInits = gets compInit
+
+item :: CSStmt -> CompilerM op s ()
+item x = tell $ mempty { accItems = [x] }
+
+stm :: CSStmt -> CompilerM op s ()
+stm = item
+
+stms :: [CSStmt] -> CompilerM op s ()
+stms = mapM_ stm
+
+debugReport :: CSStmt -> CompilerM op s ()
+debugReport x = modify $ \s ->
+ s { compDebugItems = compDebugItems s ++ [x] }
+
+getVarAssigned :: VName -> CompilerM op s Bool
+getVarAssigned vname =
+ elem vname <$> gets compAssignedVars
+
+setVarAssigned :: VName -> CompilerM op s ()
+setVarAssigned vname = modify $ \s ->
+ s { compAssignedVars = vname : compAssignedVars s}
+
+futharkFun :: String -> String
+futharkFun s = "futhark_" ++ zEncodeString s
+
+paramsTypes :: [Imp.Param] -> [Imp.Type]
+paramsTypes = map paramType
+
+paramType :: Imp.Param -> Imp.Type
+paramType (Imp.MemParam _ space) = Imp.Mem (Imp.ConstSize 0) space
+paramType (Imp.ScalarParam _ t) = Imp.Scalar t
+
+compileOutput :: Imp.Param -> (CSExp, CSType)
+compileOutput = nameFun &&& typeFun
+ where nameFun = Var . compileName . Imp.paramName
+ typeFun = compileType . paramType
+
+getDefaultDecl :: Imp.Param -> CSStmt
+getDefaultDecl (Imp.MemParam v DefaultSpace) =
+ Assign (Var $ compileName v) $ simpleCall "allocateMem" [Integer 0]
+getDefaultDecl (Imp.MemParam v _) =
+ AssignTyped (CustomT "OpenCLMemblock") (Var $ compileName v) (Just $ simpleCall "EmptyMemblock" [Var "Ctx.EMPTY_MEM_HANDLE"])
+getDefaultDecl (Imp.ScalarParam v Cert) =
+ Assign (Var $ compileName v) $ Bool True
+getDefaultDecl (Imp.ScalarParam v t) =
+ Assign (Var $ compileName v) $ simpleInitClass (compilePrimType t) []
+
+
+runCompilerM :: Imp.Functions op -> Operations op s
+ -> VNameSource
+ -> s
+ -> CompilerM op s a
+ -> a
+runCompilerM prog ops src userstate (CompilerM m) =
+ fst $ evalRWS m (newCompilerEnv prog ops) (newCompilerState src userstate)
+
+standardOptions :: [Option]
+standardOptions = [
+ Option { optionLongName = "write-runtime-to"
+ , optionShortName = Just 't'
+ , optionArgument = RequiredArgument
+ , optionAction =
+ [
+ If (BinOp "!=" (Var "RuntimeFile") Null)
+ [Exp $ simpleCall "RuntimeFile.Close" []] []
+ , Reassign (Var "RuntimeFile") $
+ simpleInitClass "FileStream" [Var "optarg", Var "FileMode.Create"]
+ , Reassign (Var "RuntimeFileWriter") $
+ simpleInitClass "StreamWriter" [Var "RuntimeFile"]
+ ]
+ },
+ Option { optionLongName = "runs"
+ , optionShortName = Just 'r'
+ , optionArgument = RequiredArgument
+ , optionAction =
+ [ Reassign (Var "NumRuns") $ simpleCall "Convert.ToInt32" [Var "optarg"]
+ , Reassign (Var "DoWarmupRun") $ Bool True
+ ]
+ },
+ Option { optionLongName = "entry-point"
+ , optionShortName = Just 'e'
+ , optionArgument = RequiredArgument
+ , optionAction =
+ [ Reassign (Var "EntryPoint") $ Var "optarg" ]
+ }
+ ]
+
+-- | The class generated by the code generator must have a
+-- constructor, although it can be vacuous.
+data Constructor = Constructor [CSFunDefArg] [CSStmt]
+
+-- | A constructor that takes no arguments and does nothing.
+emptyConstructor :: Constructor
+emptyConstructor = Constructor [(Composite $ ArrayT $ Primitive StringT, "args")] []
+
+constructorToConstructorDef :: Constructor -> String -> [CSStmt] -> CSStmt
+constructorToConstructorDef (Constructor params body) name at_init =
+ ConstructorDef $ ClassConstructor name params $ body <> at_init
+
+
+compileProg :: MonadFreshNames m =>
+ Maybe String
+ -> Constructor
+ -> [CSStmt]
+ -> [CSStmt]
+ -> Operations op s
+ -> s
+ -> CompilerM op s ()
+ -> [CSStmt]
+ -> [Space]
+ -> [Option]
+ -> Imp.Functions op
+ -> m String
+compileProg module_name constructor imports defines ops userstate boilerplate pre_timing _ options prog@(Imp.Functions funs) = do
+ src <- getNameSource
+ let prog' = runCompilerM prog ops src userstate compileProg'
+ let imports' = [ Using Nothing "System"
+ , Using Nothing "System.Diagnostics"
+ , Using Nothing "System.Collections"
+ , Using Nothing "System.Collections.Generic"
+ , Using Nothing "System.IO"
+ , Using Nothing "System.Linq"
+ , Using Nothing "System.Runtime.InteropServices"
+ , Using Nothing "static System.ValueTuple"
+ , Using Nothing "static System.Convert"
+ , Using Nothing "static System.Math"
+ , Using Nothing "System.Numerics"
+ , Using Nothing "Mono.Options" ] ++ imports
+
+ return $ pretty (CSProg $ imports' ++ prog')
+ where compileProg' = do
+ definitions <- mapM compileFunc funs
+ opencl_boilerplate <- collect boilerplate
+ compBeforeParses <- gets compBeforeParse
+ compInits <- gets compInit
+ staticDecls <- gets compStaticMemDecls
+ staticAllocs <- gets compStaticMemAllocs
+ extraMemberDecls <- gets compMemberDecls
+ let member_decls' = member_decls ++ extraMemberDecls ++ staticDecls
+ let at_inits' = at_inits ++ compBeforeParses ++ parse_options ++ compInits ++ staticAllocs
+
+
+ case module_name of
+ Just name -> do
+ entry_points <- mapM (compileEntryFun pre_timing) $ filter (Imp.functionEntry . snd) funs
+ let constructor' = constructorToConstructorDef constructor name at_inits'
+ return [ Namespace name [ClassDef $ PublicClass name $ member_decls' ++
+ constructor' : defines' ++ opencl_boilerplate ++
+ map PrivateFunDef definitions ++
+ map PublicFunDef entry_points ]]
+
+
+ Nothing -> do
+ let name = "FutharkInternal"
+ let constructor' = constructorToConstructorDef constructor name at_inits'
+ (entry_point_defs, entry_point_names, entry_points) <-
+ unzip3 <$> mapM (callEntryFun pre_timing)
+ (filter (Imp.functionEntry . snd) funs)
+
+ debug_ending <- gets compDebugItems
+ return [Namespace name ((ClassDef $
+ PublicClass name $
+ member_decls' ++
+ constructor' : defines' ++
+ opencl_boilerplate ++
+ map PrivateFunDef (definitions ++ entry_point_defs) ++
+ [PublicFunDef $ Def "InternalEntry" VoidT [] $ selectEntryPoint entry_point_names entry_points ++ debug_ending
+ ]
+ ) :
+ [ClassDef $ PublicClass "Program"
+ [StaticFunDef $ Def "Main" VoidT [(string_arrayT,"args")] main_entry]])
+ ]
+
+
+
+ string_arrayT = Composite $ ArrayT $ Primitive StringT
+ main_entry :: [CSStmt]
+ main_entry = [ Assign (Var "internalInstance") (simpleInitClass "FutharkInternal" [Var "args"])
+ , Exp $ simpleCall "internalInstance.InternalEntry" []
+ ]
+
+ member_decls =
+ [ AssignTyped (CustomT "FileStream") (Var "RuntimeFile") Nothing
+ , AssignTyped (CustomT "StreamWriter") (Var "RuntimeFileWriter") Nothing
+ , AssignTyped (Primitive BoolT) (Var "DoWarmupRun") Nothing
+ , AssignTyped (Primitive $ CSInt Int32T) (Var "NumRuns") Nothing
+ , AssignTyped (Primitive StringT) (Var "EntryPoint") Nothing
+ ]
+
+ at_inits = [ Reassign (Var "DoWarmupRun") (Bool False)
+ , Reassign (Var "NumRuns") (Integer 1)
+ , Reassign (Var "EntryPoint") (String "main")
+ , Exp $ simpleCall "ValueReader" []
+ ]
+
+ defines' = [ Escape csScalar
+ , Escape csMemory
+ , Escape csPanic
+ , Escape csExceptions
+ , Escape csReader] ++ defines
+
+ parse_options =
+ generateOptionParser (standardOptions ++ options)
+
+ selectEntryPoint entry_point_names entry_points =
+ [ Assign (Var "EntryPoints") $
+ Collection "Dictionary<string, Action>" $ zipWith Pair (map String entry_point_names) entry_points,
+ If (simpleCall "!EntryPoints.ContainsKey" [Var "EntryPoint"])
+ [ Exp $ simpleCall "Console.Error.WriteLine"
+ [simpleCall "string.Format"
+ [ String "No entry point '{0}'. Select another with --entry point. Options are:\n{1}"
+ , Var "EntryPoint"
+ , simpleCall "string.Join"
+ [ String "\n"
+ , Field (Var "EntryPoints") "Keys" ]]]
+ , Exp $ simpleCall "Environment.Exit" [Integer 1]]
+ [ Assign (Var "entryPointFun") $
+ Index (Var "EntryPoints") (IdxExp $ Var "EntryPoint")
+ , Exp $ simpleCall "entryPointFun.Invoke" []]
+ ]
+
+
+compileFunc :: (Name, Imp.Function op) -> CompilerM op s CSFunDef
+compileFunc (fname, Imp.Function _ outputs inputs body _ _) = do
+ body' <- blockScope $ compileCode body
+ let inputs' = map compileTypedInput inputs
+ let outputs' = map compileOutput outputs
+ let outputDecls = map getDefaultDecl outputs
+ let (ret, retType) = unzip outputs'
+ let retType' = tupleOrSingleT retType
+ let ret' = [Return $ tupleOrSingle ret]
+
+ case outputs of
+ [] -> return $ Def (futharkFun . nameToString $ fname) VoidT inputs' (outputDecls++body')
+ _ -> return $ Def (futharkFun . nameToString $ fname) retType' inputs' (outputDecls++body'++ret')
+
+
+compileTypedInput :: Imp.Param -> (CSType, String)
+compileTypedInput input = (typeFun input, nameFun input)
+ where nameFun = compileName . Imp.paramName
+ typeFun = compileType . paramType
+
+tupleOrSingleEntryT :: [CSType] -> CSType
+tupleOrSingleEntryT [e] = e
+tupleOrSingleEntryT es = Composite $ SystemTupleT es
+
+tupleOrSingleEntry :: [CSExp] -> CSExp
+tupleOrSingleEntry [e] = e
+tupleOrSingleEntry es = CreateSystemTuple es
+
+tupleOrSingleT :: [CSType] -> CSType
+tupleOrSingleT [e] = e
+tupleOrSingleT es = Composite $ TupleT es
+
+tupleOrSingle :: [CSExp] -> CSExp
+tupleOrSingle [e] = e
+tupleOrSingle es = Tuple es
+
+assignScalarPointer :: CSExp -> CSExp -> CSStmt
+assignScalarPointer e ptr =
+ AssignTyped (PointerT VoidT) ptr (Just $ Addr e)
+
+-- | A 'Call' where the function is a variable and every argument is a
+-- simple 'Arg'.
+simpleCall :: String -> [CSExp] -> CSExp
+simpleCall fname = Call (Var fname) . map simpleArg
+
+-- | A 'Call' where the function is a variable and every argument is a
+-- simple 'Arg'.
+parametrizedCall :: String -> String -> [CSExp] -> CSExp
+parametrizedCall fname primtype = Call (Var fname') . map simpleArg
+ where fname' = concat [fname, "<", primtype, ">"]
+
+simpleArg :: CSExp -> CSArg
+simpleArg = Arg Nothing
+
+-- | A CallMethod
+callMethod :: CSExp -> String -> [CSExp] -> CSExp
+callMethod object method = CallMethod object (Var method) . map simpleArg
+
+simpleInitClass :: String -> [CSExp] -> CSExp
+simpleInitClass fname =CreateObject (Var fname) . map simpleArg
+
+compileName :: VName -> String
+compileName = zEncodeString . pretty
+
+compileType :: Imp.Type -> CSType
+compileType (Imp.Scalar p) = compilePrimTypeToAST p
+compileType (Imp.Mem _ space) = rawMemCSType space
+
+compilePrimTypeToAST :: PrimType -> CSType
+compilePrimTypeToAST (IntType Int8) = Primitive $ CSInt Int8T
+compilePrimTypeToAST (IntType Int16) = Primitive $ CSInt Int16T
+compilePrimTypeToAST (IntType Int32) = Primitive $ CSInt Int32T
+compilePrimTypeToAST (IntType Int64) = Primitive $ CSInt Int64T
+compilePrimTypeToAST (FloatType Float32) = Primitive $ CSFloat FloatT
+compilePrimTypeToAST (FloatType Float64) = Primitive $ CSFloat DoubleT
+compilePrimTypeToAST Imp.Bool = Primitive BoolT
+compilePrimTypeToAST Imp.Cert = Primitive BoolT
+
+compilePrimTypeToASText :: PrimType -> Imp.Signedness -> CSType
+compilePrimTypeToASText (IntType Int8) Imp.TypeUnsigned = Primitive $ CSUInt UInt8T
+compilePrimTypeToASText (IntType Int16) Imp.TypeUnsigned = Primitive $ CSUInt UInt16T
+compilePrimTypeToASText (IntType Int32) Imp.TypeUnsigned = Primitive $ CSUInt UInt32T
+compilePrimTypeToASText (IntType Int64) Imp.TypeUnsigned = Primitive $ CSUInt UInt64T
+compilePrimTypeToASText (IntType Int8) _ = Primitive $ CSInt Int8T
+compilePrimTypeToASText (IntType Int16) _ = Primitive $ CSInt Int16T
+compilePrimTypeToASText (IntType Int32) _ = Primitive $ CSInt Int32T
+compilePrimTypeToASText (IntType Int64) _ = Primitive $ CSInt Int64T
+compilePrimTypeToASText (FloatType Float32) _ = Primitive $ CSFloat FloatT
+compilePrimTypeToASText (FloatType Float64) _ = Primitive $ CSFloat DoubleT
+compilePrimTypeToASText Imp.Bool _ = Primitive BoolT
+compilePrimTypeToASText Imp.Cert _ = Primitive BoolT
+
+compileDim :: Imp.DimSize -> CSExp
+compileDim (Imp.ConstSize i) = Integer $ toInteger i
+compileDim (Imp.VarSize v) = Var $ compileName v
+
+unpackDim :: CSExp -> Imp.DimSize -> Int32 -> CompilerM op s ()
+unpackDim arr_name (Imp.ConstSize c) i = do
+ let shape_name = Field arr_name "Item2" -- array tuples are currently (data array * dimension array) currently
+ let constant_c = Integer $ toInteger c
+ let constant_i = Integer $ toInteger i
+ stm $ Assert (BinOp "==" constant_c (Index shape_name $ IdxExp constant_i)) [String "constant dimension wrong"]
+
+unpackDim arr_name (Imp.VarSize var) i = do
+ let shape_name = Field arr_name "Item2"
+ let src = Index shape_name $ IdxExp $ Integer $ toInteger i
+ let dest = Var $ compileName var
+ isAssigned <- getVarAssigned var
+ if isAssigned
+ then
+ stm $ Reassign dest $ Cast (Primitive $ CSInt Int32T) src
+ else do
+ stm $ Assign dest $ Cast (Primitive $ CSInt Int32T) src
+ setVarAssigned var
+
+entryPointOutput :: Imp.ExternalValue -> CompilerM op s CSExp
+entryPointOutput (Imp.OpaqueValue _ vs) =
+ CreateSystemTuple <$> mapM (entryPointOutput . Imp.TransparentValue) vs
+
+entryPointOutput (Imp.TransparentValue (Imp.ScalarValue bt ept name)) =
+ return $ cast $ Var $ compileName name
+ where cast = compileTypecastExt bt ept
+
+entryPointOutput (Imp.TransparentValue (Imp.ArrayValue mem _ Imp.DefaultSpace bt ept dims)) = do
+ let src = Var $ compileName mem
+ let createTuple = "createTuple_" ++ compilePrimTypeExt bt ept
+ return $ simpleCall createTuple [src, CreateArray (Primitive $ CSInt Int64T) $ map compileDim dims]
+
+entryPointOutput (Imp.TransparentValue (Imp.ArrayValue mem _ (Imp.Space sid) bt ept dims)) = do
+ unRefMem mem (Imp.Space sid)
+ pack_output <- asks envEntryOutput
+ pack_output mem sid bt ept dims
+
+entryPointInput :: (Int, Imp.ExternalValue, CSExp) -> CompilerM op s ()
+entryPointInput (i, Imp.OpaqueValue _ vs, e) =
+ mapM_ entryPointInput $ zip3 (repeat i) (map Imp.TransparentValue vs) $
+ map (\idx -> Field e $ "Item" ++ show (idx :: Int)) [1..]
+
+entryPointInput (_, Imp.TransparentValue (Imp.ScalarValue bt _ name), e) = do
+ let vname' = Var $ compileName name
+ cast = compileTypecast bt
+ stm $ Assign vname' (cast e)
+
+entryPointInput (_, Imp.TransparentValue (Imp.ArrayValue mem memsize Imp.DefaultSpace bt _ dims), e) = do
+ zipWithM_ (unpackDim e) dims [0..]
+ let arrayData = Field e "Item1"
+ let dest = Var $ compileName mem
+ unwrap_call = simpleCall "unwrapArray" [arrayData, sizeOf $ compilePrimTypeToAST bt]
+ case memsize of
+ Imp.VarSize sizevar ->
+ stm $ Assign (Var $ compileName sizevar) $ Field e "Item2.Length"
+ Imp.ConstSize _ ->
+ return ()
+ stm $ Assign dest unwrap_call
+
+entryPointInput (_, Imp.TransparentValue (Imp.ArrayValue mem memsize (Imp.Space sid) bt ept dims), e) = do
+ unpack_input <- asks envEntryInput
+ unpack <- collect $ unpack_input mem memsize sid bt ept dims e
+ stms unpack
+
+extValueDescName :: Imp.ExternalValue -> String
+extValueDescName (Imp.TransparentValue v) = extName $ valueDescName v
+extValueDescName (Imp.OpaqueValue desc []) = extName $ zEncodeString desc
+extValueDescName (Imp.OpaqueValue desc (v:_)) =
+ extName $ zEncodeString desc ++ "_" ++ pretty (baseTag (valueDescVName v))
+
+extName :: String -> String
+extName = (++"_ext")
+
+sizeOf :: CSType -> CSExp
+sizeOf t = simpleCall "sizeof" [(Var . pretty) t]
+
+publicFunDef :: String -> CSType -> [(CSType, String)] -> [CSStmt] -> CSStmt
+publicFunDef s t args stmts = PublicFunDef $ Def s t args stmts
+
+privateFunDef :: String -> CSType -> [(CSType, String)] -> [CSStmt] -> CSStmt
+privateFunDef s t args stmts = PrivateFunDef $ Def s t args stmts
+
+valueDescName :: Imp.ValueDesc -> String
+valueDescName = compileName . valueDescVName
+
+valueDescVName :: Imp.ValueDesc -> VName
+valueDescVName (Imp.ScalarValue _ _ vname) = vname
+valueDescVName (Imp.ArrayValue vname _ _ _ _ _) = vname
+
+consoleWrite :: String -> [CSExp] -> CSExp
+consoleWrite str exps = simpleCall "Console.Write" $ String str:exps
+
+consoleWriteLine :: String -> [CSExp] -> CSExp
+consoleWriteLine str exps = simpleCall "Console.WriteLine" $ String str:exps
+
+consoleErrorWrite :: String -> [CSExp] -> CSExp
+consoleErrorWrite str exps = simpleCall "Console.Error.Write" $ String str:exps
+
+consoleErrorWriteLine :: String -> [CSExp] -> CSExp
+consoleErrorWriteLine str exps = simpleCall "Console.Error.WriteLine" $ String str:exps
+
+readFun :: PrimType -> Imp.Signedness -> String
+readFun (FloatType Float32) _ = "ReadF32"
+readFun (FloatType Float64) _ = "ReadF64"
+readFun (IntType Int8) Imp.TypeUnsigned = "ReadU8"
+readFun (IntType Int16) Imp.TypeUnsigned = "ReadU16"
+readFun (IntType Int32) Imp.TypeUnsigned = "ReadU32"
+readFun (IntType Int64) Imp.TypeUnsigned = "ReadU64"
+readFun (IntType Int8) Imp.TypeDirect = "ReadI8"
+readFun (IntType Int16) Imp.TypeDirect = "ReadI16"
+readFun (IntType Int32) Imp.TypeDirect = "ReadI32"
+readFun (IntType Int64) Imp.TypeDirect = "ReadI64"
+readFun Imp.Bool _ = "ReadBool"
+readFun Cert _ = error "readFun: cert"
+
+readBinFun :: PrimType -> Imp.Signedness -> String
+readBinFun (FloatType Float32) _bin_ = "ReadBinF32"
+readBinFun (FloatType Float64) _bin_ = "ReadBinF64"
+readBinFun (IntType Int8) Imp.TypeUnsigned = "ReadBinU8"
+readBinFun (IntType Int16) Imp.TypeUnsigned = "ReadBinU16"
+readBinFun (IntType Int32) Imp.TypeUnsigned = "ReadBinU32"
+readBinFun (IntType Int64) Imp.TypeUnsigned = "ReadBinU64"
+readBinFun (IntType Int8) Imp.TypeDirect = "ReadBinI8"
+readBinFun (IntType Int16) Imp.TypeDirect = "ReadBinI16"
+readBinFun (IntType Int32) Imp.TypeDirect = "ReadBinI32"
+readBinFun (IntType Int64) Imp.TypeDirect = "ReadBinI64"
+readBinFun Imp.Bool _ = "ReadBinBool"
+readBinFun Cert _ = error "readFun: cert"
+
+-- The value returned will be used when reading binary arrays, to indicate what
+-- the expected type is
+-- Key into the FUTHARK_PRIMTYPES dict.
+readTypeEnum :: PrimType -> Imp.Signedness -> String
+readTypeEnum (IntType Int8) Imp.TypeUnsigned = "u8"
+readTypeEnum (IntType Int16) Imp.TypeUnsigned = "u16"
+readTypeEnum (IntType Int32) Imp.TypeUnsigned = "u32"
+readTypeEnum (IntType Int64) Imp.TypeUnsigned = "u64"
+readTypeEnum (IntType Int8) Imp.TypeDirect = "i8"
+readTypeEnum (IntType Int16) Imp.TypeDirect = "i16"
+readTypeEnum (IntType Int32) Imp.TypeDirect = "i32"
+readTypeEnum (IntType Int64) Imp.TypeDirect = "i64"
+readTypeEnum (FloatType Float32) _ = "f32"
+readTypeEnum (FloatType Float64) _ = "f64"
+readTypeEnum Imp.Bool _ = "bool"
+readTypeEnum Cert _ = error "readTypeEnum: cert"
+
+readInput :: Imp.ExternalValue -> CSStmt
+readInput (Imp.OpaqueValue desc _) =
+ Throw $ simpleInitClass "Exception" [String $ "Cannot read argument of type " ++ desc ++ "."]
+
+readInput decl@(Imp.TransparentValue (Imp.ScalarValue bt ept _)) =
+ let read_func = Var $ readFun bt ept
+ read_bin_func = Var $ readBinFun bt ept
+ type_enum = String $ readTypeEnum bt ept
+ bt' = compilePrimTypeExt bt ept
+ readScalar = initializeGenericFunction "ReadScalar" bt'
+ in Assign (Var $ extValueDescName decl) $ simpleCall readScalar [type_enum, read_func, read_bin_func]
+
+-- TODO: If the type identifier of 'Float32' is changed, currently the error
+-- messages for reading binary input will not use this new name. This is also a
+-- problem for the C runtime system.
+readInput decl@(Imp.TransparentValue (Imp.ArrayValue _ _ _ bt ept dims)) =
+ let rank' = Var $ show $ length dims
+ type_enum = String $ readTypeEnum bt ept
+ bt' = compilePrimTypeExt bt ept
+ read_func = Var $ readFun bt ept
+ readArray = initializeGenericFunction "ReadArray" bt'
+ in Assign (Var $ extValueDescName decl) $ simpleCall readArray [rank', type_enum, read_func]
+
+initializeGenericFunction :: String -> String -> String
+initializeGenericFunction fun tp = fun ++ "<" ++ tp ++ ">"
+
+
+printPrimStm :: CSExp -> CSStmt
+printPrimStm val = Exp $ simpleCall "WriteValue" [val]
+
+formatString :: String -> [CSExp] -> CSExp
+formatString fmt contents =
+ simpleCall "String.Format" $ String fmt : contents
+
+printStm :: Imp.ValueDesc -> CSExp -> CSExp -> CompilerM op s CSStmt
+printStm Imp.ScalarValue{} _ e =
+ return $ printPrimStm e
+printStm (Imp.ArrayValue _ _ _ _ _ []) ind e = do
+ let e' = Index e (IdxExp (PostUnOp "++" ind))
+ return $ printPrimStm e'
+
+printStm (Imp.ArrayValue mem memsize space bt ept (outer:shape)) ind e = do
+ ptr <- newVName "shapePtr"
+ first <- newVName "printFirst"
+ let size = callMethod (CreateArray (Primitive $ CSInt Int32T) $ map compileDim $ outer:shape)
+ "Aggregate" [ Integer 1
+ , Lambda (Tuple [Var "acc", Var "val"])
+ [Exp $ BinOp "*" (Var "acc") (Var "val")]
+ ]
+ emptystr = "empty(" ++ ppArrayType bt (length shape) ++ ")"
+
+ printelem <- printStm (Imp.ArrayValue mem memsize space bt ept shape) ind e
+ return $
+ If (BinOp "==" size (Integer 0))
+ [puts emptystr]
+ [ Assign (Var $ pretty first) $ Var "true"
+ , puts "["
+ , For (pretty ptr) (compileDim outer)
+ [ If (simpleCall "!" [Var $ pretty first]) [puts ", "] []
+ , printelem
+ , Reassign (Var $ pretty first) $ Var "false"
+ ]
+ , puts "]"
+ ]
+
+ where ppArrayType :: PrimType -> Int -> String
+ ppArrayType t 0 = prettyPrimType ept t
+ ppArrayType t n = "[]" ++ ppArrayType t (n-1)
+
+ prettyPrimType Imp.TypeUnsigned (IntType Int8) = "u8"
+ prettyPrimType Imp.TypeUnsigned (IntType Int16) = "u16"
+ prettyPrimType Imp.TypeUnsigned (IntType Int32) = "u32"
+ prettyPrimType Imp.TypeUnsigned (IntType Int64) = "u64"
+ prettyPrimType _ t = pretty t
+
+ puts s = Exp $ simpleCall "Console.Write" [String s]
+
+printValue :: [(Imp.ExternalValue, CSExp)] -> CompilerM op s [CSStmt]
+printValue = fmap concat . mapM (uncurry printValue')
+ -- We copy non-host arrays to the host before printing. This is
+ -- done in a hacky way - we assume the value has a .get()-method
+ -- that returns an equivalent Numpy array. This works for CSOpenCL,
+ -- but we will probably need yet another plugin mechanism here in
+ -- the future.
+ where printValue' (Imp.OpaqueValue desc _) _ =
+ return [Exp $ simpleCall "Console.Write"
+ [String $ "#<opaque " ++ desc ++ ">"]]
+ printValue' (Imp.TransparentValue r@Imp.ScalarValue{}) e = do
+ p <- printStm r (Integer 0) e
+ return [p, Exp $ simpleCall "Console.Write" [String "\n"]]
+ printValue' (Imp.TransparentValue r@Imp.ArrayValue{}) e = do
+ tuple <- newVName "resultArr"
+ i <- newVName "arrInd"
+ let i' = Var $ compileName i
+ p <- printStm r i' (Var $ compileName tuple)
+ let e' = Var $ pretty e
+ return [ Assign (Var $ compileName tuple) (Field e' "Item1")
+ , Assign i' (Integer 0)
+ , p
+ , Exp $ simpleCall "Console.Write" [String "\n"]]
+
+prepareEntry :: (Name, Imp.Function op) -> CompilerM op s
+ (String, [(CSType, String)], CSType, [CSStmt], [CSStmt], [CSStmt], [CSStmt],
+ [(Imp.ExternalValue, CSExp)], [CSStmt])
+prepareEntry (fname, Imp.Function _ outputs inputs _ results args) = do
+ let (output_types, output_paramNames) = unzip $ map compileTypedInput outputs
+ funTuple = tupleOrSingle $ fmap Var output_paramNames
+
+
+ (_, sizeDecls) <- collect' $ forM args declsfunction
+
+ (argexps_mem_copies, prepare_run) <- collect' $ forM inputs $ \case
+ Imp.MemParam name space -> do
+ -- A program might write to its input parameters, so create a new memory
+ -- block and copy the source there. This way the program can be run more
+ -- than once.
+ name' <- newVName $ baseString name <> "_copy"
+ copy <- asks envCopy
+ allocate <- asks envAllocate
+
+ let size = Var (compileName name ++ "_nbytes")
+ dest = name'
+ src = name
+ offset = Integer 0
+ case space of
+ DefaultSpace ->
+ stm $ Reassign (Var (compileName name'))
+ (simpleCall "allocateMem" [size]) -- FIXME
+ Space sid ->
+ allocate name' size sid
+ copy dest offset space src offset space size (IntType Int64) -- FIXME
+ return $ Just (compileName name')
+ _ -> return Nothing
+
+ prepareIn <- collect $ mapM_ entryPointInput $ zip3 [0..] args $
+ map (Var . extValueDescName) args
+ (res, prepareOut) <- collect' $ mapM entryPointOutput results
+
+ let mem_copies = mapMaybe liftMaybe $ zip argexps_mem_copies inputs
+ mem_copy_inits = map initCopy mem_copies
+
+ argexps_lib = map (compileName . Imp.paramName) inputs
+ argexps_bin = zipWith fromMaybe argexps_lib argexps_mem_copies
+ fname' = futharkFun (nameToString fname)
+ arg_types = map (fst . compileTypedInput) inputs
+ inputs' = zip arg_types (map extValueDescName args)
+ output_type = tupleOrSingleEntryT output_types
+ call_lib = [Reassign funTuple $ simpleCall fname' (fmap Var argexps_lib)]
+ call_bin = [Reassign funTuple $ simpleCall fname' (fmap Var argexps_bin)]
+ prepareIn' = prepareIn ++ mem_copy_inits ++ sizeDecls
+
+ return (nameToString fname, inputs', output_type,
+ prepareIn', call_lib, call_bin, prepareOut,
+ zip results res, prepare_run)
+
+ where liftMaybe (Just a, b) = Just (a,b)
+ liftMaybe _ = Nothing
+
+ initCopy (varName, Imp.MemParam _ space) = declMem' varName space
+ initCopy _ = Pass
+
+ valueDescFun (Imp.ArrayValue mem _ Imp.DefaultSpace _ _ _) =
+ stm $ Assign (Var $ compileName mem ++ "_nbytes") (Var $ compileName mem ++ ".Length")
+ valueDescFun (Imp.ArrayValue mem _ (Imp.Space _) bt _ dims) =
+ stm $ Assign (Var $ compileName mem ++ "_nbytes") $ foldr (BinOp "*" . compileDim) (sizeOf $ compilePrimTypeToAST bt) dims
+ valueDescFun _ = stm Pass
+
+ declsfunction (Imp.TransparentValue v) = valueDescFun v
+ declsfunction (Imp.OpaqueValue _ vs) = mapM_ valueDescFun vs
+
+copyMemoryDefaultSpace :: VName -> CSExp -> VName -> CSExp -> CSExp ->
+ CompilerM op s ()
+copyMemoryDefaultSpace destmem destidx srcmem srcidx nbytes =
+ stm $ Exp $ simpleCall "Buffer.BlockCopy" [ Var (compileName srcmem), srcidx
+ , Var (compileName destmem), destidx,
+ nbytes]
+
+compileEntryFun :: [CSStmt] -> (Name, Imp.Function op)
+ -> CompilerM op s CSFunDef
+compileEntryFun pre_timing entry@(_,Imp.Function _ outputs _ _ results args) = do
+ let params = map (getType &&& extValueDescName) args
+ let outputType = tupleOrSingleEntryT $ map getType results
+
+ (fname', _, _, prepareIn, body_lib, _, prepareOut, res, _) <- prepareEntry entry
+ let ret = Return $ tupleOrSingleEntry $ map snd res
+ let outputDecls = map getDefaultDecl outputs
+ do_run = body_lib ++ pre_timing
+ (do_run_with_timing, close_runtime_file) <- addTiming do_run
+
+ let do_warmup_run = If (Var "DoWarmupRun") do_run []
+ do_num_runs = For "i" (Var "NumRuns") do_run_with_timing
+
+ return $ Def fname' outputType params $
+ prepareIn ++ outputDecls ++ [do_warmup_run, do_num_runs, close_runtime_file] ++ prepareOut ++ [ret]
+
+ where getType :: Imp.ExternalValue -> CSType
+ getType (Imp.OpaqueValue _ valueDescs) =
+ let valueDescs' = map getType' valueDescs
+ in Composite $ SystemTupleT valueDescs'
+ getType (Imp.TransparentValue valueDesc) =
+ getType' valueDesc
+
+ getType' :: Imp.ValueDesc -> CSType
+ getType' (Imp.ScalarValue primtype signedness _) =
+ compilePrimTypeToASText primtype signedness
+ getType' (Imp.ArrayValue _ _ _ primtype signedness _) =
+ let t = compilePrimTypeToASText primtype signedness
+ in Composite $ SystemTupleT [Composite $ ArrayT t, Composite $ ArrayT $ Primitive $ CSInt Int64T]
+
+
+callEntryFun :: [CSStmt] -> (Name, Imp.Function op)
+ -> CompilerM op s (CSFunDef, String, CSExp)
+callEntryFun pre_timing entry@(fname, Imp.Function _ outputs _ _ _ decl_args) =
+ if any isOpaque decl_args then
+ return (Def fname' VoidT [] [exitException], nameToString fname, Var fname')
+ else do
+ (_, _, _, prepareIn, _, body_bin, prepare_out, res, prepare_run) <- prepareEntry entry
+ let str_input = map readInput decl_args
+
+ let outputDecls = map getDefaultDecl outputs
+ exitcall = [
+ Exp $ simpleCall "Console.Error.WriteLine" [formatString "Assertion.{0} failed" [Var "e"]]
+ , Exp $ simpleCall "Environment.Exit" [Integer 1]
+ ]
+ except' = Catch (Var "Exception") exitcall
+ do_run = body_bin ++ pre_timing
+ (do_run_with_timing, close_runtime_file) <- addTiming do_run
+
+ -- We ignore overflow errors and the like for executable entry
+ -- points. These are (somewhat) well-defined in Futhark.
+
+ let maybe_free =
+ [If (BinOp "<" (Var "i") (BinOp "-" (Var "NumRuns") (Integer 1)))
+ prepare_out []]
+
+ do_warmup_run =
+ If (Var "DoWarmupRun") (prepare_run ++ do_run ++ prepare_out) []
+
+ do_num_runs =
+ For "i" (Var "NumRuns") (prepare_run ++ do_run_with_timing ++ maybe_free)
+
+ str_output <- printValue res
+
+ return (Def fname' VoidT [] $
+ str_input ++ prepareIn ++ outputDecls ++
+ [Try [do_warmup_run, do_num_runs] [except']] ++
+ [close_runtime_file] ++
+ str_output,
+
+ nameToString fname,
+
+ Var fname')
+
+ where fname' = "entry_" ++ nameToString fname
+ isOpaque Imp.TransparentValue{} = False
+ isOpaque _ = True
+
+ exitException = Throw $ simpleInitClass "Exception" [String $ "The function " ++ nameToString fname ++ " is not available as an entry function."]
+
+addTiming :: [CSStmt] -> CompilerM s op ([CSStmt], CSStmt)
+addTiming statements = do
+ syncFun <- asks envSyncFun
+
+ return ([ Assign (Var "StopWatch") $ simpleInitClass "Stopwatch" []
+ , syncFun
+ , Exp $ simpleCall "StopWatch.Start" [] ] ++
+ statements ++
+ [ syncFun
+ , Exp $ simpleCall "StopWatch.Stop" []
+ , Assign (Var "timeElapsed") $ asMicroseconds (Var "StopWatch")
+ , If (not_null (Var "RuntimeFile")) [print_runtime] []
+ ]
+ , If (not_null (Var "RuntimeFile")) [
+ Exp $ simpleCall "RuntimeFileWriter.Close" [] ,
+ Exp $ simpleCall "RuntimeFile.Close" []
+ ] []
+ )
+
+ where print_runtime = Exp $ simpleCall "RuntimeFileWriter.WriteLine" [ callMethod (Var "timeElapsed") "ToString" [] ]
+ not_null var = BinOp "!=" var Null
+ asMicroseconds watch =
+ BinOp "/" (Field watch "ElapsedTicks")
+ (BinOp "/" (Field (Var "TimeSpan") "TicksPerMillisecond") (Integer 1000))
+
+compileUnOp :: Imp.UnOp -> String
+compileUnOp op =
+ case op of
+ Not -> "!"
+ Complement{} -> "~"
+ Abs{} -> "Math.Abs" -- actually write these helpers
+ FAbs{} -> "Math.Abs"
+ SSignum{} -> "ssignum"
+ USignum{} -> "usignum"
+
+compileBinOpLike :: Monad m =>
+ Imp.Exp -> Imp.Exp
+ -> CompilerM op s (CSExp, CSExp, String -> m CSExp)
+compileBinOpLike x y = do
+ x' <- compileExp x
+ y' <- compileExp y
+ let simple s = return $ BinOp s x' y'
+ return (x', y', simple)
+
+-- | The ctypes type corresponding to a 'PrimType'.
+compilePrimType :: PrimType -> String
+compilePrimType t =
+ case t of
+ IntType Int8 -> "sbyte"
+ IntType Int16 -> "short"
+ IntType Int32 -> "int"
+ IntType Int64 -> "long"
+ FloatType Float32 -> "float"
+ FloatType Float64 -> "double"
+ Imp.Bool -> "bool"
+ Cert -> "bool"
+
+-- | The ctypes type corresponding to a 'PrimType', taking sign into account.
+compilePrimTypeExt :: PrimType -> Imp.Signedness -> String
+compilePrimTypeExt t ept =
+ case (t, ept) of
+ (IntType Int8, Imp.TypeUnsigned) -> "byte"
+ (IntType Int16, Imp.TypeUnsigned) -> "ushort"
+ (IntType Int32, Imp.TypeUnsigned) -> "uint"
+ (IntType Int64, Imp.TypeUnsigned) -> "ulong"
+ (IntType Int8, _) -> "sbyte"
+ (IntType Int16, _) -> "short"
+ (IntType Int32, _) -> "int"
+ (IntType Int64, _) -> "long"
+ (FloatType Float32, _) -> "float"
+ (FloatType Float64, _) -> "double"
+ (Imp.Bool, _) -> "bool"
+ (Cert, _) -> "byte"
+
+-- | Select function to retrieve bytes from byte array as specific data type
+-- | The ctypes type corresponding to a 'PrimType'.
+compileTypecastExt :: PrimType -> Imp.Signedness -> (CSExp -> CSExp)
+compileTypecastExt t ept =
+ let t' = case (t, ept) of
+ (IntType Int8 , Imp.TypeUnsigned)-> Primitive $ CSUInt UInt8T
+ (IntType Int16 , Imp.TypeUnsigned)-> Primitive $ CSUInt UInt16T
+ (IntType Int32 , Imp.TypeUnsigned)-> Primitive $ CSUInt UInt32T
+ (IntType Int64 , Imp.TypeUnsigned)-> Primitive $ CSUInt UInt64T
+ (IntType Int8 , _)-> Primitive $ CSInt Int8T
+ (IntType Int16 , _)-> Primitive $ CSInt Int16T
+ (IntType Int32 , _)-> Primitive $ CSInt Int32T
+ (IntType Int64 , _)-> Primitive $ CSInt Int64T
+ (FloatType Float32, _)-> Primitive $ CSFloat FloatT
+ (FloatType Float64, _)-> Primitive $ CSFloat DoubleT
+ (Imp.Bool , _)-> Primitive BoolT
+ (Cert, _)-> Primitive $ CSInt Int8T
+ in Cast t'
+
+-- | The ctypes type corresponding to a 'PrimType'.
+compileTypecast :: PrimType -> (CSExp -> CSExp)
+compileTypecast t =
+ let t' = case t of
+ IntType Int8 -> Primitive $ CSInt Int8T
+ IntType Int16 -> Primitive $ CSInt Int16T
+ IntType Int32 -> Primitive $ CSInt Int32T
+ IntType Int64 -> Primitive $ CSInt Int64T
+ FloatType Float32 -> Primitive $ CSFloat FloatT
+ FloatType Float64 -> Primitive $ CSFloat DoubleT
+ Imp.Bool -> Primitive BoolT
+ Cert -> Primitive $ CSInt Int8T
+ in Cast t'
+
+-- | The ctypes type corresponding to a 'PrimType'.
+compilePrimValue :: Imp.PrimValue -> CSExp
+compilePrimValue (IntValue (Int8Value v)) =
+ Cast (Primitive $ CSInt Int8T) $ Integer $ toInteger v
+compilePrimValue (IntValue (Int16Value v)) =
+ Cast (Primitive $ CSInt Int16T) $ Integer $ toInteger v
+compilePrimValue (IntValue (Int32Value v)) =
+ Cast (Primitive $ CSInt Int32T) $ Integer $ toInteger v
+compilePrimValue (IntValue (Int64Value v)) =
+ Cast (Primitive $ CSInt Int64T) $ Integer $ toInteger v
+compilePrimValue (FloatValue (Float32Value v))
+ | isInfinite v =
+ if v > 0 then Var "Single.PositiveInfinity" else Var "Single.NegativeInfinity"
+ | isNaN v =
+ Var "Single.NaN"
+ | otherwise = Cast (Primitive $ CSFloat FloatT) (Float $ fromRational $ toRational v)
+compilePrimValue (FloatValue (Float64Value v))
+ | isInfinite v =
+ if v > 0 then Var "Double.PositiveInfinity" else Var "Double.NegativeInfinity"
+ | isNaN v =
+ Var "Double.NaN"
+ | otherwise = Cast (Primitive $ CSFloat DoubleT) (Float $ fromRational $ toRational v)
+compilePrimValue (BoolValue v) = Bool v
+compilePrimValue Checked = Bool True
+
+compileExp :: Imp.Exp -> CompilerM op s CSExp
+
+compileExp (Imp.ValueExp v) = return $ compilePrimValue v
+
+compileExp (Imp.LeafExp (Imp.ScalarVar vname) _) =
+ return $ Var $ compileName vname
+
+compileExp (Imp.LeafExp (Imp.SizeOf t) _) =
+ return $ (compileTypecast $ IntType Int32) (Integer $ primByteSize t)
+
+compileExp (Imp.LeafExp (Imp.Index src (Imp.Count iexp) (IntType Int8) DefaultSpace _) _) = do
+ let src' = compileName src
+ iexp' <- compileExp iexp
+ return $ Cast (Primitive $ CSInt Int8T) (Index (Var src') (IdxExp iexp'))
+
+compileExp (Imp.LeafExp (Imp.Index src (Imp.Count iexp) bt DefaultSpace _) _) = do
+ iexp' <- compileExp iexp
+ let bt' = compilePrimType bt
+ return $ simpleCall ("indexArray_" ++ bt') [Var $ compileName src, iexp']
+
+compileExp (Imp.LeafExp (Imp.Index src (Imp.Count iexp) restype (Imp.Space space) _) _) =
+ join $ asks envReadScalar
+ <*> pure src <*> compileExp iexp
+ <*> pure restype <*> pure space
+
+compileExp (Imp.BinOpExp op x y) = do
+ (x', y', simple) <- compileBinOpLike x y
+ case op of
+ FAdd{} -> simple "+"
+ FSub{} -> simple "-"
+ FMul{} -> simple "*"
+ FDiv{} -> simple "/"
+ LogAnd{} -> simple "&&"
+ LogOr{} -> simple "||"
+ _ -> return $ simpleCall (pretty op) [x', y']
+
+compileExp (Imp.ConvOpExp conv x) = do
+ x' <- compileExp x
+ return $ simpleCall (pretty conv) [x']
+
+compileExp (Imp.CmpOpExp cmp x y) = do
+ (x', y', simple) <- compileBinOpLike x y
+ case cmp of
+ CmpEq{} -> simple "=="
+ FCmpLt{} -> simple "<"
+ FCmpLe{} -> simple "<="
+ _ -> return $ simpleCall (pretty cmp) [x', y']
+
+compileExp (Imp.UnOpExp op exp1) =
+ PreUnOp (compileUnOp op) <$> compileExp exp1
+
+compileExp (Imp.FunExp h args _) =
+ simpleCall (futharkFun (pretty h)) <$> mapM compileExp args
+
+compileCode :: Imp.Code op -> CompilerM op s ()
+
+compileCode Imp.DebugPrint{} =
+ return ()
+
+compileCode (Imp.Op op) =
+ join $ asks envOpCompiler <*> pure op
+
+compileCode (Imp.If cond tb fb) = do
+ cond' <- compileExp cond
+ tb' <- blockScope $ compileCode tb
+ fb' <- blockScope $ compileCode fb
+ stm $ If cond' tb' fb'
+
+compileCode (c1 Imp.:>>: c2) = do
+ compileCode c1
+ compileCode c2
+
+compileCode (Imp.While cond body) = do
+ cond' <- compileExp cond
+ body' <- blockScope $ compileCode body
+ stm $ While cond' body'
+
+compileCode (Imp.For i it bound body) = do
+ bound' <- compileExp bound
+ let i' = compileName i
+ body' <- blockScope $ compileCode body
+ counter <- pretty <$> newVName "counter"
+ one <- pretty <$> newVName "one"
+ stm $ Assign (Var i') $ compileTypecast (IntType it) (Integer 0)
+ stm $ Assign (Var one) $ compileTypecast (IntType it) (Integer 1)
+ stm $ For counter bound' $ body' ++
+ [AssignOp "+" (Var i') (Var one)]
+
+
+compileCode (Imp.SetScalar vname exp1) = do
+ let name' = Var $ compileName vname
+ exp1' <- compileExp exp1
+ stm $ Reassign name' exp1'
+
+compileCode (Imp.DeclareMem v space) = declMem v space
+
+compileCode (Imp.DeclareScalar v Cert) =
+ stm $ Assign (Var $ compileName v) $ Bool True
+compileCode (Imp.DeclareScalar v t) =
+ stm $ AssignTyped t' (Var $ compileName v) Nothing
+ where t' = compilePrimTypeToAST t
+
+compileCode (Imp.DeclareArray name DefaultSpace t vs) =
+ stms [Assign (Var $ "init_"++name') $
+ simpleCall "unwrapArray"
+ [
+ CreateArray (compilePrimTypeToAST t) (map compilePrimValue vs)
+ , simpleCall "sizeof" [Var $ compilePrimType t]
+ ]
+ , Assign (Var name') $ Var ("init_"++name')
+ ]
+ where name' = compileName name
+
+
+compileCode (Imp.DeclareArray name (Space space) t vs) =
+ join $ asks envStaticArray <*>
+ pure name <*> pure space <*> pure t <*> pure vs
+
+compileCode (Imp.Comment s code) = do
+ code' <- blockScope $ compileCode code
+ stm $ Comment s code'
+
+compileCode (Imp.Assert e (Imp.ErrorMsg parts) (loc,locs)) = do
+ e' <- compileExp e
+ let onPart (i, Imp.ErrorString s) = return (printFormatArg i, String s)
+ onPart (i, Imp.ErrorInt32 x) = (printFormatArg i,) <$> compileExp x
+ (formatstrs, formatargs) <- unzip <$> mapM onPart (zip ([1..] :: [Integer]) parts)
+ stm $ Assert e' $ (String $ "Error at {0}:\n" <> concat formatstrs) : (String stacktrace : formatargs)
+ where stacktrace = intercalate " -> " (reverse $ map locStr $ loc:locs)
+ printFormatArg = printf "{%d}"
+
+compileCode (Imp.Call dests fname args) = do
+ args' <- mapM compileArg args
+ let dests' = tupleOrSingle $ fmap Var (map compileName dests)
+ fname' = futharkFun (pretty fname)
+ call' = simpleCall fname' args'
+ -- If the function returns nothing (is called only for side
+ -- effects), take care not to assign to an empty tuple.
+ stm $ if null dests
+ then Exp call'
+ else Reassign dests' call'
+ where compileArg (Imp.MemArg m) = return $ Var $ compileName m
+ compileArg (Imp.ExpArg e) = compileExp e
+
+compileCode (Imp.SetMem dest src DefaultSpace) = do
+ let src' = Var (compileName src)
+ let dest' = Var (compileName dest)
+ stm $ Reassign dest' src'
+
+compileCode (Imp.SetMem dest src _) = do
+ let src' = Var (compileName src)
+ let dest' = Var (compileName dest)
+ stm $ Exp $ simpleCall "MemblockSetDevice" [Ref $ Var "Ctx", Ref dest', Ref src', String (compileName src)]
+
+compileCode (Imp.Allocate name (Imp.Count e) DefaultSpace) = do
+ e' <- compileExp e
+ let allocate' = simpleCall "allocateMem" [e']
+ let name' = Var (compileName name)
+ stm $ Reassign name' allocate'
+
+compileCode (Imp.Allocate name (Imp.Count e) (Imp.Space space)) =
+ join $ asks envAllocate
+ <*> pure name
+ <*> compileExp e
+ <*> pure space
+
+compileCode (Imp.Free name space) = do
+ unRefMem name space
+ tell $ mempty { accFreedMem = [name] }
+
+compileCode (Imp.Copy dest (Imp.Count destoffset) DefaultSpace src (Imp.Count srcoffset) DefaultSpace (Imp.Count size)) = do
+ destoffset' <- compileExp destoffset
+ srcoffset' <- compileExp srcoffset
+ let dest' = Var (compileName dest)
+ let src' = Var (compileName src)
+ size' <- compileExp size
+ stm $ Exp $ simpleCall "Buffer.BlockCopy" [src', srcoffset', dest', destoffset', size']
+
+compileCode (Imp.Copy dest (Imp.Count destoffset) destspace src (Imp.Count srcoffset) srcspace (Imp.Count size)) = do
+ copy <- asks envCopy
+ join $ copy
+ <$> pure dest <*> compileExp destoffset <*> pure destspace
+ <*> pure src <*> compileExp srcoffset <*> pure srcspace
+ <*> compileExp size <*> pure (IntType Int64) -- FIXME
+
+compileCode (Imp.Write dest (Imp.Count idx) elemtype DefaultSpace _ elemexp) = do
+ idx' <- compileExp idx
+ elemexp' <- compileExp elemexp
+ let dest' = Var $ compileName dest
+ let elemtype' = compileTypecast elemtype
+ let ctype = elemtype' elemexp'
+ stm $ Exp $ simpleCall "writeScalarArray" [dest', idx', ctype]
+
+compileCode (Imp.Write dest (Imp.Count idx) elemtype (Imp.Space space) _ elemexp) =
+ join $ asks envWriteScalar
+ <*> pure dest
+ <*> compileExp idx
+ <*> pure elemtype
+ <*> pure space
+ <*> compileExp elemexp
+
+compileCode Imp.Skip = return ()
+
+blockScope :: CompilerM op s () -> CompilerM op s [CSStmt]
+blockScope = fmap snd . blockScope'
+
+blockScope' :: CompilerM op s a -> CompilerM op s (a, [CSStmt])
+blockScope' m = do
+ old_allocs <- gets compDeclaredMem
+ (x, items) <- pass $ do
+ (x, w) <- listen m
+ let items = accItems w
+ return ((x, items), const mempty)
+ new_allocs <- gets $ filter (`notElem` old_allocs) . compDeclaredMem
+ modify $ \s -> s { compDeclaredMem = old_allocs }
+ releases <- collect $ mapM_ (uncurry unRefMem) new_allocs
+ return (x, items <> releases)
+
+unRefMem :: VName -> Space -> CompilerM op s ()
+unRefMem mem (Space "device") =
+ (stm . Exp) $ simpleCall "MemblockUnrefDevice" [ Ref $ Var "Ctx"
+ , (Ref . Var . compileName) mem
+ , (String . compileName) mem]
+unRefMem _ DefaultSpace = stm Pass
+unRefMem _ (Space "local") = stm Pass
+unRefMem _ (Space _) = fail "The default compiler cannot compile unRefMem for other spaces"
+
+
+-- | Public names must have a consistent prefix.
+publicName :: String -> String
+publicName s = "Futhark" ++ s
+
+declMem :: VName -> Space -> CompilerM op s ()
+declMem name space = do
+ modify $ \s -> s { compDeclaredMem = (name, space) : compDeclaredMem s}
+ stm $ declMem' (compileName name) space
+
+declMem' :: String -> Space -> CSStmt
+declMem' name DefaultSpace =
+ AssignTyped (Composite $ ArrayT $ Primitive ByteT) (Var name) Nothing
+declMem' name (Space _) =
+ AssignTyped (CustomT "OpenCLMemblock") (Var name) (Just $ simpleCall "EmptyMemblock" [Var "Ctx.EMPTY_MEM_HANDLE"])
+
+rawMemCSType :: Space -> CSType
+rawMemCSType DefaultSpace = Composite $ ArrayT $ Primitive ByteT
+rawMemCSType (Space _) = CustomT "OpenCLMemblock"
+
+toIntPtr :: CSExp -> CSExp
+toIntPtr e = simpleInitClass "IntPtr" [e]
diff --git a/src/Futhark/CodeGen/Backends/PyOpenCL.hs b/src/Futhark/CodeGen/Backends/PyOpenCL.hs
index 075f648..d4483bb 100644
--- a/src/Futhark/CodeGen/Backends/PyOpenCL.hs
+++ b/src/Futhark/CodeGen/Backends/PyOpenCL.hs
@@ -101,23 +101,26 @@ callKernel (Imp.GetSizeMax v size_class) =
callKernel (Imp.HostCode c) =
Py.compileCode c
-callKernel (Imp.LaunchKernel name args kernel_size workgroup_size) = do
- kernel_size' <- mapM Py.compileExp kernel_size
- let total_elements = foldl mult_exp (Integer 1) kernel_size'
- let cond = BinOp "!=" total_elements (Integer 0)
- workgroup_size' <- Tuple <$> mapM (fmap asLong . Py.compileExp) workgroup_size
- body <- Py.collect $ launchKernel name kernel_size' workgroup_size' args
+callKernel (Imp.LaunchKernel name args num_workgroups workgroup_size) = do
+ num_workgroups' <- mapM (fmap asLong . Py.compileExp) num_workgroups
+ workgroup_size' <- mapM (fmap asLong . Py.compileExp) workgroup_size
+ let kernel_size = zipWith mult_exp num_workgroups' workgroup_size'
+ total_elements = foldl mult_exp (Integer 1) kernel_size
+ cond = BinOp "!=" total_elements (Integer 0)
+ body <- Py.collect $ launchKernel name kernel_size workgroup_size' args
Py.stm $ If cond body []
where mult_exp = BinOp "*"
-launchKernel :: String -> [PyExp] -> PyExp -> [Imp.KernelArg] -> Py.CompilerM op s ()
+launchKernel :: String -> [PyExp] -> [PyExp] -> [Imp.KernelArg]
+ -> Py.CompilerM op s ()
launchKernel kernel_name kernel_dims workgroup_dims args = do
- let kernel_dims' = Tuple $ map asLong kernel_dims
- let kernel_name' = "self." ++ kernel_name ++ "_var"
+ let kernel_dims' = Tuple kernel_dims
+ workgroup_dims' = Tuple workgroup_dims
+ kernel_name' = "self." ++ kernel_name ++ "_var"
args' <- mapM processKernelArg args
Py.stm $ Exp $ Py.simpleCall (kernel_name' ++ ".set_args") args'
Py.stm $ Exp $ Py.simpleCall "cl.enqueue_nd_range_kernel"
- [Var "self.queue", Var kernel_name', kernel_dims', workgroup_dims]
+ [Var "self.queue", Var kernel_name', kernel_dims', workgroup_dims']
finishIfSynchronous
where processKernelArg :: Imp.KernelArg -> Py.CompilerM op s PyExp
processKernelArg (Imp.ValueKArg e bt) = do
diff --git a/src/Futhark/CodeGen/Backends/PyOpenCL/Boilerplate.hs b/src/Futhark/CodeGen/Backends/PyOpenCL/Boilerplate.hs
index 8b111e5..c409f94 100644
--- a/src/Futhark/CodeGen/Backends/PyOpenCL/Boilerplate.hs
+++ b/src/Futhark/CodeGen/Backends/PyOpenCL/Boilerplate.hs
@@ -23,7 +23,7 @@ openClPrelude = $(embedStringFile "rts/python/opencl.py")
-- | Python code (as a string) that calls the
-- @initiatialize_opencl_object@ procedure. Should be put in the
-- class constructor.
-openClInit :: [PrimType] -> String -> M.Map VName (SizeClass, Name) -> String
+openClInit :: [PrimType] -> String -> M.Map Name SizeClass -> String
openClInit types assign sizes = T.unpack [text|
size_heuristics=$size_heuristics
program = initialise_opencl_object(self,
@@ -44,9 +44,9 @@ $assign'
where assign' = T.pack assign
size_heuristics = prettyText $ sizeHeuristicsToPython sizeHeuristicsTable
types' = prettyText $ map (show . pretty) types -- Looks enough like Python.
- sizes' = prettyText $ sizeClassesToPython $ M.map fst sizes
+ sizes' = prettyText $ sizeClassesToPython sizes
-sizeClassesToPython :: M.Map VName SizeClass -> PyExp
+sizeClassesToPython :: M.Map Name SizeClass -> PyExp
sizeClassesToPython = Dict . map f . M.toList
where f (size_name, size_class) =
(String $ pretty size_name,
diff --git a/src/Futhark/CodeGen/ImpCode.hs b/src/Futhark/CodeGen/ImpCode.hs
index 328c908..b5ad49c 100644
--- a/src/Futhark/CodeGen/ImpCode.hs
+++ b/src/Futhark/CodeGen/ImpCode.hs
@@ -62,7 +62,6 @@ import Data.List
import Data.Loc
import Data.Traversable
import qualified Data.Set as S
-import qualified Data.Semigroup as Sem
import Language.Futhark.Core
import Futhark.Representation.Primitive
@@ -94,12 +93,11 @@ paramName (ScalarParam name _) = name
-- | A collection of imperative functions.
newtype Functions a = Functions [(Name, Function a)]
-instance Sem.Semigroup (Functions a) where
+instance Semigroup (Functions a) where
Functions x <> Functions y = Functions $ x ++ y
instance Monoid (Functions a) where
mempty = Functions []
- mappend = (Sem.<>)
data Signedness = TypeUnsigned
| TypeDirect
@@ -146,7 +144,11 @@ data Code a = Skip
| DeclareMem VName Space
| DeclareScalar VName PrimType
| DeclareArray VName Space PrimType [PrimValue]
- -- ^ Create a read-only array containing the given values.
+ -- ^ Create an array containing the given values. The
+ -- lifetime of the array will be the entire application.
+ -- This is mostly used for constant arrays, but also for
+ -- some bookkeeping data, like the synchronisation
+ -- counts used to implement reduction.
| Allocate VName (Count Bytes) Space
-- ^ Memory space must match the corresponding
-- 'DeclareMem'.
@@ -186,14 +188,13 @@ data Code a = Skip
data Volatility = Volatile | Nonvolatile
deriving (Eq, Ord, Show)
-instance Sem.Semigroup (Code a) where
+instance Semigroup (Code a) where
Skip <> y = y
x <> Skip = x
x <> y = x :>>: y
instance Monoid (Code a) where
mempty = Skip
- mappend = (Sem.<>)
data ExpLeaf = ScalarVar VName
| SizeOf PrimType
diff --git a/src/Futhark/CodeGen/ImpCode/Kernels.hs b/src/Futhark/CodeGen/ImpCode/Kernels.hs
index ea3bf86..455e27c 100644
--- a/src/Futhark/CodeGen/ImpCode/Kernels.hs
+++ b/src/Futhark/CodeGen/ImpCode/Kernels.hs
@@ -12,8 +12,6 @@ module Futhark.CodeGen.ImpCode.Kernels
, HostOp (..)
, KernelOp (..)
, AtomicOp (..)
- , CallKernel (..)
- , MapKernel (..)
, Kernel (..)
, LocalMemoryUse
, KernelUse (..)
@@ -27,7 +25,6 @@ module Futhark.CodeGen.ImpCode.Kernels
import Control.Monad.Writer
import Data.List
-import qualified Data.Set as S
import Futhark.CodeGen.ImpCode hiding (Function, Code)
import qualified Futhark.CodeGen.ImpCode as Imp
@@ -39,42 +36,24 @@ import Futhark.Util.Pretty
type Program = Functions HostOp
type Function = Imp.Function HostOp
-- | Host-level code that can call kernels.
-type Code = Imp.Code CallKernel
+type Code = Imp.Code HostOp
-- | Code inside a kernel.
type KernelCode = Imp.Code KernelOp
-- | A run-time constant related to kernels.
-newtype KernelConst = SizeConst VName
+newtype KernelConst = SizeConst Name
deriving (Eq, Ord, Show)
-- | An expression whose variables are kernel constants.
type KernelConstExp = PrimExp KernelConst
-data HostOp = CallKernel CallKernel
- | GetSize VName VName SizeClass
- | CmpSizeLe VName VName SizeClass Imp.Exp
+data HostOp = CallKernel Kernel
+ | GetSize VName Name SizeClass
+ | CmpSizeLe VName Name SizeClass Imp.Exp
| GetSizeMax VName SizeClass
deriving (Show)
-data CallKernel = Map MapKernel
- | AnyKernel Kernel
- deriving (Show)
-
-- | A generic kernel containing arbitrary kernel code.
-data MapKernel = MapKernel { mapKernelThreadNum :: VName
- -- ^ Stm position - also serves as a unique
- -- name for the kernel.
- , mapKernelDesc :: String
- -- ^ Used to name the kernel for readability.
- , mapKernelBody :: Imp.Code KernelOp
- , mapKernelUses :: [KernelUse]
- , mapKernelNumGroups :: DimSize
- , mapKernelGroupSize :: DimSize
- , mapKernelSize :: Imp.Exp
- -- ^ Do not actually execute threads past this.
- }
- deriving (Show)
-
data Kernel = Kernel
{ kernelBody :: Imp.Code KernelOp
, kernelLocalMemory :: [LocalMemoryUse]
@@ -99,7 +78,7 @@ data KernelUse = ScalarUse VName PrimType
| ConstUse VName KernelConstExp
deriving (Eq, Show)
-getKernels :: Program -> [CallKernel]
+getKernels :: Program -> [Kernel]
getKernels = nubBy sameKernel . execWriter . traverse getFunKernels
where getFunKernels (CallKernel kernel) =
tell [kernel]
@@ -145,33 +124,17 @@ instance Pretty HostOp where
instance FreeIn HostOp where
freeIn (CallKernel c) = freeIn c
- freeIn (CmpSizeLe dest name _ x) =
- freeIn dest <> freeIn name <> freeIn x
+ freeIn (CmpSizeLe dest _ _ x) =
+ freeIn dest <> freeIn x
freeIn (GetSizeMax dest _) =
freeIn dest
freeIn (GetSize dest _ _) =
freeIn dest
-instance Pretty CallKernel where
- ppr (Map k) = ppr k
- ppr (AnyKernel k) = ppr k
-
-instance FreeIn CallKernel where
- freeIn (Map k) = freeIn k
- freeIn (AnyKernel k) = freeIn k
-
instance FreeIn Kernel where
freeIn kernel = freeIn (kernelBody kernel) <>
freeIn [kernelNumGroups kernel, kernelGroupSize kernel]
-instance Pretty MapKernel where
- ppr kernel =
- text "mapKernel" <+> brace
- (text "uses" <+> brace (commasep $ map ppr $ mapKernelUses kernel) </>
- text "body" <+> brace (ppr (mapKernelThreadNum kernel) <+>
- text "<- get_thread_number()" </>
- ppr (mapKernelBody kernel)))
-
instance Pretty Kernel where
ppr kernel =
text "kernel" <+> brace
@@ -187,10 +150,6 @@ instance Pretty Kernel where
ppLocalMemory (name, Right size) =
ppr name <+> parens (ppr size <+> text "bytes (const)")
-instance FreeIn MapKernel where
- freeIn kernel =
- mapKernelThreadNum kernel `S.delete` freeIn (mapKernelBody kernel)
-
data KernelOp = GetGroupId VName Int
| GetLocalId VName Int
| GetLocalSize VName Int
@@ -198,7 +157,8 @@ data KernelOp = GetGroupId VName Int
| GetGlobalId VName Int
| GetLockstepWidth VName
| Atomic AtomicOp
- | Barrier
+ | LocalBarrier
+ | GlobalBarrier
| MemFence
deriving (Show)
@@ -247,8 +207,10 @@ instance Pretty KernelOp where
ppr (GetLockstepWidth dest) =
ppr dest <+> text "<-" <+>
text "get_lockstep_width()"
- ppr Barrier =
- text "barrier()"
+ ppr LocalBarrier =
+ text "local_barrier()"
+ ppr GlobalBarrier =
+ text "global_barrier()"
ppr MemFence =
text "mem_fence()"
ppr (Atomic (AtomicAdd old arr ind x)) =
diff --git a/src/Futhark/CodeGen/ImpCode/OpenCL.hs b/src/Futhark/CodeGen/ImpCode/OpenCL.hs
index 1502297..8552871 100644
--- a/src/Futhark/CodeGen/ImpCode/OpenCL.hs
+++ b/src/Futhark/CodeGen/ImpCode/OpenCL.hs
@@ -14,6 +14,7 @@ module Futhark.CodeGen.ImpCode.OpenCL
, KernelName
, KernelArg (..)
, OpenCL (..)
+ , KernelTarget (..)
, module Futhark.CodeGen.ImpCode
, module Futhark.Representation.Kernels.Sizes
)
@@ -34,7 +35,7 @@ data Program = Program { openClProgram :: String
, openClKernelNames :: [KernelName]
, openClUsedTypes :: [PrimType]
-- ^ So we can detect whether the device is capable.
- , openClSizes :: M.Map VName (SizeClass, Name)
+ , openClSizes :: M.Map Name SizeClass
-- ^ Runtime-configurable constants.
, hostFunctions :: Functions OpenCL
}
@@ -60,10 +61,15 @@ data KernelArg = ValueKArg Exp PrimType
-- | Host-level OpenCL operation.
data OpenCL = LaunchKernel KernelName [KernelArg] [Exp] [Exp]
| HostCode Code
- | GetSize VName VName
- | CmpSizeLe VName VName Exp
+ | GetSize VName Name
+ | CmpSizeLe VName Name Exp
| GetSizeMax VName SizeClass
deriving (Show)
+-- | The target platform when compiling imperative code to a 'Program'
+data KernelTarget = TargetOpenCL
+ | TargetCUDA
+ deriving (Eq)
+
instance Pretty OpenCL where
ppr = text . show
diff --git a/src/Futhark/CodeGen/ImpGen.hs b/src/Futhark/CodeGen/ImpGen.hs
index 29179a6..9e322a8 100644
--- a/src/Futhark/CodeGen/ImpGen.hs
+++ b/src/Futhark/CodeGen/ImpGen.hs
@@ -20,7 +20,7 @@ module Futhark.CodeGen.ImpGen
-- * Monadic Compiler Interface
, ImpM
- , Env (envDefaultSpace)
+ , Env (envDefaultSpace, envFunction)
, VTable
, getVTable
, localVTable
@@ -79,7 +79,7 @@ module Futhark.CodeGen.ImpGen
, sIf, sWhen, sUnless
, sOp
, sAlloc
- , sArray
+ , sArray, sAllocArray, sStaticArray
, sWrite
, (<--)
)
@@ -152,6 +152,7 @@ data ArrayEntry = ArrayEntry {
entryArrayLocation :: MemLocation
, entryArrayElemType :: PrimType
}
+ deriving (Show)
entryArrayShape :: ArrayEntry -> [Imp.DimSize]
entryArrayShape = memLocationShape . entryArrayLocation
@@ -160,15 +161,18 @@ data MemEntry = MemEntry {
entryMemSize :: Imp.MemSize
, entryMemSpace :: Imp.Space
}
+ deriving (Show)
newtype ScalarEntry = ScalarEntry {
entryScalarType :: PrimType
}
+ deriving (Show)
-- | Every non-scalar variable must be associated with an entry.
data VarEntry lore = ArrayVar (Maybe (Exp lore)) ArrayEntry
| ScalarVar (Maybe (Exp lore)) ScalarEntry
| MemVar (Maybe (Exp lore)) MemEntry
+ deriving (Show)
-- | When compiling an expression, this is a description of where the
-- result should end up. The integer is a reference to the construct
@@ -198,16 +202,23 @@ data Env lore op = Env {
, envCopyCompiler :: CopyCompiler lore op
, envDefaultSpace :: Imp.Space
, envVolatility :: Imp.Volatility
+ , envFakeMemory :: [Space]
+ -- ^ Do not actually generate allocations for these memory spaces.
+ , envFunction :: Name
+ -- ^ Name of the function we are compiling.
}
-newEnv :: Operations lore op -> Imp.Space -> Env lore op
-newEnv ops ds = Env { envExpCompiler = opsExpCompiler ops
- , envStmsCompiler = opsStmsCompiler ops
- , envOpCompiler = opsOpCompiler ops
- , envCopyCompiler = opsCopyCompiler ops
- , envDefaultSpace = ds
- , envVolatility = Imp.Nonvolatile
- }
+newEnv :: Operations lore op -> Imp.Space -> [Imp.Space] -> Name -> Env lore op
+newEnv ops ds fake fname =
+ Env { envExpCompiler = opsExpCompiler ops
+ , envStmsCompiler = opsStmsCompiler ops
+ , envOpCompiler = opsOpCompiler ops
+ , envCopyCompiler = opsCopyCompiler ops
+ , envDefaultSpace = ds
+ , envVolatility = Imp.Nonvolatile
+ , envFakeMemory = fake
+ , envFunction = fname
+ }
-- | The symbol table used during compilation.
type VTable lore = M.Map VName (VarEntry lore)
@@ -247,11 +258,10 @@ instance HasScope SOACS (ImpM lore op) where
Prim $ entryScalarType scalarEntry
runImpM :: ImpM lore op a
- -> Operations lore op -> Imp.Space -> VNameSource
- -> Either InternalError (a, VNameSource, Imp.Code op, Imp.Functions op)
-runImpM (ImpM m) comp space src = do
- (a, s, code) <- runRWST m (newEnv comp space) (newState src)
- return (a, stateNameSource s, code, stateFunctions s)
+ -> Operations lore op -> Imp.Space -> [Imp.Space] -> Name -> State lore op
+ -> Either InternalError (a, State lore op, Imp.Code op)
+runImpM (ImpM m) comp space fake fname =
+ runRWST m (newEnv comp space fake fname)
subImpM_ :: Operations lore' op' -> ImpM lore' op' a
-> ImpM lore op (Imp.Code op')
@@ -311,13 +321,17 @@ hasFunction fname = gets $ \s -> let Imp.Functions fs = stateFunctions s
in isJust $ lookup fname fs
compileProg :: (ExplicitMemorish lore, MonadFreshNames m) =>
- Operations lore op -> Imp.Space
+ Operations lore op -> Imp.Space -> [Imp.Space]
-> Prog lore -> m (Either InternalError (Imp.Functions op))
-compileProg ops space prog =
+compileProg ops space fake prog =
modifyNameSource $ \src ->
- case runImpM (mapM_ compileFunDef $ progFunctions prog) ops space src of
+ case foldM compileFunDef' (newState src) (progFunctions prog) of
Left err -> (Left err, src)
- Right ((), src', _, fs) -> (Right fs, src')
+ Right s -> (Right $ stateFunctions s, stateNameSource s)
+ where compileFunDef' s fdef = do
+ ((), s', _) <-
+ runImpM (compileFunDef fdef) ops space fake (funDefName fdef) s
+ return s'
compileInParam :: ExplicitMemorish lore =>
FParam lore -> ImpM lore op (Either Imp.Param ArrayDecl)
@@ -1178,7 +1192,8 @@ compileAlloc :: ExplicitMemorish lore =>
-> ImpM lore op ()
compileAlloc (Pattern [] [mem]) e space = do
e' <- compileSubExp e
- emit $ Imp.Allocate (patElemName mem) (Imp.bytes e') space
+ fake <- asks $ elem space . envFakeMemory
+ unless fake $ emit $ Imp.Allocate (patElemName mem) (Imp.bytes e') space
compileAlloc pat _ _ =
compilerBugS $ "compileAlloc: Invalid pattern: " ++ pretty pat
@@ -1232,7 +1247,8 @@ sAlloc name size space = do
size_var <-- Imp.innerExp size
return $ Imp.VarSize size_var
emit $ Imp.DeclareMem name' space
- emit $ Imp.Allocate name' size space
+ fake <- asks $ elem space . envFakeMemory
+ unless fake $ emit $ Imp.Allocate name' size space
addVar name' $ MemVar Nothing $ MemEntry size' space
return name'
@@ -1242,6 +1258,26 @@ sArray name bt shape membind = do
dArray name' bt shape membind
return name'
+-- | Uses linear/iota index function.
+sAllocArray :: String -> PrimType -> ShapeBase SubExp -> Space -> ImpM lore op VName
+sAllocArray name pt shape space = do
+ let arr_bytes = Imp.bytes $ Imp.LeafExp (Imp.SizeOf pt) int32 *
+ product (map (compileSubExpOfType int32) (shapeDims shape))
+ mem <- sAlloc (name ++ "_mem") arr_bytes space
+ sArray name pt shape $
+ ArrayIn mem $ IxFun.iota $ map (primExpFromSubExp int32) $ shapeDims shape
+
+-- | Uses linear/iota index function.
+sStaticArray :: String -> Space -> PrimType -> [PrimValue] -> ImpM lore op VName
+sStaticArray name space pt vs = do
+ let shape = Shape [constant $ length vs]
+ size = Imp.ConstSize $ fromIntegral (length vs) * primByteSize pt
+ mem <- newVName $ name ++ "_mem"
+ emit $ Imp.DeclareArray mem space pt vs
+ addVar mem $ MemVar Nothing $ MemEntry size space
+ sArray name pt shape $
+ ArrayIn mem $ IxFun.iota $ map (primExpFromSubExp int32) $ shapeDims shape
+
sWrite :: VName -> [Imp.Exp] -> PrimExp Imp.ExpLeaf -> ImpM lore op ()
sWrite arr is v = do
(mem, space, offset) <- fullyIndexArray arr is
diff --git a/src/Futhark/CodeGen/ImpGen/CUDA.hs b/src/Futhark/CodeGen/ImpGen/CUDA.hs
new file mode 100644
index 0000000..a37c97e
--- /dev/null
+++ b/src/Futhark/CodeGen/ImpGen/CUDA.hs
@@ -0,0 +1,14 @@
+module Futhark.CodeGen.ImpGen.CUDA
+ ( compileProg
+ ) where
+
+import Futhark.Error
+import Futhark.Representation.ExplicitMemory
+import qualified Futhark.CodeGen.ImpCode.OpenCL as OpenCL
+import qualified Futhark.CodeGen.ImpGen.Kernels as ImpGenKernels
+import Futhark.CodeGen.ImpGen.Kernels.ToOpenCL
+import Futhark.MonadFreshNames
+
+compileProg :: MonadFreshNames m => Prog ExplicitMemory
+ -> m (Either InternalError OpenCL.Program)
+compileProg prog = either Left kernelsToCUDA <$> ImpGenKernels.compileProg prog
diff --git a/src/Futhark/CodeGen/ImpGen/Kernels.hs b/src/Futhark/CodeGen/ImpGen/Kernels.hs
index a753060..ae62adf 100644
--- a/src/Futhark/CodeGen/ImpGen/Kernels.hs
+++ b/src/Futhark/CodeGen/ImpGen/Kernels.hs
@@ -1,6 +1,5 @@
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
-{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ConstraintKinds #-}
module Futhark.CodeGen.ImpGen.Kernels
@@ -8,37 +7,29 @@ module Futhark.CodeGen.ImpGen.Kernels
)
where
-import Control.Arrow ((&&&))
import Control.Monad.Except
import Control.Monad.Reader
import Data.Maybe
-import Data.Semigroup ((<>))
import qualified Data.Map.Strict as M
-import qualified Data.Set as S
import Data.List
import Prelude hiding (quot)
import Futhark.Error
import Futhark.MonadFreshNames
-import Futhark.Transform.Rename
import Futhark.Representation.ExplicitMemory
import qualified Futhark.CodeGen.ImpCode.Kernels as Imp
import Futhark.CodeGen.ImpCode.Kernels (bytes)
import qualified Futhark.CodeGen.ImpGen as ImpGen
-import Futhark.CodeGen.ImpGen ((<--),
- sFor, sWhile, sComment, sIf, sWhen, sUnless,
+import Futhark.CodeGen.ImpGen.Kernels.Base
+import Futhark.CodeGen.ImpGen.Kernels.SegRed
+import Futhark.CodeGen.ImpGen (sFor, sWhen,
sOp,
dPrim, dPrim_, dPrimV)
import Futhark.CodeGen.ImpGen.Kernels.Transpose
import qualified Futhark.Representation.ExplicitMemory.IndexFunction as IxFun
import Futhark.CodeGen.SetDefaultSpace
-import Futhark.Tools (partitionChunkedKernelLambdaParameters)
-import Futhark.Util.IntegralExp (quotRoundingUp, quot, rem, IntegralExp)
-import Futhark.Util (splitAt3)
-
-type CallKernelGen = ImpGen.ImpM ExplicitMemory Imp.HostOp
-type InKernelGen = ImpGen.ImpM InKernel Imp.KernelOp
+import Futhark.Util.IntegralExp (quotRoundingUp, quot, IntegralExp)
callKernelOperations :: ImpGen.Operations ExplicitMemory Imp.HostOp
callKernelOperations =
@@ -48,17 +39,10 @@ callKernelOperations =
, ImpGen.opsStmsCompiler = ImpGen.defCompileStms
}
-inKernelOperations :: KernelConstants -> ImpGen.Operations InKernel Imp.KernelOp
-inKernelOperations constants = (ImpGen.defaultOperations $ compileInKernelOp constants)
- { ImpGen.opsCopyCompiler = inKernelCopy
- , ImpGen.opsExpCompiler = inKernelExpCompiler
- , ImpGen.opsStmsCompiler = \_ -> compileKernelStms constants
- }
-
compileProg :: MonadFreshNames m => Prog ExplicitMemory -> m (Either InternalError Imp.Program)
compileProg prog =
fmap (setDefaultSpace (Imp.Space "device")) <$>
- ImpGen.compileProg callKernelOperations (Imp.Space "device") prog
+ ImpGen.compileProg callKernelOperations (Imp.Space "device") [Imp.Space "local"] prog
opCompiler :: Pattern ExplicitMemory -> Op ExplicitMemory
-> CallKernelGen ()
@@ -67,71 +51,41 @@ opCompiler dest (Alloc e space) =
opCompiler dest (Inner kernel) =
kernelCompiler dest kernel
-compileInKernelOp :: KernelConstants -> Pattern InKernel -> Op InKernel
- -> InKernelGen ()
-compileInKernelOp _ (Pattern _ [mem]) Alloc{} =
- compilerLimitationS $ "Cannot allocate memory block " ++ pretty mem ++ " in kernel."
-compileInKernelOp _ dest Alloc{} =
- compilerBugS $ "Invalid target for in-kernel allocation: " ++ show dest
-compileInKernelOp constants pat (Inner op) =
- compileKernelExp constants pat op
-
--- | Recognise kernels (maps), give everything else back.
+sizeClassWithEntryPoint :: Name -> Imp.SizeClass -> Imp.SizeClass
+sizeClassWithEntryPoint fname (Imp.SizeThreshold path) =
+ Imp.SizeThreshold $ map f path
+ where f (name, x) = (keyWithEntryPoint fname name, x)
+sizeClassWithEntryPoint _ size_class = size_class
+
kernelCompiler :: Pattern ExplicitMemory -> Kernel InKernel
-> CallKernelGen ()
-kernelCompiler (Pattern _ [pe]) (GetSize key size_class) =
- sOp $ Imp.GetSize (patElemName pe) key size_class
+kernelCompiler (Pattern _ [pe]) (GetSize key size_class) = do
+ fname <- asks ImpGen.envFunction
+ sOp $ Imp.GetSize (patElemName pe) (keyWithEntryPoint fname key) $
+ sizeClassWithEntryPoint fname size_class
-kernelCompiler (Pattern _ [pe]) (CmpSizeLe key size_class x) =
- sOp . Imp.CmpSizeLe (patElemName pe) key size_class =<< ImpGen.compileSubExp x
+kernelCompiler (Pattern _ [pe]) (CmpSizeLe key size_class x) = do
+ fname <- asks ImpGen.envFunction
+ let size_class' = sizeClassWithEntryPoint fname size_class
+ sOp . Imp.CmpSizeLe (patElemName pe) (keyWithEntryPoint fname key) size_class'
+ =<< ImpGen.compileSubExp x
kernelCompiler (Pattern _ [pe]) (GetSizeMax size_class) =
sOp $ Imp.GetSizeMax (patElemName pe) size_class
kernelCompiler pat (Kernel desc space _ kernel_body) = do
+ (constants, init_constants) <- kernelInitialisation space
- group_size' <- ImpGen.subExpToDimSize $ spaceGroupSize space
- num_threads' <- ImpGen.subExpToDimSize $ spaceNumThreads space
+ kernel_body' <-
+ makeAllMemoryGlobal $ ImpGen.subImpM_ (inKernelOperations constants) $ do
+ init_constants
+ compileKernelBody pat constants kernel_body
let bound_in_kernel =
M.keys $
scopeOfKernelSpace space <>
scopeOf (kernelBodyStms kernel_body)
-
- let global_tid = spaceGlobalId space
- local_tid = spaceLocalId space
- group_id = spaceGroupId space
- wave_size <- newVName "wave_size"
- inner_group_size <- newVName "group_size"
- thread_active <- newVName "thread_active"
-
- let (space_is, space_dims) = unzip $ spaceDimensions space
- space_dims' <- mapM ImpGen.compileSubExp space_dims
- let constants = KernelConstants global_tid local_tid group_id
- group_size' num_threads'
- (Imp.VarSize wave_size) (zip space_is space_dims')
- (Imp.var thread_active Bool) mempty
-
- kernel_body' <-
- makeAllMemoryGlobal $ ImpGen.subImpM_ (inKernelOperations constants) $ do
- dPrim_ wave_size int32
- dPrim_ inner_group_size int32
- dPrim_ thread_active Bool
- ImpGen.dScope Nothing (scopeOfKernelSpace space)
-
- sOp (Imp.GetGlobalId global_tid 0)
- sOp (Imp.GetLocalId local_tid 0)
- sOp (Imp.GetLocalSize inner_group_size 0)
- sOp (Imp.GetLockstepWidth wave_size)
- sOp (Imp.GetGroupId group_id 0)
-
- setSpaceIndices space
-
- thread_active <-- isActive (spaceDimensions space)
-
- compileKernelBody pat constants kernel_body
-
(uses, local_memory) <- computeKernelUses kernel_body' bound_in_kernel
forM_ (kernelHints desc) $ \(s,v) -> do
@@ -144,86 +98,62 @@ kernelCompiler pat (Kernel desc space _ kernel_body) = do
ImpGen.compileSubExp v >>= ImpGen.emit . Imp.DebugPrint s (elemType ty)
- sOp $ Imp.CallKernel $ Imp.AnyKernel Imp.Kernel
+ sOp $ Imp.CallKernel Imp.Kernel
{ Imp.kernelBody = kernel_body'
, Imp.kernelLocalMemory = local_memory
, Imp.kernelUses = uses
, Imp.kernelNumGroups = [ImpGen.compileSubExpOfType int32 $ spaceNumGroups space]
, Imp.kernelGroupSize = [ImpGen.compileSubExpOfType int32 $ spaceGroupSize space]
, Imp.kernelName = nameFromString $ kernelName desc ++ "_" ++
- show (baseTag global_tid)
+ show (baseTag $ kernelGlobalThreadIdVar constants)
}
+-- First handle the simple case of a non-segmented reduction. Our
+-- strategy is the conventional approach of generating two kernels:
+-- one where each group is given a chunk of the total input and
+-- produces a partial result per group, and then a final kernel that
+-- combines the per-group partial results.
+kernelCompiler pat (SegRed space comm red_op nes _ body) =
+ compileSegRed pat space comm red_op nes body
+
kernelCompiler pat e =
compilerBugS $ "ImpGen.kernelCompiler: Invalid pattern\n " ++
pretty pat ++ "\nfor expression\n " ++ pretty e
expCompiler :: ImpGen.ExpCompiler ExplicitMemory Imp.HostOp
+
-- We generate a simple kernel for itoa and replicate.
expCompiler (Pattern _ [pe]) (BasicOp (Iota n x s et)) = do
+ n' <- ImpGen.compileSubExp n
+ x' <- ImpGen.compileSubExp x
+ s' <- ImpGen.compileSubExp s
destloc <- ImpGen.entryArrayLocation <$> ImpGen.lookupArray (patElemName pe)
- let tag = Just $ baseTag $ patElemName pe
- thread_gid <- maybe (newVName "thread_gid") (return . VName (nameFromString "thread_gid")) tag
+ (constants, set_constants) <- simpleKernelConstants n' "iota"
- makeAllMemoryGlobal $ do
- (destmem, destspace, destidx) <-
- ImpGen.fullyIndexArray' destloc [ImpGen.varIndex thread_gid] (IntType et)
+ sKernel constants "iota" $ do
+ set_constants
+ let gtid = kernelGlobalThreadId constants
+ sWhen (kernelThreadActive constants) $ do
+ (destmem, destspace, destidx) <-
+ ImpGen.fullyIndexArray' destloc [gtid] (IntType et)
- n' <- ImpGen.compileSubExp n
- x' <- ImpGen.compileSubExp x
- s' <- ImpGen.compileSubExp s
-
- let body = Imp.Write destmem destidx (IntType et) destspace Imp.Nonvolatile $
- Imp.ConvOpExp (SExt Int32 et) (Imp.var thread_gid int32) * s' + x'
-
- (group_size, num_groups) <- computeMapKernelGroups n'
-
- (body_uses, _) <- computeKernelUses
- (freeIn body <> freeIn [n',x',s'])
- [thread_gid]
-
- sOp $ Imp.CallKernel $ Imp.Map Imp.MapKernel
- { Imp.mapKernelThreadNum = thread_gid
- , Imp.mapKernelDesc = "iota"
- , Imp.mapKernelNumGroups = Imp.VarSize num_groups
- , Imp.mapKernelGroupSize = Imp.VarSize group_size
- , Imp.mapKernelSize = n'
- , Imp.mapKernelUses = body_uses
- , Imp.mapKernelBody = body
- }
-
-expCompiler
- (Pattern _ [pe]) (BasicOp (Replicate (Shape ds) se)) = do
- constants <- simpleKernelConstants (Just $ baseTag $ patElemName pe) "replicate"
+ ImpGen.emit $
+ Imp.Write destmem destidx (IntType et) destspace Imp.Nonvolatile $
+ Imp.ConvOpExp (SExt Int32 et) gtid * s' + x'
+expCompiler (Pattern _ [pe]) (BasicOp (Replicate (Shape ds) se)) = do
t <- subExpType se
- let thread_gid = kernelGlobalThreadId constants
- row_dims = arrayDims t
- dims = ds ++ row_dims
- is' = unflattenIndex (map (ImpGen.compileSubExpOfType int32) dims) $
- ImpGen.varIndex thread_gid
- ds' <- mapM ImpGen.compileSubExp ds
-
- makeAllMemoryGlobal $ do
- body <- ImpGen.subImpM_ (inKernelOperations constants) $
- ImpGen.copyDWIM (patElemName pe) is' se $ drop (length ds) is'
- dims' <- mapM ImpGen.compileSubExp dims
- (group_size, num_groups) <- computeMapKernelGroups $ product dims'
+ dims <- mapM ImpGen.compileSubExp $ ds ++ arrayDims t
+ (constants, set_constants) <-
+ simpleKernelConstants (product dims) "replicate"
- (body_uses, _) <- computeKernelUses
- (freeIn body <> freeIn ds')
- [thread_gid]
+ let is' = unflattenIndex dims $ kernelGlobalThreadId constants
- sOp $ Imp.CallKernel $ Imp.Map Imp.MapKernel
- { Imp.mapKernelThreadNum = thread_gid
- , Imp.mapKernelDesc = "replicate"
- , Imp.mapKernelNumGroups = Imp.VarSize num_groups
- , Imp.mapKernelGroupSize = Imp.VarSize group_size
- , Imp.mapKernelSize = product dims'
- , Imp.mapKernelUses = body_uses
- , Imp.mapKernelBody = body
- }
+ sKernel constants "replicate" $ do
+ set_constants
+ sWhen (kernelThreadActive constants) $
+ ImpGen.copyDWIM (patElemName pe) is' se $ drop (length ds) is'
-- Allocation in the "local" space is just a placeholder.
expCompiler _ (Op (Alloc _ (Space "local"))) =
@@ -266,48 +196,28 @@ callKernelCopy bt
(n * row_size) `Imp.withElemType` bt
| otherwise = do
- global_thread_index <- newVName "copy_global_thread_index"
-- Note that the shape of the destination and the source are
-- necessarily the same.
let shape = map Imp.sizeToExp srcshape
shape_se = map (Imp.innerExp . ImpGen.dimSizeToExp) srcshape
- dest_is = unflattenIndex shape_se $ ImpGen.varIndex global_thread_index
- src_is = dest_is
-
- makeAllMemoryGlobal $ do
- (_, destspace, destidx) <- ImpGen.fullyIndexArray' destloc dest_is bt
- (_, srcspace, srcidx) <- ImpGen.fullyIndexArray' srcloc src_is bt
-
- let body = Imp.Write destmem destidx bt destspace Imp.Nonvolatile $
- Imp.index srcmem srcidx bt srcspace Imp.Nonvolatile
-
- let writes_to = [Imp.MemoryUse destmem]
+ kernel_size = Imp.innerExp n * product (drop 1 shape)
- reads_from <- readsFromSet $
- S.singleton srcmem <>
- freeIn destIxFun <> freeIn srcIxFun <> freeIn destshape
+ (constants, set_constants) <- simpleKernelConstants kernel_size "copy"
- let kernel_size = Imp.innerExp n * product (drop 1 shape)
- (group_size, num_groups) <- computeMapKernelGroups kernel_size
+ sKernel constants "copy" $ do
+ set_constants
- let bound_in_kernel = [global_thread_index]
- (body_uses, _) <- computeKernelUses (kernel_size, body) bound_in_kernel
+ let gtid = kernelGlobalThreadId constants
+ dest_is = unflattenIndex shape_se gtid
+ src_is = dest_is
- sOp $ Imp.CallKernel $ Imp.Map Imp.MapKernel
- { Imp.mapKernelThreadNum = global_thread_index
- , Imp.mapKernelDesc = "copy"
- , Imp.mapKernelNumGroups = Imp.VarSize num_groups
- , Imp.mapKernelGroupSize = Imp.VarSize group_size
- , Imp.mapKernelSize = kernel_size
- , Imp.mapKernelUses = nub $ body_uses ++ writes_to ++ reads_from
- , Imp.mapKernelBody = body
- }
+ (_, destspace, destidx) <- ImpGen.fullyIndexArray' destloc dest_is bt
+ (_, srcspace, srcidx) <- ImpGen.fullyIndexArray' srcloc src_is bt
--- | We have no bulk copy operation (e.g. memmove) inside kernels, so
--- turn any copy into a loop.
-inKernelCopy :: ImpGen.CopyCompiler InKernel Imp.KernelOp
-inKernelCopy = ImpGen.copyElementWise
+ sWhen (gtid .<. kernel_size) $ ImpGen.emit $
+ Imp.Write destmem destidx bt destspace Imp.Nonvolatile $
+ Imp.index srcmem srcidx bt srcspace Imp.Nonvolatile
mapTransposeForType :: PrimType -> ImpGen.ImpM ExplicitMemory Imp.HostOp Name
mapTransposeForType bt = do
@@ -421,120 +331,13 @@ mapTransposeFunction bt =
(Imp.Count num_bytes)
callTransposeKernel =
- Imp.Op . Imp.CallKernel . Imp.AnyKernel .
+ Imp.Op . Imp.CallKernel .
mapTransposeKernel (mapTransposeName bt) block_dim_int
(destmem, v32 destoffset, srcmem, v32 srcoffset,
v32 x, v32 y, v32 in_elems, v32 out_elems,
v32 mulx, v32 muly, v32 num_arrays,
block) bt
-
-inKernelExpCompiler :: ImpGen.ExpCompiler InKernel Imp.KernelOp
-inKernelExpCompiler _ (BasicOp (Assert _ _ (loc, locs))) =
- compilerLimitationS $
- unlines [ "Cannot compile assertion at " ++
- intercalate " -> " (reverse $ map locStr $ loc:locs) ++
- " inside parallel kernel."
- , "As a workaround, surround the expression with 'unsafe'."]
--- The static arrays stuff does not work inside kernels.
-inKernelExpCompiler (Pattern _ [dest]) (BasicOp (ArrayLit es _)) =
- forM_ (zip [0..] es) $ \(i,e) ->
- ImpGen.copyDWIM (patElemName dest) [fromIntegral (i::Int32)] e []
-inKernelExpCompiler dest e =
- ImpGen.defCompileExp dest e
-
-computeKernelUses :: FreeIn a =>
- a -> [VName]
- -> CallKernelGen ([Imp.KernelUse], [Imp.LocalMemoryUse])
-computeKernelUses kernel_body bound_in_kernel = do
- let actually_free = freeIn kernel_body `S.difference` S.fromList bound_in_kernel
-
- -- Compute the variables that we need to pass to the kernel.
- reads_from <- readsFromSet actually_free
-
- -- Are we using any local memory?
- local_memory <- computeLocalMemoryUse actually_free
- return (nub reads_from, nub local_memory)
-
-readsFromSet :: Names -> CallKernelGen [Imp.KernelUse]
-readsFromSet free =
- fmap catMaybes $
- forM (S.toList free) $ \var -> do
- t <- lookupType var
- case t of
- Array {} -> return Nothing
- Mem _ (Space "local") -> return Nothing
- Mem _ _ -> return $ Just $ Imp.MemoryUse var
- Prim bt ->
- isConstExp var >>= \case
- Just ce -> return $ Just $ Imp.ConstUse var ce
- Nothing | bt == Cert -> return Nothing
- | otherwise -> return $ Just $ Imp.ScalarUse var bt
-
-computeLocalMemoryUse :: Names -> CallKernelGen [Imp.LocalMemoryUse]
-computeLocalMemoryUse free =
- fmap catMaybes $
- forM (S.toList free) $ \var -> do
- t <- lookupType var
- case t of
- Mem memsize (Space "local") -> do
- memsize' <- localMemSize =<< ImpGen.subExpToDimSize memsize
- return $ Just (var, memsize')
- _ -> return Nothing
-
-localMemSize :: Imp.MemSize -> CallKernelGen (Either Imp.MemSize Imp.KernelConstExp)
-localMemSize (Imp.ConstSize x) =
- return $ Right $ ValueExp $ IntValue $ Int64Value x
-localMemSize (Imp.VarSize v) = isConstExp v >>= \case
- Just e | isStaticExp e -> return $ Right e
- _ -> return $ Left $ Imp.VarSize v
-
--- | Only some constant expressions quality as *static* expressions,
--- which we can use for static memory allocation. This is a bit of a
--- hack, as it is primarly motivated by what you can put as the size
--- when daring an array in C.
-isStaticExp :: Imp.KernelConstExp -> Bool
-isStaticExp LeafExp{} = True
-isStaticExp ValueExp{} = True
-isStaticExp (BinOpExp Add{} x y) = isStaticExp x && isStaticExp y
-isStaticExp (BinOpExp Sub{} x y) = isStaticExp x && isStaticExp y
-isStaticExp (BinOpExp Mul{} x y) = isStaticExp x && isStaticExp y
-isStaticExp _ = False
-
-isConstExp :: VName -> CallKernelGen (Maybe Imp.KernelConstExp)
-isConstExp v = do
- vtable <- ImpGen.getVTable
- let lookupConstExp name = constExp =<< hasExp =<< M.lookup name vtable
- constExp (Op (Inner (GetSize key _))) = Just $ LeafExp (Imp.SizeConst key) int32
- constExp e = primExpFromExp lookupConstExp e
- return $ lookupConstExp v
- where hasExp (ImpGen.ArrayVar e _) = e
- hasExp (ImpGen.ScalarVar e _) = e
- hasExp (ImpGen.MemVar e _) = e
-
--- | Change every memory block to be in the global address space,
--- except those who are in the local memory space. This only affects
--- generated code - we still need to make sure that the memory is
--- actually present on the device (and dared as variables in the
--- kernel).
-makeAllMemoryGlobal :: CallKernelGen a -> CallKernelGen a
-makeAllMemoryGlobal =
- local (\env -> env { ImpGen.envDefaultSpace = Imp.Space "global" }) .
- ImpGen.localVTable (M.map globalMemory)
- where globalMemory (ImpGen.MemVar _ entry)
- | ImpGen.entryMemSpace entry /= Space "local" =
- ImpGen.MemVar Nothing entry { ImpGen.entryMemSpace = Imp.Space "global" }
- globalMemory entry =
- entry
-
-computeMapKernelGroups :: Imp.Exp -> CallKernelGen (VName, VName)
-computeMapKernelGroups kernel_size = do
- group_size <- dPrim "group_size" int32
- let group_size_var = Imp.var group_size int32
- sOp $ Imp.GetSize group_size group_size Imp.SizeGroup
- num_groups <- dPrimV "num_groups" $ kernel_size `quotRoundingUp` Imp.ConvOpExp (SExt Int32 Int32) group_size_var
- return (group_size, num_groups)
-
isMapTransposeKernel :: PrimType -> ImpGen.MemLocation -> ImpGen.MemLocation
-> Maybe (Imp.Exp, Imp.Exp,
Imp.Exp, Imp.Exp, Imp.Exp,
@@ -570,142 +373,38 @@ isMapTransposeKernel bt
(pretrans, posttrans) = f $ splitAt r2 notmapped
in (product mapped, product pretrans, product posttrans)
-writeParamToLocalMemory :: Typed (MemBound u) =>
- Imp.Exp -> (VName, t) -> Param (MemBound u)
- -> ImpGen.ImpM lore op ()
-writeParamToLocalMemory i (mem, _) param
- | Prim t <- paramType param =
- ImpGen.emit $
- Imp.Write mem (bytes i') bt (Space "local") Imp.Volatile $
- Imp.var (paramName param) t
- | otherwise =
- return ()
- where i' = i * Imp.LeafExp (Imp.SizeOf bt) int32
- bt = elemType $ paramType param
-
-readParamFromLocalMemory :: Typed (MemBound u) =>
- VName -> Imp.Exp -> Param (MemBound u) -> (VName, t)
- -> ImpGen.ImpM lore op ()
-readParamFromLocalMemory index i param (l_mem, _)
- | Prim _ <- paramType param =
- paramName param <--
- Imp.index l_mem (bytes i') bt (Space "local") Imp.Volatile
- | otherwise = index <-- i
- where i' = i * Imp.LeafExp (Imp.SizeOf bt) int32
- bt = elemType $ paramType param
-
-computeThreadChunkSize :: SplitOrdering
- -> Imp.Exp
- -> Imp.Count Imp.Elements
- -> Imp.Count Imp.Elements
- -> VName
- -> ImpGen.ImpM lore op ()
-computeThreadChunkSize (SplitStrided stride) thread_index elements_per_thread num_elements chunk_var = do
- stride' <- ImpGen.compileSubExp stride
- chunk_var <--
- Imp.BinOpExp (SMin Int32)
- (Imp.innerExp elements_per_thread)
- ((Imp.innerExp num_elements - thread_index) `quotRoundingUp` stride')
-
-computeThreadChunkSize SplitContiguous thread_index elements_per_thread num_elements chunk_var = do
- starting_point <- dPrimV "starting_point" $
- thread_index * Imp.innerExp elements_per_thread
- remaining_elements <- dPrimV "remaining_elements" $
- Imp.innerExp num_elements - Imp.var starting_point int32
-
- let no_remaining_elements = Imp.var remaining_elements int32 .<=. 0
- beyond_bounds = Imp.innerExp num_elements .<=. Imp.var starting_point int32
-
- sIf (no_remaining_elements .||. beyond_bounds)
- (chunk_var <-- 0)
- (sIf is_last_thread
- (chunk_var <-- Imp.innerExp last_thread_elements)
- (chunk_var <-- Imp.innerExp elements_per_thread))
- where last_thread_elements =
- num_elements - Imp.elements thread_index * elements_per_thread
- is_last_thread =
- Imp.innerExp num_elements .<.
- (thread_index + 1) * Imp.innerExp elements_per_thread
-
-inBlockScan :: Imp.Exp
- -> Imp.Exp
- -> Imp.Exp
- -> VName
- -> [(VName, t)]
- -> Lambda InKernel
- -> InKernelGen ()
-inBlockScan lockstep_width block_size active local_id acc_local_mem scan_lam = ImpGen.everythingVolatile $ do
- skip_threads <- dPrim "skip_threads" int32
- let in_block_thread_active =
- Imp.var skip_threads int32 .<=. in_block_id
- (scan_lam_i, other_index_param, actual_params) =
- partitionChunkedKernelLambdaParameters $ lambdaParams scan_lam
- (x_params, y_params) =
- splitAt (length actual_params `div` 2) actual_params
- read_operands =
- zipWithM_ (readParamFromLocalMemory (paramName other_index_param) $
- Imp.var local_id int32 - Imp.var skip_threads int32)
- x_params acc_local_mem
-
- -- Set initial y values
- sWhen active $
- zipWithM_ (readParamFromLocalMemory scan_lam_i $ Imp.var local_id int32)
- y_params acc_local_mem
-
- let op_to_y = ImpGen.compileBody' y_params $ lambdaBody scan_lam
- write_operation_result =
- zipWithM_ (writeParamToLocalMemory $ Imp.var local_id int32)
- acc_local_mem y_params
- maybeBarrier = sWhen (lockstep_width .<=. Imp.var skip_threads int32) $
- sOp Imp.Barrier
-
- sComment "in-block scan (hopefully no barriers needed)" $ do
- skip_threads <-- 1
- sWhile (Imp.var skip_threads int32 .<. block_size) $ do
- sWhen (in_block_thread_active .&&. active) $ do
- sComment "read operands" read_operands
- sComment "perform operation" op_to_y
-
- maybeBarrier
-
- sWhen (in_block_thread_active .&&. active) $
- sComment "write result" write_operation_result
-
- maybeBarrier
-
- skip_threads <-- Imp.var skip_threads int32 * 2
-
- where block_id = Imp.var local_id int32 `quot` block_size
- in_block_id = Imp.var local_id int32 - block_id * block_size
-
-data KernelConstants = KernelConstants
- { kernelGlobalThreadId :: VName
- , kernelLocalThreadId :: VName
- , kernelGroupId :: VName
- , kernelGroupSize :: Imp.DimSize
- , _kernelNumThreads :: Imp.DimSize
- , kernelWaveSize :: Imp.DimSize
- , kernelDimensions :: [(VName, Imp.Exp)]
- , kernelThreadActive :: Imp.Exp
- , kernelStreamed :: [(VName, Imp.DimSize)]
- -- ^ Chunk sizez and their maximum size. Hint
- -- for unrolling.
- }
-
--- FIXME: wing a KernelConstants structure for use in Replicate
--- compilation. This cannot be the best way to do this...
-simpleKernelConstants :: MonadFreshNames m =>
- Maybe Int -> String
- -> m KernelConstants
-simpleKernelConstants tag desc = do
- thread_gtid <- maybe (newVName $ desc ++ "_gtid")
- (return . VName (nameFromString $ desc ++ "_gtid")) tag
+simpleKernelConstants :: Imp.Exp -> String
+ -> CallKernelGen (KernelConstants, ImpGen.ImpM InKernel Imp.KernelOp ())
+simpleKernelConstants kernel_size desc = do
+ thread_gtid <- newVName $ desc ++ "_gtid"
thread_ltid <- newVName $ desc ++ "_ltid"
- thread_gid <- newVName $ desc ++ "_gid"
- return $ KernelConstants
- thread_gtid thread_ltid thread_gid
- (Imp.ConstSize 0) (Imp.ConstSize 0) (Imp.ConstSize 0)
- [] (Imp.ValueExp $ BoolValue True) mempty
+ group_id <- newVName $ desc ++ "_gid"
+ (group_size, num_groups) <- computeMapKernelGroups kernel_size
+ let set_constants = do
+ dPrim_ thread_gtid int32
+ dPrim_ thread_ltid int32
+ dPrim_ group_id int32
+ sOp (Imp.GetGlobalId thread_gtid 0)
+ sOp (Imp.GetLocalId thread_ltid 0)
+ sOp (Imp.GetGroupId group_id 0)
+
+ return (KernelConstants
+ (Imp.var thread_gtid int32) (Imp.var thread_ltid int32) (Imp.var group_id int32)
+ thread_gtid thread_ltid group_id
+ group_size num_groups (group_size*num_groups) 0
+ [] (Imp.var thread_gtid int32 .<. kernel_size) mempty,
+
+ set_constants)
+
+computeMapKernelGroups :: Imp.Exp -> CallKernelGen (Imp.Exp, Imp.Exp)
+computeMapKernelGroups kernel_size = do
+ group_size <- dPrim "group_size" int32
+ fname <- asks ImpGen.envFunction
+ let group_size_var = Imp.var group_size int32
+ group_size_key = keyWithEntryPoint fname $ nameFromString $ pretty group_size
+ sOp $ Imp.GetSize group_size group_size_key Imp.SizeGroup
+ num_groups <- dPrimV "num_groups" $ kernel_size `quotRoundingUp` Imp.ConvOpExp (SExt Int32 Int32) group_size_var
+ return (Imp.var group_size int32, Imp.var num_groups int32)
compileKernelBody :: Pattern InKernel
-> KernelConstants
@@ -716,510 +415,6 @@ compileKernelBody pat constants kbody =
zipWithM_ (compileKernelResult constants) (patternElements pat) $
kernelBodyResult kbody
-compileKernelStms :: KernelConstants -> [Stm InKernel]
- -> InKernelGen a
- -> InKernelGen a
-compileKernelStms constants ungrouped_bnds m =
- compileGroupedKernelStms' $ groupStmsByGuard constants ungrouped_bnds
- where compileGroupedKernelStms' [] = m
- compileGroupedKernelStms' ((g, bnds):rest_bnds) = do
- ImpGen.dScopes (map ((Just . stmExp) &&& (castScope . scopeOf)) bnds)
- protect g $ mapM_ compileKernelStm bnds
- compileGroupedKernelStms' rest_bnds
-
- protect Nothing body_m =
- body_m
- protect (Just (Imp.ValueExp (BoolValue True))) body_m =
- body_m
- protect (Just g) body_m =
- sWhen g $ allThreads constants body_m
-
- compileKernelStm (Let pat _ e) = ImpGen.compileExp pat e
-
-groupStmsByGuard :: KernelConstants
- -> [Stm InKernel]
- -> [(Maybe Imp.Exp, [Stm InKernel])]
-groupStmsByGuard constants bnds =
- map collapse $ groupBy sameGuard $ zip (map bindingGuard bnds) bnds
- where bindingGuard (Let _ _ Op{}) = Nothing
- bindingGuard _ = Just $ kernelThreadActive constants
-
- sameGuard (g1, _) (g2, _) = g1 == g2
-
- collapse [] =
- (Nothing, [])
- collapse l@((g,_):_) =
- (g, map snd l)
-
-compileKernelExp :: KernelConstants -> Pattern InKernel -> KernelExp InKernel
- -> InKernelGen ()
-
-compileKernelExp _ pat (Barrier ses) = do
- forM_ (zip (patternNames pat) ses) $ \(d, se) ->
- ImpGen.copyDWIM d [] se []
- sOp Imp.Barrier
-
-compileKernelExp _ (Pattern [] [size]) (SplitSpace o w i elems_per_thread) = do
- num_elements <- Imp.elements <$> ImpGen.compileSubExp w
- i' <- ImpGen.compileSubExp i
- elems_per_thread' <- Imp.elements <$> ImpGen.compileSubExp elems_per_thread
- computeThreadChunkSize o i' elems_per_thread' num_elements (patElemName size)
-
-compileKernelExp constants pat (Combine (CombineSpace scatter cspace) _ aspace body) = do
- -- First we compute how many times we have to iterate to cover
- -- cspace with our group size. It is a fairly common case that
- -- we statically know that this requires 1 iteration, so we
- -- could detect it and not generate a loop in that case.
- -- However, it seems to have no impact on performance (an extra
- -- conditional jump), so for simplicity we just always generate
- -- the loop.
- let cspace_dims = map (streamBounded . snd) cspace
- num_iters
- | cspace_dims == [Imp.sizeToExp $ kernelGroupSize constants] = 1
- | otherwise = product cspace_dims `quotRoundingUp`
- Imp.sizeToExp (kernelGroupSize constants)
-
- iter <- newVName "comb_iter"
-
- sFor iter Int32 num_iters $ do
- mapM_ ((`dPrim_` int32) . fst) cspace
- -- Compute the *flat* array index.
- cid <- dPrimV "flat_comb_id" $
- Imp.var iter int32 * Imp.sizeToExp (kernelGroupSize constants) +
- Imp.var (kernelLocalThreadId constants) int32
-
- -- Turn it into a nested array index.
- zipWithM_ (<--) (map fst cspace) $ unflattenIndex cspace_dims (Imp.var cid int32)
-
- -- Construct the body. This is mostly about the book-keeping
- -- for the scatter-like part.
- let (scatter_ws, scatter_ns, _scatter_vs) = unzip3 scatter
- scatter_ws_repl = concat $ zipWith replicate scatter_ns scatter_ws
- (scatter_pes, normal_pes) =
- splitAt (sum scatter_ns) $ patternElements pat
- (res_is, res_vs, res_normal) =
- splitAt3 (sum scatter_ns) (sum scatter_ns) $ bodyResult body
-
- -- Execute the body if we are within bounds.
- sWhen (isActive cspace .&&. isActive aspace) $ allThreads constants $
- ImpGen.compileStms (freeIn $ bodyResult body) (stmsToList $ bodyStms body) $ do
-
- forM_ (zip4 scatter_ws_repl res_is res_vs scatter_pes) $
- \(w, res_i, res_v, scatter_pe) -> do
- let res_i' = ImpGen.compileSubExpOfType int32 res_i
- w' = ImpGen.compileSubExpOfType int32 w
- -- We have to check that 'res_i' is in-bounds wrt. an array of size 'w'.
- in_bounds = 0 .<=. res_i' .&&. res_i' .<. w'
- sWhen in_bounds $ ImpGen.copyDWIM (patElemName scatter_pe) [res_i'] res_v []
-
- forM_ (zip normal_pes res_normal) $ \(pe, res) ->
- ImpGen.copyDWIM (patElemName pe) local_index res []
-
- sOp Imp.Barrier
-
- where streamBounded (Var v)
- | Just x <- lookup v $ kernelStreamed constants =
- Imp.sizeToExp x
- streamBounded se = ImpGen.compileSubExpOfType int32 se
-
- local_index = map (ImpGen.compileSubExpOfType int32 . Var . fst) cspace
-
-compileKernelExp constants (Pattern _ dests) (GroupReduce w lam input) = do
- groupReduce constants w lam $ map snd input
- let (reduce_acc_params, _) =
- splitAt (length input) $ drop 2 $ lambdaParams lam
- forM_ (zip dests reduce_acc_params) $ \(dest, reduce_acc_param) ->
- ImpGen.copyDWIM (patElemName dest) [] (Var $ paramName reduce_acc_param) []
-
-compileKernelExp constants _ (GroupScan w lam input) = do
- renamed_lam <- renameLambda lam
- w' <- ImpGen.compileSubExp w
-
- when (any (not . primType . paramType) $ lambdaParams lam) $
- compilerLimitationS "Cannot compile parallel scans with array element type."
-
- let local_tid = kernelLocalThreadId constants
- (_nes, arrs) = unzip input
- (lam_i, other_index_param, actual_params) =
- partitionChunkedKernelLambdaParameters $ lambdaParams lam
- (x_params, y_params) =
- splitAt (length input) actual_params
-
- ImpGen.dLParams (lambdaParams lam++lambdaParams renamed_lam)
- lam_i <-- Imp.var local_tid int32
-
- acc_local_mem <- flip zip (repeat ()) <$>
- mapM (fmap (ImpGen.memLocationName . ImpGen.entryArrayLocation) .
- ImpGen.lookupArray) arrs
-
- -- The scan works by splitting the group into blocks, which are
- -- scanned separately. Typically, these blocks are smaller than
- -- the lockstep width, which enables barrier-free execution inside
- -- them.
- --
- -- We hardcode the block size here. The only requirement is that
- -- it should not be less than the square root of the group size.
- -- With 32, we will work on groups of size 1024 or smaller, which
- -- fits every device Troels has seen. Still, it would be nicer if
- -- it were a runtime parameter. Some day.
- let block_size = Imp.ValueExp $ IntValue $ Int32Value 32
- simd_width = Imp.sizeToExp $ kernelWaveSize constants
- block_id = Imp.var local_tid int32 `quot` block_size
- in_block_id = Imp.var local_tid int32 - block_id * block_size
- doInBlockScan active = inBlockScan simd_width block_size active local_tid acc_local_mem
- lid_in_bounds = Imp.var local_tid int32 .<. w'
-
- doInBlockScan lid_in_bounds lam
- sOp Imp.Barrier
-
- let last_in_block = in_block_id .==. block_size - 1
- sComment "last thread of block 'i' writes its result to offset 'i'" $
- sWhen (last_in_block .&&. lid_in_bounds) $
- zipWithM_ (writeParamToLocalMemory block_id) acc_local_mem y_params
-
- sOp Imp.Barrier
-
- let is_first_block = block_id .==. 0
- ImpGen.comment
- "scan the first block, after which offset 'i' contains carry-in for warp 'i+1'" $
- doInBlockScan (is_first_block .&&. lid_in_bounds) renamed_lam
-
- sOp Imp.Barrier
-
- let read_carry_in =
- zipWithM_ (readParamFromLocalMemory
- (paramName other_index_param) (block_id - 1))
- x_params acc_local_mem
-
- let op_to_y =
- ImpGen.compileBody' y_params $ lambdaBody lam
- write_final_result =
- zipWithM_ (writeParamToLocalMemory $ Imp.var local_tid int32) acc_local_mem y_params
-
- sComment "carry-in for every block except the first" $
- sUnless (is_first_block .||. Imp.UnOpExp Not lid_in_bounds) $ do
- sComment "read operands" read_carry_in
- sComment "perform operation" op_to_y
- sComment "write final result" write_final_result
-
- sOp Imp.Barrier
-
- sComment "restore correct values for first block" $
- sWhen is_first_block write_final_result
-
-compileKernelExp constants (Pattern _ final) (GroupStream w maxchunk lam accs _arrs) = do
- let GroupStreamLambda block_size block_offset acc_params arr_params body = lam
- block_offset' = Imp.var block_offset int32
- w' <- ImpGen.compileSubExp w
- max_block_size <- ImpGen.compileSubExp maxchunk
-
- ImpGen.dLParams (acc_params++arr_params)
- zipWithM_ ImpGen.compileSubExpTo (map paramName acc_params) accs
- dPrim_ block_size int32
-
- -- If the GroupStream is morally just a do-loop, generate simpler code.
- case mapM isSimpleThreadInSpace $ stmsToList $ bodyStms body of
- Just stms' | ValueExp x <- max_block_size, oneIsh x -> do
- let body' = body { bodyStms = stmsFromList stms' }
- body'' = allThreads constants $
- ImpGen.compileLoopBody (map paramName acc_params) body'
- block_size <-- 1
-
- -- Check if loop is candidate for unrolling.
- let loop =
- case w of
- Var w_var | Just w_bound <- lookup w_var $ kernelStreamed constants,
- w_bound /= Imp.ConstSize 1 ->
- -- Candidate for unrolling, so generate two loops.
- sIf (w' .==. Imp.sizeToExp w_bound)
- (sFor block_offset Int32 (Imp.sizeToExp w_bound) body'')
- (sFor block_offset Int32 w' body'')
- _ -> sFor block_offset Int32 w' body''
-
- if kernelThreadActive constants == Imp.ValueExp (BoolValue True)
- then loop
- else sWhen (kernelThreadActive constants) loop
-
- _ -> do
- dPrim_ block_offset int32
- let body' = streaming constants block_size maxchunk $
- ImpGen.compileBody' acc_params body
-
- block_offset <-- 0
-
- let not_at_end = block_offset' .<. w'
- set_block_size =
- sIf (w' - block_offset' .<. max_block_size)
- (block_size <-- (w' - block_offset'))
- (block_size <-- max_block_size)
- increase_offset =
- block_offset <-- block_offset' + max_block_size
-
- -- Three cases to consider for simpler generated code based
- -- on max block size: (0) if full input size, do not
- -- generate a loop; (1) if one, generate for-loop (2)
- -- otherwise, generate chunked while-loop.
- if max_block_size == w' then
- (block_size <-- w') >> body'
- else if max_block_size == Imp.ValueExp (value (1::Int32)) then do
- block_size <-- w'
- sFor block_offset Int32 w' body'
- else
- sWhile not_at_end $
- set_block_size >> body' >> increase_offset
-
- forM_ (zip final acc_params) $ \(pe, p) ->
- ImpGen.copyDWIM (patElemName pe) [] (Var $ paramName p) []
-
- where isSimpleThreadInSpace (Let _ _ Op{}) = Nothing
- isSimpleThreadInSpace bnd = Just bnd
-
-compileKernelExp _ _ (GroupGenReduce w arrs op bucket values locks) = do
- -- Check if bucket is in-bounds
- bucket' <- mapM ImpGen.compileSubExp bucket
- w' <- mapM ImpGen.compileSubExp w
- sWhen (indexInBounds bucket' w') $
- atomicUpdate arrs bucket op values locking
- where indexInBounds inds bounds =
- foldl1 (.&&.) $ zipWith checkBound inds bounds
- where checkBound ind bound = 0 .<=. ind .&&. ind .<. bound
- locking = Locking locks 0 1 0
-
-compileKernelExp _ dest e =
- compilerBugS $ unlines ["Invalid target", " " ++ show dest,
- "for kernel expression", " " ++ pretty e]
-
--- | Locking strategy used for an atomic update.
-data Locking = Locking { lockingArray :: VName -- ^ Array containing the lock.
- , lockingIsUnlocked :: Imp.Exp -- ^ Value for us to consider the lock free.
- , lockingToLock :: Imp.Exp -- ^ What to write when we lock it.
- , lockingToUnlock :: Imp.Exp -- ^ What to write when we unlock it.
- }
-
-groupReduce :: ExplicitMemorish lore =>
- KernelConstants
- -> SubExp
- -> Lambda lore
- -> [VName]
- -> ImpGen.ImpM lore Imp.KernelOp ()
-groupReduce constants w lam arrs = do
- w' <- ImpGen.compileSubExp w
-
- let local_tid = kernelLocalThreadId constants
- (reduce_i, reduce_j_param, actual_reduce_params) =
- partitionChunkedKernelLambdaParameters $ lambdaParams lam
- (reduce_acc_params, reduce_arr_params) =
- splitAt (length arrs) actual_reduce_params
- reduce_j = paramName reduce_j_param
-
- offset <- dPrim "offset" int32
-
- skip_waves <- dPrim "skip_waves" int32
- ImpGen.dLParams $ lambdaParams lam
-
- reduce_i <-- Imp.var local_tid int32
-
- let setOffset x = do
- offset <-- x
- reduce_j <-- Imp.var local_tid int32 + Imp.var offset int32
-
- setOffset 0
-
- sWhen (Imp.var local_tid int32 .<. w') $
- zipWithM_ (readReduceArgument offset) reduce_acc_params arrs
-
- let read_reduce_args = zipWithM_ (readReduceArgument offset)
- reduce_arr_params arrs
- do_reduce = do ImpGen.comment "read array element" read_reduce_args
- ImpGen.compileBody' reduce_acc_params $ lambdaBody lam
- zipWithM_ (writeReduceOpResult local_tid)
- reduce_acc_params arrs
- in_wave_reduce = ImpGen.everythingVolatile do_reduce
-
- wave_size = Imp.sizeToExp $ kernelWaveSize constants
- group_size = Imp.sizeToExp $ kernelGroupSize constants
- wave_id = Imp.var local_tid int32 `quot` wave_size
- in_wave_id = Imp.var local_tid int32 - wave_id * wave_size
- num_waves = (group_size + wave_size - 1) `quot` wave_size
- arg_in_bounds = Imp.var reduce_j int32 .<. w'
-
- doing_in_wave_reductions =
- Imp.var offset int32 .<. wave_size
- apply_in_in_wave_iteration =
- (in_wave_id .&. (2 * Imp.var offset int32 - 1)) .==. 0
- in_wave_reductions = do
- setOffset 1
- sWhile doing_in_wave_reductions $ do
- sWhen (arg_in_bounds .&&. apply_in_in_wave_iteration)
- in_wave_reduce
- setOffset $ Imp.var offset int32 * 2
-
- doing_cross_wave_reductions =
- Imp.var skip_waves int32 .<. num_waves
- is_first_thread_in_wave =
- in_wave_id .==. 0
- wave_not_skipped =
- (wave_id .&. (2 * Imp.var skip_waves int32 - 1)) .==. 0
- apply_in_cross_wave_iteration =
- arg_in_bounds .&&. is_first_thread_in_wave .&&. wave_not_skipped
- cross_wave_reductions = do
- skip_waves <-- 1
- sWhile doing_cross_wave_reductions $ do
- sOp Imp.Barrier
- setOffset (Imp.var skip_waves int32 * wave_size)
- sWhen apply_in_cross_wave_iteration
- do_reduce
- skip_waves <-- Imp.var skip_waves int32 * 2
-
- in_wave_reductions
- cross_wave_reductions
- where readReduceArgument offset param arr
- | Prim _ <- paramType param =
- ImpGen.copyDWIM (paramName param) [] (Var arr) [i]
- | otherwise =
- return ()
- where i = ImpGen.varIndex (kernelLocalThreadId constants) + ImpGen.varIndex offset
-
- writeReduceOpResult i param arr
- | Prim _ <- paramType param =
- ImpGen.copyDWIM arr [ImpGen.varIndex i] (Var $ paramName param) []
- | otherwise =
- return ()
-
-atomicUpdate :: ExplicitMemorish lore =>
- [VName] -> [SubExp] -> Lambda lore -> [SubExp] -> Locking
- -> ImpGen.ImpM lore Imp.KernelOp ()
-atomicUpdate [a] bucket op [v] _
- | [Prim t] <- lambdaReturnType op,
- primBitSize t == 32 = do
- -- If we have only one array and one non-array value (this is a
- -- one-to-one correspondance) then we need only one
- -- update. If operator has an atomic implementation we use
- -- that, otherwise it is still a binary operator which can
- -- be implemented by atomic compare-and-swap if 32 bits.
-
- -- Common variables.
- old <- dPrim "old" t
- bucket' <- mapM ImpGen.compileSubExp bucket
-
- (arr', _a_space, bucket_offset) <- ImpGen.fullyIndexArray a bucket'
-
- val' <- ImpGen.compileSubExp v
- case opHasAtomicSupport old arr' bucket_offset op of
- Just f -> sOp $ f val'
-
- Nothing -> do
- -- Code generation target:
- --
- -- old = d_his[idx];
- -- do {
- -- assumed = old;
- -- tmp = OP::apply(val, assumed);
- -- old = atomicCAS(&d_his[idx], assumed, tmp);
- -- } while(assumed != old);
- assumed <- dPrim "assumed" t
- run_loop <- dPrimV "run_loop" true
- ImpGen.copyDWIM old [] (Var a) bucket'
-
- -- Preparing parameters
- let (acc_p:arr_p:_) = lambdaParams op
-
- -- Critical section
- ImpGen.dLParams $ lambdaParams op
-
- -- While-loop: Try to insert your value
- let (toBits, fromBits) =
- case t of FloatType Float32 -> (\x -> Imp.FunExp "to_bits32" [x] int32,
- \x -> Imp.FunExp "from_bits32" [x] t)
- _ -> (id, id)
- sWhile (Imp.var run_loop Bool) $ do
- assumed <-- Imp.var old t
- paramName acc_p <-- val'
- paramName arr_p <-- Imp.var assumed t
- ImpGen.compileBody' [acc_p] $ lambdaBody op
- old_bits <- dPrim "old_bits" int32
- sOp $ Imp.Atomic $
- Imp.AtomicCmpXchg old_bits arr' bucket_offset
- (toBits (Imp.var assumed int32)) (toBits (Imp.var (paramName acc_p) int32))
- old <-- fromBits (Imp.var old_bits int32)
- sWhen (toBits (Imp.var assumed t) .==. Imp.var old_bits int32)
- (run_loop <-- false)
- where opHasAtomicSupport old arr' bucket' lam = do
- let atomic f = Imp.Atomic . f old arr' bucket'
- [BasicOp (BinOp bop _ _)] <-
- Just $ map stmExp $ stmsToList $ bodyStms $ lambdaBody lam
- atomic <$> Imp.atomicBinOp bop
-
-atomicUpdate arrs bucket op values locking = do
- old <- dPrim "old" int32
- loop_done <- dPrimV "loop_done" 0
-
- -- Check if bucket is in-bounds
- bucket' <- mapM ImpGen.compileSubExp bucket
-
- -- Correctly index into locks.
- (locks', _locks_space, locks_offset) <-
- ImpGen.fullyIndexArray (lockingArray locking) bucket'
-
- -- Preparing parameters
- let (acc_params, arr_params) =
- splitAt (length values) $ lambdaParams op
-
- -- Critical section
- let try_acquire_lock =
- sOp $ Imp.Atomic $
- Imp.AtomicCmpXchg old locks' locks_offset (lockingIsUnlocked locking) (lockingToLock locking)
- lock_acquired = Imp.var old int32 .==. lockingIsUnlocked locking
- loop_cond = Imp.var loop_done int32 .==. 0
- release_lock = ImpGen.everythingVolatile $
- ImpGen.sWrite (lockingArray locking) bucket' $ lockingToUnlock locking
- break_loop = loop_done <-- 1
-
- -- We copy the current value and the new value to the parameters
- -- unless they are array-typed. If they are arrays, then the
- -- index functions should already be set up correctly, so there is
- -- nothing more to do.
- let bind_acc_params =
- forM_ (zip acc_params arrs) $ \(acc_p, arr) ->
- when (primType (paramType acc_p)) $
- ImpGen.copyDWIM (paramName acc_p) [] (Var arr) bucket'
-
- let bind_arr_params =
- forM_ (zip arr_params values) $ \(arr_p, val) ->
- when (primType (paramType arr_p)) $
- ImpGen.copyDWIM (paramName arr_p) [] val []
-
- let op_body = ImpGen.compileBody' acc_params $ lambdaBody op
-
- do_gen_reduce = zipWithM_ (writeArray bucket') arrs $ map (Var . paramName) acc_params
-
- -- While-loop: Try to insert your value
- sWhile loop_cond $ do
- try_acquire_lock
- sWhen lock_acquired $ do
- ImpGen.dLParams $ lambdaParams op
- bind_acc_params
- bind_arr_params
- op_body
- do_gen_reduce
- release_lock
- break_loop
- sOp Imp.MemFence
- where writeArray bucket' arr val =
- ImpGen.copyDWIM arr bucket' val []
-
-allThreads :: KernelConstants -> InKernelGen () -> InKernelGen ()
-allThreads constants = ImpGen.emit <=< ImpGen.subImpM_ (inKernelOperations constants')
- where constants' =
- constants { kernelThreadActive = Imp.ValueExp (BoolValue True) }
-
-streaming :: KernelConstants -> VName -> SubExp -> InKernelGen () -> InKernelGen ()
-streaming constants chunksize bound m = do
- bound' <- ImpGen.subExpToDimSize bound
- let constants' =
- constants { kernelStreamed = (chunksize, bound') : kernelStreamed constants }
- ImpGen.emit =<< ImpGen.subImpM_ (inKernelOperations constants') m
-
compileKernelResult :: KernelConstants -> PatElem InKernel -> KernelResult
-> InKernelGen ()
@@ -1227,12 +422,12 @@ compileKernelResult constants pe (ThreadsReturn OneResultPerGroup what) = do
i <- newVName "i"
in_local_memory <- arrayInLocalMemory what
- let me = Imp.var (kernelLocalThreadId constants) int32
+ let me = kernelLocalThreadId constants
if not in_local_memory then do
who' <- ImpGen.compileSubExp $ intConst Int32 0
sWhen (me .==. who') $
- ImpGen.copyDWIM (patElemName pe) [ImpGen.varIndex $ kernelGroupId constants] what []
+ ImpGen.copyDWIM (patElemName pe) [kernelGroupId constants] what []
else do
-- If the result of the group is an array in local memory, we
-- store it by collective copying among all the threads of the
@@ -1245,20 +440,20 @@ compileKernelResult constants pe (ThreadsReturn OneResultPerGroup what) = do
-- Compute how many elements this thread is responsible for.
-- Formula: (w - ltid) / group_size (rounded up).
let w = product ws
- ltid = ImpGen.varIndex (kernelLocalThreadId constants)
- group_size = Imp.sizeToExp (kernelGroupSize constants)
+ ltid = kernelLocalThreadId constants
+ group_size = kernelGroupSize constants
to_write = (w - ltid) `quotRoundingUp` group_size
is = unflattenIndex ws $ ImpGen.varIndex i * group_size + ltid
sFor i Int32 to_write $
- ImpGen.copyDWIM (patElemName pe) (ImpGen.varIndex (kernelGroupId constants) : is) what is
+ ImpGen.copyDWIM (patElemName pe) (kernelGroupId constants : is) what is
compileKernelResult constants pe (ThreadsReturn AllThreads what) =
- ImpGen.copyDWIM (patElemName pe) [ImpGen.varIndex $ kernelGlobalThreadId constants] what []
+ ImpGen.copyDWIM (patElemName pe) [kernelGlobalThreadId constants] what []
compileKernelResult constants pe (ThreadsReturn (ThreadsPerGroup limit) what) =
sWhen (isActive limit) $
- ImpGen.copyDWIM (patElemName pe) [ImpGen.varIndex $ kernelGroupId constants] what []
+ ImpGen.copyDWIM (patElemName pe) [kernelGroupId constants] what []
compileKernelResult constants pe (ThreadsReturn ThreadsInSpace what) = do
let is = map (ImpGen.varIndex . fst) $ kernelDimensions constants
@@ -1271,7 +466,7 @@ compileKernelResult constants pe (ConcatReturns SplitContiguous _ per_thread_ele
ImpGen.copyDWIMDest dest' [] (Var what) []
where offset = case moffset of
Nothing -> ImpGen.compileSubExpOfType int32 per_thread_elems *
- ImpGen.varIndex (kernelGlobalThreadId constants)
+ kernelGlobalThreadId constants
Just se -> ImpGen.compileSubExpOfType int32 se
compileKernelResult constants pe (ConcatReturns (SplitStrided stride) _ _ moffset what) = do
@@ -1282,7 +477,7 @@ compileKernelResult constants pe (ConcatReturns (SplitStrided stride) _ _ moffse
dest' = ImpGen.arrayDestination dest_loc'
ImpGen.copyDWIMDest dest' [] (Var what) []
where offset = case moffset of
- Nothing -> ImpGen.varIndex (kernelGlobalThreadId constants)
+ Nothing -> kernelGlobalThreadId constants
Just se -> ImpGen.compileSubExpOfType int32 se
compileKernelResult constants pe (WriteReturn rws _arr dests) = do
@@ -1298,46 +493,6 @@ compileKernelResult _ _ KernelInPlaceReturn{} =
-- Already in its place... said it was a hack.
return ()
-isActive :: [(VName, SubExp)] -> Imp.Exp
-isActive limit = case actives of
- [] -> Imp.ValueExp $ BoolValue True
- x:xs -> foldl (.&&.) x xs
- where (is, ws) = unzip limit
- actives = zipWith active is $ map (ImpGen.compileSubExpOfType Bool) ws
- active i = (Imp.var i int32 .<.)
-
-setSpaceIndices :: KernelSpace -> InKernelGen ()
-setSpaceIndices space =
- case spaceStructure space of
- FlatThreadSpace is_and_dims ->
- flatSpaceWith gtid is_and_dims
- NestedThreadSpace is_and_dims -> do
- let (gtids, gdims, ltids, ldims) = unzip4 is_and_dims
- gdims' <- mapM ImpGen.compileSubExp gdims
- ldims' <- mapM ImpGen.compileSubExp ldims
- let (gtid_es, ltid_es) = unzip $ unflattenNestedIndex gdims' ldims' gtid
- zipWithM_ (<--) gtids gtid_es
- zipWithM_ (<--) ltids ltid_es
- where gtid = Imp.var (spaceGlobalId space) int32
-
- flatSpaceWith base is_and_dims = do
- let (is, dims) = unzip is_and_dims
- dims' <- mapM ImpGen.compileSubExp dims
- let index_expressions = unflattenIndex dims' base
- zipWithM_ (<--) is index_expressions
-
-unflattenNestedIndex :: IntegralExp num => [num] -> [num] -> num -> [(num,num)]
-unflattenNestedIndex global_dims group_dims global_id =
- zip global_is local_is
- where num_groups_dims = zipWith quotRoundingUp global_dims group_dims
- group_size = product group_dims
- group_id = global_id `Futhark.Util.IntegralExp.quot` group_size
- local_id = global_id `Futhark.Util.IntegralExp.rem` group_size
-
- group_is = unflattenIndex num_groups_dims group_id
- local_is = unflattenIndex group_dims local_id
- global_is = zipWith (+) local_is $ zipWith (*) group_is group_dims
-
arrayInLocalMemory :: SubExp -> InKernelGen Bool
arrayInLocalMemory (Var name) = do
res <- ImpGen.lookupVar name
diff --git a/src/Futhark/CodeGen/ImpGen/Kernels/Base.hs b/src/Futhark/CodeGen/ImpGen/Kernels/Base.hs
new file mode 100644
index 0000000..29dae98
--- /dev/null
+++ b/src/Futhark/CodeGen/ImpGen/Kernels/Base.hs
@@ -0,0 +1,960 @@
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE TypeFamilies #-}
+module Futhark.CodeGen.ImpGen.Kernels.Base
+ ( KernelConstants (..)
+ , inKernelOperations
+ , computeKernelUses
+ , keyWithEntryPoint
+ , CallKernelGen
+ , InKernelGen
+ , computeThreadChunkSize
+ , kernelInitialisation
+ , kernelInitialisationSetSpace
+ , setSpaceIndices
+ , makeAllMemoryGlobal
+ , allThreads
+ , compileKernelStms
+ , groupReduce
+ , groupScan
+ , isActive
+ , sKernel
+ )
+ where
+
+import Control.Arrow ((&&&))
+import Control.Monad.Except
+import Control.Monad.Reader
+import Data.Maybe
+import qualified Data.Map.Strict as M
+import qualified Data.Set as S
+import Data.List
+
+import Prelude hiding (quot)
+
+import Futhark.Error
+import Futhark.MonadFreshNames
+import Futhark.Transform.Rename
+import Futhark.Representation.ExplicitMemory
+import qualified Futhark.CodeGen.ImpCode.Kernels as Imp
+import Futhark.CodeGen.ImpCode.Kernels (bytes)
+import qualified Futhark.CodeGen.ImpGen as ImpGen
+import Futhark.CodeGen.ImpGen ((<--),
+ sFor, sWhile, sComment, sIf, sWhen, sUnless,
+ sOp,
+ dPrim, dPrim_, dPrimV)
+import Futhark.Tools (partitionChunkedKernelLambdaParameters)
+import Futhark.Util.IntegralExp (quotRoundingUp, quot, rem, IntegralExp)
+import Futhark.Util (splitAt3, maybeNth)
+
+type CallKernelGen = ImpGen.ImpM ExplicitMemory Imp.HostOp
+type InKernelGen = ImpGen.ImpM InKernel Imp.KernelOp
+
+data KernelConstants = KernelConstants
+ { kernelGlobalThreadId :: Imp.Exp
+ , kernelLocalThreadId :: Imp.Exp
+ , kernelGroupId :: Imp.Exp
+ , kernelGlobalThreadIdVar :: VName
+ , kernelLocalThreadIdVar :: VName
+ , kernelGroupIdVar :: VName
+ , kernelGroupSize :: Imp.Exp
+ , kernelNumGroups :: Imp.Exp
+ , kernelNumThreads :: Imp.Exp
+ , kernelWaveSize :: Imp.Exp
+ , kernelDimensions :: [(VName, Imp.Exp)]
+ , kernelThreadActive :: Imp.Exp
+ , kernelStreamed :: [(VName, Imp.DimSize)]
+ -- ^ Chunk sizes and their maximum size. Hint
+ -- for unrolling.
+ }
+
+inKernelOperations :: KernelConstants -> ImpGen.Operations InKernel Imp.KernelOp
+inKernelOperations constants = (ImpGen.defaultOperations $ compileInKernelOp constants)
+ { ImpGen.opsCopyCompiler = inKernelCopy
+ , ImpGen.opsExpCompiler = inKernelExpCompiler
+ , ImpGen.opsStmsCompiler = \_ -> compileKernelStms constants
+ }
+
+keyWithEntryPoint :: Name -> Name -> Name
+keyWithEntryPoint fname key =
+ nameFromString $ nameToString fname ++ "." ++ nameToString key
+
+-- | We have no bulk copy operation (e.g. memmove) inside kernels, so
+-- turn any copy into a loop.
+inKernelCopy :: ImpGen.CopyCompiler InKernel Imp.KernelOp
+inKernelCopy = ImpGen.copyElementWise
+
+compileInKernelOp :: KernelConstants -> Pattern InKernel -> Op InKernel
+ -> InKernelGen ()
+compileInKernelOp _ (Pattern _ [mem]) Alloc{} =
+ compilerLimitationS $ "Cannot allocate memory block " ++ pretty mem ++ " in kernel."
+compileInKernelOp _ dest Alloc{} =
+ compilerBugS $ "Invalid target for in-kernel allocation: " ++ show dest
+compileInKernelOp constants pat (Inner op) =
+ compileKernelExp constants pat op
+
+inKernelExpCompiler :: ImpGen.ExpCompiler InKernel Imp.KernelOp
+inKernelExpCompiler _ (BasicOp (Assert _ _ (loc, locs))) =
+ compilerLimitationS $
+ unlines [ "Cannot compile assertion at " ++
+ intercalate " -> " (reverse $ map locStr $ loc:locs) ++
+ " inside parallel kernel."
+ , "As a workaround, surround the expression with 'unsafe'."]
+-- The static arrays stuff does not work inside kernels.
+inKernelExpCompiler (Pattern _ [dest]) (BasicOp (ArrayLit es _)) =
+ forM_ (zip [0..] es) $ \(i,e) ->
+ ImpGen.copyDWIM (patElemName dest) [fromIntegral (i::Int32)] e []
+inKernelExpCompiler dest e =
+ ImpGen.defCompileExp dest e
+
+compileKernelExp :: KernelConstants -> Pattern InKernel -> KernelExp InKernel
+ -> InKernelGen ()
+
+compileKernelExp _ pat (Barrier ses) = do
+ forM_ (zip (patternNames pat) ses) $ \(d, se) ->
+ ImpGen.copyDWIM d [] se []
+ sOp Imp.LocalBarrier
+
+compileKernelExp _ (Pattern [] [size]) (SplitSpace o w i elems_per_thread) = do
+ num_elements <- Imp.elements <$> ImpGen.compileSubExp w
+ i' <- ImpGen.compileSubExp i
+ elems_per_thread' <- Imp.elements <$> ImpGen.compileSubExp elems_per_thread
+ computeThreadChunkSize o i' elems_per_thread' num_elements (patElemName size)
+
+compileKernelExp constants pat (Combine (CombineSpace scatter cspace) _ aspace body) = do
+ -- First we compute how many times we have to iterate to cover
+ -- cspace with our group size. It is a fairly common case that
+ -- we statically know that this requires 1 iteration, so we
+ -- could detect it and not generate a loop in that case.
+ -- However, it seems to have no impact on performance (an extra
+ -- conditional jump), so for simplicity we just always generate
+ -- the loop.
+ let cspace_dims = map (streamBounded . snd) cspace
+ num_iters
+ | cspace_dims == [kernelGroupSize constants] = 1
+ | otherwise = product cspace_dims `quotRoundingUp`
+ kernelGroupSize constants
+
+ iter <- newVName "comb_iter"
+
+ sFor iter Int32 num_iters $ do
+ mapM_ ((`dPrim_` int32) . fst) cspace
+ -- Compute the *flat* array index.
+ cid <- dPrimV "flat_comb_id" $
+ Imp.var iter int32 * kernelGroupSize constants +
+ kernelLocalThreadId constants
+
+ -- Turn it into a nested array index.
+ zipWithM_ (<--) (map fst cspace) $ unflattenIndex cspace_dims (Imp.var cid int32)
+
+ -- Construct the body. This is mostly about the book-keeping
+ -- for the scatter-like part.
+ let (scatter_ws, scatter_ns, _scatter_vs) = unzip3 scatter
+ scatter_ws_repl = concat $ zipWith replicate scatter_ns scatter_ws
+ (scatter_pes, normal_pes) =
+ splitAt (sum scatter_ns) $ patternElements pat
+ (res_is, res_vs, res_normal) =
+ splitAt3 (sum scatter_ns) (sum scatter_ns) $ bodyResult body
+
+ -- Execute the body if we are within bounds.
+ sWhen (isActive cspace .&&. isActive aspace) $ allThreads constants $
+ ImpGen.compileStms (freeIn $ bodyResult body) (stmsToList $ bodyStms body) $ do
+
+ forM_ (zip4 scatter_ws_repl res_is res_vs scatter_pes) $
+ \(w, res_i, res_v, scatter_pe) -> do
+ let res_i' = ImpGen.compileSubExpOfType int32 res_i
+ w' = ImpGen.compileSubExpOfType int32 w
+ -- We have to check that 'res_i' is in-bounds wrt. an array of size 'w'.
+ in_bounds = 0 .<=. res_i' .&&. res_i' .<. w'
+ sWhen in_bounds $ ImpGen.copyDWIM (patElemName scatter_pe) [res_i'] res_v []
+
+ forM_ (zip normal_pes res_normal) $ \(pe, res) ->
+ ImpGen.copyDWIM (patElemName pe) local_index res []
+
+ sOp Imp.LocalBarrier
+
+ where streamBounded (Var v)
+ | Just x <- lookup v $ kernelStreamed constants =
+ Imp.sizeToExp x
+ streamBounded se = ImpGen.compileSubExpOfType int32 se
+
+ local_index = map (ImpGen.compileSubExpOfType int32 . Var . fst) cspace
+
+compileKernelExp constants (Pattern _ dests) (GroupReduce w lam input) = do
+ let [my_index_param, offset_param] = take 2 $ lambdaParams lam
+ lam' = lam { lambdaParams = drop 2 $ lambdaParams lam }
+
+ dPrim_ (paramName my_index_param) int32
+ dPrim_ (paramName offset_param) int32
+ paramName my_index_param <-- kernelGlobalThreadId constants
+ w' <- ImpGen.compileSubExp w
+ groupReduceWithOffset constants (paramName offset_param) w' lam' $ map snd input
+
+ sOp Imp.LocalBarrier
+
+ -- The final result will be stored in element 0 of the local memory array.
+ forM_ (zip dests input) $ \(dest, (_, arr)) ->
+ ImpGen.copyDWIM (patElemName dest) [] (Var arr) [0]
+
+compileKernelExp constants _ (GroupScan w lam input) = do
+ w' <- ImpGen.compileSubExp w
+ groupScan constants Nothing w' lam $ map snd input
+
+compileKernelExp constants (Pattern _ final) (GroupStream w maxchunk lam accs _arrs) = do
+ let GroupStreamLambda block_size block_offset acc_params arr_params body = lam
+ block_offset' = Imp.var block_offset int32
+ w' <- ImpGen.compileSubExp w
+ max_block_size <- ImpGen.compileSubExp maxchunk
+
+ ImpGen.dLParams (acc_params++arr_params)
+ zipWithM_ ImpGen.compileSubExpTo (map paramName acc_params) accs
+ dPrim_ block_size int32
+
+ -- If the GroupStream is morally just a do-loop, generate simpler code.
+ case mapM isSimpleThreadInSpace $ stmsToList $ bodyStms body of
+ Just stms' | ValueExp x <- max_block_size, oneIsh x -> do
+ let body' = body { bodyStms = stmsFromList stms' }
+ body'' = allThreads constants $
+ ImpGen.compileLoopBody (map paramName acc_params) body'
+ block_size <-- 1
+
+ -- Check if loop is candidate for unrolling.
+ let loop =
+ case w of
+ Var w_var | Just w_bound <- lookup w_var $ kernelStreamed constants,
+ w_bound /= Imp.ConstSize 1 ->
+ -- Candidate for unrolling, so generate two loops.
+ sIf (w' .==. Imp.sizeToExp w_bound)
+ (sFor block_offset Int32 (Imp.sizeToExp w_bound) body'')
+ (sFor block_offset Int32 w' body'')
+ _ -> sFor block_offset Int32 w' body''
+
+ if kernelThreadActive constants == Imp.ValueExp (BoolValue True)
+ then loop
+ else sWhen (kernelThreadActive constants) loop
+
+ _ -> do
+ dPrim_ block_offset int32
+ let body' = streaming constants block_size maxchunk $
+ ImpGen.compileBody' acc_params body
+
+ block_offset <-- 0
+
+ let not_at_end = block_offset' .<. w'
+ set_block_size =
+ sIf (w' - block_offset' .<. max_block_size)
+ (block_size <-- (w' - block_offset'))
+ (block_size <-- max_block_size)
+ increase_offset =
+ block_offset <-- block_offset' + max_block_size
+
+ -- Three cases to consider for simpler generated code based
+ -- on max block size: (0) if full input size, do not
+ -- generate a loop; (1) if one, generate for-loop (2)
+ -- otherwise, generate chunked while-loop.
+ if max_block_size == w' then
+ (block_size <-- w') >> body'
+ else if max_block_size == Imp.ValueExp (value (1::Int32)) then do
+ block_size <-- w'
+ sFor block_offset Int32 w' body'
+ else
+ sWhile not_at_end $
+ set_block_size >> body' >> increase_offset
+
+ forM_ (zip final acc_params) $ \(pe, p) ->
+ ImpGen.copyDWIM (patElemName pe) [] (Var $ paramName p) []
+
+ where isSimpleThreadInSpace (Let _ _ Op{}) = Nothing
+ isSimpleThreadInSpace bnd = Just bnd
+
+compileKernelExp _ _ (GroupGenReduce w arrs op bucket values locks) = do
+ -- Check if bucket is in-bounds
+ bucket' <- mapM ImpGen.compileSubExp bucket
+ w' <- mapM ImpGen.compileSubExp w
+ sWhen (indexInBounds bucket' w') $
+ atomicUpdate arrs bucket op values locking
+ where indexInBounds inds bounds =
+ foldl1 (.&&.) $ zipWith checkBound inds bounds
+ where checkBound ind bound = 0 .<=. ind .&&. ind .<. bound
+ locking = Locking locks 0 1 0
+
+compileKernelExp _ dest e =
+ compilerBugS $ unlines ["Invalid target", " " ++ show dest,
+ "for kernel expression", " " ++ pretty e]
+
+streaming :: KernelConstants -> VName -> SubExp -> InKernelGen () -> InKernelGen ()
+streaming constants chunksize bound m = do
+ bound' <- ImpGen.subExpToDimSize bound
+ let constants' =
+ constants { kernelStreamed = (chunksize, bound') : kernelStreamed constants }
+ ImpGen.emit =<< ImpGen.subImpM_ (inKernelOperations constants') m
+
+-- | Locking strategy used for an atomic update.
+data Locking = Locking { lockingArray :: VName -- ^ Array containing the lock.
+ , lockingIsUnlocked :: Imp.Exp -- ^ Value for us to consider the lock free.
+ , lockingToLock :: Imp.Exp -- ^ What to write when we lock it.
+ , lockingToUnlock :: Imp.Exp -- ^ What to write when we unlock it.
+ }
+
+atomicUpdate :: ExplicitMemorish lore =>
+ [VName] -> [SubExp] -> Lambda lore -> [SubExp] -> Locking
+ -> ImpGen.ImpM lore Imp.KernelOp ()
+
+atomicUpdate arrs bucket lam values _
+ | Just ops_and_ts <- splitOp lam,
+ all ((==32) . primBitSize . snd) ops_and_ts =
+ -- If the operator is a vectorised binary operator on 32-bit values,
+ -- we can use a particularly efficient implementation. If the
+ -- operator has an atomic implementation we use that, otherwise it
+ -- is still a binary operator which can be implemented by atomic
+ -- compare-and-swap if 32 bits.
+ forM_ (zip3 arrs ops_and_ts values) $ \(a, (op, t), val) -> do
+
+ -- Common variables.
+ old <- dPrim "old" t
+ bucket' <- mapM ImpGen.compileSubExp bucket
+
+ (arr', _a_space, bucket_offset) <- ImpGen.fullyIndexArray a bucket'
+
+ val' <- ImpGen.compileSubExp val
+ case opHasAtomicSupport old arr' bucket_offset op of
+ Just f -> sOp $ f val'
+
+ Nothing -> do
+ -- Code generation target:
+ --
+ -- old = d_his[idx];
+ -- do {
+ -- assumed = old;
+ -- tmp = OP::apply(val, assumed);
+ -- old = atomicCAS(&d_his[idx], assumed, tmp);
+ -- } while(assumed != old);
+ assumed <- dPrim "assumed" t
+ run_loop <- dPrimV "run_loop" 1
+ ImpGen.copyDWIM old [] (Var a) bucket'
+
+ -- Critical section
+ x <- dPrim "x" t
+ y <- dPrim "y" t
+
+ -- While-loop: Try to insert your value
+ let (toBits, fromBits) =
+ case t of FloatType Float32 -> (\v -> Imp.FunExp "to_bits32" [v] int32,
+ \v -> Imp.FunExp "from_bits32" [v] t)
+ _ -> (id, id)
+ sWhile (Imp.var run_loop int32) $ do
+ assumed <-- Imp.var old t
+ x <-- val'
+ y <-- Imp.var assumed t
+ x <-- Imp.BinOpExp op (Imp.var x t) (Imp.var y t)
+ old_bits <- dPrim "old_bits" int32
+ sOp $ Imp.Atomic $
+ Imp.AtomicCmpXchg old_bits arr' bucket_offset
+ (toBits (Imp.var assumed t)) (toBits (Imp.var x t))
+ old <-- fromBits (Imp.var old_bits int32)
+ sWhen (toBits (Imp.var assumed t) .==. Imp.var old_bits int32)
+ (run_loop <-- 0)
+
+ where opHasAtomicSupport old arr' bucket' bop = do
+ let atomic f = Imp.Atomic . f old arr' bucket'
+ atomic <$> Imp.atomicBinOp bop
+
+atomicUpdate arrs bucket op values locking = do
+ old <- dPrim "old" int32
+ continue <- dPrimV "continue" true
+
+ -- Check if bucket is in-bounds
+ bucket' <- mapM ImpGen.compileSubExp bucket
+
+ -- Correctly index into locks.
+ (locks', _locks_space, locks_offset) <-
+ ImpGen.fullyIndexArray (lockingArray locking) bucket'
+
+ -- Preparing parameters
+ let (acc_params, arr_params) =
+ splitAt (length values) $ lambdaParams op
+
+ -- Critical section
+ let try_acquire_lock =
+ sOp $ Imp.Atomic $
+ Imp.AtomicCmpXchg old locks' locks_offset (lockingIsUnlocked locking) (lockingToLock locking)
+ lock_acquired = Imp.var old int32 .==. lockingIsUnlocked locking
+ release_lock = ImpGen.everythingVolatile $
+ ImpGen.sWrite (lockingArray locking) bucket' $ lockingToUnlock locking
+ break_loop = continue <-- false
+
+ -- We copy the current value and the new value to the parameters.
+ -- It is important that the right-hand-side is bound first for the
+ -- (rare) case when we are dealing with arrays.
+ let bind_acc_params =
+ ImpGen.sComment "bind lhs" $
+ forM_ (zip acc_params arrs) $ \(acc_p, arr) ->
+ ImpGen.copyDWIM (paramName acc_p) [] (Var arr) bucket'
+
+ let bind_arr_params =
+ ImpGen.sComment "bind rhs" $
+ forM_ (zip arr_params values) $ \(arr_p, val) ->
+ ImpGen.copyDWIM (paramName arr_p) [] val []
+
+ let op_body = ImpGen.sComment "execute operation" $
+ ImpGen.compileBody' acc_params $ lambdaBody op
+
+ do_gen_reduce = ImpGen.sComment "update global result" $
+ zipWithM_ (writeArray bucket') arrs $ map (Var . paramName) acc_params
+
+ -- While-loop: Try to insert your value
+ sWhile (Imp.var continue Bool) $ do
+ try_acquire_lock
+ sWhen lock_acquired $ do
+ ImpGen.dLParams $ lambdaParams op
+ bind_arr_params
+ bind_acc_params
+ op_body
+ do_gen_reduce
+ release_lock
+ break_loop
+ sOp Imp.MemFence
+ where writeArray bucket' arr val =
+ ImpGen.copyDWIM arr bucket' val []
+
+-- | Horizontally fission a lambda that models a binary operator.
+splitOp :: Attributes lore => Lambda lore -> Maybe [(BinOp, PrimType)]
+splitOp lam = mapM splitStm $ bodyResult $ lambdaBody lam
+ where n = length $ lambdaReturnType lam
+ splitStm :: SubExp -> Maybe (BinOp, PrimType)
+ splitStm (Var res) = do
+ Let (Pattern [] [pe]) _ (BasicOp (BinOp op (Var x) (Var y))) <-
+ find (([res]==) . patternNames . stmPattern) $
+ stmsToList $ bodyStms $ lambdaBody lam
+ i <- Var res `elemIndex` bodyResult (lambdaBody lam)
+ xp <- maybeNth i $ lambdaParams lam
+ yp <- maybeNth (n+i) $ lambdaParams lam
+ guard $ paramName xp == x
+ guard $ paramName yp == y
+ Prim t <- Just $ patElemType pe
+ return (op, t)
+ splitStm _ = Nothing
+
+computeKernelUses :: FreeIn a =>
+ a -> [VName]
+ -> CallKernelGen ([Imp.KernelUse], [Imp.LocalMemoryUse])
+computeKernelUses kernel_body bound_in_kernel = do
+ let actually_free = freeIn kernel_body `S.difference` S.fromList bound_in_kernel
+
+ -- Compute the variables that we need to pass to the kernel.
+ reads_from <- readsFromSet actually_free
+
+ -- Are we using any local memory?
+ local_memory <- computeLocalMemoryUse actually_free
+ return (nub reads_from, nub local_memory)
+
+readsFromSet :: Names -> CallKernelGen [Imp.KernelUse]
+readsFromSet free =
+ fmap catMaybes $
+ forM (S.toList free) $ \var -> do
+ t <- lookupType var
+ case t of
+ Array {} -> return Nothing
+ Mem _ (Space "local") -> return Nothing
+ Mem {} -> return $ Just $ Imp.MemoryUse var
+ Prim bt ->
+ isConstExp var >>= \case
+ Just ce -> return $ Just $ Imp.ConstUse var ce
+ Nothing | bt == Cert -> return Nothing
+ | otherwise -> return $ Just $ Imp.ScalarUse var bt
+
+computeLocalMemoryUse :: Names -> CallKernelGen [Imp.LocalMemoryUse]
+computeLocalMemoryUse free =
+ fmap catMaybes $
+ forM (S.toList free) $ \var -> do
+ t <- lookupType var
+ case t of
+ Mem memsize (Space "local") -> do
+ memsize' <- localMemSize =<< ImpGen.subExpToDimSize memsize
+ return $ Just (var, memsize')
+ _ -> return Nothing
+
+localMemSize :: Imp.MemSize -> CallKernelGen (Either Imp.MemSize Imp.KernelConstExp)
+localMemSize (Imp.ConstSize x) =
+ return $ Right $ ValueExp $ IntValue $ Int64Value x
+localMemSize (Imp.VarSize v) = isConstExp v >>= \case
+ Just e | isStaticExp e -> return $ Right e
+ _ -> return $ Left $ Imp.VarSize v
+
+isConstExp :: VName -> CallKernelGen (Maybe Imp.KernelConstExp)
+isConstExp v = do
+ vtable <- ImpGen.getVTable
+ fname <- asks ImpGen.envFunction
+ let lookupConstExp name = constExp =<< hasExp =<< M.lookup name vtable
+ constExp (Op (Inner (GetSize key _))) =
+ Just $ LeafExp (Imp.SizeConst $ keyWithEntryPoint fname key) int32
+ constExp e = primExpFromExp lookupConstExp e
+ return $ lookupConstExp v
+ where hasExp (ImpGen.ArrayVar e _) = e
+ hasExp (ImpGen.ScalarVar e _) = e
+ hasExp (ImpGen.MemVar e _) = e
+
+-- | Only some constant expressions quality as *static* expressions,
+-- which we can use for static memory allocation. This is a bit of a
+-- hack, as it is primarly motivated by what you can put as the size
+-- when daring an array in C.
+isStaticExp :: Imp.KernelConstExp -> Bool
+isStaticExp LeafExp{} = True
+isStaticExp ValueExp{} = True
+isStaticExp (BinOpExp Add{} x y) = isStaticExp x && isStaticExp y
+isStaticExp (BinOpExp Sub{} x y) = isStaticExp x && isStaticExp y
+isStaticExp (BinOpExp Mul{} x y) = isStaticExp x && isStaticExp y
+isStaticExp _ = False
+
+computeThreadChunkSize :: SplitOrdering
+ -> Imp.Exp
+ -> Imp.Count Imp.Elements
+ -> Imp.Count Imp.Elements
+ -> VName
+ -> ImpGen.ImpM lore op ()
+computeThreadChunkSize (SplitStrided stride) thread_index elements_per_thread num_elements chunk_var = do
+ stride' <- ImpGen.compileSubExp stride
+ chunk_var <--
+ Imp.BinOpExp (SMin Int32)
+ (Imp.innerExp elements_per_thread)
+ ((Imp.innerExp num_elements - thread_index) `quotRoundingUp` stride')
+
+computeThreadChunkSize SplitContiguous thread_index elements_per_thread num_elements chunk_var = do
+ starting_point <- dPrimV "starting_point" $
+ thread_index * Imp.innerExp elements_per_thread
+ remaining_elements <- dPrimV "remaining_elements" $
+ Imp.innerExp num_elements - Imp.var starting_point int32
+
+ let no_remaining_elements = Imp.var remaining_elements int32 .<=. 0
+ beyond_bounds = Imp.innerExp num_elements .<=. Imp.var starting_point int32
+
+ sIf (no_remaining_elements .||. beyond_bounds)
+ (chunk_var <-- 0)
+ (sIf is_last_thread
+ (chunk_var <-- Imp.innerExp last_thread_elements)
+ (chunk_var <-- Imp.innerExp elements_per_thread))
+ where last_thread_elements =
+ num_elements - Imp.elements thread_index * elements_per_thread
+ is_last_thread =
+ Imp.innerExp num_elements .<.
+ (thread_index + 1) * Imp.innerExp elements_per_thread
+
+kernelInitialisationSetSpace :: KernelSpace -> InKernelGen ()
+ -> ImpGen.ImpM lore op (KernelConstants, ImpGen.ImpM InKernel Imp.KernelOp ())
+kernelInitialisationSetSpace space set_space = do
+ group_size' <- ImpGen.compileSubExp $ spaceGroupSize space
+ num_threads' <- ImpGen.compileSubExp $ spaceNumThreads space
+ num_groups <- ImpGen.compileSubExp $ spaceNumGroups space
+
+ let global_tid = spaceGlobalId space
+ local_tid = spaceLocalId space
+ group_id = spaceGroupId space
+ wave_size <- newVName "wave_size"
+ inner_group_size <- newVName "group_size"
+
+ let (space_is, space_dims) = unzip $ spaceDimensions space
+ space_dims' <- mapM ImpGen.compileSubExp space_dims
+ let constants =
+ KernelConstants
+ (Imp.var global_tid int32)
+ (Imp.var local_tid int32)
+ (Imp.var group_id int32)
+ global_tid local_tid group_id
+ group_size' num_groups num_threads'
+ (Imp.var wave_size int32) (zip space_is space_dims')
+ (if null (spaceDimensions space)
+ then true else isActive (spaceDimensions space)) mempty
+
+ let set_constants = do
+ dPrim_ wave_size int32
+ dPrim_ inner_group_size int32
+ ImpGen.dScope Nothing (scopeOfKernelSpace space)
+
+ sOp (Imp.GetGlobalId global_tid 0)
+ sOp (Imp.GetLocalId local_tid 0)
+ sOp (Imp.GetLocalSize inner_group_size 0)
+ sOp (Imp.GetLockstepWidth wave_size)
+ sOp (Imp.GetGroupId group_id 0)
+
+ set_space
+
+ return (constants, set_constants)
+
+kernelInitialisation :: KernelSpace
+ -> ImpGen.ImpM lore op (KernelConstants, ImpGen.ImpM InKernel Imp.KernelOp ())
+kernelInitialisation space =
+ kernelInitialisationSetSpace space $
+ setSpaceIndices (Imp.var (spaceGlobalId space) int32) space
+
+setSpaceIndices :: Imp.Exp -> KernelSpace -> InKernelGen ()
+setSpaceIndices gtid space =
+ case spaceStructure space of
+ FlatThreadSpace is_and_dims ->
+ flatSpaceWith gtid is_and_dims
+ NestedThreadSpace is_and_dims -> do
+ let (gtids, gdims, ltids, ldims) = unzip4 is_and_dims
+ gdims' <- mapM ImpGen.compileSubExp gdims
+ ldims' <- mapM ImpGen.compileSubExp ldims
+ let (gtid_es, ltid_es) = unzip $ unflattenNestedIndex gdims' ldims' gtid
+ zipWithM_ (<--) gtids gtid_es
+ zipWithM_ (<--) ltids ltid_es
+ where flatSpaceWith base is_and_dims = do
+ let (is, dims) = unzip is_and_dims
+ dims' <- mapM ImpGen.compileSubExp dims
+ let index_expressions = unflattenIndex dims' base
+ zipWithM_ (<--) is index_expressions
+
+isActive :: [(VName, SubExp)] -> Imp.Exp
+isActive limit = case actives of
+ [] -> Imp.ValueExp $ BoolValue True
+ x:xs -> foldl (.&&.) x xs
+ where (is, ws) = unzip limit
+ actives = zipWith active is $ map (ImpGen.compileSubExpOfType Bool) ws
+ active i = (Imp.var i int32 .<.)
+
+unflattenNestedIndex :: IntegralExp num => [num] -> [num] -> num -> [(num,num)]
+unflattenNestedIndex global_dims group_dims global_id =
+ zip global_is local_is
+ where num_groups_dims = zipWith quotRoundingUp global_dims group_dims
+ group_size = product group_dims
+ group_id = global_id `Futhark.Util.IntegralExp.quot` group_size
+ local_id = global_id `Futhark.Util.IntegralExp.rem` group_size
+
+ group_is = unflattenIndex num_groups_dims group_id
+ local_is = unflattenIndex group_dims local_id
+ global_is = zipWith (+) local_is $ zipWith (*) group_is group_dims
+
+
+-- | Change every memory block to be in the global address space,
+-- except those who are in the local memory space. This only affects
+-- generated code - we still need to make sure that the memory is
+-- actually present on the device (and dared as variables in the
+-- kernel).
+makeAllMemoryGlobal :: CallKernelGen a -> CallKernelGen a
+makeAllMemoryGlobal =
+ local (\env -> env { ImpGen.envDefaultSpace = Imp.Space "global" }) .
+ ImpGen.localVTable (M.map globalMemory)
+ where globalMemory (ImpGen.MemVar _ entry)
+ | ImpGen.entryMemSpace entry /= Space "local" =
+ ImpGen.MemVar Nothing entry { ImpGen.entryMemSpace = Imp.Space "global" }
+ globalMemory entry =
+ entry
+
+allThreads :: KernelConstants -> InKernelGen () -> InKernelGen ()
+allThreads constants = ImpGen.emit <=< ImpGen.subImpM_ (inKernelOperations constants')
+ where constants' =
+ constants { kernelThreadActive = Imp.ValueExp (BoolValue True) }
+
+
+
+writeParamToLocalMemory :: Typed (MemBound u) =>
+ Imp.Exp -> (VName, t) -> Param (MemBound u)
+ -> ImpGen.ImpM lore op ()
+writeParamToLocalMemory i (mem, _) param
+ | Prim t <- paramType param =
+ ImpGen.emit $
+ Imp.Write mem (bytes i') bt (Space "local") Imp.Volatile $
+ Imp.var (paramName param) t
+ | otherwise =
+ return ()
+ where i' = i * Imp.LeafExp (Imp.SizeOf bt) int32
+ bt = elemType $ paramType param
+
+readParamFromLocalMemory :: Typed (MemBound u) =>
+ VName -> Imp.Exp -> Param (MemBound u) -> (VName, t)
+ -> ImpGen.ImpM lore op ()
+readParamFromLocalMemory index i param (l_mem, _)
+ | Prim _ <- paramType param =
+ paramName param <--
+ Imp.index l_mem (bytes i') bt (Space "local") Imp.Volatile
+ | otherwise = index <-- i
+ where i' = i * Imp.LeafExp (Imp.SizeOf bt) int32
+ bt = elemType $ paramType param
+
+groupReduce :: ExplicitMemorish lore =>
+ KernelConstants
+ -> Imp.Exp
+ -> Lambda lore
+ -> [VName]
+ -> ImpGen.ImpM lore Imp.KernelOp ()
+groupReduce constants w lam arrs = do
+ offset <- dPrim "offset" int32
+ groupReduceWithOffset constants offset w lam arrs
+
+groupReduceWithOffset :: ExplicitMemorish lore =>
+ KernelConstants
+ -> VName
+ -> Imp.Exp
+ -> Lambda lore
+ -> [VName]
+ -> ImpGen.ImpM lore Imp.KernelOp ()
+groupReduceWithOffset constants offset w lam arrs = do
+ let (reduce_acc_params, reduce_arr_params) = splitAt (length arrs) $ lambdaParams lam
+
+ skip_waves <- dPrim "skip_waves" int32
+ ImpGen.dLParams $ lambdaParams lam
+
+ offset <-- 0
+
+ ImpGen.comment "participating threads read initial accumulator" $
+ sWhen (local_tid .<. w) $
+ zipWithM_ readReduceArgument reduce_acc_params arrs
+
+ let do_reduce = do ImpGen.comment "read array element" $
+ zipWithM_ readReduceArgument reduce_arr_params arrs
+ ImpGen.comment "apply reduction operation" $
+ ImpGen.compileBody' reduce_acc_params $ lambdaBody lam
+ ImpGen.comment "write result of operation" $
+ zipWithM_ writeReduceOpResult reduce_acc_params arrs
+ in_wave_reduce = ImpGen.everythingVolatile do_reduce
+
+ wave_size = kernelWaveSize constants
+ group_size = kernelGroupSize constants
+ wave_id = local_tid `quot` wave_size
+ in_wave_id = local_tid - wave_id * wave_size
+ num_waves = (group_size + wave_size - 1) `quot` wave_size
+ arg_in_bounds = local_tid + Imp.var offset int32 .<. w
+
+ doing_in_wave_reductions =
+ Imp.var offset int32 .<. wave_size
+ apply_in_in_wave_iteration =
+ (in_wave_id .&. (2 * Imp.var offset int32 - 1)) .==. 0
+ in_wave_reductions = do
+ offset <-- 1
+ sWhile doing_in_wave_reductions $ do
+ sWhen (arg_in_bounds .&&. apply_in_in_wave_iteration)
+ in_wave_reduce
+ offset <-- Imp.var offset int32 * 2
+
+ doing_cross_wave_reductions =
+ Imp.var skip_waves int32 .<. num_waves
+ is_first_thread_in_wave =
+ in_wave_id .==. 0
+ wave_not_skipped =
+ (wave_id .&. (2 * Imp.var skip_waves int32 - 1)) .==. 0
+ apply_in_cross_wave_iteration =
+ arg_in_bounds .&&. is_first_thread_in_wave .&&. wave_not_skipped
+ cross_wave_reductions = do
+ skip_waves <-- 1
+ sWhile doing_cross_wave_reductions $ do
+ barrier
+ offset <-- Imp.var skip_waves int32 * wave_size
+ sWhen apply_in_cross_wave_iteration
+ do_reduce
+ skip_waves <-- Imp.var skip_waves int32 * 2
+
+ in_wave_reductions
+ cross_wave_reductions
+ where local_tid = kernelLocalThreadId constants
+ global_tid = kernelGlobalThreadId constants
+
+ barrier
+ | all primType $ lambdaReturnType lam = sOp Imp.LocalBarrier
+ | otherwise = sOp Imp.GlobalBarrier
+
+ readReduceArgument param arr
+ | Prim _ <- paramType param = do
+ let i = local_tid + ImpGen.varIndex offset
+ ImpGen.copyDWIM (paramName param) [] (Var arr) [i]
+ | otherwise = do
+ let i = global_tid + ImpGen.varIndex offset
+ ImpGen.copyDWIM (paramName param) [] (Var arr) [i]
+
+ writeReduceOpResult param arr
+ | Prim _ <- paramType param =
+ ImpGen.copyDWIM arr [local_tid] (Var $ paramName param) []
+ | otherwise =
+ return ()
+
+groupScan :: KernelConstants
+ -> Maybe (Imp.Exp -> Imp.Exp -> Imp.Exp)
+ -> Imp.Exp
+ -> Lambda InKernel
+ -> [VName]
+ -> ImpGen.ImpM InKernel Imp.KernelOp ()
+groupScan constants seg_flag w lam arrs = do
+ when (any (not . primType . paramType) $ lambdaParams lam) $
+ compilerLimitationS "Cannot compile parallel scans with array element type."
+
+ renamed_lam <- renameLambda lam
+
+ acc_local_mem <- flip zip (repeat ()) <$>
+ mapM (fmap (ImpGen.memLocationName . ImpGen.entryArrayLocation) .
+ ImpGen.lookupArray) arrs
+
+ let ltid = kernelLocalThreadId constants
+ (lam_i, other_index_param, actual_params) =
+ partitionChunkedKernelLambdaParameters $ lambdaParams lam
+ (x_params, y_params) = splitAt (length arrs) actual_params
+
+ ImpGen.dLParams (lambdaParams lam++lambdaParams renamed_lam)
+ lam_i <-- ltid
+
+ -- The scan works by splitting the group into blocks, which are
+ -- scanned separately. Typically, these blocks are smaller than
+ -- the lockstep width, which enables barrier-free execution inside
+ -- them.
+ --
+ -- We hardcode the block size here. The only requirement is that
+ -- it should not be less than the square root of the group size.
+ -- With 32, we will work on groups of size 1024 or smaller, which
+ -- fits every device Troels has seen. Still, it would be nicer if
+ -- it were a runtime parameter. Some day.
+ let block_size = Imp.ValueExp $ IntValue $ Int32Value 32
+ simd_width = kernelWaveSize constants
+ block_id = ltid `quot` block_size
+ in_block_id = ltid - block_id * block_size
+ doInBlockScan seg_flag' active = inBlockScan seg_flag' simd_width block_size active ltid acc_local_mem
+ ltid_in_bounds = ltid .<. w
+
+ doInBlockScan seg_flag ltid_in_bounds lam
+ sOp Imp.LocalBarrier
+
+ let last_in_block = in_block_id .==. block_size - 1
+ sComment "last thread of block 'i' writes its result to offset 'i'" $
+ sWhen (last_in_block .&&. ltid_in_bounds) $
+ zipWithM_ (writeParamToLocalMemory block_id) acc_local_mem y_params
+
+ sOp Imp.LocalBarrier
+
+ let is_first_block = block_id .==. 0
+ first_block_seg_flag = do
+ flag_true <- seg_flag
+ Just $ \from to ->
+ flag_true (from*block_size+block_size-1) (to*block_size+block_size-1)
+ ImpGen.comment
+ "scan the first block, after which offset 'i' contains carry-in for warp 'i+1'" $
+ doInBlockScan first_block_seg_flag (is_first_block .&&. ltid_in_bounds) renamed_lam
+
+ sOp Imp.LocalBarrier
+
+ let read_carry_in =
+ zipWithM_ (readParamFromLocalMemory
+ (paramName other_index_param) (block_id - 1))
+ x_params acc_local_mem
+
+ let op_to_y
+ | Nothing <- seg_flag =
+ ImpGen.compileBody' y_params $ lambdaBody lam
+ | Just flag_true <- seg_flag =
+ sUnless (flag_true (block_id*block_size-1) ltid) $
+ ImpGen.compileBody' y_params $ lambdaBody lam
+ write_final_result =
+ zipWithM_ (writeParamToLocalMemory ltid) acc_local_mem y_params
+
+ sComment "carry-in for every block except the first" $
+ sUnless (is_first_block .||. Imp.UnOpExp Not ltid_in_bounds) $ do
+ sComment "read operands" read_carry_in
+ sComment "perform operation" op_to_y
+ sComment "write final result" write_final_result
+
+ sOp Imp.LocalBarrier
+
+ sComment "restore correct values for first block" $
+ sWhen is_first_block write_final_result
+
+inBlockScan :: Maybe (Imp.Exp -> Imp.Exp -> Imp.Exp)
+ -> Imp.Exp
+ -> Imp.Exp
+ -> Imp.Exp
+ -> Imp.Exp
+ -> [(VName, t)]
+ -> Lambda InKernel
+ -> InKernelGen ()
+inBlockScan seg_flag lockstep_width block_size active ltid acc_local_mem scan_lam = ImpGen.everythingVolatile $ do
+ skip_threads <- dPrim "skip_threads" int32
+ let in_block_thread_active =
+ Imp.var skip_threads int32 .<=. in_block_id
+ (scan_lam_i, other_index_param, actual_params) =
+ partitionChunkedKernelLambdaParameters $ lambdaParams scan_lam
+ (x_params, y_params) =
+ splitAt (length actual_params `div` 2) actual_params
+ read_operands =
+ zipWithM_ (readParamFromLocalMemory (paramName other_index_param) $
+ ltid - Imp.var skip_threads int32)
+ x_params acc_local_mem
+
+ -- Set initial y values
+ sWhen active $
+ zipWithM_ (readParamFromLocalMemory scan_lam_i ltid)
+ y_params acc_local_mem
+
+ let op_to_y
+ | Nothing <- seg_flag =
+ ImpGen.compileBody' y_params $ lambdaBody scan_lam
+ | Just flag_true <- seg_flag =
+ sUnless (flag_true (ltid-Imp.var skip_threads int32) ltid) $
+ ImpGen.compileBody' y_params $ lambdaBody scan_lam
+ write_operation_result =
+ zipWithM_ (writeParamToLocalMemory ltid) acc_local_mem y_params
+ maybeLocalBarrier = sWhen (lockstep_width .<=. Imp.var skip_threads int32) $
+ sOp Imp.LocalBarrier
+
+ sComment "in-block scan (hopefully no barriers needed)" $ do
+ skip_threads <-- 1
+ sWhile (Imp.var skip_threads int32 .<. block_size) $ do
+ sWhen (in_block_thread_active .&&. active) $ do
+ sComment "read operands" read_operands
+ sComment "perform operation" op_to_y
+
+ maybeLocalBarrier
+
+ sWhen (in_block_thread_active .&&. active) $
+ sComment "write result" write_operation_result
+
+ maybeLocalBarrier
+
+ skip_threads <-- Imp.var skip_threads int32 * 2
+
+ where block_id = ltid `quot` block_size
+ in_block_id = ltid - block_id * block_size
+
+compileKernelStms :: KernelConstants -> [Stm InKernel]
+ -> InKernelGen a
+ -> InKernelGen a
+compileKernelStms constants ungrouped_bnds m =
+ compileGroupedKernelStms' $ groupStmsByGuard constants ungrouped_bnds
+ where compileGroupedKernelStms' [] = m
+ compileGroupedKernelStms' ((g, bnds):rest_bnds) = do
+ ImpGen.dScopes (map ((Just . stmExp) &&& (castScope . scopeOf)) bnds)
+ protect g $ mapM_ compileKernelStm bnds
+ compileGroupedKernelStms' rest_bnds
+
+ protect Nothing body_m =
+ body_m
+ protect (Just (Imp.ValueExp (BoolValue True))) body_m =
+ body_m
+ protect (Just g) body_m =
+ sWhen g $ allThreads constants body_m
+
+ compileKernelStm (Let pat _ e) = ImpGen.compileExp pat e
+
+groupStmsByGuard :: KernelConstants
+ -> [Stm InKernel]
+ -> [(Maybe Imp.Exp, [Stm InKernel])]
+groupStmsByGuard constants bnds =
+ map collapse $ groupBy sameGuard $ zip (map bindingGuard bnds) bnds
+ where bindingGuard (Let _ _ Op{}) = Nothing
+ bindingGuard _ = Just $ kernelThreadActive constants
+
+ sameGuard (g1, _) (g2, _) = g1 == g2
+
+ collapse [] =
+ (Nothing, [])
+ collapse l@((g,_):_) =
+ (g, map snd l)
+
+sKernel :: KernelConstants -> String -> ImpGen.ImpM InKernel Imp.KernelOp a -> CallKernelGen ()
+sKernel constants name m = do
+ body <- makeAllMemoryGlobal $
+ ImpGen.subImpM_ (inKernelOperations constants) m
+ (uses, local_memory) <- computeKernelUses body mempty
+ ImpGen.emit $ Imp.Op $ Imp.CallKernel Imp.Kernel
+ { Imp.kernelBody = body
+ , Imp.kernelLocalMemory = local_memory
+ , Imp.kernelUses = uses
+ , Imp.kernelNumGroups = [kernelNumGroups constants]
+ , Imp.kernelGroupSize = [kernelGroupSize constants]
+ , Imp.kernelName =
+ nameFromString $ name ++ "_" ++ show (baseTag $ kernelGlobalThreadIdVar constants)
+ }
diff --git a/src/Futhark/CodeGen/ImpGen/Kernels/SegRed.hs b/src/Futhark/CodeGen/ImpGen/Kernels/SegRed.hs
new file mode 100644
index 0000000..6b3fba3
--- /dev/null
+++ b/src/Futhark/CodeGen/ImpGen/Kernels/SegRed.hs
@@ -0,0 +1,601 @@
+{-# LANGUAGE TypeFamilies #-}
+{-# LANGUAGE FlexibleContexts #-}
+-- | We generate code for non-segmented/single-segment SegRed using
+-- the basic approach outlined in the paper "Design and GPGPU
+-- Performance of Futhark’s Redomap Construct" (ARRAY '16). The main
+-- deviations are:
+--
+-- * While we still use two-phase reduction, we use only a single
+-- kernel, with the final workgroup to write a result (tracked via
+-- an atomic counter) performing the final reduction as well.
+--
+-- * Instead of depending on storage layout transformations to handle
+-- non-commutative reductions efficiently, we slide a
+-- 'groupsize'-sized window over the input, and perform a parallel
+-- reduction for each window. This sacrifices the notion of
+-- efficient sequentialisation, but is sometimes faster and
+-- definitely simpler and more predictable (and uses less auxiliary
+-- storage).
+--
+-- For segmented reductions we use the approach from "Strategies for
+-- Regular Segmented Reductions on GPU" (FHPC '17). This involves
+-- having two different strategies, and dynamically deciding which one
+-- to use based on the number of segments and segment size. We use the
+-- (static) @group_size@ to decide which of the following two
+-- strategies to choose:
+--
+-- * Large: uses one or more groups to process a single segment. If
+-- multiple groups are used per segment, the intermediate reduction
+-- results must be recursively reduced, until there is only a single
+-- value per segment.
+--
+-- Each thread /can/ read multiple elements, which will greatly
+-- increase performance; however, if the reduction is
+-- non-commutative we will have to use a less efficient traversal
+-- (with interim group-wide reductions) to enable coalesced memory
+-- accesses, just as in the non-segmented case.
+--
+-- * Small: is used to let each group process *multiple* segments
+-- within a group. We will only use this approach when we can
+-- process at least two segments within a single group. In those
+-- cases, we would allocate a /whole/ group per segment with the
+-- large strategy, but at most 50% of the threads in the group would
+-- have any element to read, which becomes highly inefficient.
+module Futhark.CodeGen.ImpGen.Kernels.SegRed
+ ( compileSegRed
+ )
+ where
+
+import Control.Monad.Except
+import Data.Maybe
+import qualified Data.Set as S
+import Data.List
+
+import Prelude hiding (quot, rem)
+
+import Futhark.MonadFreshNames
+import Futhark.Transform.Rename
+import Futhark.Representation.ExplicitMemory
+import qualified Futhark.CodeGen.ImpCode.Kernels as Imp
+import qualified Futhark.CodeGen.ImpGen as ImpGen
+import Futhark.CodeGen.ImpGen ((<--),
+ sFor, sComment, sIf, sWhen,
+ sOp,
+ dPrim, dPrimV)
+import Futhark.CodeGen.ImpGen.Kernels.Base
+import qualified Futhark.Representation.ExplicitMemory.IndexFunction as IxFun
+import Futhark.Util.IntegralExp (quotRoundingUp, quot, rem)
+
+-- | For many kernels, we may not have enough physical groups to cover
+-- the logical iteration space. Some groups thus have to perform
+-- double duty; we put an outer loop to accomplish this. The
+-- advantage over just launching a bazillion threads is that the cost
+-- of memory expansion should be proportional to the number of
+-- *physical* threads (hardware parallelism), not the amount of
+-- application parallelism.
+virtualiseGroups :: KernelConstants
+ -> Imp.Exp
+ -> (Imp.Exp -> ImpGen.ImpM lore op ())
+ -> ImpGen.ImpM lore op ()
+virtualiseGroups constants required_groups m = do
+ let group_id = kernelGroupId constants
+ iterations = (required_groups - group_id) `quotRoundingUp` kernelNumGroups constants
+ i <- newVName "i"
+ sFor i Int32 iterations $ m $ group_id + Imp.var i int32 * kernelNumGroups constants
+
+-- Compile 'SegRed' instance to host-level code with calls to various
+-- kernels.
+compileSegRed :: Pattern ExplicitMemory
+ -> KernelSpace
+ -> Commutativity -> Lambda InKernel -> [SubExp]
+ -> Body InKernel
+ -> CallKernelGen ()
+compileSegRed pat space comm red_op nes body
+ | [(_, Constant (IntValue (Int32Value 1))), _] <- spaceDimensions space =
+ nonsegmentedReduction pat space comm red_op nes body
+ | otherwise = do
+ segment_size <-
+ ImpGen.compileSubExp $ last $ map snd $ spaceDimensions space
+ group_size <- ImpGen.compileSubExp $ spaceGroupSize space
+ let use_small_segments = segment_size * 2 .<. group_size
+ sIf (segment_size .==. 1)
+ (unitSegmentsReduction pat space nes body) $
+ sIf use_small_segments
+ (smallSegmentsReduction pat space red_op nes body)
+ (largeSegmentsReduction pat space comm red_op nes body)
+
+-- Handle degenerate case where segments are of size 1, meaning
+-- that it is really just a 'map' in disguise.
+unitSegmentsReduction :: Pattern ExplicitMemory
+ -> KernelSpace
+ -> [SubExp]
+ -> Body InKernel
+ -> CallKernelGen ()
+unitSegmentsReduction (Pattern _ segred_pes) space nes body = do
+ (constants, init_constants) <- kernelInitialisationSetSpace space $ return ()
+
+ let (gtids, dims) = unzip $ spaceDimensions space
+ (redout_pes, mapout_pes) = splitAt (length nes) segred_pes
+
+ dims' <- mapM ImpGen.compileSubExp dims
+
+ let num_segments = product $ init dims'
+ required_groups = num_segments `quotRoundingUp` kernelGroupSize constants
+
+ ImpGen.emit $ Imp.DebugPrint "num_segments" int32 num_segments
+ ImpGen.emit $ Imp.DebugPrint "required_groups" int32 required_groups
+
+ sKernel constants "segred_mapseg" $ do
+ init_constants
+ virtualiseGroups constants required_groups $ \group_id -> do
+ setSpaceIndices (group_id * kernelGroupSize constants + kernelLocalThreadId constants) space
+ ImpGen.compileStms mempty (stmsToList $ bodyStms body) $
+ sWhen (kernelThreadActive constants) $ do
+ let (redout_ses, mapout_ses) = splitAt (length nes) $ bodyResult body
+ forM_ (zip redout_pes redout_ses) $ \(pe, se) ->
+ ImpGen.copyDWIM (patElemName pe)
+ (map (`Imp.var` int32) (init gtids)) se []
+
+ forM_ (zip mapout_pes mapout_ses) $ \(pe, se) ->
+ ImpGen.copyDWIM (patElemName pe)
+ (map (`Imp.var` int32) gtids) se []
+
+nonsegmentedReduction :: Pattern ExplicitMemory
+ -> KernelSpace
+ -> Commutativity -> Lambda InKernel -> [SubExp]
+ -> Body InKernel
+ -> CallKernelGen ()
+nonsegmentedReduction segred_pat space comm red_op nes body = do
+ (base_constants, init_constants) <- kernelInitialisationSetSpace space $ return ()
+ let constants = base_constants { kernelThreadActive = true }
+ global_tid = kernelGlobalThreadId constants
+ (_, w) = last $ spaceDimensions space
+
+ let red_op_params = lambdaParams red_op
+ (red_acc_params, _) = splitAt (length nes) red_op_params
+ red_arrs <- forM red_acc_params $ \p ->
+ case paramAttr p of
+ MemArray pt shape _ (ArrayIn mem _) -> do
+ let shape' = Shape [spaceNumThreads space] <> shape
+ ImpGen.sArray "red_arr" pt shape' $
+ ArrayIn mem $ IxFun.iota $ map (primExpFromSubExp int32) $ shapeDims shape'
+ _ -> do
+ let pt = elemType $ paramType p
+ shape = Shape [spaceGroupSize space]
+ ImpGen.sAllocArray "red_arr" pt shape $ Space "local"
+
+ counter <-
+ ImpGen.sStaticArray "counter" (Space "device") int32 $
+ replicate 1 $ IntValue $ Int32Value 0
+
+ group_res_arrs <- forM (lambdaReturnType red_op) $ \t -> do
+ let pt = elemType t
+ shape = Shape [spaceNumGroups space] <> arrayShape t
+ ImpGen.sAllocArray "group_res_arr" pt shape $ Space "device"
+
+ sync_arr <- ImpGen.sAllocArray "sync_arr" Bool (Shape [intConst Int32 1]) $ Space "local"
+
+ num_threads <- dPrimV "num_threads" $ kernelNumThreads constants
+
+ sKernel constants "segred_nonseg" $ allThreads constants $ do
+ init_constants
+
+ -- Since this is the nonsegmented case, all outer segment IDs must
+ -- necessarily be 0.
+ let gtids = map fst $ spaceDimensions space
+ forM_ (init gtids) $ \v ->
+ v <-- 0
+
+ num_elements <- Imp.elements <$> ImpGen.compileSubExp w
+ let elems_per_thread = num_elements `quotRoundingUp` Imp.elements (kernelNumThreads constants)
+
+ (group_result_params, red_op_renamed) <-
+ reductionStageOne constants segred_pat num_elements
+ global_tid elems_per_thread num_threads
+ comm red_op nes red_arrs body
+
+ reductionStageTwo constants segred_pat 0 [0] 0
+ (kernelNumGroups constants) group_result_params red_acc_params red_op_renamed nes
+ 1 counter sync_arr group_res_arrs red_arrs
+
+hasMemoryAccesses :: Body InKernel -> ImpGen.ImpM InKernel Imp.KernelOp Bool
+hasMemoryAccesses body = or <$> mapM isArray (S.toList $ freeInBody body)
+ where isArray = fmap (not . primType) . lookupType
+
+smallSegmentsReduction :: Pattern ExplicitMemory
+ -> KernelSpace
+ -> Lambda InKernel -> [SubExp]
+ -> Body InKernel
+ -> CallKernelGen ()
+smallSegmentsReduction (Pattern _ segred_pes) space red_op nes body = do
+ (base_constants, init_constants) <- kernelInitialisationSetSpace space $ return ()
+ let constants = base_constants { kernelThreadActive = true }
+
+ let (gtids, dims) = unzip $ spaceDimensions space
+ dims' <- mapM ImpGen.compileSubExp dims
+
+ let segment_size = last dims'
+ num_segments = product $ init dims'
+ segments_per_group = kernelGroupSize constants `quot` segment_size
+ required_groups = num_segments `quotRoundingUp` segments_per_group
+
+ let red_op_params = lambdaParams red_op
+ (red_acc_params, _red_next_params) = splitAt (length nes) red_op_params
+ red_arrs <- forM red_acc_params $ \p ->
+ case paramAttr p of
+ MemArray pt shape _ (ArrayIn mem _) -> do
+ let shape' = Shape [spaceNumThreads space] <> shape
+ ImpGen.sArray "red_arr" pt shape' $
+ ArrayIn mem $ IxFun.iota $ map (primExpFromSubExp int32) $ shapeDims shape'
+ _ -> do
+ let pt = elemType $ paramType p
+ shape = Shape [spaceGroupSize space]
+ ImpGen.sAllocArray "red_arr" pt shape $ Space "local"
+
+ ImpGen.emit $ Imp.DebugPrint "num_segments" int32 num_segments
+ ImpGen.emit $ Imp.DebugPrint "segment_size" int32 segment_size
+ ImpGen.emit $ Imp.DebugPrint "segments_per_group" int32 segments_per_group
+ ImpGen.emit $ Imp.DebugPrint "required_groups" int32 required_groups
+
+ sKernel constants "segred_small" $ allThreads constants $ do
+ init_constants
+
+ -- We probably do not have enough actual workgroups to cover the
+ -- entire iteration space. Some groups thus have to perform double
+ -- duty; we put an outer loop to accomplish this.
+ virtualiseGroups constants required_groups $ \group_id' -> do
+ -- Compute the 'n' input indices. The outer 'n-1' correspond to
+ -- the segment ID, and are computed from the group id. The inner
+ -- is computed from the local thread id, and may be out-of-bounds.
+ let ltid = kernelLocalThreadId constants
+ segment_index = (ltid `quot` segment_size) + (group_id' * segments_per_group)
+ index_within_segment = ltid `rem` segment_size
+
+ zipWithM_ (<--) (init gtids) $ unflattenIndex (init dims') segment_index
+ last gtids <-- index_within_segment
+
+ let toLocalMemory ses =
+ forM_ (zip red_arrs ses) $ \(arr, se) -> do
+ se_t <- subExpType se
+ when (primType se_t) $
+ ImpGen.copyDWIM arr [ltid] se []
+
+ in_bounds =
+ ImpGen.compileStms mempty (stmsToList $ bodyStms body) $ do
+ let (red_res, map_res) = splitAt (length nes) $ bodyResult body
+
+ sComment "save results to be reduced" $
+ toLocalMemory red_res
+
+ sComment "save map-out results" $
+ forM_ (zip (drop (length nes) segred_pes) map_res) $ \(pe, se) ->
+ ImpGen.copyDWIM (patElemName pe) (map (`Imp.var` int32) gtids) se []
+
+ sComment "apply map function if in bounds" $
+ sIf (isActive (init $ zip gtids dims) .&&.
+ ltid .<. segment_size * segments_per_group) in_bounds (toLocalMemory nes)
+
+ sOp Imp.LocalBarrier
+
+ index_i <- newVName "index_i"
+ index_j <- newVName "index_j"
+ let crossesSegment from to =
+ (to-from) .>. (to `rem` segment_size)
+ red_op' = red_op { lambdaParams = Param index_i (MemPrim int32) :
+ Param index_j (MemPrim int32) :
+ lambdaParams red_op }
+
+ sComment "perform segmented scan to imitate reduction" $
+ groupScan constants (Just crossesSegment) (segment_size*segments_per_group) red_op' red_arrs
+
+ sOp Imp.LocalBarrier
+
+ sComment "save final values of segments" $
+ sWhen (group_id' * segments_per_group + ltid .<. num_segments .&&.
+ ltid .<. segments_per_group) $
+ forM_ (zip segred_pes red_arrs) $ \(pe, arr) -> do
+ -- Figure out which segment result this thread should write...
+ let flat_segment_index = group_id' * segments_per_group + ltid
+ gtids' = unflattenIndex (init dims') flat_segment_index
+ ImpGen.copyDWIM (patElemName pe) gtids'
+ (Var arr) [(ltid+1) * segment_size - 1]
+
+largeSegmentsReduction :: Pattern ExplicitMemory
+ -> KernelSpace
+ -> Commutativity -> Lambda InKernel -> [SubExp]
+ -> Body InKernel
+ -> CallKernelGen ()
+largeSegmentsReduction segred_pat space comm red_op nes body = do
+ (base_constants, init_constants) <- kernelInitialisationSetSpace space $ return ()
+ let (gtids, dims) = unzip $ spaceDimensions space
+ dims' <- mapM ImpGen.compileSubExp dims
+ let segment_size = last dims'
+ num_segments = product $ init dims'
+
+ let (groups_per_segment, elems_per_thread) =
+ groupsPerSegmentAndElementsPerThread segment_size num_segments
+ (kernelNumGroups base_constants) (kernelGroupSize base_constants)
+ num_groups <- dPrimV "num_groups" $
+ groups_per_segment * num_segments
+
+ num_threads <- dPrimV "num_threads" $
+ Imp.var num_groups int32 * kernelGroupSize base_constants
+
+ threads_per_segment <- dPrimV "thread_per_segment" $
+ groups_per_segment * kernelGroupSize base_constants
+
+ let constants = base_constants
+ { kernelThreadActive = true
+ , kernelNumGroups = Imp.var num_groups int32
+ , kernelNumThreads = Imp.var num_threads int32
+ }
+
+ ImpGen.emit $ Imp.DebugPrint "num_segments" int32 num_segments
+ ImpGen.emit $ Imp.DebugPrint "segment_size" int32 segment_size
+ ImpGen.emit $ Imp.DebugPrint "num_groups" int32 (Imp.var num_groups int32)
+ ImpGen.emit $ Imp.DebugPrint "group_size" int32 (kernelGroupSize constants)
+ ImpGen.emit $ Imp.DebugPrint "elems_per_thread" int32 $ Imp.innerExp elems_per_thread
+ ImpGen.emit $ Imp.DebugPrint "groups_per_segment" int32 groups_per_segment
+
+ let red_op_params = lambdaParams red_op
+ (red_acc_params, _) = splitAt (length nes) red_op_params
+ red_arrs <- forM red_acc_params $ \p ->
+ case paramAttr p of
+ MemArray pt shape _ (ArrayIn mem _) -> do
+ let shape' = Shape [Var num_threads] <> shape
+ ImpGen.sArray "red_arr" pt shape' $
+ ArrayIn mem $ IxFun.iota $ map (primExpFromSubExp int32) $ shapeDims shape'
+ _ -> do
+ let pt = elemType $ paramType p
+ shape = Shape [spaceGroupSize space]
+ ImpGen.sAllocArray "red_arr" pt shape $ Space "local"
+
+ group_res_arrs <- forM (lambdaReturnType red_op) $ \t -> do
+ let pt = elemType t
+ shape = Shape [Var num_groups] <> arrayShape t
+ ImpGen.sAllocArray "group_res_arr" pt shape $ Space "device"
+
+ -- In principle we should have a counter for every segment. Since
+ -- the number of segments is a dynamic quantity, we would have to
+ -- allocate and zero out an array here, which is expensive.
+ -- However, we exploit the fact that the number of segments being
+ -- reduced at any point in time is limited by the number of
+ -- workgroups. If we bound the number of workgroups, we can get away
+ -- with using that many counters. FIXME: Is this limit checked
+ -- anywhere? There are other places in the compiler that will fail
+ -- if the group count exceeds the maximum group size, which is at
+ -- most 1024 anyway.
+ let num_counters = 1024
+ counter <-
+ ImpGen.sStaticArray "counter" (Space "device") int32 $
+ replicate num_counters $ IntValue $ Int32Value 0
+
+ sync_arr <- ImpGen.sAllocArray "sync_arr" Bool (Shape [intConst Int32 1]) $ Space "local"
+
+ sKernel constants "segred_large" $ allThreads constants $ do
+ init_constants
+ let segment_gtids = init gtids
+ group_id = kernelGroupId constants
+ group_size = kernelGroupSize constants
+ flat_segment_id = group_id `quot` groups_per_segment
+ local_tid = kernelLocalThreadId constants
+
+ global_tid = kernelGlobalThreadId constants
+ `rem` (group_size * groups_per_segment)
+ w = last dims
+ first_group_for_segment = flat_segment_id * groups_per_segment
+
+ zipWithM_ (<--) segment_gtids $ unflattenIndex (init dims') flat_segment_id
+ num_elements <- Imp.elements <$> ImpGen.compileSubExp w
+
+ (group_result_params, red_op_renamed) <-
+ reductionStageOne constants segred_pat num_elements
+ global_tid elems_per_thread threads_per_segment
+ comm red_op nes red_arrs body
+
+ let multiple_groups_per_segment =
+ reductionStageTwo constants segred_pat
+ flat_segment_id (map (`Imp.var` int32) segment_gtids)
+ first_group_for_segment groups_per_segment
+ group_result_params red_acc_params red_op_renamed
+ nes (fromIntegral num_counters) counter sync_arr group_res_arrs red_arrs
+
+ one_group_per_segment =
+ ImpGen.comment "first thread in group saves final result to memory" $
+ sWhen (local_tid .==. 0) $
+ forM_ (take (length nes) $ zip (patternNames segred_pat) group_result_params) $ \(v, p) ->
+ ImpGen.copyDWIM v (map (`Imp.var` int32) segment_gtids) (Var $ paramName p) []
+
+ sIf (groups_per_segment .==. 1) one_group_per_segment multiple_groups_per_segment
+
+groupsPerSegmentAndElementsPerThread :: Imp.Exp -> Imp.Exp -> Imp.Exp -> Imp.Exp
+ -> (Imp.Exp, Imp.Count Imp.Elements)
+groupsPerSegmentAndElementsPerThread segment_size num_segments num_groups_hint group_size =
+ let groups_per_segment =
+ num_groups_hint `quotRoundingUp` num_segments
+ elements_per_thread =
+ segment_size `quotRoundingUp` (group_size * groups_per_segment)
+ in (groups_per_segment, Imp.elements elements_per_thread)
+
+reductionStageOne :: KernelConstants
+ -> Pattern ExplicitMemory
+ -> Imp.Count Imp.Elements
+ -> Imp.Exp
+ -> Imp.Count Imp.Elements
+ -> VName
+ -> Commutativity
+ -> LambdaT InKernel
+ -> [SubExp]
+ -> [VName]
+ -> Body InKernel
+ -> InKernelGen ([LParam InKernel], Lambda InKernel)
+reductionStageOne constants (Pattern _ segred_pes) num_elements global_tid elems_per_thread threads_per_segment comm red_op nes red_arrs body = do
+
+ let red_op_params = lambdaParams red_op
+ (red_acc_params, red_next_params) = splitAt (length nes) red_op_params
+ (gtids, _dims) = unzip $ kernelDimensions constants
+ gtid = last gtids
+ local_tid = kernelLocalThreadId constants
+ index_in_segment = global_tid `quot` kernelGroupSize constants
+
+ -- Figure out how many elements this thread should process.
+ chunk_size <- dPrim "chunk_size" int32
+ let ordering = case comm of Commutative -> SplitStrided $ Var threads_per_segment
+ Noncommutative -> SplitContiguous
+ accesses_memory <- hasMemoryAccesses body
+ computeThreadChunkSize ordering global_tid elems_per_thread num_elements chunk_size
+
+ ImpGen.dScope Nothing $ scopeOfLParams $ lambdaParams red_op
+
+ forM_ (zip red_acc_params nes) $ \(p, ne) ->
+ ImpGen.copyDWIM (paramName p) [] ne []
+
+ red_op_renamed <- renameLambda red_op
+
+ let doTheReduction = do
+ ImpGen.comment "to reduce current chunk, first store our result to memory" $
+ forM_ (zip red_arrs red_acc_params) $ \(arr, p) ->
+ when (primType $ paramType p) $
+ ImpGen.copyDWIM arr [local_tid] (Var $ paramName p) []
+
+ sOp Imp.LocalBarrier
+
+ groupReduce constants (kernelGroupSize constants) red_op_renamed red_arrs
+
+ i <- newVName "i"
+ -- If this is a non-commutative reduction, each thread must run the
+ -- loop the same number of iterations, because we will be performing
+ -- a group-wide reduction in there.
+ let (bound, check_bounds) =
+ case comm of
+ Commutative -> (Imp.var chunk_size int32, id)
+ Noncommutative -> (Imp.innerExp elems_per_thread,
+ sWhen (Imp.var gtid int32 .<. Imp.innerExp num_elements))
+
+ sFor i Int32 bound $ do
+ gtid <--
+ case comm of
+ Commutative ->
+ global_tid +
+ Imp.var threads_per_segment int32 * Imp.var i int32
+ Noncommutative | accesses_memory ->
+ local_tid +
+ (index_in_segment * Imp.innerExp elems_per_thread + Imp.var i int32) *
+ kernelGroupSize constants
+ Noncommutative ->
+ Imp.var i int32 +
+ global_tid * Imp.innerExp elems_per_thread
+
+ check_bounds $ sComment "apply map function" $
+ ImpGen.compileStms mempty (stmsToList $ bodyStms body) $ do
+ let (red_res, map_res) = splitAt (length nes) $ bodyResult body
+
+ sComment "save results to be reduced" $
+ forM_ (zip red_next_params red_res) $ \(p, se) ->
+ ImpGen.copyDWIM (paramName p) [] se []
+
+ sComment "save map-out results" $
+ forM_ (zip (drop (length nes) segred_pes) map_res) $ \(pe, se) ->
+ ImpGen.copyDWIM (patElemName pe) (map (`Imp.var` int32) gtids) se []
+
+ sComment "apply reduction operator" $
+ ImpGen.compileBody' red_acc_params $ lambdaBody red_op
+
+ case comm of
+ Noncommutative | accesses_memory -> do
+ doTheReduction
+ sComment "first thread takes carry-out; others neutral element" $ do
+ let carry_out =
+ forM_ (zip red_acc_params $ lambdaParams red_op_renamed) $ \(p_to, p_from) ->
+ ImpGen.copyDWIM (paramName p_to) [] (Var $ paramName p_from) []
+ reset_to_neutral =
+ forM_ (zip red_acc_params nes) $ \(p, ne) ->
+ ImpGen.copyDWIM (paramName p) [] ne []
+ sIf (local_tid .==. 0) carry_out reset_to_neutral
+ _ ->
+ return ()
+
+ group_result_params <- case comm of
+ Noncommutative | accesses_memory ->
+ return red_acc_params
+
+ _ -> do
+ doTheReduction
+
+ return $ lambdaParams red_op_renamed
+
+ return (group_result_params, red_op_renamed)
+
+reductionStageTwo :: KernelConstants
+ -> Pattern ExplicitMemory
+ -> Imp.Exp
+ -> [Imp.Exp]
+ -> Imp.Exp
+ -> PrimExp Imp.ExpLeaf
+ -> [LParam InKernel]
+ -> [LParam InKernel]
+ -> Lambda InKernel
+ -> [SubExp]
+ -> Imp.Exp
+ -> VName
+ -> VName
+ -> [VName]
+ -> [VName]
+ -> InKernelGen ()
+reductionStageTwo constants segred_pat
+ flat_segment_id segment_gtids first_group_for_segment groups_per_segment
+ group_result_params red_acc_params
+ red_op_renamed nes
+ num_counters counter sync_arr group_res_arrs red_arrs = do
+ let local_tid = kernelLocalThreadId constants
+ group_id = kernelGroupId constants
+ group_size = kernelGroupSize constants
+ old_counter <- dPrim "old_counter" int32
+ (counter_mem, _, counter_offset) <- ImpGen.fullyIndexArray counter [flat_segment_id `rem` num_counters]
+ ImpGen.comment "first thread in group saves group result to memory" $
+ sWhen (local_tid .==. 0) $ do
+ forM_ (take (length nes) $ zip group_res_arrs group_result_params) $ \(v, p) ->
+ ImpGen.copyDWIM v [group_id] (Var $ paramName p) []
+ sOp Imp.MemFence
+ -- Increment the counter, thus stating that our result is
+ -- available.
+ sOp $ Imp.Atomic $ Imp.AtomicAdd old_counter counter_mem counter_offset 1
+ -- Now check if we were the last group to write our result. If
+ -- so, it is our responsibility to produce the final result.
+ ImpGen.sWrite sync_arr [0] $ Imp.var old_counter int32 .==. groups_per_segment - 1
+
+ sOp Imp.LocalBarrier
+
+ is_last_group <- dPrim "is_last_group" Bool
+ ImpGen.copyDWIM is_last_group [] (Var sync_arr) [0]
+ sWhen (Imp.var is_last_group Bool) $ do
+ -- The final group has written its result (and it was
+ -- us!), so read in all the group results and perform the
+ -- final stage of the reduction. But first, we reset the
+ -- counter so it is ready for next time. This is done
+ -- with an atomic to avoid warnings about write/write
+ -- races in oclgrind.
+ sWhen (local_tid .==. 0) $
+ sOp $ Imp.Atomic $ Imp.AtomicAdd old_counter counter_mem counter_offset $
+ negate groups_per_segment
+ ImpGen.comment "read in the per-group-results" $
+ forM_ (zip4 red_acc_params red_arrs nes group_res_arrs) $
+ \(p, arr, ne, group_res_arr) -> do
+ let load_group_result =
+ ImpGen.copyDWIM (paramName p) []
+ (Var group_res_arr) [first_group_for_segment + local_tid]
+ load_neutral_element =
+ ImpGen.copyDWIM (paramName p) [] ne []
+ ImpGen.sIf (local_tid .<. groups_per_segment)
+ load_group_result load_neutral_element
+ when (primType $ paramType p) $
+ ImpGen.copyDWIM arr [local_tid] (Var $ paramName p) []
+
+ sComment "reduce the per-group results" $ do
+ groupReduce constants group_size red_op_renamed red_arrs
+
+ sComment "and back to memory with the final result" $
+ sWhen (local_tid .==. 0) $
+ forM_ (take (length nes) $ zip (patternNames segred_pat) $
+ lambdaParams red_op_renamed) $ \(v, p) ->
+ ImpGen.copyDWIM v segment_gtids (Var $ paramName p) []
diff --git a/src/Futhark/CodeGen/ImpGen/Kernels/ToOpenCL.hs b/src/Futhark/CodeGen/ImpGen/Kernels/ToOpenCL.hs
index ca564cc..bfcbaf5 100644
--- a/src/Futhark/CodeGen/ImpGen/Kernels/ToOpenCL.hs
+++ b/src/Futhark/CodeGen/ImpGen/Kernels/ToOpenCL.hs
@@ -4,6 +4,7 @@
-- kernels to imperative code with OpenCL calls.
module Futhark.CodeGen.ImpGen.Kernels.ToOpenCL
( kernelsToOpenCL
+ , kernelsToCUDA
)
where
@@ -14,10 +15,10 @@ import Control.Monad.Reader
import Data.Maybe
import qualified Data.Set as S
import qualified Data.Map.Strict as M
-import qualified Data.Semigroup as Sem
import qualified Language.C.Syntax as C
import qualified Language.C.Quote.OpenCL as C
+import qualified Language.C.Quote.CUDA as CUDAC
import Futhark.Error
import qualified Futhark.CodeGen.Backends.GenericC as GenericC
@@ -30,19 +31,27 @@ import Futhark.MonadFreshNames
import Futhark.Util (zEncodeString)
import Futhark.Util.Pretty (pretty, prettyOneLine)
+kernelsToCUDA, kernelsToOpenCL :: ImpKernels.Program
+ -> Either InternalError ImpOpenCL.Program
+kernelsToCUDA = translateKernels TargetCUDA
+kernelsToOpenCL = translateKernels TargetOpenCL
+
-- | Translate a kernels-program to an OpenCL-program.
-kernelsToOpenCL :: ImpKernels.Program
- -> Either InternalError ImpOpenCL.Program
-kernelsToOpenCL (ImpKernels.Functions funs) = do
+translateKernels :: KernelTarget
+ -> ImpKernels.Program
+ -> Either InternalError ImpOpenCL.Program
+translateKernels target (ImpKernels.Functions funs) = do
(prog', ToOpenCL extra_funs kernels requirements sizes) <-
runWriterT $ fmap Functions $ forM funs $ \(fname, fun) ->
- (fname,) <$> runReaderT (traverse onHostOp fun) fname
+ (fname,) <$> runReaderT (traverse (onHostOp target) fun) fname
let kernel_names = M.keys kernels
opencl_code = openClCode $ M.elems kernels
- opencl_prelude = pretty $ genOpenClPrelude requirements
+ opencl_prelude = pretty $ genPrelude target requirements
return $ ImpOpenCL.Program opencl_code opencl_prelude kernel_names
(S.toList $ kernelUsedTypes requirements) sizes $
ImpOpenCL.Functions (M.toList extra_funs) <> prog'
+ where genPrelude TargetOpenCL = genOpenClPrelude
+ genPrelude TargetCUDA = genCUDAPrelude
pointerQuals :: Monad m => String -> m [C.TypeQual]
pointerQuals "global" = return [C.ctyquals|__global|]
@@ -61,75 +70,42 @@ data OpenClRequirements =
, _kernelConstants :: [(VName, KernelConstExp)]
}
-instance Sem.Semigroup OpenClRequirements where
+instance Semigroup OpenClRequirements where
OpenClRequirements ts1 consts1 <> OpenClRequirements ts2 consts2 =
OpenClRequirements (ts1 <> ts2) (consts1 <> consts2)
instance Monoid OpenClRequirements where
mempty = OpenClRequirements mempty mempty
- mappend = (Sem.<>)
data ToOpenCL = ToOpenCL { clExtraFuns :: M.Map Name ImpOpenCL.Function
, clKernels :: M.Map KernelName C.Func
, clRequirements :: OpenClRequirements
- , clSizes :: M.Map VName (SizeClass, Name)
+ , clSizes :: M.Map Name SizeClass
}
-instance Sem.Semigroup ToOpenCL where
+instance Semigroup ToOpenCL where
ToOpenCL f1 k1 r1 sz1 <> ToOpenCL f2 k2 r2 sz2 =
ToOpenCL (f1<>f2) (k1<>k2) (r1<>r2) (sz1<>sz2)
instance Monoid ToOpenCL where
mempty = ToOpenCL mempty mempty mempty mempty
- mappend = (Sem.<>)
type OnKernelM = ReaderT Name (WriterT ToOpenCL (Either InternalError))
-onHostOp :: HostOp -> OnKernelM OpenCL
-onHostOp (CallKernel k) = onKernel k
-onHostOp (ImpKernels.GetSize v key size_class) = do
- fname <- ask
- tell mempty { clSizes = M.singleton key (size_class, fname) }
+onHostOp :: KernelTarget -> HostOp -> OnKernelM OpenCL
+onHostOp target (CallKernel k) = onKernel target k
+onHostOp _ (ImpKernels.GetSize v key size_class) = do
+ tell mempty { clSizes = M.singleton key size_class }
return $ ImpOpenCL.GetSize v key
-onHostOp (ImpKernels.CmpSizeLe v key size_class x) = do
- fname <- ask
- tell mempty { clSizes = M.singleton key (size_class, fname) }
+onHostOp _ (ImpKernels.CmpSizeLe v key size_class x) = do
+ tell mempty { clSizes = M.singleton key size_class }
return $ ImpOpenCL.CmpSizeLe v key x
-onHostOp (ImpKernels.GetSizeMax v size_class) =
+onHostOp _ (ImpKernels.GetSizeMax v size_class) =
return $ ImpOpenCL.GetSizeMax v size_class
-onKernel :: CallKernel -> OnKernelM OpenCL
-
-onKernel called@(Map kernel) = do
- let (funbody, _) =
- GenericC.runCompilerM (Functions []) inKernelOperations blankNameSource mempty $ do
- size <- GenericC.compileExp $ mapKernelSize kernel
- let check = [C.citem|if ($id:(mapKernelThreadNum kernel) >= $exp:size) return;|]
- body <- GenericC.blockScope $ GenericC.compileCode $ mapKernelBody kernel
- return $ check : body
-
- params = mapMaybe useAsParam $ mapKernelUses kernel
-
- tell mempty
- { clExtraFuns = mempty
- , clKernels = M.singleton (mapKernelName kernel)
- [C.cfun|__kernel void $id:(mapKernelName kernel) ($params:params) {
- const uint $id:(mapKernelThreadNum kernel) = get_global_id(0);
- $items:funbody
- }|]
- , clRequirements = OpenClRequirements
- (typesInKernel called)
- (mapMaybe useAsConst $ mapKernelUses kernel)
- }
-
- return $ LaunchKernel
- (calledKernelName called) (kernelArgs called) kernel_size workgroup_size
-
- where kernel_size = [sizeToExp (mapKernelNumGroups kernel) *
- sizeToExp (mapKernelGroupSize kernel)]
- workgroup_size = [sizeToExp $ mapKernelGroupSize kernel]
-
-onKernel called@(AnyKernel kernel) = do
+onKernel :: KernelTarget -> Kernel -> OnKernelM OpenCL
+
+onKernel target kernel = do
let (kernel_body, _) =
GenericC.runCompilerM (Functions []) inKernelOperations blankNameSource mempty $
GenericC.blockScope $ GenericC.compileCode $ kernelBody kernel
@@ -139,35 +115,64 @@ onKernel called@(AnyKernel kernel) = do
(local_memory_params, local_memory_init) =
unzip $
flip evalState (blankNameSource :: VNameSource) $
- mapM prepareLocalMemory $ kernelLocalMemory kernel
-
- params = catMaybes local_memory_params ++ use_params
+ mapM (prepareLocalMemory target) $ kernelLocalMemory kernel
+
+ -- CUDA has very strict restrictions on the number of blocks
+ -- permitted along the 'y' and 'z' dimensions of the grid
+ -- (1<<16). To work around this, we are going to dynamically
+ -- permute the block dimensions to move the largest one to the
+ -- 'x' dimension, which has a higher limit (1<<31). This means
+ -- we need to extend the kernel with extra parameters that
+ -- contain information about this permutation, but we only do
+ -- this for multidimensional kernels (at the time of this
+ -- writing, only transposes). The corresponding arguments are
+ -- added automatically in CCUDA.hs.
+ (perm_params, block_dim_init) =
+ case (target, num_groups) of
+ (TargetCUDA, [_, _, _]) -> ([[C.cparam|const int block_dim0|],
+ [C.cparam|const int block_dim1|],
+ [C.cparam|const int block_dim2|]],
+ mempty)
+ _ -> (mempty,
+ [[C.citem|const int block_dim0 = 0;|],
+ [C.citem|const int block_dim1 = 1;|],
+ [C.citem|const int block_dim2 = 2;|]])
+
+ params = perm_params ++ catMaybes local_memory_params ++ use_params
tell mempty { clExtraFuns = mempty
- , clKernels = M.singleton name
- [C.cfun|__kernel void $id:name ($params:params) {
- $items:local_memory_init
- $items:kernel_body
- }|]
- , clRequirements = OpenClRequirements
- (typesInKernel called)
- (mapMaybe useAsConst $ kernelUses kernel)
- }
-
- return $ LaunchKernel
- (calledKernelName called) (kernelArgs called) kernel_size workgroup_size
-
- where prepareLocalMemory (mem, Left _) = do
+ , clKernels = M.singleton name
+ [C.cfun|__kernel void $id:name ($params:params) {
+ $items:block_dim_init
+ $items:local_memory_init
+ $items:kernel_body
+ }|]
+ , clRequirements = OpenClRequirements
+ (typesInKernel kernel)
+ (mapMaybe useAsConst $ kernelUses kernel)
+ }
+
+ return $ LaunchKernel name (kernelArgs kernel) num_groups group_size
+ where name = nameToString $ kernelName kernel
+ num_groups = kernelNumGroups kernel
+ group_size = kernelGroupSize kernel
+
+ prepareLocalMemory TargetOpenCL (mem, Left _) = do
mem_aligned <- newVName $ baseString mem ++ "_aligned"
return (Just [C.cparam|__local volatile typename int64_t* $id:mem_aligned|],
[C.citem|__local volatile char* restrict $id:mem = $id:mem_aligned;|])
- prepareLocalMemory (mem, Right size) = do
+ prepareLocalMemory TargetOpenCL (mem, Right size) = do
let size' = compilePrimExp size
return (Nothing,
[C.citem|ALIGNED_LOCAL_MEMORY($id:mem, $exp:size');|])
- name = calledKernelName called
- kernel_size = zipWith (*) (kernelNumGroups kernel) (kernelGroupSize kernel)
- workgroup_size = kernelGroupSize kernel
+ prepareLocalMemory TargetCUDA (mem, Left _) = do
+ param <- newVName $ baseString mem ++ "_offset"
+ return (Just [C.cparam|uint $id:param|],
+ [C.citem|volatile char *$id:mem = &shared_mem[$id:param];|])
+ prepareLocalMemory TargetCUDA (mem, Right size) = do
+ let size' = compilePrimExp size
+ return (Nothing,
+ [CUDAC.citem|__shared__ volatile char $id:mem[$exp:size'];|])
useAsParam :: KernelUse -> Maybe C.Param
useAsParam (ScalarUse name bt) =
@@ -222,28 +227,154 @@ $esc:("#define ALIGNED_LOCAL_MEMORY(m,size) __local unsigned char m[size] __attr
(if uses_float64 then cFloat64Ops ++ cFloat64Funs ++ cFloatConvOps else []) ++
[ [C.cedecl|$esc:def|] | def <- map constToDefine consts ]
where uses_float64 = FloatType Float64 `S.member` ts
- constToDefine (name, e) =
- let e' = compilePrimExp e
- in unwords ["#define", zEncodeString (pretty name), "("++prettyOneLine e'++")"]
+
+
+cudaAtomicOps :: [C.Definition]
+cudaAtomicOps = (return mkOp <*> opNames <*> types) ++ extraOps
+ where
+ mkOp (clName, cuName) t =
+ [C.cedecl|static inline $ty:t $id:clName(volatile $ty:t *p, $ty:t val) {
+ return $id:cuName(($ty:t *)p, val);
+ }|]
+ types = [ [C.cty|int|]
+ , [C.cty|unsigned int|]
+ , [C.cty|unsigned long long|]
+ ]
+ opNames = [ ("atomic_add", "atomicAdd")
+ , ("atomic_max", "atomicMax")
+ , ("atomic_min", "atomicMin")
+ , ("atomic_and", "atomicAnd")
+ , ("atomic_or", "atomicOr")
+ , ("atomic_xor", "atomicXor")
+ , ("atomic_xchg", "atomicExch")
+ ]
+ extraOps =
+ [ [C.cedecl|static inline $ty:t atomic_cmpxchg(volatile $ty:t *p, $ty:t cmp, $ty:t val) {
+ return atomicCAS(($ty:t *)p, cmp, val);
+ }|] | t <- types]
+
+genCUDAPrelude :: OpenClRequirements -> [C.Definition]
+genCUDAPrelude (OpenClRequirements _ consts) =
+ cudafy ++ cudaAtomicOps ++ defs ++ ops
+ where ops = cIntOps ++ cFloat32Ops ++ cFloat32Funs ++ cFloat64Ops
+ ++ cFloat64Funs ++ cFloatConvOps
+ defs = [ [C.cedecl|$esc:def|] | def <- map constToDefine consts ]
+ cudafy = [CUDAC.cunit|
+typedef char int8_t;
+typedef short int16_t;
+typedef int int32_t;
+typedef long int64_t;
+typedef unsigned char uint8_t;
+typedef unsigned short uint16_t;
+typedef unsigned int uint32_t;
+typedef unsigned long long uint64_t;
+typedef uint8_t uchar;
+typedef uint16_t ushort;
+typedef uint32_t uint;
+typedef uint64_t ulong;
+$esc:("#define __kernel extern \"C\" __global__ __launch_bounds__(MAX_THREADS_PER_BLOCK)")
+$esc:("#define __global")
+$esc:("#define __local")
+$esc:("#define __private")
+$esc:("#define __constant")
+$esc:("#define __write_only")
+$esc:("#define __read_only")
+
+static inline int get_group_id_fn(int block_dim0, int block_dim1, int block_dim2, int d)
+{
+ switch (d) {
+ case 0: d = block_dim0; break;
+ case 1: d = block_dim1; break;
+ case 2: d = block_dim2; break;
+ }
+ switch (d) {
+ case 0: return blockIdx.x;
+ case 1: return blockIdx.y;
+ case 2: return blockIdx.z;
+ default: return 0;
+ }
+}
+$esc:("#define get_group_id(d) get_group_id_fn(block_dim0, block_dim1, block_dim2, d)")
+
+static inline int get_num_groups_fn(int block_dim0, int block_dim1, int block_dim2, int d)
+{
+ switch (d) {
+ case 0: d = block_dim0; break;
+ case 1: d = block_dim1; break;
+ case 2: d = block_dim2; break;
+ }
+ switch(d) {
+ case 0: return gridDim.x;
+ case 1: return gridDim.y;
+ case 2: return gridDim.z;
+ default: return 0;
+ }
+}
+$esc:("#define get_num_groups(d) get_num_groups_fn(block_dim0, block_dim1, block_dim2, d)")
+
+static inline int get_local_id(int d)
+{
+ switch (d) {
+ case 0: return threadIdx.x;
+ case 1: return threadIdx.y;
+ case 2: return threadIdx.z;
+ default: return 0;
+ }
+}
+
+static inline int get_local_size(int d)
+{
+ switch (d) {
+ case 0: return blockDim.x;
+ case 1: return blockDim.y;
+ case 2: return blockDim.z;
+ default: return 0;
+ }
+}
+
+static inline int get_global_id_fn(int block_dim0, int block_dim1, int block_dim2, int d)
+{
+ return get_group_id(d) * get_local_size(d) + get_local_id(d);
+}
+$esc:("#define get_global_id(d) get_global_id_fn(block_dim0, block_dim1, block_dim2, d)")
+
+static inline int get_global_size(int block_dim0, int block_dim1, int block_dim2, int d)
+{
+ return get_num_groups(d) * get_local_size(d);
+}
+
+$esc:("#define CLK_LOCAL_MEM_FENCE 1")
+$esc:("#define CLK_GLOBAL_MEM_FENCE 2")
+static inline void barrier(int x)
+{
+ __syncthreads();
+}
+static inline void mem_fence(int x)
+{
+ if (x == CLK_LOCAL_MEM_FENCE) {
+ __threadfence_block();
+ } else {
+ __threadfence();
+ }
+}
+$esc:("#define NAN (0.0/0.0)")
+$esc:("#define INFINITY (1.0/0.0)")
+extern volatile __shared__ char shared_mem[];
+|]
+
+constToDefine :: (VName, KernelConstExp) -> String
+constToDefine (name, e) =
+ let e' = compilePrimExp e
+ in unwords ["#define", zEncodeString (pretty name), "("++prettyOneLine e'++")"]
+
compilePrimExp :: PrimExp KernelConst -> C.Exp
compilePrimExp e = runIdentity $ GenericC.compilePrimExp compileKernelConst e
- where compileKernelConst (SizeConst key) = return [C.cexp|$id:(pretty key)|]
-
-mapKernelName :: MapKernel -> String
-mapKernelName k = "kernel_"++ mapKernelDesc k ++ "_" ++
- show (baseTag $ mapKernelThreadNum k)
-
-calledKernelName :: CallKernel -> String
-calledKernelName (Map k) =
- mapKernelName k
-calledKernelName (AnyKernel k) =
- nameToString $ kernelName k
-
-kernelArgs :: CallKernel -> [KernelArg]
-kernelArgs (Map kernel) =
- mapMaybe useToArg $ mapKernelUses kernel
-kernelArgs (AnyKernel kernel) =
+ where compileKernelConst (SizeConst key) =
+ return [C.cexp|$id:(zEncodeString (pretty key))|]
+
+kernelArgs :: Kernel -> [KernelArg]
+kernelArgs kernel =
mapMaybe (fmap (SharedMemoryKArg . memSizeToExp) . localMemorySize)
(kernelLocalMemory kernel) ++
mapMaybe useToArg (kernelUses kernel)
@@ -277,8 +408,10 @@ inKernelOperations = GenericC.Operations
GenericC.stm [C.cstm|$id:v = get_global_size($int:i);|]
kernelOps (GetLockstepWidth v) =
GenericC.stm [C.cstm|$id:v = LOCKSTEP_WIDTH;|]
- kernelOps Barrier =
+ kernelOps LocalBarrier =
GenericC.stm [C.cstm|barrier(CLK_LOCAL_MEM_FENCE);|]
+ kernelOps GlobalBarrier =
+ GenericC.stm [C.cstm|barrier(CLK_GLOBAL_MEM_FENCE);|]
kernelOps MemFence =
GenericC.stm [C.cstm|mem_fence(CLK_GLOBAL_MEM_FENCE);|]
kernelOps (Atomic aop) = atomicOps aop
@@ -361,9 +494,8 @@ useToArg (MemoryUse mem) = Just $ MemKArg mem
useToArg (ScalarUse v bt) = Just $ ValueKArg (LeafExp (ScalarVar v) bt) bt
useToArg ConstUse{} = Nothing
-typesInKernel :: CallKernel -> S.Set PrimType
-typesInKernel (Map kernel) = typesInCode $ mapKernelBody kernel
-typesInKernel (AnyKernel kernel) = typesInCode $ kernelBody kernel
+typesInKernel :: Kernel -> S.Set PrimType
+typesInKernel kernel = typesInCode $ kernelBody kernel
typesInCode :: ImpKernels.KernelCode -> S.Set PrimType
typesInCode Skip = mempty
diff --git a/src/Futhark/CodeGen/ImpGen/Kernels/Transpose.hs b/src/Futhark/CodeGen/ImpGen/Kernels/Transpose.hs
index 9caec80..24de3f4 100644
--- a/src/Futhark/CodeGen/ImpGen/Kernels/Transpose.hs
+++ b/src/Futhark/CodeGen/ImpGen/Kernels/Transpose.hs
@@ -7,7 +7,6 @@ module Futhark.CodeGen.ImpGen.Kernels.Transpose
where
import qualified Data.Set as S
-import Data.Semigroup ((<>))
import Prelude hiding (quot, rem)
@@ -119,7 +118,7 @@ mapTranspose block_dim args t kind =
t (Space "local") Nonvolatile $
index idata (bytes $ (v32 idata_offset + v32 index_in) * tsize)
t (Space "global") Nonvolatile]
- , Op Barrier
+ , Op LocalBarrier
, SetScalar x_index $ v32 get_group_id_1 * tile_dim + v32 get_local_id_0
, SetScalar y_index $ v32 get_group_id_0 * tile_dim + v32 get_local_id_1
, when (v32 x_index .<. height) $
@@ -203,7 +202,7 @@ mapTranspose block_dim args t kind =
t (Space "local") Nonvolatile $
index idata (bytes $ (v32 idata_offset + v32 index_in) * tsize)
t (Space "global") Nonvolatile
- , Op Barrier
+ , Op LocalBarrier
, SetScalar x_index x_out_index
, SetScalar y_index y_out_index
, dec index_out $ v32 y_index * height + v32 x_index
diff --git a/src/Futhark/CodeGen/ImpGen/Sequential.hs b/src/Futhark/CodeGen/ImpGen/Sequential.hs
index 90d8d84..f124fc6 100644
--- a/src/Futhark/CodeGen/ImpGen/Sequential.hs
+++ b/src/Futhark/CodeGen/ImpGen/Sequential.hs
@@ -11,7 +11,7 @@ import Futhark.Representation.ExplicitMemory
import Futhark.MonadFreshNames
compileProg :: MonadFreshNames m => Prog ExplicitMemory -> m (Either InternalError Imp.Program)
-compileProg = ImpGen.compileProg ops Imp.DefaultSpace
+compileProg = ImpGen.compileProg ops Imp.DefaultSpace []
where ops = ImpGen.defaultOperations opCompiler
opCompiler :: ImpGen.OpCompiler ExplicitMemory Imp.Sequential
opCompiler dest (Alloc e space) =
diff --git a/src/Futhark/CodeGen/OpenCL/Kernels.hs b/src/Futhark/CodeGen/OpenCL/Kernels.hs
index 6ba2174..25fde52 100644
--- a/src/Futhark/CodeGen/OpenCL/Kernels.hs
+++ b/src/Futhark/CodeGen/OpenCL/Kernels.hs
@@ -44,7 +44,7 @@ sizeHeuristicsTable =
[ SizeHeuristic "NVIDIA CUDA" DeviceGPU LockstepWidth $ HeuristicConst 32
, SizeHeuristic "AMD Accelerated Parallel Processing" DeviceGPU LockstepWidth $ HeuristicConst 64
, SizeHeuristic "" DeviceGPU LockstepWidth $ HeuristicConst 1
- , SizeHeuristic "" DeviceGPU NumGroups $ HeuristicConst 128
+ , SizeHeuristic "" DeviceGPU NumGroups $ HeuristicConst 256
, SizeHeuristic "" DeviceGPU GroupSize $ HeuristicConst 256
, SizeHeuristic "" DeviceGPU TileSize $ HeuristicConst 32
diff --git a/src/Futhark/Compiler.hs b/src/Futhark/Compiler.hs
index 294387a..9dfd51b 100644
--- a/src/Futhark/Compiler.hs
+++ b/src/Futhark/Compiler.hs
@@ -8,7 +8,6 @@ module Futhark.Compiler
, FutharkConfig (..)
, newFutharkConfig
, dumpError
- , reportingIOErrors
, module Futhark.Compiler.Program
, readProgram
@@ -16,14 +15,11 @@ module Futhark.Compiler
)
where
-import Data.Semigroup ((<>))
-import Control.Exception
import Control.Monad
import Control.Monad.Reader
import Control.Monad.Except
import System.Exit (exitWith, ExitCode(..))
import System.IO
-import qualified Data.Text as T
import qualified Data.Text.IO as T
import Futhark.Internalise
@@ -70,23 +66,6 @@ dumpError config err =
maybe (T.hPutStr stderr) T.writeFile
(snd (futharkVerbose config)) $ info <> "\n"
--- | Catch all IO exceptions and print a better error message if they
--- happen. Use this at the top-level of all Futhark compiler
--- frontends.
-reportingIOErrors :: IO () -> IO ()
-reportingIOErrors = flip catches [Handler onExit, Handler onError]
- where onExit :: ExitCode -> IO ()
- onExit = throwIO
- onError :: SomeException -> IO ()
- onError e
- | Just UserInterrupt <- asyncExceptionFromException e =
- return () -- This corresponds to CTRL-C, which is not an error.
- | otherwise = do
- T.hPutStrLn stderr "Internal compiler error (unhandled IO exception)."
- T.hPutStrLn stderr "Please report this at https://github.com/diku-dk/futhark/issues"
- T.hPutStrLn stderr $ T.pack $ show e
- exitWith $ ExitFailure 1
-
runCompilerOnProgram :: FutharkConfig
-> Pipeline I.SOACS lore
-> Action lore
diff --git a/src/Futhark/Compiler/CLI.hs b/src/Futhark/Compiler/CLI.hs
index 9406ff6..a54f59f 100644
--- a/src/Futhark/Compiler/CLI.hs
+++ b/src/Futhark/Compiler/CLI.hs
@@ -30,13 +30,14 @@ compilerMain :: cfg -- ^ Initial configuration.
-> Pipeline SOACS lore -- ^ The pipeline to use.
-> (cfg -> CompilerMode -> FilePath -> Prog lore -> FutharkM ())
-- ^ The action to take on the result of the pipeline.
+ -> String -- ^ Program name
+ -> [String] -- ^ Command line arguments.
-> IO ()
-compilerMain cfg cfg_opts name desc pipeline doIt = do
+compilerMain cfg cfg_opts name desc pipeline doIt prog args = do
hSetEncoding stdout utf8
hSetEncoding stderr utf8
- reportingIOErrors $
- mainWithOptions (newCompilerConfig cfg) (commandLineOptions ++ map wrapOption cfg_opts)
- "options... program" inspectNonOptions
+ mainWithOptions (newCompilerConfig cfg) (commandLineOptions ++ map wrapOption cfg_opts)
+ "options... program" inspectNonOptions prog args
where inspectNonOptions [file] config = Just $ compile config file
inspectNonOptions _ _ = Nothing
diff --git a/src/Futhark/Compiler/Program.hs b/src/Futhark/Compiler/Program.hs
index db82f37..a3e08a3 100644
--- a/src/Futhark/Compiler/Program.hs
+++ b/src/Futhark/Compiler/Program.hs
@@ -15,7 +15,6 @@ module Futhark.Compiler.Program
)
where
-import Data.Semigroup ((<>))
import Data.Loc
import Control.Exception
import Control.Monad
diff --git a/src/Futhark/Doc/Generator.hs b/src/Futhark/Doc/Generator.hs
index b9a7f7a..259804a 100644
--- a/src/Futhark/Doc/Generator.hs
+++ b/src/Futhark/Doc/Generator.hs
@@ -376,7 +376,9 @@ synopsisValBindBind :: (VName, BoundV) -> DocM Html
synopsisValBindBind (name, BoundV tps t) = do
let tps' = map typeParamHtml tps
t' <- typeHtml t
- return $ keyword "val " <> vnameHtml name <> joinBy " " tps' <> ": " <> t'
+ return $
+ keyword "val " <> vnameHtml name <>
+ mconcat (map (" "<>) tps') <> ": " <> t'
prettyEnum :: [Name] -> Html
prettyEnum cs = pipes $ map (("#"<>) . renderName) cs
@@ -396,7 +398,7 @@ typeHtml t = case t of
targs' <- mapM typeArgHtml targs
et' <- typeNameHtml et
return $ prettyU u <> et' <> joinBy " " targs'
- Array et shape u -> do
+ Array _ u et shape -> do
shape' <- prettyShapeDecl shape
et' <- prettyElem et
return $ prettyU u <> shape' <> et'
@@ -410,9 +412,9 @@ typeHtml t = case t of
t1' <> " -> " <> t2'
Enum cs -> return $ prettyEnum cs
-prettyElem :: ArrayElemTypeBase (DimDecl VName) () -> DocM Html
-prettyElem (ArrayPrimElem et _) = return $ primTypeHtml et
-prettyElem (ArrayPolyElem et targs _) = do
+prettyElem :: ArrayElemTypeBase (DimDecl VName) -> DocM Html
+prettyElem (ArrayPrimElem et) = return $ primTypeHtml et
+prettyElem (ArrayPolyElem et targs) = do
targs' <- mapM typeArgHtml targs
return $ prettyTypeName et <> joinBy " " targs'
prettyElem (ArrayRecordElem fs)
@@ -423,18 +425,18 @@ prettyElem (ArrayRecordElem fs)
where ppField (name, tp) = do
tp' <- prettyRecordElem tp
return $ toHtml (nameToString name) <> ": " <> tp'
-prettyElem (ArrayEnumElem cs _ ) = return $ braces $ prettyEnum cs
+prettyElem (ArrayEnumElem cs) = return $ braces $ prettyEnum cs
-prettyRecordElem :: RecordArrayElemTypeBase (DimDecl VName) () -> DocM Html
+prettyRecordElem :: RecordArrayElemTypeBase (DimDecl VName) -> DocM Html
prettyRecordElem (RecordArrayElem et) = prettyElem et
-prettyRecordElem (RecordArrayArrayElem et shape u) =
- typeHtml $ Array et shape u
+prettyRecordElem (RecordArrayArrayElem et shape) =
+ typeHtml $ Array () Nonunique et shape
prettyShapeDecl :: ShapeDecl (DimDecl VName) -> DocM Html
prettyShapeDecl (ShapeDecl ds) =
mconcat <$> mapM (fmap brackets . dimDeclHtml) ds
-typeArgHtml :: TypeArg (DimDecl VName) () -> DocM Html
+typeArgHtml :: TypeArg (DimDecl VName) -> DocM Html
typeArgHtml (TypeArgDim d _) = brackets <$> dimDeclHtml d
typeArgHtml (TypeArgType t _) = typeHtml t
diff --git a/src/Futhark/Doc/Html.hs b/src/Futhark/Doc/Html.hs
index 5808069..d983260 100644
--- a/src/Futhark/Doc/Html.hs
+++ b/src/Futhark/Doc/Html.hs
@@ -12,8 +12,6 @@ module Futhark.Doc.Html
)
where
-import Data.Semigroup ((<>))
-
import Language.Futhark
import Futhark.Util.Pretty (Doc,ppr)
diff --git a/src/Futhark/FreshNames.hs b/src/Futhark/FreshNames.hs
index 137147c..c8a2462 100644
--- a/src/Futhark/FreshNames.hs
+++ b/src/Futhark/FreshNames.hs
@@ -9,7 +9,6 @@ module Futhark.FreshNames
, newVNameFromName
) where
-import qualified Data.Semigroup as Sem
import Language.Haskell.TH.Syntax (Lift)
import Language.Futhark.Core
@@ -24,12 +23,11 @@ import Language.Futhark.Core
newtype VNameSource = VNameSource Int
deriving (Lift, Eq, Ord)
-instance Sem.Semigroup VNameSource where
+instance Semigroup VNameSource where
VNameSource x <> VNameSource y = VNameSource (x `max` y)
instance Monoid VNameSource where
mempty = blankNameSource
- mappend = (Sem.<>)
-- | Produce a fresh name, using the given name as a template.
newName :: VNameSource -> VName -> (VName, VNameSource)
diff --git a/src/Futhark/Internalise.hs b/src/Futhark/Internalise.hs
index bbbf13f..4d6c14d 100644
--- a/src/Futhark/Internalise.hs
+++ b/src/Futhark/Internalise.hs
@@ -14,7 +14,6 @@ import Control.Monad.State
import Control.Monad.Reader
import qualified Data.Map.Strict as M
import qualified Data.Set as S
-import Data.Semigroup ((<>))
import Data.List
import Data.Loc
import Data.Char (chr)
@@ -168,11 +167,11 @@ entryPoint params (retdecl, eret, crets) =
-> [EntryPointType]
entryPointType (_, E.Prim E.Unsigned{}, _) =
[I.TypeUnsigned]
- entryPointType (_, E.Array (ArrayPrimElem Unsigned{} _) _ _, _) =
+ entryPointType (_, E.Array _ _ (ArrayPrimElem Unsigned{}) _, _) =
[I.TypeUnsigned]
entryPointType (_, E.Prim{}, _) =
[I.TypeDirect]
- entryPointType (_, E.Array ArrayPrimElem{} _ _, _) =
+ entryPointType (_, E.Array _ _ ArrayPrimElem{} _, _) =
[I.TypeDirect]
entryPointType (te, t, ts) =
[I.TypeOpaque desc $ length ts]
@@ -272,7 +271,7 @@ internaliseExp desc (E.ArrayLit es (Info arr_t) loc)
forM flat_arrs $ \flat_arr -> do
flat_arr_t <- lookupType flat_arr
let new_shape' = reshapeOuter (map (DimNew . constant) new_shape)
- (length new_shape) $ arrayShape flat_arr_t
+ 1 $ arrayShape flat_arr_t
letSubExp desc $ I.BasicOp $ I.Reshape new_shape' flat_arr
| otherwise = do
@@ -1385,8 +1384,7 @@ isOverloadedFunction qname args loc = do
where isCharLit (Literal (SignedValue iv) _) = Just $ chr $ fromIntegral $ intToInt64 iv
isCharLit _ = Nothing
- handle [E.TupLit [n, m, arr] _] f
- | f `elem` ["unflatten", "cosmin_unflatten"] = Just $ \desc -> do
+ handle [E.TupLit [n, m, arr] _] "unflatten" = Just $ \desc -> do
arrs <- internaliseExpToVars "unflatten_arr" arr
n' <- internaliseExp1 "n" n
m' <- internaliseExp1 "m" m
@@ -1403,8 +1401,7 @@ isOverloadedFunction qname args loc = do
letSubExp desc $ I.BasicOp $
I.Reshape (reshapeOuter [DimNew n', DimNew m'] 1 $ arrayShape arr_t) arr'
- handle [arr] f
- | f `elem` ["flatten", "cosmin_flatten"] = Just $ \desc -> do
+ handle [arr] "flatten" = Just $ \desc -> do
arrs <- internaliseExpToVars "flatten_arr" arr
forM arrs $ \arr' -> do
arr_t <- lookupType arr'
@@ -1424,23 +1421,46 @@ isOverloadedFunction qname args loc = do
mapM (fmap (arraysSize 0) . mapM lookupType) [ys]
let conc xarr yarr = do
- -- All dimensions except for dimension 'i' must match.
+ -- All dimensions except the outermost must match. An
+ -- empty array matches anything.
xt <- lookupType xarr
yt <- lookupType yarr
let matches n m =
- letExp "match" =<<
- eAssert (pure $ I.BasicOp $ I.CmpOp (I.CmpEq I.int32) n m)
- "arguments do not have the same row shape" loc
- x_inner_dims = drop 1 $ I.arrayDims xt
- y_inner_dims = drop 1 $ I.arrayDims yt
- updims = zipWith3 updims' [(0::Int)..] (I.arrayDims xt)
- updims' j xd yd | j == 0 = yd
- | otherwise = xd
- matchcs <- asserting $ Certificates <$>
- zipWithM matches x_inner_dims y_inner_dims
+ letSubExp "match" $
+ I.BasicOp $ I.CmpOp (I.CmpEq I.int32) n m
+
+ emptyRow arr_t =
+ letSubExp "empty_row" =<<
+ foldBinOp I.LogOr (constant False) =<<
+ mapM (matches (intConst Int32 0)) (arrayDims $ rowType arr_t)
+
+ all_match <- letSubExp "all_match" =<<
+ foldBinOp I.LogAnd (constant True) =<<
+ zipWithM matches
+ (arrayDims (rowType xt)) (arrayDims (rowType yt))
+ xarr_empty <- emptyRow xt
+ yarr_empty <- emptyRow yt
+ either_empty <- letSubExp "either_empty" $
+ I.BasicOp $ I.BinOp I.LogOr xarr_empty yarr_empty
+ matchcs <- assertingOne $ letExp "concat_ok" =<<
+ eAssert (pure $ I.BasicOp $ I.BinOp I.LogOr either_empty all_match)
+ "row sizes do not match when concatenating" loc
+
+ let updims (j, xd, yd)
+ | j == 0 =
+ return (xd, yd)
+ | otherwise = do
+ d <- letSubExp "dim" $ I.BasicOp $ I.BinOp (SMax Int32) xd yd
+ return (d, d)
+
+ (xdims, ydims) <- unzip <$>
+ mapM updims (zip3 [(0::Int)..] (I.arrayDims xt) (I.arrayDims yt))
+
+ xarr' <- certifying matchcs $ letExp "concat_x_reshaped" $
+ shapeCoerce xdims xarr
yarr' <- certifying matchcs $ letExp "concat_y_reshaped" $
- shapeCoerce (updims $ I.arrayDims yt) yarr
- return $ I.BasicOp $ I.Concat 0 xarr [yarr'] ressize
+ shapeCoerce ydims yarr
+ return $ I.BasicOp $ I.Concat 0 xarr' [yarr'] ressize
letSubExps desc =<< zipWithM conc xs ys
handle [TupLit [offset, e] _] "rotate" = Just $ \desc -> do
diff --git a/src/Futhark/Internalise/Defunctionalise.hs b/src/Futhark/Internalise/Defunctionalise.hs
index 70bada0..66c4673 100644
--- a/src/Futhark/Internalise/Defunctionalise.hs
+++ b/src/Futhark/Internalise/Defunctionalise.hs
@@ -11,7 +11,6 @@ import Data.List
import Data.Loc
import qualified Data.Map.Strict as M
import qualified Data.Set as S
-import qualified Data.Semigroup as Sem
import qualified Data.Sequence as Seq
import Futhark.MonadFreshNames
@@ -21,7 +20,7 @@ import Futhark.Representation.AST.Pretty ()
-- | A static value stores additional information about the result of
-- defunctionalization of an expression, aside from the residual expression.
data StaticVal = Dynamic CompType
- | LambdaSV [VName] Pattern Exp Env
+ | LambdaSV [VName] Pattern StructType Exp Env
-- ^ The 'VName's are shape parameters that are bound
-- by the 'Pattern'.
| RecordSV [(Name, StaticVal)]
@@ -62,8 +61,8 @@ restrictEnvTo (NameSet m) = restrict <$> ask
Dynamic $ t `setUniqueness` Nonunique
restrict' _ (Dynamic t) =
Dynamic t
- restrict' u (LambdaSV dims pat e env) =
- LambdaSV dims pat e $ M.map (restrict' u) env
+ restrict' u (LambdaSV dims pat t e env) =
+ LambdaSV dims pat t e $ M.map (restrict' u) env
restrict' u (RecordSV fields) =
RecordSV $ map (fmap $ restrict' u) fields
restrict' u (DynamicFun (e, sv1) sv2) =
@@ -74,9 +73,9 @@ restrictEnvTo (NameSet m) = restrict <$> ask
-- the current Env as well as the set of globally defined dynamic
-- functions. This is used to avoid unnecessarily large closure
-- environments.
-newtype DefM a = DefM (RWS (Names, Env) (Seq.Seq ValBind) VNameSource a)
+newtype DefM a = DefM (RWS (S.Set VName, Env) (Seq.Seq ValBind) VNameSource a)
deriving (Functor, Applicative, Monad,
- MonadReader (Names, Env),
+ MonadReader (S.Set VName, Env),
MonadWriter (Seq.Seq ValBind),
MonadFreshNames)
@@ -211,15 +210,15 @@ defuncExp (Negate e0 loc) = do
(e0', sv) <- defuncExp e0
return (Negate e0' loc, sv)
-defuncExp e@(Lambda tparams pats e0 decl tp loc) = do
+defuncExp e@(Lambda tparams pats e0 decl (Info (closure, ret)) loc) = do
when (any isTypeParam tparams) $
error $ "Received a lambda with type parameters at " ++ locStr loc
++ ", but the defunctionalizer expects a monomorphic input program."
-- Extract the first parameter of the lambda and "push" the
-- remaining ones (if there are any) into the body of the lambda.
- let (dims, pat, e0') = case pats of
+ let (dims, pat, ret', e0') = case pats of
[] -> error "Received a lambda with no parameters."
- [pat'] -> (map typeParamName tparams, pat', e0)
+ [pat'] -> (map typeParamName tparams, pat', ret, e0)
(pat' : pats') ->
-- Split shape parameters into those that are determined by
-- the first pattern, and those that are determined by later
@@ -227,14 +226,15 @@ defuncExp e@(Lambda tparams pats e0 decl tp loc) = do
let bound_by_pat = (`S.member` patternDimNames pat') . typeParamName
(pat_dims, rest_dims) = partition bound_by_pat tparams
in (map typeParamName pat_dims, pat',
- Lambda rest_dims pats' e0 decl tp loc)
+ foldFunType (map (toStruct . patternPatternType) pats') ret,
+ Lambda rest_dims pats' e0 decl (Info (closure, ret)) loc)
-- Construct a record literal that closes over the environment of
-- the lambda. Closed-over 'DynamicFun's are converted to their
-- closure representation.
env <- restrictEnvTo (freeVars e)
let (fields, env') = unzip $ map closureFromDynamicFun $ M.toList env
- return (RecordLit fields loc, LambdaSV dims pat e0' $ M.fromList env')
+ return (RecordLit fields loc, LambdaSV dims pat ret' e0' $ M.fromList env')
where closureFromDynamicFun (vn, DynamicFun (clsr_env, sv) _) =
let name = nameFromString $ pretty vn
@@ -489,7 +489,7 @@ defuncApply depth e@(Apply e1 e2 d t@(Info ret) loc) = do
(e2', sv2) <- defuncExp e2
let e' = Apply e1' e2' d t loc
case sv1 of
- LambdaSV dims pat e0 closure_env -> do
+ LambdaSV dims pat e0_t e0 closure_env -> do
let env' = matchPatternSV pat sv2
env_dim = envFromDimNames dims
(e0', sv) <- localNewEnv (env' <> closure_env <> env_dim) $ defuncExp e0
@@ -531,7 +531,7 @@ defuncApply depth e@(Apply e1 e2 d t@(Info ret) loc) = do
else do
-- Lift lambda to top-level function definition.
let params = [closure_pat, pat']
- rettype = buildRetType closure_env params $ typeOf e0'
+ rettype = buildRetType closure_env params e0_t $ typeOf e0'
-- Embed some information about the original function
-- into the name of the lifted function, to make the
@@ -584,9 +584,9 @@ defuncApply depth e@(Var qn (Info t) loc) = do
fname <- newName $ qualLeaf qn
let (dims, pats, e0, sv') = liftDynFun sv depth
(argtypes', rettype) = dynamicFunType sv' argtypes
- liftValDec fname rettype dims pats e0
+ liftValDec fname (fromStruct rettype) dims pats e0
return (Var (qualName fname)
- (Info (foldFunType argtypes' rettype)) loc, sv')
+ (Info (foldFunType argtypes' $ fromStruct rettype)) loc, sv')
IntrinsicSV -> return (e, IntrinsicSV)
@@ -608,7 +608,7 @@ fullyApplied _ _ = True
-- depth of partial application.
liftDynFun :: StaticVal -> Int -> ([VName], [Pattern], Exp, StaticVal)
liftDynFun (DynamicFun (e, sv) _) 0 = ([], [], e, sv)
-liftDynFun (DynamicFun clsr@(_, LambdaSV dims pat _ _) sv) d
+liftDynFun (DynamicFun clsr@(_, LambdaSV dims pat _ _ _) sv) d
| d > 0 = let (dims', pats, e', sv') = liftDynFun sv (d-1)
in (dims ++ dims', pat : pats, e', DynamicFun clsr sv')
liftDynFun sv _ = error $ "Tried to lift a StaticVal " ++ show sv
@@ -671,22 +671,27 @@ buildEnvPattern env = RecordPattern (map buildField $ M.toList env) noLoc
-- lifted function can create unique arrays as long as they do not
-- alias any of its parameters. XXX: it is not clear that this is a
-- sufficient property, unfortunately.
-buildRetType :: Env -> [Pattern] -> CompType -> PatternType
-buildRetType env pats = vacuousShapeAnnotations . descend
+buildRetType :: Env -> [Pattern] -> StructType -> CompType -> PatternType
+buildRetType env pats = comb
where bound = foldMap oneName (M.keys env) <> foldMap patternVars pats
boundAsUnique v =
maybe False (unique . unInfo . identType) $
find ((==v) . identName) $ S.toList $ foldMap patIdentSet pats
problematic v = (v `member` bound) && not (boundAsUnique v)
+ comb (Record fs_annot) (Record fs_got) =
+ Record $ M.intersectionWith comb fs_annot fs_got
+ comb Arrow{} t = vacuousShapeAnnotations $ descend t
+ comb got _ = fromStruct got
+
descend t@Array{}
- | any problematic (aliases t) = t `setUniqueness` Nonunique
+ | any (problematic . aliasVar) (aliases t) = t `setUniqueness` Nonunique
descend (Record t) = Record $ fmap descend t
descend t = t
-- | Compute the corresponding type for a given static value.
typeFromSV :: StaticVal -> CompType
typeFromSV (Dynamic tp) = tp
-typeFromSV (LambdaSV _ _ _ env) = typeFromEnv env
+typeFromSV (LambdaSV _ _ _ _ env) = typeFromEnv env
typeFromSV (RecordSV ls) = Record $ M.fromList $ map (fmap typeFromSV) ls
typeFromSV (DynamicFun (_, sv) _) = typeFromSV sv
typeFromSV IntrinsicSV = error $ "Tried to get the type from the "
@@ -770,12 +775,11 @@ svFromType t = Dynamic t
-- A set of names where we also track uniqueness.
newtype NameSet = NameSet (M.Map VName Uniqueness)
-instance Sem.Semigroup NameSet where
+instance Semigroup NameSet where
NameSet x <> NameSet y = NameSet $ M.unionWith max x y
instance Monoid NameSet where
mempty = NameSet mempty
- mappend = (Sem.<>)
without :: NameSet -> NameSet -> NameSet
without (NameSet x) (NameSet y) = NameSet $ x `M.difference` y
@@ -789,7 +793,7 @@ ident v = NameSet $ M.singleton (identName v) (uniqueness $ unInfo $ identType v
oneName :: VName -> NameSet
oneName v = NameSet $ M.singleton v Nonunique
-names :: Names -> NameSet
+names :: S.Set VName -> NameSet
names = foldMap oneName
-- | Compute the set of free variables of an expression.
@@ -883,34 +887,34 @@ patternVars = mconcat . map ident . S.toList . patIdentSet
-- argument is the orignal type and the second is the type of the transformed
-- expression. This is necessary since the original type may contain additional
-- information (e.g., shape restrictions) from the user given annotation.
-combineTypeShapes :: ArrayDim dim =>
+combineTypeShapes :: (Monoid as, ArrayDim dim) =>
TypeBase dim as -> TypeBase dim as -> TypeBase dim as
combineTypeShapes (Record ts1) (Record ts2)
| M.keys ts1 == M.keys ts2 =
- Record $ M.map (uncurry combineTypeShapes) (M.intersectionWith (,) ts1 ts2)
-combineTypeShapes (Array et1 shape1 u1) (Array et2 shape2 _u2)
+ Record $ M.map (uncurry combineTypeShapes) (M.intersectionWith (,) ts1 ts2)
+combineTypeShapes (Array als1 u1 et1 shape1) (Array als2 _u2 et2 shape2)
| Just new_shape <- unifyShapes shape1 shape2 =
- Array (combineElemTypeInfo et1 et2) new_shape u1
+ Array (als1<>als2) u1 (combineElemTypeInfo et1 et2) new_shape
combineTypeShapes _ new_tp = new_tp
combineElemTypeInfo :: ArrayDim dim =>
- ArrayElemTypeBase dim as
- -> ArrayElemTypeBase dim as -> ArrayElemTypeBase dim as
+ ArrayElemTypeBase dim
+ -> ArrayElemTypeBase dim -> ArrayElemTypeBase dim
combineElemTypeInfo (ArrayRecordElem et1) (ArrayRecordElem et2) =
ArrayRecordElem $ M.map (uncurry combineRecordArrayTypeInfo)
(M.intersectionWith (,) et1 et2)
combineElemTypeInfo _ new_tp = new_tp
combineRecordArrayTypeInfo :: ArrayDim dim =>
- RecordArrayElemTypeBase dim as
- -> RecordArrayElemTypeBase dim as
- -> RecordArrayElemTypeBase dim as
+ RecordArrayElemTypeBase dim
+ -> RecordArrayElemTypeBase dim
+ -> RecordArrayElemTypeBase dim
combineRecordArrayTypeInfo (RecordArrayElem et1) (RecordArrayElem et2) =
RecordArrayElem $ combineElemTypeInfo et1 et2
-combineRecordArrayTypeInfo (RecordArrayArrayElem et1 shape1 u1)
- (RecordArrayArrayElem et2 shape2 u2)
+combineRecordArrayTypeInfo (RecordArrayArrayElem et1 shape1)
+ (RecordArrayArrayElem et2 shape2)
| Just new_shape <- unifyShapes shape1 shape2 =
- RecordArrayArrayElem (combineElemTypeInfo et1 et2) new_shape (u1 <> u2)
+ RecordArrayArrayElem (combineElemTypeInfo et1 et2) new_shape
combineRecordArrayTypeInfo _ new_tp = new_tp
-- | Defunctionalize a top-level value binding. Returns the
@@ -920,12 +924,18 @@ combineRecordArrayTypeInfo _ new_tp = new_tp
defuncValBind :: ValBind -> DefM (ValBind, Env, Bool)
-- Eta-expand entry points with a functional return type.
-defuncValBind (ValBind True name retdecl (Info rettype) tparams params body _ loc)
+defuncValBind (ValBind True name _ (Info rettype) tparams params body _ loc)
| (rettype_ps, rettype') <- unfoldFunType rettype,
not $ null rettype_ps = do
(body_pats, body', _) <- etaExpand body
- defuncValBind $ ValBind True name retdecl (Info rettype')
+ -- FIXME: we should also handle non-constant size annotations
+ -- here.
+ defuncValBind $ ValBind True name Nothing
+ (Info $ onlyConstantDims rettype')
tparams (params <> body_pats) body' Nothing loc
+ where onlyConstantDims = bimap onDim id
+ onDim (ConstDim x) = ConstDim x
+ onDim _ = AnyDim
defuncValBind valbind@(ValBind _ name retdecl rettype tparams params body _ _) = do
let env = envFromShapeParams tparams
diff --git a/src/Futhark/Internalise/Defunctorise.hs b/src/Futhark/Internalise/Defunctorise.hs
index c543858..43815e2 100644
--- a/src/Futhark/Internalise/Defunctorise.hs
+++ b/src/Futhark/Internalise/Defunctorise.hs
@@ -10,7 +10,6 @@ import qualified Data.Map as M
import qualified Data.Set as S
import Data.Maybe
import Data.Loc
-import qualified Data.Semigroup as Sem
import Prelude hiding (mod, abs)
@@ -53,13 +52,11 @@ lookupSubstInScope qn@(QualName quals name) scope@(Scope substs mods) =
Just (ModMod mod_scope) -> lookupSubstInScope (QualName qs name) mod_scope
_ -> (qn, scope)
-instance Sem.Semigroup Scope where
- Scope ss1 mt1 <> Scope ss2 mt2 =
- Scope (ss1<>ss2) (mt1<>mt2)
+instance Semigroup Scope where
+ Scope ss1 mt1 <> Scope ss2 mt2 = Scope (ss1<>ss2) (mt1<>mt2)
instance Monoid Scope where
mempty = Scope mempty mempty
- mappend = (Sem.<>)
type TySet = S.Set VName
diff --git a/src/Futhark/Internalise/Monad.hs b/src/Futhark/Internalise/Monad.hs
index 39775ce..4c9e42f 100644
--- a/src/Futhark/Internalise/Monad.hs
+++ b/src/Futhark/Internalise/Monad.hs
@@ -40,7 +40,6 @@ import Control.Monad.Writer
import Control.Monad.RWS
import qualified Control.Monad.Fail as Fail
import qualified Data.Map.Strict as M
-import qualified Data.Semigroup as Sem
import Futhark.Representation.SOACS
import Futhark.MonadFreshNames
@@ -75,7 +74,7 @@ data InternaliseState = InternaliseState {
}
newtype InternaliseResult = InternaliseResult [FunDef]
- deriving (Sem.Semigroup, Monoid)
+ deriving (Semigroup, Monoid)
newtype InternaliseM a = InternaliseM (BinderT SOACS
(RWST
diff --git a/src/Futhark/Internalise/Monomorphise.hs b/src/Futhark/Internalise/Monomorphise.hs
index f3479ad..1b76f23 100644
--- a/src/Futhark/Internalise/Monomorphise.hs
+++ b/src/Futhark/Internalise/Monomorphise.hs
@@ -31,7 +31,6 @@ import Control.Monad.RWS
import Control.Monad.State
import Data.Loc
import qualified Data.Map.Strict as M
-import qualified Data.Semigroup as Sem
import qualified Data.Sequence as Seq
import Data.Foldable
@@ -43,8 +42,13 @@ import Language.Futhark.TypeChecker.Types
-- | The monomorphization monad reads 'PolyBinding's and writes 'ValBinding's.
-- The 'TypeParam's in a 'ValBinding' can only be shape parameters.
-newtype PolyBinding = PolyBinding (VName, [TypeParam], [Pattern],
- Maybe (TypeExp VName), StructType, Exp, SrcLoc)
+--
+-- Each 'Polybinding' is also connected with the 'RecordReplacements'
+-- that were active when the binding was defined. This is used only
+-- in local functions.
+data PolyBinding = PolyBinding RecordReplacements
+ (VName, [TypeParam], [Pattern],
+ Maybe (TypeExp VName), StructType, Exp, SrcLoc)
-- | Mapping from record names to the variable names that contain the
-- fields. This is used because the monomorphiser also expands all
@@ -60,12 +64,11 @@ data Env = Env { envPolyBindings :: M.Map VName PolyBinding
, envRecordReplacements :: RecordReplacements
}
-instance Sem.Semigroup Env where
+instance Semigroup Env where
Env tb1 pb1 rr1 <> Env tb2 pb2 rr2 = Env (tb1 <> tb2) (pb1 <> pb2) (rr1 <> rr2)
instance Monoid Env where
mempty = Env mempty mempty mempty
- mappend = (Sem.<>)
localEnv :: Env -> MonoM a -> MonoM a
localEnv env = local (env <>)
@@ -75,10 +78,10 @@ extendEnv vn binding = localEnv
mempty { envPolyBindings = M.singleton vn binding }
withRecordReplacements :: RecordReplacements -> MonoM a -> MonoM a
-withRecordReplacements rr = localEnv mempty { envRecordReplacements = rr}
+withRecordReplacements rr = localEnv mempty { envRecordReplacements = rr }
-noRecordReplacements :: MonoM a -> MonoM a
-noRecordReplacements = local $ \env -> env { envRecordReplacements = mempty }
+replaceRecordReplacements :: RecordReplacements -> MonoM a -> MonoM a
+replaceRecordReplacements rr = local $ \env -> env { envRecordReplacements = rr }
-- | The monomorphization monad.
newtype MonoM a = MonoM (RWST Env (Seq.Seq (VName, ValBind)) VNameSource
@@ -194,7 +197,8 @@ transformExp (LetFun fname (tparams, params, retdecl, Info ret, body) e loc)
-- Retrieve the lifted monomorphic function bindings that are produced,
-- filter those that are monomorphic versions of the current let-bound
-- function and insert them at this point, and propagate the rest.
- let funbind = PolyBinding (fname, tparams, params, retdecl, ret, body, loc)
+ rr <- asks envRecordReplacements
+ let funbind = PolyBinding rr (fname, tparams, params, retdecl, ret, body, loc)
pass $ do
(e', bs) <- listen $ extendEnv fname funbind $ transformExp e
let (bs_local, bs_prop) = Seq.partition ((== fname) . fst) bs
@@ -469,8 +473,8 @@ expandRecordPattern (PatternLit e t loc) = return (PatternLit e t loc, mempty)
-- list. Monomorphizes the body of the function as well. Returns the fresh name
-- of the generated monomorphic function and its 'ValBind' representation.
monomorphizeBinding :: PolyBinding -> TypeBase () () -> MonoM (VName, ValBind)
-monomorphizeBinding (PolyBinding (name, tparams, params, retdecl, rettype, body, loc)) t =
- noRecordReplacements $ do
+monomorphizeBinding (PolyBinding rr (name, tparams, params, retdecl, rettype, body, loc)) t =
+ replaceRecordReplacements rr $ do
t' <- removeTypeVariablesInType t
let bind_t = foldFunType (map (toStructural . patternType) params) $
toStructural rettype
@@ -542,7 +546,7 @@ substPattern f pat = case pat of
toPolyBinding :: ValBind -> PolyBinding
toPolyBinding (ValBind _ name retdecl (Info rettype) tparams params body _ loc) =
- PolyBinding (name, tparams, params, retdecl, rettype, body, loc)
+ PolyBinding mempty (name, tparams, params, retdecl, rettype, body, loc)
-- | Remove all type variables and type abbreviations from a value binding.
removeTypeVariables :: ValBind -> MonoM ValBind
diff --git a/src/Futhark/Internalise/TypesValues.hs b/src/Futhark/Internalise/TypesValues.hs
index 5231f02..6d95166 100644
--- a/src/Futhark/Internalise/TypesValues.hs
+++ b/src/Futhark/Internalise/TypesValues.hs
@@ -24,7 +24,6 @@ import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Data.Maybe
import Data.Monoid ((<>))
-import Data.Semigroup (Semigroup)
import qualified Language.Futhark as E
import Futhark.Representation.SOACS as I
@@ -114,7 +113,7 @@ internaliseTypeM orig_t =
fail "internaliseTypeM: cannot handle type variable."
E.Record ets ->
concat <$> mapM (internaliseTypeM . snd) (E.sortFields ets)
- E.Array et shape u -> do
+ E.Array _ u et shape -> do
dims <- internaliseShape shape
ets <- internaliseElemType et
return [I.arrayOf et' (Shape dims) $ internaliseUniqueness u | et' <- ets ]
@@ -123,17 +122,17 @@ internaliseTypeM orig_t =
where internaliseElemType E.ArrayPolyElem{} =
fail "internaliseElemType: cannot handle type variable."
- internaliseElemType (E.ArrayPrimElem bt _) =
+ internaliseElemType (E.ArrayPrimElem bt) =
return [I.Prim $ internalisePrimType bt]
internaliseElemType (E.ArrayRecordElem elemts) =
concat <$> mapM (internaliseRecordElem . snd) (E.sortFields elemts)
- internaliseElemType (E.ArrayEnumElem _ _) =
+ internaliseElemType E.ArrayEnumElem{} =
return [I.Prim $ I.IntType I.Int8]
internaliseRecordElem (E.RecordArrayElem et) =
internaliseElemType et
- internaliseRecordElem (E.RecordArrayArrayElem et shape u) =
- internaliseTypeM $ E.Array et shape u
+ internaliseRecordElem (E.RecordArrayArrayElem et shape) =
+ internaliseTypeM $ E.Array mempty Nonunique et shape
internaliseShape = mapM internaliseDim . E.shapeDims
diff --git a/src/Futhark/Optimise/CSE.hs b/src/Futhark/Optimise/CSE.hs
index c071805..ed7101d 100644
--- a/src/Futhark/Optimise/CSE.hs
+++ b/src/Futhark/Optimise/CSE.hs
@@ -35,7 +35,6 @@ module Futhark.Optimise.CSE
import Control.Monad.Reader
import qualified Data.Set as S
import qualified Data.Map.Strict as M
-import Data.Semigroup ((<>))
import Futhark.Analysis.Alias
import Futhark.Representation.AST
diff --git a/src/Futhark/Optimise/Fusion.hs b/src/Futhark/Optimise/Fusion.hs
index b5867c3..5695e6b 100644
--- a/src/Futhark/Optimise/Fusion.hs
+++ b/src/Futhark/Optimise/Fusion.hs
@@ -10,9 +10,7 @@ module Futhark.Optimise.Fusion ( fuseSOACs )
import Control.Monad.State
import Control.Monad.Reader
import Control.Monad.Except
-import qualified Data.Semigroup as Sem
import Data.Maybe
-import Data.Semigroup ((<>))
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import qualified Data.List as L
@@ -245,7 +243,7 @@ data FusedRes = FusedRes {
-- ^ The map recording the uses
}
-instance Sem.Semigroup FusedRes where
+instance Semigroup FusedRes where
res1 <> res2 =
FusedRes (rsucc res1 || rsucc res2)
(outArr res1 `M.union` outArr res2)
@@ -256,7 +254,6 @@ instance Sem.Semigroup FusedRes where
instance Monoid FusedRes where
mempty = FusedRes { rsucc = False, outArr = M.empty, inpArr = M.empty,
infusible = S.empty, kernels = M.empty }
- mappend = (Sem.<>)
isInpArrInResModKers :: FusedRes -> S.Set KernName -> VName -> Bool
isInpArrInResModKers ress kers nm =
diff --git a/src/Futhark/Optimise/Fusion/Composing.hs b/src/Futhark/Optimise/Fusion/Composing.hs
index aa84b42..caec60b 100644
--- a/src/Futhark/Optimise/Fusion/Composing.hs
+++ b/src/Futhark/Optimise/Fusion/Composing.hs
@@ -18,7 +18,6 @@ module Futhark.Optimise.Fusion.Composing
where
import Data.List
-import Data.Semigroup ((<>))
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Data.Maybe
diff --git a/src/Futhark/Optimise/Fusion/LoopKernel.hs b/src/Futhark/Optimise/Fusion/LoopKernel.hs
index 2d0d0d9..7ed85a6 100644
--- a/src/Futhark/Optimise/Fusion/LoopKernel.hs
+++ b/src/Futhark/Optimise/Fusion/LoopKernel.hs
@@ -20,7 +20,6 @@ import Control.Monad
import qualified Data.Set as S
import qualified Data.Map.Strict as M
import Data.Maybe
-import Data.Semigroup ((<>))
import Data.List
import Futhark.Representation.SOACS hiding (SOAC(..))
diff --git a/src/Futhark/Optimise/InPlaceLowering.hs b/src/Futhark/Optimise/InPlaceLowering.hs
index e525fe0..01b75fa 100644
--- a/src/Futhark/Optimise/InPlaceLowering.hs
+++ b/src/Futhark/Optimise/InPlaceLowering.hs
@@ -66,7 +66,6 @@ module Futhark.Optimise.InPlaceLowering
import Control.Monad.RWS
import qualified Data.Map.Strict as M
import qualified Data.Set as S
-import qualified Data.Semigroup as Sem
import Futhark.Analysis.Alias
import Futhark.Representation.Aliases
@@ -196,13 +195,12 @@ data BottomUp lore = BottomUp { bottomUpSeen :: Names
, forwardThese :: [DesiredUpdate (LetAttr (Aliases lore))]
}
-instance Sem.Semigroup (BottomUp lore) where
+instance Semigroup (BottomUp lore) where
BottomUp seen1 forward1 <> BottomUp seen2 forward2 =
BottomUp (seen1 <> seen2) (forward1 <> forward2)
instance Monoid (BottomUp lore) where
mempty = BottomUp mempty mempty
- mappend = (Sem.<>)
updateStm :: Constraints lore => DesiredUpdate (LetAttr (Aliases lore)) -> Stm (Aliases lore)
updateStm fwd =
diff --git a/src/Futhark/Optimise/InPlaceLowering/SubstituteIndices.hs b/src/Futhark/Optimise/InPlaceLowering/SubstituteIndices.hs
index 4b793e5..43469dc 100644
--- a/src/Futhark/Optimise/InPlaceLowering/SubstituteIndices.hs
+++ b/src/Futhark/Optimise/InPlaceLowering/SubstituteIndices.hs
@@ -11,7 +11,6 @@ module Futhark.Optimise.InPlaceLowering.SubstituteIndices
, IndexSubstitutions
) where
-import Data.Semigroup ((<>))
import Control.Monad
import qualified Data.Map.Strict as M
import qualified Data.Set as S
diff --git a/src/Futhark/Optimise/MemoryBlockMerging/Types.hs b/src/Futhark/Optimise/MemoryBlockMerging/Types.hs
index a703a3e..9798bb9 100644
--- a/src/Futhark/Optimise/MemoryBlockMerging/Types.hs
+++ b/src/Futhark/Optimise/MemoryBlockMerging/Types.hs
@@ -18,7 +18,6 @@ module Futhark.Optimise.MemoryBlockMerging.Types
where
import qualified Data.Map.Strict as M
-import qualified Data.Semigroup as Sem
import Futhark.Representation.AST
import qualified Futhark.Representation.ExplicitMemory as ExpMem
@@ -83,9 +82,8 @@ type ActualVariables = M.Map VName Names
newtype Log = Log (M.Map VName [(String, String)])
deriving (Show, Eq, Ord)
-instance Sem.Semigroup Log where
+instance Semigroup Log where
Log a <> Log b = Log $ M.unionWith (++) a b
instance Monoid Log where
mempty = Log M.empty
- mappend = (Sem.<>)
diff --git a/src/Futhark/Optimise/Simplify.hs b/src/Futhark/Optimise/Simplify.hs
index c4992f4..bb9a6f1 100644
--- a/src/Futhark/Optimise/Simplify.hs
+++ b/src/Futhark/Optimise/Simplify.hs
@@ -18,8 +18,6 @@ module Futhark.Optimise.Simplify
)
where
-import Data.Semigroup ((<>))
-
import Futhark.Representation.AST
import Futhark.MonadFreshNames
import qualified Futhark.Optimise.Simplify.Engine as Engine
diff --git a/src/Futhark/Optimise/Simplify/ClosedForm.hs b/src/Futhark/Optimise/Simplify/ClosedForm.hs
index 7bfe3d7..087b28c 100644
--- a/src/Futhark/Optimise/Simplify/ClosedForm.hs
+++ b/src/Futhark/Optimise/Simplify/ClosedForm.hs
@@ -16,7 +16,6 @@ import Control.Monad
import Data.Maybe
import qualified Data.Map.Strict as M
import qualified Data.Set as S
-import Data.Semigroup ((<>))
import Futhark.Construct
import Futhark.Representation.AST
diff --git a/src/Futhark/Optimise/Simplify/Lore.hs b/src/Futhark/Optimise/Simplify/Lore.hs
index f3a1e62..bd268a3 100644
--- a/src/Futhark/Optimise/Simplify/Lore.hs
+++ b/src/Futhark/Optimise/Simplify/Lore.hs
@@ -29,7 +29,6 @@ module Futhark.Optimise.Simplify.Lore
import Control.Monad.Identity
import Control.Monad.Reader
-import Data.Semigroup ((<>))
import qualified Data.Map.Strict as M
import Futhark.Representation.AST
diff --git a/src/Futhark/Optimise/Simplify/Rule.hs b/src/Futhark/Optimise/Simplify/Rule.hs
index 036188b..2959d0e 100644
--- a/src/Futhark/Optimise/Simplify/Rule.hs
+++ b/src/Futhark/Optimise/Simplify/Rule.hs
@@ -51,9 +51,7 @@ module Futhark.Optimise.Simplify.Rule
, bottomUpSimplifyStm
) where
-import Data.Semigroup ((<>))
import Control.Monad.State
-import qualified Data.Semigroup as Sem
import qualified Control.Monad.Fail as Fail
import Control.Monad.Except
@@ -137,13 +135,12 @@ data Rules lore a = Rules { rulesAny :: [SimplificationRule lore a]
, rulesOp :: [SimplificationRule lore a]
}
-instance Sem.Semigroup (Rules lore a) where
+instance Semigroup (Rules lore a) where
Rules as1 bs1 cs1 ds1 es1 <> Rules as2 bs2 cs2 ds2 es2 =
Rules (as1<>as2) (bs1<>bs2) (cs1<>cs2) (ds1<>ds2) (es1<>es2)
instance Monoid (Rules lore a) where
mempty = Rules mempty mempty mempty mempty mempty
- mappend = (Sem.<>)
-- | Context for a rule applied during top-down traversal of the
-- program. Takes a symbol table as argument.
@@ -178,12 +175,11 @@ data RuleBook lore = RuleBook { bookTopDownRules :: TopDownRules lore
, bookBottomUpRules :: BottomUpRules lore
}
-instance Sem.Semigroup (RuleBook lore) where
+instance Semigroup (RuleBook lore) where
RuleBook ts1 bs1 <> RuleBook ts2 bs2 = RuleBook (ts1<>ts2) (bs1<>bs2)
instance Monoid (RuleBook lore) where
mempty = RuleBook mempty mempty
- mappend = (Sem.<>)
-- | Construct a rule book from a collection of rules.
ruleBook :: [TopDownRule m]
diff --git a/src/Futhark/Optimise/Simplify/Rules.hs b/src/Futhark/Optimise/Simplify/Rules.hs
index 9b35f73..d78211d 100644
--- a/src/Futhark/Optimise/Simplify/Rules.hs
+++ b/src/Futhark/Optimise/Simplify/Rules.hs
@@ -24,7 +24,6 @@ import Data.Either
import Data.Foldable (all)
import Data.List hiding (all)
import Data.Maybe
-import Data.Semigroup ((<>))
import qualified Data.Map.Strict as M
import qualified Data.Set as S
diff --git a/src/Futhark/Optimise/TileLoops.hs b/src/Futhark/Optimise/TileLoops.hs
index 2327d79..5828172 100644
--- a/src/Futhark/Optimise/TileLoops.hs
+++ b/src/Futhark/Optimise/TileLoops.hs
@@ -11,7 +11,6 @@ import Control.Monad.State
import Control.Monad.Reader
import qualified Data.Set as S
import qualified Data.Map.Strict as M
-import Data.Semigroup ((<>))
import Data.List
import Data.Maybe
@@ -25,7 +24,7 @@ import Futhark.Optimise.TileLoops.RegTiling3D
tileLoops :: Pass Kernels Kernels
tileLoops = Pass "tile loops" "Tile stream loops inside kernels" $
- intraproceduralTransformation optimiseFunDef
+ fmap Prog . mapM optimiseFunDef . progFunctions
optimiseFunDef :: MonadFreshNames m => FunDef Kernels -> m (FunDef Kernels)
optimiseFunDef fundec = do
@@ -126,7 +125,7 @@ tileInStms branch_variant initial_variance initial_kspace kstms = do
arrs arr_chunk_params = do
((tile_size, tiled_group_size), tile_size_bnds) <- runBinder $ do
- tile_size_key <- newVName "tile_size"
+ tile_size_key <- nameFromString . pretty <$> newVName "tile_size"
tile_size <- letSubExp "tile_size" $ Op $ GetSize tile_size_key SizeTile
tiled_group_size <- letSubExp "tiled_group_size" $
BasicOp $ BinOp (Mul Int32) tile_size tile_size
diff --git a/src/Futhark/Optimise/TileLoops/RegTiling3D.hs b/src/Futhark/Optimise/TileLoops/RegTiling3D.hs
index c5ea862..ede4f7a 100644
--- a/src/Futhark/Optimise/TileLoops/RegTiling3D.hs
+++ b/src/Futhark/Optimise/TileLoops/RegTiling3D.hs
@@ -22,7 +22,6 @@ import Control.Monad.Reader
import qualified Data.Set as S
import qualified Data.Map.Strict as M
import Data.List
-import Data.Semigroup ((<>))
import Data.Maybe
import Futhark.MonadFreshNames
@@ -659,7 +658,7 @@ mkKerSpaceExtraStms reg_tile gspace = do
(gidx,sz_x) : (gidy,sz_y) : (gidz,m_M) : untiled_gspace = reverse gspace
((tile_size_x, tile_size_y, tiled_group_size), tile_size_bnds) <- runBinder $ do
- tile_size_key <- newVName "tile_size"
+ tile_size_key <- nameFromString . pretty <$> newVName "tile_size"
tile_ct_size <- letSubExp "tile_size" $ Op $ GetSize tile_size_key SizeTile
tile_size_x <- letSubExp "tile_size_x" $ BasicOp $
BinOp (SMin Int32) tile_ct_size sz_x
diff --git a/src/Futhark/Pass/ExpandAllocations.hs b/src/Futhark/Pass/ExpandAllocations.hs
index 23a47af..4706576 100644
--- a/src/Futhark/Pass/ExpandAllocations.hs
+++ b/src/Futhark/Pass/ExpandAllocations.hs
@@ -12,7 +12,6 @@ import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Data.Maybe
import Data.List
-import Data.Semigroup ((<>))
import Prelude hiding (quot)
@@ -72,6 +71,52 @@ transformExp (Op (Inner (Kernel desc kspace ts kbody))) = do
variantAlloc _ = False
(variant_allocs, invariant_allocs) = M.partition (variantAlloc . fst) allocs
+ (alloc_stms, alloc_offsets) <-
+ memoryRequirements kspace (kernelBodyStms kbody) variant_allocs invariant_allocs
+
+ kbody'' <- either compilerLimitationS pure $
+ offsetMemoryInKernelBody alloc_offsets
+ kbody'
+
+ return (alloc_stms,
+ Op $ Inner $ Kernel desc kspace ts kbody'')
+
+ where bound_in_kernel =
+ S.fromList $ M.keys $ scopeOfKernelSpace kspace <>
+ scopeOf (kernelBodyStms kbody)
+
+transformExp (Op (Inner (SegRed kspace comm red_op nes ts kbody))) = do
+ let (kbody', kbody_allocs) = extractBodyAllocations kbody
+ (red_op', red_op_allocs) = extractLambdaAllocations red_op
+ variantAlloc (Var v) = v `S.member` bound_in_kernel
+ variantAlloc _ = False
+ allocs = kbody_allocs <> red_op_allocs
+ (variant_allocs, invariant_allocs) = M.partition (variantAlloc . fst) allocs
+
+ (alloc_stms, alloc_offsets) <-
+ memoryRequirements kspace (bodyStms kbody) variant_allocs invariant_allocs
+
+ either compilerLimitationS pure $ do
+ kbody'' <- offsetMemoryInBody alloc_offsets kbody'
+ red_op'' <- offsetMemoryInLambda alloc_offsets red_op'
+
+ return (alloc_stms,
+ Op $ Inner $ SegRed kspace comm red_op'' nes ts kbody'')
+
+ where bound_in_kernel =
+ S.fromList $ map fst (spaceDimensions kspace) ++
+ M.keys (scopeOfKernelSpace kspace <>
+ scopeOf (bodyStms kbody))
+
+transformExp e =
+ return (mempty, e)
+
+memoryRequirements :: KernelSpace
+ -> Stms InKernel
+ -> M.Map VName (SubExp, Space)
+ -> M.Map VName (SubExp, Space)
+ -> ExpandM (Stms ExplicitMemory, RebaseMap)
+memoryRequirements kspace kstms variant_allocs invariant_allocs = do
num_threads64 <- newVName "num_threads64"
let num_threads64_pat = Pattern [] [PatElem num_threads64 $ MemPrim int64]
num_threads64_bnd = Let num_threads64_pat (defAux ()) $ BasicOp $
@@ -83,33 +128,42 @@ transformExp (Op (Inner (Kernel desc kspace ts kbody))) = do
(spaceGlobalId kspace, spaceGroupId kspace, spaceLocalId kspace) invariant_allocs
(variant_alloc_stms, variant_alloc_offsets) <-
- expandedVariantAllocations kspace kbody variant_allocs
+ expandedVariantAllocations kspace kstms variant_allocs
let alloc_offsets = invariant_alloc_offsets <> variant_alloc_offsets
alloc_stms = invariant_alloc_stms <> variant_alloc_stms
- kbody'' <- either compilerLimitationS pure $
- offsetMemoryInKernelBody alloc_offsets
- kbody' { kernelBodyStms = kernelBodyStms kbody' }
-
- return (oneStm num_threads64_bnd <> alloc_stms,
- Op $ Inner $ Kernel desc kspace ts kbody'')
-
- where bound_in_kernel =
- S.fromList $ M.keys $ scopeOfKernelSpace kspace <>
- scopeOf (kernelBodyStms kbody)
-
-transformExp e =
- return (mempty, e)
+ return (oneStm num_threads64_bnd <> alloc_stms, alloc_offsets)
-- | Extract allocations from 'Thread' statements with
-- 'extractThreadAllocations'.
extractKernelBodyAllocations :: KernelBody InKernel
-> (KernelBody InKernel,
M.Map VName (SubExp, Space))
-extractKernelBodyAllocations kbody =
- let (allocs, stms) = mapAccumL extract M.empty $ stmsToList $ kernelBodyStms kbody
- in (kbody { kernelBodyStms = mconcat stms }, allocs)
+extractKernelBodyAllocations = extractGenericBodyAllocations kernelBodyStms $
+ \stms kbody -> kbody { kernelBodyStms = stms }
+
+extractBodyAllocations :: Body InKernel
+ -> (Body InKernel,
+ M.Map VName (SubExp, Space))
+extractBodyAllocations = extractGenericBodyAllocations bodyStms $
+ \stms body -> body { bodyStms = stms }
+
+extractLambdaAllocations :: Lambda InKernel
+ -> (Lambda InKernel,
+ M.Map VName (SubExp, Space))
+extractLambdaAllocations lam = (lam { lambdaBody = body' }, allocs)
+ where (body', allocs) = extractGenericBodyAllocations bodyStms
+ (\stms body -> body { bodyStms = stms }) $ lambdaBody lam
+
+extractGenericBodyAllocations :: (body -> Stms InKernel)
+ -> (Stms InKernel -> body -> body)
+ -> body
+ -> (body,
+ M.Map VName (SubExp, Space))
+extractGenericBodyAllocations get_stms set_stms body =
+ let (allocs, stms) = mapAccumL extract M.empty $ stmsToList $ get_stms body
+ in (set_stms (mconcat stms) body, allocs)
where extract allocs bnd =
let (bnds, body_allocs) = extractThreadAllocations $ oneStm bnd
in (allocs <> body_allocs, bnds)
@@ -170,17 +224,17 @@ expandedInvariantAllocations (num_threads64, num_groups, group_size)
map untouched old_shape
in offset_ixfun
-expandedVariantAllocations :: KernelSpace -> KernelBody InKernel
+expandedVariantAllocations :: KernelSpace -> Stms InKernel
-> M.Map VName (SubExp, Space)
-> ExpandM (Stms ExplicitMemory, RebaseMap)
expandedVariantAllocations _ _ variant_allocs
| null variant_allocs = return (mempty, mempty)
-expandedVariantAllocations kspace kbody variant_allocs = do
+expandedVariantAllocations kspace kstms variant_allocs = do
let sizes_to_blocks = removeCommonSizes variant_allocs
variant_sizes = map fst sizes_to_blocks
(slice_stms, offsets, size_sums) <-
- sliceKernelSizes variant_sizes kspace kbody
+ sliceKernelSizes variant_sizes kspace kstms
-- Note the recursive call to expand allocations inside the newly
-- produced kernels.
slice_stms_tmp <- ExplicitMemory.simplifyStms =<< explicitAllocationsInStms slice_stms
@@ -241,6 +295,11 @@ offsetMemoryInBody offsets (Body attr stms res) = do
stms' <- stmsFromList . snd <$> mapAccumLM offsetMemoryInStm offsets (stmsToList stms)
return $ Body attr stms' res
+offsetMemoryInLambda :: RebaseMap -> Lambda InKernel -> Either String (Lambda InKernel)
+offsetMemoryInLambda offset lam = do
+ body <- offsetMemoryInBody offset $ lambdaBody lam
+ return $ lam { lambdaBody = body }
+
offsetMemoryInStm :: RebaseMap -> Stm InKernel
-> Either String (RebaseMap, Stm InKernel)
offsetMemoryInStm offsets (Let pat attr e) = do
@@ -319,16 +378,13 @@ offsetMemoryInExp offsets e = mapExpM recurse e
---- Slicing allocation sizes out of a kernel.
-unAllocInKernelBody :: KernelBody InKernel
- -> Either String (KernelBody Kernels.InKernel)
-unAllocInKernelBody = unAllocKernelBody False
+unAllocInKernelStms :: Stms InKernel
+ -> Either String (Stms Kernels.InKernel)
+unAllocInKernelStms = unAllocStms False
where
unAllocBody (Body attr stms res) =
Body attr <$> unAllocStms True stms <*> pure res
- unAllocKernelBody nested (KernelBody attr stms res) =
- KernelBody attr <$> unAllocStms nested stms <*> pure res
-
unAllocStms nested =
fmap (stmsFromList . catMaybes) . mapM (unAllocStm nested) . stmsToList
@@ -408,10 +464,10 @@ removeCommonSizes :: M.Map VName (SubExp, Space)
removeCommonSizes = M.toList . foldl' comb mempty . M.toList
where comb m (mem, (size, space)) = M.insertWith (++) size [(mem, space)] m
-sliceKernelSizes :: [SubExp] -> KernelSpace -> KernelBody InKernel
+sliceKernelSizes :: [SubExp] -> KernelSpace -> Stms InKernel
-> ExpandM (Stms Kernels.Kernels, [VName], [VName])
-sliceKernelSizes sizes kspace kbody = do
- kbody' <- either compilerLimitationS return $ unAllocInKernelBody kbody
+sliceKernelSizes sizes kspace kstms = do
+ kstms' <- either compilerLimitationS return $ unAllocInKernelStms kstms
let num_sizes = length sizes
i64s = replicate num_sizes $ Prim int64
inkernels_scope <- asks unAllocScope
@@ -430,7 +486,7 @@ sliceKernelSizes sizes kspace kbody = do
params <- replicateM num_sizes $ newParam "x" (Prim int64)
(zs, stms) <- localScope (scopeOfLParams params <>
scopeOfKernelSpace kspace) $ collectStms $ do
- mapM_ addStm $ kernelBodyStms kbody'
+ mapM_ addStm kstms'
return sizes
localScope (scopeOfKernelSpace kspace) $
Kernels.simplifyLambda kspace -- XXX, is this the right KernelSpace?
diff --git a/src/Futhark/Pass/ExplicitAllocations.hs b/src/Futhark/Pass/ExplicitAllocations.hs
index 33ce62d..64d27a1 100644
--- a/src/Futhark/Pass/ExplicitAllocations.hs
+++ b/src/Futhark/Pass/ExplicitAllocations.hs
@@ -529,16 +529,25 @@ allocInFun (FunDef entry fname rettype params fbody) =
return $ FunDef entry fname (memoryInRetType rettype) params' fbody'
handleKernel :: Kernel InInKernel
- -> AllocM fromlore2 ExplicitMemory (MemOp (Kernel OutInKernel))
+ -> AllocM Kernels ExplicitMemory (MemOp (Kernel OutInKernel))
handleKernel (GetSize key size_class) =
return $ Inner $ GetSize key size_class
handleKernel (GetSizeMax size_class) =
return $ Inner $ GetSizeMax size_class
handleKernel (CmpSizeLe key size_class x) =
return $ Inner $ CmpSizeLe key size_class x
-handleKernel (Kernel desc space kernel_ts kbody) = subAllocM handleKernelExp True $
+handleKernel (Kernel desc space kernel_ts kbody) = subInKernel $
Inner . Kernel desc space kernel_ts <$>
localScope (scopeOfKernelSpace space) (allocInKernelBody kbody)
+
+handleKernel (SegRed space comm red_op nes ts body) = do
+ body' <- subInKernel $ localScope (scopeOfKernelSpace space) $ allocInBodyNoDirect body
+ red_op' <- allocInSegRedLambda (spaceGlobalId space) (spaceNumThreads space) red_op
+ return $ Inner $ SegRed space comm red_op' nes ts body'
+
+subInKernel :: AllocM InInKernel OutInKernel a
+ -> AllocM fromlore2 ExplicitMemory a
+subInKernel = subAllocM handleKernelExp True
where handleKernelExp (Barrier se) =
return $ Inner $ Barrier se
@@ -756,7 +765,7 @@ allocInReduceLambda lam input_summaries = do
(acc_params, arr_params) =
splitAt (length input_summaries) actual_params
this_index = LeafExp i int32
- other_index = LeafExp (paramName j_param) int32
+ other_index = this_index + LeafExp (paramName j_param) int32
acc_params' <-
allocInReduceParameters this_index $
zip acc_params input_summaries
@@ -784,6 +793,48 @@ allocInReduceParameters my_id = mapM allocInReduceParameter
Mem size space ->
return p { paramAttr = MemMem size space }
+allocInSegRedLambda :: VName -> SubExp -> Lambda InInKernel
+ -> AllocM Kernels ExplicitMemory (Lambda OutInKernel)
+allocInSegRedLambda gtid num_threads lam = do
+ let (acc_params, arr_params) =
+ splitAt (length (lambdaParams lam) `div` 2) $ lambdaParams lam
+ this_index = LeafExp gtid int32
+ other_index = this_index + primExpFromSubExp int32 num_threads
+ (acc_params', arr_params') <-
+ allocInSegRedParameters num_threads this_index other_index acc_params arr_params
+
+ subInKernel $ allocInLambda (acc_params' ++ arr_params')
+ (lambdaBody lam) (lambdaReturnType lam)
+
+allocInSegRedParameters :: SubExp
+ -> PrimExp VName -> PrimExp VName
+ -> [LParam InInKernel]
+ -> [LParam InInKernel]
+ -> AllocM Kernels ExplicitMemory ([LParam ExplicitMemory], [LParam ExplicitMemory])
+allocInSegRedParameters num_threads my_id other_id xs ys = unzip <$> zipWithM alloc xs ys
+ where alloc x y =
+ case paramType x of
+ Array bt shape u -> do
+ twice_num_threads <- letSubExp "twice_num_threads" $
+ BasicOp $ BinOp (Mul Int32) num_threads $ intConst Int32 2
+ let t = paramType x `arrayOfRow` twice_num_threads
+ (_, mem) <- allocForArray t DefaultSpace
+ -- XXX: this iota ixfun is a bit inefficient; leading to uncoalesced access.
+ let ixfun_base = IxFun.iota $
+ map (primExpFromSubExp int32) (arrayDims t)
+ ixfun_x = IxFun.slice ixfun_base $
+ fullSliceNum (IxFun.shape ixfun_base) [DimFix my_id]
+ ixfun_y = IxFun.slice ixfun_base $
+ fullSliceNum (IxFun.shape ixfun_base) [DimFix other_id]
+ return (x { paramAttr = MemArray bt shape u $ ArrayIn mem ixfun_x },
+ y { paramAttr = MemArray bt shape u $ ArrayIn mem ixfun_y })
+ Prim bt ->
+ return (x { paramAttr = MemPrim bt },
+ y { paramAttr = MemPrim bt })
+ Mem size space ->
+ return (x { paramAttr = MemMem size space },
+ y { paramAttr = MemMem size space })
+
allocInChunkedParameters :: PrimExp VName
-> [(LParam InInKernel, (VName, IxFun))]
-> AllocM InInKernel OutInKernel [LParam OutInKernel]
@@ -961,6 +1012,7 @@ kernelExpHints (BasicOp (Manifest perm v)) = do
ixfun = IxFun.permute (IxFun.iota $ map (primExpFromSubExp int32) dims')
perm_inv
return [Hint ixfun DefaultSpace]
+
kernelExpHints (Op (Inner (Kernel _ space rets kbody))) =
zipWithM hint rets $ kernelBodyResult kbody
where num_threads = spaceNumThreads space
@@ -975,17 +1027,6 @@ kernelExpHints (Op (Inner (Kernel _ space rets kbody))) =
coalesceReturnOfShape bs [Constant (IntValue (Int32Value d))] = bs * d > 4
coalesceReturnOfShape _ _ = True
- innermost space_dims t_dims =
- let r = length t_dims
- dims = space_dims ++ t_dims
- perm = [length space_dims..length space_dims+r-1] ++
- [0..length space_dims-1]
- perm_inv = rearrangeInverse perm
- dims_perm = rearrangeShape perm dims
- ixfun_base = IxFun.iota $ map (primExpFromSubExp int32) dims_perm
- ixfun_rearranged = IxFun.permute ixfun_base perm_inv
- in ixfun_rearranged
-
hint t (ThreadsReturn threads _)
| coalesceReturnOfShape (primByteSize (elemType t)) $ arrayDims t,
Just space_dims <- spacy threads = do
@@ -1004,9 +1045,30 @@ kernelExpHints (Op (Inner (Kernel _ space rets kbody))) =
return $ Hint ixfun DefaultSpace
hint _ _ = return NoHint
+
+kernelExpHints (Op (Inner (SegRed space _ _ nes ts body))) =
+ (map (const NoHint) red_res <>) <$> zipWithM mapHint (drop (length nes) ts) map_res
+ where (red_res, map_res) = splitAt (length nes) $ bodyResult body
+
+ mapHint t _ = do
+ t_dims <- mapM dimAllocationSize $ arrayDims t
+ return $ Hint (innermost (map snd $ spaceDimensions space) t_dims) DefaultSpace
+
kernelExpHints e =
return $ replicate (expExtTypeSize e) NoHint
+innermost :: [SubExp] -> [SubExp] -> IxFun
+innermost space_dims t_dims =
+ let r = length t_dims
+ dims = space_dims ++ t_dims
+ perm = [length space_dims..length space_dims+r-1] ++
+ [0..length space_dims-1]
+ perm_inv = rearrangeInverse perm
+ dims_perm = rearrangeShape perm dims
+ ixfun_base = IxFun.iota $ map (primExpFromSubExp int32) dims_perm
+ ixfun_rearranged = IxFun.permute ixfun_base perm_inv
+ in ixfun_rearranged
+
inKernelExpHints :: (Allocator lore m, Op lore ~ MemOp (KernelExp somelore)) =>
Exp lore -> m [ExpHint]
inKernelExpHints (Op (Inner (Combine (CombineSpace scatter cspace) ts _ _))) =
diff --git a/src/Futhark/Pass/ExtractKernels.hs b/src/Futhark/Pass/ExtractKernels.hs
index dd07807..2a251a2 100644
--- a/src/Futhark/Pass/ExtractKernels.hs
+++ b/src/Futhark/Pass/ExtractKernels.hs
@@ -6,6 +6,7 @@
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE RankNTypes #-}
-- | Kernel extraction.
--
-- In the following, I will use the term "width" to denote the amount
@@ -163,12 +164,11 @@ module Futhark.Pass.ExtractKernels
import Control.Monad.RWS.Strict
import Control.Monad.Reader
+import Control.Monad.Writer.Strict
import Control.Monad.Trans.Maybe
-import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Data.Maybe
import Data.List
-import qualified Data.Semigroup as Sem
import Futhark.Representation.SOACS
import Futhark.Representation.SOACS.Simplify (simplifyStms, simpleSOACS)
@@ -200,31 +200,38 @@ extractKernels :: Pass SOACS Out.Kernels
extractKernels =
Pass { passName = "extract kernels"
, passDescription = "Perform kernel extraction"
- , passFunction = runDistribM . fmap Prog . mapM transformFunDef . progFunctions
+ , passFunction = fmap Prog . mapM transformFunDef . progFunctions
}
-newtype DistribM a = DistribM (RWS (Scope Out.Kernels) Log VNameSource a)
+-- In order to generate more stable threshold names, we keep track of
+-- the numbers used for thresholds separately from the ordinary name
+-- source,
+data State = State { stateNameSource :: VNameSource
+ , stateThresholdCounter :: Int
+ }
+
+newtype DistribM a = DistribM (RWS (Scope Out.Kernels) Log State a)
deriving (Functor, Applicative, Monad,
- HasScope Out.Kernels,
- LocalScope Out.Kernels,
- MonadFreshNames,
+ HasScope Out.Kernels, LocalScope Out.Kernels,
+ MonadState State,
MonadLogger)
+instance MonadFreshNames DistribM where
+ getNameSource = gets stateNameSource
+ putNameSource src = modify $ \s -> s { stateNameSource = src }
+
runDistribM :: (MonadLogger m, MonadFreshNames m) =>
DistribM a -> m a
runDistribM (DistribM m) = do
- (x, msgs) <- modifyNameSource $ positionNameSource . runRWS m M.empty
+ (x, msgs) <- modifyNameSource $ \src ->
+ let (x, s, msgs) = runRWS m mempty (State src 0)
+ in ((x, msgs), stateNameSource s)
addLog msgs
return x
- where positionNameSource (x, src, msgs) = ((x, msgs), src)
-runDistribM' :: MonadFreshNames m => DistribM a -> m a
-runDistribM' (DistribM m) =
- fmap fst $ modifyNameSource $ positionNameSource . runRWS m M.empty
- where positionNameSource (x, src, msgs) = ((x, msgs), src)
-
-transformFunDef :: FunDef -> DistribM (Out.FunDef Out.Kernels)
-transformFunDef (FunDef entry name rettype params body) = do
+transformFunDef :: (MonadFreshNames m, MonadLogger m) =>
+ FunDef -> m (Out.FunDef Out.Kernels)
+transformFunDef (FunDef entry name rettype params body) = runDistribM $ do
body' <- localScope (scopeOfFParams params) $
transformBody mempty body
return $ FunDef entry name rettype params body'
@@ -262,9 +269,10 @@ scopeForKernels = castScope
transformStm :: KernelPath -> Stm -> DistribM KernelsStms
-transformStm path (Let pat aux (Op (CmpThreshold what s))) =
+transformStm path (Let pat aux (Op (CmpThreshold what s))) = do
+ ((r, _), stms) <- cmpSizeLe s (Out.SizeThreshold path) what
runBinder_ $ do
- (r, _) <- cmpSizeLe s (Out.SizeThreshold path) what
+ addStms stms
addStm $ Let pat aux $ BasicOp $ SubExp r
transformStm path (Let pat aux (If c tb fb rt)) = do
@@ -324,11 +332,12 @@ transformStm path (Let res_pat (StmAux cs _) (Op (Screma w form arrs)))
transformStm path (Let pat (StmAux cs _) (Op (Screma w form arrs)))
| Just (comm, red_lam, nes, map_lam) <- isRedomapSOAC form = do
- let paralleliseOuter = do
+ let paralleliseOuter = runBinder_ $ do
red_lam_sequential <- Kernelise.transformLambda red_lam
map_lam_sequential <- Kernelise.transformLambda map_lam
- fmap (certify cs) <$>
- blockedReduction pat w comm' red_lam_sequential map_lam_sequential [] nes arrs
+ addStms =<<
+ (fmap (certify cs) <$>
+ nonSegRed pat w comm' red_lam_sequential map_lam_sequential nes arrs)
outerParallelBody =
renameBody =<<
@@ -342,7 +351,6 @@ transformStm path (Let pat (StmAux cs _) (Op (Screma w form arrs)))
renameBody =<<
(mkBody <$> paralleliseInner path' <*> pure (map Var (patternNames pat)))
-
comm' | commutativeLambda red_lam = Commutative
| otherwise = comm
@@ -350,7 +358,7 @@ transformStm path (Let pat (StmAux cs _) (Op (Screma w form arrs)))
then paralleliseOuter
else if incrementalFlattening then do
((outer_suff, outer_suff_key), suff_stms) <-
- runBinder $ sufficientParallelism "suff_outer_redomap" w path
+ sufficientParallelism "suff_outer_redomap" w path
outer_stms <- outerParallelBody
inner_stms <- innerParallelBody ((outer_suff_key, False):path)
@@ -370,7 +378,7 @@ transformStm path (Let pat (StmAux cs _) (Op (Stream w (Parallel _ _ _ []) map_f
transformStm path (Let pat aux@(StmAux cs _) (Op (Stream w (Parallel o comm red_fun nes) fold_fun arrs)))
| incrementalFlattening = do
((outer_suff, outer_suff_key), suff_stms) <-
- runBinder $ sufficientParallelism "suff_outer_stream" w path
+ sufficientParallelism "suff_outer_stream" w path
outer_stms <- outerParallelBody ((outer_suff_key, True) : path)
inner_stms <- innerParallelBody ((outer_suff_key, False) : path)
@@ -466,13 +474,11 @@ data MapLoop = MapLoop Pattern Certificates SubExp Lambda [VName]
mapLoopStm :: MapLoop -> Stm
mapLoopStm (MapLoop pat cs w lam arrs) = Let pat (StmAux cs ()) $ Op $ Screma w (mapSOAC lam) arrs
-sufficientParallelism :: (Op (Lore m) ~ Kernel innerlore, MonadBinder m) =>
- String -> SubExp -> KernelPath -> m (SubExp, VName)
+sufficientParallelism :: String -> SubExp -> KernelPath
+ -> DistribM ((SubExp, Name), Out.Stms Out.Kernels)
sufficientParallelism desc what path = cmpSizeLe desc (Out.SizeThreshold path) what
-distributeMap :: (HasScope Out.Kernels m,
- MonadFreshNames m, MonadLogger m) =>
- KernelPath -> MapLoop -> m KernelsStms
+distributeMap :: KernelPath -> MapLoop -> DistribM KernelsStms
distributeMap path (MapLoop pat cs w lam arrs) = do
types <- askScope
let loopnest = MapNesting pat cs w $ zip (lambdaParams lam) arrs
@@ -505,7 +511,7 @@ distributeMap path (MapLoop pat cs w lam arrs) = do
return $ postKernelsStms postkernels <>
identityStms (outerTarget $ kernelTargets acc')
- distributeMap' (newKernel loopnest) path exploitOuterParallelism exploitInnerParallelism pat w lam
+ distributeMap' id (newKernel loopnest) path exploitOuterParallelism exploitInnerParallelism pat w lam
where acc = KernelAcc { kernelTargets = singleTarget (pat, bodyResult $ lambdaBody lam)
, kernelStms = mempty
}
@@ -520,19 +526,20 @@ distributeMap path (MapLoop pat cs w lam arrs) = do
Let (Pattern [] [pe]) (defAux ()) $ BasicOp $ Replicate (Shape [w]) se
distributeMap' :: (HasScope Out.Kernels m, MonadFreshNames m) =>
- KernelNest -> KernelPath
+ (forall a. DistribM a -> m a)
+ -> KernelNest -> KernelPath
-> (KernelPath -> m (Out.Stms Out.Kernels))
-> (KernelPath -> m (Out.Stms Out.Kernels))
-> PatternT Type
-> SubExp
-> LambdaT SOACS
-> m (Out.Stms Out.Kernels)
-distributeMap' loopnest path mk_seq_stms mk_par_stms pat nest_w lam = do
+distributeMap' distribM loopnest path mk_seq_stms mk_par_stms pat nest_w lam = do
let res = map Var $ patternNames pat
types <- askScope
- ((outer_suff, outer_suff_key), outer_suff_stms) <- runBinder $
- sufficientParallelism "suff_outer_par" nest_w path
+ ((outer_suff, outer_suff_key), outer_suff_stms) <-
+ distribM $ sufficientParallelism "suff_outer_par" nest_w path
intra <- if worthIntraGroup lam then
flip runReaderT types $ intraGroupParallelise loopnest lam
@@ -550,19 +557,25 @@ distributeMap' loopnest path mk_seq_stms mk_par_stms pat nest_w lam = do
Just ((_intra_min_par, intra_avail_par), group_size, intra_prelude, intra_stms) -> do
-- We must check that all intra-group parallelism fits in a group.
- ((intra_ok, intra_suff_key), intra_suff_stms) <- runBinder $ do
- addStms intra_prelude
+ ((intra_ok, intra_suff_key), intra_suff_stms) <- do
+
+ ((intra_suff, suff_key), check_suff_stms) <-
+ distribM $ sufficientParallelism "suff_intra_par" intra_avail_par $
+ (outer_suff_key, False) : path
+
+ runBinder $ do
+
+ addStms intra_prelude
- max_group_size <-
- letSubExp "max_group_size" $ Op $ Out.GetSizeMax Out.SizeGroup
- fits <- letSubExp "fits" $ BasicOp $
- CmpOp (CmpSle Int32) group_size max_group_size
+ max_group_size <-
+ letSubExp "max_group_size" $ Op $ Out.GetSizeMax Out.SizeGroup
+ fits <- letSubExp "fits" $ BasicOp $
+ CmpOp (CmpSle Int32) group_size max_group_size
- (intra_suff, suff_key) <- sufficientParallelism "suff_intra_par" intra_avail_par $
- (outer_suff_key, False) : path
+ addStms check_suff_stms
- intra_ok <- letSubExp "intra_suff_and_fits" $ BasicOp $ BinOp LogAnd fits intra_suff
- return (intra_ok, suff_key)
+ intra_ok <- letSubExp "intra_suff_and_fits" $ BasicOp $ BinOp LogAnd fits intra_suff
+ return (intra_ok, suff_key)
group_par_body <- renameBody $ mkBody intra_stms res
@@ -587,24 +600,22 @@ data KernelRes = KernelRes { accPostKernels :: PostKernels
, accLog :: Log
}
-instance Sem.Semigroup KernelRes where
+instance Semigroup KernelRes where
KernelRes ks1 log1 <> KernelRes ks2 log2 =
KernelRes (ks1 <> ks2) (log1 <> log2)
instance Monoid KernelRes where
mempty = KernelRes mempty mempty
- mappend = (Sem.<>)
newtype PostKernel = PostKernel { unPostKernel :: KernelsStms }
newtype PostKernels = PostKernels [PostKernel]
-instance Sem.Semigroup PostKernels where
+instance Semigroup PostKernels where
PostKernels xs <> PostKernels ys = PostKernels $ ys ++ xs
instance Monoid PostKernels where
mempty = PostKernels mempty
- mappend = (Sem.<>)
postKernelsStms :: PostKernels -> KernelsStms
postKernelsStms (PostKernels kernels) = mconcat $ map unPostKernel kernels
@@ -622,11 +633,19 @@ addStmToKernel bnd acc = do
stms <- runBinder_ $ Kernelise.transformStm bnd
return acc { kernelStms = stms <> kernelStms acc }
-newtype KernelM a = KernelM (RWS KernelEnv KernelRes VNameSource a)
+newtype KernelM a = KernelM (ReaderT KernelEnv (WriterT KernelRes DistribM) a)
deriving (Functor, Applicative, Monad,
MonadReader KernelEnv,
- MonadWriter KernelRes,
- MonadFreshNames)
+ MonadWriter KernelRes)
+
+liftDistribM :: DistribM a -> KernelM a
+liftDistribM m = do
+ scope <- askScope
+ KernelM $ lift $ lift $ localScope scope m
+
+instance MonadFreshNames KernelM where
+ getNameSource = KernelM $ lift getNameSource
+ putNameSource = KernelM . lift . putNameSource
instance HasScope Out.Kernels KernelM where
askScope = asks kernelScope
@@ -638,13 +657,11 @@ instance LocalScope Out.Kernels KernelM where
instance MonadLogger KernelM where
addLog msgs = tell mempty { accLog = msgs }
-runKernelM :: (MonadFreshNames m, MonadLogger m) =>
- KernelEnv -> KernelM a -> m (a, PostKernels)
+runKernelM :: KernelEnv -> KernelM a -> DistribM (a, PostKernels)
runKernelM env (KernelM m) = do
- (x, res) <- modifyNameSource $ getKernels . runRWS m env
+ (x, res) <- runWriterT $ runReaderT m env
addLog $ accLog res
return (x, accPostKernels res)
- where getKernels (x,s,a) = ((x, a), s)
collectKernels :: KernelM a -> KernelM (a, PostKernels)
collectKernels m = pass $ do
@@ -841,7 +858,7 @@ distributeInnerMap maploop@(MapLoop pat cs w lam arrs) acc
let outer_pat = loopNestingPattern $ fst nest
path <- asks kernelPath
addKernel =<< (nestw_bnds<>) <$>
- localScope extra_scope (distributeMap' nest' path
+ localScope extra_scope (distributeMap' liftDistribM nest' path
(const $ return $ oneStm sequentialised_kernel)
exploitInnerParallelism
outer_pat nestw
@@ -908,13 +925,11 @@ maybeDistributeStm bnd@(Let pat _ (DoLoop [] val form@ForLoop{} body)) acc
addKernels kernels
nest' <- expandKernelNest pat_unused nest
types <- asksScope scopeForSOACs
- scope <- askScope
+
bnds <- runReaderT
(interchangeLoops nest' (SeqLoop perm pat val form body)) types
- -- runDistribM starts out with an empty scope, so we have to
- -- immmediately insert the real one.
path <- asks kernelPath
- bnds' <- runDistribM $ localScope scope $ transformStms path $ stmsToList bnds
+ bnds' <- liftDistribM $ transformStms path $ stmsToList bnds
addKernel bnds'
return acc'
_ ->
@@ -937,11 +952,9 @@ maybeDistributeStm stm@(Let pat _ (If cond tbranch fbranch ret)) acc
types <- asksScope scopeForSOACs
let branch = Branch perm pat cond tbranch fbranch ret
stms <- runReaderT (interchangeBranch nest' branch) types
- -- runDistribM starts out with an empty scope, so we have to
- -- immmediately insert the real one.
- scope <- askScope
+
path <- asks kernelPath
- stms' <- runDistribM $ localScope scope $ transformStms path $ stmsToList stms
+ stms' <- liftDistribM $ transformStms path $ stmsToList stms
addKernel stms'
return acc'
_ ->
@@ -1293,7 +1306,7 @@ segmentedGenReduceKernel nest perm cs genred_w ops lam arrs = do
-- array). They will not be used anywhere else (due to uniqueness
-- constraints), so this is safe.
let all_dests = concatMap genReduceDest ops'
- (nest_stms<>) <$>
+ liftDistribM $ (nest_stms<>) <$>
inScopeOf nest_stms
(genReduceKernel path (kernelNestLoops $ removeArraysFromNest all_dests nest)
orig_pat ispace inputs cs genred_w ops' lam arrs)
@@ -1301,12 +1314,11 @@ segmentedGenReduceKernel nest perm cs genred_w ops lam arrs = do
maybe bad return $ find ((==a) . kernelInputName) kernel_inps
bad = fail "Ill-typed nested GenReduce encountered."
-genReduceKernel :: (HasScope Out.Kernels m, MonadFreshNames m) =>
- KernelPath -> [LoopNesting]
+genReduceKernel :: KernelPath -> [LoopNesting]
-> Pattern -> [(VName, SubExp)] -> [KernelInput]
-> Certificates -> SubExp -> [GenReduceOp SOACS]
-> InKernelLambda -> [VName]
- -> m KernelsStms
+ -> DistribM KernelsStms
genReduceKernel path nests orig_pat ispace inputs cs genred_w ops lam arrs = do
ops' <- forM ops $ \(GenReduceOp num_bins dests nes op) ->
GenReduceOp num_bins dests nes <$> Kernelise.transformLambda op
@@ -1314,26 +1326,17 @@ genReduceKernel path nests orig_pat ispace inputs cs genred_w ops lam arrs = do
let isDest = flip elem $ concatMap genReduceDest ops'
inputs' = filter (not . isDest . kernelInputArray) inputs
- runBinder_ $ do
- (histos, k_stms) <- blockedGenReduce genred_w ispace inputs' ops' lam arrs
-
- addStms $ fmap (certify cs) k_stms
+ (histos, k_stms) <- blockedGenReduce genred_w ispace inputs' ops' lam arrs
- let histos' = chunks (map (length . genReduceDest) ops') histos
- pes = chunks (map (length . genReduceDest) ops') $ patternElements orig_pat
+ let histos' = chunks (map (length . genReduceDest) ops') histos
+ pes = chunks (map (length . genReduceDest) ops') $ patternElements orig_pat
- mapM_ combineIntermediateResults (zip3 pes ops histos')
+ (fmap (certify cs) k_stms<>) . mconcat <$>
+ inScopeOf k_stms (mapM combineIntermediateResults (zip3 pes ops histos'))
where depth = length nests
- combineIntermediateResults (pes, GenReduceOp num_bins _ nes op, histos) = do
- num_histos <- arraysSize depth <$> mapM lookupType histos
-
- -- Avoid the segmented reduction if num_histos is 1.
- num_histos_is_one <-
- letSubExp "num_histos_is_one" $
- BasicOp $ CmpOp (CmpEq int32) num_histos $ intConst Int32 1
-
+ mkBodies num_histos pes num_bins nes op histos = runBinder $ do
body_with_reshape <- runBodyBinder $
fmap resultBody $ forM histos $ \histo -> do
histo_dims <- arrayDims <$> lookupType histo
@@ -1361,19 +1364,34 @@ genReduceKernel path nests orig_pat ispace inputs cs genred_w ops lam arrs = do
nests' <-
moreArrays (map paramName map_params) histos_tr_t histos_tr $
nests ++ [MapNesting inner_segred_pat cs num_bins $ zip (lambdaParams lam) arrs]
+
let collapse_body = reconstructMapNest nests' (map (rowType . patElemType) pes) $
mkBody map_stms $ map Var map_res
- scope <- askScope
+ return (body_with_reshape, collapse_body)
+
+ combineIntermediateResults (pes, GenReduceOp num_bins _ nes op, histos) = do
+ num_histos <- arraysSize depth <$> mapM lookupType histos
+
+ ((body_with_reshape, collapse_body), aux_stms) <- mkBodies num_histos pes num_bins nes op histos
+
segmented_reduce_stms <-
- runDistribM' $ localScope scope $ transformStms path $
- stmsToList $ bodyStms collapse_body
+ inScopeOf aux_stms $ transformStms path $ stmsToList $ bodyStms collapse_body
let body_with_segred = mkBody segmented_reduce_stms $
bodyResult collapse_body
- letBindNames (map patElemName pes) $
- If num_histos_is_one body_with_reshape body_with_segred $
- IfAttr (staticShapes $ map patElemType pes) IfNormal
+
+ runBinder_ $ do
+ addStms aux_stms
+
+ -- Avoid the segmented reduction if num_histos is 1.
+ num_histos_is_one <-
+ letSubExp "num_histos_is_one" $
+ BasicOp $ CmpOp (CmpEq int32) num_histos $ intConst Int32 1
+
+ letBindNames (map patElemName pes) $
+ If num_histos_is_one body_with_reshape body_with_segred $
+ IfAttr (staticShapes $ map patElemType pes) IfNormal
reconstructMapNest :: [LoopNesting] -> [Type] -> BodyT SOACS -> BodyT SOACS
reconstructMapNest [] _ body = body
@@ -1425,11 +1443,8 @@ regularSegmentedRedomapKernel :: KernelNest
regularSegmentedRedomapKernel nest perm segment_size comm lam map_lam nes arrs =
isSegmentedOp nest perm segment_size
(lambdaReturnType map_lam) (freeInLambda lam) (freeInLambda map_lam) nes arrs $
- \pat flat_pat num_segments total_num_elements ispace inps nes' _ arrs' -> do
- fold_lam <- composeLambda nilFn lam map_lam
- regularSegmentedRedomap
- segment_size num_segments (kernelNestWidths nest)
- flat_pat pat total_num_elements comm lam fold_lam ispace inps nes' arrs'
+ \pat _flat_pat _num_segments total_num_elements ispace inps nes' _ _ ->
+ addStms =<< segRed pat total_num_elements segment_size comm lam map_lam nes' arrs ispace inps
isSegmentedOp :: KernelNest
-> [Int]
@@ -1565,6 +1580,16 @@ expandKernelNest pes (outer_nest, inner_nests) = do
, patElemAttr = patElemType pe `arrayOfShape` Shape dims
}
+cmpSizeLe :: String -> Out.SizeClass -> SubExp
+ -> DistribM ((SubExp, Name), Out.Stms Out.Kernels)
+cmpSizeLe desc size_class to_what = do
+ x <- gets stateThresholdCounter
+ modify $ \s -> s { stateThresholdCounter = x + 1}
+ let size_key = nameFromString $ desc ++ "_" ++ show x
+ runBinder $ do
+ cmp_res <- letSubExp desc $ Op $ CmpSizeLe size_key size_class to_what
+ return (cmp_res, size_key)
+
kernelAlternatives :: (MonadFreshNames m, HasScope Out.Kernels m) =>
Out.Pattern Out.Kernels
-> Out.Body Out.Kernels
diff --git a/src/Futhark/Pass/ExtractKernels/BlockedKernel.hs b/src/Futhark/Pass/ExtractKernels/BlockedKernel.hs
index 598ab9c..b59c57d 100644
--- a/src/Futhark/Pass/ExtractKernels/BlockedKernel.hs
+++ b/src/Futhark/Pass/ExtractKernels/BlockedKernel.hs
@@ -7,6 +7,9 @@ module Futhark.Pass.ExtractKernels.BlockedKernel
, blockedMap
, blockedScan
+ , segRed
+ , nonSegRed
+
, mapKernel
, mapKernelFromBody
, KernelInput(..)
@@ -18,14 +21,12 @@ module Futhark.Pass.ExtractKernels.BlockedKernel
, chunkLambda
, splitArrays
, getSize
- , cmpSizeLe
)
where
import Control.Monad
import Data.Maybe
import Data.List
-import Data.Semigroup ((<>))
import qualified Data.Set as S
import Prelude hiding (quot)
@@ -48,16 +49,9 @@ import Futhark.Util.IntegralExp
getSize :: (MonadBinder m, Op (Lore m) ~ Kernel innerlore) =>
String -> SizeClass -> m SubExp
getSize desc size_class = do
- size_key <- newVName desc
+ size_key <- nameFromString . pretty <$> newVName desc
letSubExp desc $ Op $ GetSize size_key size_class
-cmpSizeLe :: (MonadBinder m, Op (Lore m) ~ Kernel innerlore) =>
- String -> SizeClass -> SubExp -> m (SubExp, VName)
-cmpSizeLe desc size_class to_what = do
- size_key <- newVName desc
- cmp_res <- letSubExp desc $ Op $ CmpSizeLe size_key size_class to_what
- return (cmp_res, size_key)
-
blockedReductionStream :: (MonadFreshNames m, HasScope Kernels m) =>
Pattern Kernels
-> SubExp
@@ -291,6 +285,58 @@ kerneliseLambda nes lam = do
fold_inp_params
}
+segRed :: (MonadFreshNames m, HasScope Kernels m) =>
+ Pattern Kernels
+ -> SubExp
+ -> SubExp -- segment size
+ -> Commutativity
+ -> Lambda InKernel -> Lambda InKernel
+ -> [SubExp] -> [VName]
+ -> [(VName, SubExp)] -- ispace = pair of (gtid, size) for the maps on "top" of this reduction
+ -> [KernelInput] -- inps = inputs that can be looked up by using the gtids from ispace
+ -> m (Stms Kernels)
+segRed pat total_num_elements w comm reduce_lam map_lam nes arrs ispace inps = runBinder_ $ do
+ (_, KernelSize num_groups group_size _ _ num_threads) <- blockedKernelSize =<< asIntS Int64 total_num_elements
+ gtid <- newVName "gtid"
+ kspace <- newKernelSpace (num_groups, group_size, num_threads) $ FlatThreadSpace $
+ ispace ++ [(gtid, w)]
+ body <- runBodyBinder $ localScope (scopeOfKernelSpace kspace) $ do
+ mapM_ (addStm <=< readKernelInput) inps
+ forM_ (zip (lambdaParams map_lam) arrs) $ \(p, arr) -> do
+ arr_t <- lookupType arr
+ letBindNames_ [paramName p] $
+ BasicOp $ Index arr $ fullSlice arr_t [DimFix $ Var gtid]
+ return $ lambdaBody map_lam
+
+ letBind_ pat $ Op $
+ SegRed kspace comm reduce_lam nes (lambdaReturnType map_lam) body
+
+nonSegRed :: (MonadFreshNames m, HasScope Kernels m) =>
+ Pattern Kernels
+ -> SubExp
+ -> Commutativity
+ -> Lambda InKernel
+ -> Lambda InKernel
+ -> [SubExp]
+ -> [VName]
+ -> m (Stms Kernels)
+nonSegRed pat w comm red_lam map_lam nes arrs = runBinder_ $ do
+ -- We add a unit-size segment on top to ensure that the result
+ -- of the SegRed is an array, which we then immediately index.
+ -- This is useful in the case that the value is used on the
+ -- device afterwards, as this may save an expensive
+ -- host-device copy (scalars are kept on the host, but arrays
+ -- may be on the device).
+ let addDummyDim t = t `arrayOfRow` intConst Int32 1
+ pat' <- fmap addDummyDim <$> renamePattern pat
+ dummy <- newVName "dummy"
+ addStms =<<
+ segRed pat' w w comm red_lam map_lam nes arrs [(dummy, intConst Int32 1)] []
+
+ forM_ (zip (patternNames pat') (patternNames pat)) $ \(from, to) -> do
+ from_t <- lookupType from
+ letBindNames_ [to] $ BasicOp $ Index from $ fullSlice from_t [DimFix $ intConst Int32 0]
+
blockedReduction :: (MonadFreshNames m, HasScope Kernels m) =>
Pattern Kernels
-> SubExp
diff --git a/src/Futhark/Pass/ExtractKernels/ISRWIM.hs b/src/Futhark/Pass/ExtractKernels/ISRWIM.hs
index 82b2274..dbee919 100644
--- a/src/Futhark/Pass/ExtractKernels/ISRWIM.hs
+++ b/src/Futhark/Pass/ExtractKernels/ISRWIM.hs
@@ -9,7 +9,6 @@ module Futhark.Pass.ExtractKernels.ISRWIM
import Control.Arrow (first)
import Control.Monad.State
-import Data.Semigroup ((<>))
import Futhark.MonadFreshNames
import Futhark.Representation.SOACS
diff --git a/src/Futhark/Pass/ExtractKernels/Intragroup.hs b/src/Futhark/Pass/ExtractKernels/Intragroup.hs
index aa536a6..d6c59a9 100644
--- a/src/Futhark/Pass/ExtractKernels/Intragroup.hs
+++ b/src/Futhark/Pass/ExtractKernels/Intragroup.hs
@@ -191,11 +191,11 @@ intraGroupStm stm@(Let pat _ e) = do
-- A GroupScan lambda needs two more parameters.
my_index <- newVName "my_index"
- other_index <- newVName "other_index"
+ offset <- newVName "offset"
let my_index_param = Param my_index (Prim int32)
- other_index_param = Param other_index (Prim int32)
+ offset_param = Param offset (Prim int32)
scanfun'' = scanfun' { lambdaParams = my_index_param :
- other_index_param :
+ offset_param :
lambdaParams scanfun'
}
letBind_ (Pattern [] scan_pes) $
@@ -212,11 +212,11 @@ intraGroupStm stm@(Let pat _ e) = do
-- A GroupReduce lambda needs two more parameters.
my_index <- newVName "my_index"
- other_index <- newVName "other_index"
+ offset <- newVName "offset"
let my_index_param = Param my_index (Prim int32)
- other_index_param = Param other_index (Prim int32)
+ offset_param = Param offset (Prim int32)
redfun'' = redfun' { lambdaParams = my_index_param :
- other_index_param :
+ offset_param :
lambdaParams redfun'
}
letBind_ (Pattern [] red_pes) $
diff --git a/src/Futhark/Pass/ExtractKernels/Kernelise.hs b/src/Futhark/Pass/ExtractKernels/Kernelise.hs
index 91f12d3..bc073eb 100644
--- a/src/Futhark/Pass/ExtractKernels/Kernelise.hs
+++ b/src/Futhark/Pass/ExtractKernels/Kernelise.hs
@@ -14,7 +14,6 @@ module Futhark.Pass.ExtractKernels.Kernelise
where
import Control.Monad
-import Data.Semigroup ((<>))
import qualified Data.Set as S
import qualified Futhark.Analysis.Alias as Alias
diff --git a/src/Futhark/Pass/ExtractKernels/Segmented.hs b/src/Futhark/Pass/ExtractKernels/Segmented.hs
index 7a210ab..180fc4a 100644
--- a/src/Futhark/Pass/ExtractKernels/Segmented.hs
+++ b/src/Futhark/Pass/ExtractKernels/Segmented.hs
@@ -1,15 +1,13 @@
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}
--- | Multiversion segmented reduction.
+-- | Segmented scan.
module Futhark.Pass.ExtractKernels.Segmented
- ( regularSegmentedRedomap
- , regularSegmentedScan
+ ( regularSegmentedScan
)
where
import Control.Monad
import qualified Data.Map.Strict as M
-import Data.Semigroup ((<>))
import Futhark.Transform.Rename
import Futhark.Representation.Kernels
@@ -18,814 +16,6 @@ import Futhark.MonadFreshNames
import Futhark.Tools hiding (true, false)
import Futhark.Pass.ExtractKernels.BlockedKernel
-data SegmentedVersion = OneGroupOneSegment
- | ManyGroupsOneSegment
- deriving (Eq, Ord, Show)
-
--- | @regularSegmentedRedomap@ will generate code for a segmented redomap using
--- two different strategies, and dynamically deciding which one to use based on
--- the number of segments and segment size. We use the (static) @group_size@ to
--- decide which of the following two strategies to choose:
---
--- * Large: uses one or more groups to process a single segment. If multiple
--- groups are used per segment, the intermediate reduction results must be
--- recursively reduced, until there is only a single value per segment.
---
--- Each thread /can/ read multiple elements, which will greatly increase
--- performance; however, if the reduction is non-commutative the input array
--- will be transposed (by the KernelBabysitter) to enable memory coalesced
--- accesses. Currently we will always make each thread read as many elements
--- as it can, but this /could/ be unfavorable because of the transpose: in
--- the case where each thread can only read 2 elements, the cost of the
--- transpose might not be worth the performance gained by letting each thread
--- read multiple elements. This could be investigated more in depth in the
--- future (TODO)
---
--- * Small: is used to let each group process *multiple* segments within a
--- group. We will only use this approach when we can process at least two
--- segments within a single group. In those cases, we would normally allocate
--- a /whole/ group per segment with the large strategy, but at most 50% of the
--- threads in the group would have any element to read, which becomes highly
--- inefficient.
-regularSegmentedRedomap :: (HasScope Kernels m, MonadBinder m, Lore m ~ Kernels) =>
- SubExp -- segment_size
- -> SubExp -- num_segments
- -> [SubExp] -- nest_sizes = the sizes of the maps on "top" of this redomap
- -> Pattern Kernels -- flat_pat ... pat where each type is array with dim [w]
- -> Pattern Kernels -- pat
- -> SubExp -- w = total_num_elements
- -> Commutativity -- comm
- -> Lambda InKernel -- reduce_lam
- -> Lambda InKernel -- fold_lam = this lambda performs both the map-part and
- -- reduce-part of a redomap (described in redomap paper)
- -> [(VName, SubExp)] -- ispace = pair of (gtid, size) for the maps on "top" of this redomap
- -> [KernelInput] -- inps = inputs that can be looked up by using the gtids from ispace
- -> [SubExp] -- nes
- -> [VName] -- arrs_flat
- -> m ()
-regularSegmentedRedomap segment_size num_segments nest_sizes flat_pat
- pat w comm reduce_lam fold_lam ispace inps nes arrs_flat = do
- unless (null $ patternContextElements pat) $ fail "regularSegmentedRedomap result pattern contains context elements, and Rasmus did not think this would ever happen."
-
- -- the result of the "map" part of a redomap has to be stored somewhere within
- -- the chunking loop of a kernel. The current way to do this is to make some
- -- scratch space initially, and each thread will get a part of this by
- -- splitting it. Finally it is returned as a result of the kernel (to not
- -- break functional semantics).
- map_out_arrs <- forM (drop num_redres $ patternIdents pat) $ \(Ident name t) -> do
- tmp <- letExp (baseString name <> "_out_in") $
- BasicOp $ Scratch (elemType t) (arrayDims t)
- -- This reshape will not always work.
- letExp (baseString name ++ "_out_in") $
- BasicOp $ Reshape (reshapeOuter [DimNew w] (length nest_sizes+1) $ arrayShape t) tmp
-
- -- Check that we're only dealing with arrays with dimension [w]
- forM_ arrs_flat $ \arr -> do
- tp <- lookupType arr
- case tp of
- -- TODO: this won't work if the reduction operator works on lists... but
- -- they seem to be handled in some other way (which makes sense)
- Array _primtp (Shape (flatsize:_)) _uniqness ->
- when (flatsize /= w) $
- fail$ "regularSegmentedRedomap: first dimension of array has incorrect size " ++ pretty arr ++ ":" ++ pretty tp
- _ ->
- fail $ "regularSegmentedRedomap: non array encountered " ++ pretty arr ++ ":" ++ pretty tp
-
- -- The pattern passed to chunkLambda must have exactly *one* array dimension,
- -- to get the correct size of [chunk_size]type.
- --
- -- TODO: not sure if this will work when result of map is multidimensional,
- -- or if reduction operator uses lists... must check
- chunk_pat <- fmap (Pattern []) $ forM (patternValueElements pat) $ \pat_e ->
- case patElemType pat_e of
- Array ty (Shape (dim0:_)) u -> do
- vn' <- newName $ patElemName pat_e
- return $ PatElem vn' $ Array ty (Shape [dim0]) u
- _ -> fail $ "segmentedRedomap: result pattern is not array " ++ pretty pat_e
-
- chunk_fold_lam <- chunkLambda chunk_pat nes fold_lam
-
- kern_chunk_fold_lam <- kerneliseLambda nes chunk_fold_lam
-
- let chunk_red_pat = Pattern [] $ take num_redres $ patternValueElements chunk_pat
- kern_chunk_reduce_lam <- kerneliseLambda nes =<< chunkLambda chunk_red_pat nes reduce_lam
-
- -- the lambda for a GroupReduce needs these two extra parameters
- my_index <- newVName "my_index"
- other_offset <- newVName "other_offset"
- let my_index_param = Param my_index (Prim int32)
- let other_offset_param = Param other_offset (Prim int32)
- let reduce_lam' = reduce_lam { lambdaParams = my_index_param :
- other_offset_param :
- lambdaParams reduce_lam
- }
- flag_reduce_lam <- addFlagToLambda nes reduce_lam
- let flag_reduce_lam' = flag_reduce_lam { lambdaParams = my_index_param :
- other_offset_param :
- lambdaParams flag_reduce_lam
- }
-
-
- -- TODO: 'blockedReductionStream' in BlockedKernel.hs which is very similar
- -- performs a copy here... however, I have not seen a need for it yet.
-
- group_size <- getSize "group_size" SizeGroup
- num_groups_hint <- getSize "num_groups_hint" SizeNumGroups
-
- -- Here we make a small optimization: if we will use the large kernel, and
- -- only one group per segment, we can simplify the calcualtions within the
- -- kernel for the indexes of which segment is it working on; therefore we
- -- create two different kernels (this will increase the final code size a bit
- -- though). TODO: test how much we win by doing this.
-
- (num_groups_per_segment, _) <-
- calcGroupsPerSegmentAndElementsPerThread
- segment_size num_segments num_groups_hint group_size ManyGroupsOneSegment
-
- let all_arrs = arrs_flat ++ map_out_arrs
- (large_1_ses, large_1_stms) <- runBinder $
- useLargeOnePerSeg group_size all_arrs reduce_lam' kern_chunk_fold_lam
- (large_m_ses, large_m_stms) <- runBinder $
- useLargeMultiRecursiveReduce group_size all_arrs reduce_lam' kern_chunk_fold_lam
- kern_chunk_reduce_lam flag_reduce_lam'
-
- let e_large_seg = eIf (eCmpOp (CmpEq $ IntType Int32) (eSubExp num_groups_per_segment)
- (eSubExp one))
- (mkBodyM large_1_stms large_1_ses)
- (mkBodyM large_m_stms large_m_ses)
-
-
- (small_ses, small_stms) <- runBinder $ useSmallKernel group_size map_out_arrs flag_reduce_lam'
-
- -- if (group_size/2) < segment_size, means that we will not be able to fit two
- -- segments into one group, and therefore we should not use the kernel that
- -- relies on this.
- e <- eIf (eCmpOp (CmpSlt Int32) (eBinOp (SQuot Int32) (eSubExp group_size) (eSubExp two))
- (eSubExp segment_size))
- (eBody [e_large_seg])
- (mkBodyM small_stms small_ses)
-
- redres_pes <- forM (take num_redres (patternValueElements pat)) $ \pe -> do
- vn' <- newName $ patElemName pe
- return $ PatElem vn' $ replaceSegmentDims num_segments $ patElemType pe
- let mapres_pes = drop num_redres $ patternValueElements flat_pat
- let unreshaped_pat = Pattern [] $ redres_pes ++ mapres_pes
-
- letBind_ unreshaped_pat e
-
- forM_ (zip (patternValueElements unreshaped_pat)
- (patternValueElements pat)) $ \(kpe, pe) ->
- letBind_ (Pattern [] [pe]) $
- BasicOp $ Reshape [DimNew se | se <- arrayDims $ patElemAttr pe]
- (patElemName kpe)
-
- where
- replaceSegmentDims d t =
- t `setArrayDims` (d : drop (length nest_sizes) (arrayDims t))
-
- one = constant (1 :: Int32)
- two = constant (2 :: Int32)
-
- -- number of reduction results (tuple size for reduction operator)
- num_redres = length nes
-
- ----------------------------------------------------------------------------
- -- The functions below generate all the needed code for the two different
- -- version of segmented-redomap (one group per segment, and many groups per
- -- segment).
- --
- -- We rename statements before adding them because the same lambdas
- -- (reduce/fold) are used multiple times, and we do not want to bind the
- -- same VName twice (as this is a type error)
- ----------------------------------------------------------------------------
- useLargeOnePerSeg group_size all_arrs reduce_lam' kern_chunk_fold_lam = do
- mapres_pes <- forM (drop num_redres $ patternValueElements flat_pat) $ \pe -> do
- vn' <- newName $ patElemName pe
- return $ PatElem vn' $ patElemType pe
-
- (kernel, _, _) <-
- largeKernel group_size segment_size num_segments nest_sizes
- all_arrs comm reduce_lam' kern_chunk_fold_lam
- nes w OneGroupOneSegment
- ispace inps
-
- kernel_redres_pes <- forM (take num_redres (patternValueElements pat)) $ \pe -> do
- vn' <- newName $ patElemName pe
- return $ PatElem vn' $ replaceSegmentDims num_segments $ patElemType pe
-
- let kernel_pat = Pattern [] $ kernel_redres_pes ++ mapres_pes
-
- addStm =<< renameStm (Let kernel_pat (defAux ()) $ Op kernel)
- return $ map (Var . patElemName) $ patternValueElements kernel_pat
-
- ----------------------------------------------------------------------------
- useLargeMultiRecursiveReduce group_size all_arrs reduce_lam' kern_chunk_fold_lam kern_chunk_reduce_lam flag_reduce_lam' = do
- mapres_pes <- forM (drop num_redres $ patternValueElements flat_pat) $ \pe -> do
- vn' <- newName $ patElemName pe
- return $ PatElem vn' $ patElemType pe
-
- (firstkernel, num_groups_used, num_groups_per_segment) <-
- largeKernel group_size segment_size num_segments nest_sizes
- all_arrs comm reduce_lam' kern_chunk_fold_lam
- nes w ManyGroupsOneSegment
- ispace inps
-
- firstkernel_redres_pes <- forM (take num_redres (patternValueElements pat)) $ \pe -> do
- vn' <- newName $ patElemName pe
- return $ PatElem vn' $ replaceSegmentDims num_groups_used $ patElemType pe
-
- let first_pat = Pattern [] $ firstkernel_redres_pes ++ mapres_pes
- addStm =<< renameStm (Let first_pat (defAux ()) $ Op firstkernel)
-
- let new_segment_size = num_groups_per_segment
- let new_total_elems = num_groups_used
- let tmp_redres = map patElemName firstkernel_redres_pes
-
- (finalredres, part_two_stms) <- runBinder $ performFinalReduction
- new_segment_size new_total_elems tmp_redres
- reduce_lam' kern_chunk_reduce_lam flag_reduce_lam'
-
- mapM_ (addStm <=< renameStm) part_two_stms
-
- return $ finalredres ++ map (Var . patElemName) mapres_pes
-
- ----------------------------------------------------------------------------
- -- The "recursive" reduction step. However, will always do this using
- -- exactly one extra step. Either by using the small kernel, or by using the
- -- large kernel with one group per segment.
- performFinalReduction new_segment_size new_total_elems tmp_redres
- reduce_lam' kern_chunk_reduce_lam flag_reduce_lam' = do
- group_size <- getSize "group_size" SizeGroup
-
- -- Large kernel, using one group per segment (ogps)
- (large_ses, large_stms) <- runBinder $ do
- (large_kernel, _, _) <- largeKernel group_size new_segment_size num_segments nest_sizes
- tmp_redres comm reduce_lam' kern_chunk_reduce_lam
- nes new_total_elems OneGroupOneSegment
- ispace inps
- letTupExp' "kernel_result" $ Op large_kernel
-
- -- Small kernel, using one group many segments (ogms)
- (small_ses, small_stms) <- runBinder $ do
- red_scratch_arrs <- forM (take num_redres $ patternIdents pat) $ \(Ident name t) -> do
- -- We construct a scratch array for writing the result, but
- -- we have to flatten the dimensions corresponding to the
- -- map nest, because multi-dimensional WriteReturns are/were
- -- not supported.
- tmp <- letExp (baseString name <> "_redres_scratch") $
- BasicOp $ Scratch (elemType t) (arrayDims t)
- let reshape = reshapeOuter [DimNew num_segments] (length nest_sizes) $ arrayShape t
- letExp (baseString name ++ "_redres_scratch") $
- BasicOp $ Reshape reshape tmp
- kernel <- smallKernel group_size new_segment_size num_segments
- tmp_redres red_scratch_arrs
- comm flag_reduce_lam' reduce_lam
- nes new_total_elems ispace inps
- letTupExp' "kernel_result" $ Op kernel
-
- e <- eIf (eCmpOp (CmpSlt Int32)
- (eBinOp (SQuot Int32) (eSubExp group_size) (eSubExp two))
- (eSubExp new_segment_size))
- (mkBodyM large_stms large_ses)
- (mkBodyM small_stms small_ses)
-
- letTupExp' "step_two_kernel_result" e
-
- ----------------------------------------------------------------------------
- useSmallKernel group_size map_out_arrs flag_reduce_lam' = do
- red_scratch_arrs <-
- forM (take num_redres $ patternIdents pat) $ \(Ident name t) -> do
- tmp <- letExp (baseString name <> "_redres_scratch") $
- BasicOp $ Scratch (elemType t) (arrayDims t)
- let shape_change = reshapeOuter [DimNew num_segments]
- (length nest_sizes) (arrayShape t)
- letExp (baseString name ++ "_redres_scratch") $
- BasicOp $ Reshape shape_change tmp
-
- let scratch_arrays = red_scratch_arrs ++ map_out_arrs
-
- kernel <- smallKernel group_size segment_size num_segments
- arrs_flat scratch_arrays
- comm flag_reduce_lam' fold_lam
- nes w ispace inps
- letTupExp' "kernel_result" $ Op kernel
-
-largeKernel :: (MonadBinder m, Lore m ~ Kernels) =>
- SubExp -- group_size
- -> SubExp -- segment_size
- -> SubExp -- num_segments
- -> [SubExp] -- nest sizes
- -> [VName] -- all_arrs: flat arrays (also the "map_out" ones)
- -> Commutativity -- comm
- -> Lambda InKernel -- reduce_lam
- -> Lambda InKernel -- kern_chunk_fold_lam
- -> [SubExp] -- nes
- -> SubExp -- w = total_num_elements
- -> SegmentedVersion -- segver
- -> [(VName, SubExp)] -- ispace = pair of (gtid, size) for the maps on "top" of this redomap
- -> [KernelInput] -- inps = inputs that can be looked up by using the gtids from ispace
- -> m (Kernel InKernel, SubExp, SubExp)
-largeKernel group_size segment_size num_segments nest_sizes all_arrs comm
- reduce_lam' kern_chunk_fold_lam
- nes w segver ispace inps = do
- let num_redres = length nes -- number of reduction results (tuple size for
- -- reduction operator)
-
- num_groups_hint <- getSize "num_groups_hint" SizeNumGroups
-
- (num_groups_per_segment, elements_per_thread) <-
- calcGroupsPerSegmentAndElementsPerThread segment_size num_segments num_groups_hint group_size segver
-
- num_groups <- letSubExp "num_groups" $
- case segver of
- OneGroupOneSegment -> BasicOp $ SubExp num_segments
- ManyGroupsOneSegment -> BasicOp $ BinOp (Mul Int32) num_segments num_groups_per_segment
-
- num_threads <- letSubExp "num_threads" $
- BasicOp $ BinOp (Mul Int32) num_groups group_size
-
- threads_within_segment <- letSubExp "threads_within_segment" $
- BasicOp $ BinOp (Mul Int32) group_size num_groups_per_segment
-
- gtid_vn <- newVName "gtid"
- gtid_ln <- newVName "gtid"
-
- -- the array passed here is the structure for how to layout the kernel space
- space <- newKernelSpace (num_groups, group_size, num_threads) $
- FlatThreadSpace $ ispace ++ [(gtid_vn, num_groups_per_segment),(gtid_ln,group_size)]
-
- let red_ts = take num_redres $ lambdaReturnType kern_chunk_fold_lam
- let map_ts = map rowType $ drop num_redres $ lambdaReturnType kern_chunk_fold_lam
- let kernel_return_types = red_ts ++ map_ts
-
- let ordering = case comm of Commutative -> SplitStrided threads_within_segment
- Noncommutative -> SplitContiguous
-
- let stride = case ordering of SplitStrided s -> s
- SplitContiguous -> one
-
- let each_thread = do
- segment_index <- letSubExp "segment_index" $
- BasicOp $ BinOp (SQuot Int32) (Var $ spaceGroupId space) num_groups_per_segment
-
- -- localId + (group_size * (groupId % num_groups_per_segment))
- index_within_segment <- letSubExp "index_within_segment" =<<
- eBinOp (Add Int32)
- (eSubExp $ Var gtid_ln)
- (eBinOp (Mul Int32)
- (eSubExp group_size)
- (eBinOp (SRem Int32) (eSubExp $ Var $ spaceGroupId space) (eSubExp num_groups_per_segment))
- )
-
- (in_segment_offset,offset) <-
- makeOffsetExp ordering index_within_segment elements_per_thread segment_index
-
- let (_, chunksize, [], arr_params) =
- partitionChunkedKernelFoldParameters 0 $ lambdaParams kern_chunk_fold_lam
- let chunksize_se = Var $ paramName chunksize
-
- patelems_res_of_split <- forM arr_params $ \arr_param -> do
- let chunk_t = paramType arr_param `setOuterSize` Var (paramName chunksize)
- return $ PatElem (paramName arr_param) chunk_t
-
- letBind_ (Pattern [] [PatElem (paramName chunksize) $ paramType chunksize]) $
- Op $ SplitSpace ordering segment_size index_within_segment elements_per_thread
-
- addKernelInputStms inps
-
- forM_ (zip all_arrs patelems_res_of_split) $ \(arr, pe) -> do
- let pe_t = patElemType pe
- segment_dims = nest_sizes ++ arrayDims (pe_t `setOuterSize` segment_size)
- arr_nested <- letExp (baseString arr ++ "_nested") $
- BasicOp $ Reshape (map DimNew segment_dims) arr
- arr_nested_t <- lookupType arr_nested
- let slice = fullSlice arr_nested_t $ map (DimFix . Var . fst) ispace ++
- [DimSlice in_segment_offset chunksize_se stride]
- letBind_ (Pattern [] [pe]) $ BasicOp $ Index arr_nested slice
-
- red_pes <- forM red_ts $ \red_t -> do
- pe_name <- newVName "chunk_fold_red"
- return $ PatElem pe_name red_t
- map_pes <- forM map_ts $ \map_t -> do
- pe_name <- newVName "chunk_fold_map"
- return $ PatElem pe_name $ map_t `arrayOfRow` chunksize_se
-
- -- we add the lets here, as we practially don't know if the resulting subexp
- -- is a Constant or a Var, so better be safe (?)
- addStms $ bodyStms (lambdaBody kern_chunk_fold_lam)
- addStms $ stmsFromList
- [ Let (Pattern [] [pe]) (defAux ()) $ BasicOp $ SubExp se
- | (pe,se) <- zip (red_pes ++ map_pes)
- (bodyResult $ lambdaBody kern_chunk_fold_lam) ]
-
- -- Combine the reduction results from each thread. This will put results in
- -- local memory, so a GroupReduce can be performed on them
- combine_red_pes <- forM red_ts $ \red_t -> do
- pe_name <- newVName "chunk_fold_red"
- return $ PatElem pe_name $ red_t `arrayOfRow` group_size
- cids <- replicateM (length red_pes) $ newVName "cid"
- addStms $ stmsFromList
- [ Let (Pattern [] [pe']) (defAux ()) $
- Op $ Combine (combineSpace [(cid, group_size)]) [patElemType pe] [] $
- Body () mempty [Var $ patElemName pe]
- | (cid, pe', pe) <- zip3 cids combine_red_pes red_pes ]
-
- final_red_pes <- forM (lambdaReturnType reduce_lam') $ \t -> do
- pe_name <- newVName "final_result"
- return $ PatElem pe_name t
- letBind_ (Pattern [] final_red_pes) $
- Op $ GroupReduce group_size reduce_lam' $
- zip nes (map patElemName combine_red_pes)
-
- return (final_red_pes, map_pes, offset)
-
-
- ((final_red_pes, map_pes, offset), stms) <- runBinder each_thread
-
- red_returns <- forM final_red_pes $ \pe ->
- return $ ThreadsReturn OneResultPerGroup $ Var $ patElemName pe
- map_returns <- forM map_pes $ \pe ->
- return $ ConcatReturns ordering w elements_per_thread
- (Just offset) $
- patElemName pe
- let kernel_returns = red_returns ++ map_returns
-
- let kerneldebughints = KernelDebugHints kernelname
- [ ("num_segment", num_segments)
- , ("segment_size", segment_size)
- , ("num_groups", num_groups)
- , ("group_size", group_size)
- , ("elements_per_thread", elements_per_thread)
- , ("num_groups_per_segment", num_groups_per_segment)
- ]
-
- let kernel = Kernel kerneldebughints space kernel_return_types $
- KernelBody () stms kernel_returns
-
- return (kernel, num_groups, num_groups_per_segment)
-
- where
- one = constant (1 :: Int32)
-
- commname = case comm of Commutative -> "comm"
- Noncommutative -> "nocomm"
-
- kernelname = case segver of
- OneGroupOneSegment -> "segmented_redomap__large_" ++ commname ++ "_one"
- ManyGroupsOneSegment -> "segmented_redomap__large_" ++ commname ++ "_many"
-
- makeOffsetExp SplitContiguous index_within_segment elements_per_thread segment_index = do
- in_segment_offset <- letSubExp "in_segment_offset" $
- BasicOp $ BinOp (Mul Int32) elements_per_thread index_within_segment
- offset <- letSubExp "offset" =<< eBinOp (Add Int32) (eSubExp in_segment_offset)
- (eBinOp (Mul Int32) (eSubExp segment_size) (eSubExp segment_index))
- return (in_segment_offset, offset)
- makeOffsetExp (SplitStrided _) index_within_segment _elements_per_thread segment_index = do
- offset <- letSubExp "offset" =<< eBinOp (Add Int32) (eSubExp index_within_segment)
- (eBinOp (Mul Int32) (eSubExp segment_size) (eSubExp segment_index))
- return (index_within_segment, offset)
-
-calcGroupsPerSegmentAndElementsPerThread :: (MonadBinder m, Lore m ~ Kernels) =>
- SubExp
- -> SubExp
- -> SubExp
- -> SubExp
- -> SegmentedVersion
- -> m (SubExp, SubExp)
-calcGroupsPerSegmentAndElementsPerThread segment_size num_segments
- num_groups_hint group_size segver = do
- num_groups_per_segment_hint <-
- letSubExp "num_groups_per_segment_hint" =<<
- case segver of
- OneGroupOneSegment -> eSubExp one
- ManyGroupsOneSegment -> eDivRoundingUp Int32 (eSubExp num_groups_hint)
- (eSubExp num_segments)
- elements_per_thread <-
- letSubExp "elements_per_thread" =<<
- eDivRoundingUp Int32 (eSubExp segment_size)
- (eBinOp (Mul Int32) (eSubExp group_size)
- (eSubExp num_groups_per_segment_hint))
-
- -- if we are using 1 element per thread, we might be launching too many
- -- groups. This expression will remedy this.
- --
- -- For example, if there are 3 segments of size 512, we are using group size
- -- 128, and @num_groups_hint@ is 256; then we would use 1 element per thread,
- -- and launch 256 groups. However, we only need 4 groups per segment to
- -- process all elements.
- num_groups_per_segment <-
- letSubExp "num_groups_per_segment" =<<
- case segver of
- OneGroupOneSegment -> eSubExp one
- ManyGroupsOneSegment ->
- eIf (eCmpOp (CmpEq $ IntType Int32) (eSubExp elements_per_thread) (eSubExp one))
- (eBody [eDivRoundingUp Int32 (eSubExp segment_size) (eSubExp group_size)])
- (mkBodyM mempty [num_groups_per_segment_hint])
-
- return (num_groups_per_segment, elements_per_thread)
-
- where
- one = constant (1 :: Int32)
-
-smallKernel :: (MonadBinder m, Lore m ~ Kernels) =>
- SubExp -- group_size
- -> SubExp -- segment_size
- -> SubExp -- num_segments
- -> [VName] -- in_arrs: flat arrays (containing input to fold_lam)
- -> [VName] -- scratch_arrs: Preallocated space that we can write into
- -> Commutativity -- comm
- -> Lambda InKernel -- flag_reduce_lam'
- -> Lambda InKernel -- fold_lam
- -> [SubExp] -- nes
- -> SubExp -- w = total_num_elements
- -> [(VName, SubExp)] -- ispace = pair of (gtid, size) for the maps on "top" of this redomap
- -> [KernelInput] -- inps = inputs that can be looked up by using the gtids from ispace
- -> m (Kernel InKernel)
-smallKernel group_size segment_size num_segments in_arrs scratch_arrs
- comm flag_reduce_lam' fold_lam_unrenamed
- nes w ispace inps = do
- let num_redres = length nes -- number of reduction results (tuple size for
- -- reduction operator)
-
- fold_lam <- renameLambda fold_lam_unrenamed
-
- num_segments_per_group <- letSubExp "num_segments_per_group" $
- BasicOp $ BinOp (SQuot Int32) group_size segment_size
-
- num_groups <- letSubExp "num_groups" =<<
- eDivRoundingUp Int32 (eSubExp num_segments) (eSubExp num_segments_per_group)
-
- num_threads <- letSubExp "num_threads" $
- BasicOp $ BinOp (Mul Int32) num_groups group_size
-
- active_threads_per_group <- letSubExp "active_threads_per_group" $
- BasicOp $ BinOp (Mul Int32) segment_size num_segments_per_group
-
- let remainder_last_group = eBinOp (SRem Int32) (eSubExp num_segments) (eSubExp num_segments_per_group)
-
- segments_in_last_group <- letSubExp "seg_in_last_group" =<<
- eIf (eCmpOp (CmpEq $ IntType Int32) remainder_last_group
- (eSubExp zero))
- (eBody [eSubExp num_segments_per_group])
- (eBody [remainder_last_group])
-
- active_threads_in_last_group <- letSubExp "active_threads_last_group" $
- BasicOp $ BinOp (Mul Int32) segment_size segments_in_last_group
-
- -- the array passed here is the structure for how to layout the kernel space
- space <- newKernelSpace (num_groups, group_size, num_threads) $
- FlatThreadSpace []
-
- ------------------------------------------------------------------------------
- -- What follows is the statements used in the kernel
- ------------------------------------------------------------------------------
-
- let lid = Var $ spaceLocalId space
-
- let (red_ts, map_ts) = splitAt num_redres $ lambdaReturnType fold_lam
- let kernel_return_types = red_ts ++ map_ts
-
- let wasted_thread_part1 = do
- let create_dummy_val (Prim ty) = return $ Constant $ blankPrimValue ty
- create_dummy_val (Array ty sh _) = letSubExp "dummy" $ BasicOp $ Scratch ty (shapeDims sh)
- create_dummy_val Mem{} = fail "segredomap, 'Mem' used as result type"
- dummy_vals <- mapM create_dummy_val kernel_return_types
- return (negone : dummy_vals)
-
- let normal_thread_part1 = do
- segment_index <- letSubExp "segment_index" =<<
- eBinOp (Add Int32)
- (eBinOp (SQuot Int32) (eSubExp $ Var $ spaceLocalId space) (eSubExp segment_size))
- (eBinOp (Mul Int32) (eSubExp $ Var $ spaceGroupId space) (eSubExp num_segments_per_group))
-
- index_within_segment <- letSubExp "index_within_segment" =<<
- eBinOp (SRem Int32) (eSubExp $ Var $ spaceLocalId space) (eSubExp segment_size)
-
- offset <- makeOffsetExp index_within_segment segment_index
-
- red_pes <- forM red_ts $ \red_t -> do
- pe_name <- newVName "fold_red"
- return $ PatElem pe_name red_t
- map_pes <- forM map_ts $ \map_t -> do
- pe_name <- newVName "fold_map"
- return $ PatElem pe_name map_t
-
- addManualIspaceCalcStms segment_index ispace
-
- addKernelInputStms inps
-
- -- Index input array to get arguments to fold_lam
- let arr_params = drop num_redres $ lambdaParams fold_lam
- let nonred_lamparam_pes = map
- (\p -> PatElem (paramName p) (paramType p)) arr_params
- forM_ (zip in_arrs nonred_lamparam_pes) $ \(arr, pe) -> do
- tp <- lookupType arr
- let slice = fullSlice tp [DimFix offset]
- letBind_ (Pattern [] [pe]) $ BasicOp $ Index arr slice
-
- -- Bind neutral element (serves as the reduction arguments to fold_lam)
- forM_ (zip nes (take num_redres $ lambdaParams fold_lam)) $ \(ne,param) -> do
- let pe = PatElem (paramName param) (paramType param)
- letBind_ (Pattern [] [pe]) $ BasicOp $ SubExp ne
-
- addStms $ bodyStms $ lambdaBody fold_lam
-
- -- we add the lets here, as we practially don't know if the resulting subexp
- -- is a Constant or a Var, so better be safe (?)
- addStms $ stmsFromList
- [ Let (Pattern [] [pe]) (defAux ()) $ BasicOp $ SubExp se
- | (pe,se) <- zip (red_pes ++ map_pes) (bodyResult $ lambdaBody fold_lam) ]
-
- let mapoffset = offset
- let mapret_elems = map (Var . patElemName) map_pes
- let redres_elems = map (Var . patElemName) red_pes
- return (mapoffset : redres_elems ++ mapret_elems)
-
- let all_threads red_pes = do
- isfirstinsegment <- letExp "isfirstinsegment" =<<
- eCmpOp (CmpEq $ IntType Int32)
- (eBinOp (SRem Int32) (eSubExp lid) (eSubExp segment_size))
- (eSubExp zero)
-
- -- We will perform a segmented-scan, so all the prime variables here
- -- include the flag, which is the first argument to flag_reduce_lam
- let red_pes_wflag = PatElem isfirstinsegment (Prim Bool) : red_pes
- let red_ts_wflag = Prim Bool : red_ts
-
- -- Combine the reduction results from each thread. This will put results in
- -- local memory, so a GroupReduce/GroupScan can be performed on them
- combine_red_pes' <- forM red_ts_wflag $ \red_t -> do
- pe_name <- newVName "chunk_fold_red"
- return $ PatElem pe_name $ red_t `arrayOfRow` group_size
- cids <- replicateM (length red_pes_wflag) $ newVName "cid"
- addStms $ stmsFromList [ Let (Pattern [] [pe']) (defAux ()) $ Op $
- Combine (combineSpace [(cid, group_size)]) [patElemType pe] [] $
- Body () mempty [Var $ patElemName pe]
- | (cid, pe', pe) <- zip3 cids combine_red_pes' red_pes_wflag ]
-
- scan_red_pes_wflag <- forM red_ts_wflag $ \red_t -> do
- pe_name <- newVName "scanned"
- return $ PatElem pe_name $ red_t `arrayOfRow` group_size
- let scan_red_pes = drop 1 scan_red_pes_wflag
- letBind_ (Pattern [] scan_red_pes_wflag) $ Op $
- GroupScan group_size flag_reduce_lam' $
- zip (false:nes) (map patElemName combine_red_pes')
-
- return scan_red_pes
-
- let normal_thread_part2 scan_red_pes = do
- segment_index <- letSubExp "segment_index" =<<
- eBinOp (Add Int32)
- (eBinOp (SQuot Int32) (eSubExp $ Var $ spaceLocalId space) (eSubExp segment_size))
- (eBinOp (Mul Int32) (eSubExp $ Var $ spaceGroupId space) (eSubExp num_segments_per_group))
-
- islastinsegment <- letExp "islastinseg" =<< eCmpOp (CmpEq $ IntType Int32)
- (eBinOp (SRem Int32) (eSubExp lid) (eSubExp segment_size))
- (eBinOp (Sub Int32) (eSubExp segment_size) (eSubExp one))
-
- redoffset <- letSubExp "redoffset" =<<
- eIf (eSubExp $ Var islastinsegment)
- (eBody [eSubExp segment_index])
- (mkBodyM mempty [negone])
-
- redret_elems <- fmap (map Var) $ letTupExp "red_return_elem" =<<
- eIf (eSubExp $ Var islastinsegment)
- (eBody [return $ BasicOp $ Index (patElemName pe) (fullSlice (patElemType pe) [DimFix lid])
- | pe <- scan_red_pes])
- (mkBodyM mempty nes)
-
- return (redoffset : redret_elems)
-
-
- let picknchoose = do
- is_last_group <- letSubExp "islastgroup" =<<
- eCmpOp (CmpEq $ IntType Int32)
- (eSubExp $ Var $ spaceGroupId space)
- (eBinOp (Sub Int32) (eSubExp num_groups) (eSubExp one))
-
- active_threads_this_group <- letSubExp "active_thread_this_group" =<<
- eIf (eSubExp is_last_group)
- (eBody [eSubExp active_threads_in_last_group])
- (eBody [eSubExp active_threads_per_group])
-
- isactive <- letSubExp "isactive" =<<
- eCmpOp (CmpSlt Int32) (eSubExp lid) (eSubExp active_threads_this_group)
-
- -- Part 1: All active threads reads element from input array and applies
- -- folding function. "wasted" threads will just create dummy values
- (normal_res1, normal_stms1) <- runBinder normal_thread_part1
- (wasted_res1, wasted_stms1) <- runBinder wasted_thread_part1
-
- -- we could just have used letTupExp, but this would not give as nice
- -- names in the generated code
- mapoffset_pe <- (`PatElem` i32) <$> newVName "mapoffset"
- redtmp_pes <- forM red_ts $ \red_t -> do
- pe_name <- newVName "redtmp_res"
- return $ PatElem pe_name red_t
- map_pes <- forM map_ts $ \map_t -> do
- pe_name <- newVName "map_res"
- return $ PatElem pe_name map_t
-
- e1 <- eIf (eSubExp isactive)
- (mkBodyM normal_stms1 normal_res1)
- (mkBodyM wasted_stms1 wasted_res1)
- letBind_ (Pattern [] (mapoffset_pe:redtmp_pes++map_pes)) e1
-
- -- Part 2: All threads participate in Comine & GroupScan
- scan_red_pes <- all_threads redtmp_pes
-
- -- Part 3: Active thread that are the last element in segment, should
- -- write the element from local memory to the output array
- (normal_res2, normal_stms2) <- runBinder $ normal_thread_part2 scan_red_pes
-
- redoffset_pe <- (`PatElem` i32) <$> newVName "redoffset"
- red_pes <- forM red_ts $ \red_t -> do
- pe_name <- newVName "red_res"
- return $ PatElem pe_name red_t
-
- e2 <- eIf (eSubExp isactive)
- (mkBodyM normal_stms2 normal_res2)
- (mkBodyM mempty (negone : nes))
- letBind_ (Pattern [] (redoffset_pe:red_pes)) e2
-
- return $ map (Var . patElemName) $ redoffset_pe:mapoffset_pe:red_pes++map_pes
-
- (redoffset:mapoffset:redmapres, stms) <- runBinder picknchoose
- let (finalredvals, finalmapvals) = splitAt num_redres redmapres
-
- -- To be able to only return elements from some threads, we exploit the fact
- -- that WriteReturn with offset=-1, won't do anything.
- red_returns <- forM (zip finalredvals $ take num_redres scratch_arrs) $ \(se, scarr) ->
- return $ WriteReturn [num_segments] scarr [([redoffset], se)]
- map_returns <- forM (zip finalmapvals $ drop num_redres scratch_arrs) $ \(se, scarr) ->
- return $ WriteReturn [w] scarr [([mapoffset], se)]
- let kernel_returns = red_returns ++ map_returns
-
- let kerneldebughints = KernelDebugHints kernelname
- [ ("num_segment", num_segments)
- , ("segment_size", segment_size)
- , ("num_groups", num_groups)
- , ("group_size", group_size)
- , ("num_segments_per_group", num_segments_per_group)
- , ("active_threads_per_group", active_threads_per_group)
- ]
-
- let kernel = Kernel kerneldebughints space kernel_return_types $
- KernelBody () stms kernel_returns
-
- return kernel
-
- where
- i32 = Prim $ IntType Int32
- zero = constant (0 :: Int32)
- one = constant (1 :: Int32)
- negone = constant (-1 :: Int32)
- false = constant False
-
-
- commname = case comm of Commutative -> "comm"
- Noncommutative -> "nocomm"
- kernelname = "segmented_redomap__small_" ++ commname
-
- makeOffsetExp index_within_segment segment_index = do
- e <- eBinOp (Add Int32)
- (eSubExp index_within_segment)
- (eBinOp (Mul Int32) (eSubExp segment_size) (eSubExp segment_index))
- letSubExp "offset" e
-
-addKernelInputStms :: (MonadBinder m, Lore m ~ InKernel) =>
- [KernelInput]
- -> m ()
-addKernelInputStms = mapM_ $ \kin -> do
- let pe = PatElem (kernelInputName kin) (kernelInputType kin)
- let arr = kernelInputArray kin
- arrtp <- lookupType arr
- let slice = fullSlice arrtp [DimFix se | se <- kernelInputIndices kin]
- letBind (Pattern [] [pe]) $ BasicOp $ Index arr slice
-
--- | Manually calculate the values for the ispace identifiers, when the
--- 'SpaceStructure' won't do. ispace is the dimensions of the overlaying maps.
---
--- If the input is @i [(a_vn, a), (b_vn, b), (c_vn, c)]@ then @i@ should hit all
--- the values [0,a*b*c). We can calculate the indexes for the other dimensions:
---
--- > c_vn = i % c
--- > b_vn = (i/c) % b
--- > a_vn = ((i/c)/b) % a
-addManualIspaceCalcStms :: (MonadBinder m, Lore m ~ InKernel) =>
- SubExp
- -> [(VName, SubExp)]
- -> m ()
-addManualIspaceCalcStms outer_index ispace = do
- -- TODO: The ispace index is calculated in a bit different way than it
- -- would have been done if the ThreadSpace was used. However, this
- -- works. Maybe ask Troels if doing it the other way has some benefit?
- let calc_ispace_index prev_val (vn,size) = do
- let pe = PatElem vn (Prim $ IntType Int32)
- letBind_ (Pattern [] [pe]) $ BasicOp $ BinOp (SRem Int32) prev_val size
- letSubExp "tmp_val" $ BasicOp $ BinOp (SQuot Int32) prev_val size
- foldM_ calc_ispace_index outer_index (reverse ispace)
-
addFlagToLambda :: (MonadBinder m, Lore m ~ Kernels) =>
[SubExp] -> Lambda InKernel -> m (Lambda InKernel)
addFlagToLambda nes lam = do
diff --git a/src/Futhark/Pass/ExtractKernels/Split.hs b/src/Futhark/Pass/ExtractKernels/Split.hs
new file mode 100644
index 0000000..8bff6b2
--- /dev/null
+++ b/src/Futhark/Pass/ExtractKernels/Split.hs
@@ -0,0 +1,41 @@
+-- | Functionality for identifying chunks of interesting parallelism
+-- inside of a map nesting.
+module Futhark.Pass.ExtractKernels.Split
+ ( splitMap) where
+
+import Control.Monad.RWS.Strict
+import Control.Monad.Reader
+import Control.Monad.Trans.Maybe
+import qualified Data.Map.Strict as M
+import qualified Data.Set as S
+import Data.Maybe
+import Data.List
+import qualified Data.Semigroup as Sem
+
+import Futhark.Representation.SOACS
+import Futhark.Representation.SOACS.Simplify (simplifyStms, simpleSOACS)
+import qualified Futhark.Representation.Kernels as Out
+import Futhark.Representation.Kernels.Kernel
+import Futhark.MonadFreshNames
+import Futhark.Tools
+import qualified Futhark.Transform.FirstOrderTransform as FOT
+import qualified Futhark.Pass.ExtractKernels.Kernelise as Kernelise
+import Futhark.Transform.Rename
+import Futhark.Pass
+import Futhark.Transform.CopyPropagate
+import Futhark.Pass.ExtractKernels.Distribution
+import Futhark.Pass.ExtractKernels.ISRWIM
+import Futhark.Pass.ExtractKernels.BlockedKernel
+import Futhark.Pass.ExtractKernels.Segmented
+import Futhark.Pass.ExtractKernels.Interchange
+import Futhark.Pass.ExtractKernels.Intragroup
+import Futhark.Util
+import Futhark.Util.Log
+
+type KernelsStms = Out.Stms Out.Kernels
+type InKernelStms = Out.Stms Out.InKernel
+type InKernelLambda = Out.Lambda Out.InKernel
+
+splitMap :: (MonadFreshNames m) =>
+ Scope SOACS -> a -> m [a]
+splitMap scope loop = return [loop]
diff --git a/src/Futhark/Pass/KernelBabysitting.hs b/src/Futhark/Pass/KernelBabysitting.hs
index f00694c..e632bf3 100644
--- a/src/Futhark/Pass/KernelBabysitting.hs
+++ b/src/Futhark/Pass/KernelBabysitting.hs
@@ -14,7 +14,6 @@ import qualified Data.Set as S
import Data.Foldable
import Data.List
import Data.Maybe
-import Data.Semigroup ((<>))
import Futhark.MonadFreshNames
import Futhark.Representation.AST
diff --git a/src/Futhark/Pkg/Info.hs b/src/Futhark/Pkg/Info.hs
index 4312a1e..2af16d9 100644
--- a/src/Futhark/Pkg/Info.hs
+++ b/src/Futhark/Pkg/Info.hs
@@ -25,7 +25,6 @@ import qualified Data.Map as M
import qualified Data.Text as T
import qualified Data.ByteString as BS
import qualified Data.Text.Encoding as T
-import qualified Data.Semigroup as Sem
import Data.List
import Data.Monoid ((<>))
import qualified System.FilePath.Posix as Posix
@@ -255,12 +254,11 @@ glPkgInfo owner repo versions =
-- monoidically. In essence, the PkgRegistry is just a cache.
newtype PkgRegistry m = PkgRegistry (M.Map PkgPath (PkgInfo m))
-instance Sem.Semigroup (PkgRegistry m) where
+instance Semigroup (PkgRegistry m) where
PkgRegistry x <> PkgRegistry y = PkgRegistry $ x <> y
instance Monoid (PkgRegistry m) where
mempty = PkgRegistry mempty
- mappend = (Sem.<>)
lookupKnownPackage :: PkgPath -> PkgRegistry m -> Maybe (PkgInfo m)
lookupKnownPackage p (PkgRegistry m) = M.lookup p m
diff --git a/src/Futhark/Pkg/Types.hs b/src/Futhark/Pkg/Types.hs
index 8a7712c..29b5cc3 100644
--- a/src/Futhark/Pkg/Types.hs
+++ b/src/Futhark/Pkg/Types.hs
@@ -42,8 +42,6 @@ import Data.List
import Data.Maybe
import Data.Traversable
import Data.Void
-import Data.Semigroup ((<>))
-import qualified Data.Semigroup as Sem
import qualified Data.Text as T
import qualified Data.Text.IO as T
import qualified Data.Map as M
@@ -102,12 +100,11 @@ semver' = SemVer <$> majorP <*> minorP <*> patchP <*> preRel <*> metaData
newtype PkgRevDeps = PkgRevDeps (M.Map PkgPath (SemVer, Maybe T.Text))
deriving (Show)
-instance Sem.Semigroup PkgRevDeps where
+instance Semigroup PkgRevDeps where
PkgRevDeps x <> PkgRevDeps y = PkgRevDeps $ x <> y
instance Monoid PkgRevDeps where
mempty = PkgRevDeps mempty
- mappend = (Sem.<>)
--- Package manifest
diff --git a/src/Futhark/Representation/AST/Attributes/TypeOf.hs b/src/Futhark/Representation/AST/Attributes/TypeOf.hs
index b24ad8b..db4ec08 100644
--- a/src/Futhark/Representation/AST/Attributes/TypeOf.hs
+++ b/src/Futhark/Representation/AST/Attributes/TypeOf.hs
@@ -38,7 +38,6 @@ module Futhark.Representation.AST.Attributes.TypeOf
where
import Data.Maybe
-import Data.Semigroup ((<>))
import Data.Foldable
import qualified Data.Set as S
diff --git a/src/Futhark/Representation/AST/Syntax.hs b/src/Futhark/Representation/AST/Syntax.hs
index ca1ad90..071ca95 100644
--- a/src/Futhark/Representation/AST/Syntax.hs
+++ b/src/Futhark/Representation/AST/Syntax.hs
@@ -69,7 +69,6 @@ module Futhark.Representation.AST.Syntax
import Data.Foldable
import Data.Loc
import qualified Data.Sequence as Seq
-import qualified Data.Semigroup as Sem
import Language.Futhark.Core
import Futhark.Representation.AST.Annotations
@@ -90,12 +89,11 @@ data PatternT attr =
instance Functor PatternT where
fmap f (Pattern ctx val) = Pattern (map (fmap f) ctx) (map (fmap f) val)
-instance Sem.Semigroup (PatternT attr) where
+instance Semigroup (PatternT attr) where
Pattern cs1 vs1 <> Pattern cs2 vs2 = Pattern (cs1++cs2) (vs1++vs2)
instance Monoid (PatternT attr) where
mempty = Pattern [] []
- mappend = (Sem.<>)
-- | A type alias for namespace control.
type Pattern lore = PatternT (LetAttr lore)
@@ -152,14 +150,20 @@ type Body = BodyT
-- | The new dimension in a 'Reshape'-like operation. This allows us to
-- disambiguate "real" reshapes, that change the actual shape of the
-- array, from type coercions that are just present to make the types
--- work out.
+-- work out. The two constructors are considered equal for purposes of 'Eq'.
data DimChange d = DimCoercion d
-- ^ The new dimension is guaranteed to be numerically
-- equal to the old one.
| DimNew d
-- ^ The new dimension is not necessarily numerically
-- equal to the old one.
- deriving (Eq, Ord, Show)
+ deriving (Ord, Show)
+
+instance Eq d => Eq (DimChange d) where
+ DimCoercion x == DimNew y = x == y
+ DimCoercion x == DimCoercion y = x == y
+ DimNew x == DimCoercion y = x == y
+ DimNew x == DimNew y = x == y
instance Functor DimChange where
fmap f (DimCoercion d) = DimCoercion $ f d
diff --git a/src/Futhark/Representation/AST/Syntax/Core.hs b/src/Futhark/Representation/AST/Syntax/Core.hs
index 1cfadc3..33a6b27 100644
--- a/src/Futhark/Representation/AST/Syntax/Core.hs
+++ b/src/Futhark/Representation/AST/Syntax/Core.hs
@@ -59,7 +59,6 @@ import Data.Monoid ((<>))
import Data.String
import qualified Data.Set as S
import qualified Data.Map.Strict as M
-import qualified Data.Semigroup as Sem
import Data.Traversable
import Language.Futhark.Core
@@ -101,12 +100,11 @@ class (Monoid a, Eq a, Ord a) => ArrayShape a where
-- | Check whether one shape if a subset of another shape.
subShapeOf :: a -> a -> Bool
-instance Sem.Semigroup (ShapeBase d) where
+instance Semigroup (ShapeBase d) where
Shape l1 <> Shape l2 = Shape $ l1 `mappend` l2
instance Monoid (ShapeBase d) where
mempty = Shape mempty
- mappend = (Sem.<>)
instance Functor ShapeBase where
fmap f = Shape . map f . shapeDims
@@ -135,12 +133,11 @@ instance ArrayShape (ShapeBase ExtSize) where
Nothing -> do put $ M.insert y x extmap
return True
-instance Sem.Semigroup Rank where
+instance Semigroup Rank where
Rank x <> Rank y = Rank $ x + y
instance Monoid Rank where
mempty = Rank 0
- mappend = (Sem.<>)
instance ArrayShape Rank where
shapeRank (Rank x) = x
@@ -214,12 +211,11 @@ instance Ord Ident where
newtype Certificates = Certificates { unCertificates :: [VName] }
deriving (Eq, Ord, Show)
-instance Sem.Semigroup Certificates where
+instance Semigroup Certificates where
Certificates x <> Certificates y = Certificates (x <> y)
instance Monoid Certificates where
mempty = Certificates mempty
- mappend = (Sem.<>)
-- | A subexpression is either a scalar constant or a variable. One
-- important property is that evaluation of a subexpression is
diff --git a/src/Futhark/Representation/Aliases.hs b/src/Futhark/Representation/Aliases.hs
index 6bad757..48c08ca 100644
--- a/src/Futhark/Representation/Aliases.hs
+++ b/src/Futhark/Representation/Aliases.hs
@@ -46,7 +46,6 @@ import Data.Maybe
import Data.Monoid ((<>))
import qualified Data.Map.Strict as M
import qualified Data.Set as S
-import qualified Data.Semigroup as Sem
import Futhark.Representation.AST.Syntax
import Futhark.Representation.AST.Attributes
@@ -68,12 +67,11 @@ data Aliases lore
newtype Names' = Names' { unNames :: Names }
deriving (Show)
-instance Sem.Semigroup Names' where
+instance Semigroup Names' where
x <> y = Names' $ unNames x <> unNames y
instance Monoid Names' where
mempty = Names' mempty
- mappend = (Sem.<>)
instance Eq Names' where
_ == _ = True
diff --git a/src/Futhark/Representation/ExplicitMemory/Simplify.hs b/src/Futhark/Representation/ExplicitMemory/Simplify.hs
index ba1adc6..c88b968 100644
--- a/src/Futhark/Representation/ExplicitMemory/Simplify.hs
+++ b/src/Futhark/Representation/ExplicitMemory/Simplify.hs
@@ -12,7 +12,6 @@ where
import Control.Monad
import qualified Data.Set as S
-import Data.Semigroup ((<>))
import Data.Maybe
import Data.List
diff --git a/src/Futhark/Representation/Kernels/Kernel.hs b/src/Futhark/Representation/Kernels/Kernel.hs
index 6f0bcee..f539307 100644
--- a/src/Futhark/Representation/Kernels/Kernel.hs
+++ b/src/Futhark/Representation/Kernels/Kernel.hs
@@ -80,11 +80,14 @@ data KernelDebugHints =
deriving (Eq, Show, Ord)
data Kernel lore =
- GetSize VName SizeClass -- ^ Produce some runtime-configurable size.
+ GetSize Name SizeClass -- ^ Produce some runtime-configurable size.
| GetSizeMax SizeClass -- ^ The maximum size of some class.
- | CmpSizeLe VName SizeClass SubExp
+ | CmpSizeLe Name SizeClass SubExp
-- ^ Compare size (likely a threshold) with some Int32 value.
| Kernel KernelDebugHints KernelSpace [Type] (KernelBody lore)
+ | SegRed KernelSpace Commutativity (Lambda lore) [SubExp] [Type] (Body lore)
+ -- ^ The KernelSpace must always have at least two dimensions,
+ -- implying that the result of a SegRed is always an array.
deriving (Eq, Show, Ord)
data KernelSpace = KernelSpace { spaceGlobalId :: VName
@@ -189,28 +192,39 @@ mapKernelM _ (GetSizeMax size_class) =
pure $ GetSizeMax size_class
mapKernelM tv (CmpSizeLe name size_class x) =
CmpSizeLe name size_class <$> mapOnKernelSubExp tv x
+mapKernelM tv (SegRed space comm red_op nes ts lam) =
+ SegRed
+ <$> mapOnKernelSpace tv space
+ <*> pure comm
+ <*> mapOnKernelLambda tv red_op
+ <*> mapM (mapOnKernelSubExp tv) nes
+ <*> mapM (mapOnType $ mapOnKernelSubExp tv) ts
+ <*> mapOnKernelBody tv lam
mapKernelM tv (Kernel desc space ts kernel_body) =
Kernel <$> mapOnKernelDebugHints desc <*>
- mapOnKernelSpace space <*>
+ mapOnKernelSpace tv space <*>
mapM (mapOnKernelType tv) ts <*>
mapOnKernelKernelBody tv kernel_body
where mapOnKernelDebugHints (KernelDebugHints name kvs) =
KernelDebugHints name <$>
(zip (map fst kvs) <$> mapM (mapOnKernelSubExp tv . snd) kvs)
- mapOnKernelSpace (KernelSpace gtid ltid gid num_threads num_groups group_size structure) =
- KernelSpace gtid ltid gid -- all in binding position
- <$> mapOnKernelSubExp tv num_threads
- <*> mapOnKernelSubExp tv num_groups
- <*> mapOnKernelSubExp tv group_size
- <*> mapOnKernelStructure structure
- mapOnKernelStructure (FlatThreadSpace dims) =
+
+mapOnKernelSpace :: Monad f =>
+ KernelMapper flore tlore f -> KernelSpace -> f KernelSpace
+mapOnKernelSpace tv (KernelSpace gtid ltid gid num_threads num_groups group_size structure) =
+ KernelSpace gtid ltid gid -- all in binding position
+ <$> mapOnKernelSubExp tv num_threads
+ <*> mapOnKernelSubExp tv num_groups
+ <*> mapOnKernelSubExp tv group_size
+ <*> mapOnKernelStructure structure
+ where mapOnKernelStructure (FlatThreadSpace dims) =
FlatThreadSpace <$> (zip gtids <$> mapM (mapOnKernelSubExp tv) gdim_sizes)
where (gtids, gdim_sizes) = unzip dims
mapOnKernelStructure (NestedThreadSpace dims) =
- NestedThreadSpace <$> (zip4 gtids
- <$> mapM (mapOnKernelSubExp tv) gdim_sizes
- <*> pure ltids
- <*> mapM (mapOnKernelSubExp tv) ldim_sizes)
+ NestedThreadSpace <$> (zip4 gtids
+ <$> mapM (mapOnKernelSubExp tv) gdim_sizes
+ <*> pure ltids
+ <*> mapM (mapOnKernelSubExp tv) ldim_sizes)
where (gtids, gdim_sizes, ltids, ldim_sizes) = unzip4 dims
mapOnKernelType :: Monad m =>
@@ -400,6 +414,13 @@ kernelType (Kernel _ space ts body) =
resultShape t KernelInPlaceReturn{} =
t
+kernelType (SegRed space _ _ nes ts _) =
+ map (`arrayOfShape` Shape outer_dims) red_ts ++
+ map (`arrayOfShape` Shape dims) map_ts
+ where (red_ts, map_ts) = splitAt (length nes) ts
+ dims = map snd $ spaceDimensions space
+ outer_dims = init dims
+
kernelType GetSize{} = [Prim int32]
kernelType GetSizeMax{} = [Prim int32]
kernelType CmpSizeLe{} = [Prim Bool]
@@ -544,6 +565,8 @@ instance Attributes lore => ST.IndexOp (Kernel lore) where
instance Aliased lore => UsageInOp (Kernel lore) where
usageInOp (Kernel _ _ _ kbody) =
mconcat $ map UT.consumedUsage $ S.toList $ consumedInKernelBody kbody
+ usageInOp (SegRed _ _ _ _ _ body) =
+ mconcat $ map UT.consumedUsage $ S.toList $ consumedInBody body
usageInOp GetSize{} = mempty
usageInOp GetSizeMax{} = mempty
usageInOp CmpSizeLe{} = mempty
@@ -559,6 +582,22 @@ typeCheckKernel GetSize{} = return ()
typeCheckKernel GetSizeMax{} = return ()
typeCheckKernel (CmpSizeLe _ _ x) = TC.require [Prim int32] x
+typeCheckKernel (SegRed space _ red_op nes ts body) = do
+ checkSpace space
+ mapM_ TC.checkType ts
+
+ ne_ts <- mapM subExpType nes
+
+ let asArg t = (t, mempty)
+ TC.binding (scopeOfKernelSpace space) $ do
+ TC.checkLambda red_op $ map asArg $ ne_ts ++ ne_ts
+ unless (lambdaReturnType red_op == ne_ts &&
+ take (length nes) ts == ne_ts) $
+ TC.bad $ TC.TypeError
+ "SegRed: wrong type for reduction or neutral elements."
+
+ TC.checkLambdaBody ts body
+
typeCheckKernel (Kernel _ space kts kbody) = do
checkSpace space
mapM_ TC.checkType kts
@@ -566,16 +605,7 @@ typeCheckKernel (Kernel _ space kts kbody) = do
TC.binding (scopeOfKernelSpace space) $
checkKernelBody kts kbody
- where checkSpace (KernelSpace _ _ _ num_threads num_groups group_size structure) = do
- mapM_ (TC.require [Prim int32]) [num_threads,num_groups,group_size]
- case structure of
- FlatThreadSpace dims ->
- mapM_ (TC.require [Prim int32] . snd) dims
- NestedThreadSpace dims ->
- let (_, gdim_sizes, _, ldim_sizes) = unzip4 dims
- in mapM_ (TC.require [Prim int32]) $ gdim_sizes ++ ldim_sizes
-
- checkKernelBody ts (KernelBody (_, attr) stms res) = do
+ where checkKernelBody ts (KernelBody (_, attr) stms res) = do
TC.checkBodyLore attr
TC.checkStms stms $ do
unless (length ts == length res) $
@@ -617,11 +647,23 @@ typeCheckKernel (Kernel _ space kts kbody) = do
mapM_ (TC.requireI [Prim int32] . fst) limit
mapM_ (TC.require [Prim int32] . snd) limit
+checkSpace :: TC.Checkable lore => KernelSpace -> TC.TypeM lore ()
+checkSpace (KernelSpace _ _ _ num_threads num_groups group_size structure) = do
+ mapM_ (TC.require [Prim int32]) [num_threads,num_groups,group_size]
+ case structure of
+ FlatThreadSpace dims ->
+ mapM_ (TC.require [Prim int32] . snd) dims
+ NestedThreadSpace dims ->
+ let (_, gdim_sizes, _, ldim_sizes) = unzip4 dims
+ in mapM_ (TC.require [Prim int32]) $ gdim_sizes ++ ldim_sizes
+
instance OpMetrics (Op lore) => OpMetrics (Kernel lore) where
opMetrics (Kernel _ _ _ kbody) =
inside "Kernel" $ kernelBodyMetrics kbody
where kernelBodyMetrics :: KernelBody lore -> MetricsM ()
kernelBodyMetrics = mapM_ bindingMetrics . kernelBodyStms
+ opMetrics (SegRed _ _ red_op _ _ body) =
+ inside "SegRed" $ lambdaMetrics red_op >> bodyMetrics body
opMetrics GetSize{} = seen "GetSize"
opMetrics GetSizeMax{} = seen "GetSizeMax"
opMetrics CmpSizeLe{} = seen "CmpSizeLe"
@@ -642,6 +684,14 @@ instance PrettyLore lore => PP.Pretty (Kernel lore) where
PP.align (ppr space) <+>
PP.colon <+> ppTuple' ts <+> PP.nestedBlock "{" "}" (ppr body)
+ ppr (SegRed space comm red_op nes ts body) =
+ text name <> PP.parens (ppr red_op <> PP.comma </>
+ PP.braces (PP.commasep $ map ppr nes)) </>
+ PP.align (ppr space) <+> PP.colon <+> ppTuple' ts <+>
+ PP.nestedBlock "{" "}" (ppr body)
+ where name = case comm of Commutative -> "segred_comm"
+ Noncommutative -> "segred"
+
instance Pretty KernelSpace where
ppr (KernelSpace f_gtid f_ltid gid num_threads num_groups group_size structure) =
parens (commasep [text "num groups:" <+> ppr num_groups,
diff --git a/src/Futhark/Representation/Kernels/Simplify.hs b/src/Futhark/Representation/Kernels/Simplify.hs
index c4b9557..4c616ca 100644
--- a/src/Futhark/Representation/Kernels/Simplify.hs
+++ b/src/Futhark/Representation/Kernels/Simplify.hs
@@ -19,7 +19,6 @@ import Data.Either
import Data.Foldable
import Data.List
import Data.Maybe
-import Data.Semigroup ((<>))
import qualified Data.Map.Strict as M
import qualified Data.Set as S
@@ -77,6 +76,37 @@ simplifyKernelOp mk_ops env (Kernel desc space ts kbody) = do
kbody_hoisted' <- mapM processHoistedStm kbody_hoisted
return (Kernel desc space' ts' $ mkWiseKernelBody () kbody_stms kbody_res,
kbody_hoisted')
+ where scope = scopeOfKernelSpace space
+ scope_vtable = ST.fromScope scope
+ bound_here = S.fromList $ M.keys scope
+
+simplifyKernelOp mk_ops env (SegRed space comm red_op nes ts body) = do
+ space' <- Engine.simplify space
+ nes' <- mapM Engine.simplify nes
+ ts' <- mapM Engine.simplify ts
+ outer_vtable <- Engine.askVtable
+
+ (red_op', red_op_hoisted) <-
+ Engine.subSimpleM (mk_ops space) env outer_vtable $
+ Engine.localVtable (<>scope_vtable) $
+ Engine.simplifyLambda red_op $ replicate (length nes * 2) Nothing
+ red_op_hoisted' <- mapM processHoistedStm red_op_hoisted
+
+ ((body_stms, body_res), body_hoisted) <-
+ Engine.subSimpleM (mk_ops space) env outer_vtable $ do
+ par_blocker <- Engine.asksEngineEnv $ Engine.blockHoistPar . Engine.envHoistBlockers
+ Engine.localVtable (<>scope_vtable) $
+ Engine.blockIf (Engine.hasFree bound_here
+ `Engine.orIf` Engine.isOp
+ `Engine.orIf` par_blocker
+ `Engine.orIf` Engine.isConsumed) $
+ Engine.simplifyBody (replicate (length ts) Observe) body
+ body_hoisted' <- mapM processHoistedStm body_hoisted
+
+ return (SegRed space' comm red_op' nes' ts' $
+ mkWiseBody () body_stms body_res,
+ red_op_hoisted' <> body_hoisted')
+
where scope_vtable = ST.fromScope scope
scope = scopeOfKernelSpace space
bound_here = S.fromList $ M.keys scope
diff --git a/src/Futhark/Representation/Kernels/Sizes.hs b/src/Futhark/Representation/Kernels/Sizes.hs
index 3a2b1c1..6d96167 100644
--- a/src/Futhark/Representation/Kernels/Sizes.hs
+++ b/src/Futhark/Representation/Kernels/Sizes.hs
@@ -3,12 +3,12 @@ module Futhark.Representation.Kernels.Sizes
where
import Futhark.Util.Pretty
-import Language.Futhark.Core (VName)
+import Language.Futhark.Core (Name)
import Futhark.Representation.AST.Pretty ()
-- | An indication of which comparisons have been performed to get to
-- this point, as well as the result of each comparison.
-type KernelPath = [(VName, Bool)]
+type KernelPath = [(Name, Bool)]
-- | The class of some kind of configurable size. Each class may
-- impose constraints on the valid values.
diff --git a/src/Futhark/Representation/SOACS/Simplify.hs b/src/Futhark/Representation/SOACS/Simplify.hs
index 1771bcd..03a456f 100644
--- a/src/Futhark/Representation/SOACS/Simplify.hs
+++ b/src/Futhark/Representation/SOACS/Simplify.hs
@@ -18,7 +18,6 @@ import Data.Foldable
import Data.Either
import Data.List
import Data.Maybe
-import Data.Semigroup ((<>))
import qualified Data.Map.Strict as M
import qualified Data.Set as S
diff --git a/src/Futhark/Test.hs b/src/Futhark/Test.hs
index fb9cae0..265ed3b 100644
--- a/src/Futhark/Test.hs
+++ b/src/Futhark/Test.hs
@@ -1,4 +1,6 @@
{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE TupleSections #-}
-- | Facilities for reading Futhark test programs. A Futhark test
-- program is an ordinary Futhark program where an initial comment
-- block specifies input- and output-sets.
@@ -9,6 +11,12 @@ module Futhark.Test
, getValues
, getValuesBS
, compareValues
+ , compareValues1
+ , testRunReferenceOutput
+ , getExpectedResult
+ , compileProgram
+ , runProgram
+ , ensureReferenceOutput
, Mismatch
, ProgramTest (..)
@@ -20,6 +28,7 @@ module Futhark.Test
, InputOutputs (..)
, TestRun (..)
, ExpectedResult (..)
+ , Success(..)
, Values (..)
, GenValue (..)
, Value
@@ -31,14 +40,13 @@ import qualified Data.ByteString.Lazy as BS
import qualified Data.ByteString as SBS
import Control.Exception (catch)
import Control.Monad
-import Control.Monad.IO.Class
+import Control.Monad.Except
import qualified Data.Map.Strict as M
import Data.Char
import Data.Functor
import Data.Maybe
import Data.Foldable (foldl')
import Data.List
-import Data.Semigroup
import qualified Data.Text as T
import qualified Data.Text.IO as T
import qualified Data.Text.Encoding as T
@@ -54,15 +62,16 @@ import Text.Regex.TDFA
import System.Directory
import System.Exit
import System.Process.ByteString (readProcessWithExitCode)
-import System.IO (withFile, IOMode(..), hFileSize)
+import System.IO (withFile, IOMode(..), hFileSize, hClose)
import System.IO.Error
+import System.IO.Temp
import Prelude
import Futhark.Analysis.Metrics
import Futhark.Representation.Primitive (IntType(..), FloatType(..), intByteSize, floatByteSize)
import Futhark.Test.Values
-import Futhark.Util (directoryContents)
+import Futhark.Util (directoryContents, pmapIO)
import Futhark.Util.Pretty (pretty, prettyText)
import Language.Futhark.Syntax (PrimType(..), Int32)
@@ -120,7 +129,7 @@ instance Show WarningTest where
data TestRun = TestRun
{ runTags :: [String]
, runInput :: Values
- , runExpectedResult :: ExpectedResult Values
+ , runExpectedResult :: ExpectedResult Success
, runIndex :: Int
, runDescription :: String
}
@@ -155,6 +164,14 @@ data ExpectedResult values
| RunTimeFailure ExpectedError -- ^ Execution fails with this error.
deriving (Show)
+-- | The result expected from a succesful execution.
+data Success = SuccessValues Values
+ -- ^ These values are expected.
+ | SuccessGenerateValues
+ -- ^ Compute expected values from executing a known-good
+ -- reference implementation.
+ deriving (Show)
+
type Parser = Parsec Void T.Text
lexeme :: Parser a -> Parser a
@@ -221,7 +238,9 @@ parseRunCases = parseRunCases' (0::Int)
parseRunCase i = do
tags <- parseRunTags
lexstr "input"
- input <- if "random" `elem` tags then parseRandomValues else parseValues
+ input <- if "random" `elem` tags
+ then parseRandomValues
+ else parseValues
expr <- parseExpectedResult
return $ TestRun tags input expr i $ desc i input
@@ -242,9 +261,10 @@ parseRunCases = parseRunCases' (0::Int)
desc _ (GenValues gens) =
unwords $ map genValueType gens
-parseExpectedResult :: Parser (ExpectedResult Values)
+parseExpectedResult :: Parser (ExpectedResult Success)
parseExpectedResult =
- (Succeeds . Just <$> (lexstr "output" *> parseValues)) <|>
+ (lexstr "auto" *> lexstr "output" $> Succeeds (Just SuccessGenerateValues)) <|>
+ (Succeeds . Just . SuccessValues <$> (lexstr "output" *> parseValues)) <|>
(RunTimeFailure <$> (lexstr "error:" *> parseExpectedError)) <|>
pure (Succeeds Nothing)
@@ -461,16 +481,29 @@ getValuesBS dir (InFile file) =
getValuesBS dir (GenValues gens) =
mconcat <$> mapM (getGenBS dir) gens
+-- | There is a risk of race conditions when multiple programs have
+-- identical 'GenValues'. In such cases, multiple threads in 'futhark
+-- test' might attempt to create the same file (or read from it, while
+-- something else is constructing it). This leads to a mess. To
+-- avoid this, we create a temporary file, and only when it is
+-- complete do we move it into place. It would be better if we could
+-- use file locking, but that does not work on some file systems. The
+-- approach here seems robust enough for now, but certainly it could
+-- be made even better. The race condition that remains should mostly
+-- result in duplicate work, not crashes or data corruption.
getGenBS :: MonadIO m => FilePath -> GenValue -> m BS.ByteString
getGenBS dir gen = do
+ liftIO $ createDirectoryIfMissing True $ dir </> "data"
exists_and_proper_size <- liftIO $
withFile (dir </> file) ReadMode (fmap (== genFileSize gen) . hFileSize)
`catch` \ex -> if isDoesNotExistError ex then return False
else E.throw ex
unless exists_and_proper_size $ liftIO $ do
s <- genValues [gen]
- createDirectoryIfMissing True $ takeDirectory $ dir </> file
- SBS.writeFile (dir </> file) s
+ withTempFile (dir </> "data") (genFileName gen) $ \tmpfile h -> do
+ hClose h -- We will be writing and reading this ourselves.
+ SBS.writeFile tmpfile s
+ renameFile tmpfile $ dir </> file
getValuesBS dir $ InFile file
where file = "data" </> genFileName gen
@@ -501,3 +534,94 @@ genFileSize = genSize
primSize (Unsigned it) = intByteSize it
primSize (FloatType ft) = floatByteSize ft
primSize Bool = 1
+
+-- | When/if generating a reference output file for this run, what
+-- should it be called? Includes the "data/" folder.
+testRunReferenceOutput :: FilePath -> T.Text -> TestRun -> FilePath
+testRunReferenceOutput prog entry tr =
+ "data"
+ </> takeBaseName prog
+ <> ":" <> T.unpack entry
+ <> "-" <> map clean (runDescription tr)
+ <.> "out"
+ where clean '/' = '_' -- Would this ever happen?
+ clean ' ' = '_'
+ clean c = c
+
+-- | Get the values corresponding to an expected result, if any.
+getExpectedResult :: MonadIO m =>
+ FilePath -> T.Text -> TestRun
+ -> m (ExpectedResult [Value])
+getExpectedResult prog entry tr =
+ case runExpectedResult tr of
+ (Succeeds (Just (SuccessValues vals))) ->
+ Succeeds . Just <$> getValues (takeDirectory prog) vals
+ Succeeds (Just SuccessGenerateValues) ->
+ getExpectedResult prog entry
+ tr { runExpectedResult = Succeeds $ Just $ SuccessValues $ InFile $
+ testRunReferenceOutput prog entry tr }
+ Succeeds Nothing ->
+ return $ Succeeds Nothing
+ RunTimeFailure err ->
+ return $ RunTimeFailure err
+
+compileProgram :: (MonadIO m, MonadError [T.Text] m) =>
+ [String] -> FilePath -> String -> FilePath
+ -> m (SBS.ByteString, SBS.ByteString)
+compileProgram extra_options futhark backend program = do
+ (futcode, stdout, stderr) <- liftIO $ readProcessWithExitCode futhark (backend:options) ""
+ case futcode of
+ ExitFailure 127 -> throwError [progNotFound $ T.pack futhark]
+ ExitFailure _ -> throwError [T.decodeUtf8 stderr]
+ ExitSuccess -> return ()
+ return (stdout, stderr)
+ where binOutputf = dropExtension program
+ options = [program, "-o", binOutputf] ++ extra_options
+ progNotFound s = s <> ": command not found"
+
+runProgram :: MonadIO m =>
+ String -> [String]
+ -> String -> T.Text -> Values
+ -> m (ExitCode, SBS.ByteString, SBS.ByteString)
+runProgram runner extra_options prog entry input = do
+ let progbin = dropExtension prog
+ dir = takeDirectory prog
+ binpath = "." </> progbin
+ entry_options = ["-e", T.unpack entry]
+
+ (to_run, to_run_args)
+ | null runner = (binpath, entry_options ++ extra_options)
+ | otherwise = (runner, binpath : entry_options ++ extra_options)
+
+ input' <- getValuesBS dir input
+ liftIO $ readProcessWithExitCode to_run to_run_args $ BS.toStrict input'
+
+-- | Ensure that any reference output files exist, or create them (by
+-- compiling the program with the reference compiler and running it on
+-- the input) if necessary.
+ensureReferenceOutput :: (MonadIO m, MonadError [T.Text] m) =>
+ FilePath -> String -> FilePath -> [InputOutputs]
+ -> m ()
+ensureReferenceOutput futhark compiler prog ios = do
+ missing <- filterM isReferenceMissing $ concatMap entryAndRuns ios
+ unless (null missing) $ do
+ void $ compileProgram [] futhark compiler prog
+ liftIO $ void $ flip pmapIO missing $ \(entry, tr) -> do
+ (code, stdout, stderr) <- runProgram "" ["-b"] prog entry $ runInput tr
+ case code of
+ ExitFailure e ->
+ fail $ "Reference dataset generation failed with exit code " ++
+ show e ++ " and stderr:\n" ++
+ map (chr . fromIntegral) (SBS.unpack stderr)
+ ExitSuccess ->
+ SBS.writeFile (file (entry, tr)) stdout
+ where file (entry, tr) =
+ takeDirectory prog </> testRunReferenceOutput prog entry tr
+
+ entryAndRuns (InputOutputs entry rts) = map (entry,) rts
+
+ isReferenceMissing (entry, tr)
+ | Succeeds (Just SuccessGenerateValues) <- runExpectedResult tr =
+ liftIO . fmap not . doesFileExist . file $ (entry, tr)
+ | otherwise =
+ return False
diff --git a/src/Futhark/Test/Values.hs b/src/Futhark/Test/Values.hs
index 0f5fad5..dd2d399 100644
--- a/src/Futhark/Test/Values.hs
+++ b/src/Futhark/Test/Values.hs
@@ -17,8 +17,8 @@ module Futhark.Test.Values
-- * Comparing Values
, compareValues
+ , compareValues1
, Mismatch
- , explainMismatch
)
where
@@ -29,7 +29,6 @@ import Data.Binary.Put
import Data.Binary.Get
import Data.Binary.IEEE754
import qualified Data.ByteString.Lazy.Char8 as BS
-import Data.Maybe
import Data.Int (Int8, Int16, Int32, Int64)
import Data.Char (isSpace, ord, chr)
import Data.Vector.Binary
@@ -46,6 +45,7 @@ import qualified Futhark.Util.Pretty as PP
import Futhark.Representation.AST.Attributes.Constants (IsValue(..))
import Futhark.Representation.AST.Pretty ()
import Futhark.Util.Pretty
+import Futhark.Util (maybeHead)
type STVector s = UMVec.STVector s
type Vector = UVec.Vector
@@ -452,7 +452,8 @@ readValues = readValues' . dropSpaces
-- Comparisons
--- | Two values differ in some way.
+-- | Two values differ in some way. The 'Show' instance produces a
+-- human-readable explanation.
data Mismatch = PrimValueMismatch (Int,Int) PrimValue PrimValue
-- ^ The position the value number and a flat index
-- into the array.
@@ -477,17 +478,18 @@ explainMismatch i what got expected =
-- | Compare two sets of Futhark values for equality. Shapes and
-- types must also match.
-compareValues :: [Value] -> [Value] -> Maybe [Mismatch]
+compareValues :: [Value] -> [Value] -> [Mismatch]
compareValues got expected
- | n /= m = Just [ValueCountMismatch n m]
- | otherwise = case catMaybes $ zipWith3 compareValue [0..] got expected of
- [] -> Nothing
- es -> Just es
+ | n /= m = [ValueCountMismatch n m]
+ | otherwise = concat $ zipWith3 compareValue [0..] got expected
where n = length got
m = length expected
+-- | As 'compareValues', but only reports one mismatch.
+compareValues1 :: [Value] -> [Value] -> Maybe Mismatch
+compareValues1 got expected = maybeHead $ compareValues got expected
-compareValue :: Int -> Value -> Value -> Maybe Mismatch
+compareValue :: Int -> Value -> Value -> [Mismatch]
compareValue i got_v expected_v
| valueShape got_v == valueShape expected_v =
case (got_v, expected_v) of
@@ -514,29 +516,29 @@ compareValue i got_v expected_v
(BoolValue _ got_vs, BoolValue _ expected_vs) ->
compareGen compareBool got_vs expected_vs
_ ->
- Just $ TypeMismatch i (pretty $ valueElemType got_v) (pretty $ valueElemType expected_v)
+ [TypeMismatch i (pretty $ valueElemType got_v) (pretty $ valueElemType expected_v)]
| otherwise =
- Just $ ArrayShapeMismatch i (valueShape got_v) (valueShape expected_v)
+ [ArrayShapeMismatch i (valueShape got_v) (valueShape expected_v)]
where compareNum tol = compareGen $ compareElement tol
compareFloat tol = compareGen $ compareFloatElement tol
compareGen cmp got expected =
- foldl mplus Nothing $
+ concat $
zipWith cmp (UVec.toList $ UVec.indexed got) (UVec.toList expected)
compareElement tol (j, got) expected
- | comparePrimValue tol got expected = Nothing
- | otherwise = Just $ PrimValueMismatch (i,j) (value got) (value expected)
+ | comparePrimValue tol got expected = []
+ | otherwise = [PrimValueMismatch (i,j) (value got) (value expected)]
compareFloatElement tol (j, got) expected
- | isNaN got, isNaN expected = Nothing
+ | isNaN got, isNaN expected = []
| isInfinite got, isInfinite expected,
- signum got == signum expected = Nothing
+ signum got == signum expected = []
| otherwise = compareElement tol (j, got) expected
compareBool (j, got) expected
- | got == expected = Nothing
- | otherwise = Just $ PrimValueMismatch (i,j) (value got) (value expected)
+ | got == expected = []
+ | otherwise = [PrimValueMismatch (i,j) (value got) (value expected)]
comparePrimValue :: (Ord num, Num num) =>
num -> num -> num -> Bool
diff --git a/src/Futhark/Tools.hs b/src/Futhark/Tools.hs
index 23a3423..226ece1 100644
--- a/src/Futhark/Tools.hs
+++ b/src/Futhark/Tools.hs
@@ -20,7 +20,6 @@ module Futhark.Tools
where
import Control.Monad.Identity
-import Data.Semigroup ((<>))
import Futhark.Representation.AST
import Futhark.Representation.SOACS.SOAC
diff --git a/src/Futhark/Transform/FirstOrderTransform.hs b/src/Futhark/Transform/FirstOrderTransform.hs
index e48336a..b0bb7c6 100644
--- a/src/Futhark/Transform/FirstOrderTransform.hs
+++ b/src/Futhark/Transform/FirstOrderTransform.hs
@@ -23,7 +23,6 @@ module Futhark.Transform.FirstOrderTransform
import Control.Monad.Except
import Control.Monad.State
-import Data.Semigroup ((<>))
import qualified Data.Map.Strict as M
import qualified Data.Set as S
diff --git a/src/Futhark/Transform/Rename.hs b/src/Futhark/Transform/Rename.hs
index eea5958..67b3963 100644
--- a/src/Futhark/Transform/Rename.hs
+++ b/src/Futhark/Transform/Rename.hs
@@ -38,7 +38,6 @@ import Control.Monad.Reader
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Data.Maybe
-import Data.Semigroup ((<>))
import Futhark.Representation.AST.Syntax
import Futhark.Representation.AST.Traversals
diff --git a/src/Futhark/TypeCheck.hs b/src/Futhark/TypeCheck.hs
index ab653e7..3f68dee 100644
--- a/src/Futhark/TypeCheck.hs
+++ b/