#!/usr/local/cpython-3.10/bin/python3

"""Test dupdict."""

import sys
import itertools

import treap as treap_mod
import dupdict_mod
import red_black_dict_mod


def simple_generic_test(dict_like_class):
    """Test dict like objects with dupdict."""
    all_good = True

    dict_like_object = dict_like_class()

    dupdict = dupdict_mod.Dupdict(dict_like_object)

    key_set = set()
    value_set = set()
    items_set = set()

    for key in itertools.chain(range(10), range(10)):
        value = 2 ** key
        dupdict[key] = value
        key_set.add(key)
        value_set.add(value)
        items_set.add((key, value))

    if set(dupdict) != key_set:
        sys.stderr.write('{}: simple_generic_test: set(dupdict) != key_set\n'.format(sys.argv[0]))
        all_good = False

    if set(dupdict.values()) != value_set:
        sys.stderr.write('{}: simple_generic_test: set(dupdict.values()) != value_set\n'.format(sys.argv[0]))
        all_good = False

    if set(dupdict.items()) != items_set:
        sys.stderr.write('{}: simple_generic_test: set(dupdict.items()) != items_set\n'.format(sys.argv[0]))
        all_good = False

    return all_good


def test_simple_dict_test():
    """Throw some values in a dictionary and make sure they come back out OK."""
    return simple_generic_test(dict)


def test_simple_treap_test():
    """Throw some values in a dictionary and make sure they come back out OK."""
    return simple_generic_test(treap_mod.treap)


def test_simple_red_black_tree_test():
    """Throw some values in a dictionary and make sure they come back out OK."""
    return simple_generic_test(red_black_dict_mod.RedBlackTree)


def test_find_min_dict():
    """Find the minimum of a dict."""
    all_good = True

    dupdict = dupdict_mod.Dupdict(dict())

    for i in reversed(range(10)):
        dupdict[i] = 2 ** i

    if dupdict.find_min() != 0:
        sys.stderr.write('{}: test_find_min_dict: Incorrect minimum\n'.format(sys.argv[0]))
        all_good = False

    return all_good


def test_find_min_treap():
    """Find the minimum of a treap."""
    all_good = True

    dupdict = dupdict_mod.Dupdict(treap_mod.treap())

    for i in reversed(range(10)):
        dupdict[i] = 2 ** i

    if dupdict.find_min() != 0:
        sys.stderr.write('{}: test_find_min_treap: Incorrect minimum\n'.format(sys.argv[0]))
        all_good = False

    return all_good


def test_find_max():
    """Find the maximum of a red-black tree."""
    all_good = True

    dupdict = dupdict_mod.Dupdict(red_black_dict_mod.RedBlackTree())

    for i in reversed(range(10)):
        dupdict[i] = 2 ** i

    if dupdict.find_max() != 9:
        sys.stderr.write('{}: test_find_max: Incorrect maximum\n'.format(sys.argv[0]))
        all_good = False

    return all_good


def test_get_all():
    """Test get_all()."""
    all_good = True

    dupdict = dupdict_mod.Dupdict(treap_mod.treap())

    dupdict[1] = 2
    dupdict[2] = 4
    dupdict[2] = 6
    dupdict[3] = 8

    if dupdict.get_all(2) == [4, 6]:
        pass
    else:
        sys.stderr.write('{}: test_get_all: dupdict[2] != [4, 6]\n'.format(sys.argv[0]))
        all_good = False

    return all_good


def test_del_all():
    """Test del_all()."""
    all_good = True

    dupdict = dupdict_mod.Dupdict(treap_mod.treap())

    dupdict[1] = 2
    dupdict[2] = 4
    dupdict[2] = 6
    dupdict[3] = 8

    dupdict.del_all(2)

    if list(dupdict.all_items()) == [(1, [2]), (3, [8])]:
        pass
    else:
        sys.stderr.write('{}: test_set_all: dupdict[2] != [12]\n'.format(sys.argv[0]))
        all_good = False

    return all_good


def test_all_items():
    """Test all_items()."""
    all_good = True

    dupdict = dupdict_mod.Dupdict(treap_mod.treap())

    dupdict[1] = 2
    dupdict[2] = 4
    dupdict[2] = 6
    dupdict[3] = 8

    # These should be in order by key, because we're using a treap
    if list(dupdict.all_items()) == [(1, [2]), (2, [4, 6]), (3, [8])]:
        pass
    else:
        sys.stderr.write('{}: test_all_items: dupdict.all_items() is off\n'.format(sys.argv[0]))
        all_good = False

    return all_good


def test_len():
    """Test len()."""
    all_good = True

    dupdict = dupdict_mod.Dupdict(treap_mod.treap())

    # pylint: disable=len-as-condition
    if len(dupdict) != 0:
        sys.stderr.write('{}: test_len: empty dupdict is not len 0\n'.format(sys.argv[0]))
        all_good = False

    dupdict[0] = 1

    if len(dupdict) != 1:
        sys.stderr.write('{}: test_len: dupdict with one element is not len 1\n'.format(sys.argv[0]))
        all_good = False

    dupdict[0] = 1.5

    if len(dupdict) != 2:
        sys.stderr.write('{}: test_len: dupdict with two elements is not len 2\n'.format(sys.argv[0]))
        all_good = False

    dupdict[1] = 2

    if len(dupdict) != 3:
        sys.stderr.write('{}: test_len: dupdict with three elements is not len 3\n'.format(sys.argv[0]))
        all_good = False

    del dupdict[0]

    if len(dupdict) != 2:
        sys.stderr.write('{}: test_len: dupdict with two elements post-del is not again len 2\n'.format(sys.argv[0]))
        all_good = False

    del dupdict[1]

    if len(dupdict) != 1:
        sys.stderr.write('{}: test_len: dupdict with one element post-2nd-del is not again len 1\n'.format(sys.argv[0]))
        all_good = False

    return all_good


