#!/usr/bin/python

"""Unit test trie_mod."""

from __future__ import print_function

import sys
import pprint

# import treap as treap_mod
import trie_mod


def dump(trie):
    """Print a trie."""
    print("Dumping trie:")
    for key in trie:
        print("  trie[%s] => %s" % (key, trie[key]))


def make_used(variable):
    """Persuade pylint and pyflakes that variable is used."""
    assert True or variable


def test_basic_assignment():
    """Test simple assignments."""
    all_good = True

    trie = trie_mod.Trie()
    trie["Foo"] = 101

    if trie["Foo"] != 101:
        sys.stderr.write('%s: test_basic_assignment: trie["Foo"] is not 101' % (sys.argv[0], ))
        all_good = False

    try:
        make_used(trie['Food'])
    except KeyError:
        pass
    else:
        sys.stderr.write('%s: test_basic_assignment: trie["Food"] did not raise KeyError' % (sys.argv[0], ))
        all_good = False

    if len(trie) != 1:
        sys.stderr.write('%s: test_basic_assignment: len incorrect\n' % (sys.argv[0], ))
        all_good = False

    if trie.node_count() != 3:
        sys.stderr.write('%s: test_basic_assignment: node_count incorrect\n' % (sys.argv[0], ))
        all_good = False

    if 'Foo' not in trie:
        sys.stderr.write('%s: test_basic_assignment: Foo not in trie\n' % (sys.argv[0], ))
        all_good = False

    trie["Bar"] = None

    if 'Bar' not in trie:
        sys.stderr.write('%s: test_basic_assignment: Bar not in trie\n' % (sys.argv[0], ))
        all_good = False

    return all_good


def test_basic_removal():
    """Test deletion."""
    all_good = True

    trie = trie_mod.Trie()

    trie["Foo"] = 101

    if not trie["Foo"]:
        sys.stderr.write('%s: test_basic_removal: Foo not in trie\n' % (sys.argv[0], ))
        all_good = False

    del trie["Foo"]

    try:
        make_used(trie['Foo'])
    except KeyError:
        pass
    else:
        sys.stderr.write('%s: test_basic_removal: Foo in trie\n' % (sys.argv[0], ))
        all_good = False

    if trie:
        sys.stderr.write('%s: test_basic_removal: len(trie) not 0\n' % (sys.argv[0], ))
        all_good = False

    if trie.node_count():
        sys.stderr.write('%s: test_basic_removal: trie.node_count() not 0\n' % (sys.argv[0], ))
        all_good = False

    if 'Foo' in trie:
        sys.stderr.write('%s: test_basic_removal: Foo in trie\n' % (sys.argv[0], ))
        all_good = False

    return all_good


def test_mixed_types():
    """Test mixed types - might be problematic on Python 3.x."""
    all_good = True

    trie = trie_mod.Trie()
    trie["Foo"] = 101
    list_ = [1, 2, 3]
    trie[list_] = 101

    if trie["Foo"] != 101:
        sys.stderr.write('%s: test_mixed_types: Foo in trie but wrong value\n' % (sys.argv[0], ))
        all_good = False

    if trie[list_] != 101:
        sys.stderr.write('%s: test_mixed_types: %s in trie but incorrect value\n' % (sys.argv[0], list_))
        all_good = False

    if list_ not in trie:
        sys.stderr.write('%s: test_mixed_types: %s not in trie\n' % (sys.argv[0], list_))
        all_good = False

    if 'Foo' not in trie:
        sys.stderr.write('%s: test_mixed_types: Foo not in trie\n' % (sys.argv[0], ))
        all_good = False

    del trie[list_]

    if list_ in trie:
        sys.stderr.write('%s: test_mixed_types: %s still in trie\n' % (sys.argv[0], list_))
        all_good = False

    return all_good


def test_iteration():
    """Test iterating over a trie."""
    all_good = True
    trie = trie_mod.Trie()
    trie["Foo"] = 101
    trie["Bar"] = 101
    trie["Grok"] = 101
    for key in trie:
        if key not in trie:
            sys.stderr.write('%s: test_iteration: key in and not in trie\n' % (sys.argv[0], ))
            all_good = False

        if trie[key] != 101:
            sys.stderr.write('%s: test_iteration: key in trie but wrong value\n' % (sys.argv[0], ))
            all_good = False

    return all_good


