
module Main () where

import Bio.Sequence
import Bio.Util (countIO)

import System.Environment (getArgs)
import Data.List (maximumBy)
import System.IO
-- import Data.ByteString.Lazy.Char8 (ByteString)
import Text.Printf

import BloomFilter

-- read files, construct 1 bf each, classify sequences from stdin
main :: IO ()
main = do
  xs <- getArgs
  bs <- mapM genfilter xs
  hPutStr stderr $ unlines $ zipWith (\ n b -> (n++": "++b)) xs $ map show bs
  ss <- countIO "Processing.." "..done.\n" 100 =<< hReadFasta stdin

  -- Choose one of three:
  main_single xs bs ss
  -- main_separate xs bs ss
  -- main_mask xs bs ss

main_single, main_separate, main_mask :: [FilePath] -> [BloomFilter] -> [Sequence] -> IO ()
main_single xs bs = classify (zip xs bs) to_stdout 

main_separate xs bs ss = do 
  hs <- mapM (\x -> do h <- openFile (x++".out") WriteMode
                       return (x,h)) xs
  h <- openFile "unmatched.out" WriteMode
  classify (zip xs bs) (separately h hs) ss
  mapM_ hClose (h:map snd hs)

main_mask _ [bs] = mapM_ mask1
    where mask1 :: Sequence -> IO ()
          mask1 s = do
            let m = matches bs s
            print (scanl (\a b -> if b then (a+1::Int) else a) 0 m)
            -- todo: pick median? modal? avg? and chop off extreme ends?
main_mask _ _ = error "only one filter allowed when masking"

type Action = (Sequence,FilePath,Int) -> IO ()

to_stdout :: Action
to_stdout (s,f,ms) = putStrLn $ show' (s,f,ms) $ calc_stats s ms

show' :: (Sequence, String, Int) -> (Double, Double, Int, Double,Double) -> String
show' (qry,tgt,score) (adjust,stdev,adj_score,cov,pval) = 
    (toStr (seqlabel qry) ++ " matches: " ++ tgt 
     ++ printf " \tbase: %d \tadjustment: %.1f±%.1f\t pval: %.3f score: %d\t cov: %.1f%%"
     score adjust stdev pval adj_score cov)

calc_stats :: Sequence -> Int -> (Double, Double, Int, Double,Double)
calc_stats qry score = (adjust,stdev,adj_score,cov,pval)
    where
      -- cov is a lower bound for query sequence coverage (percent)
      cov = if score == 0 then 0 
            else max 0 (100 * (fromIntegral (k-skipAmount+adj_score)/fromIntegral (seqlength qry))) :: Double
      -- adj_score is (a lower bound for) the number of matching positions
      adj_score = {- k + -} skipAmount*(score-1) - floor (adjust*skipAmount) :: Int
      -- n is the number of trials, 
      -- complicated by the fact that we skip (skipAmount-1) whenever we encounter a 'True'
      n = fromIntegral (seqlength qry -k + 1) - fromIntegral (score*(skipAmount-1))
      -- adjust is the expected number of false positives (binomial distribution)
      adjust = n*p :: Double
      -- stdev is the standard deviation for adjust (also binomial)
      stdev  = sqrt (n*p*(1-p)) :: Double
      -- caluclate p-value (prob of getting this high results by chance) 
      pvals = take 10 $ scanl (+) 0 [fromIntegral (floor n `choose` k) * p^k * (1-p)^(floor n-k) | k <- [0..]]
      pval = 1 - head (drop score pvals ++ [1])
      -- Approximate using Chernoff:
      --    pval = exp (-(n*(1-p)-(n-fromIntegral score))^2/(2*(1-p)*n))
      -- using Hoeffding:
      --    pval = exp (-2*(n*(1-p)-(n-fromIntegral score))^2/n)

choose :: Int -> Int -> Integer
n' `choose` k' = let (n,k) = (fromIntegral n', fromIntegral k')
                 in product [n-k+1..n] `div` product [1..k]

-- three standard deviations (one sided test) means a false positive in 1 of 1000
separately :: Handle -> [(FilePath,Handle)] -> Action
separately h_default hs (s,f,ms) = let stats@(adj,std,_,_,pv) = calc_stats s ms
                                       target = if ms==0 || pv > 0.05 then h_default
                                                else maybe h_default id (lookup f hs)
                                   in hPutStrLn target $ show' (s,f,ms) stats

-- possibility: reverse the bfs list each time (saves one cache flushing per cycle)
classify :: [(FilePath,BloomFilter)] -> Action -> [Sequence] -> IO ()
classify bfs action ss = do 
  mapM_ (action . classify1 bfs) ss
      where        
        classify1 bs s = maximumBy (compare `on` trd) $ 
                         map (\ (file,ft) -> (s,file,matchcount ft s)) bs

trd :: (a,b,c) -> c
trd (_,_,x) = x

on :: (t1 -> t1 -> t2) -> (t -> t1) -> t -> t -> t2
on f g x y = f (g x) (g y)