def test_remove_min(dict_like_object):
    """Test the remove_min method."""
    all_good = True

    dupdict = dupdict_mod.Dupdict(dict_like_object)

    dupdict[0] = 1
    dupdict[1] = 2
    dupdict[2] = 4
    dupdict[1] = 3

    (key, value) = dupdict.remove_min()

    if key != 0:
        sys.stderr.write('{}: test_remove_min: key not 0 (1)\n'.format(sys.argv[0]))
        all_good = False

    if value != 1:
        sys.stderr.write('{}: test_remove_min: value not 0 (1)\n'.format(sys.argv[0]))
        all_good = False

    (key, value) = dupdict.remove_min()

    if key != 1:
        sys.stderr.write('{}: test_remove_min: key not 1 (2)\n'.format(sys.argv[0]))
        all_good = False

    if value != 3:
        sys.stderr.write('{}: test_remove_min: value not 2 (2)\n'.format(sys.argv[0]))
        all_good = False

    (key, value) = dupdict.remove_min()

    if key != 1:
        sys.stderr.write('{}: test_remove_min: key not 1 (3)\n'.format(sys.argv[0]))
        all_good = False

    if value != 2:
        sys.stderr.write('{}: test_remove_min: value not 3 (3)\n'.format(sys.argv[0]))
        all_good = False

    (key, value) = dupdict.remove_min()

    if key != 2:
        sys.stderr.write('{}: test_remove_min: key not 2 (3)\n'.format(sys.argv[0]))
        all_good = False

    if value != 4:
        sys.stderr.write('{}: test_remove_min: value not 4 (3)\n'.format(sys.argv[0]))
        all_good = False

    try:
        (key, value) = dupdict.remove_min()
    except (IndexError, KeyError, ValueError):
        pass
    else:
        string = (
            '{}: test_remove_min: remove_min from empty dictionary like '.format(sys.argv[0]) +
            'object raise appropriate exception\n'
            )
        sys.stderr.write(string)
        all_good = False

    return all_good


def test_remove_max(dict_like_object):
    """Test the remove_min method."""
    all_good = True

    dupdict = dupdict_mod.Dupdict(dict_like_object)

    dupdict[0] = 1
    dupdict[1] = 2
    dupdict[2] = 4
    dupdict[1] = 3

    (key, value) = dupdict.remove_max()

    if key != 2:
        sys.stderr.write('{}: test_remove_max: key not 2 (1)\n'.format(sys.argv[0]))
        all_good = False

    if value != 4:
        sys.stderr.write('{}: test_remove_max: value not 4 (1)\n'.format(sys.argv[0]))
        all_good = False

    (key, value) = dupdict.remove_max()

    if key != 1:
        sys.stderr.write('{}: test_remove_max: key not 1 (2)\n'.format(sys.argv[0]))
        all_good = False

    if value != 3:
        sys.stderr.write('{}: test_remove_max: value not 2 (2)\n'.format(sys.argv[0]))
        all_good = False

    (key, value) = dupdict.remove_max()

    if key != 1:
        sys.stderr.write('{}: test_remove_max: key not 1 (3)\n'.format(sys.argv[0]))
        all_good = False

    if value != 2:
        sys.stderr.write('{}: test_remove_max: value not 3 (3)\n'.format(sys.argv[0]))
        all_good = False

    (key, value) = dupdict.remove_max()

    if key != 0:
        sys.stderr.write('{}: test_remove_max: key not 0 (3)\n'.format(sys.argv[0]))
        all_good = False

    if value != 1:
        sys.stderr.write('{}: test_remove_max: value not 1 (3)\n'.format(sys.argv[0]))
        all_good = False

    try:
        (key, value) = dupdict.remove_max()
    except (IndexError, KeyError, ValueError):
        pass
    else:
        string = (
            '{}: test_remove_max: remove_max from empty dictionary like '.format(sys.argv[0]) +
            'object raise appropriate exception\n'
        )
        sys.stderr.write(string)
        all_good = False

    return all_good


def test_del_til_empty():
    """Test deletion all values associated with a key - and then one more, expecting a KeyError exception."""
    dupdict = dupdict_mod.Dupdict({})

    dupdict['a'] = 'a'
    dupdict['b'] = 'b'
    dupdict['b'] = 'c'

    del dupdict['b']
    del dupdict['b']

    try:
        del dupdict['b']
    except KeyError:
        pass
    else:
        print('{sys.argv[0]}: test_del_til_empty: No KeyError.')
        return False

    return True


def main():
    """Start the ball rolling."""
    all_good = True

    all_good &= test_simple_dict_test()
    all_good &= test_simple_treap_test()
    all_good &= test_simple_red_black_tree_test()
    all_good &= test_find_min_dict()
    all_good &= test_find_min_treap()
    all_good &= test_find_max()
    all_good &= test_get_all()
    all_good &= test_del_all()
    all_good &= test_all_items()
    all_good &= test_len()
    all_good &= test_del_til_empty()

    # Test remove_min and remove_max with 3 different underlying datastructures
    all_good &= test_remove_min(treap_mod.treap())
    all_good &= test_remove_max(treap_mod.treap())
    all_good &= test_remove_min(red_black_dict_mod.RedBlackTree())
    all_good &= test_remove_max(red_black_dict_mod.RedBlackTree())
    all_good &= test_remove_min({})
    all_good &= test_remove_max({})

    if all_good:
        sys.exit(0)
    else:
        sys.stderr.write('{}: One or more tests failed\n'.format(sys.argv[0]))
        sys.exit(1)


main()