{-# LANGUAGE ForeignFunctionInterface #-}

-- | Yet another simple module for implementing statistics stuff.
module Statistics (sample, samples, normal, normals
                  , stdnormal, stdnormals, Distribution(..), Prob(..)
                  , pdf, pvalue, fromPdf, invcum, worsen
                  , module Control.Monad.Random
                  ) where

import Control.Monad.Random
import Control.Applicative
import Data.Char
import Data.Array.Unboxed

foreign import ccall "erf" erf :: Double -> Double

newtype Prob = P Double deriving Show

instance Random Prob where
    random g = let (a,g') = randomR (0,1) g in (P a,g')
    randomR  = error "Specifying a range for a probability makes no sense."

-- Todo: implement
data Distribution = Normal Double Double     -- ^ mu and sigma
                  | LogNormal Double Double  -- ^ mu and sigma
                  | Uniform Double Double    -- ^ from/to inclusive
                  | Empirical Double Double (UArray Int Double) 
                    -- ^ Stores a cumulative distribution, centered on first 
                    -- param (ref. 'worsen'), and second param is step size
                  deriving (Show)
                  -- add more: StudentsT, SkewNormal, SkewT

parseDist :: String -> Either String Distribution
parseDist str = case words (map toLower str) of
    ("uniform":rs) -> case rs of [x,y] -> Uniform <$> num x <*> num y
                                 _     -> Left ("Incorrect number of arguments to 'Uniform': "++show rs)
    ("lognormal":rs) -> case rs of [x,y] -> LogNormal <$> num x <*> num y
                                   _     -> Left ("Incorrect number of arguments to 'LogNormal': "++show rs)
    ("normal":rs) -> case rs of [x,y] -> Normal <$> num x <*> num y
                                _     -> Left ("Incorrect number of arguments to 'Normal': "++show rs)
    _ -> Left ("Failed to parse as distribution: '"++str++"'.")

num :: String -> Either String Double
num str = case reads str of [(x,"")] -> Right x
                            _ -> Left ("Couldn't parse '"++str++"' as a number")

instance Read Distribution where
  readsPrec _ str = case parseDist str of Right d -> [(d,"")]
                                          Left e -> error ("Couldn't understand the distribution you specified:\n"++e)

-- | Build an empirical probablility distribution by mapping the given probabilities
--   to uniformly spaced points starting at 'start' with 'step' points per unit.
--   Automatically center on pvalue = 50.  
fromPdf :: Double -> Double -> [Prob] -> Distribution
fromPdf start h ps = let    ps' = acc 0 $ map ((/scale) . unprob) ps
                            mu = start+(fromIntegral . length . takeWhile (<0.5)) ps'*h-h/2
                            a = floor ((start-mu)/h)
                            b = a+length ps'
                            scale = sum [ x | P x <- ps]
                            unprob (P x) = x
                            acc _ [] = [1]
                            acc c (x:xs) = let v = c+x
                                           in if v >= 1 then [1] else v : acc v xs
                        in Empirical mu h $ listArray (a,b) (0:ps')

invcum :: Distribution -> Prob -> Double
invcum (Normal mu sigma) (P z) = invcumnorm mu sigma z
invcum (LogNormal mu sigma) (P z) = exp $ invcum (Normal mu sigma) (P z)
invcum (Uniform a b) (P z) = a+(b-a)*z
invcum (Empirical mu h cd) (P z) = let (a,b) = bounds cd in mu+bisect ((cd!).round) (fromIntegral a) (fromIntegral b) z * h


-- | Calculate probability of sampling less than x
--   (i.e. the cumulative distribution's value in x)
pvalue :: Distribution -> Double -> Prob
pvalue (Normal mu sigma) x = P (cumnorm mu sigma x)
pvalue (Uniform a b) x | x <= a    = P 0 
                       | x >= b    = P 1
                       | otherwise = P ((x-a)/(b-a))
pvalue (LogNormal mu sigma) x | x<1e-10   = P 0 -- not sufficient!
                              | otherwise = P (cumlognorm mu sigma x)
pvalue (Empirical mu h cd) x = let (a,b) = bounds cd
                                   x' = (x-mu)/h
                               in if x' < fromIntegral a then P 0
                                  else if x' >= fromIntegral b then P 1
                                    else let x1 = floor x'
                                             x2 = x1+1
                                             y1 = cd!x1
                                             y2 = cd!x2
                                         in P (y1 + (y2-y1)*(x'-fromIntegral x1))

-- general functions for sampling

sample :: RandomGen g => Distribution -> Rand g Double
sample d = invcum d `fmap` getRandom

-- sample (Uniform a b) = getRandomR (a,b)
-- todo: replace with general function:
samples  :: RandomGen g => Distribution -> Rand g [Double]
samples (Normal mu sigma) = normals mu sigma
samples (LogNormal mu sigma) = map exp `fmap` samples (Normal mu sigma)
samples (Uniform a b) = getRandomRs (a,b)
samples (Empirical _mu _h _cds) = error "todo: implement 'samples' for Empirical distributions"

-- | The probability density function
pdf :: Distribution -> Double -> Double -- Prob?
pdf (Normal mu sigma) x = exp(negate(square (x-mu)/(2*square sigma)))/(sigma*sqrt(2*pi))
pdf (LogNormal mu sigma) x
    | x>0 = exp(negate(square (log x-mu)/(2*square sigma)))/(x*sigma*sqrt(2*pi))
    | otherwise = 0
pdf (Uniform a b) x
    | x<b && x>= a = 1/(b-a)  -- interval is open on the right, to work correctly with empirical below
    | otherwise = 0
pdf (Empirical mu h cd) x = let (a,b) = bounds cd
                                x' = (x-mu)/h
                            in if x' < fromIntegral a || x' >= fromIntegral b then 0
                               else let x1 = floor x'
                                        x2 = x1+1
                                        y1 = cd!x1
                                        y2 = cd!x2
                                    in (y2-y1)/h
-- ------------------------------
-- Specifics
-- ------------------------------
square :: Double -> Double
square x = x * x

stdnormal :: RandomGen g => Rand g Double
stdnormal = normal 0 1

stdnormals :: RandomGen g => Rand g [Double]
stdnormals = normals 0 1

normal :: RandomGen g => Double -> Double -> Rand g Double
normal mu sigma = do x <- getRandomR (0,1)
                     return (invcumnorm mu sigma x)

normals :: RandomGen g => Double -> Double -> Rand g [Double]
normals mu sigma = do x <- normal mu sigma
                      xs <- normals mu sigma
                      return (x : xs)

-- support
invcumnorm :: Double -> Double -> Double -> Double
invcumnorm mu sigma z = mu + bisect (cumnorm 0 sigma) (-limit*sigma) (limit*sigma) z

bisect :: (Double -> Double) -> Double -> Double -> Double -> Double
bisect f a b z = let c = (a+b)/2
                     cn = f c
                 in if abs (z - cn) < 10*epsilon || abs (a-b) < epsilon then c
                    else if cn > z then bisect f a c z
                         else bisect f c b z

cumnorm :: Double -> Double -> Double -> Double
cumnorm mu sigma x = 0.5*(1+erf((x-mu)/(sigma*sqrt 2)))

cumlognorm :: Double -> Double -> Double -> Double
cumlognorm mu sigma x = 0.5+0.5*erf ((log x-mu)/(sigma*sqrt 2))

epsilon, limit :: Double
epsilon = 0.0000000001
limit = 4.4

-- | Make a distribution wider by expanding (or contracting, if less than 1) 
--   its stdev by some factor
worsen :: Double -> Distribution -> Distribution
worsen d (Normal mu sigma) = Normal mu (sigma*d)
worsen d (LogNormal mu sigma) = LogNormal mu (sigma*d)
worsen d (Uniform a b) = Uniform (a*(1-d)) (b*d) -- assumption alert?
worsen d (Empirical mu h ds) = Empirical mu (h*d) ds
