-- | Yet another simple module for implementing statistics stuff.

module Statistics (sample, samples, normal, normals, stdnormal, stdnormals, Distribution(..)
                  , module Control.Monad.Random 
                  ) where
import System.Random
import Control.Monad.Random
import Data.Array.Unboxed

newtype Prob = P Double deriving Show

instance Random Prob where 
    random g = let (a,g') = randomR (0,1) g in (P a,g')
    randomR  = error "Specifying a range for a probability makes no sense."

-- Todo: implement
data Distribution = Normal Double Double     -- mu and sigma
                  | LogNormal Double Double  -- mu and sigma
                  | Uniform Double Double    -- from/to inclusive
                  -- add more: StudentsT, SkewNormal, SkewT

-- Property:
-- forall distributions d . pvalue d . invcum d == id (to some precision)

invcum :: Distribution -> Prob -> Double
invcum (Normal mu sigma) (P z) = invcumnorm mu sigma z
invcum (LogNormal mu sigma) (P z) = exp $ invcum (Normal mu sigma) (P z)
invcum (Uniform a b) (P z) = a+(b-a)*z

-- invcum _ _ = error "The inverse cumulative distribution is undefined"

-- | Calculate probability of sampling less than x
--   (i.e. the cumulative distribution's value in x)
pvalue :: Distribution -> Double -> Prob
pvalue (Normal mu sigma) x = P (cumnorm mu sigma x)
pvalue (Uniform a b) x = P ((x-a)/(b-a))
pvalue (LogNormal mu sigma) x = P (cumlognorm mu sigma x)

-- pvalue _ _ = error "The cumulative distribution is undefined"

-- general functions for sampling

sample :: RandomGen g => Distribution -> Rand g Double
sample d = invcum d `fmap` getRandom 

-- sample (Uniform a b) = getRandomR (a,b)
-- todo: replace with general function:
samples  :: RandomGen g => Distribution -> Rand g [Double]
samples (Normal mu sigma) = normals mu sigma
samples (LogNormal mu sigma) = map exp `fmap` samples (Normal mu sigma)
samples (Uniform a b) = getRandomRs (a,b)

-- ------------------------------
-- Specifics
-- ------------------------------
stdnormal :: RandomGen g => Rand g Double
stdnormal = normal 0 1

stdnormals :: RandomGen g => Rand g [Double]
stdnormals = normals 0 1

normal :: RandomGen g => Double -> Double -> Rand g Double
normal mu sigma = do x <- getRandomR (0,1)
                     return (invcumnorm mu sigma x)

normals :: RandomGen g => Double -> Double -> Rand g [Double]
normals mu sigma = do x <- normal mu sigma
                      xs <- normals mu sigma
                      return (x : xs)

lognormal :: RandomGen g => Double -> Double -> Rand g Double
lognormal mu sigma = undefined

-- support

invcumnorm mu sigma z = mu + search (-limit*sigma) (limit*sigma)
    where search a b = let c = (a+b)/2
                           cn = cumnorm 0 sigma c
                       in if abs (z - cn) < 10*epsilon || abs (a-b) < epsilon then c
                            else if cn > z then search a c
                                 else search c b

cumstdnorm :: Double -> Double
cumstdnorm x = 0.5*(1+erf (x/sqrt 2))

cumnorm :: Double -> Double -> Double -> Double
cumnorm mu sigma x = 0.5*(1+erf((x-mu)/(sigma*sqrt 2)))

cumlognorm :: Double -> Double -> Double -> Double
cumlognorm mu sigma x = 0.5+0.5*erf ((log x-mu)/(sigma*sqrt 2))

-- taylor expansion, see wikipedia "error function".  Tested within the range (-limit..limit)
erf :: Double -> Double
erf x | x>limit         = 1
      | x< negate limit = 0
      | otherwise   = (2/sqrt pi)*sum (reverse $ takeWhile ((>=epsilon).abs) [let n' = fromIntegral n in ((-1)**n'*x**(2*n'+1)) / (fac n*(2*n'+1)) | n <- [0..]])

epsilon = 0.0000000001
limit = 4.4 :: Double

fac :: Int -> Double
fac x | x < 128   = ftab!x
      | otherwise = (ftab!127) * product [128..fromIntegral x]

ftab :: UArray Int Double
ftab = listArray (0,127) [product [2..x] | x <- [0..127]]