#!/usr/local/cpython-3.3/bin/python

'''Unit tests for scapegoat_mod'''

import sys
import math

import scapegoat_mod


# This is intentionally an invalid alpha, to make sure it gets overriden
# later
ALPHA = 0.0

def test_depth():
    '''Test getting the depth of a splay tree'''
    all_good = True

    scapegoat = scapegoat_mod.ScapeGoatTree(ALPHA)
    scapegoat[2] = 'def'
    scapegoat[3] = 'ghi'
    scapegoat[1] = 'abc'
    scapegoat[4] = 'jkl'
    scapegoat[5] = 'mno'
    scapegoat[7] = 'stu'
    scapegoat[8] = 'vwx'
    scapegoat[9] = 'yz'
    scapegoat[6] = 'pqr'

    depth = scapegoat.depth()
    len_scapegoat = len(scapegoat)
    base = 1.0 / scapegoat.alpha
    if depth <= math.log(len_scapegoat, base):
        pass
    else:
        sys.stderr.write('%s: test_depth: Bogus depth: %d\n' % (sys.argv[0], depth))
        all_good = False

    return all_good

def test_insert():
    '''Insert some values into a scapegoat tree, then make sure they can all be found in the tree'''
    keys = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
    scapegoat = scapegoat_mod.ScapeGoatTree(ALPHA)
    for key in keys:
        scapegoat[key] = str(key)
        #print('%s:' % (key, ))
        #print(scapegoat)
        #sys.stdout.write('\n')

    all_good = True

    for key in keys:
        if str(key) != scapegoat[key]:
            all_good = False
            sys.stderr.write('%s: test_insert: Found mismatched key: Got %s, expected %s\n' % (
                sys.argv[0], scapegoat[key], str(key)))

    return all_good

def test_remove():
    '''
    Insert some values into a scapegoat tree, then make sure they can
    all be removed.  Finally, ensure the scapegoat tree is empty
    '''
    keys = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
    scapegoat = scapegoat_mod.ScapeGoatTree(ALPHA)
    for key in keys:
        scapegoat[key] = str(key)

    all_good = True

    for key in keys:
        del scapegoat[key]
        try:
            dummy = scapegoat[key]
        except KeyError:
            pass
        else:
            all_good = False
            sys.stderr.write('%s: test_remove: element not removed\n' % sys.argv[0])

    if scapegoat:
        all_good = False
        sys.stderr.write('%s: test_remove: final tree nonempty\n' % sys.argv[0])

    return all_good

def test_large_inserts():
    '''Insert lots of values into a scapegoat tree, just to see if we get a traceback'''

    all_good = True

    scapegoat = scapegoat_mod.ScapeGoatTree(ALPHA)
    nums = 100000
    gap = 997
    key = gap
    actual_min = None
    actual_max = None
    while key != 0:
        if actual_min is None or key < actual_min:
            actual_min = key
        if actual_max is None or key > actual_max:
            actual_max = key
        scapegoat[key] = str(key)
        key = (key + gap) % nums

    if actual_min != scapegoat.find_min():
        sys.stderr.write('%s: Large scapegoat did not return correct minimum\n' % sys.argv[0])
        all_good = False

    if actual_max != scapegoat.find_max():
        sys.stderr.write('%s: Large scapegoat did not return correct maximum\n' % sys.argv[0])
        all_good = False

    return all_good

def test_nonempty():
    '''Test a nonempty scapegoat tree'''
    keys = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
    scapegoat = scapegoat_mod.ScapeGoatTree(ALPHA)
    for key in keys:
        scapegoat[key] = str(key)

    all_good = True

    if not scapegoat:
        all_good = False
        sys.stderr.write('%s: nonempty scapegoat looks empty\n' % sys.argv[0])

    return all_good

def test_empty():
    '''Test an empty scapegoat tree'''
    all_good = True

    scapegoat = scapegoat_mod.ScapeGoatTree(ALPHA)

    if scapegoat:
        all_good = False
        sys.stderr.write('%s: empty scapegoat looks nonempty\n' % sys.argv[0])

    return all_good

def test_min_max():
    '''Insert some values into a scapegoat tree, then test find_min and find_max'''
    keys = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
    scapegoat = scapegoat_mod.ScapeGoatTree(ALPHA)
    for key in keys:
        scapegoat[key] = str(key)

    all_good = True

    if scapegoat.find_min() != 0:
        sys.stderr.write('%s: minimum was not 0\n' % sys.argv[0])
        all_good = False

    if scapegoat.find_max() != 9:
        sys.stderr.write('%s: maximum was not 9\n' % sys.argv[0])
        all_good = False

    return all_good

