summaryrefslogtreecommitdiff
path: root/Proof/Propositional/TH.hs
blob: 0f7af1cb272ae78923b7b7ffc658accc7d69acc7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
{-# LANGUAGE CPP, ExplicitNamespaces, MultiWayIf, PatternGuards #-}
{-# LANGUAGE TemplateHaskell, TupleSections                     #-}
module Proof.Propositional.TH where
import Proof.Propositional.Empty
import Proof.Propositional.Inhabited

import           Control.Arrow               (Kleisli (..), second)
import           Control.Monad               (forM, zipWithM)
import           Data.Foldable               (asum)
import           Data.Map                    (Map)
import qualified Data.Map                    as M
import           Data.Maybe                  (fromJust)
import           Data.Semigroup              (Semigroup (..))
import           Data.Type.Equality          ((:~:) (..))
import           Language.Haskell.TH         (DecsQ, Lit (CharL, IntegerL),
                                              Name, Q, TypeQ, isInstance,
                                              newName, ppr)
import           Language.Haskell.TH.Desugar (DClause (..), DCon (..),
                                              DConFields (..), DCxt, DDec (..),
                                              DExp (..), DInfo (..),
                                              DLetDec (DFunD),
                                              DPat (DConPa, DVarPa), DPred (..),
                                              DTyVarBndr (..), DType (..),
                                              Overlap (Overlapping), desugar,
                                              dsReify, expandType, substTy,
                                              sweeten)

-- | Macro to automatically derive @'Empty'@ instance for
--   concrete (variable-free) types which may contain products.
refute :: TypeQ -> DecsQ
refute tps = do
  tp <- expandType =<< desugar =<< tps
  let Just (_, tyName, args) = splitType tp
      mkInst dxt cls = return $ sweeten
                       [DInstanceD (Just Overlapping) dxt
                        (DAppT (DConT ''Empty) (foldl DAppT (DConT tyName) args))
                        [DLetDec $ DFunD 'eliminate cls]
                         ]
  if tyName == ''(:~:)
    then do
      let [l, r] = args
      v    <- newName "_v"
      dist <- compareType l r
      case dist of
        NonEqual    -> mkInst [] [DClause [] $ DLamE [v] (DCaseE (DVarE v) []) ]
        Equal       -> fail $ "Equal: " ++ show (ppr $ sweeten l) ++ " ~ " ++ show (ppr $ sweeten r)
        Undecidable -> fail $ "No enough info to check non-equality: " ++
                         show (ppr $ sweeten l) ++ " ~ " ++ show (ppr $ sweeten r)
    else do
      (dxt, cons) <- resolveSubsts args . fromJust =<< dsReify tyName
      Just cls <- sequence <$> mapM buildRefuteClause cons
      mkInst dxt cls

-- | Macro to automatically derive @'Inhabited'@ instance for
--   concrete (variable-free) types which may contain sums.
prove :: TypeQ -> DecsQ
prove tps = do
  tp <- expandType =<< desugar =<< tps
  let Just (_, tyName, args) = splitType tp
      mkInst dxt cls = return $ sweeten
                       [DInstanceD (Just Overlapping) dxt
                        (DAppT (DConT ''Inhabited) (foldl DAppT (DConT tyName) args))
                        [DLetDec $ DFunD 'trivial cls]
                         ]
  isNum <- isInstance ''Num [sweeten tp]

  if | isNum -> mkInst [] [DClause [] $ DLitE $ IntegerL 0 ]
     | tyName == ''Char  -> mkInst [] [DClause [] $ DLitE $ CharL '\NUL']
     | tyName == ''(:~:) -> do
        let [l, r] = args
        dist <- compareType l r
        case dist of
          NonEqual    -> fail $ "Equal: " ++ show (ppr $ sweeten l) ++ " ~ " ++ show (ppr $ sweeten r)
          Equal       -> mkInst [] [DClause [] $ DConE 'Refl ]
          Undecidable -> fail $ "No enough info to check non-equality: " ++
                           show (ppr $ sweeten l) ++ " ~ " ++ show (ppr $ sweeten r)
     | otherwise -> do
        (dxt, cons) <- resolveSubsts args . fromJust =<< dsReify tyName
        Just cls <- asum <$> mapM buildProveClause cons
        mkInst dxt [cls]

buildClause :: Name -> (DType -> Q b) -> (DType -> b -> DExp)
            -> (Name -> [Maybe DExp] -> Maybe DExp) -> (Name -> [b] -> [DPat])
            -> DCon -> Q (Maybe DClause)
buildClause clsName genPlaceHolder buildFactor flattenExps toPats (DCon _ _ cName flds _) = do
  let tys = fieldsVars flds
  varDic <- mapM genPlaceHolder tys
  fmap (DClause $ toPats cName varDic) . flattenExps cName <$> zipWithM tryProc tys varDic
  where
    tryProc ty name = do
      isEmpty <- isInstance clsName . (:[]) $ sweeten ty
      return $ if isEmpty
        then Just $ buildFactor ty name
        else Nothing

buildRefuteClause :: DCon -> Q (Maybe DClause)
buildRefuteClause =
  buildClause
    ''Empty (const $ newName "_x")
    (const $ (DVarE 'eliminate `DAppE`) . DVarE) (const asum)
    (\cName ps -> [DConPa cName $ map DVarPa ps])

buildProveClause :: DCon -> Q (Maybe DClause)
buildProveClause  =
  buildClause
    ''Inhabited (const $ return ())
    (const $ const $ DVarE 'trivial)
    (\ con args -> foldl DAppE (DConE con) <$> sequence args )
    (const $ const [])

fieldsVars :: DConFields -> [DType]
#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ < 804
fieldsVars (DNormalC fs)
#else
fieldsVars (DNormalC _ fs)
#endif
  = map snd fs
fieldsVars (DRecC fs)    = map (\(_,_,c) -> c) fs

resolveSubsts :: [DType] -> DInfo -> Q (DCxt, [DCon])
resolveSubsts args info =
  case info of
    (DTyConI (DDataD _ cxt _ tvbs
#if MIN_VERSION_th_desugar(1,9,0)
              _
#endif
             dcons _) _) -> do
      let dic = M.fromList $ zip (map dtvbToName tvbs) args
      (cxt , ) <$> mapM (substDCon dic) dcons
    -- (DTyConI (DOpenTypeFamilyD n) _) ->  return []
    -- (DTyConI (DClosedTypeFamilyD _ ddec2) minst) ->  return []
    -- (DTyConI (DDataFamilyD _ ddec2) minst) ->  return []
    -- (DTyConI (DDataInstD _ ddec2 ddec3 ddec4 ddec5 ddec6) minst) ->  return []
    (DTyConI _ _) -> fail "Not supported data ty"
    _ -> fail "Please pass data-type"

type SubstDic = Map Name DType

substDCon :: SubstDic -> DCon -> Q DCon
substDCon dic (DCon forall'd cxt conName fields mPhantom) =
  DCon forall'd cxt conName
    <$> substFields dic fields
#if MIN_VERSION_th_desugar(1,9,0)
    <*> substTy dic mPhantom
#else
    <*> mapM (substTy dic) mPhantom
#endif

substFields :: SubstDic -> DConFields -> Q DConFields
substFields subst
#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ < 804
  (DNormalC fs)
#else
  (DNormalC fixi fs)
#endif
  =
#if defined(__GLASGOW_HASKELL__) && __GLASGOW_HASKELL__ < 804
  DNormalC <$>
#else
  DNormalC fixi <$>
#endif
  mapM (runKleisli $ second $ Kleisli $ substTy subst) fs
substFields subst (DRecC fs) =
  DRecC <$> forM fs (\(a,b,c) -> (a, b ,) <$> substTy subst c)

dtvbToName :: DTyVarBndr -> Name
dtvbToName (DPlainTV n)    = n
dtvbToName (DKindedTV n _) = n

splitType :: DType -> Maybe ([Name], Name, [DType])
splitType (DForallT vs _ t) = (\(a,b,c) -> (map dtvbToName vs ++ a, b, c)) <$> splitType t
splitType (DAppT t1 t2) = (\(a,b,c) -> (a, b, c ++ [t2])) <$> splitType t1
splitType (DSigT t _) = splitType t
splitType (DVarT _) = Nothing
splitType (DConT n) = Just ([], n, [])
splitType DArrowT = Just ([], ''(->), [])
splitType (DLitT _) = Nothing
splitType DWildCardT = Nothing
#if !MIN_VERSION_th_desugar(1,9,0)
splitType DStarT = Nothing
#endif


data EqlJudge = NonEqual | Undecidable | Equal
              deriving (Read, Show, Eq, Ord)

instance Semigroup EqlJudge where
  NonEqual    <> _        = NonEqual
  Undecidable <> NonEqual = NonEqual
  Undecidable <> _        = Undecidable
  Equal       <> m        = m

instance Monoid EqlJudge where
  mappend = (<>)
  mempty = Equal

compareType :: DType -> DType -> Q EqlJudge
compareType t0 s0 = do
  t <- expandType t0
  s <- expandType s0
  compareType' t s

compareType' :: DType -> DType -> Q EqlJudge
compareType' (DSigT t1 t2) (DSigT s1 s2)
  = (<>) <$> compareType' t1 s1 <*> compareType' t2 s2
compareType' (DSigT t _) s
  = compareType' t s
compareType' t (DSigT s _)
  = compareType' t s
compareType' (DVarT t) (DVarT s)
  | t == s    = return Equal
  | otherwise = return Undecidable
compareType' (DVarT _) _ = return Undecidable
compareType' _ (DVarT _) = return Undecidable
compareType' DWildCardT _ = return Undecidable
compareType' _ DWildCardT = return Undecidable
compareType' (DForallT tTvBs tCxt t) (DForallT sTvBs sCxt s)
  | length tTvBs == length sTvBs = do
      let dic = M.fromList $ zip (map dtvbToName sTvBs) (map (DVarT . dtvbToName) tTvBs)
      s' <- substTy dic s
      pd <- compareCxt tCxt =<< mapM (substPred dic) sCxt
      bd <- compareType' t s'
      return (pd <> bd)
  | otherwise = return NonEqual
compareType' DForallT{} _   = return NonEqual
compareType' (DAppT t1 t2) (DAppT s1 s2)
  = (<>) <$> compareType' t1 s1 <*> compareType' t2 s2
compareType' (DConT t) (DConT s)
  | t == s    = return Equal
  | otherwise = return NonEqual
compareType' (DConT _) _ = return NonEqual
compareType' DArrowT DArrowT = return Equal
compareType' DArrowT _ = return NonEqual
compareType' (DLitT t) (DLitT s)
  | t == s    = return Equal
  | otherwise = return NonEqual
compareType' (DLitT _) _ = return NonEqual
#if !MIN_VERSION_th_desugar(1,9,0)
compareType' DStarT DStarT = return NonEqual
#endif
compareType' _ _ = return NonEqual

compareCxt :: DCxt -> DCxt -> Q EqlJudge
compareCxt l r = mconcat <$> zipWithM comparePred l r

comparePred :: DPred -> DPred -> Q EqlJudge
comparePred DWildCardPr _ = return Undecidable
comparePred _ DWildCardPr = return Undecidable
comparePred (DVarPr l) (DVarPr r)
  | l == r = return Equal
comparePred (DVarPr _) _ = return Undecidable
comparePred _ (DVarPr _) = return Undecidable
comparePred (DSigPr l t) (DSigPr r s) =
  (<>) <$> compareType' t s <*> comparePred l r
comparePred (DSigPr l _) r = comparePred l r
comparePred l (DSigPr r _) = comparePred l r
comparePred (DAppPr l1 l2) (DAppPr r1 r2) = do
  l2' <- expandType l2
  r2' <- expandType r2
  (<>) <$> comparePred l1 r1 <*> compareType' l2' r2'
comparePred (DAppPr _ _) _ = return NonEqual
comparePred (DConPr l) (DConPr r)
  | l == r = return Equal
  | otherwise = return NonEqual
comparePred (DConPr _) _ = return NonEqual
#if MIN_VERSION_th_desugar(1,9,0)
comparePred (DForallPr _ _ _) (DForallPr _ _ _) = return Undecidable
comparePred (DForallPr{}) _ = return NonEqual
#endif

substPred :: SubstDic -> DPred -> Q DPred
substPred dic (DAppPr p1 p2) = DAppPr <$> substPred dic p1 <*> (expandType =<< substTy dic p2)
substPred dic (DSigPr p knd) = DSigPr <$> substPred dic p  <*> (expandType =<< substTy dic knd)
substPred dic prd@(DVarPr p)
  | Just (DVarT t) <- M.lookup p dic = return $ DVarPr t
  | Just (DConT t) <- M.lookup p dic = return $ DConPr t
  | otherwise = return prd
substPred _ t = return t