#!/usr/bin/python3

"""
Perform set arithmetic, treating files as sets.

Each file is a list of elements, one element per line.  Order of lines does not matter.
"""

import decimal
import os
import sys


def read_set(filename):
    """Read a set into memory."""
    set_ = set()
    with open(filename, 'rb') as file_:
        for line in file_:
            if line[-1:] == b'\n':
                line = line[:-1]
            set_.add(line)
    return set_


def write_set(set_):
    """Write a set to stdout."""
    for element in set_:
        os.write(1, b'%s\n' % element)


def usage(retval):
    """Output a usage message."""
    if retval:
        write = sys.stderr.write
    else:
        write = sys.stdout.write

    write('Usage: %s\n' % sys.argv[0])
    write('\t--union file1 file2                   write the union of the two files to stdout\n')
    write('\t--intersection file1 file2            write the intersection of the two files to stdout\n')
    write('\t--difference file1 file2              write the difference of the two files to stdout\n')
    write('\t--symmetric-difference file1 file2    write the symmetric difference of the two files to stdout\n')
    write('\t--is-subset file1 file2               exit true if file1 is a subset of file2\n')
    write('\t--is-superset file1 file2             exit true if file1 is a superset of file2\n')
    write('\t--is-proper-subset file1 file2        exit true if file1 is a proper subset of file2\n')
    write('\t--is-proper-superset file1 file2      exit true if file1 is a proper superset of file2\n')
    write('\t--is-equal file1 file2                exit true if file1 is equal to file2\n')
    write('\t--is-unequal file1 file2              exit true if file1 is not equal to file2\n')
    write(
        '\t--fuzzy-match file1 file2             output 0.0 for no overlap, 1.0 for all overlap, 0 < n < 1.0 for partial overlap\n'
    )
    write('\nThis command treats files as sets, one element per line.\n')
    write('All output is to stdout or the exit status.\n')

    sys.exit(retval)


def main():
    # pylint: disable=too-many-branches,too-many-statements
    """Perform set arithmetic based on command line options."""
    if sys.argv[1:]:
        if sys.argv[1] == '--union' and not sys.argv[4:]:
            set1 = read_set(sys.argv[2])
            set2 = read_set(sys.argv[3])
            # This is just a reference, not a copy
            result = set1
            result |= set2
            write_set(result)
            sys.exit(0)
        elif sys.argv[1] == '--intersection' and not sys.argv[4:]:
            set1 = read_set(sys.argv[2])
            set2 = read_set(sys.argv[3])
            # This is just a reference, not a copy
            result = set1
            result &= set2
            write_set(result)
            sys.exit(0)
        elif sys.argv[1] == '--difference' and not sys.argv[4:]:
            set1 = read_set(sys.argv[2])
            set2 = read_set(sys.argv[3])
            # This is just a reference, not a copy
            result = set1
            result -= set2
            write_set(result)
            sys.exit(0)
        elif sys.argv[1] == '--symmetric-difference' and not sys.argv[4:]:
            set1 = read_set(sys.argv[2])
            set2 = read_set(sys.argv[3])
            # This is just a reference, not a copy
            result = set1
            result ^= set2
            write_set(result)
            sys.exit(0)
        elif sys.argv[1] == '--is-equal' and not sys.argv[4:]:
            set1 = read_set(sys.argv[2])
            set2 = read_set(sys.argv[3])
            result = set1 == set2
            if result:
                sys.exit(0)
            else:
                sys.exit(1)
        elif sys.argv[1] == '--is-subset' and not sys.argv[4:]:
            set1 = read_set(sys.argv[2])
            set2 = read_set(sys.argv[3])
            result = set1 <= set2
            if result:
                sys.exit(0)
            else:
                sys.exit(1)
        elif sys.argv[1] == '--is-superset' and not sys.argv[4:]:
            set1 = read_set(sys.argv[2])
            set2 = read_set(sys.argv[3])
            result = set1 >= set2
            if result:
                sys.exit(0)
            else:
                sys.exit(1)
        elif sys.argv[1] == '--is-proper-subset' and not sys.argv[4:]:
            set1 = read_set(sys.argv[2])
            set2 = read_set(sys.argv[3])
            result = set1 < set2
            if result:
                sys.exit(0)
            else:
                sys.exit(1)
        elif sys.argv[1] == '--is-proper-superset' and not sys.argv[4:]:
            set1 = read_set(sys.argv[2])
            set2 = read_set(sys.argv[3])
            result = set1 > set2
            if result:
                sys.exit(0)
            else:
                sys.exit(1)
        elif sys.argv[1] == '--is-equal' and not sys.argv[4:]:
            set1 = read_set(sys.argv[2])
            set2 = read_set(sys.argv[3])
            result = set1 == set2
            if result:
                sys.exit(0)
            else:
                sys.exit(1)
        elif sys.argv[1] == '--is-not-equal' and not sys.argv[4:]:
            set1 = read_set(sys.argv[2])
            set2 = read_set(sys.argv[3])
            result = set1 != set2
            if result:
                sys.exit(0)
            else:
                sys.exit(1)
        elif sys.argv[1] in ('-h', '--help'):
            usage(0)
        elif sys.argv[1] == '--fuzzy-match' and not sys.argv[4:]:
            set1 = read_set(sys.argv[2])
            set2 = read_set(sys.argv[3])
            symmetric_difference = set1
            symmetric_difference ^= set2
            one = decimal.Decimal(1)
            len_sym_diff = decimal.Decimal(len(symmetric_difference))
            len_union = decimal.Decimal(len(set1 | set2))
            matched_count = one - len_sym_diff / len_union
            print(matched_count)
        else:
            sys.stderr.write('%s: Unrecognized option: %s\n' % (sys.argv[0], sys.argv[1]))
            usage(1)
    else:
        usage(0)


main()