#!/usr/bin/python3

"""Allow reading fstab and blkid's."""

import collections
import os
import pprint
import re
import subprocess
import sys
import typing


class Fsent:
    """Hold one fstab entry."""

    def __init__(self, device_spec: bytes, mount_point: bytes, fstype: bytes, options: bytes, dump: bytes, pass_: bytes) -> None:
        """Initialize."""
        self.device_spec = re.sub(b'"', b'', device_spec)
        self.mount_point = mount_point
        self.fstype = fstype
        self.options = options
        self.dump = dump
        self.pass_ = pass_

    def __str__(self):
        """Convert to string."""
        return f'{self.device_spec} - {self.mount_point} - {self.fstype} - {self.options} - {self.dump} - {self.pass_}'

    __repr__ = __str__


def getfsents(file_: typing.BinaryIO) -> typing.List[Fsent]:
    """Get a list of Fsent's."""
    list_ = []
    for lineno, line in enumerate(file_, 1):
        commentless_line = re.sub(br'#.*$', b'', line.rstrip(b'\n'))
        fields = commentless_line.split()
        if len(fields) == 0:
            continue
        if len(fields) != 6:
            print(f'Warning: fstab line {lineno} does not contain 6 fields', file=sys.stderr)
            continue
        fsent = Fsent(*fields)
        list_.append(fsent)
    return list_


class Blkid:
    """Hold one blkid, with its attributes."""

    def __init__(self, line: bytes) -> None:
        """Initialize."""
        tuple_ = line.rstrip(b'\n').partition(b':')
        assert tuple_[1] == b':'
        self.device_spec = tuple_[0]
        attrs_string = tuple_[2]
        self.attrs_dict: typing.Dict[bytes, bytes] = {}
        for inner_match in re.finditer(br'\s*([\w]*)="([^"]*)"', attrs_string):
            if inner_match is None:
                raise ValueError(b'Odd input from blkid (2): %s' % (attrs_string, ))
            attr_name = inner_match.group(1)
            if attr_name == b'UUID':
                attr_value = inner_match.group(2).strip(b'"')
            else:
                attr_value = inner_match.group(2)
            self.attrs_dict[attr_name] = attr_value

    def __str__(self) -> str:
        """Return a string version of this blkid."""
        joined_bytes = b' '.join(b'%s: %s' % (attr_name, attr_value) for attr_name, attr_value in self.attrs_dict.items())
        bytes_result = (b'%s: ' % self.device_spec) + joined_bytes
        return bytes_result.decode('UTF-8')

    __repr__ = __str__


def get_blkids() -> typing.List[Blkid]:
    """Get blkid's via subprocess to blkid command."""
    # Note that the sdd1 line here is all one, long line; blkid doesn't wrap them, but pycodestyle didn't like the long line,
    # so I wrapped it for inclusion here.
    #
    # $ blkid
    # /dev/sda1: UUID="401d5c7d-5486-4278-8a3b-d74247f55dd6" TYPE="ext4" PARTUUID="2e95eb85-7548-44ff-a9ed-2851461312cf"
    # /dev/sdc1: UUID="0d6cb13d-9e6f-4787-a63e-a2ce4dd6b1d6" TYPE="ext4" PARTUUID="8e590aa5-01"
    # /dev/sdd1: UUID="ad63caa5-30a9-452a-b0de-c8881e32ba0b" TYPE="xfs" PARTLABEL="primary"
    #       PARTUUID="1c594d54-359e-4973-bffe-b1c4cce52eb1"
    # /dev/sdb1: UUID="ceec2609-86a8-4e08-b0b7-f0866537fbf5" TYPE="xfs" PARTUUID="74805800-01"
    with subprocess.Popen(['blkid'], stdout=subprocess.PIPE) as proc:
        assert proc.stdout is not None
        lines = proc.stdout.readlines()

    list_ = []

    for line in lines:
        blkid = Blkid(line)
        list_.append(blkid)

    return list_


