#!/usr/bin/python """Unit test splay_mod.""" from __future__ import print_function import sys import math import random import splay_mod def make_used(*var): """Persuade linters that var is used.""" assert True or var def test_insert(): """Insert some values into a splay tree, then make sure they can all be found in the tree.""" keys = list(range(100)) random.seed(0) random.shuffle(keys) splay = splay_mod.Splay() for key in keys: splay.insert(key, str(key)) all_good = True for key in keys: if str(key) != splay.find(key): all_good = False sys.stderr.write('%s: test_insert: Found mismatched key\n' % sys.argv[0]) return all_good def test_remove(): """Insert some values into a splay tree, then make sure they can all be removed. Finally, ensure the splay tree is empty.""" keys = list(range(100)) random.seed(1) random.shuffle(keys) splay = splay_mod.Splay() for key in keys: splay.insert(key, str(key)) all_good = True random.shuffle(keys) for key in keys: del splay[key] try: make_used(splay[key]) except KeyError: pass else: all_good = False sys.stderr.write('%s: test_remove: element not removed\n' % sys.argv[0]) if splay: all_good = False sys.stderr.write('%s: test_remove: final tree nonempty\n' % sys.argv[0]) return all_good def test_large_inserts(): """Insert lots of values into a splay tree, just to see if we get a traceback.""" all_good = True splay = splay_mod.Splay() nums = 100000 gap = 997 key = gap actual_min = None actual_max = None while key != 0: if actual_min is None or key < actual_min: actual_min = key if actual_max is None or key > actual_max: actual_max = key splay.insert(key, str(key)) key = (key + gap) % nums if actual_min != splay.find_min(): sys.stderr.write('%s: Large splay did not return correct minimum\n' % sys.argv[0]) all_good = False if actual_max != splay.find_max(): sys.stderr.write('%s: Large splay did not return correct maximum\n' % sys.argv[0]) all_good = False return all_good def test_nonempty(): """Test a nonempty splay tree.""" keys = list(range(100)) random.seed(2) random.shuffle(keys) splay = splay_mod.Splay() for key in keys: splay.insert(key, str(key)) all_good = True if not splay: all_good = False sys.stderr.write('%s: nonempty splay looks empty\n' % sys.argv[0]) return all_good def test_empty(): """Test an empty splay tree.""" all_good = True splay = splay_mod.Splay() if splay: all_good = False sys.stderr.write('%s: empty splay looks nonempty\n' % sys.argv[0]) return all_good def test_min_max(): """Insert some values into a splay tree, then test find_min and find_max.""" keys = list(range(100)) random.seed(3) random.shuffle(keys) splay = splay_mod.Splay() for key in keys: splay.insert(key, str(key)) all_good = True if splay.find_min() != 0: sys.stderr.write('%s: minimum was not 0\n' % sys.argv[0]) all_good = False if splay.find_max() != 99: sys.stderr.write('%s: maximum was not 99\n' % sys.argv[0]) all_good = False return all_good def test_values(): """Insert a few key-value pairs, and make sure they come back out OK.""" all_good = True splay = splay_mod.Splay() splay.insert(2, 'def') splay.insert(3, 'ghi') splay.insert(1, 'abc') splay.insert(4, 'jkl') splay.insert(5, 'mno') splay.insert(7, 'stu') splay.insert(8, 'vwx') splay.insert(9, 'yz') splay.insert(6, 'pqr') if splay.find_min() != 1: sys.stderr.write('%s: test_values: Minimum was not 0\n' % sys.argv[0]) all_good = False if splay.find_max() != 9: sys.stderr.write('%s: test_values: Maximum was not 9\n' % sys.argv[0]) all_good = False if splay.find(5) != 'mno': sys.stderr.write('%s: test_values: Middle was not mno\n' % sys.argv[0]) all_good = False return all_good def test_inorder_traversal(): """Test an inorder traversal.""" list_ = [] def visit(key, value): """Visit a node, but sticking its key, value into a list.""" list_.append((key, value)) keys = [x*3 + 1 for x in range(100)] items = [(key, str(key)) for key in keys] splay = splay_mod.Splay() for key in keys: splay.insert(key, str(key)) all_good = True splay.inorder_traversal(visit) if items != list_: sys.stderr.write('%s: test_inorder_traversal: inorder_traversal failed to rebuild the list') all_good = False return all_good def test_depth(): """Test getting the depth of a splay tree.""" all_good = True splay = splay_mod.Splay() splay.insert(2, 'def') splay.insert(3, 'ghi') splay.insert(1, 'abc') splay.insert(4, 'jkl') splay.insert(5, 'mno') splay.insert(7, 'stu') splay.insert(8, 'vwx') splay.insert(9, 'yz') splay.insert(6, 'pqr') depth = splay.depth() min_depth = math.floor(math.log(9) / math.log(2)) if min_depth <= depth <= 9: pass else: sys.stderr.write('%s: test_depth: Bogus depth: %d\n' % (sys.argv[0], depth)) all_good = False return all_good def test_str(): """Test formatting a splay tree as a string.""" all_good = True splay = splay_mod.Splay() splay.insert(1, 'abc') splay.insert(2, 'def') splay.insert(3, 'ghi') splay.insert(4, 'jkl') splay.insert(5, 'mno') splay.insert(6, 'pqr') dummy = splay.find(3) make_used(dummy) string = str(splay) count = string.count('\n') # It's actually 3 lines, but the 3rd line has no newline so we can feed it directly to print if count != 2: sys.stderr.write('%s: test_str: bad number of newlines: %d\n' % (sys.argv[0], count)) all_good = False return all_good def test_iterator(): """Test iterating over the enter splay tree.""" all_good = True actual = [] splay = splay_mod.Splay() splay.insert(2, 'def') splay.insert(3, 'ghi') splay.insert(1, 'abc') splay.insert(4, 'jkl') splay.insert(5, 'mno') splay.insert(7, 'stu') splay.insert(8, 'vwx') splay.insert(9, 'yz') splay.insert(6, 'pqr') expected = ['abc', 'def', 'ghi', 'jkl', 'mno', 'pqr', 'stu', 'vwx', 'yz'] for value in splay.values(): actual.append(value) if expected != actual: sys.stderr.write('%s: test_iterator: values incorrect\n' % sys.argv[0]) all_good = False return all_good def test_replace(): """Test replacing a value.""" all_good = True actual = [] splay = splay_mod.Splay() splay[2] = 'def' splay[3] = 'ghi' splay[1] = 'abc' splay[4] = 'jkl' splay[5] = 'mno' splay[7] = 'stu' splay[8] = 'vwx' splay[9] = 'yz' splay[6] = 'pqr' splay[4] = 'rep' expected = [ (1, 'abc'), (2, 'def'), (3, 'ghi'), (4, 'rep'), (5, 'mno'), (6, 'pqr'), (7, 'stu'), (8, 'vwx'), (9, 'yz'), ] for pair in splay.items(): actual.append(pair) if expected != actual: sys.stderr.write('%s: test_replace: items incorrect: %s\n' % (sys.argv[0], actual)) print(splay) all_good = False return all_good def main(): """Run tests (main function).""" all_good = True all_good &= test_insert() all_good &= test_remove() all_good &= test_large_inserts() all_good &= test_nonempty() all_good &= test_empty() all_good &= test_min_max() all_good &= test_values() all_good &= test_inorder_traversal() all_good &= test_depth() all_good &= test_str() all_good &= test_iterator() all_good &= test_replace() if not all_good: sys.stderr.write('%s: One or more tests failed\n' % sys.argv[0]) sys.exit(1) print('All tests passed.') main()