summaryrefslogtreecommitdiff
path: root/tests/Tests/Correlation.hs
blob: 756358307326970238b2a1ced1e9499ae4a16cc0 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
{-#LANGUAGE BangPatterns #-}

module Tests.Correlation
    ( tests ) where

import Control.Arrow (Arrow(..))
import qualified Data.Vector as V
import Statistics.Correlation
import Statistics.Correlation.Kendall
import Test.QuickCheck ((==>),Property,counterexample)
import Test.Tasty
import Test.Tasty.QuickCheck
import Test.Tasty.HUnit

import Tests.ApproxEq

----------------------------------------------------------------
-- Tests list
----------------------------------------------------------------

tests :: TestTree
tests = testGroup "Correlation"
    [ testProperty "Pearson correlation"           testPearson
    , testProperty "Spearman correlation is scale invariant" testSpearmanScale
    , testProperty "Spearman correlation, nonlinear"         testSpearmanNonlinear
    , testProperty "Kendall test -- general"       testKendall
    , testCase     "Kendall test -- special cases" testKendallSpecial
    ]


----------------------------------------------------------------
-- Pearson's correlation
----------------------------------------------------------------

testPearson :: [(Double,Double)] -> Property
testPearson sample
  = (length sample > 1) ==> (exact ~= fast)
  where
    (~=) = eql 1e-12
    exact = exactPearson $ map (realToFrac *** realToFrac) sample
    fast  = pearson $ V.fromList sample

exactPearson :: [(Rational,Rational)] -> Double
exactPearson sample
  = realToFrac cov / sqrt (realToFrac (varX * varY))
  where
    (xs,ys) = unzip sample
    n       = fromIntegral $ length sample
    -- Mean
    muX  = sum xs / n
    muY  = sum ys / n
    -- Mean of squares
    muX2 = sum (map (\x->x*x) xs) / n
    muY2 = sum (map (\x->x*x) ys) / n
    -- Covariance
    cov  = sum (zipWith (*) [x - muX | x<-xs] [y - muY | y<-ys]) / n
    varX = muX2 - muX*muX
    varY = muY2 - muY*muY

----------------------------------------------------------------
-- Spearman's correlation
----------------------------------------------------------------

-- Test that Spearman correlation is scale invariant
testSpearmanScale :: [(Double,Double)] -> Double -> Property
testSpearmanScale xs a
  = and [ length xs > 1       -- Enough to calculate underflow
        , a /= 0
        , not (isNaN c1)
        , not (isNaN c2)
        , not (isNaN c3)
        , not (isNaN c4)
        ]
  ==> ( counterexample (show xs2)
      $ counterexample (show xs3)
      $ counterexample (show xs4)
      $ counterexample (show (c1,c2,c3,c4))
      $ and [ c1 == c4
           , c1 == signum a * c2
           , c1 == signum a * c3
           ]
      )
  where
    xs2 = map ((*a) *** id  ) xs
    xs3 = map (id   *** (*a)) xs
    xs4 = map ((*a) *** (*a)) xs
    c1 = spearman $ V.fromList xs
    c2 = spearman $ V.fromList xs2
    c3 = spearman $ V.fromList xs3
    c4 = spearman $ V.fromList xs4

-- Test that Spearman correlation allows to transform sample with
testSpearmanNonlinear :: [(Double,Double)] -> Property
testSpearmanNonlinear sample0
  = and [ length sample0 > 1
        , not (isNaN c1)
        , not (isNaN c2)
        , not (isNaN c3)
        , not (isNaN c4)
        ]
  ==> ( counterexample (show sample0)
      $ counterexample (show sample1)
      $ counterexample (show sample2)
      $ counterexample (show sample3)
      $ counterexample (show sample4)
      $ counterexample (show (c1,c2,c3,c4))
      $ and [ c1 == c2
            , c1 == c3
            , c1 == c4
            ]
      )
  where
    -- We need to stretch sample into [-10 .. 10] range to avoid
    -- problems with under/overflows etc.
    stretch xs
      | a == b = xs
      | otherwise = [ (x - a - 10) * 20 / (a - b) | x <- xs ]
      where
        a = minimum xs
        b = maximum xs
    sample1 = uncurry zip $ (stretch *** stretch) $ unzip sample0
    sample2 = map (exp *** id ) sample1
    sample3 = map (id  *** exp) sample1
    sample4 = map (exp *** exp) sample1
    c1 = spearman $ V.fromList sample1
    c2 = spearman $ V.fromList sample2
    c3 = spearman $ V.fromList sample3
    c4 = spearman $ V.fromList sample4


----------------------------------------------------------------
-- Kendall's correlation
----------------------------------------------------------------

testKendall :: [(Double, Double)] -> Bool
testKendall xy | isNaN r1 = isNaN r2
               | otherwise = r1 == r2
  where
    r1 = kendallBruteForce xy
    r2 = kendall $ V.fromList xy

testKendallSpecial :: Assertion
testKendallSpecial = vs @=? map (\(xs, ys) -> kendall $ V.fromList $ zip xs ys) d
  where
    (d, vs) = unzip testData
    testData :: [(([Double], [Double]), Double)]
    testData = [ (([1, 2, 3, 1, 2], [1, 2, 1, 5, 2]), -0.375)
               , (([1, 1, 1, 3, 3], [3, 3, 3, 2, 5]), 0)
               ]


kendallBruteForce :: [(Double, Double)] -> Double
kendallBruteForce xy = (n_c - n_d) / sqrt ((n_0 - n_1) * (n_0 - n_2))
  where
    allPairs = f xy
    (n_c, n_d, n_1, n_2) = foldl g (0,0,0,0) allPairs
    n_0 = fromIntegral.length $ allPairs
    g (!nc, !nd, !n1, !n2) ((x1, y1), (x2, y2))
      | (x2 - x1) * (y2 - y1) > 0 = (nc+1, nd, n1, n2)
      | (x2 - x1) * (y2 - y1) < 0 = (nc, nd+1, n1, n2)
      | otherwise = if x1 == x2
                       then if y1 == y2
                               then (nc, nd, n1+1, n2+1)
                               else (nc, nd, n1+1, n2)
                       else (nc, nd, n1, n2+1)
    f (x:xs) = zip (repeat x) xs ++ f xs
    f _ = []