#!/usr/bin/env python

# pylint: disable=F0401,W0404
# F0401: It's not a problem that we can't always import rolling_checksum_pyx_mod.  In fact, it's normal.
# W0404: Actually, it's not imported multiple times, but we turn off the warning anyway, because we don't
#    want the warning fouling things.

'''Unit tests for rolling checksum mod - there are others, these are just the most fundamental'''

#mport os
import sys
import math
import functools
#mport pprint
#mport random

import bufsock
import comma_mod
import python2x3

if sys.argv[1:] and sys.argv[1] == 'py':
    import rolling_checksum_py_mod as rolling_checksum_mod
elif sys.argv[1:] and sys.argv[1] == 'pyx':
    import rolling_checksum_pyx_mod as rolling_checksum_mod
else:
    sys.stderr.write('argv[1] must be py or pyx\n')
    sys.exit(1)

def my_range(up_to):
    '''A range() function (generator) with consistent semantics from 2.x to 3.x'''
    value = 0
    while value < up_to:
        yield value
        value += 1

class Block_sequence_params:
    '''A class for abstracting away the differences between the kinds of block sequence generators we're interested in'''
    def __init__(self, generator, description, min_average, max_average, min_stddev, max_stddev, maximum):
        # pylint: disable=R0913
        # R0913: We want a few arguments
        self.generator = generator
        self.description = description
        self.min_average = min_average
        self.max_average = max_average
        self.min_stddev = min_stddev
        self.max_stddev = max_stddev
        self.maximum = maximum

    def average_in_range(self, average):
        '''Test if an average is in range'''
        if self.min_average < average < self.max_average:
            print('Good, %s average in range' % self.description)
            return True
        else:
            print('%s average is not between %s and %s: %s' % (
                self.description,
                comma_mod.gimme_commas(self.min_average), 
                comma_mod.gimme_commas(self.max_average), 
                comma_mod.gimme_commas(str(int(average)))),
                )
            return False

    def stddev_in_range(self, standard_deviation):
        '''Test if a standard deviation is in range'''
        if self.min_stddev < standard_deviation < self.max_stddev:
            print('Good, %s standard deviation in range' % self.description)
            return True
        else:
            print('%s standard_deviation is not between %s and %s: %s' % (
                self.description,
                comma_mod.gimme_commas(self.min_stddev),
                comma_mod.gimme_commas(self.max_stddev),
                comma_mod.gimme_commas(str(int(standard_deviation))),
                ))
            return False
    
    def highest_beneath_maximum(self, highest):
        '''Check that the highest block length found is beneath our preset maximum'''
        if self.maximum is None:
            print('Good, %s maximum,is None' % self.description)
            return True
        elif highest <= self.maximum:
            print('Good, %s highest %s beneath %s maximum' % (self.description, highest, self.maximum))
            return True
        else:
            print('zbad, %s highest %s not beneath %s maximum' % (self.description, highest, self.maximum))
            return False

def rcm_size_and_accuracy_test():
    '''Test for accuracy'''
    all_good = True

    file_handle = bufsock.bufsock(bufsock.rawio('rcm-input-data', 'rb'))
    expected_list = []
    while True:
        block = file_handle.read(2 ** 20)
        if not block:
            break
        expected_list.append(block)
    expected = python2x3.empty_bytes.join(expected_list)
    file_handle.close()

    block_sequence_params = [
        Block_sequence_params(
            generator = functools.partial(rolling_checksum_mod.n_level_chunker, levels=3), 
            description = 'n_level_chunker', 
            min_average = 300000, 
            max_average = 4000000, 
            min_stddev = 10, 
            max_stddev = 1500000,
            maximum = None,
            ),
        Block_sequence_params(
            generator = functools.partial(rolling_checksum_mod.min_max_chunker), 
            description = 'min_max_chunker', 
            min_average = 700000, 
            max_average = 4000000, 
            min_stddev = 300000, 
            max_stddev = 1500000,
            maximum = 2**22,
            ),
        ]

    for block_sequence_param in block_sequence_params:
        blocks = []
        file_handle = bufsock.bufsock(bufsock.rawio('rcm-input-data', 'rb'))

        total_len = 0
        for blockno, block in enumerate(block_sequence_param.generator(file_handle)):
            blocks.append(block)
            total_len += len(block)
            sys.stderr.write('Appended blockno %d of length (%s) %s\n' % (
                blockno, 
                comma_mod.gimme_commas(total_len), 
                comma_mod.gimme_commas(len(block)),
                ))

        file_handle.close()

        all_good &= report(block_sequence_param, expected, blocks)


    return all_good

def report(block_sequence_param, expected, blocks):
    '''Report on our findings from rcm_size_and_accuracy_test'''
    all_good = True

    actual = python2x3.empty_bytes.join(blocks)
    if actual == expected:
        print('Good, %s files match' % block_sequence_param.description)
    else:
        print('Incorrect, %s files do not match' % block_sequence_param.description)
        all_good = False

    lengths = [ len(block) for block in blocks ]
    average = float(sum(lengths)) / float(len(lengths))
    all_good &= block_sequence_param.average_in_range(average)

    standard_deviation = stddev(lengths, average)
    all_good &= block_sequence_param.stddev_in_range(standard_deviation)

    highest = max(lengths)
    all_good &= block_sequence_param.highest_beneath_maximum(highest)

    return all_good

def stddev(list_, average):
    '''Compute the standard deviation of a list, given the list and its precomputed average'''
    total = 0.0
    for element in list_:
        total += (element - average) ** 2
    return math.sqrt(total / len(list_))

def main():
    '''Main function'''

    all_good = True
    all_good &= rcm_size_and_accuracy_test()

    if all_good:
        sys.stderr.write('%s: All tests passed\n' % sys.argv[0])
    else:
        sys.stderr.write('%s: One or more tests failed\n' % sys.argv[0])
        sys.exit(1)

main()