#!/usr/bin/env python

'''Divide a group of files into like classes - rapidly'''

# Areas for improvement:
# 3) There's little sense in processing 0 length files - they're always equal to each other
# 4) We might be able to revive the dev+ino test
# 5) B4b_class should raise an exception on premature (inconsistent) EOF for callers to use (perhaps especially modified_merge_sort)

import os
import sys
#mport stat
import time

sys.path.insert(0, os.path.expanduser('~/lib'))
sys.path.insert(0, '/usr/local/lib')

try:
    import readline0
except ImportError:
    HAVE_READLINE0 = False
else:
    HAVE_READLINE0 = True

try:
    import hashlib
    HASH_MODULE = hashlib
except ImportError:
    # Ubuntu 9.04's pypy has no hashlib yet (2009-06-12)
    import md5
    HASH_MODULE = md5

def usage(retval):
    '''Output a usage message'''
    sys.stderr.write('Usage: %s\n' % sys.argv[0])
    sys.stderr.write('\t--verbosity n           set verbosity to n\n')
    sys.stderr.write('\t-v                      increase verbosity by 1\n')
    sys.stderr.write('\t-h|--help               output this message\n')
    sys.stderr.write('\t--prefix-length n       set prefix hash length to n\n')
    sys.stderr.write('\t--block-size n          set block size to n\n')
    sys.stderr.write('\t--skip-uniques          do not output unique files\n')
    sys.stderr.write('\t--skip-duplicates       do not output duplicate files\n')
    sys.stderr.write('\t--one-per-duplicate     for duplicates, output one filename for a given file content\n')
    sys.stderr.write('\t--use-dev-ino           optimize for lots of hard links.  Does not work on Windows!\n')
    sys.stderr.write('\t--sort-test             assume all full hashes are collisions so the double merge sort code is tested\n')
    sys.stderr.write('\t-0                      read filenames null separated, not newline separated\n')
    sys.exit(retval)

class Left_exception(Exception):
    '''A simple exception for when the left file in a file comparison gives some sort of error durning open or read'''
    pass

class Right_exception(Exception):
    '''A simple exception for when the right file in a file comparison gives some sort of error durning open or read'''
    pass

def separate_uniques(unique, possible_remaining_dups, verbosity, dict_, initial_time, end = False):
    # pylint: disable-msg=R0913
    # R0913: we need lots of arguments

    '''Separate out the recently-recognized unique values from the still-possibly-duplicate files'''
    log(verbosity, 2, 'Separating out the unique values...\n')
    new_uniques = 0
    log(verbosity, 2, 'Got a unique value\n')
    tuples = dict_.items()
    for key, values in tuples:
        if values[0:] and not values[1:]:
            new_uniques += 1
            unique.append(values[0])
            del dict_[key]
    possible_remaining_dups -= new_uniques
    if verbosity >= 1:
        if end:
            dupli_ambi = 'duplicates'
            informs = ''
        else:
            dupli_ambi = 'ambiguous'
            informs = ' - but this stage informs later stages'
        denominator = time.time() - initial_time
        if denominator == 0.0:
            # can't divide by zero of course
            denominator = 0.000001
        uniques_per_second = new_uniques / denominator
        sys.stderr.write('Got %d new unique values, %d %s remain\n%f uniques found per second%s\n' % \
            (new_uniques, possible_remaining_dups, dupli_ambi, uniques_per_second, informs))
    return possible_remaining_dups

def by_size(verbosity, devino_stuff, total_file_count, single_iterator):
    '''Divide files into possibly-same groups based on their sizes.

    This isn't that accurate, but it's quite fast, and large, expensive files tend to have unique lengths
    '''
    size_dict = {}
    for filename in single_iterator:
        total_file_count += 1
        log(verbosity, 3, 'stat()ing %s\n' % filename)
        try:
            statbuf = os.stat(filename)
        except OSError:
            sys.stderr.write('Error accessing %s - removing from list and continuing\n' % filename)
            continue
        length = statbuf.st_size
        if size_dict.has_key(length):
            size_dict[length].append(filename)
        else:
            size_dict[length] = [filename]
        if devino_stuff:
            devino_stuff.filename_to_devino[filename] = (statbuf.st_dev, statbuf.st_ino)
    return total_file_count, size_dict

