#!/usr/local/cpython-3.3/bin/python

"""Perform unit tests for binary_tree_dict_mod."""

import sys
import math
import random

import binary_tree_dict_mod


def make_used(var):
    """Convince linters that var is used."""
    assert True or var


def test_depth():
    """Test getting the depth of a splay tree."""
    all_good = True

    binary_tree_dict = binary_tree_dict_mod.BinaryTreeDict()
    binary_tree_dict[2] = "def"
    binary_tree_dict[3] = "ghi"
    binary_tree_dict[1] = "abc"
    binary_tree_dict[4] = "jkl"
    binary_tree_dict[5] = "mno"
    binary_tree_dict[7] = "stu"
    binary_tree_dict[8] = "vwx"
    binary_tree_dict[9] = "yz"
    binary_tree_dict[6] = "pqr"

    depth = binary_tree_dict.depth()
    len_binary_tree_dict = len(binary_tree_dict)
    log_2_len_binary_tree_dict = int(round(math.log(len_binary_tree_dict, 2)))

    # Binary trees can be linear lists, in the worst case...
    if log_2_len_binary_tree_dict <= depth <= len_binary_tree_dict:
        pass
    else:
        sys.stderr.write("%s: test_depth: Bogus depth: %d\n" % (sys.argv[0], depth))
        all_good = False

    return all_good


def test_insert():
    """Insert some values into a binary_tree_dict tree, then make sure they can all be found in the tree."""
    keys = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
    random.seed(0)
    random.shuffle(keys)
    binary_tree_dict = binary_tree_dict_mod.BinaryTreeDict()
    for key in keys:
        binary_tree_dict[key] = str(key)

    all_good = True

    for key in keys:
        if str(key) != binary_tree_dict[key]:
            all_good = False
            sys.stderr.write(
                "%s: test_insert: Found mismatched key: Got %s, expected %s\n"
                % (sys.argv[0], binary_tree_dict[key], str(key))
            )

    return all_good


def test_all_removal():
    """
    Test removal of all values from a dict, and make sure it is empty afterward.

    1) Insert some values into a binary_tree_dict tree.
    2) Make sure they can all be removed.
    3) Test if the binary_tree_dict tree is empty.
    """
    keys = list(range(30))
    random.seed(0)
    random.shuffle(keys)
    binary_tree_dict = binary_tree_dict_mod.BinaryTreeDict()
    for key in keys:
        binary_tree_dict[key] = str(key)
    # Shuffle again to make it a more interesting test
    random.shuffle(keys)

    all_good = True

    for key in keys:
        del binary_tree_dict[key]
        try:
            make_used(binary_tree_dict[key])
        except KeyError:
            pass
        else:
            all_good = False
            sys.stderr.write(
                "%s: test_all_removal: element %s not removed\n" % (sys.argv[0], key)
            )

    if binary_tree_dict:
        all_good = False
        sys.stderr.write("%s: test_all_removal: final tree nonempty\n" % (sys.argv[0],))

    return all_good


def test_two_3rd_removal():
    """
    Test removing -some- values.

    1) Insert some values into a binary_tree_dict tree.
    2) Remove 2/3rds of them.
    3) Ensure the binary_tree_dict tree has 1/3rd of the values.
    """
    all_keys = list(range(30))
    one_3rd_keys = [key for key in all_keys if key % 3 == 0]
    two_3rds_keys = [key for key in all_keys if key % 3 != 0]
    # Interestingly, python 2.x and 3.x return a different number sequence for random.seed(0).
    # Pypy and pypy3 differ identically.
    random.seed(0)
    random.shuffle(all_keys)
    one_3rd_keys.sort()
    random.shuffle(two_3rds_keys)
    binary_tree_dict = binary_tree_dict_mod.BinaryTreeDict()
    for key in all_keys:
        binary_tree_dict[key] = str(key)

    for key in two_3rds_keys:
        del binary_tree_dict[key]

    all_good = True

    if list(binary_tree_dict) != one_3rd_keys:
        all_good = False
        sys.stderr.write(
            "%s: test_two_3rd_removal: final tree has bad values\n\n" % (sys.argv[0],)
        )
        sys.stderr.write("binary_tree_dict:\n")
        sys.stderr.write(str(binary_tree_dict))
        sys.stderr.write("one_3rd_keys:\n")
        sys.stderr.write(str(one_3rd_keys))

    # This is what revealed the random.seed(0) difference.
    # print(binary_tree_dict)

    return all_good


def test_one_removal():
    """Add one node to a tree, then delete it."""
    all_good = True
    binary_tree_dict = binary_tree_dict_mod.BinaryTreeDict()

    # add one key-value pair
    binary_tree_dict[1] = 2

    # Check that it knows it has one pair
    if len(binary_tree_dict) != 1:
        sys.stderr.write(
            "%s: test_one_removal: tree does not have one node\n" % (sys.argv[0],)
        )
        all_good = False

    # Delete the pair
    del binary_tree_dict[1]

    # Test that it knows it has zero pairs
    if len(binary_tree_dict) != 0:
        sys.stderr.write(
            "%s: test_one_removal: tree does not have zero nodes\n" % (sys.argv[0],)
        )
        all_good = False

    # Test that it knows it has zero pairs in a 2nd way
    if binary_tree_dict:
        sys.stderr.write("%s: test_one_removal: tree not empty\n" % (sys.argv[0],))
        all_good = False

    return all_good


