summaryrefslogtreecommitdiff
path: root/src/AI/Search/FiniteDomain/Int/Constraint.hs
blob: 50d3b47a3b16e4c22b35c4d45017f26f21c91c90 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module AI.Search.FiniteDomain.Int.Constraint
  ( (#=)
  , (#/=)
  , (#<)
  , (#<=)
  , (#>)
  , (#>=)
  , (/\)
  , (\/)
  , Constraint
  , FD
  , Labeling(..)
  , allDifferent
  , between
  , initNewVar
  , labeling
  , newVar
  , not'
  , runFD
  ) where

-- base
import Control.Monad    ( forM, forM_ )
import Control.Monad.ST ( ST, runST )
import Data.List        ( find )

-- domain
import qualified Numeric.Domain as D

-- propeller
import Data.Propagator.Cell as P ( Cell, cell, connect, label, propagateMany
                                 , readCell, succeeded, sync, syncWith )

-- transformers
import Control.Monad.Trans.State ( State, evalState, execState, get, modify, put )

import AI.Search.FiniteDomain.Int.Cell       ( domainJoin, eqJoin, mustHoldJoin )
import AI.Search.FiniteDomain.Int.Expression ( Expression, cellifyExpression, var )

-- | All constraint solving actions are peformed in the FD monad which tracks
-- created variables and specified constraints.
newtype FD a = FD { unFD :: State ([IntConstraint], Int) a }
  deriving (Applicative, Functor, Monad)

-- | A constraint restricts the possible values of an involved 'Expression'.
type Constraint = FD ()

addConstraint :: IntConstraint -> FD ()
addConstraint cons = FD $ 
  modify (\(cs, ix) -> (cons : cs, ix))

-- | Runs the FD monad computation.
runFD :: FD a -> a
runFD (FD state) = evalState state ([], 0)

data IntConstraint
  = Equals    Expression Expression
  | NotEquals Expression Expression
  | LessThan  Expression Expression
  | And       IntConstraint IntConstraint
  | Or        IntConstraint IntConstraint
  deriving (Eq, Ord, Show)

-- | Enforces that two expressions have the same value.
infix 4 #=
(#=) :: Expression -> Expression -> Constraint
(#=) left right = addConstraint $ left `Equals` right

-- | Enforces that two expressions have different values.
infix 4 #/=
(#/=) :: Expression -> Expression -> Constraint
(#/=) left right = addConstraint $ left `NotEquals` right

-- | Enforces that the value of an expression is less than the value of another
-- expression.
infix 4 #<
(#<) :: Expression -> Expression -> Constraint
(#<) left right = addConstraint $ left `LessThan` right

-- | Enforces that the value of an expression is less than or equal to the
-- value of another expression.
infix 4 #<=
(#<=) :: Expression -> Expression -> Constraint
(#<=) left right =
  addConstraint $ (left `LessThan` right) `Or` (left `Equals` right)

-- | Enforces that the value of an expression is greater than the value of
-- another expression.
infix 4 #>
(#>) :: Expression -> Expression -> Constraint
(#>) = flip (#<)

-- | Enforces that the value of an expression is greater than or equal to
-- the value of another expression.
infix 4 #>=
(#>=) :: Expression -> Expression -> Constraint
(#>=) = flip (#<=)

-- | Conjunction of constraints, i.e. both constraints must hold.
infixl 3 /\
(/\) :: Constraint -> Constraint -> Constraint
(/\) = (>>)

-- | Disjunction of constraints, i.e. at least one of the two constraints must
-- hold.
infixl 2 \/
(\/) :: Constraint -> Constraint -> Constraint
(\/) left right = FD $ do
  (cs, ix) <- get
  let (lcs, lx) = execState (unFD left) ([], ix)
      (rcs, nx) = execState (unFD right) ([], lx)
  case (lcs, rcs) of
    (  [],    _) -> put (cs ++ rcs, nx)
    (   _,   []) -> put (cs ++ lcs, nx)
    (l:ls, r:rs) ->
      let leftAnd    = foldl And l ls
          rightAnd   = foldl And r rs
          constraint = leftAnd `Or` rightAnd
      in put (constraint : cs, nx)

-- | Negates a constraint, i.e. the specified constraint must not hold.
not' :: Constraint -> Constraint
not' constraint = FD $ do
  (cons, ix) <- get
  let (ncs, nx) = execState (unFD constraint) ([], ix)
  case ncs of
    []   -> put (cons, nx)
    c:cs -> put (recNot (foldl And c cs) : cons, nx)
  where
    recNot (Equals l r)    = NotEquals l r
    recNot (NotEquals l r) = Equals l r
    recNot (LessThan l r)  = (r `LessThan` l) `Or` (r `Equals` l)
    recNot (And l r)       = Or (recNot l) (recNot r)
    recNot (Or l r)        = And (recNot l) (recNot r)

-- | Enforces that the all the given expressions have different values.
allDifferent :: [Expression] -> Constraint
allDifferent []     = pure ()
allDifferent (c:cs) = do
  sequence_ (fmap (c #/=) cs)
  allDifferent cs

-- | Enforces that the value of an expression lies within two bounds.
between :: Expression -- ^ The inclusive lower bound of the expression.
        -> Expression -- ^ The inclusive upper bound of the expression.
        -> Expression -- ^ The expression to be constrained.
        -> Constraint -- ^ The resulting constraint.
between low high target = do
  target #>= low
  target #<= high

-- | Creates a new variable with a default value range from negative infinity
-- to positive infinity. Values are assigned to variables during 'labeling'.
newVar :: FD Expression
newVar = FD $ do
  (cs, iD) <- get
  put (cs, iD + 1)
  pure (var iD)

-- | Creates a new variable and initializes it with a specific value.
initNewVar :: Expression -> FD Expression
initNewVar initExpr = do
  v <- newVar
  v #= initExpr
  pure v

-- | A labeling is the result of the solver when trying to assign values to
-- variables according to the given constraints. Solutions have the same
-- ordering as the expressions specified during 'labeling', so operations like
-- 'zip' can be used to relate expressions to their solution values.
data Labeling a
  = Unsolvable [D.Domain Int]
  -- ^ Indicates that the given set of constraints cannot be solved, i.e.
  --   there is no combination of values for the labelled variables to fulfil
  --   all the constraints. The provided list contains the values that were
  --   narrowed down during the search.
  | Unbounded  [D.Domain Int]
  -- ^ Indicates that the given set of constraints cannot be solved in its
  --   current form because at least one variable has a lower bound of negative
  --   infinity or an upper bound of positive infinity (i.e., potential
  --   solutions cannot be enumerated). The provided list contains the values
  --   that were narrowed down during the search.
  | Solutions  [a]
  -- ^ Indicates a successful assignment of values to the labelled variables.
  --   The list contains all possible solutions.
  deriving (Eq, Show)

instance Functor Labeling where
  fmap _ (Unsolvable ds) = Unsolvable ds
  fmap _ (Unbounded  ds) = Unbounded ds
  fmap f (Solutions  xs) = Solutions (fmap f xs)

instance Applicative Labeling where
  pure a = Solutions [a]
  Unsolvable ds <*> _             = Unsolvable ds
  Unbounded ds  <*> _             = Unbounded ds
  _             <*> Unsolvable ds = Unsolvable ds
  _             <*> Unbounded ds  = Unbounded ds
  Solutions f   <*> Solutions a   = Solutions (f <*> a)

instance Monad Labeling where
  Unsolvable ds >>= _ = Unsolvable ds
  Unbounded  ds >>= _ = Unbounded ds
  Solutions  xs >>= f = go [] xs
    where
      go acc []     = Solutions acc
      go acc (y:ys) =
        case f y of
          Unsolvable ds -> Unsolvable ds
          Unbounded  ds -> Unbounded ds
          Solutions  bs -> go (acc ++ bs) ys

-- | Searches all combinations of values for the given expressions (variables,
-- most likely) which fulfil all the constraints defined in the FD monad.
--
-- The result list has the same ordering as the expressions, so operations like
-- 'zip' are possible to relate the given expressions to their solution values.
labeling :: [Expression] -> FD (Labeling [Int])
labeling vars = do
  cons <- FD (fmap fst get)
  pure $
    runST $ do
      (res, rvs, cells) <- cellifyConstraints cons []
      trueCell <- cell True mustHoldJoin
      allCell  <- cell D.maxDomain domainJoin
      forM_ res $ \c -> connect c trueCell pure
      let userCells = fmap snd userView
          userView  =
            flip fmap vars $ \v -> do
              case find ((== v) . fst) rvs of
                Just av -> av
                Nothing -> (v, allCell)
      propagation <- propagateMany cells
      snapshot    <- forM userCells P.readCell
      if not (succeeded propagation) then
        pure (Unsolvable snapshot)
      else do
        if any D.isInfinite snapshot then
          pure (Unbounded snapshot)
        else do
          result <- label (concat . D.elems) D.singleton userCells
          case result of
            [] -> pure (Unsolvable snapshot)
            xs -> pure (Solutions xs)

type DomainCell s = Cell s (D.Domain Int)
type LogicCell s  = Cell s Bool
type VarCell s    = (Expression, DomainCell s)

-- | Converts constraints to propagator cells.
-- The result consists of the new cells that represent the constraints, a list
-- of currently declared variables, and a list of all cells that were created
-- for the constraints.
cellifyConstraints
  :: [IntConstraint]
  -> [VarCell s]
  -> ST s ([LogicCell s], [VarCell s], [DomainCell s])
cellifyConstraints cons vars =
  case cons of
    [] -> pure ([], vars, [])
    c:cs -> do
      (ls, nvs, xs) <- cellifyConstraint c vars
      (rs, rvs, ys) <- cellifyConstraints cs nvs
      pure (ls : rs, rvs, xs ++ ys)

-- | Converts a constraint to a propagator cell.
-- The result consists of the new cell that represents the constraint, a list
-- of currently declared variables, and a list of all cells that were created
-- for this constraint.
cellifyConstraint
  :: IntConstraint
  -> [VarCell s]
  -> ST s (LogicCell s, [VarCell s], [DomainCell s])
cellifyConstraint constraint vars =
  case constraint of
    Equals left right ->
      binary left right sync
    NotEquals left right ->
      binary left right (syncWith D.notEqual D.notEqual)
    LessThan left right ->
      binary left right (syncWith D.greaterThanDomain D.lessThanDomain)
    And left right -> do
      (ls, nvs, xs) <- cellifyConstraint left vars
      (rs, rvs, ys) <- cellifyConstraint right nvs
      newCell <- cell True eqJoin
      connect ls newCell (\ld -> (ld &&) <$> P.readCell rs)
      connect rs newCell (\rd -> (&& rd) <$> P.readCell ls)
      pure (newCell, rvs, xs ++ ys)
    Or left right -> do
      (ls, lvs, xs) <- cellifyConstraint left []
      (rs, rvs, ys) <- cellifyConstraint right []
      newCell <- cell True eqJoin
      connect ls newCell (\ld -> (ld ||) <$> P.readCell rs)
      connect rs newCell (\rd -> (|| rd) <$> P.readCell ls)
      let (pairs, solos) = split lvs rvs
      pairNews <-
        forM pairs $ \(v,lc,rc) -> do
          (varCell, new) <-
            case find ((v ==) . fst) vars of
              Just av -> pure (snd av, [])
              Nothing -> do
                varCell <- cell D.maxDomain domainJoin
                pure (varCell, [(v, varCell)])
          connect varCell lc pure
          connect varCell rc pure
          connect lc varCell (\ld -> D.union ld     <$> P.readCell rc)
          connect rc varCell (\rd -> (`D.union` rd) <$> P.readCell lc)
          pure new
      soloNews <-
        forM solos $ \(v,sc) -> do
          (varCell, new) <-
            case find ((v ==) . fst) vars of
              Just av -> pure (snd av, [])
              Nothing -> do
                varCell <- cell D.maxDomain domainJoin
                pure (varCell, [(v, varCell)])
          connect varCell sc pure
          pure new
      pure (newCell, concat pairNews ++ concat soloNews ++ vars, xs ++ ys)
  where
    binary left right wire = do
      (ls, lcs, xs) <- cellifyExpression left vars
      (rs, rcs, ys) <- cellifyExpression right lcs
      newCell <- cell True eqJoin
      _ <- wire ls rs
      connect ls newCell (pure . not . D.null)
      connect rs newCell (pure . not . D.null)
      pure (newCell, rcs, xs ++ ys)

-- | Tries to find an expression in a cell list. If it is found, it is removed
-- from the list.
extract :: Expression -> [VarCell s] -> Maybe (VarCell s, [VarCell s])
extract _ []     = Nothing
extract a (x:xs) | a == fst x = Just (x, xs)
                 | otherwise  = do (b, rs) <- extract a xs
                                   pure (b, x : rs)

-- | Takes to list of cells and checks if any two entries of the lists stand
-- for the same expression. If so, they land in the first tuple entry of the
-- result (i.e., we pair them). All other "solo entries" of both lists with no
-- counterpart are collected in the second tuple entry.
split :: [VarCell s]
      -> [VarCell s]
      -> ([(Expression, DomainCell s, DomainCell s)], [VarCell s])
split []     right = ([], right)
split (x:xs) right =
  case extract xVar right of
    Just (a, rs) ->
      let (ps, vs) = split xs rs
      in ((xVar, snd x, snd a) : ps, vs)
    Nothing ->
      let (ps, vs) = split xs right
      in (ps, x : vs)
  where
    xVar = fst x