{-| The mutator randomly introduces substitutions and indels into 
    Fasta sequences
-}
{-# LANGUAGE DeriveDataTypeable #-}

module Mutator where

import System.IO
import System.Console.CmdArgs
import Bio.Core.Sequence
import Bio.Sequence.Fasta
import Data.ByteString.Lazy.Char8 as B

import Statistics
import Version

data Conf = Conf { input, output :: FilePath
                 , subst, indel :: Double
                 -- , gapext :: Double
                 } deriving (Typeable, Data)

conf :: Conf
conf = Conf 
  { input = "-"   &= args                      &= typFile
  , output = "-"  &= help "Output file"        &= typFile
  , subst = 0.01  &= help "Substitution rate"  &= typ "Float"
  , indel = 0.01  &= help "Indel rate"         &= typ "Float"
  -- , gapext = 0.05 &= help "Gap extension rate" &= typ "Float" -- todo: affine gaps
  } &= program "mutator"
    &= summary ("mutator "++version)
    &= details ["Mutate sequences in Fasta format by introducing"
               ,"random substitutions and insertions/deletions"]

main :: IO ()
main = do
  c <- cmdArgs conf
  let inp = case input c of "-" -> hReadFasta stdin
                            x   -> readFasta x
      outp = case output c of "-" -> hWriteFasta stdout
                              x   -> writeFasta x
  inp >>= doMutate c >>= outp
  
doMutate :: Conf -> [Sequence] -> IO [Sequence]
doMutate cf = evalRandIO . mutate (subst cf) (indel cf)

mutate :: RandomGen g => 
          Double -> Double -> [Sequence] -> Rand g [Sequence]
mutate sub ind  = mapM mut1 
  where mut1 (Seq h d _) = do 
          d2 <- go (B.unpack $ unSD d)
          return $ Seq h (SeqData $ B.pack d2) Nothing
        nuc :: Int -> Char  
        nuc x = case x of {0 -> 'a'; 1 -> 'c'; 2 -> 'g'; 3 -> 't'}
        go "" = return ""
        go (x:xs) = do
          z <- sample (Uniform 0 1)
          if z<sub 
            then do                         -- substitute
               y <- getRandomR (0,3)
               ys <- go xs
               return (nuc y:ys)
            else if z < sub+(ind/2) then do -- insert
               y <- getRandomR (0,3)
               ys <- go (x:xs)
               return (nuc y:ys)
            else if z < sub+ind then do     -- delete
               go xs
            else do
               ys <- go xs
               return (x:ys)