def test_none_removal():
    """Add one node to a tree, then delete it."""
    all_good = True
    binary_tree_dict = binary_tree_dict_mod.BinaryTreeDict()

    # add one key-value pair
    binary_tree_dict[None] = None

    # Check that it knows it has one pair
    if len(binary_tree_dict) != 1:
        sys.stderr.write(
            "%s: test_none_removal: tree does not have one node\n" % (sys.argv[0],)
        )
        all_good = False

    # Delete the pair
    del binary_tree_dict[1]

    # Test that it knows it has zero pairs
    if len(binary_tree_dict) != 0:
        sys.stderr.write(
            "%s: test_none_removal: tree does not have zero nodes\n" % (sys.argv[0],)
        )
        all_good = False

    # Test that it knows it has zero pairs in a 2nd way
    if binary_tree_dict:
        sys.stderr.write("%s: test_none_removal: tree not empty\n" % (sys.argv[0],))
        all_good = False

    return all_good


def test_large_inserts():
    """Insert lots of values into a binary_tree_dict tree, just to see if we get a traceback."""
    all_good = True

    binary_tree_dict = binary_tree_dict_mod.BinaryTreeDict()
    nums = 100000
    gap = 997
    key = gap
    expected_min = None
    expected_max = None
    while key != 0:
        if expected_min is None or key < expected_min:
            expected_min = key
        if expected_max is None or key > expected_max:
            expected_max = key
        binary_tree_dict[key] = str(key)
        key = (key + gap) % nums

    actual_min = binary_tree_dict.find_min()
    if expected_min != actual_min:
        sys.stderr.write(
            "%s: Large binary_tree_dict did not return correct minimum: expected: %s, actual: %s\n"
            % (sys.argv[0], expected_min, actual_min)
        )
        all_good = False

    actual_max = binary_tree_dict.find_max()
    if expected_max != actual_max:
        sys.stderr.write(
            "%s: Large binary_tree_dict did not return correct maximum: expected: %s, actual: %s\n"
            % (sys.argv[0], expected_max, actual_max)
        )
        all_good = False

    return all_good


def test_nonempty():
    """Test a nonempty binary_tree_dict tree."""
    keys = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
    random.seed(0)
    random.shuffle(keys)
    binary_tree_dict = binary_tree_dict_mod.BinaryTreeDict()
    for key in keys:
        binary_tree_dict[key] = str(key)

    all_good = True

    if not binary_tree_dict:
        all_good = False
        sys.stderr.write("%s: nonempty binary_tree_dict looks empty\n" % sys.argv[0])

    return all_good


def test_empty():
    """Test an empty binary_tree_dict tree."""
    all_good = True

    binary_tree_dict = binary_tree_dict_mod.BinaryTreeDict()

    if binary_tree_dict:
        all_good = False
        sys.stderr.write("%s: empty binary_tree_dict looks nonempty\n" % sys.argv[0])

    return all_good


def test_min_max():
    """Insert some values into a binary_tree_dict tree, then test find_min and find_max."""
    keys = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
    random.seed(0)
    random.shuffle(keys)
    binary_tree_dict = binary_tree_dict_mod.BinaryTreeDict()
    for key in keys:
        binary_tree_dict[key] = str(key)

    all_good = True

    actual_min = binary_tree_dict.find_min()
    if actual_min != 0:
        sys.stderr.write("%s: minimum was not 0: %s\n" % (sys.argv[0], actual_min))
        all_good = False

    actual_max = binary_tree_dict.find_max()
    if actual_max != 9:
        sys.stderr.write("%s: maximum was not 9: %s\n" % (sys.argv[0], actual_max))
        all_good = False

    return all_good


def test_values():
    """Insert a few key-value pairs, and make sure they come back out OK."""
    all_good = True

    binary_tree_dict = binary_tree_dict_mod.BinaryTreeDict()
    binary_tree_dict[2] = "def"
    binary_tree_dict[3] = "ghi"
    binary_tree_dict[1] = "abc"
    binary_tree_dict[4] = "jkl"
    binary_tree_dict[5] = "mno"
    binary_tree_dict[7] = "stu"
    binary_tree_dict[8] = "vwx"
    binary_tree_dict[9] = "yz"
    binary_tree_dict[6] = "pqr"

    if binary_tree_dict.find_min() != 1:
        sys.stderr.write("%s: test_values: Minimum was not 0\n" % sys.argv[0])
        all_good = False

    if binary_tree_dict.find_max() != 9:
        sys.stderr.write("%s: test_values: Maximum was not 9\n" % sys.argv[0])
        all_good = False

    if binary_tree_dict[5] != "mno":
        sys.stderr.write("%s: test_values: Middle was not mno\n" % sys.argv[0])
        all_good = False

    return all_good


