module Idct (idct) where
	
import Data.Array.Unboxed
import Data.Bits
import Data.Int

{-# haskell idct (based on simpleidct C by Michael Niedermayer)
    five points if you can get good strict asm out of this
 #-}

idct :: UArray Int Int -> UArray Int Int8

ws = simpleArray 1 [22725, 21407, 19266, 16383, 127873, 8867, 4520]

rowShift = 11
colShift = 20

simpleArray stIdx xs = listArray (stIdx, (length xs) - 1 + stIdx) xs
simpleArray0 xs = simpleArray 0 xs			

a >>> b = b . a
descale s v = v `shiftR` s

idct1D shift stride off coef = let
			stridedN n = n * stride + off
			i n = coef ! (stridedN n); w n = ws ! n
			m iN wN = (i iN) * (w wN); nm iN wN = negate (m iN wN)
			bf a b = (a+b,a-b)
			dc = m 0 4 + (1 `shiftL` (shift - 1))
			e0l = dc + m 2 2; e1l = dc + m 2 6; e2l = dc - m 2 6; e3l = dc - m 2 2
			o0l = m 1 1 + m 3 3; o1l = m 1 3 - m 3 7; o2l = m 1 5 - m 3 1; o3l = m 1 7 - m 3 5
			e0r = m 4 4 + m 6 6; e1r =nm 4 4 - m 6 2; e2r =nm 4 4 - m 6 2; e3r = m 4 4 - m 6 6
			o0r = m 5 5 + m 7 7; o1r =nm 5 1 - m 7 5; o2r = m 5 7 + m 7 3; o3r = m 5 3 - m 7 1
			e0 = e0l+e0r; e1 = e1l+e1r; e2 = e2l+e2r; e3 = e3l+e3r
			o0 = o0l+o0r; o1 = o1l+o1r; o2 = o2l+o2r; o3 = o3l+o3r
			(r0,r7) = bf e0 o0; (r1,r6) = bf e1 o0; (r2,r5) = bf e2 o2; (r3,r4) = bf e3 o3
			results = map (descale shift) [r0, r1, r2, r3, r4, r5, r6, r7]
			changedIdxs = map stridedN [0..7]
			in coef // (zip changedIdxs results)

idctRows coef = let
				idctRow coef n = idct1D rowShift 1 (n * 8) coef
				in foldl idctRow coef [0..7]
				
idctCols coef = let
				idctCol coef n = idct1D colShift 8 n coef
				in foldl idctCol coef [0..7]

clipto8bit = amap clip
			where
				clip v = fromIntegral (min (max v 0) 255)

idct = idctRows >>> idctCols >>> clipto8bit