module Yhc.Core.Strictness(coreStrictness) where
import Yhc.Core.Type
import Yhc.Core.Prim
import qualified Data.Map as Map
import Data.List(intersect, nub, partition)
{-
ALGORITHM:
SCC PARTIAL SORT:
First sort the functions so that they occur in the childmost order:
x1 < x2, if x1 doesn't transitive-call x2, and x2 does transitive-call x1
Being wrong is fine, but being better gives better results
PRIM STRICTNESS:
The strictness of the various primitive operations
BASE STRICTNESS:
If all paths case on a particular value, then these are strict in that one
If call onwards, then strict based on the caller
-}
-- | Given a function, return a list of arguments.
-- True is strict in that argument, False is not.
-- [] is unknown strictness
coreStrictness :: Core -> (CoreFuncName -> [Bool])
coreStrictness core = \funcname -> Map.findWithDefault [] funcname mp
where mp = mapStrictness $ sccSort $ coreFuncs core
mapStrictness :: [CoreFunc] -> Map.Map CoreFuncName [Bool]
mapStrictness funcs = foldl f Map.empty funcs
where
f mp (CorePrim{coreFuncName=name}) = case corePrimMaybe name of
Nothing -> mp
Just p -> Map.insert name (primStrict p) mp
f mp (CoreFunc name args body) = Map.insert name (map (`elem` strict) args) mp
where
strict = strictVars body
-- which variables are strict
strictVars :: CoreExpr -> [String]
strictVars (CorePos _ x) = strictVars x
strictVars (CoreVar x) = [x]
strictVars (CoreCase (CoreVar x) alts) = nub $ x : intersectList (map (strictVars . snd) alts)
strictVars (CoreCase x alts) = strictVars x
strictVars (CoreApp (CoreFun x) xs)
| length xs == length res
= nub $ concatMap strictVars $ map snd $ filter fst $ zip res xs
where res = Map.findWithDefault [] x mp
strictVars (CoreApp x xs) = strictVars x
strictVars _ = []
intersectList :: Eq a => [[a]] -> [a]
intersectList [] = []
intersectList xs = foldr1 intersect xs
-- do a sort in approximate SCC order
sccSort :: [CoreFunc] -> [CoreFunc]
sccSort xs = prims ++ funcs
where (prims,funcs) = partition isCorePrim xs