def get_prefix_hash(verbosity, devino_stuff, prefix_length, filename):
    '''Get a hash from a file's initial prefix'''
    log(verbosity, 3, 'prefix hashing %s\n' % filename)
    if devino_stuff:
        devino = devino_stuff.filename_to_devino[filename]
        if devino_stuff.devino_to_prefix_hash.has_key(devino):
            if verbosity >= 2:
                sys.stderr.write('Pulled prefix hash from devino_to_prefix_hash\n')
            return devino_stuff.devino_to_prefix_hash[devino]
    md5_hash = HASH_MODULE.md5()
    file_ = open(filename, 'r')
    md5_hash.update(file_.read(prefix_length))
    file_.close()
    result = md5_hash.hexdigest()
    if devino_stuff:
        devino_stuff.devino_to_prefix_hash[devino] = result
    return result

def get_full_hash(verbosity, devino_stuff, blocksize, filename):
    '''Get a file's full hash'''
    log(verbosity, 3, 'full hashing %s\n' % filename)
    if devino_stuff:
        devino = devino_stuff.filename_to_devino[filename]
        if devino_stuff.devino_to_full_hash.has_key(devino):
            if verbosity >= 2:
                sys.stderr.write('Pulled prefix hash from devino_to_prefix_hash\n')
            return devino_stuff.devino_to_full_hash[devino]
    md5_hash = HASH_MODULE.md5()
    file_ = open(filename, 'r')
    while 1:
        block = file_.read(blocksize)
        if not block:
            break
        md5_hash.update(block)
    file_.close()
    result = md5_hash.hexdigest()
    if devino_stuff:
        devino_stuff.devino_to_full_hash[devino] = result
    return result

def by_prefix_hash(verbosity, devino_stuff, prefix_length, double_iterator):
    '''
    Compare files by a hash of an initial prefix of the file.
    If they have the same prefix hash, they might be equal files.
    If they do not have the same prefix hash, they are definitely different files.

    This isn't as thorough as a full hash or byte for byte compare, but it's fast.
    '''
    prefix_hash_dict = {}
    small_file_pass_through_count = 0
    pair_pass_through_count = 0
    for original_key, values in double_iterator:
        if original_key <= prefix_length:
            # This is a "small" file - there's no sense in prefix hashing it now, and full hashing it later.
            # So pass it through this stage unchecked.
            if verbosity >= 2:
                len_values = len(values)
                sys.stderr.write('Passing through %d short files of length %d\n' % (len_values, original_key))
            prefix_hash_dict['%d/' % original_key] = values
            small_file_pass_through_count += len(values)
        else:
            if values[1:] and not values[2:]:
                if verbosity >= 2:
                    sys.stderr.write('Skipping prefix hash of two items (because there are but two and hashing won\'t add anything)\n')
                new_key = '%s/' % original_key
                prefix_hash_dict[new_key] = values
                pair_pass_through_count += 1
            else:
                if verbosity >= 2:
                    len_values = len(values)
                    sys.stderr.write('Doing prefix hash of %d items for %s\n' % (len_values, original_key))
                for filename in values:
                    try:
                        new_key = '%s/%s' % (original_key, get_prefix_hash(verbosity, devino_stuff, prefix_length, filename))
                        if prefix_hash_dict.has_key(new_key):
                            prefix_hash_dict[new_key].append(filename)
                        else:
                            prefix_hash_dict[new_key] = [filename]
                    except (IOError, OSError):
                        sys.stderr.write('Error accessing %s - removing from list and continuing\n' % filename)
    log(verbosity, 1, 'Passed through %d small files without prefix hashing\n' % small_file_pass_through_count)
    log(verbosity, 1, 'Passed through %d pairs without prefix hashing\n' % pair_pass_through_count)
    return prefix_hash_dict