class PartitionInfo:
    """Get partition data, scrutinize it, and make it available to the caller."""

    def __init__(self, filename: bytes):
        """Initialize."""
        self.fstab_filename = filename
        with open(self.fstab_filename, 'rb') as file_:
            self.fsents = getfsents(file_)
        self.check_fsents_for_duplicates()
        self.blkids = get_blkids()
        self.check_blkids_for_duplicates()
        self.build_mount_point_from_device()

    def build_mount_point_from_device(self) -> None:
        """
        Build a dictionary self.mount_point_from_device_path, that maps device paths (EG /dev/sda1) to mount points (EG /).

        We get the UUID's in by_blkid_device_spec using the device path, and then get the fsent by UUID in by_fsent_uuid.
        """
        self.mount_point_from_device_path: typing.DefaultDict[bytes, typing.List[bytes]] = collections.defaultdict(list)
        for device_path, blkids in self.by_blkid_device_spec.items():
            blkid = blkids[0]
            if b'UUID' in blkid.attrs_dict:
                # We prefer a normal UUID
                uuid: typing.Optional[bytes] = blkid.attrs_dict[b'UUID']
            else:
                # But we're willing to use a partition with no UUID yet.
                uuid = None
            if uuid is None:
                fsents: typing.List[Fsent] = []
            else:
                fsents = self.by_fsent_uuid[uuid]
            # If there's more than one mount point for a given device file, error
            if len(fsents) == 0:
                os.write(1, b'No mountpoint for UUID %s in fstab. Continuing.\n' % (uuid, ))
                continue
            elif len(fsents) != 1:
                raise AssertionError(f"len(fsents) ({len(fsents)}) is not 1: {fsents}")
            fsent = fsents[0]
            # It's actually the mount point that we want.  UUID in fsent is unimportant.
            self.mount_point_from_device_path[device_path].append(fsent.mount_point)

    def check_fsents_for_duplicates(self) -> None:
        """Error out if there is a duplicate mount point or duplicate device spec in fstab."""
        self.by_fsent_device_spec: typing.DefaultDict[bytes, typing.List[Fsent]] = collections.defaultdict(list)
        for fsent in self.fsents:
            if fsent.device_spec.startswith(b'/dev/'):
                raise ValueError(b'Error: %s contains one or more hardcoded device files' % (self.fstab_filename, ))
            self.by_fsent_device_spec[fsent.device_spec].append(fsent)
        for device_spec, ddict in self.by_fsent_device_spec.items():
            if len(ddict) != 1:
                device_spec_str = device_spec.decode('UTF-8')
                ddict_str = str(ddict)
                raise ValueError(f'duplicate fstab device spec found: {device_spec_str}, {ddict_str}')

        self.by_fsent_uuid: typing.DefaultDict[bytes, typing.List[Fsent]] = collections.defaultdict(list)
        for fsent in self.fsents:
            if fsent.device_spec.startswith(b'UUID='):
                part = fsent.device_spec.partition(b'=')
                assert part[0] == b'UUID', f'part[0] is {str(part[0])}'
                assert part[1] == b'='
                uuid = part[2].strip(b'"')
                self.by_fsent_uuid[uuid].append(fsent)

        self.by_mount_point: typing.DefaultDict[bytes, typing.List[Fsent]] = collections.defaultdict(list)
        for fsent in self.fsents:
            self.by_mount_point[fsent.mount_point].append(fsent)
        for mount_point, ddict in self.by_mount_point.items():
            if len(ddict) != 1:
                str_ddict = str(ddict)
                bytes_ddict = str_ddict.encode('UTF-8')
                raise ValueError(b'duplicate fstab mount point found: %s, %s' % (mount_point, bytes_ddict))

    def check_blkids_for_duplicates(self) -> None:
        """Error out if there is a duplicate device spec or duplicate uuid in blkids."""
        self.by_blkid_device_spec: typing.DefaultDict[bytes, typing.List[Blkid]] = collections.defaultdict(list)
        for blkid in self.blkids:
            self.by_blkid_device_spec[blkid.device_spec].append(blkid)
        for device_spec, ddict in self.by_blkid_device_spec.items():
            if len(ddict) != 1:
                raise ValueError(b'duplicate blkid device spec found: %s, %s' % (device_spec, ddict))

        self.by_blkid_uuid: typing.DefaultDict[bytes, typing.List[Blkid]] = collections.defaultdict(list)
        for blkid in self.blkids:
            self.by_blkid_uuid[blkid.attrs_dict[b'UUID']].append(blkid)
        for uuid, ddict in self.by_blkid_uuid.items():
            if len(ddict) != 1:
                ddict_str = str(ddict)
                ddict_bytes = ddict_str.encode('UTF-8')
                raise ValueError(b'duplicate blkid uuid found: %s, %s' % (uuid, ddict_bytes))

    def __str__(self):
        """Convert to str."""
        list_ = [
            ('by_fsent_device_spec', self.by_fsent_device_spec),
            ('by_fsent_uuid', self.by_fsent_uuid),
            ('by_fsent_mount_point', self.by_mount_point),
            ('by_blkid_device_spec', self.by_blkid_device_spec),
            ('by_blkid_uuid', self.by_blkid_uuid),
            ('mount_point_from_device_path', self.mount_point_from_device_path),
            ]
        return '\n'.join(description + ' ' + pprint.pformat(element) for description, element in list_)

    def check_one_device(self, device: bytes) -> None:
        """Perform a preflight check of a single device."""
        if device in self.by_blkid_device_spec:
            if len(self.by_blkid_device_spec[device]) != 1:
                raise ValueError("Too many blkid device_spec's for {device}")
            uuid = self.by_blkid_device_spec[device][0].attrs_dict[b'UUID']
            if len(self.by_blkid_uuid[uuid]) != 1:
                raise ValueError("Too many blkid UUID's for {device} - {uuid}")
        else:
            os.write(1, b'Warning: device %s not found in blkid output\n' % (device, ))

    def check_transfer(self, from_device: bytes, to_device: bytes) -> None:
        """Perform a preflight check of a transfer from from_device to to_device."""
        if from_device == to_device:
            from_dev_str = from_device.decode('UTF-8')
            to_dev_str = to_device.decode('UTF-8')
            raise ValueError(f'from_device {from_dev_str} is the same as to_device {to_dev_str}. Do not copy a device to itself')
        self.check_one_device(from_device)
        self.check_one_device(to_device)


def main() -> None:
    """Read the fstab and blkid, and print result."""
    partition_info = PartitionInfo(b'/etc/fstab')
    print(partition_info)
    partition_info.check_transfer(b'/dev/sdf1', b'/dev/sdg1')


if __name__ == '__main__':
    main()
