#!/usr/bin/env python3

# pylint: disable=F0401,W0404
# F0401: It's not a problem that we can't always import rolling_checksum_pyx_mod.  In fact, it's normal.
# W0404: Actually, it's not imported multiple times, but we turn off the warning anyway, because we don't
#    want the warning fouling things.

"""Unit tests for rolling checksum mod - there are others, these are just the most fundamental"""

import sys
import math
import functools

import bufsock
import comma_mod

sys.path.insert(0, '.')

quiet = False
got_module = False

while sys.argv[1:]:
    if sys.argv[1] == '--quiet':
        quiet = True
    elif sys.argv[1] == '--pure-python':
        got_module = True
        import rolling_checksum_py_mod as rolling_checksum_mod
    elif sys.argv[1] == '--cython':
        got_module = True
        import rolling_checksum_pyx_mod as rolling_checksum_mod
    elif sys.argv[1] == '--shedskin':
        got_module = True
        import rolling_checksum_ss_mod as rolling_checksum_mod
    del sys.argv[1]


if not got_module:
    print('You must specify one of --pure-python, --cython or --shedskin', file=sys.stderr)
    sys.exit(1)


def my_range(up_to):
    """A range() function (generator) with consistent semantics from 2.x to 3.x"""
    value = 0
    while value < up_to:
        yield value
        value += 1


class Block_sequence_params(object):
    """A class for abstracting away the differences between the kinds of block sequence generators we're interested in"""
    def __init__(self, generator, description, min_average, max_average, min_stddev, max_stddev, maximum):
        # pylint: disable=R0913
        # R0913: We want a few arguments
        self.generator = generator
        self.description = description
        self.min_average = min_average
        self.max_average = max_average
        self.min_stddev = min_stddev
        self.max_stddev = max_stddev
        self.maximum = maximum

    def average_in_range(self, average):
        """Test if an average is in range"""
        if self.min_average < average < self.max_average:
            sys.stdout.write('Good, %s average in range\n' % self.description)
            return True
        sys.stdout.write('%s average is not between %s and %s: %s\n' % (
            self.description,
            comma_mod.gimme_commas(self.min_average),
            comma_mod.gimme_commas(self.max_average),
            comma_mod.gimme_commas(str(int(average)))
        ))
        return False

    def stddev_in_range(self, standard_deviation):
        """Test if a standard deviation is in range"""
        if self.min_stddev < standard_deviation < self.max_stddev:
            sys.stdout.write('Good, %s standard deviation in range\n' % self.description)
            return True
        sys.stdout.write('%s standard_deviation is not between %s and %s: %s\n' % (
            self.description,
            comma_mod.gimme_commas(self.min_stddev),
            comma_mod.gimme_commas(self.max_stddev),
            comma_mod.gimme_commas(str(int(standard_deviation))),
        ))
        return False

    def highest_beneath_maximum(self, highest):
        """Check that the highest block length found is beneath our preset maximum"""
        if self.maximum is None:
            sys.stdout.write('Good, %s maximum,is None\n' % self.description)
            return True
        if highest <= self.maximum:
            sys.stdout.write('Good, %s highest %s beneath %s maximum\n' % (self.description, highest, self.maximum))
            return True
        sys.stdout.write('Bad, %s highest %s not beneath %s maximum\n' % (self.description, highest, self.maximum))
        return False


def rcm_size_and_accuracy_test():
    """Test for accuracy"""
    all_good = True

    file_handle = bufsock.bufsock(bufsock.rawio('rcm-input-data', 'rb'))
    expected_list = []
    while True:
        block = file_handle.read(2 ** 20)
        if not block:
            break
        expected_list.append(block)
    expected = b''
    file_handle.close()

    block_sequence_params = [
        Block_sequence_params(
            generator=functools.partial(rolling_checksum_mod.n_level_chunker, levels=3),
            description='n_level_chunker',
            min_average=300000,
            max_average=4000000,
            min_stddev=10,
            max_stddev=1500000,
            maximum=None,
        ),
        Block_sequence_params(
            generator=functools.partial(rolling_checksum_mod.min_max_chunker),
            description='min_max_chunker',
            min_average=700000,
            max_average=4000000,
            min_stddev=300000,
            max_stddev=1500000,
            maximum=2**22,
        ),
    ]

    for block_sequence_param in block_sequence_params:
        blocks = []
        file_handle = bufsock.bufsock(bufsock.rawio('rcm-input-data', 'rb'))

        total_len = 0
        for blockno, block in enumerate(block_sequence_param.generator(file_handle)):
            blocks.append(block)
            total_len += len(block)
            if not quiet:
                sys.stderr.write('Appended blockno %d of length (%s) %s\n' % (
                    blockno,
                    comma_mod.gimme_commas(total_len),
                    comma_mod.gimme_commas(len(block)),
                ))

        file_handle.close()

        all_good &= report(block_sequence_param, expected, blocks)

    return all_good


def report(block_sequence_param, expected, blocks):
    """Report on our findings from rcm_size_and_accuracy_test"""
    all_good = True

    actual = b''
    if actual == expected:
        sys.stdout.write('Good, %s files match\n' % block_sequence_param.description)
    else:
        sys.stdout.write('Incorrect, %s files do not match\n' % block_sequence_param.description)
        all_good = False

    lengths = [len(block) for block in blocks]
    average = float(sum(lengths)) / float(len(lengths))
    all_good &= block_sequence_param.average_in_range(average)

    standard_deviation = stddev(lengths, average)
    all_good &= block_sequence_param.stddev_in_range(standard_deviation)

    highest = max(lengths)
    all_good &= block_sequence_param.highest_beneath_maximum(highest)

    return all_good


def stddev(list_, average):
    """Compute the standard deviation of a list, given the list and its precomputed average"""
    total = 0.0
    for element in list_:
        total += (element - average) ** 2
    return math.sqrt(total / len(list_))


def main():
    """Main function"""

    all_good = True
    all_good &= rcm_size_and_accuracy_test()

    if all_good:
        sys.stderr.write('%s: All tests passed\n' % sys.argv[0])
    else:
        sys.stderr.write('%s: One or more tests failed\n' % sys.argv[0])
        sys.exit(1)


main()