def by_full_hash(verbosity, devino_stuff, blocksize, double_iterator):
    '''
    Compare values by their complete hashes.
    If they have the same hash, they're almost certainly equal files.
    If they do not have the same hash, they are definitely different files.

    This isn't as thorough as a byte for byte compare, but it's fast-ish.
    '''
    full_hash_dict = {}
    pass_through_count = 0
    for original_key, values in double_iterator:
        len_values = len(values)
        if len_values == 1:
            sys.stderr.write('Error: Unique vaue in by_full_hash\n')
            sys.exit(1)
        elif len_values == 2:
            # If there are exactly 2 values that may or may not be equal, skip the full hash, as the b4b comparison
            # that comes after this will take the same or less time, and give greater assurance that the files
            # are different
            if verbosity >= 2:
                sys.stderr.write('Skipping full hash of two items (because there are but two and hashing won\'t add anything)\n')
            full_hash_dict[original_key+'/'] = values
            pass_through_count += 1
        else:
            if verbosity >= 2:
                sys.stderr.write('Doing full hash of %d items for %s\n' % (len_values, original_key))
            for filename in values:
                try:
                    new_key = '%s/%s' % (original_key, get_full_hash(verbosity, devino_stuff, blocksize, filename))
                    if full_hash_dict.has_key(new_key):
                        full_hash_dict[new_key].append(filename)
                    else:
                        full_hash_dict[new_key] = [filename]
                except (IOError, OSError):
                    sys.stderr.write('Error accessing %s - removing from list and continuing\n' % filename)
    log(verbosity, 1, 'Passed through %d pairs without full hashing\n' % pass_through_count)
    return full_hash_dict

def list_of_lists_from_likely_same(verbosity, lst):
    '''
    Our lists (passed to double_merge_sort) will normally be of equal values, because (at one time at least) the values all had the
    same length, prefix hash and full hash.  So instead of incurring a O(nlogn) sort right off the bat, first we check if the elements
    are all the same in O(n) time.
    
    -And- we optimize a little bit as we go - we make this scan for sameness a degenerate first pass through the sort.  The algorithm
    in this function is to compare the 1..n-1 elements to the 0th element, each in turn (this helps maximize the benefit of the buffer
    cache too).
    
    If the 1..n-1 elements are all the same as the first element, the list is equal already - no need to sort.
    
    If they're not all the same as the first element, hopefully the first element at least isn't unique, so we get started on setting
    up buckets of equal values - which is really what double_merge_sort is doing.
    
    In short, we get started on setting up buckets of equal values before we start sorting: So if we don't need to sort, we don't,
    and if we do need to start, we have a little headstart.
    '''
    if not lst[1:]:
        # if there's no second element, the list is either empty, or has a single value.  Return the list
        return lst
    magic_element = lst[0]
    sames = [ magic_element ]
    maybe_differents = []
    tail = lst[1:]
    len_tail = len(tail)
    for consider_element_no in xrange(len_tail-1, -1, -1):
        consider_element = tail[consider_element_no]
        try:
            comparison_result = cmp(magic_element, consider_element)
        except Left_exception:
            sys.stderr.write('Error accessing %s - skipping linear sameness check and continuing\n' % sames[0])
            del sames[0]
            maybe_differents.extend(tail[:consider_element_no+1])
            break
        except Right_exception:
            sys.stderr.write('Error accessing %s - removing from list and continuing\n' % consider_element)
            continue
        else:
            if comparison_result == 0:
                sames.append(consider_element)
            else:
                maybe_differents.append(consider_element)
    if sames:
        lst2 = [ sames ] + [ [x] for x in maybe_differents ]
    else:
        lst2 = [ [x] for x in maybe_differents ]
    if maybe_differents:
        log(verbosity, 2, 'One or more files were found different via linear check\n')
    else:
        log(verbosity, 2, 'All files were found same via linear check\n')
    return lst2
    
# we first check if the list is a bunch of equal values, because frequently, the list will be
def double_merge_sort(verbosity, lst):
    '''Merge sort that compresses things down to a list of sublists, where the sublists have equal values'''
    # convert to a list of lists and then sort recursively
    lst2 = list_of_lists_from_likely_same(verbosity, lst)
    if lst2:
        # lst2 is not empty
        if not lst2[1:]:
            # the list is not empty, but we have no second element, so the values in lst were already all equal: return lst2
            return lst2
        if lst2[1:] and not lst2[2:]:
            # we have two buckets but not a third - this too is already "sorted" for our purposes - because
            # we don't care if a < b or a > b, we only care that a != b
            return lst2
    else:
        # lst2 is empty - of course it's already sorted
        return lst2
    return double_merge_sort_backend(verbosity, lst2)

