#!/usr/bin/env python3

"""Set up passwordless ssh on 1..n hosts, specified using a UUCP-like bang path."""

import sys
import pprint

sys.path.insert(0, '/usr/local/lib')  # noqa: ignore=E402
import deep_ssh


def usage(retval):
    """Output a usage message."""
    if retval:
        write = sys.stderr
    else:
        write = sys.stdout

    write("Usage: {} --bang-path host1!host2\n".format(sys.argv[0]))
    write("\n")
    write("EG:\n")
    write("    {} --bang-path host1\n".format(sys.argv[0]))
    write("    {} --bang-path user1@host1\n".format(sys.argv[0]))
    write("    {} --bang-path user1@host1!host2\n".format(sys.argv[0]))
    write("    {} --bang-path user1@host1!user2@host2\n".format(sys.argv[0]))
    write("(The assumption above is that you can only reach host2 from host1)\n")
    write("\n")
    write("In each case, you enter requested passwords - two hosts joined with a ! can request up to 2 valid\n")
    write("passwords.\n")
    write('\n')
    write("Note that if you have your public key on host A, and you hop through B to install the key on host\n")
    write("C, C will get the public key for A, not B!  Sometimes that's what you want, but certainly not always.\n")

    sys.exit(retval)


def yield_pairs(bang_path):
    """
    Yield up the left, right pairs of our bang path.

    If we have abc!def!ghi, then we should yield:
    (localhost, abc)
    (abc, abc!def)
    (abc!def, abc!def!ghi)
    """
    split_bang_path = bang_path.split('!')
    split_bang_path.insert(0, 'localhost')

    # This is one of those rare times in Python where we really do want range(len).
    for i in range(len(split_bang_path)):
        left = '!'.join(split_bang_path[:i])
        right = '!'.join(split_bang_path[:i+1])
        if left == '':
            left = 'localhost'
        yield (left, right)


def main():
    """Set up passwordless ssh across a bang path."""
    bang_path = ''

    while sys.argv[1:]:
        if sys.argv[1] == '--bang-path':
            bang_path = sys.argv[2]
            del sys.argv[1]
        elif sys.argv[1] in ('--help', '-h'):
            usage(0)
        else:
            sys.stderr.write("{}: Unrecognized option: {}\n".format(sys.argv[0], sys.argv[1]))
            usage(1)
        del sys.argv[1]

    if bang_path == '':
        sys.stderr.write('{}: --bang-path is a required option\n'.format(sys.argv[0]))
        usage(1)

    # This looks up the user_at_host and public_key_line variables
    # It's invariant over i.
    left_shell_command = '''
set -eu
# Using $USER has the advantage of working on nearly everything.
echo "$USER"
# Get the public key data
head -1 < ~/.ssh/id_rsa.pub
'''
    for left_path, right_path in yield_pairs(bang_path):
        # For this one, we want the output go to stderr, but we want to get stdout back into line_list.
        print("ssh'ing to left_path: {}".format(left_path))
        left_file = deep_ssh.handle(optional_opts='', chain=left_path, command=left_shell_command, popen=True, echo_chain=False)

        line_list = [line.strip() for line in left_file]

        retval = left_file.close()

        if retval is not None:
            sys.stderr.write('{}: left_path ({}) gave a weird exit code: {}\n'.format(sys.argv[0], retval))
            sys.stderr.write('Terminating unsuccessfully\n')
            sys.exit(1)

        if len(line_list) != 2:
            sys.stderr.write('{}: left_path ({}) gave weird output:\n'.format(sys.argv[0], left_path))
            sys.stderr.write(pprint.pformat(line_list))
            sys.stderr.write('Terminating unsuccessfully\n')
            sys.exit(1)

        public_key_str = line_list[1]
        public_key_list = public_key_str.split()
        public_key_email = public_key_list[2]

        # This removes any related keys, and then adds back just the one we are currently interested in.
        # It is a function of i.
        right_shell_command = '''
set -eu
temp=~/.ssh/authorized_keys.temp
touch "$temp"
chmod 600 "$temp"
awk ' {{ if ($3 != "{}") {{ print }} }}' < ~/.ssh/authorized_keys > "$temp"
echo '{}' >> "$temp"
mv "$temp" ~/.ssh/authorized_keys
'''.format(public_key_email, public_key_str)

        # For this one, we just let the output go to stdout/stderr
        print("ssh'ing to right_path: {}".format(right_path))
        retval = deep_ssh.handle(optional_opts='', chain=right_path, command=right_shell_command, popen=False, echo_chain=False)

        if retval:
            sys.stderr.write('{}: right_shell_command failed on right_path ({})\n'.format(sys.argv[0], right_path))
            sys.stderr.write('Terminating unsuccessfully\n')
            sys.exit(retval)

        print()


main()