summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPatrickBahr <>2020-08-10 15:18:00 (GMT)
committerhdiff <hdiff@hdiff.luite.com>2020-08-10 15:18:00 (GMT)
commit4e4effaff27ef38f29731eb835ff48879017578f (patch)
tree3df0974fcabbd6dc10f8ea9db4d4e4823dfa52a0
parent38ebeba655f97f40f7447058a05d5dd84e7abbe7 (diff)
version 0.30.3
-rwxr-xr-xCHANGELOG.md8
-rw-r--r--Rattus.cabal21
-rw-r--r--src/Rattus.hs11
-rw-r--r--src/Rattus/Plugin.hs66
-rw-r--r--src/Rattus/Plugin/Annotation.hs30
-rw-r--r--src/Rattus/Plugin/Dependency.hs334
-rw-r--r--src/Rattus/Plugin/ScopeCheck.hs842
-rw-r--r--src/Rattus/Plugin/StableSolver.hs7
-rw-r--r--src/Rattus/Plugin/Utils.hs34
-rw-r--r--src/Rattus/Stream.hs12
-rw-r--r--test/IllTyped.hs20
-rw-r--r--test/TimeLeak.hs1
-rw-r--r--test/WellTyped.hs29
13 files changed, 1086 insertions, 329 deletions
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 24d4cdd..f176842 100755
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,3 +1,11 @@
+0.3
+---
+
+Rattus code is now checked just after GHC's type checking phase
+(instead of after desugaring to Core). As a consequence, error
+messages for some corner cases are much improved and we don't need
+to use the -g2 compiler option anymore to get good error messages.
+
0.2
---
diff --git a/Rattus.cabal b/Rattus.cabal
index 860142c..5179205 100644
--- a/Rattus.cabal
+++ b/Rattus.cabal
@@ -1,6 +1,6 @@
cabal-version: 1.18
name: Rattus
-version: 0.2
+version: 0.3
category: FRP
synopsis: A modal FRP language
description:
@@ -68,8 +68,7 @@ description:
non-standard typing rules. To write Rattus programs, one
must enable this plugin via the GHC option
@-fplugin=Rattus.Plugin@, e.g. by including the following
- line in the source file (for better error messages we also
- suggest using the option @-g2@):
+ line in the source file:
.
@@ -149,13 +148,15 @@ library
Rattus.Primitives
other-modules: Rattus.Plugin.ScopeCheck
+ Rattus.Plugin.Annotation
Rattus.Plugin.Strictify
Rattus.Plugin.Utils
+ Rattus.Plugin.Dependency
Rattus.Plugin.StableSolver
- build-depends: base >=4.12 && <5, containers, simple-affine-space, ghc >= 8.6
+ build-depends: base >=4.12 && <5, containers, simple-affine-space, ghc >= 8.6 && < 8.11
hs-source-dirs: src
default-language: Haskell2010
- ghc-options: -W -g2
+ ghc-options: -W
Test-Suite memory-leak
type: exitcode-stdio-1.0
@@ -163,7 +164,7 @@ Test-Suite memory-leak
hs-source-dirs: test
default-language: Haskell2010
build-depends: Rattus, base
- ghc-options: -fplugin=Rattus.Plugin -rtsopts -g2
+ ghc-options: -fplugin=Rattus.Plugin -rtsopts
Test-Suite time-leak
@@ -172,14 +173,14 @@ Test-Suite time-leak
hs-source-dirs: test
default-language: Haskell2010
build-depends: Rattus, base
- ghc-options: -fplugin=Rattus.Plugin -rtsopts -g2
+ ghc-options: -fplugin=Rattus.Plugin -rtsopts
Test-Suite ill-typed
type: exitcode-stdio-1.0
main-is: test/IllTyped.hs
default-language: Haskell2010
build-depends: Rattus, base
- ghc-options: -fplugin=Rattus.Plugin -rtsopts -g2
+ ghc-options: -fplugin=Rattus.Plugin -rtsopts
Test-Suite well-typed
@@ -188,7 +189,7 @@ Test-Suite well-typed
hs-source-dirs: test
default-language: Haskell2010
build-depends: Rattus, base, containers
- ghc-options: -fplugin=Rattus.Plugin -rtsopts -g2
+ ghc-options: -fplugin=Rattus.Plugin -rtsopts
Test-Suite rewrite
@@ -197,5 +198,5 @@ Test-Suite rewrite
hs-source-dirs: test
default-language: Haskell2010
build-depends: Rattus, base, containers
- ghc-options: -fplugin=Rattus.Plugin -rtsopts -g2
+ ghc-options: -fplugin=Rattus.Plugin -rtsopts
diff --git a/src/Rattus.hs b/src/Rattus.hs
index 5886c1e..037f4dd 100644
--- a/src/Rattus.hs
+++ b/src/Rattus.hs
@@ -16,7 +16,10 @@ module Rattus (
(|*|),
(|**),
(<*>),
- (<**))
+ (<**),
+ -- * box for stable types
+ box'
+ )
where
import Rattus.Plugin
@@ -49,3 +52,9 @@ f |*| x = box (unbox f (unbox x))
{-# INLINE (|**) #-}
(|**) :: Stable a => Box (a -> b) -> a -> Box b
f |** x = box (unbox f x)
+
+
+-- | Variant of 'box' for stable types that can be safely used nested
+-- in recursive definitions or in another box.
+box' :: Stable a => a -> Box a
+box' x = box x
diff --git a/src/Rattus/Plugin.hs b/src/Rattus/Plugin.hs
index 83b91b5..78d6aaa 100644
--- a/src/Rattus/Plugin.hs
+++ b/src/Rattus/Plugin.hs
@@ -9,43 +9,19 @@ import Rattus.Plugin.StableSolver
import Rattus.Plugin.ScopeCheck
import Rattus.Plugin.Strictify
import Rattus.Plugin.Utils
+import Rattus.Plugin.Annotation
import Prelude hiding ((<>))
import GhcPlugins
+import TcRnTypes
+
-import qualified Data.Set as Set
import Control.Monad
import Data.Maybe
import Data.Data hiding (tyConName)
-import System.Exit
--- | Use this type to mark a Haskell function definition as a Rattus
--- function:
---
--- > {-# ANN myFunction Rattus #-}
---
--- Or mark a whole module as consisting of Rattus functions only:
---
--- > {-# ANN module Rattus #-}
---
--- If you use the latter option, you can mark exceptions
--- (i.e. functions that should be treated as ordinary Haskell function
--- definitions) as follows:
---
--- > {-# ANN myFunction NotRattus #-}
---
--- By default all Rattus functions are checked for use of lazy data
--- types, since these may cause memory leaks. If any lazy data types
--- are used, a warning is issued. These warnings can be disabled by
--- annotating the module or the function with 'AllowLazyData'
---
--- > {-# ANN myFunction AllowLazyData #-}
--- >
--- > {-# ANN module AllowLazyData #-}
-
-data Rattus = Rattus | NotRattus | AllowLazyData deriving (Typeable, Data, Show, Eq)
-- | Use this to enable Rattus' plugin, either by supplying the option
-- @-fplugin=Rattus.Plugin@ directly to GHC. or by including the
@@ -55,17 +31,18 @@ data Rattus = Rattus | NotRattus | AllowLazyData deriving (Typeable, Data, Show,
plugin :: Plugin
plugin = defaultPlugin {
installCoreToDos = install,
- tcPlugin = tcStable,
- pluginRecompile = purePlugin
+ pluginRecompile = purePlugin,
+ typeCheckResultAction = typechecked,
+ tcPlugin = tcStable
}
+typechecked :: [CommandLineOption] -> ModSummary -> TcGblEnv -> TcM TcGblEnv
+typechecked _ _ env = checkAll env >> return env
install :: [CommandLineOption] -> [CoreToDo] -> CoreM [CoreToDo]
-install _ todo = return (scPass : strPass : todo)
- where scPass = CoreDoPluginPass "Rattus scopecheck" scopecheckProgram
-
- strPass = CoreDoPluginPass "Rattus strictify" strictifyProgram
+install _ todo = return (strPass : todo)
+ where strPass = CoreDoPluginPass "Rattus strictify" strictifyProgram
strictifyProgram :: ModGuts -> CoreM ModGuts
strictifyProgram guts = do
@@ -90,29 +67,6 @@ strictify guts b@(NonRec v e) = do
return (NonRec v e')
else return b
-
-scopecheckProgram :: ModGuts -> CoreM ModGuts
-scopecheckProgram guts = do
- res <- mapM (scopecheck guts) (mg_binds guts)
- if and res then return guts else liftIO exitFailure
-
-
-scopecheck :: ModGuts -> CoreBind -> CoreM Bool
-scopecheck guts (Rec bs) = do
- tr <- liftM or (mapM (shouldTransform guts . fst) bs)
- if tr then do
- let vs = map fst bs
- let vs' = Set.fromList vs
- valid <- mapM (\ (v,e) -> checkExpr (emptyCtx (Just vs') v) e) bs
- return (and valid)
- else return True
-scopecheck guts (NonRec v e) = do
- tr <- shouldTransform guts v
- if tr then do
- valid <- checkExpr (emptyCtx Nothing v) e
- return valid
- else return True
-
getModuleAnnotations :: Data a => ModGuts -> [a]
getModuleAnnotations guts = anns'
where anns = filter (\a-> case ann_target a of
diff --git a/src/Rattus/Plugin/Annotation.hs b/src/Rattus/Plugin/Annotation.hs
new file mode 100644
index 0000000..29941c2
--- /dev/null
+++ b/src/Rattus/Plugin/Annotation.hs
@@ -0,0 +1,30 @@
+{-# LANGUAGE DeriveDataTypeable #-}
+module Rattus.Plugin.Annotation (Rattus(..)) where
+
+import Data.Data
+
+-- | Use this type to mark a Haskell function definition as a Rattus
+-- function:
+--
+-- > {-# ANN myFunction Rattus #-}
+--
+-- Or mark a whole module as consisting of Rattus functions only:
+--
+-- > {-# ANN module Rattus #-}
+--
+-- If you use the latter option, you can mark exceptions
+-- (i.e. functions that should be treated as ordinary Haskell function
+-- definitions) as follows:
+--
+-- > {-# ANN myFunction NotRattus #-}
+--
+-- By default all Rattus functions are checked for use of lazy data
+-- types, since these may cause memory leaks. If any lazy data types
+-- are used, a warning is issued. These warnings can be disabled by
+-- annotating the module or the function with 'AllowLazyData'
+--
+-- > {-# ANN myFunction AllowLazyData #-}
+-- >
+-- > {-# ANN module AllowLazyData #-}
+
+data Rattus = Rattus | NotRattus | AllowLazyData deriving (Typeable, Data, Show, Eq)
diff --git a/src/Rattus/Plugin/Dependency.hs b/src/Rattus/Plugin/Dependency.hs
new file mode 100644
index 0000000..968b3a0
--- /dev/null
+++ b/src/Rattus/Plugin/Dependency.hs
@@ -0,0 +1,334 @@
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE CPP #-}
+
+-- | This module is used to perform a dependency analysis of top-level
+-- function definitions, i.e. to find out which defintions are
+-- (mutual) recursive. To this end, this module also provides a
+-- functions to compute, bound variables and variable occurrences.
+
+module Rattus.Plugin.Dependency (dependency, HasBV (..)) where
+
+
+import GhcPlugins
+import Bag
+
+
+#if __GLASGOW_HASKELL__ >= 810
+import GHC.Hs.Extension
+import GHC.Hs.Expr
+import GHC.Hs.Pat
+import GHC.Hs.Binds
+import GHC.Hs.Types
+#else
+import HsExtension
+import HsExpr
+import HsPat
+import HsBinds
+import HsTypes
+#endif
+
+import Data.Set (Set)
+import qualified Data.Set as Set
+import Data.Graph
+import Data.Maybe
+import Data.Either
+import Prelude hiding ((<>))
+
+
+
+
+
+-- | Compute the dependencies of a bag of bindings, returning a list
+-- of the strongly-connected components.
+dependency :: Bag (LHsBindLR GhcTc GhcTc) -> [SCC (LHsBindLR GhcTc GhcTc, Set Var)]
+dependency binds = map AcyclicSCC noDeps ++ catMaybes (map filterJust (stronglyConnComp (concat deps)))
+ where (deps,noDeps) = partitionEithers $ map mkDep $ bagToList binds
+ mkDep :: GenLocated l (HsBindLR GhcTc GhcTc) ->
+ Either [(Maybe (GenLocated l (HsBindLR GhcTc GhcTc), Set Var), Name, [Name])]
+ (GenLocated l (HsBindLR GhcTc GhcTc), Set Var)
+ mkDep b =
+ let dep = map varName $ Set.toList (getFV b)
+ vars = getBV b in
+ case Set.toList vars of
+ (v:vs) -> Left ((Just (b,vars), varName v , dep) : map (\ v' -> (Nothing, varName v' , dep)) vs)
+ [] -> Right (b,vars)
+ filterJust (AcyclicSCC Nothing) = Nothing -- this should not happen
+ filterJust (AcyclicSCC (Just b)) = Just (AcyclicSCC b)
+ filterJust (CyclicSCC bs) = Just (CyclicSCC (catMaybes bs))
+
+
+-- printBinds (AcyclicSCC bind) = liftIO (putStr "acyclic bind: ") >> printBind (fst bind) >> liftIO (putStrLn "")
+-- printBinds (CyclicSCC binds) = liftIO (putStr "cyclic binds: ") >> mapM_ (printBind . fst) binds >> liftIO (putStrLn "")
+
+
+-- printBind (L _ FunBind{fun_id = L _ name}) =
+-- liftIO $ putStr $ (getOccString name ++ " ")
+-- printBind (L _ (AbsBinds {abs_exports = exp})) =
+-- mapM_ (\ e -> liftIO $ putStr $ ((getOccString $ abe_poly e) ++ " ")) exp
+-- printBind (L _ (VarBind {var_id = name})) = liftIO $ putStr $ (getOccString name ++ " ")
+-- printBind _ = return ()
+
+-- | Computes the variables that are bound by a given piece of syntax.
+
+class HasBV a where
+ getBV :: a -> Set Var
+
+instance HasBV (HsBindLR GhcTc GhcTc) where
+ getBV (FunBind{fun_id = L _ v}) = Set.singleton v
+ getBV (AbsBinds {abs_exports = es}) = Set.fromList (map abe_poly es)
+ getBV (PatBind {pat_lhs = pat}) = getBV pat
+ getBV (VarBind {var_id = v}) = Set.singleton v
+ getBV PatSynBind{} = Set.empty
+ getBV XHsBindsLR{} = Set.empty
+
+instance HasBV a => HasBV (GenLocated b a) where
+ getBV (L _ e) = getBV e
+
+instance HasBV a => HasBV [a] where
+ getBV ps = foldl (\s p -> getBV p `Set.union` s) Set.empty ps
+
+
+
+getConBV (PrefixCon ps) = getBV ps
+getConBV (InfixCon p p') = getBV p `Set.union` getBV p'
+getConBV (RecCon (HsRecFields {rec_flds = fs})) = foldl run Set.empty fs
+ where run s (L _ f) = getBV (hsRecFieldArg f) `Set.union` s
+
+instance HasBV (Pat GhcTc) where
+ getBV (VarPat _ (L _ v)) = Set.singleton v
+ getBV (LazyPat _ p) = getBV p
+ getBV (AsPat _ (L _ v) p) = Set.insert v (getBV p)
+ getBV (ParPat _ p) = getBV p
+ getBV (BangPat _ p) = getBV p
+ getBV (ListPat _ ps) = getBV ps
+ getBV (TuplePat _ ps _) = getBV ps
+ getBV (SumPat _ p _ _) = getBV p
+ getBV (ConPatIn (L _ v) con) = Set.insert v (getConBV con)
+ getBV (ConPatOut {pat_args = con}) = getConBV con
+ getBV (ViewPat _ _ p) = getBV p
+ getBV (SplicePat _ sp) =
+ case sp of
+ HsTypedSplice _ _ v _ -> Set.singleton v
+ HsUntypedSplice _ _ v _ -> Set.singleton v
+ HsQuasiQuote _ p p' _ _ -> Set.fromList [p,p']
+ HsSpliced _ _ (HsSplicedPat p) -> getBV p
+ _ -> Set.empty
+ getBV (NPlusKPat _ (L _ v) _ _ _ _) = Set.singleton v
+ getBV (CoPat _ _ p _) = getBV p
+ getBV (NPat {}) = Set.empty
+ getBV (XPat p) = getBV p
+ getBV (WildPat {}) = Set.empty
+ getBV (LitPat {}) = Set.empty
+
+#if __GLASGOW_HASKELL__ >= 808
+ getBV (SigPat _ p _) =
+#else
+ getBV (SigPat _ p) =
+#endif
+ getBV p
+
+#if __GLASGOW_HASKELL__ >= 810
+instance HasBV NoExtCon where
+#else
+instance HasBV NoExt where
+#endif
+ getBV _ = Set.empty
+
+
+-- | Syntax that may contain variables.
+class HasFV a where
+ -- | Compute the set of variables occurring in the given piece of
+ -- syntax. The name falsely suggests that returns free variables,
+ -- but in fact it returns all variable occurrences, no matter
+ -- whether they are free or bound.
+ getFV :: a -> Set Var
+
+instance HasFV a => HasFV (GenLocated b a) where
+ getFV (L _ e) = getFV e
+
+instance HasFV a => HasFV [a] where
+ getFV es = foldMap getFV es
+
+instance HasFV a => HasFV (Bag a) where
+ getFV es = foldMap getFV es
+
+instance HasFV Var where
+ getFV v = Set.singleton v
+
+instance HasFV a => HasFV (MatchGroup GhcTc a) where
+ getFV MG {mg_alts = alts} = getFV alts
+ getFV XMatchGroup{} = Set.empty
+
+instance HasFV a => HasFV (Match GhcTc a) where
+ getFV Match {m_grhss = rhss} = getFV rhss
+ getFV XMatch{} = Set.empty
+
+instance HasFV (HsTupArg GhcTc) where
+ getFV (Present _ e) = getFV e
+ getFV _ = Set.empty
+
+
+instance HasFV a => HasFV (GRHS GhcTc a) where
+ getFV (GRHS _ g b) = getFV g `Set.union` getFV b
+ getFV XGRHS{} = Set.empty
+
+instance HasFV a => HasFV (GRHSs GhcTc a) where
+ getFV GRHSs {grhssGRHSs = rhs, grhssLocalBinds = lbs} =
+ getFV rhs `Set.union` getFV lbs
+ getFV _ = Set.empty
+
+
+instance HasFV (HsLocalBindsLR GhcTc GhcTc) where
+ getFV (HsValBinds _ bs) = getFV bs
+ getFV (HsIPBinds _ bs) = getFV bs
+ getFV _ = Set.empty
+
+instance HasFV (HsValBindsLR GhcTc GhcTc) where
+ getFV (ValBinds _ b _) = getFV b
+ getFV _ = Set.empty
+
+instance HasFV (HsBindLR GhcTc GhcTc) where
+ getFV FunBind {fun_matches = ms} = getFV ms
+ getFV PatBind {pat_rhs = rhs} = getFV rhs
+ getFV VarBind {var_rhs = rhs} = getFV rhs
+ getFV AbsBinds {abs_binds = bs} = getFV bs
+ getFV _ = Set.empty
+
+
+instance HasFV (IPBind GhcTc) where
+ getFV (IPBind _ _ e) = getFV e
+ getFV _ = Set.empty
+
+instance HasFV (HsIPBinds GhcTc) where
+ getFV (IPBinds _ bs) = getFV bs
+ getFV _ = Set.empty
+
+instance HasFV (ApplicativeArg GhcTc) where
+#if __GLASGOW_HASKELL__ >= 810
+ getFV (ApplicativeArgOne _ _ e _ _)
+#else
+ getFV (ApplicativeArgOne _ _ e _)
+#endif
+ = getFV e
+ getFV (ApplicativeArgMany _ es e _) = getFV es `Set.union` getFV e
+ getFV XApplicativeArg{} = Set.empty
+
+instance HasFV (ParStmtBlock GhcTc GhcTc) where
+ getFV (ParStmtBlock _ es _ _) = getFV es
+ getFV XParStmtBlock{} = Set.empty
+
+instance HasFV a => HasFV (StmtLR GhcTc GhcTc a) where
+ getFV (LastStmt _ e _ _) = getFV e
+ getFV (BindStmt _ _ e _ _) = getFV e
+ getFV (ApplicativeStmt _ args _) = foldMap (getFV . snd) args
+ getFV (BodyStmt _ e _ _) = getFV e
+ getFV (LetStmt _ bs) = getFV bs
+ getFV (ParStmt _ stms e _) = getFV stms `Set.union` getFV e
+ getFV TransStmt{} = Set.empty -- TODO
+ getFV RecStmt{} = Set.empty -- TODO
+ getFV XStmtLR{} = Set.empty
+
+instance HasFV (HsRecordBinds GhcTc) where
+ getFV HsRecFields{rec_flds = fs} = getFV fs
+
+instance HasFV (HsRecField' o (LHsExpr GhcTc)) where
+ getFV HsRecField {hsRecFieldArg = arg} = getFV arg
+
+instance HasFV (ArithSeqInfo GhcTc) where
+ getFV (From e) = getFV e
+ getFV (FromThen e1 e2) = getFV e1 `Set.union` getFV e2
+ getFV (FromTo e1 e2) = getFV e1 `Set.union` getFV e2
+ getFV (FromThenTo e1 e2 e3) = getFV e1 `Set.union` getFV e2 `Set.union` getFV e3
+
+instance HasFV (HsBracket GhcTc) where
+ getFV (ExpBr _ e) = getFV e
+ getFV (VarBr _ _ e) = getFV e
+ getFV _ = Set.empty
+
+instance HasFV (HsCmd GhcTc) where
+ getFV (HsCmdArrApp _ e1 e2 _ _) = getFV e1 `Set.union` getFV e2
+ getFV (HsCmdArrForm _ e _ _ cmd) = getFV e `Set.union` getFV cmd
+ getFV (HsCmdApp _ e1 e2) = getFV e1 `Set.union` getFV e2
+ getFV (HsCmdLam _ l) = getFV l
+ getFV (HsCmdPar _ cmd) = getFV cmd
+ getFV (HsCmdCase _ _ mg) = getFV mg
+ getFV (HsCmdIf _ _ e1 e2 e3) = getFV e1 `Set.union` getFV e2 `Set.union` getFV e3
+ getFV (HsCmdLet _ bs _) = getFV bs
+ getFV (HsCmdDo _ cmd) = getFV cmd
+ getFV (HsCmdWrap _ _ cmd) = getFV cmd
+ getFV XCmd{} = Set.empty
+
+
+instance HasFV (HsCmdTop GhcTc) where
+ getFV (HsCmdTop _ cmd) = getFV cmd
+ getFV XCmdTop{} = Set.empty
+
+instance HasFV (HsExpr GhcTc) where
+ getFV (HsVar _ v) = getFV v
+ getFV HsUnboundVar {} = Set.empty
+ getFV HsConLikeOut {} = Set.empty
+ getFV HsRecFld {} = Set.empty
+ getFV HsOverLabel {} = Set.empty
+ getFV HsIPVar {} = Set.empty
+ getFV HsOverLit {} = Set.empty
+ getFV HsLit {} = Set.empty
+ getFV (HsLam _ mg) = getFV mg
+ getFV (HsLamCase _ mg) = getFV mg
+ getFV (HsApp _ e1 e2) = getFV e1 `Set.union` getFV e2
+
+#if __GLASGOW_HASKELL__ >= 808
+ getFV (HsAppType _ e _)
+#else
+ getFV (HsAppType _ e)
+#endif
+ = getFV e
+
+ getFV (OpApp _ e1 e2 e3) = getFV e1 `Set.union` getFV e2 `Set.union` getFV e3
+ getFV (NegApp _ e _) = getFV e
+ getFV (HsPar _ e) = getFV e
+ getFV (SectionL _ e1 e2) = getFV e1 `Set.union` getFV e2
+ getFV (SectionR _ e1 e2) = getFV e1 `Set.union` getFV e2
+ getFV (ExplicitTuple _ es _) = getFV es
+ getFV (ExplicitSum _ _ _ e) = getFV e
+ getFV (HsCase _ e mg) = getFV e `Set.union` getFV mg
+ getFV (HsIf _ _ e1 e2 e3) = getFV e1 `Set.union` getFV e2 `Set.union` getFV e3
+ getFV (HsMultiIf _ es) = getFV es
+ getFV (HsLet _ bs e) = getFV bs `Set.union` getFV e
+ getFV (HsDo _ _ e) = getFV e
+ getFV (ExplicitList _ _ es) = getFV es
+ getFV (RecordCon {rcon_flds = fs}) = getFV fs
+ getFV (RecordUpd {rupd_expr = e, rupd_flds = fs}) = getFV e `Set.union` getFV fs
+
+#if __GLASGOW_HASKELL__ >= 808
+ getFV (ExprWithTySig _ e _)
+#else
+ getFV (ExprWithTySig _ e)
+#endif
+ = getFV e
+
+ getFV (ArithSeq _ _ e) = getFV e
+ getFV (HsSCC _ _ _ e) = getFV e
+ getFV (HsCoreAnn _ _ _ e) = getFV e
+ getFV (HsBracket _ e) = getFV e
+ getFV HsRnBracketOut {} = Set.empty
+ getFV HsTcBracketOut {} = Set.empty
+ getFV HsSpliceE{} = Set.empty
+ getFV (HsProc _ _ e) = getFV e
+ getFV (HsStatic _ e) = getFV e
+
+#if __GLASGOW_HASKELL__ < 810
+ getFV (HsArrApp _ e1 e2 _ _) = getFV e1 `Set.union` getFV e2
+ getFV (HsArrForm _ e _ cmd) = getFV e `Set.union` getFV cmd
+ getFV EWildPat {} = Set.empty
+ getFV (EAsPat _ e1 e2) = getFV e1 `Set.union` getFV e2
+ getFV (EViewPat _ e1 e2) = getFV e1 `Set.union` getFV e2
+ getFV (ELazyPat _ e) = getFV e
+#endif
+
+ getFV (HsTick _ _ e) = getFV e
+ getFV (HsBinTick _ _ _ e) = getFV e
+ getFV (HsTickPragma _ _ _ _ e) = getFV e
+ getFV (HsWrap _ _ e) = getFV e
+ getFV XExpr{} = Set.empty
+
diff --git a/src/Rattus/Plugin/ScopeCheck.hs b/src/Rattus/Plugin/ScopeCheck.hs
index d72a629..fc87f18 100644
--- a/src/Rattus/Plugin/ScopeCheck.hs
+++ b/src/Rattus/Plugin/ScopeCheck.hs
@@ -1,274 +1,668 @@
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE FlexibleInstances #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE OverloadedStrings #-}
-module Rattus.Plugin.ScopeCheck (checkExpr, emptyCtx) where
+{-# LANGUAGE ImplicitParams #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE CPP #-}
-import Rattus.Plugin.Utils
+
+
+-- | This module implements the source plugin that checks the variable
+-- scope of of Rattus programs.
+
+module Rattus.Plugin.ScopeCheck (checkAll) where
+
+import Rattus.Plugin.Utils
+import Rattus.Plugin.Dependency
+import Rattus.Plugin.Annotation
import Prelude hiding ((<>))
import GhcPlugins
-import Control.Monad
-import Data.Set (Set)
+import TcRnTypes
+import Bag
+
+#if __GLASGOW_HASKELL__ >= 810
+import GHC.Hs.Extension
+import GHC.Hs.Expr
+import GHC.Hs.Pat
+import GHC.Hs.Binds
+#else
+import HsExtension
+import HsExpr
+import HsPat
+import HsBinds
+#endif
+
+import Data.Graph
import qualified Data.Set as Set
-import Data.Map (Map)
import qualified Data.Map as Map
+import Data.Set (Set)
+import Data.Map (Map)
+import System.Exit
import Data.Maybe
-type LCtx = Set Var
-data HiddenReason = BoxApp | AdvApp | NestedRec Var | FunDef
-type Hidden = Map Var HiddenReason
-
-data Prim = Delay | Adv | Box | Unbox | Arr
-
-instance Outputable Prim where
- ppr Delay = "delay"
- ppr Adv = "adv"
- ppr Box = "box"
- ppr Unbox = "unbox"
- ppr Arr = "arr"
-
-
-type RecDef = Set Var
+import Control.Monad
-data Ctx = Ctx
- { current :: LCtx,
+-- | The current context for scope checking
+data Ctxt = Ctxt
+ {
+ -- | Variables that are in scope now (i.e. occurring in the typing
+ -- context but not to the left of a tick)
+ current :: LCtxt,
+ -- | Variables that are in the typing context, but to the left of a
+ -- tick
+ earlier :: Maybe LCtxt,
+ -- | Variables that have fallen out of scope. The map contains the
+ -- reason why they have fallen out of scope.
hidden :: Hidden,
+ -- | Same as 'hidden' but for recursive variables.
hiddenRec :: Hidden,
- earlier :: Maybe LCtx,
+ -- | The current location information.
srcLoc :: SrcSpan,
+ -- | If we are in the body of a recursively defined function, this
+ -- field contains the variables that are defined recursively
+ -- (could be more than one due to mutual recursion or because of a
+ -- recursive pattern definition) and the location of the recursive
+ -- definition.
recDef :: Maybe RecDef,
+ -- | Type variables with a 'Stable' constraint attached to them.
stableTypes :: Set Var,
+ -- | A mapping from variables to the primitives that they are
+ -- defined equal to. For example, a program could contain @let
+ -- mydel = delay in mydel 1@, in which case @mydel@ is mapped to
+ -- 'Delay'.
primAlias :: Map Var Prim,
- funDef :: Var,
- stabilized :: Bool}
+ -- | This flag indicates whether the context was 'stabilized'
+ -- (stripped of all non-stable stuff). It is set when typechecking
+ -- 'box', 'arr' and guarded recursion.
+ stabilized :: Maybe StableReason}
+ deriving Show
-primMap :: Map FastString Prim
-primMap = Map.fromList
- [("Delay", Delay),
- ("delay", Delay),
- ("adv", Adv),
- ("box", Box),
- ("arr", Arr),
- ("unbox", Unbox)]
+-- | The starting context for checking a top-level definition. For
+-- non-recursive definitions, the argument is @Nothing@. Otherwise, it
+-- contains the recursively defined variables along with the location
+-- of the recursive definition.
+emptyCtxt :: Maybe (Set Var,SrcSpan) -> Ctxt
+emptyCtxt mvar =
+ Ctxt { current = Set.empty,
+ earlier = Nothing,
+ hidden = Map.empty,
+ hiddenRec = Map.empty,
+ srcLoc = UnhelpfulSpan "<no location info>",
+ recDef = mvar,
+ primAlias = Map.empty,
+ stableTypes = Set.empty,
+ stabilized = case mvar of
+ Just (_,loc) -> Just (StableRec loc)
+ _ -> Nothing}
-isPrim :: Ctx -> Var -> Maybe Prim
-isPrim c v
- | Just p <- Map.lookup v (primAlias c) = Just p
- | otherwise = do
- (name,mod) <- getNameModule v
- if isRattModule mod then Map.lookup name primMap
- else Nothing
+-- | A local context, consisting of a set of variables.
+type LCtxt = Set Var
-stabilizeLater :: Ctx -> Ctx
-stabilizeLater c =
- if isJust (earlier c)
- then c {earlier = Nothing,
- hidden = hid,
- hiddenRec = maybe (hiddenRec c) (Map.union (hidden c) . Map.fromSet (const FunDef)) (recDef c),
- recDef = Nothing}
- else c {earlier = Nothing,
- hidden = hid}
- where hid = maybe (hidden c) (Map.union (hidden c) . Map.fromSet (const FunDef)) (earlier c)
+-- | The recursively defined variables + the position where the
+-- recursive definition starts
+type RecDef = (Set Var, SrcSpan)
-stabilize :: HiddenReason -> Ctx -> Ctx
-stabilize hr c = c
- {current = Set.empty,
- earlier = Nothing,
- hidden = hidden c `Map.union` Map.fromSet (const hr) ctxHid,
- hiddenRec = hiddenRec c `Map.union` maybe Map.empty (Map.fromSet (const hr)) (recDef c),
- recDef = Nothing,
- stabilized = True}
- where ctxHid = maybe (current c) (Set.union (current c)) (earlier c)
-data Scope = Hidden SDoc | Visible | ImplUnboxed
+data StableReason = StableRec SrcSpan | StableBox | StableArr deriving Show
+
+-- | Indicates, why a variable has fallen out of scope.
+data HiddenReason = Stabilize StableReason | FunDef | AdvApp deriving Show
+
+-- | Hidden context, containing variables that have fallen out of
+-- context along with the reason why they have.
+type Hidden = Map Var HiddenReason
+
+-- | The 4 primitive Rattus operations plus 'arr'.
+data Prim = Delay | Adv | Box | Unbox | Arr deriving Show
+
+-- | This constraint is used to pass along the context implicitly via
+-- an implicit parameter.
+type GetCtxt = ?ctxt :: Ctxt
+
+
+-- | This type class is implemented for each AST type @a@ for which we
+-- can check whether it adheres to the scoping rules of Rattus.
+class Scope a where
+ -- | Check whether the argument is a scope correct piece of syntax
+ -- in the given context.
+ check :: GetCtxt => a -> TcM Bool
+
+-- | This is a variant of 'Scope' for syntax that can also bind
+-- variables.
+class ScopeBind a where
+ -- | 'checkBind' checks whether its argument is scope-correct and in
+-- addition returns the the set of variables bound by it.
+ checkBind :: GetCtxt => a -> TcM (Bool,Set Var)
+
+
+-- | set the current context.
+setCtxt :: Ctxt -> (GetCtxt => a) -> a
+setCtxt c a = let ?ctxt = c in a
+
+
+-- | modify the current context.
+modifyCtxt :: (Ctxt -> Ctxt) -> (GetCtxt => a) -> (GetCtxt => a)
+modifyCtxt f a =
+ let newc = f ?ctxt in
+ let ?ctxt = newc in a
+
+
+-- | Check all definitions in the given module. If Scope errors are
+-- found, the current execution is halted with 'exitFailure'.
+checkAll :: TcGblEnv -> TcM ()
+checkAll env = do
+ let bindDep = filter (filterBinds (tcg_mod env) (tcg_ann_env env)) (dependency (tcg_binds env))
+ res <- mapM checkSCC bindDep
+ if and res then return () else liftIO exitFailure
+
+-- | This function checks whether a given top-level definition (either
+-- a single non-recursive definition or a group of mutual recursive
+-- definitions) is marked as Rattus code (via an annotation). In a
+-- group of mutual recursive definitions, the whole group is
+-- considered Rattus code if at least one of its constituents is
+-- marked as such.
+filterBinds :: Module -> AnnEnv -> SCC (LHsBindLR GhcTc GhcTc, Set Var) -> Bool
+filterBinds mod anEnv scc =
+ case scc of
+ (AcyclicSCC (_,vs)) -> any checkVar vs
+ (CyclicSCC bs) -> any (any checkVar . snd) bs
+ where checkVar :: Var -> Bool
+ checkVar v =
+ let anns = findAnns deserializeWithData anEnv (NamedTarget name) :: [Rattus]
+ annsMod = findAnns deserializeWithData anEnv (ModuleTarget mod) :: [Rattus]
+ name :: Name
+ name = varName v
+ in Rattus `elem` anns || (not (NotRattus `elem` anns) && Rattus `elem` annsMod)
+
+
+instance Scope a => Scope (GenLocated SrcSpan a) where
+ check (L l x) = (\c -> c {srcLoc = l}) `modifyCtxt` check x
+
+
+instance Scope (LHsBinds GhcTc) where
+ check bs = fmap and (mapM check (bagToList bs))
+
+instance Scope a => Scope [a] where
+ check ls = fmap and (mapM check ls)
+
+
+instance Scope a => Scope (Match GhcTc a) where
+ check Match{m_pats=ps,m_grhss=rhs} = mod `modifyCtxt` check rhs
+ where mod c = addVars (getBV ps) (if null ps then c else stabilizeLater c)
+ check XMatch{} = return True
+
+instance Scope a => Scope (MatchGroup GhcTc a) where
+ check MG {mg_alts = alts} = check alts
+ check XMatchGroup {} = return True
+
+instance Scope a => ScopeBind (StmtLR GhcTc GhcTc a) where
+ checkBind (LastStmt _ b _ _) = ( , Set.empty) <$> check b
+ checkBind (BindStmt _ p b _ _) = do
+ let vs = getBV p
+ let c' = addVars vs ?ctxt
+ r <- setCtxt c' (check b)
+ return (r,vs)
+ checkBind (BodyStmt _ b _ _) = ( , Set.empty) <$> check b
+ checkBind (LetStmt _ bs) = checkBind bs
+ checkBind ParStmt{} = notSupported "monad comprehensions"
+ checkBind TransStmt{} = notSupported "monad comprehensions"
+ checkBind ApplicativeStmt{} = notSupported "applicative do notation"
+ checkBind RecStmt{} = notSupported "recursive do notation"
+ checkBind XStmtLR {} = return (True,Set.empty)
+
+
+instance ScopeBind a => ScopeBind [a] where
+ checkBind [] = return (True,Set.empty)
+ checkBind (x:xs) = do
+ (r,vs) <- checkBind x
+ (r',vs') <- addVars vs `modifyCtxt` (checkBind xs)
+ return (r && r',vs `Set.union` vs')
+
+instance ScopeBind a => ScopeBind (GenLocated SrcSpan a) where
+ checkBind (L l x) = (\c -> c {srcLoc = l}) `modifyCtxt` checkBind x
+
+
+instance Scope a => Scope (GRHS GhcTc a) where
+ check (GRHS _ gs b) = do
+ (r, vs) <- checkBind gs
+ r' <- addVars vs `modifyCtxt` (check b)
+ return (r && r')
+ check XGRHS{} = return True
+
+
+
+
+-- | Check the scope of a list of (mutual) recursive bindings. The
+-- second argument is the set of variables defined by the (mutual)
+-- recursive bindings
+checkRecursiveBinds :: GetCtxt => [LHsBindLR GhcTc GhcTc] -> Set Var -> TcM (Bool, Set Var)
+checkRecursiveBinds bs vs = do
+ res <- fmap and (mapM check' bs)
+ case stabilized ?ctxt of
+ Just reason | res ->
+ (printMessage' SevWarning (recReason reason <> " can cause time leaks")) >> return (res, vs)
+ _ -> return (res, vs)
+ where check' b@(L l _) = fc l `modifyCtxt` check b
+ fc l c = let
+ ctxHid = maybe (current c) (Set.union (current c)) (earlier c)
+ recHid = maybe ctxHid (Set.union ctxHid . fst) (recDef c)
+ in c {current = Set.empty,
+ earlier = Nothing,
+ hidden = hidden c `Map.union`
+ (Map.fromSet (const (Stabilize (StableRec l))) recHid),
+ recDef = Just (vs,l),
+ stabilized = Just (StableRec l)}
+
+ recReason :: StableReason -> SDoc
+ recReason (StableRec _) = "nested recursive definitions"
+ recReason StableBox = "recursive definitions nested under box"
+ recReason StableArr = "recursive definitions nested under arr"
+
-getScope :: Ctx -> Var -> Scope
-getScope Ctx{recDef = Just (vs), funDef = recV, earlier = e} v
- | v `Set.member` vs =
- case e of
- Just _ -> Visible
- Nothing
- | recV == v -> Hidden ("Recursive call to " <> ppr v <> " must occur under delay")
- | otherwise -> Hidden ("Mutually recursice call to " <> ppr v <> " must occur under delay")
+
+instance ScopeBind (SCC (LHsBindLR GhcTc GhcTc, Set Var)) where
+ checkBind (AcyclicSCC (b,vs)) = (, vs) <$> check b
+ checkBind (CyclicSCC bs) = checkRecursiveBinds (map fst bs) (foldMap snd bs)
---getScope Ctx{hiddenRecs = h} v
- -- recursive call that is out of scope
--- | (Set.member v h) = Hidden ""
-getScope c v =
- case Map.lookup v (hiddenRec c) of
- Just (NestedRec rv) -> Hidden ("Recursive call to" <> ppr v <>
- " is not allowed as it occurs in a local recursive definiton (namely of " <> ppr rv <> ")")
- Just BoxApp -> Hidden ("Recursive call to " <> ppr v <> " is not allowed here, since it occurs under a box")
- Just FunDef -> Hidden ("Recursive call to " <> ppr v <> " is not allowed here, since it occurs in a function that is defined under delay")
- Just AdvApp -> Hidden ("This should not happen: recursive call to " <> ppr v <> " is out of scope due to adv")
- Nothing ->
- case Map.lookup v (hidden c) of
- Just (NestedRec rv) ->
- if (isStable (stableTypes c) (varType v)) then Visible
- else Hidden ("Variable " <> ppr v <> " is no longer in scope:" $$
- "It appears in a local recursive definiton (namely of " <> ppr rv <> ")"
- $$ "and is of type " <> ppr (varType v) <> ", which is not stable.")
- Just BoxApp ->
- if (isStable (stableTypes c) (varType v)) then Visible
- else Hidden ("Variable " <> ppr v <> " is no longer in scope:" $$
- "It occurs under " <> keyword "box" $$ "and is of type " <> ppr (varType v) <> ", which is not stable.")
- Just AdvApp -> Hidden ("Variable " <> ppr v <> " is no longer in scope: It occurs under adv.")
- Just FunDef -> if (isStable (stableTypes c) (varType v)) then Visible
- else Hidden ("Variable " <> ppr v <> " is no longer in scope: It occurs in a function that is defined under a delay, is a of a non-stable type " <> ppr (varType v) <> ", and is bound outside delay")
- Nothing
- | maybe False (Set.member v) (earlier c) ->
- if isStable (stableTypes c) (varType v) then Visible
- else Hidden ("Variable " <> ppr v <> " is no longer in scope:" $$
- "It occurs under delay" $$ "and is of type " <> ppr (varType v) <> ", which is not stable.")
- | Set.member v (current c) -> Visible
- | isTemporal (varType v) && isJust (earlier c) && userFunction v
- -> ImplUnboxed
- | otherwise -> Visible
+instance ScopeBind (HsValBindsLR GhcTc GhcTc) where
+ checkBind (ValBinds _ bs _) = checkBind (dependency bs)
+
+ checkBind (XValBindsLR (NValBinds binds _)) = checkBind binds
+instance ScopeBind (HsBindLR GhcTc GhcTc) where
+ checkBind b = (, getBV b) <$> check b
-pickFirst :: SrcSpan -> SrcSpan -> SrcSpan
-pickFirst s@RealSrcSpan{} _ = s
-pickFirst _ s = s
-printMessage' :: Severity -> Ctx -> Var -> SDoc -> CoreM ()
-printMessage' sev cxt var doc =
- printMessage sev (pickFirst (srcLoc cxt) (nameSrcSpan (varName var))) doc
+-- | Compute the set of variables defined by the given Haskell binder.
+getAllBV :: GenLocated l (HsBindLR GhcTc GhcTc) -> Set Var
+getAllBV (L _ b) = getAllBV' b where
+ getAllBV' (FunBind{fun_id = L _ v}) = Set.singleton v
+ getAllBV' (AbsBinds {abs_exports = es, abs_binds = bs}) = Set.fromList (map abe_poly es) `Set.union` foldMap getBV bs
+ getAllBV' (PatBind {pat_lhs = pat}) = getBV pat
+ getAllBV' (VarBind {var_id = v}) = Set.singleton v
+ getAllBV' PatSynBind{} = Set.empty
+ getAllBV' XHsBindsLR{} = Set.empty
+
+
+-- Check nested bindings
+instance ScopeBind (RecFlag, LHsBinds GhcTc) where
+ checkBind (NonRecursive, bs) = checkBind $ bagToList bs
+ checkBind (Recursive, bs) = checkRecursiveBinds bs' (foldMap getAllBV bs')
+ where bs' = bagToList bs
-printMessageCheck :: Severity -> Ctx -> Var -> SDoc -> CoreM Bool
-printMessageCheck sev cxt var doc = printMessage' sev cxt var doc >>
- case sev of
- SevError -> return False
- _ -> return True
+instance ScopeBind (HsLocalBindsLR GhcTc GhcTc) where
+ checkBind (HsValBinds _ bs) = checkBind bs
+ checkBind HsIPBinds {} = notSupported "implicit parameters"
+ checkBind EmptyLocalBinds{} = return (True,Set.empty)
+ checkBind XHsLocalBindsLR{} = return (True,Set.empty)
+instance Scope a => Scope (GRHSs GhcTc a) where
+ check GRHSs{grhssGRHSs = rhs, grhssLocalBinds = lbinds} = do
+ (l,vs) <- checkBind lbinds
+ r <- addVars vs `modifyCtxt` (check rhs)
+ return (r && l)
+ check XGRHSs{} = return True
-emptyCtx :: Maybe (Set Var) -> Var -> Ctx
-emptyCtx mvar fun =
- Ctx { current = Set.empty,
- earlier = Nothing,
- hidden = Map.empty,
- hiddenRec = Map.empty,
- srcLoc = UnhelpfulSpan "<no location info>",
- recDef = mvar,
- funDef = fun,
- primAlias = Map.empty,
- stableTypes = Set.empty,
- stabilized = isJust mvar}
+instance Show Var where
+ show v = getOccString v
-isPrimExpr :: Ctx -> Expr Var -> Maybe (Prim,Var)
-isPrimExpr c (App e (Type _)) = isPrimExpr c e
-isPrimExpr c (App e e') | not $ tcIsLiftedTypeKind $ typeKind $ exprType e' = isPrimExpr c e
-isPrimExpr c (Var v) = fmap (,v) (isPrim c v)
-isPrimExpr c (Tick _ e) = isPrimExpr c e
-isPrimExpr c (Lam v e)
- | isTyVar v || (not $ tcIsLiftedTypeKind $ typeKind $ varType v) = isPrimExpr c e
-isPrimExpr _ _ = Nothing
+boxReason StableBox = "Nested use of box"
+boxReason StableArr = "The use of box in the scope of arr"
+boxReason (StableRec _ ) = "The use of box in a recursive definition"
+arrReason StableArr = "Nested use of arr"
+arrReason StableBox = "The use of arr in the scope of box"
+arrReason (StableRec _) = "The use of arr in a recursive definition"
-isStableConstr :: Type -> CoreM (Maybe Var)
+instance Scope (HsExpr GhcTc) where
+ check (HsVar _ (L _ v))
+ | Just p <- isPrim v =
+ case p of
+ Unbox -> return True
+ _ -> printMessageCheck SevError ("Defining an alias for " <> ppr v <> " is not allowed")
+ | otherwise = case getScope v of
+ Hidden reason -> printMessageCheck SevError reason
+ Visible -> return True
+ ImplUnboxed -> printMessageCheck SevWarning
+ (ppr v <> text " is an external temporal function used under delay, which may cause time leaks")
+ check (HsApp _ e1 e2) =
+ case isPrimExpr e1 of
+ Just (p,_) -> case p of
+ Box -> do
+ ch <- stabilize StableBox `modifyCtxt` check e2
+ case stabilized ?ctxt of
+ Just reason | ch ->
+ (printMessage' SevWarning (boxReason reason <> " can cause time leaks")) >> return ch
+ _ -> return ch
+ Arr -> do
+ ch <- stabilize StableArr `modifyCtxt` check e2
+ -- don't bother with a warning if the scopecheck fails
+ case stabilized ?ctxt of
+ Just reason | ch ->
+ printMessage' SevWarning (arrReason reason <> " can cause time leaks") >> return ch
+ _ -> return ch
+
+ Unbox -> check e2
+ Delay -> case earlier ?ctxt of
+ Just _ -> printMessageCheck SevError (text "cannot delay more than once")
+ Nothing -> (\c -> c{current = Set.empty, earlier = Just (current ?ctxt)})
+ `modifyCtxt` check e2
+ Adv -> case earlier ?ctxt of
+ Just er -> mod `modifyCtxt` check e2
+ where mod c = c{earlier = Nothing, current = er,
+ hidden = hidden ?ctxt `Map.union`
+ Map.fromSet (const AdvApp) (current ?ctxt)}
+ Nothing -> printMessageCheck SevError (text "can only advance under delay")
+ _ -> liftM2 (&&) (check e1) (check e2)
+ check HsUnboundVar{} = return True
+ check HsConLikeOut{} = return True
+ check HsRecFld{} = return True
+ check HsOverLabel{} = return True
+ check HsIPVar{} = notSupported "implicit parameters"
+ check HsOverLit{} = return True
+
+#if __GLASGOW_HASKELL__ >= 808
+ check (HsAppType _ e _)
+#else
+ check (HsAppType _ e)
+#endif
+ = check e
+
+ check (HsTick _ _ e) = check e
+ check (HsBinTick _ _ _ e) = check e
+ check (HsSCC _ _ _ e) = check e
+ check (HsPar _ e) = check e
+ check (HsWrap _ _ e) = check e
+ check HsLit{} = return True
+ check (OpApp _ e1 e2 e3) = and <$> mapM check [e1,e2,e3]
+ check (HsLam _ mg) = stabilizeLater `modifyCtxt` check mg
+ check (HsLamCase _ mg) = stabilizeLater `modifyCtxt` check mg
+ check (HsIf _ _ e1 e2 e3) = and <$> mapM check [e1,e2,e3]
+ check (HsCase _ e1 e2) = (&&) <$> check e1 <*> check e2
+ check (SectionL _ e1 e2) = (&&) <$> check e1 <*> check e2
+ check (SectionR _ e1 e2) = (&&) <$> check e1 <*> check e2
+ check (ExplicitTuple _ e _) = check e
+ check (HsLet _ bs e) = do
+ (l,vs) <- checkBind bs
+ r <- addVars vs `modifyCtxt` (check e)
+ return (r && l)
+ check (NegApp _ e _) = check e
+ check (ExplicitSum _ _ _ e) = check e
+ check (HsMultiIf _ e) = check e
+ check (ExplicitList _ _ e) = check e
+ check RecordCon { rcon_flds = f} = check f
+ check RecordUpd { rupd_expr = e, rupd_flds = fs} = (&&) <$> check e <*> check fs
+#if __GLASGOW_HASKELL__ >= 808
+ check (ExprWithTySig _ e _)
+#else
+ check (ExprWithTySig _ e)
+#endif
+ = check e
+ check (ArithSeq _ _ e) = check e
+ check HsBracket{} = notSupported "MetaHaskell"
+ check HsRnBracketOut{} = notSupported "MetaHaskell"
+ check HsTcBracketOut{} = notSupported "MetaHaskell"
+ check HsSpliceE{} = notSupported "Template Haskell"
+ check (HsProc _ p e) = mod `modifyCtxt` check e
+ where mod c = addVars (getBV p) (stabilize StableArr c)
+ check (HsStatic _ e) = check e
+ check (HsDo _ _ e) = fst <$> checkBind e
+ check (HsCoreAnn _ _ _ e) = check e
+ check (HsTickPragma _ _ _ _ e) = check e
+ check XExpr {} = return True
+#if __GLASGOW_HASKELL__ < 810
+ check HsArrApp{} = impossible
+ check HsArrForm{} = impossible
+ check EWildPat{} = impossible
+ check EAsPat{} = impossible
+ check EViewPat{} = impossible
+ check ELazyPat{} = impossible
+
+impossible :: GetCtxt => TcM Bool
+impossible = printMessageCheck SevError "This syntax should never occur after typechecking"
+#endif
+
+
+instance Scope (HsCmdTop GhcTc) where
+ check (HsCmdTop _ e) = check e
+ check XCmdTop{} = return True
+
+instance Scope (HsCmd GhcTc) where
+ check (HsCmdArrApp _ e1 e2 _ _) = (&&) <$> check e1 <*> check e2
+ check (HsCmdDo _ e) = fst <$> checkBind e
+ check (HsCmdArrForm _ e1 _ _ e2) = (&&) <$> check e1 <*> check e2
+ check (HsCmdApp _ e1 e2) = (&&) <$> check e1 <*> check e2
+ check (HsCmdLam _ e) = check e
+ check (HsCmdPar _ e) = check e
+ check (HsCmdCase _ e1 e2) = (&&) <$> check e1 <*> check e2
+ check (HsCmdIf _ _ e1 e2 e3) = (&&) <$> ((&&) <$> check e1 <*> check e2) <*> check e3
+ check (HsCmdLet _ bs e) = do
+ (l,vs) <- checkBind bs
+ r <- addVars vs `modifyCtxt` (check e)
+ return (r && l)
+ check (HsCmdWrap _ _ e) = check e
+ check XCmd{} = return True
+
+
+-- | This is used when checking function definitions. If the context
+-- is not ticked, it stays the same. Otherwise, it is stabilized
+-- (similar to 'box').
+stabilizeLater :: Ctxt -> Ctxt
+stabilizeLater c =
+ if isJust (earlier c)
+ then c {earlier = Nothing,
+ hidden = hid,
+ hiddenRec = maybe (hiddenRec c) (Map.union (hidden c) . Map.fromSet (const FunDef))
+ (fst <$> recDef c),
+ recDef = Nothing}
+ else c
+ where hid = maybe (hidden c) (Map.union (hidden c) . Map.fromSet (const FunDef)) (earlier c)
+
+
+instance Scope (ArithSeqInfo GhcTc) where
+ check (From e) = check e
+ check (FromThen e1 e2) = (&&) <$> check e1 <*> check e2
+ check (FromTo e1 e2) = (&&) <$> check e1 <*> check e2
+ check (FromThenTo e1 e2 e3) = (&&) <$> ((&&) <$> check e1 <*> check e2) <*> check e3
+
+instance Scope (HsRecordBinds GhcTc) where
+ check HsRecFields {rec_flds = fs} = check fs
+
+instance Scope (HsRecField' a (LHsExpr GhcTc)) where
+ check HsRecField{hsRecFieldArg = a} = check a
+
+instance Scope (HsTupArg GhcTc) where
+ check (Present _ e) = check e
+ check Missing{} = return True
+ check XTupArg{} = return True
+
+instance Scope (HsBindLR GhcTc GhcTc) where
+ check AbsBinds {abs_binds = binds, abs_ev_vars = ev} = mod `modifyCtxt` check binds
+ where mod c = c { stableTypes= stableTypes c `Set.union`
+ Set.fromList (mapMaybe (isStableConstr . varType) ev)}
+ check FunBind{fun_matches= matches, fun_id = L _ v} = mod `modifyCtxt` check matches
+ where mod c = c { stableTypes= stableTypes c `Set.union`
+ Set.fromList (extractStableConstr (varType v))}
+ check PatBind{pat_lhs = lhs, pat_rhs=rhs} = addVars (getBV lhs) `modifyCtxt` check rhs
+ check VarBind{var_rhs = rhs} = check rhs
+ check PatSynBind {} = return True -- pattern synonyms are not supported
+ check XHsBindsLR {} = return True
+
+
+-- | Checks whether the given type is a type constraint of the form
+-- @Stable a@ for some type variable @a@. In that case it returns the
+-- type variable @a@.
+isStableConstr :: Type -> Maybe TyVar
isStableConstr t =
case splitTyConApp_maybe t of
Just (con,[args]) ->
case getNameModule con of
Just (name, mod) ->
if isRattModule mod && name == "Stable"
- then return (getTyVar_maybe args)
- else return Nothing
- _ -> return Nothing
- _ -> return Nothing
-
-checkExpr :: Ctx -> Expr Var -> CoreM Bool
-checkExpr c (App e e') | isType e' || (not $ tcIsLiftedTypeKind $ typeKind $ exprType e')
- = checkExpr c e
-checkExpr c@Ctx{current = cur, hidden = hid, earlier = earl} (App e1 e2) =
- case isPrimExpr c e1 of
- Just (p,v) -> case p of
- Box -> do
- ch <- checkExpr (stabilize BoxApp c) e2
- -- don't bother with a warning if the scopecheck fails
- when (ch && stabilized c && not (isStable (stableTypes c) (exprType e2)))
- (printMessage' SevWarning c v
- (text "When box is used inside another box or a recursive definition, it can cause time leaks unless applied to an expression of stable type"))
- return ch
- Arr -> do
- ch <- checkExpr (stabilize BoxApp c) e2
- -- don't bother with a warning if the scopecheck fails
- when (ch && stabilized c && not (isStable (stableTypes c) (exprType e2)))
- (printMessage' SevWarning c v
- (text "When arr is used inside box or a recursive definition, it can cause time leaks unless applied to an expression of stable type"))
- return ch
-
- Unbox -> checkExpr c e2
- Delay -> case earl of
- Just _ -> printMessageCheck SevError c v (text "cannot delay more than once")
- Nothing -> checkExpr c{current = Set.empty, earlier = Just cur} e2
- Adv -> case earl of
- Just er -> checkExpr c{earlier = Nothing, current = er,
- hidden = hid `Map.union` Map.fromSet (const AdvApp) cur} e2
- Nothing -> printMessageCheck SevError c v (text "can only advance under delay")
- _ -> liftM2 (&&) (checkExpr c e1) (checkExpr c e2)
-checkExpr c (Case e v _ alts) =
- liftM2 (&&) (checkExpr c e) (liftM and (mapM (\ (_,vs,e)-> checkExpr (addVars vs c') e) alts))
- where c' = addVars [v] c
-checkExpr c (Lam v e)
- | isTyVar v || (not $ tcIsLiftedTypeKind $ typeKind $ varType v) = do
- is <- isStableConstr (varType v)
- let c' = case is of
- Nothing -> c
- Just t -> c{stableTypes = Set.insert t (stableTypes c)}
- checkExpr c' e
- | otherwise = checkExpr (addVars [v] (stabilizeLater c)) e
-checkExpr _ (Type _) = return True
-checkExpr _ (Lit _) = return True
-checkExpr _ (Coercion _) = return True
-checkExpr c (Tick (SourceNote span _name) e) =
- checkExpr c{srcLoc = RealSrcSpan span} e
-checkExpr c (Tick _ e) = checkExpr c e
-checkExpr c (Cast e _) = checkExpr c e
-checkExpr c (Let (NonRec v e1) e2) =
- case isPrimExpr c e1 of
- Just (p,_) -> (checkExpr (c{primAlias = Map.insert v p (primAlias c)}) e2)
- Nothing -> liftM2 (&&) (checkExpr c e1) (checkExpr (addVars [v] c) e2)
-checkExpr _ (Let (Rec ([])) _) = return True
-checkExpr c (Let (Rec binds) e2) = do
- r1 <- mapM (\ (v,e) -> checkExpr (c' v) e) binds
- r2 <- checkExpr (addVars vs c) e2
- let r = (and r1 && r2)
- when (r && stabilized c) (printMessage' SevWarning c (head vs)
- (text "recursive definition nested inside a box or annother recursive definition can cause time leaks"))
- return r
- where vs = map fst binds
- vs' = Set.fromList vs
- ctxHid = maybe (current c) (Set.union (current c)) (earlier c)
- recHid = maybe ctxHid (Set.union ctxHid) (recDef c)
- c' v = c {current = Set.empty,
- earlier = Nothing,
- hidden = hidden c `Map.union`
- (Map.fromSet (const (NestedRec v)) recHid),
- recDef = Just (vs'),
- funDef = v,
- stabilized = True}
-checkExpr c (Var v)
- | tcIsLiftedTypeKind $ typeKind $ varType v =
- case isPrim c v of
- Just p ->
- case p of
- Unbox -> return True
- _ -> printMessage SevError (nameSrcSpan (varName (funDef c))) ("Defining an alias for " <> ppr v <> " is not allowed") >> return False
- _ -> case getScope c v of
- Hidden reason -> printMessageCheck SevError c v reason
- Visible -> return True
- ImplUnboxed -> printMessageCheck SevWarning c v
- (ppr v <> text " is an external temporal function used under delay, which may cause time leaks")
+ then (getTyVar_maybe args)
+ else Nothing
+ _ -> Nothing
+ _ -> Nothing
+
+
+-- | Given a type @(C1, ... Cn) => t@, this function returns the list
+-- of type variables @[a1,...,am]@ for which there is a constraint
+-- @Stable ai@ among @C1, ... Cn@.
+extractStableConstr :: Type -> [TyVar]
+extractStableConstr = mapMaybe isStableConstr . fst . splitFunTys . snd . splitForAllTys
- | otherwise = return True
+-- | Checks a top-level definition group, which is either a single
+-- non-recursive definition or a group of (mutual) recursive
+-- definitions.
+checkSCC :: SCC (LHsBindLR GhcTc GhcTc, Set Var) -> TcM Bool
+checkSCC (AcyclicSCC (b,_)) = setCtxt (emptyCtxt Nothing) (check b)
-addVars :: [Var] -> Ctx -> Ctx
-addVars v c = c{current = Set.fromList v `Set.union` current c }
+checkSCC (CyclicSCC bs) = (fmap and (mapM check' bs'))
+ where bs' = map fst bs
+ vs = foldMap snd bs
+ check' b@(L l _) = setCtxt (emptyCtxt (Just (vs,l))) (check b)
+-- | Stabilizes the given context, i.e. remove all non-stable types
+-- and any tick. This is performed on checking 'box', 'arr' and
+-- guarded recursive definitions. To provide better error messages a
+-- reason has to be given as well.
+stabilize :: StableReason -> Ctxt -> Ctxt
+stabilize sr c = c
+ {current = Set.empty,
+ earlier = Nothing,
+ hidden = hidden c `Map.union` Map.fromSet (const hr) ctxHid,
+ hiddenRec = hiddenRec c `Map.union` maybe Map.empty (Map.fromSet (const hr) . fst) (recDef c),
+ recDef = Nothing,
+ stabilized = Just sr}
+ where ctxHid = maybe (current c) (Set.union (current c)) (earlier c)
+ hr = Stabilize sr
+
+data VarScope = Hidden SDoc | Visible | ImplUnboxed
+
+
+-- | This function checks whether the given variable is in scope.
+getScope :: GetCtxt => Var -> VarScope
+getScope v =
+ case ?ctxt of
+ Ctxt{recDef = Just (vs,_), earlier = e}
+ | v `Set.member` vs ->
+ case e of
+ Just _ -> Visible
+ Nothing -> Hidden ("(Mutually) recursive call to " <> ppr v <> " must occur under delay")
+ _ ->
+ case Map.lookup v (hiddenRec ?ctxt) of
+ Just (Stabilize (StableRec rv)) -> Hidden ("Recursive call to" <> ppr v <>
+ " is not allowed as it occurs in a local recursive definiton (at " <> ppr rv <> ")")
+ Just (Stabilize StableBox) -> Hidden ("Recursive call to " <> ppr v <> " is not allowed here, since it occurs under a box")
+ Just (Stabilize StableArr) -> Hidden ("Recursive call to " <> ppr v <> " is not allowed here, since it occurs inside an arrow notation")
+ Just FunDef -> Hidden ("Recursive call to " <> ppr v <> " is not allowed here, since it occurs in a function that is defined under delay")
+ Just AdvApp -> Hidden ("This should not happen: recursive call to " <> ppr v <> " is out of scope due to adv")
+ Nothing ->
+ case Map.lookup v (hidden ?ctxt) of
+ Just (Stabilize (StableRec rv)) ->
+ if (isStable (stableTypes ?ctxt) (varType v)) then Visible
+ else Hidden ("Variable " <> ppr v <> " is no longer in scope:" $$
+ "It appears in a local recursive definiton (at " <> ppr rv <> ")"
+ $$ "and is of type " <> ppr (varType v) <> ", which is not stable.")
+ Just (Stabilize StableBox) ->
+ if (isStable (stableTypes ?ctxt) (varType v)) then Visible
+ else Hidden ("Variable " <> ppr v <> " is no longer in scope:" $$
+ "It occurs under " <> keyword "box" $$ "and is of type " <> ppr (varType v) <> ", which is not stable.")
+ Just (Stabilize StableArr) ->
+ if (isStable (stableTypes ?ctxt) (varType v)) then Visible
+ else Hidden ("Variable " <> ppr v <> " is no longer in scope:" $$
+ "It occurs under inside an arrow notation and is of type " <> ppr (varType v) <> ", which is not stable.")
+ Just AdvApp -> Hidden ("Variable " <> ppr v <> " is no longer in scope: It occurs under adv.")
+ Just FunDef -> if (isStable (stableTypes ?ctxt) (varType v)) then Visible
+ else Hidden ("Variable " <> ppr v <> " is no longer in scope: It occurs in a function that is defined under a delay, is a of a non-stable type " <> ppr (varType v) <> ", and is bound outside delay")
+ Nothing
+ | maybe False (Set.member v) (earlier ?ctxt) ->
+ if isStable (stableTypes ?ctxt) (varType v) then Visible
+ else Hidden ("Variable " <> ppr v <> " is no longer in scope:" $$
+ "It occurs under delay" $$ "and is of type " <> ppr (varType v) <> ", which is not stable.")
+ | Set.member v (current ?ctxt) -> Visible
+ | isTemporal (varType v) && isJust (earlier ?ctxt) && userFunction v
+ -> ImplUnboxed
+ | otherwise -> Visible
+
+-- | A map from the syntax of a primitive of Rattus to 'Prim'.
+primMap :: Map FastString Prim
+primMap = Map.fromList
+ [("Delay", Delay),
+ ("delay", Delay),
+ ("adv", Adv),
+ ("box", Box),
+ ("arr", Arr),
+ ("unbox", Unbox)]
+
+
+-- | Checks whether a given variable is in fact a Rattus primitive.
+isPrim :: GetCtxt => Var -> Maybe Prim
+isPrim v
+ | Just p <- Map.lookup v (primAlias ?ctxt) = Just p
+ | otherwise = do
+ (name,mod) <- getNameModule v
+ if isRattModule mod then Map.lookup name primMap
+ else Nothing
+
+
+-- | Checks whether a given expression is in fact a Rattus primitive.
+isPrimExpr :: GetCtxt => LHsExpr GhcTc -> Maybe (Prim,Var)
+isPrimExpr (L _ e) = isPrimExpr' e where
+ isPrimExpr' :: GetCtxt => HsExpr GhcTc -> Maybe (Prim,Var)
+ isPrimExpr' (HsVar _ (L _ v)) = fmap (,v) (isPrim v)
+
+#if __GLASGOW_HASKELL__ >= 808
+ isPrimExpr' (HsAppType _ e _)
+#else
+ isPrimExpr' (HsAppType _ e)
+#endif
+ = isPrimExpr e
+ isPrimExpr' (HsTick _ _ e) = isPrimExpr e
+ isPrimExpr' (HsBinTick _ _ _ e) = isPrimExpr e
+ isPrimExpr' (HsSCC _ _ _ e) = isPrimExpr e
+ isPrimExpr' (HsWrap _ _ e) = isPrimExpr' e
+ isPrimExpr' (HsPar _ e) = isPrimExpr e
+ isPrimExpr' _ = Nothing
+
+
+-- | This type class provides default implementations for 'check' and
+-- 'checkBind' for Haskell syntax that is not supported. These default
+-- implementations simply print an error message.
+class NotSupported a where
+ notSupported :: GetCtxt => SDoc -> TcM a
+
+instance NotSupported Bool where
+ notSupported doc = printMessageCheck SevError ("Rattus does not support " <> doc)
+
+instance NotSupported (Bool,Set Var) where
+ notSupported doc = (,Set.empty) <$> notSupported doc
+
+
+-- | Add variables to the current context.
+addVars :: Set Var -> Ctxt -> Ctxt
+addVars vs c = c{current = vs `Set.union` current c }
+
+-- | Print a message with the current location.
+printMessage' :: GetCtxt => Severity -> SDoc -> TcM ()
+printMessage' sev doc = printMessage sev (srcLoc ?ctxt) doc
+
+-- | Print a message with the current location. Returns 'False', if
+-- the severity is 'SevError' and otherwise 'True.
+printMessageCheck :: GetCtxt => Severity -> SDoc -> TcM Bool
+printMessageCheck sev doc = printMessage' sev doc >>
+ case sev of
+ SevError -> return False
+ _ -> return True
diff --git a/src/Rattus/Plugin/StableSolver.hs b/src/Rattus/Plugin/StableSolver.hs
index 9390c27..51baa8f 100644
--- a/src/Rattus/Plugin/StableSolver.hs
+++ b/src/Rattus/Plugin/StableSolver.hs
@@ -1,6 +1,10 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE CPP #-}
+
+-- | This module implements a constraint solver plugin for the
+-- 'Stable' type class.
+
module Rattus.Plugin.StableSolver (tcStable) where
import Rattus.Plugin.Utils
@@ -24,7 +28,7 @@ import qualified Data.Set as Set
import TcRnTypes
-
+-- | Constraint solver plugin for the 'Stable' type class.
tcStable :: [CommandLineOption] -> Maybe TcPlugin
tcStable _ = Just $ TcPlugin
{ tcPluginInit = return ()
@@ -32,6 +36,7 @@ tcStable _ = Just $ TcPlugin
, tcPluginStop = \ () -> return ()
}
+
wrap :: Class -> Type -> EvTerm
wrap cls ty = EvExpr appDc
where
diff --git a/src/Rattus/Plugin/Utils.hs b/src/Rattus/Plugin/Utils.hs
index 68c38eb..ebf2658 100644
--- a/src/Rattus/Plugin/Utils.hs
+++ b/src/Rattus/Plugin/Utils.hs
@@ -21,6 +21,7 @@ import Data.Set (Set)
import qualified Data.Set as Set
import Data.Char
import Data.Maybe
+import MonadUtils
isType Type {} = True
isType (App e _) = isType e
@@ -28,36 +29,15 @@ isType (Cast e _) = isType e
isType (Tick _ e) = isType e
isType _ = False
-
-
--- printMessage :: Severity -> SDoc -> CoreM ()
--- printMessage sev doc =
--- case sev of
--- Error -> errorMsg doc
--- Warning -> warnMsg doc
-
-
--- printMessageV :: Severity -> Var -> SDoc -> CoreM ()
--- printMessageV sev var doc =
--- let loc = nameSrcLoc (varName var)
--- doc' = ppr loc <> text ": " <> doc
--- in case sev of
--- Error -> errorMsg doc'
--- Warning -> warnMsg doc'
-
-
-printMessage :: Severity -> SrcSpan -> SDoc -> CoreM ()
+printMessage :: (HasDynFlags m, MonadIO m) =>
+ Severity -> SrcSpan -> MsgDoc -> m ()
printMessage sev loc doc = do
dflags <- getDynFlags
- unqual <- getPrintUnqualified
let sty = case sev of
- SevError -> err_sty
- SevWarning -> err_sty
- SevDump -> dump_sty
- _ -> user_sty
- err_sty = mkErrStyle dflags unqual
- user_sty = mkUserStyle dflags unqual AllTheWay
- dump_sty = mkDumpStyle dflags unqual
+ SevError -> defaultErrStyle dflags
+ SevWarning -> defaultErrStyle dflags
+ SevDump -> defaultDumpStyle dflags
+ _ -> defaultUserStyle dflags
liftIO $ putLogMsg dflags NoReason sev loc sty doc
diff --git a/src/Rattus/Stream.hs b/src/Rattus/Stream.hs
index 0c57b28..642f494 100644
--- a/src/Rattus/Stream.hs
+++ b/src/Rattus/Stream.hs
@@ -1,6 +1,6 @@
{-# OPTIONS -fplugin=Rattus.Plugin #-}
{-# LANGUAGE TypeOperators #-}
-
+{-# LANGUAGE CPP #-}
-- | Programming with streams.
module Rattus.Stream
@@ -148,6 +148,13 @@ integral acc (t ::: ts) (a ::: as) = acc' ::: delay (integral acc' (adv ts) (adv
"map/scan" forall f p acc as.
map p (scan f acc as) = scanMap f p acc as ;
+ "zip/map" forall xs ys f.
+ map f (zip xs ys) = let f' = unbox f in zipWith (box (\ x y -> f' (x :* y))) xs ys
+#-}
+
+
+#if __GLASGOW_HASKELL__ >= 808
+{-# RULES
"scan/scan" forall f g b c as.
scan g c (scan f b as) =
let f' = unbox f; g' = unbox g in
@@ -158,6 +165,5 @@ integral acc (t ::: ts) (a ::: as) = acc' ::: delay (integral acc' (adv ts) (adv
let f' = unbox f; g' = unbox g; p' = unbox p in
scanMap (box (\ (b:*c) a -> let b' = f' (p' b) a in (b':* g' c b'))) (box snd') (b:*c) as ;
- "zip/map" forall xs ys f.
- map f (zip xs ys) = let f' = unbox f in zipWith (box (\ x y -> f' (x :* y))) xs ys
#-}
+#endif
diff --git a/test/IllTyped.hs b/test/IllTyped.hs
index d542883..3b80d37 100644
--- a/test/IllTyped.hs
+++ b/test/IllTyped.hs
@@ -1,8 +1,11 @@
+{-# LANGUAGE Arrows #-}
+{-# LANGUAGE RebindableSyntax #-}
module Main (module Main) where
import Rattus
import Rattus.Stream
+import Rattus.Yampa
import Prelude hiding ((<*>), map, const)
-- Uncomment the annotation below to check that the definitions in
@@ -11,20 +14,27 @@ import Prelude hiding ((<*>), map, const)
-- {-# ANN module Rattus #-}
--- This function will produce a confusing scoping error message since
--- GHC will inline the let-binding before Rattus' scope checker gets
--- to see it.
+sfLeak :: O Int -> SF () (O Int)
+sfLeak x = proc _ -> do
+ returnA -< x
+
advDelay :: O (O a) -> O a
advDelay y = delay (let x = adv y in adv x)
+advDelay' :: O a -> a
+advDelay' y = let x = adv y in x
+
+dblAdv :: O (O a) -> O a
+dblAdv y = delay (adv (adv y))
+
dblDelay :: O (O Int)
dblDelay = delay (delay 1)
lambdaUnderDelay :: O (O Int -> Int -> Int)
lambdaUnderDelay = delay (\x _ -> adv x)
-sneakyLambdaUnderDelay :: O (O Int -> Int -> Int)
-sneakyLambdaUnderDelay = delay (let f x _ = adv x in f)
+sneakyLambdaUnderDelay :: O (Int -> Int)
+sneakyLambdaUnderDelay = delay (let f _ = adv (delay 1) in f)
lambdaUnderDelay' :: O Int -> O (Int -> O Int)
diff --git a/test/TimeLeak.hs b/test/TimeLeak.hs
index 809c07a..2ab5301 100644
--- a/test/TimeLeak.hs
+++ b/test/TimeLeak.hs
@@ -10,7 +10,6 @@ import Prelude hiding ((<*>), map)
{-# ANN module Rattus #-}
-
nats' :: Str Int
nats' = unfold (box ((+) 1)) 0
diff --git a/test/WellTyped.hs b/test/WellTyped.hs
index 759feb5..36f6b6c 100644
--- a/test/WellTyped.hs
+++ b/test/WellTyped.hs
@@ -1,7 +1,12 @@
+{-# LANGUAGE TypeOperators #-}
+{-# LANGUAGE Arrows #-}
+{-# LANGUAGE RebindableSyntax #-}
+
module Main (module Main) where
import Rattus
import Rattus.Stream
+import Rattus.Yampa
import Prelude hiding ((<*>))
import Data.Set
import qualified Data.Set as Set
@@ -9,6 +14,22 @@ import qualified Data.Set as Set
{-# ANN module Rattus #-}
+ballPos :: SF (Int :* Int) Int
+ballPos = arr (\ (x :* y) -> x)
+
+padPos :: SF Int Int
+padPos = arr (\ x -> x)
+
+{-# ANN pong AllowLazyData #-}
+pong :: SF Int (Int :* Int)
+pong = proc inp -> do
+ pad <- padPos -< inp
+ ball <- ballPos -< (pad :* inp)
+ returnA -< (ball :* pad)
+
+
+boxedInt :: Box Int
+boxedInt = box 8
lambdaUnderDelay :: O (Int -> Int -> Int)
lambdaUnderDelay = delay (\x _ -> x)
@@ -29,7 +50,7 @@ scanBox f acc (a ::: as) = unbox acc' ::: delay (scanBox f (unbox acc') (adv as
sumBox :: Str Int -> Str Int
-sumBox = scanBox (box (\x y -> box (x + y))) 0
+sumBox = scanBox (box (\x y -> box' (x + y))) 0
map1 :: Box (a -> b) -> Str a -> Str b
map1 f (x ::: xs) = unbox f x ::: delay (map1 f (adv xs))
@@ -49,6 +70,12 @@ bar2 :: Box (a -> b) -> Str a -> Str b
bar2 f (x ::: xs) = unbox f x ::: (delay (bar1 f) <*> xs)
+-- mutual recursive definition
+foo1,foo2 :: Box (a -> b) -> Str a -> Str b
+(foo1,foo2) = (\ f (x ::: xs) -> unbox f x ::: (delay (foo2 f) <*> xs),
+ \ f (x ::: xs) -> unbox f x ::: (delay (foo1 f) <*> xs))
+
+
applyDelay :: O (O (a -> b)) -> O (O a) -> O (O b)
applyDelay f x = delay (adv f <*> adv x)