#!/usr/bin/env python

'''Unit tests for rolling checksum mod - there are others, these are just the most fundamental'''

import os
import sys
import math
import functools
#mport pprint
#mport random

import comma_mod
import python2x3

if sys.argv[1:] and sys.argv[1] == 'py':
	import rolling_checksum_py_mod as rolling_checksum_mod
elif sys.argv[1:] and sys.argv[1] == 'pyx':
	import rolling_checksum_pyx_mod as rolling_checksum_mod
else:
	sys.stderr.write('argv[1] must be py or pyx\n')
	sys.exit(1)

def my_range(up_to):
	'''A range() function (generator) with consistent semantics from 2.x to 3.x'''
	value = 0
	while value < up_to:
		yield value
		value += 1

class Block_sequence_params:
	'''A class for abstracting away the differences between the kinds of block sequence generators we're interested in'''
	def __init__(self, generator, description, min_average, max_average, min_stddev, max_stddev, maximum):
		# pylint: disable=R0913
		# R0913: We want a few arguments
		self.generator = generator
		self.description = description
		self.min_average = min_average
		self.max_average = max_average
		self.min_stddev = min_stddev
		self.max_stddev = max_stddev
		self.maximum = maximum

	def average_in_range(self, average):
		'''Test if an average is in range'''
		if self.min_average < average < self.max_average:
			print('Good, %s average in range' % self.description)
			return True
		else:
			print('%s average is not between %s and %s: %s' % (
				self.description,
				comma_mod.gimme_commas(self.min_average), 
				comma_mod.gimme_commas(self.max_average), 
				comma_mod.gimme_commas(str(int(average)))),
				)
			return False

	def stddev_in_range(self, standard_deviation):
		'''Test if a standard deviation is in range'''
		if self.min_stddev < standard_deviation < self.max_stddev:
			print('Good, %s standard deviation in range' % self.description)
			return True
		else:
			print('%s standard_deviation is not between %s and %s: %s' % (
				self.description,
				comma_mod.gimme_commas(self.min_stddev),
				comma_mod.gimme_commas(self.max_stddev),
				comma_mod.gimme_commas(str(int(standard_deviation))),
				))
			return False
	
	def highest_beneath_maximum(self, highest):
		'''Check that the highest block length found is beneath our preset maximum'''
		if self.maximum is None:
			print('Good, %s maximum,is None' % self.description)
			return True
		elif highest <= self.maximum:
			print('Good, %s highest %s beneath %s maximum' % (self.description, highest, self.maximum))
			return True
		else:
			print('zbad, %s highest %s not beneath %s maximum' % (self.description, highest, self.maximum))
			return False

def rcm_size_and_accuracy_test():
	'''Test for accuracy'''
	all_good = True

	file_handle = os.open('rcm-input-data', os.O_RDONLY)
	expected_list = []
	while True:
		block = os.read(file_handle, 2**20)
		if not block:
			break
		expected_list.append(block)
	expected = python2x3.empty_bytes.join(expected_list)
	os.close(file_handle)

	block_sequence_params = [
		Block_sequence_params(
			generator = functools.partial(rolling_checksum_mod.n_level_chunker, levels=3), 
			description = 'n_level_chunker', 
			min_average = 300000, 
			max_average = 4000000, 
			min_stddev = 10, 
			max_stddev = 1500000,
			maximum = None,
			),
		Block_sequence_params(
			generator = functools.partial(rolling_checksum_mod.min_max_chunker), 
			description = 'min_max_chunker', 
			min_average = 700000, 
			max_average = 4000000, 
			min_stddev = 300000, 
			max_stddev = 1500000,
			maximum = 2**22,
			),
		]

	for block_sequence_param in block_sequence_params:
		blocks = []
		file_handle = os.open('rcm-input-data', os.O_RDONLY)

		total_len = 0
		for blockno, block in enumerate(block_sequence_param.generator(file_handle)):
			blocks.append(block)
			total_len += len(block)
			sys.stderr.write('Appended blockno %d of length (%s) %s\n' % (
				blockno, 
				comma_mod.gimme_commas(total_len), 
				comma_mod.gimme_commas(len(block)),
				))

		os.close(file_handle)

		all_good &= report(block_sequence_param, expected, blocks)


	return all_good

def report(block_sequence_param, expected, blocks):
	'''Report on our findings from rcm_size_and_accuracy_test'''
	all_good = True

	actual = python2x3.empty_bytes.join(blocks)
	if actual == expected:
		print('Good, %s files match' % block_sequence_param.description)
	else:
		print('Incorrect, %s files do not match' % block_sequence_param.description)
		all_good = False

	lengths = [ len(block) for block in blocks ]
	average = float(sum(lengths)) / float(len(lengths))
	all_good &= block_sequence_param.average_in_range(average)

	standard_deviation = stddev(lengths, average)
	all_good &= block_sequence_param.stddev_in_range(standard_deviation)

	highest = max(lengths)
	all_good &= block_sequence_param.highest_beneath_maximum(highest)

	return all_good

def stddev(list_, average):
	'''Compute the standard deviation of a list, given the list and its precomputed average'''
	total = 0.0
	for element in list_:
		total += (element - average) ** 2
	return math.sqrt(total / len(list_))

def main():
	'''Main function'''

	all_good = True
	all_good &= rcm_size_and_accuracy_test()

	if all_good:
		sys.stderr.write('%s: All tests passed\n' % sys.argv[0])
	else:
		sys.stderr.write('%s: One or more tests failed\n' % sys.argv[0])
		sys.exit(1)

main()