#!/usr/local/cpython-3.4/bin/python3

'''Read data from a source (file or port) and write to a destination (file, port or bitbucket)'''

import re
import os
import sys
import errno
import signal
import functools
import socket as socket_mod


def usage(retval):
    '''Output a usage message'''

    if retval == 0:
        writer = sys.stdout.write
    else:
        writer = sys.stderr.write

    writer('Usage: {}\n'.format(sys.argv[0]))
    writer('   [-b blocksize_in_bytes]\n')
    writer('   [-t total_bytes_to_transfer]\n')
    writer('   [-i|-I port]\n')
    writer('   [-o|-O hostname port|-n]\n')
    writer('   [-w window]\n')
    writer('   [-v]\n')
    writer('\n')
    writer('-i\t\t\tsays to use file I/O to read from stdin\n')
    writer('-I port\t\t\tsays to use sockets to get input from "port" on the local machine\n')
    writer('\n')
    writer('-o\t\t\tsays to use file I/O to write to stdout\n')
    writer('-O hostname port\tsays to use sockets to connect to hostname on port\n')
    writer('-n\t\t\tsays to only read - do not write the data anywhere (should be faster than writing to /dev/null)\n')
    writer('-w size\t\t\tsays to set the TCP window size\n')
    writer('\n')
    writer('--O-udp\t\t\tSays to use UDP sockets, but only for the producer - IE modify -O mode\n')
    writer('--I-udp\t\t\tSays to use UDP sockets, but only for the consumer - IE modify -I mode\n')
    writer('-u\t\t\tSays to use UDP sockets, not the default, which is TCP (implies both of the above)\n')
    writer('\n')
    writer('Naturally, -i conflicts with -I, and -o conflicts with -O\n')
    writer('\n')
    writer("-w window conflicts with -u, because it's for setting the TCP window size\n")
    writer('\n')
    writer("-N [0|1] says to set TCP_NODELAY to 0 or 1 (0 enables Nagel, 1 disables) (_N_agel Algorithm)\n")
    writer('\n')
    writer('Also, -u of course requires -I or -O, --O-udp requires -O and --I-tcp requires -I\n')
    writer('Also note that -u may require small blocksizes compared to TCP\n')
    writer('For example 65508 was once the highest UDP blocksize on a\n')
    writer('run of the mill Fedora Core 4 system\n')
    writer('\n')
    writer('Among -i/-I and -o/-O, the lower case letter does file I/O, and the capital does socket I/O\n')

    sys.exit(retval)


def protocol_string(is_tcp):
    '''Return a string describing whether we are doing tcp or udp'''
    if is_tcp:
        return 'tcp'
    else:
        return 'udp'


def portno(service, is_tcp):
    '''Get a port number'''
    numeric = re.compile('^[0-9][0-9]*$')
    if not numeric.match(service):
        try:
            result = socket_mod.getservbyname(service, protocol_string(is_tcp))
        except socket_mod.error:
            sys.stderr.write("getservbyname({}, '{}') failed\n".format(service, protocol_string(is_tcp)))
    else:
        result = int(service)
    return result


def getip(host):
    '''Get an IP address'''
    numeric = re.compile(r'^([0-9][0-9]\.)*$')
    if not numeric.match(host):
        dict_ = {}
        # test for round robin DNS or other forms of variability
        reps = 10
        for counter in range(reps):
            dummy = counter
            res = socket_mod.gethostbyname(host)
            if res in dict_:
                dict_[res] += 1
            else:
                dict_[res] = 1
        if len(dict_.keys()) == 1:
            return list(dict_)[0]
        else:
            sys.stderr.write('Sorry, there is some variability in hostname lookups for {}\n'.format(host))
            sys.stderr.write('Please use an IP address instead\n')
            sys.exit(1)
    else:
        return host


