#!/usr/bin/python

"""Unit test splay_mod."""

from __future__ import print_function

import sys
import math
import random

import splay_mod


def make_used(*var):
    """Persuade linters that var is used."""
    assert True or var


def test_insert():
    """Insert some values into a splay tree, then make sure they can all be found in the tree."""
    keys = list(range(100))

    random.seed(0)
    random.shuffle(keys)

    splay = splay_mod.Splay()
    for key in keys:
        splay.insert(key, str(key))

    all_good = True

    for key in keys:
        if str(key) != splay.find(key):
            all_good = False
            sys.stderr.write('%s: test_insert: Found mismatched key\n' % sys.argv[0])

    return all_good


def test_remove():
    """Insert some values into a splay tree, then make sure they can all be removed.  Finally, ensure the splay tree is empty."""
    keys = list(range(100))

    random.seed(1)
    random.shuffle(keys)

    splay = splay_mod.Splay()

    for key in keys:
        splay.insert(key, str(key))

    all_good = True

    random.shuffle(keys)

    for key in keys:
        del splay[key]
        try:
            make_used(splay[key])
        except KeyError:
            pass
        else:
            all_good = False
            sys.stderr.write('%s: test_remove: element not removed\n' % sys.argv[0])

    if splay:
        all_good = False
        sys.stderr.write('%s: test_remove: final tree nonempty\n' % sys.argv[0])

    return all_good


def test_large_inserts():
    """Insert lots of values into a splay tree, just to see if we get a traceback."""
    all_good = True

    splay = splay_mod.Splay()
    nums = 100000
    gap = 997
    key = gap
    actual_min = None
    actual_max = None
    while key != 0:
        if actual_min is None or key < actual_min:
            actual_min = key
        if actual_max is None or key > actual_max:
            actual_max = key
        splay.insert(key, str(key))
        key = (key + gap) % nums

    if actual_min != splay.find_min():
        sys.stderr.write('%s: Large splay did not return correct minimum\n' % sys.argv[0])
        all_good = False

    if actual_max != splay.find_max():
        sys.stderr.write('%s: Large splay did not return correct maximum\n' % sys.argv[0])
        all_good = False

    return all_good


def test_nonempty():
    """Test a nonempty splay tree."""
    keys = list(range(100))

    random.seed(2)
    random.shuffle(keys)

    splay = splay_mod.Splay()
    for key in keys:
        splay.insert(key, str(key))

    all_good = True

    if not splay:
        all_good = False
        sys.stderr.write('%s: nonempty splay looks empty\n' % sys.argv[0])

    return all_good


def test_empty():
    """Test an empty splay tree."""
    all_good = True

    splay = splay_mod.Splay()

    if splay:
        all_good = False
        sys.stderr.write('%s: empty splay looks nonempty\n' % sys.argv[0])

    return all_good


def test_min_max():
    """Insert some values into a splay tree, then test find_min and find_max."""
    keys = list(range(100))

    random.seed(3)
    random.shuffle(keys)

    splay = splay_mod.Splay()
    for key in keys:
        splay.insert(key, str(key))

    all_good = True

    if splay.find_min() != 0:
        sys.stderr.write('%s: minimum was not 0\n' % sys.argv[0])
        all_good = False

    if splay.find_max() != 99:
        sys.stderr.write('%s: maximum was not 99\n' % sys.argv[0])
        all_good = False

    return all_good


def test_values():
    """Insert a few key-value pairs, and make sure they come back out OK."""
    all_good = True

    splay = splay_mod.Splay()
    splay.insert(2, 'def')
    splay.insert(3, 'ghi')
    splay.insert(1, 'abc')
    splay.insert(4, 'jkl')
    splay.insert(5, 'mno')
    splay.insert(7, 'stu')
    splay.insert(8, 'vwx')
    splay.insert(9, 'yz')
    splay.insert(6, 'pqr')

    if splay.find_min() != 1:
        sys.stderr.write('%s: test_values: Minimum was not 0\n' % sys.argv[0])
        all_good = False

    if splay.find_max() != 9:
        sys.stderr.write('%s: test_values: Maximum was not 9\n' % sys.argv[0])
        all_good = False

    if splay.find(5) != 'mno':
        sys.stderr.write('%s: test_values: Middle was not mno\n' % sys.argv[0])
        all_good = False

    return all_good


