#!/usr/bin/python3

# pylint: disable=superfluous-parens
# superfluous-parens: Parentheses are good for clarity and portability

"""
Compute the maximum TCP window size for the system we're running on.

First bound up by powers of 2, until we get something too big.
Then binary search between the last two powers of 2 until we get
the precise value.
"""

import sys
import errno
import socket


def powers_of_2():
    """Return powers of 2."""
    value = 1
    while True:
        yield value
        value *= 2


class Window_test(object):
    # pylint: disable=too-few-public-methods
    # too-few-public-methods: We don't need a lot of public methods
    """Test setting a TCP window of a given size."""

    def __init__(self):
        """Initialize."""
        self.socket_type = socket.SOCK_STREAM
        self.input_socket = socket.socket(socket.AF_INET, self.socket_type)
        self.input_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        self.input_socket.bind(('', 12346))

    def window_size_works(self, tcp_window_size):
        """Return True iff tcp_window_size is allowable."""
        try:
            self.input_socket.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, tcp_window_size)
        except socket.error:
            tuple_ = sys.exc_info()
            if tuple_[1] == errno.EBADMSG:
                return False
            else:
                string = '%s: unexpected error %s setting TCP window size to %d\n'
                message_tuple = (sys.argv[0], tuple_[1], tcp_window_size)
                sys.stderr.write(string % message_tuple)
                sys.exit(1)
        except TypeError:
            # Apparently sizes over 0x80000000 cause a TypeError
            return False
        else:
            return True


def bisect(window_test, lower_bound, upper_bound):
    """Use binary search to narrow down where the maximum is."""
    sys.stderr.write('Checking between {} and {}\n'.format(lower_bound, upper_bound))

    midpoint = (lower_bound + upper_bound) // 2

    if midpoint == lower_bound or midpoint == upper_bound:
        return midpoint

    if window_test.window_size_works(midpoint):
        return bisect(window_test, midpoint, upper_bound)
    return bisect(window_test, lower_bound, midpoint)


def main():
    """Find maximum TCP size: main function."""
    window_test = Window_test()
    tcp_window_size = 0
    for tcp_window_size in powers_of_2():
        sys.stderr.write('Checking {}\n'.format(tcp_window_size))
        if not window_test.window_size_works(tcp_window_size):
            break
    lower_bound = tcp_window_size // 2
    upper_bound = tcp_window_size
    # sys.stderr.write('Max window size must be between {} and {}\n'.format(lower_bound, upper_bound))
    actual_value = bisect(window_test, lower_bound, upper_bound)
    print('{} == {}'.format(actual_value, hex(actual_value)))


main()