#include <stdio.h>

#undef fprintf

static int setexit=0;

typedef struct row_stat {
    uint64_t whole_row, except_dc;
    uint64_t left_side, right_side;
    uint64_t complex;
} row_stat;

typedef struct dct_stat {
    row_stat rows[8], doublerows[4];
    row_stat quadrows[2], before_lastnz, general;
    uint64_t total, lastnz;
} dct_stat;

static dct_stat put_rows = {}, put_cols = {}, add_rows = {}, add_cols = {};

static void print_rs(FILE *out, row_stat *r, uint64_t t_)
{
    float t = t_;
    fprintf(out, "whole row: %f, except dc: %f, left: %f, right: %f, none: %f\n", (float)r->whole_row/t * 100., (float)r->except_dc/t * 100.,
            (float)r->left_side/t * 100., (float)r->right_side/t * 100., (float)r->complex/t * 100.);
}

static void print_dct_stat(FILE *out, dct_stat *s)
{
    int i;
    for (i=0; i < 2; i++) {
        fprintf(out, "rows %d-%d: ", i*4, i*4+3);
        print_rs(out, &s->rows[i], s->total/2);
    }
    
    for (i=0; i < 4; i++) {
        fprintf(out, "rows %d-%d: ", i*2, i*2+1);
        print_rs(out, &s->rows[i], s->total/4);
    }
    
    for (i=0; i < 8; i++) {
        fprintf(out, "row %d: ", i);
        print_rs(out, &s->rows[i], s->total/8);
    }
    
    fprintf(out, "rows before lastnz: ");
    print_rs(out, &s->before_lastnz, s->total);
    
    fprintf(out, "rows in general: ");
    print_rs(out, &s->general, s->total);
    
    fprintf(out, "average row lastnz: %f\n", (float)s->lastnz/(float)(s->total/8));
    fprintf(out, "total lastnz: %lld total: %lld\n", s->lastnz, s->total);
}

static void print_dct_zero()
{
    FILE *out = fopen("dct-stats.txt", "a");
    fprintf(out,"-idct_put rows\n");
    print_dct_stat(out, &put_rows);
    fprintf(out,"-idct_put cols\n");
    print_dct_stat(out, &put_cols);
    fprintf(out,"-idct_add rows\n");
    print_dct_stat(out, &add_rows);
    fprintf(out,"-idct_add cols\n");
    print_dct_stat(out, &add_cols);
    fprintf(out, "----------------------\n");
    fclose(out);
}

static int find_last_nonzero(int16_t *dct)
{
    int i;
    for (i=63; i >= 0; i--) if (dct[i]) return i;
    return 64;
}

static uint64_t row_is_zero(int16_t *dct, int n)
{
    while (n--) if (*dct++) return 0;
    return 1;
}

static row_stat count_a_row(int16_t *dct)
{
    row_stat r = {0};
    r.whole_row = row_is_zero(dct, 8);
    r.except_dc = row_is_zero(dct + 1, 7);
    r.left_side = row_is_zero(dct, 4);
    r.right_side = row_is_zero(dct + 4, 4);
    
    if (!(r.whole_row || r.except_dc || r.left_side || r.right_side)) r.complex = 1;
    
    return r;
}

static row_stat count_multiple_rows(int16_t *dct, int n)
{
    int16_t row[8] = {0};
    int i, j;
    
    for (i = 0; i < n; i++) 
        for (j = 0; j < 8; j++)
            row[i] |= dct[i*8+j];
    
    return count_a_row(row);
}

static void add_rs(row_stat *s, row_stat *d)
{
    d->whole_row += s->whole_row;
    d->except_dc += s->except_dc;
    d->left_side += s->left_side;
    d->right_side += s->right_side;
    d->complex += s->complex;
}

static void count_dct_zeroes(int16_t *dct, dct_stat *s)
{
    int lastnz = find_last_nonzero(dct), lastnz_row = lastnz / 8;
    int i;
    
    if (!setexit) {
        setexit = 1;
        atexit(print_dct_zero);
    }
    
    s->lastnz += lastnz;
    
    for (i=0;i<8;i++) {
        row_stat r = count_a_row(dct + i*8);
        s->total++;
        add_rs(&r, &s->rows[i]);
        if (i <= lastnz_row) add_rs(&r, &s->before_lastnz);
        add_rs(&r, &s->general);
    }
    
    for (i=0;i<4;i++) {
        row_stat r = count_multiple_rows(dct + i*8*2, 2);
        add_rs(&r, &s->doublerows[i]);
    }
    
    for (i=0;i<2;i++) {
        row_stat r = count_multiple_rows(dct + i*8*4, 4);
        add_rs(&r, &s->quadrows[i]);
    }
}