#!/usr/local/pypy-2.2/bin/pypy

from __future__ import with_statement

import gc
import os
import sys
import copy
import glob
import time
import random
import functools

import btree_mod
import splay_mod
import treap_mod
import aa_tree_mod
import avl_tree_mod
import scapegoat_mod
import red_black_dict_mod
import binary_tree_dict_mod


def my_range(top):
    number = 0
    while number < top:
        yield number
        number += 1


def get_random_number(maximum):
    return int(random.random() * maximum)


class Random_key_class:
    description = 'random'

    def __init__(self):
        self.index = 0

    def get_key(self):
        random.seed(self.index)
        retval = int(random.random() * 10000000)
        self.index += 1
        return retval

    def key_verifies(self, value1, value2):
        # We cannot verify these keys because of collisions
        return True


class Sequential_key_class:
    description = 'sequential'

    def __init__(self):
        self.index = 0

    def get_key(self):
        retval = self.index
        self.index += 1
        return retval

    def key_verifies(self, value1, value2):
        return str(value1) == value2

        
class Datastructure:
    def __init__(self, initializer, description):
        self.initializer = initializer
        self.description = description

    def __hash__(self):
        return hash(self.description)

    def __cmp__(self, other):
        return cmp(self.description, other.description)


def test(datastructure, number_pattern, size, get, set_):

    random.seed(0)

    total = get + set_

    total_time = 0

    # We do this 3 times, because some datastructures seem to hit the garbage collection harder than others
    sys.stdout.write('.')
    for repetition in range(3):
        sys.stdout.flush()
        time0 = time.time()
        obj = datastructure.initializer()
        for reps in my_range(size):
            something_set = False
            if get_random_number(total) < get:
                # Do a get
                if not something_set:
                    # unless there's nothing there to get yet
                    continue
                # This will get a random number.  Domain is 0..element-1.  Range is all over the place.
                key = number_pattern.get_key()
                # We no longer verify that value == str(key), because we're randomizing the values and
                # there almost certainly will be duplicate keys
                value = obj[key]
                if number_pattern.key_verifies(key, value):
                    pass
                else:
                    sys.stderr.write('%s: key failed to verify: %s, %s\n' % (sys.argv[0], key, value))
                    sys.exit(1)
            else:
                # Do a set
                something_set = True
                key = number_pattern.get_key()
                value = str(key)
                obj[key] = str(value)
        del obj
        # Now force a garbage collection - otherwise the numbers produced can come up pretty misleading
        sys.stdout.write('+gc ')
        sys.stdout.flush()
        gc.collect()
        sys.stdout.write('-gc ')
        sys.stdout.flush()

        time1 = time.time()
        sys.stdout.flush()
        delta_time = time1 - time0
        total_time += delta_time
    sys.stdout.write('. ')
    return total_time
        
def get_interpreter():
    bin_dir = os.path.dirname(sys.executable)
    interpreter_dir = os.path.dirname(bin_dir)
    interpreter = os.path.basename(interpreter_dir)

    return interpreter

def get_filename(get, set_, number_pattern, tree_type):
    return '%s-%s-%s-%s-%s.dat' % (get_interpreter(), number_pattern.description, get, set_, tree_type)

def main():

    for filename in glob.glob('%s-*.dat' % get_interpreter()):
        print('removing %s' % filename)
        os.unlink(filename)

    get_set_ratios = [
        (95, 5),
        (50, 50),
        (5, 95),
        ]

    #number_of_exponents = 8
    number_of_exponents = 24

    # Two minutes
    too_long = 2.0 * 60.0

    for number_pattern in [ Random_key_class(), Sequential_key_class() ]:
        for get, set_ in get_set_ratios:
            # (Re)Create the set for each get, set_ pair.
            datastructures = set([
                Datastructure(dict, 'dict'),
                Datastructure(aa_tree_mod.AA_tree, 'AA_tree'),
                Datastructure(avl_tree_mod.AVL_tree, 'AVL_tree'),
                Datastructure(btree_mod.BPlusTree, 'B_tree'),
                Datastructure(splay_mod.Splay, 'splay_tree'),
                Datastructure(treap_mod.Treap, 'treap'), 
                Datastructure(red_black_dict_mod.RedBlackTree, 'red-black_tree'), 
                Datastructure(functools.partial(scapegoat_mod.ScapeGoatTree, alpha=0.6), 'scapegoat_tree_0_6'), 
                Datastructure(functools.partial(scapegoat_mod.ScapeGoatTree, alpha=0.75), 'scapegoat_tree_0_75'), 
                Datastructure(functools.partial(scapegoat_mod.ScapeGoatTree, alpha=0.9), 'scapegoat_tree_0_9'), 
                Datastructure(binary_tree_dict_mod.BinaryTreeDict, 'binary_tree_dict'),
                ])
            for size in range(number_of_exponents):
                for datastructure in copy.copy(datastructures):
                    amount = 2**size
                    sys.stdout.write('%-20s %-12s %d and %d %d ' % (datastructure.description, number_pattern.description, get, set_, amount))
                    sys.stdout.flush()
                    duration = test(datastructure, number_pattern, amount, get, set_)
                    print('%s' % (duration, ))
                    if duration > too_long:
                        datastructures.remove(datastructure)
                    else:
                        with open(get_filename(get, set_, number_pattern, datastructure.description), 'a') as file_:
                            file_.write('%s %s\n' % (amount, duration))
                sys.stdout.write('\n')


main()