def merge_or_start_new(result, bucket):
    '''Either start a new bucket, or add to an existing bucket.  Part of our double merge sort'''
    if result[0:] and bucket[0] == result[-1][0]:
        # bucket has elements that match the end of result, so merge
        result[-1] += bucket
    else:
        # bucket has elements that are different from the end of result, so append, starting a new bucket in result
        result.append(bucket)

def double_merge_sort_backend(verbosity, lst):
    # pylint: disable-msg=R0912
    # R0912: I'm afraid we just need a lot of branches for this.
    '''
    We merge sorted lists, as is traditional, but we additionally merge buckets of same values together.
    We expect as input something like [ [1], [2], [5], [4], [3], [2] ] and return [ [1], [2, 2], [3], [4], [5] ].
    '''
    length = len(lst)
    if length <= 1:
        return lst
    midpoint = length/2
    list1 = double_merge_sort_backend(verbosity, lst[:midpoint])
    list2 = double_merge_sort_backend(verbosity, lst[midpoint:])
    index1 = 0
    index2 = 0
    len1 = len(list1)
    len2 = len(list2)
    result = []
    while index1 < len1 and index2 < len2:
        # sometimes we do get empty lists - because of the bad file removals
        if not list1[index1]:
            index1 += 1
            continue
        if not list2[index2]:
            index2 += 1
            continue
        try:
            comparison = cmp(list1[index1][0], list2[index2][0])
        except Left_exception:
            log(verbosity, 1, 'Error accessing %s - removing from list and continuing' % list1[index1][0])
            del list1[index1][0]
            if list1[index1]:
                index1 += 1
            else:
                del list1[index1]
            continue
        except Right_exception:
            log(verbosity, 2, 'Error accessing %s - removing from list and continuing' % list2[index2][0])
            del list2[index2][0]
            if list2[index2]:
                index2 += 1
            else:
                del list2[index2]
            continue
        if comparison < 0:
            merge_or_start_new(result, list1[index1])
            index1 += 1
        elif comparison > 0:
            merge_or_start_new(result, list2[index2])
            index2 += 1
        else:
            # they're equal - collapse
            #merge_or_start_new(result, list1[index1])
            #merge_or_start_new(result, list2[index2])
            # these are lists of lists of equal values, so we can concatenate lists of equal values to save some time
            merge_or_start_new(result, list1[index1] + list2[index2])
            index1 += 1
            index2 += 1
    while index1 < len1:
        merge_or_start_new(result, list1[index1])
        index1 += 1
    while index2 < len2:
        merge_or_start_new(result, list2[index2])
        index2 += 1
    return result

class B4b_class:
    # pylint: disable-msg=R0903
    # R0903: We want few public methods
    '''Handle byte for byte comparisons of two files'''
    blocksize = 2**20
    devino_stuff = None
    verbosity = 1

    def __init__(self, filename):
        self.filename = filename

    def __str__(self):
        return self.filename

    __repr__ = __str__

    def __cmp__(self, other):
        # pylint: disable-msg=R0912
        # R0912: I'm afraid we just need a lot of branches for this.
        left_filename = self.filename
        right_filename = other.filename
        log(self.verbosity, 3, 'byte for byte comparing %s and %s\n' % (left_filename, right_filename))
        if self.devino_stuff != None:
            left_devino = self.devino_stuff.filename_to_devino[left_filename]
            right_devino = self.devino_stuff.filename_to_devino[right_filename]
            if left_devino == right_devino:
                return 0
        try:
            left_file = open(left_filename, 'r')
        except (IOError, OSError):
            raise Left_exception
        try:
            right_file = open(right_filename, 'r')
        except (IOError, OSError):
            left_file.close()
            raise Right_exception
        try:
            left_eof = False
            right_eof = False
            while 1:
                try:
                    left_block = left_file.read(self.blocksize)
                except (IOError, OSError):
                    raise Left_exception
                if not left_block:
                    left_eof = True
                try:
                    right_block = right_file.read(self.blocksize)
                except (IOError, OSError):
                    raise Right_exception
                if not right_block:
                    right_eof = True
                if left_eof:
                    if right_eof:
                        # good, we hit EOF at the same time
                        return 0
                    else:
                        sys.stderr.write('Warning: EOF on %s but not %s, even though they previously had the same length\n' % (
                            left_filename, 
                            right_filename,
                            ))
                        return -1
                else:
                    if right_eof:
                        sys.stderr.write('Warning: EOF on %s but not %s, even though they previously had the same length\n' % (
                            right_filename, 
                            left_filename,
                            ))
                        # good, we hit EOF at the same time
                        return 1
                    else:
                        # We still have more to compare
                        comparison_result = cmp(left_block, right_block)
                        if comparison_result:
                            return comparison_result
        finally:
            left_file.close()
            right_file.close()
        # we should never reach this point
        sys.stderr.write('This should never happen\n')
        exit(1)

