#!/usr/bin/env python3

"""Use dd and gprog to more-safely clone a partition with progress."""

# TODO:
# + 1) Might someday get some use out of: findmnt --fstab --evaluate --json
#      It should help tell what should be mounted where, according to the fstab.
# + 2) Yes/no question should make it clear that you can answer y or n.
# + 3) Use parted --machine; see comment in get_partition_size().
# - 4) Use fsent_mod instead of findmnt --fstab

import os
import re
import subprocess
import sys
import typing

sys.path.insert(0, '/usr/local/lib')

import fsent_mod  # noqa: E402


def usage(retval: int) -> None:
    """Output a usage message and exit as requested."""
    if retval != 0:
        write = sys.stderr.write
    else:
        write = sys.stdout.write

    write('Usage: {}\n'.format(sys.argv[0]))
    write('    --source-partition /dev/sdb1\n')
    write('    --allow-mounted-source\n')
    write('    --destination-partition /dev/sdd1\n')
    write('    --dry-run\n')
    write('    --help\n')
    write('\n')
    write('Clone a partition to another partition on the same computer.\n')
    write('\n')
    write('Consider using fsck on your source partition before running this script.\n')
    write('On XFS, that might look like: xfs_repair /dev/sdb1\n')
    write('\n')
    write('Also, you may or may not want to generate a new UUID for your destination\n')
    write('filesystem after running this script.  On XFS, that might look like:\n')
    write('xfs_admin -U generate /dev/sdd1\n')
    sys.exit(retval)


def get_partition_size(partition: bytes) -> int:
    """Look up and return the number of bytes in a given partition."""
    # $ parted --machine /dev/sda1 unit B print
    # BYT;
    # /dev/sda1:6001173463040B:unknown:512:4096:loop:Unknown:;
    # 1:0B:6001173463039B:6001173463040B:xfs::;

    with subprocess.Popen(['parted', '--machine', partition, 'unit', 'B', 'print'], stdout=subprocess.PIPE) as proc:
        assert proc.stdout is not None
        lines = proc.stdout.readlines()
    for line in lines:
        fields = line.split(b':')
        if fields and fields[0] == partition:
            length_field_bytes = fields[1].rstrip(b'B')
            break
    else:
        os.write(1, b'partition %s not found by parted\n' % (partition, ))
        sys.exit(1)
    length_field_str = length_field_bytes.decode('ISO-8859-1')
    return int(length_field_str)


class OneFstabEntry:
    """Container class for one fstab entry."""

    def __init__(self, *, source, target, fstype):
        """Initialize."""
        self.source = source
        self.target = target
        self.fstype = fstype

    def __repr__(self):
        """Output a representation of this fstab entry - for debugging."""
        return '{} - {} - {}'.format(self.source, self.target, self.fstype)

    def __str__(self):
        """Output a str representation of this fstab entry."""
        return '{} of type {} in fstab'.format(self.target, self.fstype)


def get_mount_points_from_fstab() -> typing.DefaultDict[bytes, typing.List[bytes]]:
    """Return a defaultdict from device file to mount points.  This is actually more than just an fstab inspection."""
    partition_info = fsent_mod.PartitionInfo(b'/etc/fstab')
    result = partition_info.mount_point_from_device_path
    return result


def check_partition_form(path: bytes, description: bytes) -> bytes:
    """
    Check whether partition path is of the form /dev/sd[a-m][0-9].

    Note that other operating systems with /dev or similar can be added here with minimal fuss; we make no further
    assumptions about the general form of a device file.
    """
    regex = re.compile(b'^/dev/sd[a-m][0-9]$')
    match_obj = regex.match(path)
    if match_obj is None:
        os.write(2, b'%s: %s partition %s not of the correct form. Are you on Linux?\n' % (
            sys.argv[0].encode('UTF-8'), description, path),
        )
        usage(1)
    return path


def partition_mounted(partition: bytes) -> bool:
    """Return True iff partition is mounted."""
    with subprocess.Popen(['mount'], stdout=subprocess.PIPE) as proc:
        assert proc.stdout is not None
        lines = proc.stdout.readlines()
    if any(line and line.split()[0] == partition for line in lines):
        return True
    return False


