{-# Language FlexibleContexts, RankNTypes #-}

module Cluster where

import Control.Monad (when)
import Data.Array.Base
import Data.Array.IO
import Data.Array.ST 
import Control.Monad.ST
import Control.Concurrent.STM
import Data.List (nub)

import Base

data ClusterPtr = CP ClusterID -- todo: id, and free ends
                deriving Show
data Cluster = C !Int [(ContigID,End,ContigID,End)] -- todo: more specifics
             deriving Show
data End = Five | Three deriving (Show,Eq)


printCluster :: Cluster -> IO ()
printCluster (C n es) = if n>1 then print es else return ()

printIO :: IOArray ClusterID Cluster -> IO ()
printIO carr = mapM_ printCluster =<< getElems carr

printSTM :: TArray ClusterID Cluster -> IO () 
printSTM carr = do
  (x,y) <- atomically (getBounds carr)
  mapM_ (\i -> printCluster =<< atomically (readArray carr i)) [x..y]

opencptr :: ContigID -> ClusterPtr
opencptr x = CP x -- True True

singleton :: ContigID -> Cluster
singleton _ = C 1 []

-- Ideally, we'd like to define this, but...
-- type CtgMap m =  MArray a ContigID m => a ContigID ClusterPtr 
-- type ClsMap m = MArray a ClusterID m => a ClusterID Cluster

-- Call this for each contig to complete the scaffolding
scaffold1 :: (MArray a ClusterPtr m, MArray a Cluster m) 
            => Reads -> Links 
             -> (a ContigID ClusterPtr, a ClusterID Cluster)
             -> ContigID -> m ()
scaffold1 rds ls (ctgs,clss) c = do
  -- CP i <- readArray ctgs c
  case lookup_left ls c of 
    [] -> return ()
    (l:_) -> case lookup_link ls l of
      [] -> return ()
      (x:_) -> when (targetid x == c) $ do
        let e = if targetLeft l then Five else Three
        merge ctgs clss (c,Five,contigid x,e)
  case lookup_right ls c of
    [] -> return ()
    (r:_) -> case lookup_link ls r of
      [] -> return ()
      (x:_) -> when (targetid x == c) $ do
        let e = if targetLeft r then Five else Three
        merge ctgs clss (c,Three,contigid x,e)

lookup_link ls r = (if targetLeft r then lookup_left else lookup_right) ls (targetid r)

-- inefficient!
myelems :: [(ContigID,End,ContigID,End)] -> [ContigID]
myelems = nub . go
  where go ((x1,_,x2,_):xs) = x1 : x2 : go xs
        go [] = []

-- Merge the clusters corresponding to 'ContigDir's.
merge :: (MArray a ClusterPtr m, MArray a Cluster m) 
          => a ContigID ClusterPtr -> a ClusterID Cluster 
         -> (ContigID,End,ContigID,End) -> m ()
merge ctgs clss (c1,e1,c2,e2) = do
    -- Get the cluster numbers for each contig
    CP i1 <- readArray ctgs c1
    CP i2 <- readArray ctgs c2

    if i1 == i2
      then return () -- unread STM vars?
      else do     -- Read the cluster contents
        C n1 cl1 <- readArray clss i1
        C n2 cl2 <- readArray clss i2
        
        -- Order by size to minimize operations. This is especially important for 
        -- STM, which scales poorly with the number of TVars touched.
        let (nmin,cmin,nmax,cmax) = if n1 <= n2 then (i1,cl1,i2,cl2) else (i2,cl2,i1,cl1)
        writeArray clss nmax (C (n1+n2) (cmin++(c1,e1,c2,e2):cmax))
        writeArray clss nmin $ C 0 []
        mapM_ (\x -> writeArray ctgs x (CP nmax)) (nmin : myelems cmin) -- assumes element n is in cluster n
        return ()

-- | Set up the two arrays
initialize :: (MArray a ClusterPtr m, MArray a Cluster m) 
              => Int -> m (a ContigID ClusterPtr, a ClusterID Cluster) 
initialize n = do
  ctgs <- newListArray (0,n-1) [opencptr x | x <- [0..n-1]]
  clss <- newListArray (0,n-1) [singleton x | x <- [0..n-1]]
  return (ctgs,clss)

-- | Specialize for easier use for common values of "m"
initializeIO :: Int -> IO (IOArray ContigID ClusterPtr, IOArray ClusterID Cluster)
initializeIO = initialize 

initializeSTM :: Int -> IO (TArray ContigID ClusterPtr, TArray ClusterID Cluster)
initializeSTM n =  -- atomically . initialize  -- todo: split this up to make it faster!
  do
    ctgs <- atomically $ newArray (0,n-1) (CP (-1))
    sequence_ [ atomically $ writeArray ctgs i (CP i) | i <- [0..n-1]]
    clss <- atomically $ newArray (0,n-1) undefined
    sequence_ [ atomically $ writeArray clss i (singleton i) | i <- [0..n-1]]
    return (ctgs,clss)

initializeST :: forall s . Int -> ST s (STArray s ContigID ClusterPtr, STArray s ClusterID Cluster)
initializeST = initialize 
