summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorlyxia <>2017-03-05 20:25:00 (GMT)
committerhdiff <hdiff@hdiff.luite.com>2017-03-05 20:25:00 (GMT)
commita0a76156ef45e3428161d250727512c86c22ac42 (patch)
tree127e23df4aaf9cf3f5a73e3511460010a52d1e9c
version 0.1.0.00.1.0.0
-rw-r--r--LICENSE20
-rw-r--r--README.md54
-rw-r--r--Setup.hs2
-rw-r--r--bench/binaryTree.hs100
-rw-r--r--boltzmann-samplers.cabal90
-rw-r--r--src/Boltzmann/Data.hs313
-rw-r--r--src/Boltzmann/Data/Common.hs39
-rw-r--r--src/Boltzmann/Data/Data.hs152
-rw-r--r--src/Boltzmann/Data/Oracle.hs506
-rw-r--r--src/Boltzmann/Data/Types.hs197
-rw-r--r--src/Boltzmann/Solver.hs69
-rw-r--r--src/Boltzmann/Species.hs220
-rw-r--r--test/Test/Stats.hs76
-rw-r--r--test/tree.hs85
14 files changed, 1923 insertions, 0 deletions
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..eebc804
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,20 @@
+Copyright 2017 Li-yao Xia
+Copyright 2017 Li-yao Xia
+
+Permission is hereby granted, free of charge, to any person obtaining a copy of
+this software and associated documentation files (the "Software"), to deal in
+the Software without restriction, including without limitation the rights to
+use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
+of the Software, and to permit persons to whom the Software is furnished to do
+so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE. \ No newline at end of file
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..657b0df
--- /dev/null
+++ b/README.md
@@ -0,0 +1,54 @@
+Boltzmann samplers [![Hackage](https://img.shields.io/hackage/v/boltzmann-samplers.svg)](https://hackage.haskell.org/package/generic-random) [![Build Status](https://travis-ci.org/Lysxia/generic-random.svg)](https://travis-ci.org/Lysxia/boltzmann-samplers)
+==================
+
+`Boltzmann.Data`
+----------------
+
+Define sized random generators for `Data.Data` generic types.
+
+```haskell
+ {-# LANGUAGE DeriveDataTypeable #-}
+
+ import Data.Data
+ import Test.QuickCheck
+ import Boltzmann.Data
+
+ data Term = Lambda Int Term | App Term Term | Var Int
+ deriving (Show, Data)
+
+ instance Arbitrary Term where
+ arbitrary = sized $ generatorPWith [positiveInts]
+
+ positiveInts :: Alias Gen
+ positiveInts =
+ alias $ \() -> fmap getPositive arbitrary :: Gen Int
+
+ main = sample (arbitrary :: Gen Term)
+```
+
+- Objects of the same size (number of constructors) occur with the same
+ probability (see Duchon et al., references below).
+- Implements rejection sampling and pointing.
+- Works with QuickCheck and MonadRandom, but also similar user-defined monads
+ for randomness (just implement `MonadRandomLike`).
+- Can be tweaked somewhat with user defined generators.
+
+`Boltzmann.Species`
+-------------------
+
+An experimental interface to obtain Boltzmann samplers from an applicative
+specification of a combinatorial system.
+
+No documentation (yet).
+
+References
+----------
+
+- The core theory of Boltzmann samplers is described in
+ [Boltzmann Samplers for the Random Generation of Combinatorial Structures](http://algo.inria.fr/flajolet/Publications/DuFlLoSc04.pdf),
+ P. Duchon, P. Flajolet, G. Louchard, G. Schaeffer.
+
+- The numerical evaluation of recursively defined generating functions
+ is taken from
+ [Boltzmann Oracle for Combinatorial Systems](http://www.dmtcs.org/pdfpapers/dmAI0132.pdf),
+ C. Pivoteau, B. Salvy, M. Soria.
diff --git a/Setup.hs b/Setup.hs
new file mode 100644
index 0000000..9a994af
--- /dev/null
+++ b/Setup.hs
@@ -0,0 +1,2 @@
+import Distribution.Simple
+main = defaultMain
diff --git a/bench/binaryTree.hs b/bench/binaryTree.hs
new file mode 100644
index 0000000..d09d4b0
--- /dev/null
+++ b/bench/binaryTree.hs
@@ -0,0 +1,100 @@
+{-# LANGUAGE DeriveDataTypeable #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE TemplateHaskell #-}
+
+module Main where
+
+import Control.Applicative
+import Control.Monad
+import Control.Monad.Trans.Class
+import Data.Bool
+import Data.Data
+import Data.Functor
+import GHC.Generics
+import Control.DeepSeq
+import Criterion.Main
+import Test.Feat
+import Test.QuickCheck
+import Test.QuickCheck.Gen
+import Test.QuickCheck.Random
+import Control.Exception ( evaluate )
+
+import Boltzmann.Data
+import Boltzmann.Data.Data
+import Boltzmann.Data.Types
+
+data T = N T T | L
+ deriving (Eq, Ord, Show, Data, Typeable, Generic)
+
+instance NFData T
+
+deriveEnumerable ''T
+
+size :: Num a => T -> a
+size L = 1
+size (N l r) = 1 + size l + size r
+
+gen1 :: Int -> Gen T
+gen1 n = runRejectT (tolerance epsilon (n + 1)) gen'
+ where
+ gen' = incr >> lift arbitrary >>= bool (return L) (liftA2 N gen' gen')
+
+gen2 :: Int -> Gen T
+gen2 n = g
+ where
+ (minSize, maxSize) = tolerance epsilon (n + 1)
+ g = gen' 0 (\m t -> if m < minSize then g else return t)
+ gen' n k | n >= maxSize = g
+ gen' n k =
+ arbitrary >>= bool
+ (k (n+1) L)
+ (gen' (n+1) $ \m l -> gen' m $ \m r -> k m (N l r))
+
+genFeat :: Int -> Gen T
+genFeat = uniform
+
+main = newQCGen >>= \g -> defaultMain $ liftA2 (\n f -> f n g)
+ [4 ^ e | e <- [1 .. 6]]
+
+ -- Singular rejection sampling
+ [ bg "handwritten1" gen1
+ , bg "handwritten2" gen2
+
+ , bg "feat" genFeat
+
+ -- Pointed generator
+ , bg "P" generatorP'
+
+ -- Pointed generator with rejection sampling
+ , bg "PR" generatorPR'
+
+ , bg "SR" generatorSR
+
+ -- Sized rejection sampling
+ , bg "R" generatorR'
+
+ -- Sized rejection sampling, not memoizing oracle
+ , bg' "R-recomp" generatorR'
+
+ -- Pointed generator, not memoizing oracle
+ , bg' "P-recomp" generatorP'
+ ]
+
+bg, bg' :: String -> (Int -> Gen T) -> Int -> QCGen -> Benchmark
+bg name gen n g =
+ bench (name ++ "_" ++ show n) $ nf f g
+ where
+ go 0 = return (0 :: Int)
+ go k = liftA2 (\t s -> size t + s) gg (go (k-1))
+ gg = gen n
+ f g = unGen (go 100) g 0
+
+bg' name gen n g =
+ bench (name ++ "_" ++ show n) $ nf f (n, g)
+ where
+ go n 0 = return (0 :: Int)
+ go n k = liftA2 (\t s -> size t + s) (gen n) (go n (k-1))
+ f (n, g) = unGen (go n 100) g 0
+
+avgSize :: [T] -> Double
+avgSize ts = sum (fmap size ts) / fromIntegral (length ts)
diff --git a/boltzmann-samplers.cabal b/boltzmann-samplers.cabal
new file mode 100644
index 0000000..8e3a5f7
--- /dev/null
+++ b/boltzmann-samplers.cabal
@@ -0,0 +1,90 @@
+name: boltzmann-samplers
+version: 0.1.0.0
+synopsis: Uniform random generators
+description:
+
+ Random generators with a uniform distribution conditioned
+ to a given size.
+
+ See also @<http://hackage.haskell.org/package/testing-feat testing-feat>@,
+ which is currently a faster method with similar results.
+
+homepage: https://github.com/Lysxia/boltzmann-samplers#readme
+license: MIT
+license-file: LICENSE
+author: Li-yao Xia
+maintainer: lysxia@gmail.com
+category: Data, Generic, Random
+build-type: Simple
+extra-source-files: README.md
+cabal-version: >=1.10
+
+flag test
+ Description:
+ Enable testing. Disabled by default because the current test suite
+ is slow and can fail with non-zero probability.
+ Manual: True
+ Default: False
+
+library
+ hs-source-dirs: src
+ exposed-modules:
+ Boltzmann.Data
+ Boltzmann.Data.Data
+ Boltzmann.Data.Common
+ Boltzmann.Data.Oracle
+ Boltzmann.Data.Types
+ Boltzmann.Solver
+ Boltzmann.Species
+ build-depends:
+ ad,
+ base >= 4.9 && < 5,
+ containers,
+ hashable,
+ hmatrix,
+ ieee754,
+ unordered-containers,
+ MonadRandom,
+ mtl,
+ QuickCheck,
+ transformers,
+ vector
+ default-language: Haskell2010
+
+source-repository head
+ type: git
+ location: https://github.com/Lysxia/boltzmann-samplers
+
+test-suite test-tree
+ type: exitcode-stdio-1.0
+ hs-source-dirs: test
+ main-is: tree.hs
+ default-language: Haskell2010
+ other-modules:
+ Test.Stats
+ if flag(test)
+ build-depends:
+ base,
+ QuickCheck,
+ optparse-generic,
+ boltzmann-samplers
+ else
+ buildable: False
+
+benchmark bench-binarytree
+ type: exitcode-stdio-1.0
+ hs-source-dirs: bench
+ main-is: binaryTree.hs
+ default-language: Haskell2010
+ ghc-options: -O2
+ if flag(test)
+ build-depends:
+ base,
+ criterion,
+ deepseq,
+ QuickCheck,
+ transformers,
+ testing-feat,
+ boltzmann-samplers
+ else
+ buildable: False
diff --git a/src/Boltzmann/Data.hs b/src/Boltzmann/Data.hs
new file mode 100644
index 0000000..bf5ac23
--- /dev/null
+++ b/src/Boltzmann/Data.hs
@@ -0,0 +1,313 @@
+-- | Generic Boltzmann samplers.
+--
+-- Here, the words "/sampler/" and "/generator/" are used interchangeably.
+--
+-- Given an algebraic datatype:
+--
+-- > data A = A1 B C | A2 D
+--
+-- a Boltzmann sampler is recursively defined by choosing a constructor with
+-- some fixed distribution, and /independently/ generating values for the
+-- corresponding fields with the same method.
+--
+-- A key component is the aforementioned distribution, defined for every type
+-- such that the resulting generator produces a finite value in the end. These
+-- distributions are obtained from a precomputed object called /oracle/, which
+-- we will not describe further here.
+--
+-- Oracles depend on the target size of the generated data (except for singular
+-- samplers), and can be fairly expensive to compute repeatedly, hence some of
+-- the functions below attempt to avoid (re)computing too many of them even
+-- when the required size changes.
+--
+-- When these functions are specialized, oracles are memoized and will be
+-- reused for different sizes.
+
+module Boltzmann.Data (
+ Size',
+ -- * Main functions
+ -- $sized
+ generatorSR,
+ generatorP,
+ generatorPR,
+ generatorR,
+ -- ** Fixed size
+ -- $fixed
+ generatorP',
+ generatorPR',
+ generatorR',
+ generator',
+ -- * Generators with aliases
+ -- $aliases
+ generatorSRWith,
+ generatorPWith,
+ generatorPRWith,
+ generatorRWith,
+ -- ** Fixed size
+ generatorPWith',
+ generatorPRWith',
+ generatorRWith',
+ generatorWith',
+ -- * Other generators
+ -- $other
+ Points,
+ generatorM,
+ generatorMR,
+ generator_,
+ generatorR_,
+ -- * Auxiliary definitions
+ -- ** Type classes
+ MonadRandomLike (..),
+ AMonadRandom (..),
+ -- ** Alias
+ alias,
+ aliasR,
+ coerceAlias,
+ coerceAliases,
+ Alias (..),
+ AliasR,
+ ) where
+
+import Data.Data
+import Boltzmann.Data.Data
+import Boltzmann.Data.Types
+
+-- * Main functions
+
+-- $sized
+--
+-- === Suffixes
+--
+-- [@S@] Singular sampler.
+--
+-- This works with recursive tree-like structures, as opposed to (lists of)
+-- structures with bounded size. More precisely, the generating function of
+-- the given type should have a finite radius of convergence, with a
+-- singularity of a certain kind (see Duchon et al., reference in the
+-- README), so that the oracle can be evaluated at that point.
+--
+-- This has the advantage of using the same oracle for all size parameters,
+-- which simply specify a target size interval.
+--
+-- [@P@] Generator of pointed values.
+--
+-- It usually has a flatter distribution of sizes than a simple Boltzmann
+-- sampler, making it an efficient alternative to rejection sampling.
+--
+-- It also works on more types, particularly lists and finite types,
+-- but relies on multiple oracles.
+--
+-- [@R@] Rejection sampling.
+--
+-- These generators filter out values whose sizes are not within some
+-- interval. In the first two sections, that interval is implicit:
+-- @[(1-'epsilon')*size', (1+'epsilon')*size']@, for @'epsilon' = 0.1@.
+--
+-- The generator restarts as soon as it has produced more constructors than
+-- the upper bound, this strategy is called /ceiled rejection sampling/.
+--
+-- = Pointing
+--
+-- The /pointing/ of a type @t@ is a derived type whose values are essentially
+-- values of type @t@, with one of their constructors being "pointed".
+-- Alternatively, we may turn every constructor into variants that indicate
+-- the position of points.
+--
+-- @
+-- -- Original type
+-- data Tree = Node Tree Tree | Leaf
+-- -- Pointing of Tree
+-- data Tree'
+-- = Tree' Tree -- Point at the root
+-- | Node'0 Tree' Tree -- Point to the left
+-- | Node'1 Tree Tree' -- Point to the right
+-- @
+--
+-- Pointed values are easily mapped back to the original type by erasing the
+-- point. Pointing makes larger values occur much more frequently, while
+-- preserving the uniformness of the distribution conditionally to a fixed
+-- size.
+--
+
+-- | @
+-- 'generatorSR' :: Int -> 'Gen' a
+-- 'asMonadRandom' . 'generatorSR' :: 'MonadRandom' m => Int -> m a
+-- @
+--
+-- Singular ceiled rejection sampler.
+generatorSR :: (Data a, MonadRandomLike m) => Size' -> m a
+generatorSR = generatorSRWith []
+
+-- | @
+-- 'generatorP' :: Int -> 'Gen' a
+-- 'asMonadRandom' . 'generatorP' :: 'MonadRandom' m => Int -> m a
+-- @
+--
+-- Generator of pointed values.
+
+generatorP :: (Data a, MonadRandomLike m) => Size' -> m a
+generatorP = generatorPWith []
+
+-- | Pointed generator with rejection.
+generatorPR :: (Data a, MonadRandomLike m) => Size' -> m a
+generatorPR = generatorPRWith []
+
+-- | Generator with rejection and dynamic average size.
+generatorR :: (Data a, MonadRandomLike m) => Size' -> m a
+generatorR = generatorRWith []
+
+-- ** Fixed size
+
+-- $fixed
+-- The @'@ suffix indicates functions which do not do any
+-- precomputation before passing the size parameter.
+--
+-- This means that oracles are computed from scratch for every size value,
+-- which may incur a significant overhead.
+
+-- | Pointed generator.
+generatorP' :: (Data a, MonadRandomLike m) => Size' -> m a
+generatorP' = generatorPWith' []
+
+-- | Pointed generator with rejection.
+generatorPR' :: (Data a, MonadRandomLike m) => Size' -> m a
+generatorPR' = generatorPRWith' []
+
+-- | Ceiled rejection sampler with given average size.
+generatorR' :: (Data a, MonadRandomLike m) => Size' -> m a
+generatorR' = generatorRWith' []
+
+-- | Basic boltzmann sampler with no optimization.
+generator' :: (Data a, MonadRandomLike m) => Size' -> m a
+generator' = generatorWith' []
+
+-- * Generators with aliases
+
+-- $aliases
+-- Boltzmann samplers can normally be defined only for types @a@ such that:
+--
+-- - they are instances of 'Data';
+-- - the set of types of subterms of values of type @a@ is finite;
+-- - and all of these types have at least one finite value (i.e., values with
+-- finitely many constructors).
+--
+-- Examples of misbehaving types are:
+--
+-- - @a -> b -- Not Data@
+-- - @data E a = L a | R (E [a]) -- Contains a, [a], [[a]], [[[a]]], etc.@
+-- - @data I = C I -- No finite value@
+--
+-- = Alias
+--
+-- The 'Alias' type works around these limitations ('AliasR' for rejection
+-- samplers).
+-- This existential wrapper around a user-defined function @f :: a -> m b@
+-- makes @boltzmann-samplers@ view occurences of the type @b@ as @a@ when
+-- processing a recursive system of types, possibly stopping some infinite
+-- unrolling of type definitions. When a value of type @b@ needs to be
+-- generated, it generates an @a@ which is passed to @f@.
+--
+-- @
+-- let
+-- as = ['aliasR' $ \\() -> return (L []) :: 'Gen' (E [[Int]])]
+-- in
+-- 'generatorSRWith' as 'asGen' :: 'Size' -> 'Gen' (E Int)
+-- @
+--
+-- Another use case is to plug in user-defined generators where the default is
+-- not satisfactory, for example, to generate positive @Int@s:
+--
+-- @
+-- let
+-- as = ['alias' $ \\() -> 'choose' (0, 100) :: 'Gen' Int)]
+-- in
+-- 'generatorPWith' as 'asGen' :: 'Size' -> 'Gen' [Int]
+-- @
+--
+-- or to modify the weights assigned to some types. In particular, in some
+-- cases it seems preferable to make @String@ (and @Text@) have the same weight
+-- as @Int@ and @()@.
+--
+-- @
+-- let
+-- as = ['alias' $ \\() -> arbitrary :: 'Gen' String]
+-- in
+-- 'generatorPWith' as 'asGen' :: 'Size' -> 'Gen' (Either Int String)
+-- @
+
+generatorSRWith
+ :: (Data a, MonadRandomLike m) => [AliasR m] -> Size' -> m a
+generatorSRWith aliases =
+ generatorR_ aliases 0 Nothing . tolerance epsilon
+
+generatorPRWith
+ :: (Data a, MonadRandomLike m) => [AliasR m] -> Size' -> m a
+generatorPRWith aliases size' =
+ generatorMR aliases 1 size' (tolerance epsilon size')
+
+generatorPWith
+ :: (Data a, MonadRandomLike m) => [Alias m] -> Size' -> m a
+generatorPWith aliases = generatorM aliases 1
+
+generatorRWith
+ :: (Data a, MonadRandomLike m) => [AliasR m] -> Size' -> m a
+generatorRWith aliases size' =
+ generatorMR aliases 0 size' (tolerance epsilon size')
+
+-- ** Fixed size
+
+generatorPWith'
+ :: (Data a, MonadRandomLike m) => [Alias m] -> Size' -> m a
+generatorPWith' aliases = generator_ aliases 1 . Just
+
+generatorPRWith'
+ :: (Data a, MonadRandomLike m) => [AliasR m] -> Size' -> m a
+generatorPRWith' aliases size' =
+ generatorR_ aliases 1 (Just size') (tolerance epsilon size')
+
+generatorRWith'
+ :: (Data a, MonadRandomLike m) => [AliasR m] -> Size' -> m a
+generatorRWith' aliases size' =
+ generatorR_ aliases 0 (Just size') (tolerance epsilon size')
+
+generatorWith'
+ :: (Data a, MonadRandomLike m) => [Alias m] -> Size' -> m a
+generatorWith' aliases = generator_ aliases 0 . Just
+
+-- * Other generators
+
+-- $other Used in the implementation of the generators above.
+-- These also allow to apply pointing more than once.
+--
+-- === Suffixes
+--
+-- [@M@] Sized generators are memoized for some sparsely chosen values of
+-- sizes. Subsequently supplied sizes are approximated by the closest larger
+-- value. This strategy avoids recomputing too many oracles. Aside from
+-- singular samplers, all other generators above not marked by @'@ use this.
+--
+-- [@_@] If the size parameter is @Nothing@, produces the singular generator
+-- (associated with the suffix @S@); otherwise the generator produces values
+-- with average size equal to the given value.
+
+generatorM
+ :: (Data a, MonadRandomLike m)
+ => [Alias m] -> Points -> Size' -> m a
+generatorM = memo make apply
+
+generatorMR
+ :: (Data a, MonadRandomLike m)
+ => [AliasR m] -> Points -> Size' -> (Size', Size') -> m a
+generatorMR = memo makeR applyR
+
+-- | Boltzmann sampler without rejection.
+generator_
+ :: (Data a, MonadRandomLike m)
+ => [Alias m] -> Points -> Maybe Size' -> m a
+generator_ aliases = apply (make aliases [])
+
+-- | Boltzmann sampler with rejection.
+generatorR_
+ :: (Data a, MonadRandomLike m)
+ => [AliasR m] -> Points -> Maybe Size' -> (Size', Size') -> m a
+generatorR_ aliases = applyR (makeR aliases [])
diff --git a/src/Boltzmann/Data/Common.hs b/src/Boltzmann/Data/Common.hs
new file mode 100644
index 0000000..121567c
--- /dev/null
+++ b/src/Boltzmann/Data/Common.hs
@@ -0,0 +1,39 @@
+-- | General helper functions
+
+module Boltzmann.Data.Common where
+
+frequencyWith
+ :: (Ord r, Num r, Monad m) => (r -> m r) -> [(r, m a)] -> m a
+frequencyWith _ [(_, a)] = a
+frequencyWith randomR as = randomR total >>= select as
+ where
+ total = (sum . fmap fst) as
+ select ((w, a) : as) x
+ | x < w = a
+ | otherwise = select as (x - w)
+ select _ _ = (snd . head) as
+ -- That should not happen in theory, but floating point might be funny.
+
+-- | @partitions k n@: lists of non-negative integers of length @n@ with sum
+-- less than or equal to @k@.
+partitions :: Int -> Int -> [[Int]]
+partitions _ 0 = [[]]
+partitions k n = do
+ p <- [0 .. k]
+ (p :) <$> partitions (k - p) (n - 1)
+
+-- | Binomial coefficient.
+--
+-- > binomial n k == factorial n `div` (factorial k * factorial (n-k))
+binomial :: Int -> Int -> Integer
+binomial = \n k -> pascal !! n !! k
+ where
+ pascal = [1] : fmap nextRow pascal
+ nextRow r = zipWith (+) (0 : r) (r ++ [0])
+
+-- | Multinomial coefficient.
+--
+-- > multinomial n ps == factorial n `div` product [factorial p | p <- ps]
+multinomial :: Int -> [Int] -> Integer
+multinomial _ [] = 1
+multinomial n (p : ps) = binomial n p * multinomial (n - p) ps
diff --git a/src/Boltzmann/Data/Data.hs b/src/Boltzmann/Data/Data.hs
new file mode 100644
index 0000000..ec9061f
--- /dev/null
+++ b/src/Boltzmann/Data/Data.hs
@@ -0,0 +1,152 @@
+-- | Internal module
+--
+-- Derive Boltzmann samplers for SYB-generic types.
+
+{-# LANGUAGE DeriveFunctor #-}
+{-# LANGUAGE RecordWildCards #-}
+
+module Boltzmann.Data.Data where
+
+import Control.Arrow ( (&&&) )
+import Control.Applicative
+import Data.Data
+import Data.Foldable
+import Data.Maybe
+import qualified Data.HashMap.Lazy as HashMap
+import Boltzmann.Data.Oracle
+import Boltzmann.Data.Types
+
+-- | Sized generator.
+data SG r = SG
+ { minSize :: Size
+ , maxSizeM :: Maybe Size
+ , runSG :: Points -> Maybe Double -> r
+ , runSmallG :: Points -> r
+ } deriving Functor
+
+-- | Number of pointing iterations.
+type Points = Int
+
+rangeSG :: SG r -> (Size, Maybe Size)
+rangeSG = minSize &&& maxSizeM
+
+-- | For documentation.
+applySG :: SG r -> Points -> Maybe Double -> r
+applySG SG{..} k sizeM
+ | Just minSize == maxSizeM = runSG k (fmap fromIntegral maxSizeM)
+ | Just size <- sizeM, size <= fromIntegral minSize =
+ error "Target size too small."
+ | Just True <- liftA2 ((<=) . fromIntegral) maxSizeM sizeM =
+ error "Target size too large."
+ | Nothing <- sizeM, Just _ <- maxSizeM =
+ error "Cannot make singular sampler for finite type."
+ | otherwise = runSG k sizeM
+
+-- * Helper functions
+
+make :: (Data a, MonadRandomLike m)
+ => [Alias m] -> proxy a -> SG (m a)
+make aliases a =
+ SG minSize maxSizeM make' makeSmall
+ where
+ dd = collectTypes aliases a
+ t = typeRep a
+ i = case index dd #! t of
+ Left j -> fst (xedni' dd #! j)
+ Right i -> i
+ minSize = natToInt $ fst (lTerm dd #! i)
+ maxSizeM = HashMap.lookup i (degree dd)
+ make' k sizeM = getGenerator dd' generators a k
+ where
+ dd' = dds !! k
+ oracle = makeOracle dd' t sizeM
+ generators = makeGenerators dd' oracle
+ makeSmall k = getSmallGenerator dd' (smallGenerators dd') a
+ where
+ dd' = dds !! k
+ dds = iterate point dd
+
+makeR :: (Data a, MonadRandomLike m)
+ => [AliasR m] -> proxy a
+ -> SG ((Size, Size) -> m a)
+makeR aliases a = fmap (flip runRejectT) (make aliases a)
+
+-- | The size of a value is its number of constructors.
+--
+-- Here, however, the 'Size'' type is interpreted differently to make better
+-- use of QuickCheck's size parameter provided by the 'Test.QuickCheck.sized'
+-- combinator, so that we generate non-trivial data even at very small size
+-- values.
+--
+-- For infinite types, with objects of unbounded sizes @> minSize@, given a
+-- parameter @delta :: 'Size''@, the produced values have an average size close
+-- to @minSize + delta@.
+--
+-- For example, values of type @Either () [Bool]@ have at least two constructors,
+-- so
+--
+-- @
+-- 'generator' delta :: 'Gen' (Either () [Bool])
+-- @
+--
+-- will target sizes close to @2 + delta@;
+-- the offset becomes less noticeable as @delta@ grows to infinity.
+--
+-- For finite types with sizes in @[minSize, maxSize]@, the target expected
+-- size is obtained by clamping a 'Size'' to @[0, 99]@ and applying an affine
+-- mapping.
+type Size' = Int
+
+rescale :: SG r -> Size' -> Double
+rescale (SG minSize (Just maxSize) _ _) size' =
+ fromIntegral minSize + fromIntegral (min 99 size' * (maxSize - minSize)) / 100
+rescale (SG minSize Nothing _ _) size' = fromIntegral (minSize + size')
+
+apply :: SG r -> Points -> Maybe Size' -> r
+apply sg k (Just 0) = runSmallG sg k
+apply sg k size' = runSG sg k (fmap (rescale sg) size')
+
+applyR :: SG ((Size, Size) -> r) -> Points -> Maybe Size' -> (Size', Size') -> r
+applyR sg k size' = apply sg k size' . rescaleInterval sg
+
+rescaleInterval :: SG r -> (Size', Size') -> (Size, Size)
+rescaleInterval sg (a', b') = (a, b)
+ where
+ a = (clamp . floor .rescale sg) a'
+ b = (clamp . ceiling . rescale sg) b'
+ clamp x
+ | Just maxSize <- maxSizeM sg, x >= 100 = maxSize
+ | otherwise = x
+
+-- | > 'epsilon' = 0.1
+--
+-- Default approximation ratio.
+epsilon :: Double
+epsilon = 0.1
+
+-- | > (size * (1 - epsilon), size * (1 + epsilon))
+tolerance :: Double -> Int -> (Int, Int)
+tolerance epsilon size = (size - delta, size + delta)
+ where
+ delta = ceiling (fromIntegral size * epsilon)
+
+-- * Auxiliary definitions
+
+memo
+ :: (t -> [t2] -> SG r)
+ -> (SG r -> t1 -> Maybe Int -> a)
+ -> t -> t1 -> Int -> a
+memo make apply aliases k = generators
+ where
+ sg = make aliases []
+ generators = sparseSized (apply sg k . Just) (99 <$ maxSizeM sg)
+
+-- Oracles are computed only for sizes that are a power of two away from
+-- the minimum size of the datatype @minSize + 2 ^ e@.
+sparseSized :: (Int -> a) -> Maybe Int -> Int -> a
+sparseSized f maxSizeM =
+ maybe a0 snd . \size' -> find ((>= size') . fst) as
+ where
+ as = [ (s, f s) | s <- ss ]
+ ss = 0 : maybe id (takeWhile . (>)) maxSizeM [ 2 ^ e | e <- [ 0 :: Int ..] ]
+ a0 = f (fromJust maxSizeM)
diff --git a/src/Boltzmann/Data/Oracle.hs b/src/Boltzmann/Data/Oracle.hs
new file mode 100644
index 0000000..63fbd84
--- /dev/null
+++ b/src/Boltzmann/Data/Oracle.hs
@@ -0,0 +1,506 @@
+-- | Internal module
+
+{-# LANGUAGE DeriveDataTypeable #-}
+{-# LANGUAGE DeriveGeneric #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE RecordWildCards #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+
+module Boltzmann.Data.Oracle where
+
+import Control.Applicative
+import Control.Monad
+import Control.Monad.Fix
+import Control.Monad.Reader
+import Control.Monad.State
+import Data.Bifunctor
+import Data.Data
+import Data.Hashable ( Hashable )
+import Data.HashMap.Lazy ( HashMap )
+import qualified Data.HashMap.Lazy as HashMap
+import Data.Maybe ( fromJust, isJust )
+import Data.Monoid
+import qualified Data.Vector as V
+import GHC.Generics ( Generic )
+import Numeric.AD
+import Boltzmann.Data.Common
+import Boltzmann.Data.Types
+import Boltzmann.Solver
+
+-- | We build a dictionary which reifies type information in order to
+-- create a Boltzmann generator.
+--
+-- We denote by @n@ (or 'count') the number of types in the dictionary.
+--
+-- Every type has an index @0 <= i < n@; the variable @X i@ represents its
+-- generating function @C_i(x)@, and @X (i + k*n)@ the GF of its @k@-th
+-- "pointing" @C_i[k](x)@; we have
+--
+-- @
+-- C_i[0](x) = C_i(x)
+-- C_i[k+1](x) = x * C_i[k]'(x)
+-- @
+--
+-- where @C_i[k]'@ is the derivative of @C_i[k]@. See also 'point'.
+--
+-- The /order/ (or /valuation/) of a power series is the index of the first
+-- non-zero coefficient, called the /leading coefficient/.
+
+data DataDef m = DataDef
+ { count :: Int -- ^ Number of registered types
+ , points :: Int -- ^ Number of iterations of the pointing operator
+ , index :: HashMap TypeRep (Either Aliased Ix) -- ^ Map from types to indices
+ , xedni :: HashMap Ix SomeData' -- ^ Inverse map from indices to types
+ , xedni' :: HashMap Aliased (Ix, Alias m) -- ^ Inverse map to aliases
+ , types :: HashMap C [(Integer, Constr, [C'])]
+ -- ^ Structure of types and their pointings (up to 'points', initially 0)
+ --
+ -- Primitive types and empty types are mapped to an empty constructor list, and
+ -- can be distinguished using 'Data.Data.dataTypeRep' on the 'SomeData'
+ -- associated to it by 'xedni'.
+ --
+ -- The integer is a multiplicity which can be > 1 for pointings.
+ , lTerm :: HashMap Ix (Nat, Integer)
+ -- ^ Leading term @a * x ^ u@ of the generating functions @C_i[k](x)@ in the
+ -- form (u, a).
+ --
+ -- [Order @u@] Smallest size of objects of a given type.
+ -- [Leading coefficient @a@] number of objects of smallest size.
+ , degree :: HashMap Ix Int
+ -- ^ Degrees of the generating functions, when applicable: greatest size of
+ -- objects of a given type.
+ } deriving Show
+
+-- | A pair @C i k@ represents the @k@-th "pointing" of the type at index @i@,
+-- with generating function @C_i[k](x)@.
+data C = C Ix Int
+ deriving (Eq, Ord, Show, Generic)
+
+instance Hashable C
+
+data AC = AC Aliased Int
+ deriving (Eq, Ord, Show, Generic)
+
+instance Hashable AC
+
+type C' = (Maybe Aliased, C)
+
+newtype Aliased = Aliased Int
+ deriving (Eq, Ord, Show, Generic)
+
+instance Hashable Aliased
+
+type Ix = Int
+
+data Nat = Zero | Succ Nat
+ deriving (Eq, Ord, Show)
+
+instance Monoid Nat where
+ mempty = Zero
+ mappend (Succ n) = Succ . mappend n
+ mappend Zero = id
+
+natToInt :: Nat -> Int
+natToInt Zero = 0
+natToInt (Succ n) = 1 + natToInt n
+
+infinity :: Nat
+infinity = Succ infinity
+
+dataDef :: [Alias m] -> DataDef m
+dataDef as = DataDef
+ { count = 0
+ , points = 0
+ , index = index
+ , xedni = HashMap.empty
+ , xedni' = xedni'
+ , types = HashMap.empty
+ , lTerm = HashMap.empty
+ , degree = HashMap.empty
+ } where
+ xedni' = HashMap.fromList (fmap (\(i, a) -> (i, (-1, a))) as')
+ index = HashMap.fromList (fmap (\(i, a) -> (ofType a, Left i)) as')
+ as' = zip (fmap Aliased [0 ..]) as
+ ofType (Alias f) = typeRep (f undefined)
+
+-- | Find all types that may be types of subterms of a value of type @a@.
+--
+-- This will loop if there are infinitely many such types.
+collectTypes :: Data a => [Alias m] -> proxy a -> DataDef m
+collectTypes as a = collectTypesM a `execState` dataDef as
+
+-- | Primitive datatypes have @C(x) = x@: they are considered as
+-- having a single object (@lCoef@) of size 1 (@order@)).
+primOrder :: Int
+primOrder = 1
+
+primOrder' :: Nat
+primOrder' = Succ Zero
+
+primlCoef :: Integer
+primlCoef = 1
+
+-- | The type of the first argument of 'Data.Data.gunfold'.
+type GUnfold m = forall b r. Data b => m (b -> r) -> m r
+
+-- | Type of 'xedni''.
+type AMap m = HashMap Aliased (Ix, Alias m)
+
+collectTypesM :: Data a => proxy a
+ -> State (DataDef m) (Either Aliased Ix, ((Nat, Integer), Maybe Int))
+collectTypesM a = chaseType a (const id)
+
+chaseType :: Data a => proxy a
+ -> ((Maybe (Alias m), Ix) -> AMap m -> AMap m)
+ -> State (DataDef m) (Either Aliased Ix, ((Nat, Integer), Maybe Int))
+chaseType a k = do
+ let t = typeRep a
+ dd@DataDef{..} <- get
+ let
+ lookup i r =
+ let
+ lTerm_i = lTerm #! i
+ degree_i = HashMap.lookup i degree
+ in return (r, (lTerm_i, degree_i))
+ case HashMap.lookup t index of
+ Nothing -> do
+ let i = count
+ put dd
+ { count = i + 1
+ , index = HashMap.insert t (Right i) index
+ , xedni = HashMap.insert i (someData' a) xedni
+ , xedni' = k (Nothing, i) xedni'
+ }
+ traverseType a i -- Updates lTerm and degree
+ Just (Right i) -> do
+ put dd { xedni' = k (Nothing, i) xedni' }
+ lookup i (Right i)
+ Just (Left j) ->
+ case xedni' #! j of
+ (-1, Alias f) -> do
+ (_, ld) <- chaseType (ofType f) $ \(alias, i) ->
+ let
+ alias' = case alias of
+ Nothing -> Alias f
+ Just (Alias g) -> Alias (composeCastM f g)
+ in
+ k (Just alias', i) . HashMap.insert j (i, alias')
+ return (Left j, ld)
+ (i, _) -> lookup i (Left j)
+ where
+ ofType :: (m a -> m b) -> m a
+ ofType _ = undefined
+
+-- | Traversal of the definition of a datatype.
+traverseType
+ :: Data a => proxy a -> Ix
+ -> State (DataDef m) (Either Aliased Ix, ((Nat, Integer), Maybe Int))
+traverseType a i = do
+ let d = withProxy dataTypeOf a
+ mfix $ \ ~(_, (lTerm_i0, _)) -> do
+ modify $ \dd@DataDef{..} -> dd
+ { lTerm = HashMap.insert i lTerm_i0 lTerm
+ }
+ (types_i, ld@(_, degree_i)) <- traverseType' a d
+ modify $ \dd@DataDef{..} -> dd
+ { types = HashMap.insert (C i 0) types_i types
+ , degree = maybe id (HashMap.insert i) degree_i degree
+ }
+ return (Right i, ld)
+
+traverseType'
+ :: Data a => proxy a -> DataType
+ -> State (DataDef m)
+ ([(Integer, Constr, [(Maybe Aliased, C)])], ((Nat, Integer), Maybe Int))
+traverseType' a d | isAlgType d = do
+ let
+ constrs = dataTypeConstrs d
+ collect
+ :: GUnfold (StateT
+ ([Either Aliased Ix], (Nat, Integer), Maybe Int)
+ (State (DataDef m)))
+ collect mkCon = do
+ f <- mkCon
+ let ofType :: (b -> a) -> Proxy b
+ ofType _ = Proxy
+ b = ofType f
+ (j, (lTerm_, degree_)) <- lift (collectTypesM b)
+ modify $ \(js, lTerm', degree') ->
+ (j : js, lMul lTerm_ lTerm', liftA2 (+) degree_ degree')
+ return (withProxy f b)
+ tlds <- forM constrs $ \constr -> do
+ (js, lTerm', degree') <-
+ gunfold collect return constr `proxyType` a
+ `execStateT` ([], (Zero, 1), Just 1)
+ dd <- get
+ let
+ c (Left j) = (Just j, C (fst (xedni' dd #! j)) 0)
+ c (Right i) = (Nothing, C i 0)
+ return ((1, constr, [ c j | j <- js]), lTerm', degree')
+ let
+ (types_i, ls, ds) = unzip3 tlds
+ lTerm_i = first Succ (lSum ls)
+ degree_i = maxDegree ds
+ return (types_i, (lTerm_i, degree_i))
+traverseType' _ _ =
+ return ([], ((primOrder', primlCoef), Just primOrder))
+
+-- | If @(u, a)@ represents a power series of leading term @a * x ^ u@, and
+-- similarly for @(u', a')@, this finds the leading term of their sum.
+--
+-- The comparison of 'Nat' is unrolled here for maximum laziness.
+lPlus :: (Nat, Integer) -> (Nat, Integer) -> (Nat, Integer)
+lPlus (Zero, lCoef) (Zero, lCoef') = (Zero, lCoef + lCoef')
+lPlus (Zero, lCoef) _ = (Zero, lCoef)
+lPlus _ (Zero, lCoef') = (Zero, lCoef')
+lPlus (Succ order, lCoef) (Succ order', lCoef') =
+ first Succ $ lPlus (order, lCoef) (order', lCoef')
+
+-- | Sum of a list of series.
+lSum :: [(Nat, Integer)] -> (Nat, Integer)
+lSum [] = (infinity, 0)
+lSum ls = foldl1 lPlus ls
+
+-- | Leading term of a product of series.
+lMul :: (Nat, Integer) -> (Nat, Integer) -> (Nat, Integer)
+lMul (order, lCoef) (order', lCoef') = (order <> order', lCoef * lCoef')
+
+lProd :: [(Nat, Integer)] -> (Nat, Integer)
+lProd = foldl lMul (Zero, 1)
+
+maxDegree :: [Maybe Int] -> Maybe Int
+maxDegree = foldl (liftA2 max) (Just minBound)
+
+-- | Pointing operator.
+--
+-- Populates a 'DataDef' with one more level of pointings.
+-- ('collectTypes' produces a dictionary at level 0.)
+--
+-- The "pointing" of a type @t@ is a derived type whose values are essentially
+-- values of type @t@, with one of their constructors being "pointed".
+-- Alternatively, we may turn every constructor into variants that indicate
+-- the position of points.
+--
+-- @
+-- -- Original type
+-- data Tree = Node Tree Tree | Leaf
+-- -- Pointing of Tree
+-- data Tree'
+-- = Tree' Tree -- Point at the root
+-- | Node'0 Tree' Tree -- Point to the left
+-- | Node'1 Tree Tree' -- Point to the right
+-- -- Pointing of the pointing
+-- -- Notice that the "points" introduced by both applications of pointing
+-- -- are considered different: exchanging their positions (when different)
+-- -- produces a different tree.
+-- data Tree''
+-- = Tree'' Tree' -- Point 2 at the root, the inner Tree' places point 1
+-- | Node'0' Tree' Tree -- Point 1 at the root, point 2 to the left
+-- | Node'1' Tree Tree' -- Point 1 at the root, point 2 to the right
+-- | Node'0'0 Tree'' Tree -- Points 1 and 2 to the left
+-- | Node'0'1 Tree' Tree' -- Point 1 to the left, point 2 to the right
+-- | Node'1'0 Tree' Tree' -- Point 1 to the right, point 2 to the left
+-- | Node'0'1 Tree Tree'' -- Points 1 and 2 to the right
+-- @
+--
+-- If we ignore points, some constructors are equivalent. Thus we may simply
+-- calculate their multiplicity instead of duplicating them.
+--
+-- Given a constructor with @c@ arguments @C x_1 ... x_c@, and a sequence
+-- @p_0 + p_1 + ... + p_c = k@ corresponding to a distribution of @k@ points
+-- (@p_0@ are assigned to the constructor @C@ itself, and for @i > 0@, @p_i@
+-- points are assigned within the @i@-th subterm), the multiplicity of the
+-- constructor paired with that distribution is the multinomial coefficient
+-- @multinomial k [p_1, ..., p_c]@.
+
+point :: DataDef m -> DataDef m
+point dd@DataDef{..} = dd
+ { points = points'
+ , types = foldl g types [0 .. count-1]
+ } where
+ points' = points + 1
+ g types i = HashMap.insert (C i points') (types' i) types
+ types' i = types #! C i 0 >>= h
+ h (_, constr, js) = do
+ ps <- partitions points' (length js)
+ let
+ mult = multinomial points' ps
+ js' = zipWith (\(j', C i _) p -> (j', C i p)) js ps
+ return (mult, constr, js')
+
+-- | An oracle gives the values of the generating functions at some @x@.
+type Oracle = HashMap C Double
+
+-- | Find the value of @x@ such that the average size of the generator
+-- for the @k-1@-th pointing is equal to @size@, and produce the associated
+-- oracle. If the size is @Nothing@, find the radius of convergence.
+--
+-- The search evaluates the generating functions for some values of @x@ in
+-- order to run a binary search. The evaluator is implemented using Newton's
+-- method, the convergence of which has been shown for relevant systems in
+-- /Boltzmann Oracle for Combinatorial Systems/,
+-- C. Pivoteau, B. Salvy, M. Soria.
+makeOracle :: DataDef m -> TypeRep -> Maybe Double -> Oracle
+makeOracle dd0 t size' =
+ seq v
+ HashMap.fromList (zip cs (V.toList v))
+ where
+ -- We need the next pointing to capture the average size in an equation.
+ dd@DataDef{..} = if isJust size' then point dd0 else dd0
+ cs = flip C <$> [0 .. points] <*> [0 .. count - 1]
+ m = count * (points + 1)
+ k = points - 1
+ i = case index #! t of
+ Left j -> fst (xedni' #! j)
+ Right i -> i
+ checkSize _ (Just ys) | V.any (< 0) ys = False
+ -- There may be solutions outside of the radius
+ -- of convergence, but with negative components.
+ checkSize (Just size) (Just ys) =
+ size >= size_
+ where
+ size_ = ys V.! j' / ys V.! j
+ j = dd ? C i k
+ j' = dd ? C i (k + 1)
+ checkSize Nothing (Just _) = True
+ checkSize _ Nothing = False
+ -- Equations defining C_i(x) for all types with indices i
+ phis :: Num a => V.Vector (a -> V.Vector a -> a)
+ phis = V.fromList [ phi dd c (types #! c) | c <- listCs dd ]
+ eval' :: Double -> Maybe (V.Vector Double)
+ eval' x = fixedPoint defSolveArgs phi' (V.replicate m 0)
+ where
+ phi' :: (Mode a, Scalar a ~ Double) => V.Vector a -> V.Vector a
+ phi' y = fmap (\f -> f (auto x) y) phis
+ v = (fromJust . snd) (search eval' (checkSize size'))
+
+-- | Generating function definition. This defines a @Phi_i[k]@ function
+-- associated with the @k@-th pointing of the type at index @i@, such that:
+--
+-- > C_i[k](x)
+-- > = Phi_i[k](x, C_0[0](x), ..., C_(n-1)[0](x),
+-- > ..., C_0[k](x), ..., C_(n-1)[k](x))
+--
+-- Primitive datatypes have @C(x) = x@: they are considered as
+-- having a single object ('lCoef') of size 1 ('order')).
+phi :: Num a => DataDef m -> C -> [(Integer, constr, [C'])]
+ -> a -> V.Vector a -> a
+phi DataDef{..} (C i _) [] =
+ case xedni #! i of
+ SomeData a ->
+ case (dataTypeRep . withProxy dataTypeOf) a of
+ AlgRep _ -> \_ _ -> 0
+ _ -> \x _ -> fromInteger primlCoef * x ^ primOrder
+phi dd@DataDef{..} _ tyInfo = f
+ where
+ f x y = x * (sum . fmap (toProd y)) tyInfo
+ toProd y (w, _, js) =
+ fromInteger w * product [ y V.! (dd ? j) | (_, j) <- js ]
+
+-- | Maps a key representing a type @a@ (or one of its pointings) to a
+-- generator @m a@.
+type Generators m = (HashMap AC (SomeData m), HashMap C (SomeData m))
+
+-- | Build all involved generators at once.
+makeGenerators
+ :: forall m. MonadRandomLike m
+ => DataDef m -> Oracle -> Generators m
+makeGenerators DataDef{..} oracle =
+ seq oracle
+ (generatorsL, generatorsR)
+ where
+ f (C i _) tyInfo = case xedni #! i of
+ SomeData a -> SomeData $ incr >>
+ case tyInfo of
+ [] -> defGen
+ _ -> frequencyWith doubleR (fmap g tyInfo) `proxyType` a
+ g :: Data a => (Integer, Constr, [C']) -> (Double, m a)
+ g (v, constr, js) =
+ ( fromInteger v * w
+ , gunfold generate return constr `runReaderT` gs)
+ where
+ gs = fmap (\(j', i) -> m j' i) js
+ m = maybe (generatorsR #!) m'
+ m' j (C _ k) = (generatorsL #! AC j k)
+ w = product $ fmap ((oracle #!) . snd) js
+ h (j, (i, Alias f)) k =
+ (AC j k, applyCast f (generatorsR #! C i k))
+ generatorsL = HashMap.fromList (liftA2 h (HashMap.toList xedni') [0 .. points])
+ generatorsR = HashMap.mapWithKey f types
+
+type SmallGenerators m =
+ (HashMap Aliased (SomeData m), HashMap Ix (SomeData m))
+
+-- | Generators of values of minimal sizes.
+smallGenerators
+ :: forall m. MonadRandomLike m => DataDef m -> SmallGenerators m
+smallGenerators DataDef{..} = (generatorsL, generatorsR)
+ where
+ f i (SomeData a) = SomeData $ incr >>
+ case types #! C i 0 of
+ [] -> defGen
+ tyInfo ->
+ let gs = (tyInfo >>= g (fst (lTerm #! i))) in
+ frequencyWith integerR gs `proxyType` a
+ g :: Data a => Nat -> (Integer, Constr, [C']) -> [(Integer, m a)]
+ g minSize (_, constr, js) =
+ guard (minSize == Succ size) *>
+ [(weight, gunfold generate return constr `runReaderT` gs)]
+ where
+ (size, weight) = lProd [ lTerm #! i | (_, C i _) <- js ]
+ gs = fmap lookup js
+ lookup (j', C i _) = maybe (generatorsR #! i) (generatorsL #!) j'
+ h (j, (i, Alias f)) = (j, applyCast f (generatorsR #! i))
+ generatorsL = (HashMap.fromList . fmap h . HashMap.toList) xedni'
+ generatorsR = HashMap.mapWithKey f xedni
+
+generate :: Applicative m => GUnfold (ReaderT [SomeData m] m)
+generate rest = ReaderT $ \(g : gs) ->
+ rest `runReaderT` gs <*> unSomeData g
+
+defGen :: (Data a, MonadRandomLike m) => m a
+defGen = gen
+ where
+ gen =
+ let dt = withProxy dataTypeOf gen in
+ case dataTypeRep dt of
+ IntRep -> fromConstr . mkIntegralConstr dt <$> int
+ FloatRep -> fromConstr . mkRealConstr dt <$> double
+ CharRep -> fromConstr . mkCharConstr dt <$> char
+ AlgRep _ -> error "Cannot generate for empty type."
+ NoRep -> error "No representation."
+
+-- * Short operators
+
+(?) :: DataDef m -> C -> Int
+dd ? C i k = i + k * count dd
+
+-- | > dd ? (listCs dd !! i) = i
+listCs :: DataDef m -> [C]
+listCs dd = liftA2 (flip C) [0 .. points dd] [0 .. count dd - 1]
+
+ix :: C -> Int
+ix (C i _) = i
+
+-- | > dd ? (dd ?! i) = i
+(?!) :: DataDef m -> Int -> C
+dd ?! j = C i k
+ where (k, i) = j `divMod` count dd
+
+getGenerator :: Data a => DataDef m -> Generators m -> proxy a -> Int -> m a
+getGenerator dd (l, r) a k = unSomeData $
+ case index dd #! typeRep a of
+ Right i -> (r #! C i k)
+ Left j -> (l #! AC j k)
+
+getSmallGenerator :: Data a => DataDef m -> SmallGenerators m -> proxy a -> m a
+getSmallGenerator dd (l, r) a = unSomeData $
+ case index dd #! typeRep a of
+ Right i -> (r #! i)
+ Left j -> (l #! j)
+
+(#!) :: (Eq k, Hashable k)
+ => HashMap k v -> k -> v
+(#!) = (HashMap.!)
diff --git a/src/Boltzmann/Data/Types.hs b/src/Boltzmann/Data/Types.hs
new file mode 100644
index 0000000..fae00db
--- /dev/null
+++ b/src/Boltzmann/Data/Types.hs
@@ -0,0 +1,197 @@
+-- | Internal module
+
+{-# LANGUAGE GADTs #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE TypeOperators #-}
+
+module Boltzmann.Data.Types where
+
+import Control.Monad.Random
+import Control.Monad.Trans
+import Data.Coerce
+import Data.Data
+import Data.Function
+import Test.QuickCheck
+
+data SomeData m where
+ SomeData :: Data a => m a -> SomeData m
+
+type SomeData' = SomeData Proxy
+
+-- | Dummy instance for debugging.
+instance Show (SomeData m) where
+ show _ = "SomeData"
+
+data Alias m where
+ Alias :: (Data a, Data b) => !(m a -> m b) -> Alias m
+
+type AliasR m = Alias (RejectT m)
+
+-- | Dummy instance for debugging.
+instance Show (Alias m) where
+ show _ = "Alias"
+
+-- | Main constructor for 'Alias'.
+alias :: (Monad m, Data a, Data b) => (a -> m b) -> Alias m
+alias = Alias . (=<<)
+
+-- | Main constructor for 'AliasR'.
+aliasR :: (Monad m, Data a, Data b) => (a -> m b) -> AliasR m
+aliasR = Alias . (=<<) . fmap lift
+
+-- | > coerceAlias :: Alias m -> Alias (AMonadRandom m)
+coerceAlias :: Coercible m n => Alias m -> Alias n
+coerceAlias = coerce
+
+-- | > coerceAliases :: [Alias m] -> [Alias (AMonadRandom m)]
+coerceAliases :: Coercible m n => [Alias m] -> [Alias n]
+coerceAliases = coerce
+
+-- | > composeCast f g = f . g
+composeCastM :: forall a b c d m
+ . (Typeable b, Typeable c)
+ => (m c -> d) -> (a -> m b) -> (a -> d)
+composeCastM f g | Just Refl <- eqT :: Maybe (b :~: c) = f . g
+composeCastM _ _ = castError ([] :: [b]) ([] :: [c])
+
+castM :: forall a b m
+ . (Typeable a, Typeable b)
+ => m a -> m b
+castM a | Just Refl <- eqT :: Maybe (a :~: b) = a
+castM a = let x = castError a x in x
+
+unSomeData :: Typeable a => SomeData m -> m a
+unSomeData (SomeData a) = castM a
+
+applyCast :: (Typeable a, Data b) => (m a -> m b) -> SomeData m -> SomeData m
+applyCast f = SomeData . f . unSomeData
+
+castError :: (Typeable a, Typeable b)
+ => proxy a -> proxy' b -> c
+castError a b = error $ unlines
+ [ "Error trying to cast"
+ , " " ++ show (typeRep a)
+ , "to"
+ , " " ++ show (typeRep b)
+ ]
+
+withProxy :: (a -> b) -> proxy a -> b
+withProxy f _ =
+ f (error "This should not be evaluated\n")
+
+reproxy :: proxy a -> Proxy a
+reproxy _ = Proxy
+
+proxyType :: m a -> proxy a -> m a
+proxyType = const
+
+someData' :: Data a => proxy a -> SomeData'
+someData' = SomeData . reproxy
+
+-- | Size as the number of constructors.
+type Size = Int
+
+-- | Internal transformer for rejection sampling.
+--
+-- > ReaderT Size (StateT Size (MaybeT m)) a
+newtype RejectT m a = RejectT
+ { unRejectT :: forall r. Size -> Size -> m r -> (Size -> a -> m r) -> m r
+ }
+
+instance Functor (RejectT m) where
+ fmap f (RejectT go) = RejectT $ \maxSize size retry cont ->
+ go maxSize size retry $ \size a -> cont size (f a)
+
+instance Applicative (RejectT m) where
+ pure a = RejectT $ \_maxSize size _retry cont ->
+ cont size a
+ RejectT f <*> RejectT x = RejectT $ \maxSize size retry cont ->
+ f maxSize size retry $ \size f_ ->
+ x maxSize size retry $ \size x_ ->
+ cont size (f_ x_)
+
+instance Monad (RejectT m) where
+ RejectT x >>= f = RejectT $ \maxSize size retry cont ->
+ x maxSize size retry $ \size x_ ->
+ unRejectT (f x_) maxSize size retry cont
+
+instance MonadTrans RejectT where
+ lift m = RejectT $ \_maxSize size _retry cont ->
+ m >>= cont size
+
+-- | Set lower bound
+runRejectT :: Monad m => (Size, Size) -> RejectT m a -> m a
+runRejectT (minSize, maxSize) (RejectT m) = fix $ \go ->
+ m maxSize 0 go $ \size a ->
+ if size < minSize then
+ go
+ else
+ return a
+--runRejectT (minSize, maxSize) (RejectT m) = fix $ \go -> do
+-- x' <- runMaybeT (m `runReaderT` maxSize `runStateT` 0)
+-- case x' of
+-- Just (x, size) | size >= minSize -> return x
+-- _ -> go
+
+newtype AMonadRandom m a = AMonadRandom
+ { asMonadRandom :: m a
+ } deriving (Functor, Applicative, Monad)
+
+instance MonadTrans AMonadRandom where
+ lift = AMonadRandom
+
+-- ** Dictionaries
+
+-- | @'MonadRandomLike' m@ defines basic components to build generators,
+-- allowing the implementation to remain abstract over both the
+-- 'Test.QuickCheck.Gen' type and 'MonadRandom' instances.
+--
+-- For the latter, the wrapper 'AMonadRandom' is provided to avoid
+-- overlapping instances.
+class Monad m => MonadRandomLike m where
+ -- | Called for every constructor. Counter for ceiled rejection sampling.
+ incr :: m ()
+ incr = return ()
+
+ -- | @doubleR upperBound@: generates values in @[0, upperBound]@.
+ doubleR :: Double -> m Double
+
+ -- | @integerR upperBound@: generates values in @[0, upperBound-1]@.
+ integerR :: Integer -> m Integer
+
+ -- | Default @Int@ generator.
+ int :: m Int
+
+ -- | Default @Double@ generator.
+ double :: m Double
+
+ -- | Default @Char@ generator.
+ char :: m Char
+
+instance MonadRandomLike Gen where
+ doubleR x = choose (0, x)
+ integerR x = choose (0, x-1)
+ int = arbitrary
+ double = arbitrary
+ char = arbitrary
+
+instance MonadRandomLike m => MonadRandomLike (RejectT m) where
+ incr = RejectT $ \maxSize size retry cont ->
+ if size >= maxSize then
+ retry
+ else
+ cont (size + 1) ()
+ doubleR = lift . doubleR
+ integerR = lift . integerR
+ int = lift int
+ double = lift double
+ char = lift char
+
+instance MonadRandom m => MonadRandomLike (AMonadRandom m) where
+ doubleR x = lift $ getRandomR (0, x)
+ integerR x = lift $ getRandomR (0, x-1)
+ int = lift getRandom
+ double = lift getRandom
+ char = lift getRandom
diff --git a/src/Boltzmann/Solver.hs b/src/Boltzmann/Solver.hs
new file mode 100644
index 0000000..7fbe784
--- /dev/null
+++ b/src/Boltzmann/Solver.hs
@@ -0,0 +1,69 @@
+-- | Solve systems of equations
+
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE RankNTypes #-}
+{-# LANGUAGE RecordWildCards #-}
+{-# LANGUAGE TypeFamilies #-}
+
+module Boltzmann.Solver where
+
+import Control.Applicative
+import Data.AEq ( (~==) )
+import Numeric.AD.Mode
+import Numeric.AD.Mode.Forward
+import Numeric.LinearAlgebra
+import qualified Data.Vector as V
+import qualified Data.Vector.Storable as S
+
+data SolveArgs = SolveArgs
+ { accuracy :: Double
+ , numIterations :: Int
+ } deriving (Eq, Ord, Show)
+
+defSolveArgs :: SolveArgs
+defSolveArgs = SolveArgs 1e-8 20
+
+findZero
+ :: SolveArgs
+ -> (forall s. V.Vector (AD s (Forward R)) -> V.Vector (AD s (Forward R)))
+ -> Vector R
+ -> Maybe (Vector R)
+findZero SolveArgs{..} f = newton numIterations
+ where
+ newton 0 _ = Nothing
+ newton n x
+ | norm_y == 1/0 = Nothing
+ | norm_y > accuracy = newton (n - 1) (x - jacobian <\> y)
+ | otherwise = Just x
+ where
+ norm_y = norm_Inf y
+ jacobian = (fromRows . V.toList . fmap (V.convert . snd)) yj
+ y = (V.convert . fmap fst) yj
+ yj = jacobian' f (S.convert x)
+
+fixedPoint
+ :: SolveArgs
+ -> (forall a. (Mode a, Scalar a ~ R) => V.Vector a -> V.Vector a)
+ -> V.Vector R
+ -> Maybe (V.Vector R)
+fixedPoint args f =
+ fmap S.convert . findZero args (liftA2 (V.zipWith (-)) f id) . S.convert
+
+-- | Assuming @p . f@ is satisfied only for positive values in some interval
+-- @(0, r]@, find @f r@.
+search :: (Double -> a) -> (a -> Bool) -> (Double, a)
+search f p = search' e0 (0 : [2 ^ n | n <- [0 .. 100 :: Int]])
+ where
+ search' y (x : xs@(x' : _))
+ | p y' = search' y' xs
+ | otherwise = search'' y x x'
+ where y' = f x'
+ search' _ _ = error "Solution not found. Uncontradictable predicate?"
+ search'' y x x'
+ | x ~== x' = (x, y)
+ | p y_ = search'' y_ x_ x'
+ | otherwise = search'' y x x_
+ where
+ x_ = (x + x') / 2
+ y_ = f x_
+ e0 = error "Solution not found. Unsatisfiable predicate?"
diff --git a/src/Boltzmann/Species.hs b/src/Boltzmann/Species.hs
new file mode 100644
index 0000000..26b152c
--- /dev/null
+++ b/src/Boltzmann/Species.hs
@@ -0,0 +1,220 @@
+-- | Applicative interface to define recursive structures and derive Boltzmann
+-- samplers.
+--
+-- Given the recursive structure of the types, and how to combine generators,
+-- the library takes care of computing the oracles and setting the right
+-- distributions.
+
+{-# LANGUAGE FlexibleContexts, FlexibleInstances, GADTs, RankNTypes #-}
+{-# LANGUAGE ScopedTypeVariables #-}
+{-# LANGUAGE DeriveFunctor, DeriveGeneric, ImplicitParams #-}
+{-# LANGUAGE RecordWildCards, DeriveDataTypeable #-}
+{-# LANGUAGE TypeFamilies, MultiParamTypeClasses #-}
+
+module Boltzmann.Species where
+
+import Control.Applicative
+import Control.Monad
+import Data.Bifunctor
+import Data.Coerce
+import Data.Function
+import Data.Foldable
+import Data.List
+import Data.Maybe
+import Data.Vector ( Vector )
+import qualified Data.Vector as V
+import qualified Numeric.AD as AD
+
+import Boltzmann.Data.Common
+import Boltzmann.Data.Types
+import Boltzmann.Solver
+
+class Embed f m where
+ emap :: (m a -> m b) -> f a -> f b
+ -- | A natural transformation between @f@ and @m@?
+ embed :: m a -> f a
+
+-- | 'Applicative' defines a product, 'Alternative' defines an addition,
+-- with scalar multiplication we get a module.
+--
+-- This typeclass allows to directly tweak weights in the oracle by
+-- chosen factors.
+class (Alternative f, Num (Scalar f)) => Module f where
+ type Scalar f :: *
+
+ -- | Scalar embedding.
+ scalar :: Scalar f -> f ()
+ scalar x = x <.> pure ()
+
+ -- | Scalar multiplication.
+ (<.>) :: Scalar f -> f a -> f a
+ x <.> f = scalar x *> f
+
+infixr 3 <.>
+
+type Endo a = a -> a
+
+data System f a c = System
+ { dim :: Int
+ , sys' :: f () -> Vector (f a) -> (Vector (f a), c)
+ } deriving (Functor)
+
+sys :: System f a c -> f () -> Vector (f a) -> Vector (f a)
+sys = (fmap . fmap . fmap) fst sys'
+
+newtype ConstModule r a = ConstModule { unConstModule :: r }
+
+instance Functor (ConstModule r) where
+ fmap _ (ConstModule r) = ConstModule r
+
+instance Num r => Embed (ConstModule r) m where
+ emap _ (ConstModule r) = ConstModule r
+ embed _ = ConstModule 1
+
+instance Num r => Applicative (ConstModule r) where
+ pure _ = ConstModule 1
+ ConstModule x <*> ConstModule y = ConstModule (x * y)
+
+instance Num r => Alternative (ConstModule r) where
+ empty = ConstModule 0
+ ConstModule x <|> ConstModule y = ConstModule (x + y)
+
+instance Num r => Module (ConstModule r) where
+ type Scalar (ConstModule r) = r
+ scalar = ConstModule
+ x <.> ConstModule r = ConstModule (x * r)
+
+solve
+ :: forall b c
+ . (forall a. Num a => System (ConstModule a) b c)
+ -> Double -> Maybe (Vector Double)
+solve s x = fixedPoint defSolveArgs phi' (V.replicate (dim s') 0)
+ where
+ phi' :: forall a. (AD.Mode a, AD.Scalar a ~ Double) => Endo (Vector a)
+ phi' = coerce (sys s (scalar (AD.auto x)) :: Endo (Vector (ConstModule a b)))
+ -- Arbitrary instantiation to get its dimension.
+ s' :: System (ConstModule Int) b c
+ s' = s
+
+sizedGenerator
+ :: forall b c m
+ . MonadRandomLike m
+ => (forall f. (Module f, Embed f m) => System (Pointiful f) b c)
+ -> Int -- ^ Index of type
+ -> Int -- ^ Points
+ -> Maybe Double -- ^ Expected size (or singular sampler)
+ -> m b
+sizedGenerator s i k size' = fst (sfix s' x oracle) V.! j
+ where
+ (x, oracle) = solveSized s i k size'
+ s' = point (k + 1) s
+ j = i * (k + 2) + k
+
+solveSized
+ :: forall b c
+ . (forall a. Num a => System (Pointiful (ConstModule a)) b c)
+ -> Int -- ^ Index of type
+ -> Int -- ^ Points
+ -> Maybe Double -- ^ Expected size (or singular sampler)
+ -> (Double, Vector Double)
+solveSized s i k size' =
+ fmap fromJust (search (solve s') (checkSize size'))
+ where
+ s' :: forall a. Num a => System (ConstModule a) b c
+ s' = point (k + 1) s
+ j = i * (k + 2) + k
+ j' = i * (k + 2) + k + 1
+ checkSize _ (Just ys) | V.any (< 0) ys = False
+ checkSize (Just size) (Just ys) = size >= ys V.! j' / ys V.! j
+ checkSize Nothing (Just _) = True
+ checkSize _ Nothing = False
+
+newtype Weighted m a = Weighted [(Double, m a)]
+
+weighted :: Double -> m a -> Weighted m a
+weighted x a = Weighted [(x, a)]
+
+runWeighted :: MonadRandomLike m => Weighted m a -> (Double, m a)
+runWeighted (Weighted [a]) = a
+runWeighted (Weighted as) = (sum (fmap fst as), frequencyWith doubleR as)
+
+instance Functor m => Functor (Weighted m) where
+ fmap f (Weighted as) = Weighted ((fmap . fmap . fmap) f as)
+
+instance MonadRandomLike m => Embed (Weighted m) m where
+ emap f = Weighted . (: []) . fmap f . runWeighted
+ embed m = Weighted [(1, m)]
+
+instance MonadRandomLike m => Applicative (Weighted m) where
+ pure a = Weighted [(1, pure a)]
+ f' <*> a' = Weighted [(u * v, f <*> a)]
+ where
+ (u, f) = runWeighted f'
+ (v, a) = runWeighted a'
+
+instance MonadRandomLike m => Alternative (Weighted m) where
+ empty = Weighted []
+ Weighted as <|> Weighted bs = Weighted (as ++ bs)
+
+instance MonadRandomLike m => Module (Weighted m) where
+ type Scalar (Weighted m) = Double
+ scalar x = Weighted [(x, pure ())]
+ x <.> Weighted as = Weighted (fmap (first (x *)) as)
+
+sfix
+ :: MonadRandomLike m
+ => System (Weighted m) b c -> Double -> Vector Double -> (Vector (m b), c)
+sfix s x oracle =
+ fix $
+ (first . fmap) (snd . runWeighted) .
+ sys' s (scalar x) .
+ V.zipWith weighted oracle .
+ fst
+
+data Pointiful f a = Pointiful [f a] | Zero (f a)
+
+instance Functor f => Functor (Pointiful f) where
+ fmap f (Pointiful v) = Pointiful ((fmap . fmap) f v)
+ fmap f (Zero x) = Zero (fmap f x)
+
+instance Embed f m => Embed (Pointiful f) m where
+ emap f (Pointiful v) = Pointiful ((fmap . emap) f v)
+ emap f (Zero x) = Zero (emap f x)
+ embed = Zero . embed
+
+instance Module f => Applicative (Pointiful f) where
+ pure a = Zero (pure a)
+ Zero f <*> Zero x = Zero (f <*> x)
+ Zero f <*> Pointiful xs = Pointiful (fmap (f <*>) xs)
+ Pointiful fs <*> Zero x = Pointiful (fmap (<*> x) fs)
+ Pointiful fs <*> Pointiful xs = Pointiful (convolute fs xs)
+ where
+ convolute fs xs = zipWith3 sumOfProducts [0 ..] (inits' fs) (inits' xs)
+ inits' = tail . inits
+ sumOfProducts k f x = asum (zipWith3 (times k) [0 ..] f (reverse x))
+ times k k1 f x = fromInteger (binomial k k1) <.> f <*> x
+
+instance Module f => Alternative (Pointiful f) where
+ empty = Zero empty
+ Pointiful xs <|> Pointiful ys = Pointiful (zipWith (<|>) xs ys)
+ Pointiful (x : xs) <|> Zero y = Pointiful ((x <|> y) : xs)
+ Zero x <|> Pointiful (y : ys) = Pointiful ((x <|> y) : ys)
+ Zero x <|> Zero y = Zero (x <|> y)
+ Pointiful [] <|> m = m
+ m <|> Pointiful [] = m
+
+instance Module f => Module (Pointiful f) where
+ type Scalar (Pointiful f) = Scalar f
+ scalar = Zero . scalar
+
+unPointiful :: Alternative f => Pointiful f a -> [f a]
+unPointiful (Pointiful as) = as
+unPointiful (Zero a) = a : repeat empty
+
+point :: Module f => Int -> System (Pointiful f) b c -> System f b c
+point k s = System ((k + 1) * dim s) $ \x ->
+ first flatten . sys' s (Pointiful (repeat x)) . resize
+ where
+ flatten = join . fmap (V.fromList . take (k + 1) . unPointiful)
+ resize v = V.generate (dim s) $ \i ->
+ Pointiful [v V.! j | j <- [i * (k + 1) .. i * (k + 1) + k]]
diff --git a/test/Test/Stats.hs b/test/Test/Stats.hs
new file mode 100644
index 0000000..d29f000
--- /dev/null
+++ b/test/Test/Stats.hs
@@ -0,0 +1,76 @@
+module Test.Stats where
+
+import Control.Monad
+
+import Data.List
+import Data.Maybe
+
+mean :: Foldable v => v Int -> Double
+mean xs = fromIntegral (sum xs) / fromIntegral (length xs)
+
+-- | Number of samples to estimate a probability distribution on a finite set
+-- of size @n@ to precision @epsilon@ (infinity-norm between distributions)
+-- with probability at least @(1 - delta)@.
+sampleSize
+ :: Int -- ^ Domain size
+ -> Double -- ^ Target distance (infinity-norm)
+ -> Double -- ^ Target error probability
+ -> Int
+sampleSize n epsilon delta =
+ ceiling (log (2 * fromIntegral n / delta) / (2 * epsilon ^ 2))
+
+-- | Number of trees with @n@ internal nodes.
+catalan :: [Integer]
+catalan = fmap catalan' [0 ..]
+ where
+ catalan' 0 = 1
+ catalan' i =
+ let prefix = take i catalan
+ in sum $ zipWith (*) prefix (reverse prefix)
+
+-- | Average size of a binary tree given the probability (@> 1/2@) of choosing
+-- a leaf.
+avgSize :: Fractional a => a -> a
+avgSize p = 1 / (2 * p - 1)
+
+-- | Inverse of 'avgSize'.
+invAvgSize :: Fractional a => a -> a
+invAvgSize s = (1 / s + 1) / 2
+
+-- | Distribution of sizes (actually, @(size - 1) / 2@), given the probability
+-- of choosing a leaf.
+distribution :: Fractional a => a -> [a]
+distribution p = zipWith f [0 ..] catalan
+ where
+ f i c = fromInteger c * p * (p * (1 - p)) ^ i
+
+expected :: Fractional a => Maybe a -> (Int, Int) -> Double -> Double -> (Int, [(Int, a)])
+expected avgSize' (minSize_, maxSize_) epsilon delta = (k, d)
+ where
+ p = maybe (1/2) invAvgSize avgSize'
+ minSize = (minSize_ + 1) `div` 2
+ maxSize = maxSize_ `div` 2
+ n = maxSize - minSize + 1
+ k = sampleSize n epsilon delta
+ d_ = (take n . drop minSize . distribution) p
+ d = zip [minSize ..] (fmap (/ sum d_) d_)
+
+runExperiment
+ :: (Fractional a, Ord a, Monad m)
+ => (Int, [(Int, a)]) -> m Int -> m ([(Int, a)], [(Int, a)], a)
+runExperiment (k, d) gen = cmp' . collect <$> replicateM k gen
+ where
+ collect :: Fractional a => [Int] -> [(Int, a)]
+ collect = fmap c . group . sort
+ c xs@(x : _) = (x, fromIntegral (length xs) / fromIntegral k)
+ c _ = undefined
+ cmp' z = (d, z, cmp d z)
+ cmp :: (Ord a, Num a) => [(Int, a)] -> [(Int, a)] -> a
+ cmp xs ys = maximum (zipWith_ (\x y -> abs (x - y)) xs ys)
+ zipWith_ :: (a -> a -> a) -> [(Int, a)] -> [(Int, a)] -> [a]
+ zipWith_ f xxs@((x, m) : xs) yys@((y, n) : ys)
+ | x == y = f m n : zipWith_ f xs ys
+ | x < y = m : zipWith_ f xs yys
+ | otherwise = n : zipWith_ f xxs ys
+ zipWith_ f [] ys = fmap snd ys
+ zipWith_ f xs [] = fmap snd xs
diff --git a/test/tree.hs b/test/tree.hs
new file mode 100644
index 0000000..d25f6a0
--- /dev/null
+++ b/test/tree.hs
@@ -0,0 +1,85 @@
+{-# LANGUAGE DataKinds #-}
+{-# LANGUAGE DeriveDataTypeable #-}
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE TypeOperators #-}
+
+import Control.Monad
+import Data.Data
+import Data.Foldable
+import Data.IORef
+import Data.List
+import System.Exit
+import System.IO
+
+import Options.Generic
+
+import Boltzmann.Data
+import Boltzmann.Data.Data
+
+import Test.Stats
+
+data T = L | N T T
+ deriving (Eq, Ord, Show, Data)
+
+size :: T -> Int
+size (N l r) = 1 + size l + size r
+size L = 0
+
+eps, del :: Double
+eps = 0.01
+del = 0.001
+
+-- | Periodically print stuff so that Travis does not think we're stuck.
+counting x gen = do
+ modifyIORef x (+ 1)
+ readIORef x >>= \x ->
+ when (x `mod` 1000 == 0) $ putStr "." >> hFlush stdout
+ gen
+
+-- | Invocation: stack test [--test-arguments TEST_SIZE]
+type Input = Maybe (Int <?> "Test size")
+
+main = do
+ n_ <- getRecord "Test program" :: IO Input
+ success <- newIORef True
+
+ let n = maybe 10 unHelpful n_
+ range = tolerance epsilon n
+
+ for_
+ [ ( "reject "
+ , generatorSR
+ , expected Nothing range eps del
+ )
+ , ( "rejectSimple "
+ , generatorR'
+ , expected (Just (fromIntegral n)) range eps del
+ )
+ ] $ \(name, g, kdist) -> do
+ putStrLn $ name ++ show n
+ let gen = (fmap size . asMonadRandom . g) n
+ x <- newIORef 0
+ (expectedDist, estimatedDist, diff) <- runExperiment kdist (counting x gen)
+ putStrLn ""
+ when (diff > eps) $ do
+ writeIORef success False
+ putStrLn $ "FAIL > " ++ show diff
+ print expectedDist
+ print estimatedDist
+
+{-
+ let k = 80000
+ eps = 0.1
+ gen = (fmap size . asMonadRandom . generatorP') n
+ putStrLn $ "pointed " ++ show n
+ x <- newIORef 0
+ sizes <- replicateM k (counting x gen)
+ putStrLn ""
+ let diff = abs (mean sizes - fromIntegral (n `div` 2))
+ when (diff > eps) $ do
+ writeIORef success False
+ putStrLn $ "FAIL > " ++ show diff
+-}
+
+ success <- readIORef success
+ unless success exitFailure