def test_addition():
    """Test adding two tries."""
    all_good = True
    trie = trie_mod.Trie()
    trie["Foo"] = 101
    trie2 = trie_mod.Trie()
    trie2["Food"] = 101
    trie3 = trie2 + trie

    if "Foo" not in trie:
        sys.stderr.write('%s: test_addition: Foo not in trie\n' % (sys.argv[0], ))
        all_good = False

    if "Food" in trie:
        sys.stderr.write('%s: test_addition: Food not in trie\n' % (sys.argv[0], ))
        all_good = False

    if "Food" not in trie2:
        sys.stderr.write('%s: test_addition: Food not in trie2\n' % (sys.argv[0], ))
        all_good = False

    if "Foo" in trie2:
        sys.stderr.write('%s: test_addition: Foo in trie2\n' % (sys.argv[0], ))
        all_good = False

    if "Foo" not in trie3:
        sys.stderr.write('%s: test_addition: Foo not in trie3\n' % (sys.argv[0], ))
        all_good = False

    if "Food" not in trie3:
        sys.stderr.write('%s: test_addition: Food not in trie3\n' % (sys.argv[0], ))
        all_good = False

    return all_good


def test_subtraction():
    """Test subtracting two tries."""
    all_good = True

    trie = trie_mod.Trie()
    trie["Food"] = 101
    trie["Foo"] = 101
    trie2 = trie_mod.Trie()
    trie2["Food"] = 101
    trie3 = trie - trie2
    trie4 = trie2 - trie

    if "Food" not in trie:
        sys.stderr.write('%s: test_subtraction: Food not in trie\n' % (sys.argv[0], ))
        all_good = False

    if "Foo" not in trie:
        sys.stderr.write('%s: test_subtraction: Foo not in trie\n' % (sys.argv[0], ))
        all_good = False

    if "Food" not in trie2:
        sys.stderr.write('%s: test_subtraction: Food not in trie2\n' % (sys.argv[0], ))
        all_good = False

    if "Foo" not in trie3:
        sys.stderr.write('%s: test_subtraction: Foo not in trie3\n' % (sys.argv[0], ))
        all_good = False

    if "Food" in trie3:
        sys.stderr.write('%s: test_subtraction: Food in trie3\n' % (sys.argv[0], ))
        all_good = False

    if "Foo" in trie4:
        sys.stderr.write('%s: test_subtraction: Foo in trie4\n' % (sys.argv[0], ))
        all_good = False

    if "Food" in trie4:
        sys.stderr.write('%s: test_subtraction: Food in trie4\n' % (sys.argv[0], ))
        all_good = False

    return all_good


def test_self_add():
    """Test iadd."""
    all_good = True

    trie = trie_mod.Trie()
    trie["Foo"] = 101
    trie2 = trie_mod.Trie()
    trie2["Food"] = 101

    if "Foo" not in trie:
        sys.stderr.write('%s: test_self_add: Foo not in trie\n' % (sys.argv[0], ))
        all_good = False

    if "Food" in trie:
        sys.stderr.write('%s: test_self_add: Food in trie\n' % (sys.argv[0], ))
        all_good = False

    if "Food" not in trie2:
        sys.stderr.write('%s: test_self_add: Food not in trie2\n' % (sys.argv[0], ))
        all_good = False

    if "Foo" in trie2:
        sys.stderr.write('%s: test_self_add: Foo not in trie2\n' % (sys.argv[0], ))
        all_good = False

    trie += trie2

    if "Foo" not in trie:
        sys.stderr.write('%s: test_self_add: Foo not in trie\n' % (sys.argv[0], ))
        all_good = False

    if "Food" not in trie:
        sys.stderr.write('%s: test_self_add: Food not in trie\n' % (sys.argv[0], ))
        all_good = False

    return all_good