def percent(numerator, denominator):
    '''Return a percentage for numerator and denominator'''
    return str(numerator * 100 // denominator)


def xfer_forever(verbose, mode, block_len, readfn, writefn):
    '''Transfer data forever (or until EOF)'''
    if verbose:
        sys.stderr.write('{}, {}: xfer_forever: Infinite loop starting\n'.format(sys.argv[0], mode))
    while True:
        block = readfn(block_len)
        if verbose:
            sys.stderr.write('{}, {}: read block (no len for speed)\n'.format(sys.argv[0], mode))
        if not block:
            if verbose:
                sys.stderr.write('{}, {}: 0 length block received\n'.format(sys.argv[0], mode))
            break
        done = False
        try:
            writefn(block)
        except (socket_mod.error, OSError, IOError) as exc:
            this_errno = exc[1].errno
            if this_errno == errno.EPIPE:
                # Broken pipe - just ignore it and exit
                done = True
            else:
                sys.stderr.write('Error: {}\n'.format(errno.errorcode[this_errno]))
                sys.exit(1)
        if done:
            break
        if verbose:
            sys.stderr.write('{}, {}: wrote block (no len for speed)\n'.format(sys.argv[0], mode))


def xfer_to_maximum(verbose, mode, total_to_xfer, block_len, readfn, writefn, total_bytes):
    # pylint: disable=too-many-arguments,too-many-branches
    '''Read and write only up to total_to_xfer bytes'''
    if verbose:
        sys.stderr.write('{}, {}: xfer_to_maximum: Transferring <= {} bytes (all told)\n'.format(sys.argv[0], mode, total_to_xfer))
    total_xferred = 0
    while True:
        if total_xferred + block_len > total_bytes:
            intended_block_len = total_to_xfer - total_xferred
        else:
            intended_block_len = block_len
        if intended_block_len == 0:
            if verbose:
                sys.stderr.write('{}, {}: All bytes transferred\n'.format(sys.argv[0], mode))
            break
        block = readfn(intended_block_len)
        if not block:
            if verbose:
                sys.stderr.write('{}, {}: 0 length block received\n'.format(sys.argv[0], mode))
            break
        actual_block_len = len(block)
        total_xferred += actual_block_len
        done = False
        try:
            writefn(block)
        except (socket_mod.error, OSError, IOError) as exc:
            this_errno = exc.errno
            if this_errno == errno.EPIPE:
                # Broken pipe - just ignore it
                done = True
            else:
                sys.stderr.write('Error: {}\n'.format(errno.errorcode[this_errno], ))
                sys.exit(1)
        if done:
            break
        if verbose:
            sys.stderr.write('{}, {}: Transmitted block of size {}, '.format(
                sys.argv[0],
                mode,
                actual_block_len,
            ))
            sys.stderr.write(' transferred {} of {}, {}%% complete\n'.format(
                total_xferred,
                total_to_xfer,
                percent(total_xferred, total_to_xfer),
            ))


def xfer(verbose, mode, total_to_xfer, block_len, readfn, writefn, total_bytes):
    # pylint: disable=too-many-arguments
    '''Transfer function'''
    if total_to_xfer == 0:
        xfer_forever(verbose, mode, block_len, readfn, writefn)
    else:
        xfer_to_maximum(verbose, mode, total_to_xfer, block_len, readfn, writefn, total_bytes)


def read_from_0(length):
    '''Read from stdin via file descriptor 0'''
    return os.read(0, length)


def write_to_1(data_buffer):
    '''Write to stdout via file descriptor 1'''
    offset = 0
    buffer_length = len(data_buffer)
    while True:
        if offset == 0:
            length_actually_written = os.write(1, data_buffer)
        else:
            # dang slicing...  This is going to duplicate part of the data_buffer, which is slow...
            length_actually_written = os.write(1, data_buffer[offset:])
        offset += length_actually_written
        if offset >= buffer_length:
            return


def udpreadfn(length, sockin, verbose):
    '''UDP reader'''
    data, remotehost = sockin.recvfrom(length)
    if verbose:
        sys.stderr.write('Received UDP packet from {}\n'.format(remotehost))
    return data


def udpwritefn(data_buffer, sockout, hostname, output_portno):
    '''UDP writer'''
    sockout.sendto(data_buffer, (hostname, output_portno))


def noop(data_buffer):
    '''A bit bucket - eat all data given us'''
    dummy = data_buffer


def type_of_socket(is_tcp):
    '''Return the correct type of socket to use - TCP (stream) or UDP (datagram)'''
    if is_tcp:
        return socket_mod.SOCK_STREAM
    else:
        return socket_mod.SOCK_DGRAM


def do_outgoing_socket(hostname, port, stream_or_dgram):
    '''
    Establish an outgoing socket connection in a way that should work with IPv4 and IPv6.
    Based on an example at https://docs.python.org/3/library/socket.html#socket.getaddrinfo
    '''
    socket = None
    for res in socket_mod.getaddrinfo(hostname, port, socket_mod.AF_UNSPEC, stream_or_dgram):
        (address_family, socktype, proto, canonname, sockaddr) = res
        dummy = canonname
        try:
            socket = socket_mod.socket(address_family, socktype, proto)
        except OSError:
            socket = None
            continue
        try:
            socket.connect(sockaddr)
        except OSError:
            socket.close()
            socket = None
            continue
        break
    if socket is None:
        sys.stderr.write('{}: Could not open socket\n'.format(sys.argv[0]))
        sys.exit(1)
    return socket


def main():
    '''Main function'''
    # pylint: disable=too-many-statements,too-many-branches,too-many-locals

    if hasattr(signal, 'SIGPIPE'):
        signal.signal(signal.SIGPIPE, signal.SIG_IGN)

    # infinite by default
    total_bytes = 0

    # one megabyte by default
    blocksize = 1024 * 1024

    verbose = False
    use_stdio = False
    tcp_nodelay_flag = 0
    tcp_nodelay_selected = True
    default_tcp_window_size = 64 * 1024
    tcp_window_size = -1

    output_tcp = True
    input_tcp = True

    file_in = False
    file_out = False
    socket_in = False
    socket_out = False
    null_out = False

    while sys.argv[1:]:
        if sys.argv[1] == '-b' and sys.argv[2:]:
            blocksize = int(sys.argv[2])
            del sys.argv[1]
        elif sys.argv[1] == '-i':
            file_in = True
        elif sys.argv[1] == '-I' and sys.argv[2:]:
            socket_in = True
            input_port_description = sys.argv[2]
            del sys.argv[1]
        elif sys.argv[1] == '-o':
            file_out = True
        elif sys.argv[1] == '-O' and sys.argv[3:]:
            socket_out = True
            hostname = getip(sys.argv[2])
            output_port_description = sys.argv[3]
            del sys.argv[1]
            del sys.argv[1]
        elif sys.argv[1] == '-n':
            null_out = True
        elif sys.argv[1] == '-N' and sys.argv[2:]:
            tcp_nodelay_selected = True
            tcp_nodelay_flag = int(sys.argv[2])
            del sys.argv[1]
        elif sys.argv[1] == '-t' and sys.argv[2:]:
            total_bytes = int(sys.argv[2])
            del sys.argv[1]
        elif sys.argv[1] in ['-h', '--help']:
            usage(0)
        elif sys.argv[1] == '-v':
            verbose = True
        elif sys.argv[1] == '-u':
            output_tcp = False
            input_tcp = False
        elif sys.argv[1] == '--O-udp':
            output_tcp = False
        elif sys.argv[1] == '--I-udp':
            input_tcp = False
        elif sys.argv[1] == '-w' and sys.argv[2:]:
            tcp_window_size = int(sys.argv[2])
            del sys.argv[1]
        else:
            sys.stderr.write('Illegal option: {}\n'.format(sys.argv[1]))
            usage(1)
        del sys.argv[1]

    if file_in + socket_in != 1:
        sys.stderr.write('You must specify exactly one of -i and -I (file_in is {}, socket_in is {})\n'.format(file_in, socket_in))
        usage(1)

    if file_out + socket_out + null_out != 1:
        sys.stderr.write('You must specify exactly one of -o, -O and -n\n')
        usage(1)

    if not input_tcp and not (socket_in or socket_out):
        sys.stderr.write('-u requires -I and/or -O\n--O-udp requires -I\n')
        usage(1)

    if not output_tcp and not (socket_in or socket_out):
        sys.stderr.write('-u requires -I and/or -O\n--I-udp requires -O\n')
        usage(1)

    if tcp_window_size != -1 and not input_tcp and not output_tcp:
        sys.stderr.write('-u and -w conflict\n')
        usage(1)

    if tcp_window_size == -1:
        tcp_window_size = default_tcp_window_size

    if socket_in:
        input_portno = portno(input_port_description, input_tcp)

    if socket_out:
        output_portno = portno(output_port_description, output_tcp)

    if file_in:
        if use_stdio:
            readfn = sys.stdin.read
        else:
            readfn = read_from_0
        mode = 'input from stdin'
    elif socket_in:
        input_socket = socket_mod.socket(socket_mod.AF_INET, type_of_socket(input_tcp))
        input_socket.setsockopt(socket_mod.SOL_SOCKET, socket_mod.SO_REUSEADDR, 1)
        # this would only bind to the primary hostname of the machine:
        # input_socket.bind((socket.gethostname(),port))
        #
        # '' is supposed to mean "bind to all interfaces"
        input_socket.bind(('', input_portno))
        if input_tcp:
            input_socket.setsockopt(socket_mod.SOL_SOCKET, socket_mod.SO_RCVBUF, tcp_window_size)
        if input_tcp:
            if tcp_nodelay_selected:
                sys.stderr.write('Setting TCP_NODELAY to {}\n'.format(tcp_nodelay_flag))
                input_socket.setsockopt(socket_mod.SOL_TCP, socket_mod.TCP_NODELAY, tcp_nodelay_flag)
            else:
                sys.stderr.write('Accepting system default for TCP_NODELAY\n')
            input_socket.listen(0)
            if verbose:
                sys.stderr.write('Waiting for connection...\n')
            input_connection, (remotehost, remoteport) = input_socket.accept()
            if verbose:
                sys.stderr.write('Received connection from {} {}\n'.format(remotehost, remoteport))
            readfn = input_connection.recv
        else:
            readfn = functools.partial(udpreadfn, sockin=input_socket, verbose=verbose)
        mode = 'input from socket'
    else:
        sys.stderr.write('{}: Weird 1\n'.format(sys.argv[0]))
        sys.exit(1)

    if file_out:
        if use_stdio:
            writefn = sys.stdout.write
        else:
            writefn = write_to_1
        mode += ', output to stdout'
    elif socket_out:
        if output_tcp:
            output_socket = do_outgoing_socket(hostname, output_portno, socket_mod.SOCK_STREAM)
            writefn = output_socket.send
        else:
            output_socket = do_outgoing_socket(hostname, output_portno, socket_mod.SOCK_DGRAM)
            writefn = functools.partial(udpwritefn, sockout=output_socket, hostname=hostname, output_portno=output_portno)
        if output_tcp:
            # this one's assumed to only be appropriate for TCP, though there may be other reliable types someday that want this too
            output_socket.setsockopt(socket_mod.SOL_SOCKET, socket_mod.SO_SNDBUF, tcp_window_size)
            if tcp_nodelay_selected:
                sys.stderr.write('Setting TCP_NODELAY to {}\n'.format(tcp_nodelay_flag))
                output_socket.setsockopt(socket_mod.SOL_TCP, socket_mod.TCP_NODELAY, tcp_nodelay_flag)
            else:
                sys.stderr.write('Accepting system default for TCP_NODELAY\n')
        mode += ', output to {} socket'.format(protocol_string(output_tcp))
    elif null_out:
        writefn = noop
        mode += ', no output'
    else:
        sys.stderr.write('{}: Weird 2\n'.format(sys.argv[0]))
        sys.exit(1)

    if verbose:
        sys.stderr.write('{} starting up: {}\n'.format(sys.argv[0], mode))

    xfer(verbose, mode, total_bytes, blocksize, readfn, writefn, total_bytes)

    if verbose:
        sys.stderr.write('Terminating normally: {}\n'.format(mode))

    sys.exit(0)


main()