def test_inorder_traversal():
    """Test an inorder traversal."""
    list_ = []

    def visit(key, value):
        """Visit a node, but sticking its key, value into a list."""
        list_.append((key, value))

    keys = [x*3 + 1 for x in range(100)]
    items = [(key, str(key)) for key in keys]
    splay = splay_mod.Splay()
    for key in keys:
        splay.insert(key, str(key))

    all_good = True

    splay.inorder_traversal(visit)

    if items != list_:
        sys.stderr.write('%s: test_inorder_traversal: inorder_traversal failed to rebuild the list')
        all_good = False

    return all_good


def test_depth():
    """Test getting the depth of a splay tree."""
    all_good = True

    splay = splay_mod.Splay()
    splay.insert(2, 'def')
    splay.insert(3, 'ghi')
    splay.insert(1, 'abc')
    splay.insert(4, 'jkl')
    splay.insert(5, 'mno')
    splay.insert(7, 'stu')
    splay.insert(8, 'vwx')
    splay.insert(9, 'yz')
    splay.insert(6, 'pqr')

    depth = splay.depth()
    min_depth = math.floor(math.log(9) / math.log(2))
    if min_depth <= depth <= 9:
        pass
    else:
        sys.stderr.write('%s: test_depth: Bogus depth: %d\n' % (sys.argv[0], depth))
        all_good = False

    return all_good


def test_str():
    """Test formatting a splay tree as a string."""
    all_good = True

    splay = splay_mod.Splay()
    splay.insert(1, 'abc')
    splay.insert(2, 'def')
    splay.insert(3, 'ghi')
    splay.insert(4, 'jkl')
    splay.insert(5, 'mno')
    splay.insert(6, 'pqr')

    dummy = splay.find(3)
    make_used(dummy)

    string = str(splay)

    count = string.count('\n')
    # It's actually 3 lines, but the 3rd line has no newline so we can feed it directly to print
    if count != 2:
        sys.stderr.write('%s: test_str: bad number of newlines: %d\n' % (sys.argv[0], count))
        all_good = False

    return all_good


def test_iterator():
    """Test iterating over the enter splay tree."""
    all_good = True

    actual = []

    splay = splay_mod.Splay()
    splay.insert(2, 'def')
    splay.insert(3, 'ghi')
    splay.insert(1, 'abc')
    splay.insert(4, 'jkl')
    splay.insert(5, 'mno')
    splay.insert(7, 'stu')
    splay.insert(8, 'vwx')
    splay.insert(9, 'yz')
    splay.insert(6, 'pqr')

    expected = ['abc', 'def', 'ghi', 'jkl', 'mno', 'pqr', 'stu', 'vwx', 'yz']

    for value in splay.values():
        actual.append(value)

    if expected != actual:
        sys.stderr.write('%s: test_iterator: values incorrect\n' % sys.argv[0])
        all_good = False

    return all_good


def test_replace():
    """Test replacing a value."""
    all_good = True

    actual = []

    splay = splay_mod.Splay()
    splay[2] = 'def'
    splay[3] = 'ghi'
    splay[1] = 'abc'
    splay[4] = 'jkl'
    splay[5] = 'mno'
    splay[7] = 'stu'
    splay[8] = 'vwx'
    splay[9] = 'yz'
    splay[6] = 'pqr'

    splay[4] = 'rep'

    expected = [
        (1, 'abc'),
        (2, 'def'),
        (3, 'ghi'),
        (4, 'rep'),
        (5, 'mno'),
        (6, 'pqr'),
        (7, 'stu'),
        (8, 'vwx'),
        (9, 'yz'),
        ]

    for pair in splay.items():
        actual.append(pair)

    if expected != actual:
        sys.stderr.write('%s: test_replace: items incorrect: %s\n' % (sys.argv[0], actual))
        print(splay)
        all_good = False

    return all_good


def main():
    """Run tests (main function)."""
    all_good = True
    all_good &= test_insert()
    all_good &= test_remove()
    all_good &= test_large_inserts()
    all_good &= test_nonempty()
    all_good &= test_empty()
    all_good &= test_min_max()
    all_good &= test_values()
    all_good &= test_inorder_traversal()
    all_good &= test_depth()
    all_good &= test_str()
    all_good &= test_iterator()
    all_good &= test_replace()

    if not all_good:
        sys.stderr.write('%s: One or more tests failed\n' % sys.argv[0])
        sys.exit(1)

    print('All tests passed.')


main()