#!/usr/bin/env python3

"""_thread-based performance comparison."""

import sys
import _thread

try:
    import utime
except ImportError:
    import time
    have_time = True
else:
    have_time = False


def interpreter_name():
    """Return a str description of the Python interpreter/runtime we're running on."""
    g = globals()
    n = sys.implementation.name
    if hasattr(sys, 'pypy_version_info'):
        list_ = sys.pypy_version_info
    elif '__compiled__' in g:
        # This bothers me, but at least it works.
        list_ = g['__compiled__']
        n = 'nuitka'
    else:
        list_ = sys.implementation.version
    return '{}-{}'.format(n, '.'.join(str(v) for v in list_[:3]))


def avg(*numbers):
    """Return the arithmetic mean of the list numbers."""
    result = sum(numbers) / len(numbers)
    return result


def append(filename, num_cores, duration):
    """Append one statistic to the relevant file."""
    with open(filename, 'a') as file_:
        file_.write('{} {}\n'.format(num_cores, duration))


def timesecs():
    """Return the number of seconds since the epoch."""
    if have_time:
        return time.time()
    else:
        # It appears that micropython does not support time.time(), but it has utime.time_ns()
        return utime.time_ns() / 1_000_000_000


def pause_slightly():
    """Sleep for one microsecond."""
    if have_time:
        time.sleep(0.000001)
    else:
        utime.sleep_us(1)


def usage(retval):
    """Output a usage message."""
    if retval == 0:
        write = sys.stdout.write
    else:
        write = sys.stderr.write

    write('{}: --max-threads 4 --number-of-numbers 1000000000 --samples 3 --help\n'.format(sys.argv[0]))
    write('\n')
    write('This benchmark will run a simple sum-of-range test, using 1..max_threads threads.\n')
    write('The count of numbers to total is specified with --number-of-numbers\n')
    write('The number of times to repeat a given test is specified with --samples.\n')
    write('It takes some of the sting out of variable CPU availability\n')

    sys.exit(retval)


def add_up(groupno, finished, subtotals, low, high):
    """Add the numbers from low..high."""
    subtotal = 0
    for number in range(low, high + 1):
        subtotal += number
    subtotals[groupno] = subtotal
    finished[groupno] = True


def add_up2(groupno, finished, subtotals, low, high):
    """Add the numbers from low..high."""
    subtotal = sum(range(low, high + 1))
    subtotals[groupno] = subtotal
    finished[groupno] = True


def main():
    """Start the ball rolling."""
    max_threads = 0
    number_of_numbers = 1_000_000_000
    samples = 1
    while sys.argv[1:]:
        if sys.argv[1] == '--max-threads':
            max_threads = int(sys.argv[2])
            del sys.argv[1]
        elif sys.argv[1] == '--samples':
            samples = int(sys.argv[2])
            del sys.argv[1]
        elif sys.argv[1] == '--number-of-numbers':
            number_of_numbers = int(sys.argv[2])
            del sys.argv[1]
        elif sys.argv[1] in ('-h', '--help'):
            usage(0)
        else:
            print('{}: unrecognized option: {}'.format(sys.argv[0], sys.argv[1]), file=sys.stderr)
            usage(1)
        del sys.argv[1]

    if max_threads <= 0:
        print('{}: --max-threads is a required option'.format(sys.argv[0]), file=sys.stderr)
        usage(1)

    for num_threads in range(1, max_threads + 1):
        delta_time = 0.0
        for sampleno in range(samples):
            time0 = timesecs()
            # Rounding error is very real for this quotient.
            numbers_per_group = number_of_numbers // num_threads
            number_of_groups = num_threads
            finished = [False for nog in range(number_of_groups)]
            subtotals = [0 for nog in range(number_of_groups)]
            for groupno in range(number_of_groups):
                low = numbers_per_group * groupno
                if groupno == number_of_groups - 1:
                    # The last thread may have a little more work to do, due to rounding error in integer division.
                    high = number_of_numbers
                else:
                    high = numbers_per_group * (groupno + 1) - 1
                tuple_ = (groupno, finished, subtotals, low, high)
                _thread.start_new_thread(add_up2, tuple_)
            while not all(finished):
                pause_slightly()
            actual_total = sum(subtotals)
            expected_total = avg(0, high) * (high + 1)
            if actual_total != expected_total:
                raise AssertionError('actual_total {} != expected_total {}, subtotals: {}'.format(
                        actual_total,
                        expected_total,
                        subtotals,
                ))
            time1 = timesecs()
            sample_delta_time = time1 - time0
            print('sampleno: {}, num_threads: {}, time: {}'.format(
                    sampleno,
                    num_threads,
                    sample_delta_time,
            ))
            delta_time += sample_delta_time
        append('%s.dat' % interpreter_name(), num_threads, delta_time / samples)


main()