def test_self_sub():
    """Test isub."""
    all_good = True

    trie = trie_mod.Trie()
    trie["Foo"] = 101
    trie["Food"] = 101
    trie2 = trie_mod.Trie()
    trie2["Food"] = 101

    if "Food" not in trie:
        sys.stderr.write('%s: test_self_sub: Food not in trie\n' % (sys.argv[0], ))
        all_good = False

    if "Foo" not in trie:
        sys.stderr.write('%s: test_self_sub: Foo not in trie\n' % (sys.argv[0], ))
        all_good = False

    if "Food" not in trie2:
        sys.stderr.write('%s: test_self_sub: Food not in trie2\n' % (sys.argv[0], ))
        all_good = False

    trie -= trie2

    if "Food" in trie:
        sys.stderr.write('%s: test_self_sub: Food in trie\n' % (sys.argv[0], ))
        all_good = False

    if "Foo" not in trie:
        sys.stderr.write('%s: test_self_sub: Foo not in trie\n' % (sys.argv[0], ))
        all_good = False

    if 'Food' not in trie2:
        sys.stderr.write('%s: test_self_sub: Food not in trie2\n' % (sys.argv[0], ))
        all_good = False

    return all_good


def test_self_get():
    """Test get method."""
    all_good = True

    trie = trie_mod.Trie()
    trie["Foo"] = 101

    if trie["Foo"] != 101:
        sys.stderr.write('%s: test_self_get: trie["Foo"] not 101\n' % (sys.argv[0], ))
        all_good = False

    try:
        make_used(trie['Food'])
    except KeyError:
        pass
    else:
        sys.stderr.write('%s: test_self_get: trie["Food"] did not raise KeyError\n' % (sys.argv[0], ))
        all_good = False

    if trie.get("Food", "Bar") != "Bar":
        sys.stderr.write('%s: test_self_get: trie.get("Food", "Bar") did not return "Bar"\n' % (sys.argv[0], ))
        all_good = False

    if trie.get("Food", default="Bar") != "Bar":
        sys.stderr.write('%s: test_self_get: trie.get("Food", "Bar") did not return "Bar" (2)\n' % (sys.argv[0], ))
        all_good = False

    if trie.get("Foo") != 101:
        sys.stderr.write('%s: test_self_get: trie.get("Foo") not 101\n' % (sys.argv[0], ))
        all_good = False

    if trie.get("Food") is not None:
        sys.stderr.write('%s: test_self_get: trie.get("Food") not None\n' % (sys.argv[0], ))
        all_good = False

    return all_good


def test_keys_by_prefix():
    """Test getting keys."""
    all_good = True

    trie = trie_mod.Trie()
    trie["Foo"] = 101
    trie["Food"] = 101
    trie["Eggs"] = 101
    kset = trie.keys()

    if "Foo" not in kset:
        sys.stderr.write('%s: test_keys_by_prefix: "Foo" not in kset\n' % (sys.argv[0], ))
        all_good = False

    if "Food" not in kset:
        sys.stderr.write('%s: test_keys_by_prefix: "Food" not in kset\n' % (sys.argv[0], ))
        all_good = False

    if "Eggs" not in kset:
        sys.stderr.write('%s: test_keys_by_prefix: "Eggs" not in kset\n' % (sys.argv[0], ))
        all_good = False

    kset = trie.keys("Foo")

    if "Foo" not in kset:
        sys.stderr.write('%s: test_keys_by_prefix: "Foo" not in restricted kset\n' % (sys.argv[0], ))
        all_good = False

    if "Food" not in kset:
        sys.stderr.write('%s: test_keys_by_prefix: "Food" not in restricted kset\n' % (sys.argv[0], ))
        all_good = False

    if "Eggs" in kset:
        sys.stderr.write('%s: test_keys_by_prefix: "Eggs" in restricted kset\n' % (sys.argv[0], ))
        all_good = False

    kset = trie.keys("Ox")

    if kset:
        sys.stderr.write('%s: test_keys_by_prefix: len(kset) (very restricted) not 0\n' % (sys.argv[0], ))
        all_good = False

    return all_good


def test_list_keys():
    """Test getting keys."""
    all_good = True

    trie = trie_mod.Trie()

    list_ = []
    list_.append([1, 2, 3])
    list_.append([5, 10, 15, 20])
    list_.append([2, 4, 8])
    list_.sort()

    for sublist in list_:
        trie[sublist] = True
    keys = list(trie)
    keys.sort()

    if list_ != keys:
        sys.stderr.write('%s: test_list_keys: list_ != keys\n' % (sys.argv[0], ))
        sys.stderr.write(pprint.pformat(list_))
        sys.stderr.write('\n')
        sys.stderr.write(pprint.pformat(keys))
        sys.stderr.write('\n')
        all_good = False

    return all_good


