\section{RBR - RepeatBeater revisited}

RBR works by examining word frequencies in EST data, and
based on the assumption that expression of positions within each gene
is relatively uniform, identifies extremely common words as potential
repeats.

A predecessor to RBR was implemented as some scripts parsing
\texttt{xsact} output, this is a revised, standalone implmentation.

\begin{code}
{-# LANGUAGE CPP #-}

module Main (main) where

import Bio.Sequence

import Lib.FreqTable
import Lib.Stats    hiding (median,uv_add)
import Lib.Util     (interactl)
import Lib.Options

-- import Text.Printf
import qualified Data.ByteString.Lazy.Char8 as B (map)

import System.IO
import System.IO.Unsafe
import Control.Monad
import Data.List as List
import Data.Char as Char

import Data.Map hiding (null,filter,map,elems,keys)

#define perror (\s->error ("Program error in function \""++s++"\" at "++__FILE__++":"++show (__LINE__::Int)++\
"\nPlease report this incident to <ketil@ii.uib.no>."))

main :: IO ()
main = do
       (opt1,non,err) <- getOptions
       -- print (opt1,non,err)
       when ((null non && null opt1) || not (null err)) (error $ usage err)
       opts <- parseargs opt1
       case non of [] -> do ss <- return . map castToNuc =<< hReadFasta stdin
                            main_real opts ("<stdin>", return ss)
                   [f1] -> main_real opts (f1, return . map castToNuc =<< readFasta f1)
                   fs   -> error ("Only supply one input file (you supplied: "++show fs++")")

header :: Opts -> [FilePath] -> String
header opts _inputs = "# rbr version "++version++", mask="++(case (lower opts) of Lower -> "L"; Ns -> "n"; Xs -> "X")++
                      " k="++show (kval opts)++" stringency="++show (string opts)++
                      " thresh="++show (thresh opts)++" g="++show (gap opts) ++ " skip="++ show (skip opts)

\end{code}

\subsection{The real main}

\begin{code}

debug :: Bool -> String -> IO ()
debug p m = when p $ hPutStr stderr m

countIO :: String -> String -> Int -> [a] -> IO [a]
countIO msg post step xs = sequence $ map unsafeInterleaveIO ((blank >> outmsg (0::Int) >> c):cs)
   where (c:cs) = ct 0 xs
         output   = hPutStr stderr
         blank    = output ('\r':take 70 (repeat ' '))
         outmsg x = output ('\r':msg++show x) >> hFlush stderr
         ct s ys = let (a,b) = splitAt (step-1) ys
                       next  = s+step
                   in case b of [b1] -> map return a ++ [outmsg (s+step) >> hPutStr stderr post >> return b1]
                                []   -> map return (init a) ++ [outmsg (s+length a) >> hPutStr stderr post >> return (last a)]
                                _ -> map return a ++ [outmsg s >> return (head b)] ++ ct next (tail b)

main_real :: Opts -> (String, IO [Sequence Nuc]) -> IO ()
main_real opts (ifile,ssio) = do
   when (verb opts || distr opts) $ putStrLn $ header opts [ifile]
   when (verb opts) $ do
      debug (verb opts) ("Reading file: "++ifile++"\n")
      ss <- ssio
      debug (verb opts) ("Number of sequences: "++show (length ss)++"\n")
   if kval opts <= 16 then main_cont make_ft_int opts ssio
                      else main_cont make_ft_integer opts ssio

main_cont :: Integral a => MkFT a -> Opts -> IO [Sequence Nuc] -> IO ()
main_cont make_ft opts ssio = do
   mask1 <- mask_func make_ft opts ssio
   ss'' <- if (verb opts) then countIO "  ...masking sequences: " ", done.\n" 100 =<< ssio
                          else ssio
   if (server opts) then do
                ft <- make_ft "Building index array" (verb opts) (skip opts) (kval opts) ssio
                interactl "READY: " (concatMap (main_server opts ft ss'') . words)
      else do if (distr opts) then main_coverage make_ft opts ssio
                              else main_masked opts mask1 ss''

type MkFT a = String -> Bool -> Int -> Int -> IO [Sequence Nuc] -> IO (FreqTable a)