def test_inorder_traversal():
    """Test an inorder traversal."""
    list_ = []

    def visit(key, value):
        """Visit a node, sticking its key, value into a list."""
        list_.append((key, value))

    keys = [x * 3 + 1 for x in range(10)]
    random.seed(0)
    random.shuffle(keys)
    items = [(key, str(key)) for key in keys]
    items.sort()
    binary_tree_dict = binary_tree_dict_mod.BinaryTreeDict()
    for key in keys:
        binary_tree_dict[key] = str(key)

    all_good = True

    binary_tree_dict.inorder_traversal(visit)

    if items != list_:
        sys.stderr.write(
            "%s: test_inorder_traversal: inorder_traversal failed to rebuild the list\n"
            % (sys.argv[0],)
        )
        sys.stderr.write("Expected %s\n" % (items,))
        sys.stderr.write("Got %s\n" % (list_,))
        all_good = False

    return all_good


def test_str():
    """Test formatting a binary_tree_dict tree as a string."""
    all_good = True

    binary_tree_dict = binary_tree_dict_mod.BinaryTreeDict()
    binary_tree_dict[1] = "abc"
    binary_tree_dict[2] = "def"
    binary_tree_dict[3] = "ghi"
    binary_tree_dict[4] = "jkl"
    binary_tree_dict[5] = "mno"
    binary_tree_dict[6] = "pqr"
    binary_tree_dict[7] = "stu"
    binary_tree_dict[8] = "vwx"
    binary_tree_dict[9] = "yz"

    make_used(binary_tree_dict[3])

    string = str(binary_tree_dict)

    count = string.count("\n")
    len_binary_tree_dict = len(binary_tree_dict)
    maximum_allowable_depth = len_binary_tree_dict

    minimum_allowable_depth = int(round(math.log(len_binary_tree_dict, 2.0)))

    if minimum_allowable_depth >= count >= maximum_allowable_depth:
        sys.stderr.write(
            "%s: test_str: bad number of newlines: %d\n" % (sys.argv[0], count)
        )
        sys.stderr.write("%s\n" % count)
        sys.stderr.write("%s\n" % (string,))
        all_good = False

    return all_good


def test_iterator():
    """Test iterating over the enter binary_tree_dict tree."""
    all_good = True

    actual = []

    binary_tree_dict = binary_tree_dict_mod.BinaryTreeDict()
    binary_tree_dict[2] = "def"
    binary_tree_dict[3] = "ghi"
    binary_tree_dict[1] = "abc"
    binary_tree_dict[4] = "jkl"
    binary_tree_dict[5] = "mno"
    binary_tree_dict[7] = "stu"
    binary_tree_dict[8] = "vwx"
    binary_tree_dict[9] = "yz"
    binary_tree_dict[6] = "pqr"

    expected = ["abc", "def", "ghi", "jkl", "mno", "pqr", "stu", "vwx", "yz"]

    for value in binary_tree_dict.values():
        actual.append(value)

    if expected != actual:
        sys.stderr.write(
            "%s: test_iterator: values did not come back in correct order\n"
            % sys.argv[0]
        )
        all_good = False

    return all_good


def test_sequential():
    """Test inserting lots of sequential values."""
    all_good = True

    binary_tree_dict = binary_tree_dict_mod.BinaryTreeDict()

    # CPython 3.2 had a recursion limit of 100, so this should be adequate
    top = 2000
    try:
        for index in range(top):
            binary_tree_dict[index] = 1
    except RuntimeError:
        all_good = False
        sys.stderr.write("%s: Stack blown on __setitem__\n" % (sys.argv[0],))

    return all_good


def test_duplication():
    """Test inserting duplicate keys."""
    all_good = True

    binary_tree_dict = binary_tree_dict_mod.BinaryTreeDict()

    list_ = list(range(20))
    random.shuffle(list_)
    for number in list_:
        binary_tree_dict[number] = 2 ** number

    random.shuffle(list_)
    for number in list_:
        binary_tree_dict[number] = 2 ** number

    if len(binary_tree_dict) == 20:
        pass
    else:
        sys.stderr.write(
            "%s: number of elements is not 20: %s\n"
            % (sys.argv[0], len(binary_tree_dict))
        )
        all_good = False

    return all_good


def main():
    # pylint: disable=global-statement
    """Perform all tests."""
    all_good = True

    all_good &= test_depth()
    all_good &= test_insert()
    all_good &= test_min_max()
    all_good &= test_large_inserts()
    all_good &= test_nonempty()
    all_good &= test_empty()
    all_good &= test_values()
    all_good &= test_inorder_traversal()
    all_good &= test_str()
    all_good &= test_iterator()
    all_good &= test_all_removal()
    all_good &= test_two_3rd_removal()
    all_good &= test_one_removal()
    all_good &= test_sequential()
    all_good &= test_duplication()

    if not all_good:
        sys.stderr.write("%s: One or more tests failed\n" % sys.argv[0])
        sys.exit(1)


main()