#!/usr/libexec/platform-python

#
# split git for Qubes OS
# Copyright (C) 2018  Wojtek Porczyk <woju@invisiblethingslab.com>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License along
# with this program; if not, write to the Free Software Foundation, Inc.,
# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
#

#
# git remote add qrexec://<vm>/<path>[?keyring=<trustedkeys.kpx>[&keyring=...]][&list_head_only=0]
#   <path>
#   - does not contain '/'
#   - means /home/user/QubesGit/<path>
#   - is possibly a symlink somewhere else
#   - is an qrexec argument for policy purposes (git.List[HeadOnly], git.Fetch)
#
#   keyring= may specify additional trusted keyrings; this is passed directly
#   to gpgv as --keyring; see gpgv(1)
#
#   list_head_only= if true (the default) means that only the latest tag is
#   listed; set to something false to list all tags
#

import argparse
import asyncio
import asyncio.streams
import configparser
import functools
import hashlib
import logging
import os
import pathlib
import string
import subprocess
import sys
import urllib.parse
import zlib

MAX_BUFFER_SIZE = 10000000  # 10 MB
ALLOWED_IN_OBJECT_ID = set(string.hexdigits.encode('ascii'))
ALLOWED_IN_TAG_NAME = set(
    (string.ascii_letters + string.digits + '.-_').encode('ascii'))

TIMEOUT = 5  # in seconds, for single qrexec call

CAPABILITIES = [
    'fetch',
    'option',
]

BEGIN_PGP_SIGNATURE = b'-----BEGIN PGP SIGNATURE-----'

parser = argparse.ArgumentParser()
parser.add_argument('repository')
parser.add_argument('url')

class VerificationError(Exception):
    pass

class GitObject:
    # pylint: disable=too-few-public-methods
    def __init__(self, objectid, *, untrusted_data):
        if objectid != hashlib.sha1(untrusted_data).hexdigest():
            raise VerificationError()
        self.objectid = objectid
        self._data = untrusted_data
        del untrusted_data

        untrusted_header, untrusted_content = self._data.split(b'\0', 1)
        untrusted_header = untrusted_header.decode('ascii')
        untrusted_type, untrusted_size = untrusted_header.split(' ')
        del untrusted_header

        if not untrusted_size.isdigit():
            raise VerificationError()
        untrusted_size = int(untrusted_size)
        if len(untrusted_content) != untrusted_size:
            raise VerificationError()
        del untrusted_size

        if untrusted_type not in ('tag', 'commit', 'tree', 'blob'):
            raise VerificationError()
        self.type = untrusted_type
        del untrusted_type

        self.content = untrusted_content
        del untrusted_content

        if self.type in ('tag', 'commit'):
            header, _ = self.content.rstrip(b'\n').split(b'\n\n', 1)
            header = header.decode()
            self._bag = dict(line.split(' ', 1) for line in header.split('\n'))

    def __getitem__(self, *args, **kwds):
        return self._bag.__getitem__(*args, **kwds)

    def write_to(self, remote):
        dirpath = remote.git_dir / 'objects' / self.objectid[:2]
        os.makedirs(str(dirpath), exist_ok=True)
        with open(str(dirpath / self.objectid[2:]), 'wb') as file:
            file.write(zlib.compress(self._data))

async def stdio(*, loop=None):
    # loosely based on https://gist.github.com/nathan-hoad/8966377
    if loop is None:
        loop = asyncio.get_event_loop()

    reader = asyncio.StreamReader()
    await loop.connect_read_pipe(
        functools.partial(asyncio.StreamReaderProtocol, reader), sys.stdin)

    os.set_blocking(sys.stdout.fileno(), False)

    writer = asyncio.StreamWriter(*(await loop.connect_write_pipe(
        asyncio.streams.FlowControlMixin, sys.stdout)), None, loop)

    return reader, writer