make_ft_int :: MkFT Int
make_ft_int msg vb skp k ssio = do
  ss' <- if vb then countIO  ("  "++msg++": ") ", done.\n" 100 =<< ssio
         else ssio
  if skp > 0 then return (sparsetable_int skp (rcontig k) ss')
             else return (freqtable_int (rcontig k) ss')

make_ft_integer :: MkFT Integer
make_ft_integer msg vb skp k ssio = do
  ss' <- if vb then countIO  ("  "++msg++": ") ", done.\n" 100 =<< ssio
         else ssio
  if skp>0 then return (sparsetable_integer skp (rcontig k) ss')
            else return (freqtable_integer (rcontig k) ss')

main_coverage :: Integral a => MkFT a -> Opts -> IO [Sequence Nuc] -> IO ()
main_coverage make_ft opts ssio = do
    ft <- make_ft "Building index array" (verb opts) (skip opts) (kval opts) ssio
    let k = kval opts
        sk = skip opts
        (stdevs,b2,strg) = (thresh opts, 5, string opts)
	seq_cov s = let cv = coverage k sk (rcontig k) ft s
                        ls = sortv $ blunt b2
	                     $ map (\l->(fromIntegral $ head l,length l)) $ group $ sort cv
                        (mu,stdv) = distrib strg ls
                        t  = mu + max b2 (stdevs*stdv)
                    in (seqheader s, if null cv || null ls then (0,0,0) else (mu,stdv,t),cv) -- todo: round!
        write = case ofile opts of Nothing -> putStrLn; Just f -> writeFile f
    debug (verb opts) "Calculating coverage\n"
    ss <- ssio
    write $ unlines $ map show $ map seq_cov ss

main_masked :: Opts -> (Sequence a->Sequence a) -> [Sequence a] -> IO ()
main_masked opts mf ss = do
    let ms = if kval opts == 0 then ss else map mf ss
        write = case ofile opts of Nothing -> hWriteFasta stdout
                                   Just f -> writeFasta f
    write ms

mask_func :: Integral a => MkFT a -> Opts -> IO [Sequence Nuc] -> IO (Sequence Nuc -> Sequence Nuc)
mask_func make_ft opts ssio =
                    do let k = kval opts
                           g = gap  opts
                           sk = skip opts
                           gc = if g > 0 then closegaps (g+k-1) else id
                           sparm = (thresh opts, 5, string opts)
                           upcase = if preserve_lower opts then id else upcaseSequence
                       ms <- do case masksample opts of 
                                  Nothing -> case stats opts of
                                               False -> return []
                                               True  -> do ft <- make_ft "Indexing input" (verb opts) sk k ssio
                                                           return [mask_stat k sk sparm ft]
                                  Just f -> do ft <- make_ft "Indexing sample" (verb opts) sk k (return . map castToNuc =<< readFasta f)
                                               return [mask_stat k sk sparm ft]
                       mt <- mapM (\libf -> do
                                    lft <- make_ft "Indexing repeat library" (verb opts) sk k (return . map castToNuc =<< readFasta libf)
                                    return $ mask_table k sk lft) (lib opts) 
                       return ((case ms++mt of [] -> id
                                               xs -> mask_generic (lower opts) (k+if sk==0 then 0 else sk-1)
                                                     (gc . min_mask ((sk+2) `div` 2) . combine xs)) . upcase)

upcaseSequence :: Sequence a -> Sequence a
upcaseSequence (Seq l d q) = (Seq l (B.map toUpper d) q)

-- sparse keys can give aberrant counts, so require more than one consequtive
-- todo: really really convert to RLL-encoded whitelist
min_mask :: Int -> [Bool] -> [Bool]
min_mask _ [] = []
min_mask sk xs@(True:_) = let (t,f) = span id xs in t ++ min_mask sk f
min_mask sk xs@(False:_) = let (f,t) = span not xs in (if length f >= sk then f
                                                       else take (length f) (repeat True)) ++ min_mask sk t

combine :: [a -> [Bool]] -> (a -> [Bool])
combine []  = perror "combine"
combine [f] = f
combine (f1:f2:fs) = let f = \s -> zipWith (&&) (f1 s) (f2 s) in combine (f:fs)


main_server :: Integral a => Opts -> FreqTable a -> [Sequence Nuc] -> (String -> String)
main_server opts ft ss =
    let sidx     = fromList $ map (\s->(seqlabel s,s)) ss
        (k,sk,gc,ps) = (kval opts, skip opts,if gap opts > 0 then closegaps (k+gap opts-1) else id, (thresh opts,5,string opts))
    in (\s -> case Data.Map.lookup (fromStr s) sidx of
               Nothing -> ("input error: "++s++" not found\n")
               Just sq -> let cover    = show $ coverage k sk (rcontig k) ft sq
                              masked_  = show $ mask_generic (lower opts) (if sk==0 then k else k+sk-1) (gc . mask_stat k (skip opts) ps ft) sq
                                                -- mask_stat (lower opts) k g ps ft sq
                              unmasked = show $ seqdata sq
                          in unlines [masked_, cover,unmasked])

\end{code}

Generic masking function.  The list of Bool is a whitelist, so that
True corresponds to retained nucleotides, while False indicates
nucleotides to be masked.

(TODO: Run-length encode the boolean list?)

\begin{code}

mask_generic :: MaskWith -> Int -> (Sequence Nuc -> [Bool]) -> Sequence Nuc -> Sequence Nuc
mask_generic b sz fn s@(Seq l sd mq) = Seq l (fromStr . masked b sz (fn s) . toStr $ sd) mq

\end{code}

Mask all words from a dictionary (i.e. a library of known repeats)

\begin{code}

mask_table :: Integral a => Int -> Int -> FreqTable a -> Sequence Nuc -> [Bool]
mask_table k skp rlib s = let
    cv = coverage k skp (rcontig k) rlib s -- slight abuse of 'coverage'
    in map (==0) cv

\end{code}

Masking sequences from frequency counts.

\begin{code}

type MPars = (Double,Double,Double) -- stdevs thres, 1-elim, stringency

-- calculate the 'whitelist'
mask_stat :: Integral a => Int -> Int -> MPars -> FreqTable a -> Sequence Nuc -> [Bool]
mask_stat k sk (stdevs,b2,strg) ft s =
    let cv = coverage k sk (rcontig k) ft s
        ls = sortv $ blunt b2
	     $ map (\l->(fromIntegral $ head l,length l)) $ group $ sort cv
        (mu,stdv) = distrib strg ls
        t  = mu + max b2 (stdevs*stdv)
     in if null cv || null ls then take (1-k+fromIntegral (seqlength s)) $ repeat True
        else map ((<=t) . fromIntegral) cv

-- distrib calculates the base distribution
distrib :: Double -> [(Double,Int)] -> (Double,Double)
distrib strg ls = distr' ls (uv_mk [])
    where
          distr' [] _
              = perror "distrib"
          distr' ((mag,cnt):ls') ustd =
                  let ustd' = uv_add mag cnt ustd
                      u     = uniVar ustd'
                      score = variance u / stdev u
                  in if score >= strg || null ls'
                     then (mean u,stdev u)
                     else distr' ls' ustd'

-- s {seqannot = seqannot s
--              ++[UKV "Threshold" $ printf "%.2f" t]
--	      ,seqdata = listArray (bounds $ seqdata s) ms}

-- compensate for 1% error rate by reducing number of ones by k%
-- this is inaccurate, but possibly a workable estimate nonetheless
blunt :: Double -> [(Double,Int)] -> [(Double,Int)]
blunt k = reduce1s . dropWhile ((==0).fst)
   where reduce1s all_@((a,x):ls) = if a==1
            then (1,max 0 (x-floor k*(sum $ map snd all_) `div` 100)):ls
            else all_
         reduce1s [] = [] -- sequence contains no words

-- sort ls around the centre or the modal interval
sortv :: [(Double,Int)] -> [(Double,Int)]
sortv ls = let m = scaled_modal ls
           in sortBy (\(a,_) (c,_) -> compare (abs (a-m)) (abs (c-m))) ls

-- NB: 1 = error - chose mode ignoring 1s? (see C.e. seq. 24)

-- for each window of occurrence frequencies (n..sqrt n), find the one with most positions
scaled_modal :: [(Double,Int)] -> Double
scaled_modal = modal . map window . tails'

window :: [(Double,Int)] -> (Double,Int)
window [] = perror "window"
window ((mag,cnt):ls) =
   let top = mag + sqrt mag
       xs  = (mag,cnt) : takeWhile ((<top).fst) ls
   in ((mag+(fst . last) xs)/2, sum (map snd xs))

-- tails without the final [[]]
tails' :: [a] -> [[a]]
tails' []       = []
tails' x@(_:xs) = x : tails' xs

{-
windowize ((mag,count):ls) =
    let (this,rest) = break ((> mag+sqrt mag) . fst) ls
    in (mag+sqrt mag,count+sum (map snd this)) : windowize rest
windowize [] = []
-}

modal :: Ord b => [(a,b)] -> a
modal = fst . maximumBy (\x y -> compare (snd x) (snd y))

median :: [(Double,Int)] -> Double
median ls = let n = (`div` 2) $ sum $ map snd ls
            in nquant n ls

nquant :: Int -> [(Double,Int)] -> Double
nquant n ((v,c):vs) | n < c     = v
                    | otherwise = nquant (n-c) vs
nquant _ [] = perror "median of an empty list is impossible!"


-- faster(?) replacements for Stats.uv and uv_del
type UV = (Int,Double,Double,Double,Double)

uv_mk :: [(Double,Int)] -> UV
uv_mk = foldr uv1 (0,0,0,0,0)
    where uv1 (m,cnt) (n,x,x2,x3,x4) = let c = fromIntegral cnt in
            (n+cnt,x+m*c,x2+m*m*c,x3+m*m*m*c,x4+m*m*m*m*c)

uv_sub, uv_add :: Double -> Int -> UV -> UV
uv_sub m cnt (n,x,x2,x3,x4) = let c = fromIntegral cnt in
    (n-cnt,x-m*c,x2-m*m*c,x3-m*m*m*c,x4-m*m*m*m*c)

uv_add m cnt (n,x,x2,x3,x4) = let c = fromIntegral cnt in
    (n+cnt,x+m*c,x2+m*m*c,x3+m*m*m*c,x4+m*m*m*m*c)

-- masked takes mask type, kval, threshold, a whitelist ('False' means mask)
-- and the (nucleotides in the) sequence
masked :: MaskWith -> Int -> [Bool] -> String -> String
masked Lower = masked' toLower 0
masked Ns    = masked' (const 'n') 0
masked Xs    = masked' (const 'X') 0

-- mask using an arbitrary masking function
-- the 'ns' parameter keeps track of any remaining masked word, so that
-- we finish masking out the whole over-represented word
masked' :: (Char->Char) -> Int -> Int -> [Bool] -> String -> String
masked' _ _ _ (_:_) [] =
    perror "masked' (1)"
masked' mf ns _ [] es = map mf (take ns es) ++ drop ns es
masked' mf ns k (b:bs) (e:es)
    | ns == 0 && b  = e    : masked' mf 0 k bs es
    | ns > 0  && b  = mf e : masked' mf (ns-1) k bs es
    | not b         = mf e : masked' mf k k bs es
    | otherwise             = perror "masked' (2)"

-- close gaps of less than 'g' unmasked words between groups of masked characters
-- 'False' corresponds to words to mask (whitelist)
closegaps :: Int -> [Bool] -> [Bool]
closegaps g = close g True . partitions

-- always starts with a (possibly empty) unmasked group
partitions :: [Bool] -> [Int]
partitions [] = []    -- at least mask_table may return an empty list here
partitions es = map List.length $ if head es == False then []:partitions' es else partitions' es
  where partitions' = List.groupBy (\x y -> x == False && y == False || x /= False && y /= False)

-- convert runs of <=g Trues to Falses
close :: Int -> Bool -> [Int] -> [Bool]
close _ _    []     = []
close g True (c:cs) = take c (repeat (if c>g then True else False)) ++ close g False cs
close g False (c:cs) = (take c $ repeat False) ++ close g True cs

\end{code}

\subsection{Coverage}

Calculating the coverage over a sequence

For sparse indices, this should be a sliding average.

\begin{code}

coverage :: Integral a => Int -> Int -> HashF a -> FreqTable a -> Sequence Nuc -> [Int]
coverage k skp kf tb s = if skp > 0 then sliding_avg skp raw_counts else raw_counts
       where
       raw_counts = map mcount $ padkeys (seqlength s-fromIntegral k+1) 0 $ hashes kf $ seqdata s
       mcount (Just k1) = count tb k1
       mcount Nothing = 0

padkeys :: Offset -> Offset -> [(k,Offset)] -> [Maybe k]
padkeys top cur [] = take (fromIntegral (top - cur)) $ repeat Nothing
padkeys top cur ((k,p):ks) = if cur == p then Just k : padkeys top (cur+1) ks
                         else Nothing : padkeys top (cur+1) ((k,p):ks)


-- or just:
-- sliding_avg k = map ((`div` k) . sum . take k) . tails

sliding_avg :: Int -> [Int] -> [Int]
sliding_avg s xs = let x1 = take s xs
                   in if length x1 == s then slav s (sum x1) xs (drop s xs)
                   else [sum x1 `div` length x1]

-- todo: use floats?
slav :: Int -> Int -> [Int] -> [Int] -> [Int]
slav s tmp hd tl@(_:_) = tmp `div` s : slav s (tmp-head hd+head tl) (tail hd) (tail tl)
slav s tmp _ [] = [tmp `div` s]

\end{code}

