{- |
   QuickBench - extends "Test.QuickCheck" to allow for benchmarking, 
      as well as providing a framework for running tests.
-}

module Test.QuickBench ( quickTime
                       , module Test.QuickCheck
                       ) where

import Test.QuickCheck
import System.CPUTime
import System.Time
import System.IO
import Random

-- below is ripped from QC and modified

quickTime :: Testable t => t -> IO ()
quickTime = tcheck (defaultConfig { configMaxTest = 10 })

tcheck :: Testable a => Config -> a -> IO ()
tcheck config a =
  do rnd <- newStdGen
     t1 <- getCPUTime
     tbench config (evaluate a) rnd 0 0 t1

tbench :: Config -> Gen Result -> StdGen -> Int -> Int -> Integer -> IO () 
tbench config gen rnd0 ntest nfail t1
  | ntest == configMaxTest config = do done "OK, passed" ntest t1
  | nfail == configMaxFail config = do done "Arguments exhausted after" ntest t1
  | otherwise               =
      do putStr (configEvery config ntest (arguments result))
         case ok result of
           Nothing    ->
             tbench config gen rnd1 ntest (nfail+1) t1
           Just True  ->
             tbench config gen rnd1 (ntest+1) nfail t1
           Just False ->
             putStr ( "Falsifiable, after "
                   ++ show ntest
                   ++ " tests:\n"
                   ++ unlines (arguments result)
                    )
     where 
       result = generate (configSize config ntest) rnd2 gen
       (rnd1,rnd2) = split rnd0

done str ntest t1 = do
  t2 <- getCPUTime
  putStrLn (str ++ " " ++ show ntest ++ " tests, CPU time: " ++ showT (t2-t1))

-- | Take time (CPU and wall clock) and report it
time :: String -> IO () -> IO ()
time msg action = do
    d1 <- getClockTime
    t1 <- getCPUTime
    action
    t2 <- getCPUTime
    d2 <- getClockTime
    putStrLn $ "\n"++msg++", CPU time: " ++ showT (t2-t1) ++ ", wall clock: "
                 ++ timeDiffToString (diffClockTimes d2 d1)

-- | Print a CPUTime difference
showT :: Integral a => a -> String
showT t = show (fromIntegral t/1e12)++"s"