class AbstractTagVerifier:
    def __init__(self, keyrings):
        self.keyrings = keyrings

    async def verify_tag(self, untrusted_tag):
        logging.debug('verify_tag')
        untrusted_data = untrusted_tag.content
        if BEGIN_PGP_SIGNATURE not in untrusted_data:
            raise VerificationError()
        index = untrusted_data.index(BEGIN_PGP_SIGNATURE)
        untrusted_signature = untrusted_data[index:]
        untrusted_tag = untrusted_data[:index]
        del untrusted_data

        proc, sigfile, tagfile = await self.gpgv()

        # NOTE: open()ing and write()ing has to be done concurrently for both
        # pipes, because we don't know how remote end is coordinated. For
        # os.pipe() this is not relevant, but os.mkfifo() case is tricky:
        # open(2) for writing end returns only after reading end has been
        # opened.
        sigwriter, tagwriter = await asyncio.gather(
            self.open_and_write_to_fd(sigfile, untrusted_signature),
            self.open_and_write_to_fd(tagfile, untrusted_tag))
        del untrusted_signature

        self.cleanup()
        await proc.wait()
        if proc.returncode != 0:
            raise VerificationError()

        tag = untrusted_tag
        del untrusted_tag
        return tag

    def get_common_gpgv_args(self):
#       yield '--quiet',
        for keyring in self.keyrings:
            yield '--keyring={}'.format(keyring)

    async def open_and_write_to_fd(self, file, input, *, loop=None):
        # pylint: disable=redefined-builtin
        if loop is None:
            loop = asyncio.get_event_loop()
        fd = await loop.run_in_executor(None, open, file, 'wb')
        writer = asyncio.StreamWriter(*(await loop.connect_write_pipe(
                asyncio.streams.FlowControlMixin, fd)), None, loop)
        writer.write(input)
        await writer.drain()
        writer.close()

    @classmethod
    def select_tag_verifier(cls, *args, **kwds):
        if pathlib.Path('/usr/bin/gpgv2').is_file():
            cls = LegacyTagVerifier
        return cls(*args, **kwds)

    # the following methods can be overloaded:

    async def gpgv(self):
        raise NotImplementedError()

    def cleanup(self):
        pass


class TagVerifier(AbstractTagVerifier):
    async def gpgv(self):
        sig_read, sig_write = os.pipe()
        tag_read, tag_write = os.pipe()

        proc = await asyncio.create_subprocess_exec('gpgv',
            *self.get_common_gpgv_args(),
            '--enable-special-filenames',
            '--', '-&{}'.format(sig_read), '-&{}'.format(tag_read),
            pass_fds=(sig_read, tag_read),
#           stderr=subprocess.DEVNULL,
        )
        os.close(sig_read)
        os.close(tag_read)

        return proc, sig_write, tag_write

class LegacyTagVerifier(AbstractTagVerifier):
    # fc25, old gpgv without --enable-special-filenames
    async def gpgv(self):
        prefix = '/tmp/splitgit-{}-'.format(os.getpid())
        self.sigfifo = prefix + 'sig'
        self.tagfifo = prefix + 'tag'

        os.mkfifo(self.sigfifo)
        os.mkfifo(self.tagfifo)

        proc = await asyncio.create_subprocess_exec('gpgv',
                *self.get_common_gpgv_args(),
                '--', self.sigfifo, self.tagfifo,
                stdin=None,
                stdout=None,
                stderr=None,
            )

        return proc, self.sigfifo, self.tagfifo

    def cleanup(self):
        super().cleanup()
        os.unlink(self.sigfifo)
        os.unlink(self.tagfifo)