def main() -> None:
    """Kick off the process."""
    source_partition = None
    destination_partition = None
    dry_run = False
    allow_mounted_source = False

    while sys.argv[1:]:
        if sys.argv[1] == '--source-partition':
            source_partition = check_partition_form(sys.argv[2].encode('UTF-8'), b'source')
            del sys.argv[1]
        elif sys.argv[1] == '--destination-partition':
            destination_partition = check_partition_form(sys.argv[2].encode('UTF-8'), b'destination')
            del sys.argv[1]
        elif sys.argv[1] == '--dry-run':
            dry_run = True
        elif sys.argv[1] == '--allow-mounted-source':
            allow_mounted_source = True
        elif sys.argv[1] in ('--help', '-h'):
            usage(0)
        else:
            os.write(2, b'%s: unrecognized option: %s' % (sys.argv[0].encode('UTF-8'), sys.argv[1].encode('UTF-8')))
            usage(1)
        del sys.argv[1]

    all_good = True

    euid = os.geteuid()
    if euid != 0:
        print('Please run me as root instead', file=sys.stderr)
        all_good = False

    fstab_dict = get_mount_points_from_fstab()

    if source_partition is None:
        print('--source-partition partition is required.', file=sys.stderr)
        all_good = False

    if destination_partition is None:
        print('--destination-partition partition is required.', file=sys.stderr)
        all_good = False

    if source_partition == destination_partition:
        print('Danger. Do not try to copy a partition to itself!', file=sys.stderr)
        all_good = False

    if not all_good:
        print('One or more items of the early-stage preflight check failed.  Please try again.\n', file=sys.stderr)
        usage(1)

    # These asserts are to keep mypy happy.
    assert source_partition is not None
    assert destination_partition is not None

    if not allow_mounted_source and partition_mounted(source_partition):
        os.write(2, b'source partition %s is mounted. Please umount it and try again\n' % (source_partition, ))
        all_good = False

    if partition_mounted(destination_partition):
        os.write(
            2,
            b'destination partition %s is mounted. Please umount it and try again\n' % (destination_partition, ),
        )
        all_good = False

    source_partition_length = get_partition_size(source_partition)
    destination_partition_length = get_partition_size(destination_partition)

    if source_partition_length == destination_partition_length:
        os.write(1, b'Good %s and %s have the same length.  You should not need to resize.\n' % (
            source_partition,
            destination_partition,
            ))
    elif source_partition_length < destination_partition_length:
        os.write(1, b"OK, %s is smaller than %s.  You'll probably want to resize %s after the copy.\n" % (
            source_partition,
            destination_partition,
            destination_partition,
            ))
    else:
        os.write(
            2,
            b'%s: %s is too small.  Exiting without copying.\n' %
            (sys.argv[0].encode('UTF-8'), destination_partition),
        )
        all_good = False

    source_partition_length_in_meg, remainder = divmod(source_partition_length, 2**20)
    if remainder != 0:
        os.write(
            2,
            b"Sorry, source partition %s's length is not evenly divisible by 1 megabyte.\n" % (source_partition, ),
        )
        print('Exiting without copying.', file=sys.stderr)
        all_good = False

    command = b'dd bs=1024k count="%d" if="%b" | gprog --size-estimate "%d" > "%b"' % (
        source_partition_length_in_meg,
        source_partition,
        source_partition_length,
        destination_partition,
        )

    if not all_good:
        print('One or more items in the late-stage preflight check failed.  Please try again.\n', file=sys.stderr)
        usage(1)

    while True:
        os.write(1, b'Preflight check passed.\n')
        os.write(1, b'About to run %s\n\n' % (command, ))
        os.write(1, b'This command will erase all data in your %s partition (%s),\n' % (
            destination_partition,
            b', '.join(str(fsent).encode('UTF-8') for fsent in fstab_dict[destination_partition]),
        ))
        os.write(1, b'replacing it with what is currently in %s (%s)!\n' % (
            source_partition,
            b', '.join(str(fsent).encode('UTF-8') for fsent in fstab_dict[source_partition]),
        ))
        os.write(1, b'\n')
        os.write(1, b'Are you completely sure you want to copy all of %s to %s (y/n)?' % (source_partition, destination_partition))

        line = sys.stdin.readline().lower()
        if line.startswith('y'):
            break
        elif line.startswith('n'):
            os.write(1, b'Exiting by request of user.\n')
            sys.exit(1)
        else:
            os.write(1, b'Please enter "y" or "n" (without the quotes\n')

    if dry_run:
        os.write(1, b"I would have run it, but you specified --dry-run\n")
    else:
        process = subprocess.run(command, shell=True)

        if process.returncode == 0:
            os.write(1, b'Happy completion :)\n')
            exit(0)
        else:
            os.write(1, b'Sad completion :(\n')
            os.write(1, str(process).encode('UTF-8'))
            os.write(1, b'\n')
            exit(1)


if __name__ == '__main__':
    main()