def test_values():
    '''Insert a few key-value pairs, and make sure they come back out OK'''

    all_good = True

    scapegoat = scapegoat_mod.ScapeGoatTree(ALPHA)
    scapegoat[2] = 'def'
    scapegoat[3] = 'ghi'
    scapegoat[1] = 'abc'
    scapegoat[4] = 'jkl'
    scapegoat[5] = 'mno'
    scapegoat[7] = 'stu'
    scapegoat[8] = 'vwx'
    scapegoat[9] = 'yz'
    scapegoat[6] = 'pqr'

    if scapegoat.find_min() != 1:
        sys.stderr.write('%s: test_values: Minimum was not 0\n' % sys.argv[0])
        all_good = False

    if scapegoat.find_max() != 9:
        sys.stderr.write('%s: test_values: Maximum was not 9\n' % sys.argv[0])
        all_good = False

    if scapegoat[5] != 'mno':
        sys.stderr.write('%s: test_values: Middle was not mno\n' % sys.argv[0])
        all_good = False

    return all_good

def test_inorder_traversal():
    '''Test an inorder traversal'''

    list_ = []

    def visit(key, value):
        '''Visit a node, but sticking its key, value into a list'''
        list_.append((key, value))


    keys = [x*3 + 1 for x in range(10)]
    items = [(key, str(key)) for key in keys]
    scapegoat = scapegoat_mod.ScapeGoatTree(ALPHA)
    for key in keys:
        scapegoat[key] = str(key)

    all_good = True

    scapegoat.inorder_traversal(visit)

    if items != list_:
        sys.stderr.write('%s: test_inorder_traversal: inorder_traversal failed to rebuild the list\n' % (sys.argv[0], ))
        sys.stderr.write('Expected %s\n' % (items, ))
        sys.stderr.write('Got %s\n' % (list_, ))
        all_good = False

    return all_good

def test_str():
    '''Test formatting a scapegoat tree as a string'''

    all_good = True

    scapegoat = scapegoat_mod.ScapeGoatTree(ALPHA)
    scapegoat[1] = 'abc'
    scapegoat[2] = 'def'
    scapegoat[3] = 'ghi'
    scapegoat[4] = 'jkl'
    scapegoat[5] = 'mno'
    scapegoat[6] = 'pqr'
    scapegoat[7] = 'stu'
    scapegoat[8] = 'vwx'
    scapegoat[9] = 'yz'

    dummy = scapegoat[3]

    string = str(scapegoat)

    count = string.count('\n')
    len_scapegoat = len(scapegoat)
    base = 1.0 / scapegoat.alpha
    maximum_allowable_depth = math.log(len_scapegoat, base)

    if count > maximum_allowable_depth:
        sys.stderr.write('%s: test_str: bad number of newlines: %d\n' % (sys.argv[0], count))
        sys.stderr.write('%s\n' % count)
        sys.stderr.write('%s\n' % (string, ))
        all_good = False

    return all_good

def test_iterator():
    '''Test iterating over the enter scapegoat tree'''
    all_good = True

    actual = []

    scapegoat = scapegoat_mod.ScapeGoatTree(ALPHA)
    scapegoat[2] = 'def'
    scapegoat[3] = 'ghi'
    scapegoat[1] = 'abc'
    scapegoat[4] = 'jkl'
    scapegoat[5] = 'mno'
    scapegoat[7] = 'stu'
    scapegoat[8] = 'vwx'
    scapegoat[9] = 'yz'
    scapegoat[6] = 'pqr'

    expected = ['abc', 'def', 'ghi', 'jkl', 'mno', 'pqr', 'stu', 'vwx', 'yz',]

    for value in scapegoat.values():
        actual.append(value)

    if expected != actual:
        sys.stderr.write('%s: test_iterator: values did not come back in correct order\n' % sys.argv[0])
        all_good = False

    return all_good


def main():
    # pylint: disable=global-statement
    '''Main function'''

    global ALPHA

    # A high alpha value results in fewer balances, making insertion quicker
    # but lookups and deletions slower, and vice versa for a low alpha.
    # Therefore in practical applications, an alpha can be chosen depending
    # on how frequently these actions should be performed.

    # Doesn't work well
    #ALPHA = 0.5001

    #ALPHA = 0.6000
    #ALPHA = 0.7500
    #ALPHA = 0.9000

    # Doesn't work well
    #ALPHA = 0.9999

    ALPHA = float(sys.argv[1])
    assert 0.50 < ALPHA < 0.99

    all_good = True
    all_good &= test_depth()
    all_good &= test_insert()
    all_good &= test_remove()
    all_good &= test_large_inserts()
    all_good &= test_nonempty()
    all_good &= test_empty()
    all_good &= test_min_max()
    all_good &= test_values()
    all_good &= test_inorder_traversal()
    all_good &= test_str()
    all_good &= test_iterator()

    if not all_good:
        sys.stderr.write('%s: One or more tests failed\n' % sys.argv[0])
        sys.exit(1)

main()