class Remote:
    def __init__(self, name, url):
        self.name = name

        url = urllib.parse.urlsplit(url)
        if url.scheme != 'qrexec' or url.fragment:
            raise ValueError('invalid remote URI')

        query = {} if not url.query else urllib.parse.parse_qs(url.query,
            strict_parsing=True, errors='strict')

        self._keyrings = query.pop('keyring') if 'keyring' in query else ()

        try:
            list_head_only = query.pop('list_head_only')
        except KeyError:
            self._list_head_only = True
        else:
            self._list_head_only = configparser.ConfigParser.BOOLEAN_STATES[
                list_head_only[-1].casefold()]

        if query:
            raise ValueError('invalid remote URI')

        self.vmname = url.netloc
        self.path = url.path.lstrip('/')
        if '/' in self.path:
            raise ValueError('invalid remote URI')

        self.git_dir = pathlib.Path(os.environ['GIT_DIR'])
        self._objects_seen = set()

    async def list(self):
        ret = []
        untrusted_data = await self._qrexec(
            'git.ListHeadOnly' if self._list_head_only else 'git.List')

        if not untrusted_data and self._list_head_only:
            logging.warning(
                'warning: list_head_only is on and HEAD has no signed tag')
            return []

        for untrusted_line in untrusted_data.rstrip(b'\n').split(b'\n'):
            untrusted_commitid, untrusted_tagid, untrusted_tag = (
                untrusted_line.split(b' '))

            commitid = self._verify_objid(untrusted_objid=untrusted_commitid)
            tagid = self._verify_objid(untrusted_objid=untrusted_tagid)

            if set(untrusted_tag) - ALLOWED_IN_TAG_NAME:
                raise VerificationError()
            tag = untrusted_tag
            del untrusted_tag

            # refuse temptation to yield: first check all lines, then return
            ret.append([commitid.decode(), tagid.decode(), tag.decode()])

        return ret

    async def fetch(self, objid, refname):
        if not refname.startswith('refs/tags/'):
            raise VerificationError()
        tagname = refname[10:]

        untrusted_tag = await self._qrexec_fetch(objid)
        await self.verify_tag(untrusted_tag=untrusted_tag)

        if untrusted_tag['tag'] != tagname or untrusted_tag['type'] != 'commit':
            raise VerificationError()

        tag = untrusted_tag
        del untrusted_tag

        tag.write_to(self)
        await self._fetch_recursive(tag['object'], type=tag['type'])
        return tag

    @staticmethod
    async def _readn_until_eof(stream, n, iden='<unknown>'):
        # stream.read(n) where n > 0 does not read until EOF, but just one
        # chunk, therefore we need this...

        logging.debug('_readn(%r, %r, iden=%r)', stream, n, iden)
        try:
            await stream.readexactly(n + 1)
        except asyncio.IncompleteReadError as e:
            return e.partial
        else:
            logging.critical('remote sent too big file for %s', iden)
            sys.exit(1)

    async def _qrexec(self, rpcname, input=None):
        # pylint: disable=redefined-builtin
        # TODO timeout
        rpcfull = '{}+{}'.format(rpcname, self.path)
        if pathlib.Path('/usr/bin/qrexec-client-vm').is_file():
            args = ('/usr/bin/qrexec-client-vm', self.vmname, rpcfull)
        elif pathlib.Path('/usr/bin/qrexec-client').is_file():
            args = ('/usr/bin/qrexec-client', '-d', self.vmname,
                'DEFAULT:QUBESRPC {} dom0'.format(rpcfull))

        proc = await asyncio.create_subprocess_exec(*args,
            stdin=(subprocess.PIPE
                if input is not None else subprocess.DEVNULL),
            stdout=subprocess.PIPE,
            stderr=None)

        if input is not None:
            proc.stdin.write(input)
            proc.stdin.close()
        stdout = await asyncio.wait_for(
            self._readn_until_eof(proc.stdout, MAX_BUFFER_SIZE,
                iden='{} input={!r}'.format(rpcfull, input)),
            timeout=TIMEOUT)
        await proc.wait()
        if proc.returncode != 0:
            raise VerificationError()

        return stdout

    async def _qrexec_fetch(self, objectid):
        return GitObject(objectid, untrusted_data=(
            await self._qrexec('git.Fetch', input=objectid.encode('ascii'))))

    @staticmethod
    async def _subprocess_check_output(*args, input=None,
            stdout=subprocess.PIPE, stderr=subprocess.PIPE, **kwds):
        proc = await asyncio.create_subprocess_exec(*args, stdout=stdout,
                stderr=stderr, **kwds)
        stdout, stderr = await proc.communicate(input)
        if proc.returncode != 0:
            raise subprocess.CalledProcessError(proc.returncode, args,
                    output=stdout, stderr=stderr)