# we should replace the values.sort() with something that will coalesce identical items into a single item!
def by_byte_for_byte_comparison(verbosity, double_iterator):
    '''Compare files by byte for byte comparisons'''
    # b4b_dict[distinguisher] perhaps should be a list!  It's basically an array with hash overhead
    b4b_dict = {}
    distinguisher = 0
    for key, values in double_iterator:
        len_values = len(values)
        if len_values == 1:
            sys.stderr.write('Error: Unique vaue in by_byte_for_byte_comparison\n')
            sys.exit(1)
            continue
        b4b_list = [ B4b_class(value) for value in values ]
        if verbosity >= 2:
            sys.stderr.write('Sorting %d values, byte for byte, via double_merge_sort for %s\n' % (len_values, key))
        # Yes, we're sorting, but we're sorting lots of tiny lists - should be very fast.
        # Also, we only rarely sort - when there's a hash collision or a file's content changes during the run.
        sorted_buckets = double_merge_sort(verbosity, b4b_list)
        for bucket in sorted_buckets:
            distinguisher += 1
            b4b_dict[distinguisher] = bucket
    return b4b_dict

def log(verbosity, level, message):
    '''Log a message to stderr, but only if we're at an appopriate log level'''
    if verbosity >= level:
        sys.stderr.write(message)

def deal_with_command_line_options():
    # pylint: disable=R0912
    # R0912: We need a bunch of branches
    '''Pretty self-explanatory'''

    verbosity = 0
    prefix_length = 1024
    blocksize = 2**18
    null_delimited = False
    show_uniques = True
    show_duplicates = True
    one_per_duplicate = False
    use_dev_ino = False
    #sort_test = False

    while sys.argv[1:]:
        if sys.argv[1] == '--verbosity':
            verbosity = int(sys.argv[2])
            del sys.argv[1]
        elif sys.argv[1] == '-h':
            usage(0)
        elif sys.argv[1] == '-v':
            verbosity += 1
        elif sys.argv[1] == '--prefix-length':
            prefix_length = int(sys.argv[2])
            del sys.argv[1]
        elif sys.argv[1] == '--block-size':
            blocksize = int(sys.argv[2])
            del sys.argv[1]
        elif sys.argv[1] == '--skip-uniques':
            show_uniques = False
        elif sys.argv[1] == '--skip-duplicates':
            show_duplicates = False
        elif sys.argv[1] == '--one-per-duplicate':
            one_per_duplicate = True
        elif sys.argv[1] == '--use-dev-ino':
            use_dev_ino = True
#        elif sys.argv[1] == '--sort-test':
#            sort_test = True
        elif sys.argv[1] == '-0':
            if HAVE_READLINE0:
                null_delimited = True
            else:
                sys.stderr.write('$s: Apologies - we have no readline0 so -0 is unavailable\n' % sys.argv[0])
        else:
            sys.stderr.write('Illegal argument: %s\n' % sys.argv[1])
            usage(1)
        del sys.argv[1]

    return verbosity, prefix_length, blocksize, show_uniques, show_duplicates, one_per_duplicate, use_dev_ino, null_delimited

    
