#!/usr/bin/python

# pylint: disable=superfluous-parens
# superfluous-parens: Parentheses are good for clarity and portability

"""Unit tests for bloom_filter_mod."""

import itertools
import math
import sys
import time
try:
    import anydbm
except ImportError:
    import dbm as anydbm
import random

import bloom_filter_mod

CHARACTERS = 'abcdefghijklmnopqrstuvwxyz1234567890'


def make_used(variable):
    """Persuade pyflakes and pylint that variable is used."""
    assert True or variable


def test(description, values, trials, error_rate, probe_bitnoer=bloom_filter_mod.get_bitno_lin_comb, filename=None):
    # pylint: disable=R0913,R0914
    # R0913: We want a few arguments
    # R0914: We want some local variables too.  This is just test code.
    """Perform some quick automatic tests for the bloom filter class."""
    all_good = True

    bloom_filter = bloom_filter_mod.Bloom_filter(
        ideal_num_elements_n=trials * 2,
        error_rate_p=error_rate,
        probe_bitnoer=probe_bitnoer,
        filename=filename,
        start_fresh=True,
        )

    for value in values.generator():
        bloom_filter.add(value)

    include_in_count = sum(include in bloom_filter for include in values.generator())
    if include_in_count == values.length():
        # Good
        pass
    else:
        print(f'{sys.argv[0]}: include_in_count has a strange value: {include_in_count}, should be {values.length()}')
        all_good = False

    false_positives = 0
    for trialno in range(trials):
        while True:
            candidate = ''.join(random.sample(CHARACTERS, 5))
            # If we accidentally found a member, try again
            if values.within(candidate):
                continue
            if candidate in bloom_filter:
                false_positives += 1
            break

    actual_error_rate = float(false_positives) / trials

    if actual_error_rate > error_rate:
        sys.stderr.write('%s: Too many false positives: actual: %s, expected: %s\n' % (
            sys.argv[0],
            actual_error_rate,
            error_rate,
            ), file=sys.stderr)
        all_good = False

    return all_good


class States(object):
    """Generate the USA's state names."""

    def __init__(self):
        """Initialize."""
        pass

    states = '''Alabama Alaska Arizona Arkansas California Colorado Connecticut
        Delaware Florida Georgia Hawaii Idaho Illinois Indiana Iowa Kansas
        Kentucky Louisiana Maine Maryland Massachusetts Michigan Minnesota
        Mississippi Missouri Montana Nebraska Nevada NewHampshire NewJersey
        NewMexico NewYork NorthCarolina NorthDakota Ohio Oklahoma Oregon
        Pennsylvania RhodeIsland SouthCarolina SouthDakota Tennessee Texas Utah
        Vermont Virginia Washington WestVirginia Wisconsin Wyoming'''.split()

    @classmethod
    def generator(cls):
        """Generate the states."""
        yield from cls.states

    @staticmethod
    def within(value):
        """Test whether the value in our list of states."""
        return value in States.states

    @staticmethod
    def length():
        """Return the length of our contained values."""
        return len(States.states)


def random_string():
    """Generate a random, 10 character string - for testing purposes."""
    list_ = []
    for chrno in range(10):
        make_used(chrno)
        character = CHARACTERS[int(random.random() * len(CHARACTERS))]
        list_.append(character)
    return ''.join(list_)


class Random_content(object):
    """Generated a bunch of random strings in sorted order."""

    random_content = [random_string() for dummy in range(1000)]

    def __init__(self):
        """Initialize."""
        pass

    @staticmethod
    def generator():
        """Generate all values."""
        for item in Random_content.random_content:
            yield item

    @staticmethod
    def within(value):
        """Test for membership."""
        return value in Random_content.random_content

    @staticmethod
    def length():
        """Return the number of members."""
        return len(Random_content.random_content)


class Evens(object):
    """Generate a bunch of even numbers."""

    def __init__(self, maximum):
        """Initialize."""
        self.maximum = maximum

    def generator(self):
        """Generate all values."""
        for value in range(self.maximum):
            if value % 2 == 0:
                yield str(value)

    def within(self, value):
        """Test for membership."""
        try:
            int_value = int(value)
        except ValueError:
            return False

        return int_value >= 0 and int_value < self.maximum and int_value % 2 == 0

    def length(self):
        """Return the number of members."""
        return int(math.ceil(self.maximum / 2.0))


