-- post-process BLASTN output in XML format
-- BAH:  Just readXML and extract the matches (pairs)

module Main where

import System
import System.Directory
import Data.Map (Map)
import qualified Data.Map as M
import Data.List (sortBy,maximumBy)
import Text.Printf

import Bio.Sequence
import Bio.Alignment.BlastData
import Bio.Alignment.BlastXML
import Bio.Alignment.AlignData
import Bio.Alignment.QAlign
import Bio.Alignment.Matrices

main = do
  [mode,query,qual,db,blast] <- getArgs
  main1 mode query qual db blast

main1 mode query qual db blast = do
  putStrLn ("..processing: "++show blast)
  bs <- readXML blast >>= \bb -> 
        return $ M.fromListWith (++) [(fromStr t,[b]) 
                             | b@(_,t,_,_,_,_) <- concatMap flatten bb]

  putStrLn "reading query db.."
  qs <- readFastaQual query qual >>= \ss -> 
        return $ M.fromList [(seqlabel s,s) | s <- ss]

  putStrLn "processing BLAST results"
  dbs <- readFasta db
  case mode of "-l" -> mapM_ printProc $ process qs bs dbs
               "-s" -> mapM_ printProc2 $ process qs bs dbs
               _    -> error "usage: postn -[l|s] query qual db blastres"

type ProcOutput = (String,String,Int,Double,Double,(Double,Alignment),(Double,Alignment))

printProc :: ProcOutput -> IO ()
              -- ,Double,Double,Double,Double) -> IO ()
printProc (a,b,l,c,d,(e,f),(g,h)) = do
    putStrLn ""
    putStrLn $ unwords (a:b:show l:map (printf "%.3g") [c,d])
    putStrLn $ printf "%.3g" e
    putStrLn $ showalign f
    putStrLn $ printf "%.3g" g
    putStrLn $ showalign h

printProc2 :: ProcOutput -> IO ()
printProc2 (a,b,l,c,d,(e,f),(g,h)) = do
  putStrLn $ unwords (a:b:show l:map (printf "%.3g") [c,d,e,g]++[show (length f),show (length h)])

-- | Warning: db can be huge, laziness required
process :: Map SeqId Sequence -> Map SeqId [Flat] -> [Sequence] -> [ProcOutput]
process qm res dbs = concatMap proc1 dbs
    where proc1 s = case M.lookup (seqlabel s) res of
                          Nothing -> []
                          Just brs ->  map (merge s) brs
          merge s (q,tgt,len,bits,eval,rev)
              = let qs@(Seq h d _) = (if rev then revcompl else id)
                                     $ lookup' (fromStr q) qm
                    bits' = qscore (Seq h d Nothing) s
                    bits'' = qscore qs s
                in (q,tgt,len,bits,eval,bits',bits'')
{-
                   ,bits',eval*2**(bits-bits')
                   ,bits'',eval*2**(bits-bits''))
-}

lookup' v m = case M.lookup v m of 
                Just x -> x
                Nothing -> error $ show (v,length $ M.toList m)

-- blastn defaults: 1 -3 -5 -3  (compact alignments!)
--  2 -10 -10 -2 is more realistic(?)
qscore :: Sequence -> Sequence -> (Double,Alignment)
qscore = local_align qualMx (-10,-4)

type Flat = (String,String,Int,Double,Double,Bool)

-- turn it into lines - (query,target,bits,eval)
flatten :: BlastResult -> [Flat]
flatten br = concatMap flatten1 (results br)

flatten1 :: BlastRecord -> [Flat]
flatten1 br = map (flathit (query br)) (hits br)

flathit :: SeqId -> BlastHit -> Flat
flathit s bh = maxMatch $ map (flatmatch s (subject bh)) (matches bh)
    where maxMatch = maximumBy compFlat
          compFlat (a,b,c,d,e,f) (g,h,i,j,k,l) = compare d j

flatmatch :: SeqId -> SeqId -> BlastMatch -> Flat
flatmatch q t bm = (head $ words $ toStr q
                   , head $ words $ toStr t,fst (identity bm), bits bm, e_val bm
                   , rev)
    where rev = case aux bm of Strands (p,p') -> p/=p'