def newline_delimited_lines():
    '''Yield a series of lines, but strip off any trailing newlines for consistency with readline0.readline0'''
    for filename in sys.stdin:
        if filename[-1:] == '\n':
            filename = filename[:-1]
        yield filename.rstrip('\n')

class Devino_stuff_class:
    # pylint: disable-msg=R0903
    # R0903: We want few public methods
    '''Just a container for the device # and inode # stuff'''
    def __init__(self):
        self.filename_to_devino = {}
        self.devino_to_prefix_hash = {}
        self.devino_to_full_hash = {}

def main():
    # pylint: disable-msg=R0914,R0912
    # R0914: We want lots of local variables.  It's better than making them globals ^_^
    # R0912: I'm afraid we want lots of branches

    '''main function'''

#    if False:
#        # quick double_merge_sort test
#        lst = [ [1], [2], [3], [2], [5], [2], [2], [2], [3] ]
#        #import random
#        #lst = [ int(random.random()*100) for x in xrange(200) ]
#        print double_merge_sort(lst)
#        # this annoying sleep fixes an even more annoying sys.excepthook error
#        time.sleep(0.1)
#        sys.exit(0)

    verbosity, prefix_length, blocksize, show_uniques, show_duplicates, one_per_duplicate, use_dev_ino, null_delimited = \
        deal_with_command_line_options()

    if use_dev_ino:
        devino_stuff = Devino_stuff_class()
    else:
        devino_stuff = None
    B4b_class.devino_stuff = devino_stuff

    B4b_class.verbosity = verbosity

    unique = []

    total_file_count = 0

    # actually, it'd be a little faster to merge the separate_uniques calls into the various by_* functions, but:
    # 1) That's not as good software engineering.
    # 2) It gives the user a warm fuzzy feeling to get some indication of how well a given stage did.
    log(verbosity, 1, '----------\nGetting size_dict: %s\n' % time.ctime())
    initial_time = time.time()
    if null_delimited:
        iterator = readline0.readline0(sys.stdin)
    else:
        iterator = newline_delimited_lines()
    total_file_count, size_dict = by_size(verbosity, devino_stuff, total_file_count, iterator)

    # All files are considered possible duplicates at this point, because we haven't separated out the ones that have to be unique
    # based on their sizes yet.
    possible_remaining_dups = total_file_count
        
    log(verbosity, 1, 'Total file count is %d\n' % total_file_count)
    possible_remaining_dups = separate_uniques(unique, possible_remaining_dups, verbosity, size_dict, initial_time)

    log(verbosity, 1, '----------\nGetting prefix_hash_dict from size_dict: %s\n' % (time.ctime()))
    initial_time = time.time()
    prefix_hash_dict = by_prefix_hash(verbosity, devino_stuff, prefix_length, size_dict.items())
    possible_remaining_dups = separate_uniques(unique, possible_remaining_dups, verbosity, prefix_hash_dict, initial_time)

    del size_dict

    log(verbosity, 1, '----------\nGetting full_hash_dict from prefix_hash_dict: %s\n' % (time.ctime()))
    initial_time = time.time()
    full_hash_dict = by_full_hash(verbosity, devino_stuff, blocksize, prefix_hash_dict.items())
    possible_remaining_dups = separate_uniques(unique, possible_remaining_dups, verbosity, full_hash_dict, initial_time)

    del prefix_hash_dict

    log(verbosity, 1, '----------\nGetting duplicates_only_dict from full_hash_dict: %s\n' % (time.ctime()))
    initial_time = time.time()
    duplicates_only_dict = by_byte_for_byte_comparison(verbosity, full_hash_dict.iteritems())
    possible_remaining_dups = separate_uniques(unique, possible_remaining_dups, verbosity, duplicates_only_dict, initial_time, end=True)

    del full_hash_dict

    log(verbosity, 1, 'Final result: Got %d unique files and %d distinct duplicates\n' % (len(unique), len(duplicates_only_dict)))
    # display all the unique files
    if show_uniques:
        for unique_filename in unique:
            print unique_filename

    # now display all the duplicates
    if show_duplicates:
        for dups in duplicates_only_dict.values():
            if one_per_duplicate:
                print str(dups[0])
            else:
                print ' '.join( [ str(filename) for filename in dups ] )

main()