def and_test(filename1, filename2):
    """Test the & operator."""
    all_good = True

    abc = bloom_filter_mod.Bloom_filter(ideal_num_elements_n=100, error_rate_p=0.01, filename=filename1, start_fresh=True)
    for character in ['a', 'b', 'c']:
        abc += character

    bcd = bloom_filter_mod.Bloom_filter(ideal_num_elements_n=100, error_rate_p=0.01, filename=filename2, start_fresh=True)
    for character in ['b', 'c', 'd']:
        bcd += character

    abc_and_bcd = abc
    abc_and_bcd &= bcd

    if 'a' in abc_and_bcd:
        print(
            f'and_test: a in abc_and_bcd, but should not be; filename1 is {filename1}, filename2 is {filename2}',
            file=sys.stderr,
        )
        all_good = False
    if 'b' not in abc_and_bcd:
        print(
            f'and_test: b not in abc_and_bcd, but should be; filename1 is {filename1}, filename2 is {filename2}',
            file=sys.stderr,
        )
        all_good = False
    if 'c' not in abc_and_bcd:
        print(
            f'and_test: c not in abc_and_bcd, but should be; filename1 is {filename1}, filename2 is {filename2}',
            file=sys.stderr,
        )
        all_good = False
    if 'd' in abc_and_bcd:
        print(
            f'and_test: d in abc_and_bcd, but should not be; filename1 is {filename1}, filename2 is {filename2}',
            file=sys.stderr,
        )
        all_good = False

    return all_good


def or_test(filename1, filename2):
    """Test the | operator."""
    all_good = True

    abc = bloom_filter_mod.Bloom_filter(ideal_num_elements_n=100, error_rate_p=0.01, filename=filename1, start_fresh=True)
    for character in ['a', 'b', 'c']:
        abc += character

    bcd = bloom_filter_mod.Bloom_filter(ideal_num_elements_n=100, error_rate_p=0.01, filename=filename2, start_fresh=True)
    for character in ['b', 'c', 'd']:
        bcd += character

    abc_and_bcd = abc
    abc_and_bcd |= bcd

    if 'a' not in abc_and_bcd:
        print(f'or_test: a not in abc_and_bcd, but should be; filename1 is {filename1}, filename2 is {filename2}', file=sys.stderr)
        all_good = False
    if 'b' not in abc_and_bcd:
        print(f'or_test: b not in abc_and_bcd, but should be; filename1 is {filename1}, filename2 is {filename2}', file=sys.stderr)
        all_good = False
    if 'c' not in abc_and_bcd:
        print(f'or_test: c not in abc_and_bcd, but should be; filename1 is {filename1}, filename2 is {filename2}', file=sys.stderr)
        all_good = False
    if 'd' not in abc_and_bcd:
        print(f'or_test: d not in abc_and_bcd, but should be; filename1 is {filename1}, filename2 is {filename2}', file=sys.stderr)
        all_good = False
    if 'e' in abc_and_bcd:
        print(f'or_test: e in abc_and_bcd, but should not be; filename1 is {filename1}, filename2 is {filename2}', file=sys.stderr)
        all_good = False

    return all_good


def give_description(filename):
    """Return a description of the filename type - could be array, file or hybrid."""
    if filename is None:
        return 'array'

    if isinstance(filename, tuple):
        if filename[1] == -1:
            return 'mmap'
        return 'hybrid'

    return 'seek'


def minimal_1_test(filename):
    """Just create a bloom filter, add a single element to it, and then check if it's 'there'."""
    all_good = True
    bloom_filter = bloom_filter_mod.Bloom_filter(ideal_num_elements_n=100, error_rate_p=0.01, filename=filename, start_fresh=True)
    bloom_filter.add('x')
    if 'x' not in bloom_filter:
        print(f'{sys.argv[0]}: x in but not in minimalist bloom filter', file=sys.stderr)
        all_good = False
    if 'y' in bloom_filter or 'z' in bloom_filter:
        print(f'{sys.argv[0]}: y or z not in but in minimalist bloom filter', file=sys.stderr)
        all_good = False
    return all_good


def minimal_20_test(filename):
    """Just create a bloom filter, add 20 elements to it, and then check if they are 'there'."""
    all_good = True
    bloom_filter = bloom_filter_mod.Bloom_filter(ideal_num_elements_n=100, error_rate_p=0.01, filename=filename, start_fresh=True)
    elements = list('abcdefghijklmnopqrst')
    for string in elements:
        bloom_filter.add(string)
    for string in elements:
        if string not in bloom_filter:
            print(f'{sys.argv[0]}: {string} in but not in minimalist bloom filter', file=sys.stderr)
            all_good = False
    for string in list('uvwxyz'):
        if string in bloom_filter:
            print(f'{sys.argv[0]}: {string} in but not in minimalist bloom filter', file=sys.stderr)
            all_good = False
    return all_good