def test_list_ordered_keys():
    """Test getting ordered keys."""
    all_good = True

    trie = trie_mod.Trie()

    list_ = []
    list_.append([1, 2, 3])
    list_.append([5, 10, 15, 20])
    list_.append([2, 4, 8])
    list_.sort()

    for sublist in list_:
        trie[sublist] = True
    keys = list(trie.in_order())

    if list_ != keys:
        sys.stderr.write('%s: test_list_ordered_keys: list_ != keys\n' % (sys.argv[0], ))
        sys.stderr.write(pprint.pformat(list_))
        sys.stderr.write('\n')
        sys.stderr.write(pprint.pformat(keys))
        sys.stderr.write('\n')
        all_good = False

    return all_good


def test_empty():
    """Test an empty trie for Falsiness."""
    all_good = True
    trie = trie_mod.Trie()
    if trie:
        sys.stderr.write('%s: test_empty: trie not False\n' % (sys.argv[0], ))
        all_good = False

    return all_good


def test_nonempty():
    """Test an empty trie for Falsiness."""
    all_good = True
    trie = trie_mod.Trie()
    trie['abc'] = 'def'
    if not trie:
        sys.stderr.write('%s: test_empty: trie not True\n' % (sys.argv[0], ))
        all_good = False

    return all_good


def test_find_min_strings():
    """Test find_min method."""
    all_good = True

    trie = trie_mod.Trie()

    list_ = []
    list_.append('abc')
    list_.append('def')
    list_.append('ab')
    list_.sort()

    for sublist in list_:
        trie[sublist] = True

    if trie.find_min() != 'ab':
        sys.stderr.write('%s: test_find_min_strings: min != ab\n' % (sys.argv[0], ))
        all_good = False

    return all_good


def test_find_min_lists():
    """Test find_min method."""
    all_good = True

    trie = trie_mod.Trie()

    list_ = []
    list_.append([3, 4, 5])
    list_.append([1, 2, 3])
    list_.append([2, 4, 6])

    for sublist in list_:
        trie[sublist] = True

    minimum_expected = min(list_)

    minimum_actual = trie.find_min()

    if minimum_expected != minimum_actual:
        sys.stderr.write('%s: test_find_min_lists: expected: %s != actual: %s\n' % (sys.argv[0], minimum_expected, minimum_actual))
        all_good = False

    return all_good


def test_find_max_strings():
    """Test find_min method."""
    all_good = True

    trie = trie_mod.Trie()

    list_ = []
    list_.append('abc')
    list_.append('def')
    list_.append('ab')
    list_.sort()

    for sublist in list_:
        trie[sublist] = True

    if trie.find_max() != 'def':
        sys.stderr.write('%s: test_find_min_strings: max != def\n' % (sys.argv[0], ))
        all_good = False

    return all_good


def test_find_max_lists():
    """Test find_min method."""
    all_good = True

    trie = trie_mod.Trie()

    list_ = []
    list_.append([3, 4, 5])
    list_.append([1, 2, 3])
    list_.append([2, 4, 6])

    for sublist in list_:
        trie[sublist] = True

    minimum_expected = max(list_)

    minimum_actual = trie.find_max()

    if minimum_expected != minimum_actual:
        sys.stderr.write('%s: test_find_min_lists: expected: %s != actual: %s\n' % (sys.argv[0], minimum_expected, minimum_actual))
        all_good = False

    return all_good


def main():
    """Run tests."""
    all_good = True

    all_good &= test_basic_assignment()
    all_good &= test_basic_removal()
    all_good &= test_mixed_types()
    all_good &= test_iteration()
    all_good &= test_addition()
    all_good &= test_subtraction()
    all_good &= test_self_add()
    all_good &= test_self_sub()
    all_good &= test_self_get()
    all_good &= test_keys_by_prefix()
    all_good &= test_list_keys()
    all_good &= test_list_ordered_keys()
    all_good &= test_empty()
    all_good &= test_nonempty()
    all_good &= test_find_min_strings()
    all_good &= test_find_min_lists()
    all_good &= test_find_max_strings()
    all_good &= test_find_max_lists()

    if all_good:
        sys.exit(0)
    else:
        sys.stderr.write('%s: One or more tests failed\n' % (sys.argv[0], ))
        sys.exit(1)


main()