#       logging.debug('_subprocess_check_output(%s) → %d (%r, %r)',
#           ', '.join(repr(arg) for arg in args),
#           proc.returncode, stdout, stderr)
        return stdout

    async def _fetch_recursive(self, objectid, *, type=None):
        # pylint: disable=redefined-builtin
        logging.debug('_fetch_recursive(%r)', objectid)

        if objectid in self._objects_seen:
            logging.debug('  object %s already seen, skipping', objectid)
            return

        # check if we already have this object
        try:
            if type is None:
                type = (await self._subprocess_check_output(
                    'git', 'cat-file', '-t', objectid)).decode().strip()
            content = await self._subprocess_check_output(
                'git', 'cat-file', type, objectid)

            logging.debug('  loaded existing object %s', objectid)
            untrusted_data = (
                '{} {}'.format(type, len(content)).encode() + b'\0' + content)
            obj = GitObject(objectid, untrusted_data=untrusted_data)

        # nope, let's fetch it
        except subprocess.CalledProcessError:
            obj = await self._qrexec_fetch(objectid)
            if type is not None:
                assert obj.type == type
            obj.write_to(self)

        self._objects_seen.add(objectid)

        if obj.type == 'commit':
            await self._fetch_recursive(obj['tree'])
            try:
                parent = obj['parent']
            except KeyError:
                pass
            else:
                await self._fetch_recursive(parent, type='commit')
            return

        if obj.type == 'tree':
            for line in (await self._subprocess_check_output(
                        'git', 'ls-tree', objectid)
                    ).decode().rstrip('\n').split('\n'):
                _, objtype, objid, *args = line.split()

                logging.debug('objtype %s objid %s', objtype, objid)
                assert objtype in ('blob', 'tree', 'commit'), \
                    'unsupported object type in tree:' + objtype

                if objtype == 'commit':
                    # TODO submodules
                    logging.warning('submodules are unsupported!'
                        ' found submodule %s with commit %s', args[0], objid)
                    continue

                await self._fetch_recursive(objid, type=objtype)

            return


    async def verify_tag(self, untrusted_tag):
        tag_verifier = TagVerifier.select_tag_verifier(keyrings=self._keyrings)
        return await tag_verifier.verify_tag(untrusted_tag)

    @staticmethod
    def _verify_objid(untrusted_objid):
        if (len(untrusted_objid) != 40
        or set(untrusted_objid) - ALLOWED_IN_OBJECT_ID):
            raise VerificationError()
        objid = untrusted_objid
        del untrusted_objid
        return objid


async def main(args=None):
    # for debugging, set level to NOTSET, this also disables 'option verbosity'
    logging.basicConfig(
        level=logging.WARNING,
#       level=logging.NOTSET,
        format='{}: %(message)s'.format(pathlib.Path(sys.argv[0]).name),
    )

    args = parser.parse_args(args)
    remote = Remote(args.repository, args.url)
    logging.debug('remote name %s vmname %s', remote.name, remote.vmname)

    # TODO actually use this
    followtags = False

    stdin, stdout = await stdio()

    async def out(*args):
        logging.debug('→ %s', ' '.join(map(str, args)))
        stdout.write(' '.join(args).encode() + b'\n')
        await stdout.drain()

    while not stdin.at_eof():
        line = (await stdin.readline()).decode().strip()
        logging.debug('← %s', line)

        if line == 'capabilities':
            for cap in CAPABILITIES:
                await out(cap)
            await out()
            continue

        if line.startswith('list'):
            for commitid, tagid, tag in (await remote.list()):
                await out('{} refs/tags/{}'.format(tagid, tag))
                await out('{} refs/tags/{}^{{}}'.format(commitid, tag))

            await out()
            continue

        if line.startswith('fetch'):
            while line:
                _, objectid, ref = line.split()
                logging.debug('objectid %r ref %r', objectid, ref)
                await remote.fetch(objectid, ref)
                line = (await stdin.readline()).decode().strip()

            await out()

            logging.debug('fetch ended')
            continue

        if line.startswith('option'):
            _, option, args = line.split(' ', 2)

            if option == 'followtags':
                followtags = (args.lower() == 'true')
                await out('ok')
                continue

            if option == 'verbosity':
                if logging.root.level != logging.NOTSET:
                    logging.root.setLevel(logging.ERROR - 10 * int(args))
                await out('ok')
                continue

            await out('unsupported')
            continue

    logging.debug('loop ended; sys.stdin.readable')

if __name__ == '__main__':
    loop = asyncio.get_event_loop()
    loop.run_until_complete(main())
    loop.close()

# vim: ts=4 sts=4 sw=4 et