def persistence_test(description, persistent, filename):
    """Test for persistence and nonpersistence."""
    all_good = True

    # We "start_fresh", so it doesn't much matter what prior uses of this 'filename' were like.
    bloom_filter = bloom_filter_mod.Bloom_filter(
        ideal_num_elements_n=10000,
        error_rate_p=0.01,
        filename=filename,
        start_fresh=True,
    )

    persistent_in_strings = (
        'abc',
        'def',
        'ghi',
    )

    persistent_not_in_strings = (
        '123',
        'jkl',
    )

    for string in persistent_in_strings:
        bloom_filter.add(string)

    early_bit_count = bloom_filter.bit_count()

    bloom_filter.close()

    bloom_filter = bloom_filter_mod.Bloom_filter(
        ideal_num_elements_n=10000,
        error_rate_p=0.01,
        filename=filename,
        start_fresh=False,
    )

    late_bit_count = bloom_filter.bit_count()

    if persistent:
        if early_bit_count != late_bit_count:
            print(f'early_bit_count {early_bit_count} != late_bit_count {late_bit_count}')
            all_good = False
        for string in persistent_in_strings:
            if string not in bloom_filter:
                print(f'string {string} not in "{description}" filter, but should be')
                all_good = False
        for string in persistent_not_in_strings:
            if string in bloom_filter:
                print(f'string {string} in "{description}" filter, but should not be')
    else:
        if late_bit_count != 0:
            print('late_bit_count != 0')
            all_good = False
        for string in itertools.chain(persistent_in_strings, persistent_not_in_strings):
            if string in bloom_filter:
                print(f'string {string} in {description} filter')
                all_good = False

    return all_good


def main():
    """Perform unit tests for Bloom_filter class."""
    performance_test = bool(sys.argv[1:] == ['--performance-test'])

    all_good = True

    for (means, persistent, filename) in (
        # In memory as well as partially file seek (if needed).  Not persistent.
        ('part in memory, part file seek', True, ('bloom-filter.bin', 5000)),  # works partially
        ('part in memory, part file seek', True, ('bloom-filter.bin', 10000)),  # works fully
        ('part in memory, part file seek', True, ('bloom-filter.bin', 0)),  # works fully
        # File seek alone: persistent
        ('file seek', True, 'bloom-filter.bin'),
        # Wholly in-memory.  Not persistent.
        ('in memory', False, None),
        # mmap alone: persistent
        ('mmap', True, ('bloom-filter.bin', -1)),
    ):
        print(f'Filename: {filename}')

        all_good &= minimal_1_test(filename=filename)
        all_good &= minimal_20_test(filename=filename)

        all_good &= test('states', States(), trials=100000, error_rate=0.01, filename=filename)

        all_good &= test('random', Random_content(), trials=10000, error_rate=0.1, filename=filename)
        all_good &= test(
            'random',
            Random_content(),
            trials=10000,
            error_rate=0.1,
            probe_bitnoer=bloom_filter_mod.get_bitno_seed_rnd,
            filename=filename,
        )

        all_good &= test('random', Random_content(), trials=10000, error_rate=0.1, filename=filename)

        if isinstance(filename, tuple):
            filename2 = (filename[0] + '.2', filename[1])
        elif isinstance(filename, str):
            filename2 = filename + '.2'
        elif filename is None:
            filename2 = None
        else:
            raise ValueError(f'filename {filename} looks weird')
        all_good &= and_test(filename1=filename, filename2=filename2)
        all_good &= or_test(filename1=filename, filename2=filename2)

        all_good &= persistence_test(description=means, filename=filename, persistent=persistent)

    if performance_test:
        sqrt_of_10 = math.sqrt(10)
        for exponent in range(19):  # this is a lot, but probably not unreasonable
            elements = int(sqrt_of_10 ** exponent + 0.5)
            for filename in [None, 'bloom-filter-rm-me', ('bloom-filter-rm-me', 768 * 2**20), ('bloom-filter-rm-me', -1)]:
                description = give_description(filename)
                key = '%s %s' % (description, elements)
                database = anydbm.open('performance-numbers', 'c')
                if key in database.keys():
                    database.close()
                    continue
                if elements >= 100000000 and description == 'seek':
                    continue
                if elements >= 100000000 and description == 'mmap':
                    continue
                if elements >= 1000000000 and description == 'array':
                    continue
                time0 = time.time()
                all_good &= test(
                    'evens %s elements: %d' % (give_description(filename), elements),
                    Evens(elements),
                    trials=elements,
                    error_rate=1e-2,
                    filename=filename,
                    )
                time1 = time.time()
                delta_t = time1 - time0
                database = anydbm.open('performance-numbers', 'c')
                database[key] = '%f' % delta_t
                database.close()

    if all_good:
        sys.stderr.write('%s: All tests passed\n' % sys.argv[0])
        sys.exit(0)
    else:
        sys.stderr.write('%s: One or more tests failed\n' % sys.argv[0])
        sys.exit(1)


main()