{- | 
   parse blastX and try to determine correct reading frame
   
   Each blastx match (HSP) adds its bitscore? /length? to the 
   first position of each codon.

   Todo: smooth (-1,2,-1), then add -k for past stop codons, -k,+k for aug/kozaks

   Use DP to determine optimal path through "frame"-space.
   What is frame-shift cost?  Direction-shift? Termination?

   Bonus points:
     extrapolated start/stop
     Poly-A, Kozak, STP, AUG, ...

   Problem: protein switch: follow good/strong protein in one frame, then switch to a different,
   weaker match to (artificially) extend the frame.

-}
module Blastx (Weights, Frame
              , frames, select_dir -- , indet
              ,genFasta, genCDS
              ) where

import Bio.Core
import Bio.Core.Sequence.Util
import BlastFlat
import Bio.Sequence.Fasta

import Data.List (groupBy,maximumBy)
import Text.Printf
import qualified Data.ByteString.Lazy.Char8 as B
import Data.Monoid

-- | Read the character at the specified position in the sequence.
{-# INLINE (!) #-}
(!) :: Sequence -> Offset -> Char
(!) (Seq _ (SeqData bs) _) = B.index bs . unOff

-- import qualified Debug.Trace (trace)
trace :: String -> a -> a
trace m = id  --  Debug.Trace.trace m -- switch for debugging output

-- todo: use ST/U/Arrays for efficiency
type Weights = [Double]
type Frame   = (Double,[Int])

genFasta, genCDS :: (Sequence, (Frame,Frame)) -> [Sequence]
genFasta (s,frames@((sc_f,fs),(sc_r,rs))) = 
    let d = select_dir s frames
        score = printf "%s: %.2f/%.2f CDS: " (show d) sc_f sc_r
        cds_f = if sc_f > 0 then printf "%d..%d" (head fs) (last fs) else "-"
        cds_r = if sc_r > 0 then printf "%d..%d" (last rs) (head rs) else "-"
    in -- trace (show sc_f++" : "++show fs++"\n"++show sc_r++" : "++show rs++"\n") $
       return $ tag ((if d==Rev then revc else id) s) (score++cds_f++"/"++cds_r) 

revc :: Sequence -> Sequence
revc (Seq h sd mq) = Seq h (revcompl sd) (maybe Nothing (Just . revqual) mq)

genCDS (s,frames@((sc_f,fs),(sc_r,rs))) = 
    trace (show sc_f++" : "++show fs++"\n"++show sc_r++" : "++show rs++"\n") $
    case select_dir s frames of
                  Indet -> []
                  Fwd -> [cds_fwd frames s]
                  Rev -> [cds_rev frames s]
--                  Vague -> [cds_fwd fs s,cds_rev fs s]

-- todo: appears to give correct results, but...
cds_fwd, cds_rev :: (Frame,Frame) -> Sequence -> Sequence
cds_fwd ((score,is),_) s = tag (Seq (seqlabel s) sdata Nothing) 
   (printf "FWD trans: %.2f CDS: %d..%d l=%d" score (last is) (head is) sl)
  where sdata = toIUPAC $ concatMap (take 1 . translate (seqdata s) . fromIntegral) $ reverse is
        sl = fromIntegral $ seqlength s ::Int

cds_rev (_,(score,is)) s = trace (show (sl-last is)++".."++show (sl-head is)++": "++show is) $
   tag (Seq (seqlabel s) sdata Nothing)
   (printf "REV trans: %.2f CDS: %d..%d l=%d" score (sl-head is) (sl-last is) sl)
    where sl = fromIntegral $ seqlength s
          sdata = toIUPAC $ concatMap (take 1 . translate (revcompl $ seqdata s) . fromIntegral) $ reverse is

frames :: [Sequence] -> [BlastFlat] -> [(Sequence,(Frame,Frame))]
frames ss bfs = joins ss (map merge $ partitions bfs)
    where merge (p:ps) = let (wf,wr) = mergeW $ map genweights (p:ps)
                         in (query p, (wf,wr))

joins :: [Sequence] -> [(SeqLabel,(Weights,Weights))] -> [(Sequence,(Frame,Frame))]
joins ss [] = map (\s -> (s,no_frame)) ss
joins (s:ss) all@((fs,ws):rest) 
    | fs /= seqheader s   = (s,no_frame) : joins ss all
    | otherwise           = let wf' = zipWith (+) (fst ws) (patterns 10 s)
                                -- this is WRONG:
                                wr' = zipWith (+) (snd ws) (patterns 10 $ revc s)
                            in  (s,(opt_score 10 wf',opt_score 10 wr')) : joins ss rest

no_frame = ((0,[]),(0,[]))

-- | Partition BlastFlats by query name
partitions :: [BlastFlat] -> [[BlastFlat]]
partitions = groupBy ((==) `on` query)
    where f `on` g = \x y -> f (g x) (g y)

{-- RECURRENCE

opt_score (w1:w2:w3:w4:ws) = 
    max | w1 + opt_score (w4:ws)
        | w0 - p + opt_score (w3:w4:ws)
        | w2 - p + opt_score ws

Perhaps easier:
  s1,s2,s3 = w1,w2,w3
  si = wi+max (w(i-3),-p+w(i-2),-p+w(i-4))

Todo: add codon bias
      add amino comp(?)
      add slight positive weight overall to prefer long ORFs?
-}

-- TODO: xxx-AUG-yyy- -- xxx is -p , AUG is +p-epsilon? (or just hit-bits?)
-- starting at AUG gives bonus, passing (including) it is neutral
-- but...we shouldn't encourage frame shifts at the AUG.
-- xxx-STP-yyy  -- xxx is -p STP is +1.5p, yyy is -p (+eps - rather force two frame shifts)
-- should be better to frame-shift before the STP than to include it...

patterns :: Double -> Sequence -> Weights
patterns p s = zipWith (+) (codon_adjust p s) (polyA_adjust s)

-- | Add START+STOP bonus/penalties.
--   Currently, a severe penalty for STP codons
codon_adjust :: Double -> Sequence -> Weights
codon_adjust p s = scores (f1,f2,f3)
    where
      [f1,f2,f3] = map (translate $ seqdata s) [0,1,2]
      scores ([],[],[]) = []
      scores (a:as,bs,cs) = (case a of STP -> -2*p
                                       _   ->    0) : scores (bs,cs,as)

-- increase scores before poly-A, decrease it afterwards
polyA_adjust :: Sequence -> Weights
polyA_adjust s = case pApos s of Just i' -> let i = fromIntegral i' 
                                            in replicate i positive ++ replicate (sl-i) negative
                                 Nothing -> replicate sl zero
    where positive =  0.1
          negative = -0.05
          zero     =  0.01
          sl       = fromIntegral $ seqlength s

-- | Find a position whose suffix has > t As within w distance
pApos :: Sequence -> Maybe Offset
pApos s@(Seq _ d _) = go w0 w
    where w0 = fromIntegral . length . filter isA . take (fromIntegral w) . toString $ d
          isA c = c == 'a' || c == 'A'
          go :: Offset -> Offset -> Maybe Offset
          go count pos 
              | pos >= l  = Nothing
              | count > t = Just pos
              | otherwise = let next = if isA (s!pos) then 1 else 0
                                prev = if isA (s!(pos-w)) then 1 else 0
                            in go (count+next-prev) (pos+1)
          l = seqlength s
          t = 20 -- threshold
          w = 25 -- window size

-- | Calculate optimal score, and include the coordinates for the codons.
--   Uses a kind of DP, retaining only four values (but all coords for best hit so far).
opt_score :: Double -> [Double] -> Frame
opt_score p (w1:w2:w3:w4:ws) = osc 4 (maximum' [i0,i1,i2,i3]) (i0,i1,i2,i3) ws
    where (i0,i1,i2,i3) = ((w1,[0]),(w2,[1]),(w3,[2])
                          , if w1 >= w2-p then (w4 + w1,[3,0])
                            else (w4+w2-p,[3,1]))

          -- osc params: curidx maxidx cache rest
          osc i m (s0,s1,s2,s3) (w:ws) = 
              (if (i-1) `rem` 3==0 then trace (printf "%d :\t%.1f\t%.1f\t%.1f\tmax %.1f,%d..%d" (i-1) (fst s0) (fst s1) (fst s2) (fst m) (head $ snd m) (last $ snd m)) else id) $
              let (s,is) = maximum' [s1,(pm s0),(pm s2),(0,[])]
              in osc (i+1) (if s>fst m then (s,is) else m) (s1,s2,s3,(w+s,i:is)) ws
          osc _i m (_0,s1,s2,s3) [] = trace (printf "%d :\t%.1f\t%.1f\t%.1f\n" _i (fst s1) (fst s2) (fst s3)) $
                                      maximum' [m,s1,s2,s3]

          pm (s,is) = (s-p,is)
          maximum' = maximumBy (compare `on` fst)
          f `on` g = \x y -> f (g x) (g y)

opt_score _ _ = (0,[]) -- less than four nucleotides is too short a sequence :-)

{-
-- | Subtract the opposing view from a set of weights (why does this matter?)
convolve :: (Weights,Weights) -> (Weights,Weights)
convolve = unzip . c . zip 
    where c xs@(x:_) = c1 x : map c2 (takeWhile (>=2 . length) tails xs)
          c1 ((a,b):(c,d):_) = (a-(b+c+d)/3, b-(a+c+d)/3)
          c2 ((a,b):(c,d):(e,f):_) = let z = a+b+e+f 
                                     in (c-(d+z)/5, d-(c+z)/5)
          c2 z@[_,_] = c1 $ reverse z
-}

-- | Merge weights
mergeW :: [(Weights,Weights)] -> (Weights,Weights)
mergeW = foldr (sumW max) (repeat (-0.0001),repeat (-0.0001))

sumW :: (Double -> Double -> Double) -> (Weights,Weights) -> (Weights,Weights) -> (Weights,Weights)
sumW op (as,bs) (cs,ds) = (zipWith op as cs, zipWith op bs ds)

-- | Generate appropriate weights for a single BLAST hits.
--   BLAST coordinates are for the FWD strand only, starts on 1, and the hit is inclusive
genweights :: BlastFlat -> (Weights,Weights)
genweights bf = let Frame s _ = aux bf
                    matchlen = q_to bf - q_from bf + 1
                    b = 3*bits bf/fromIntegral matchlen
                    res = replicate (q_from bf-1) 0 ++ 
                          take matchlen (case s of Plus -> cycle [b,0,0]
                                                   Minus -> cycle [0,0,b]) ++
                          replicate (qlength bf - q_to bf) 0
                in case s of Plus -> (res, replicate (qlength bf) 0)
                             Minus -> (replicate (qlength bf) 0, reverse res)

-- use appendHeader from new biolib (todo!)
tag :: Sequence -> String -> Sequence
tag (Seq h d _) t = (Seq (h <> SeqLabel (B.pack (" "++t))) d Nothing)

{-
-- | What to do if no data is available
indet :: Sequence -> Sequence
indet s = tag s "INDET"

fwd, rev, vague :: (Frame,Frame) -> Sequence -> Sequence
fwd ((f,(i:is)),(r,_)) s = tag s 
                       (printf "FWD: %.2f/%.2f CDS: %d..%d" f r (last (i:is)) i)
rev ((f,_),(r,(i:is))) s = tag (revcompl s) 
                       (printf "REV: %.2f/%.2f CDS: %d..%d" r f (sl-i) (sl-last (i:is)))
    where sl = fromIntegral $ seqlength s
vague ((f,(fi:fis)),(r,(ri:ris))) s = 
    tag s (printf "WEAK: %.2f/%.2f CDSs: %d..%d/%d..%d" f r  (last (fi:fis)) fi (last (ri:ris)) ri)
vague x s = error $ show x -- todo: happens if only score in one dir, but less than threshold
-}

-- | Orient based on BLAST data
select_dir :: Sequence -> (Frame,Frame) -> Dir
select_dir s scs@((sc_f,_if),(sc_r,_ir))
    | sc_f-5 > sc_r = Fwd
    | sc_r-5 > sc_f = Rev
    | otherwise     = Indet

data Dir = Indet | Fwd | Rev deriving (Show,Eq)
