#!/usr/bin/env python3
# Author: Sean Pesce

# References:
#   https://stackoverflow.com/questions/19705785/python-3-simple-https-server
#   https://docs.python.org/3/library/ssl.html
#   https://docs.python.org/3/library/http.server.html

# Shell command to create a self-signed TLS certificate and private key:
#    openssl req -new -newkey rsa:4096 -x509 -sha256 -days 365 -nodes -out cert.crt -keyout private.key

import http.server
import json
import ssl
import sys
from urllib.parse import urlparse

# Path the target app fetches as its config (any query string is ignored).
CONFIG_PATH = '/config'

# Separator between the canary prefix and the key-path suffix in each tagged
# value, e.g. "<canary>__cl.services.ct.kinesis.defaultChannel". The canary
# stays a leading substring so DOM Invader matches it; the suffix tells you
# which key reached the sink.
CANARY_SEP = '__'


class CORSRequestHandler(http.server.SimpleHTTPRequestHandler):
    """Serves files from the working directory with permissive CORS headers so a
    browser on another origin can read the response body (e.g. an app fetching
    ./config via ?configUrl). Reflects the request Origin to support credentialed
    requests; falls back to '*' when no Origin is sent."""

    def _send_cors_headers(self):
        origin = self.headers.get('Origin', '*')
        self.send_header('Access-Control-Allow-Origin', origin)
        if origin != '*':
            self.send_header('Access-Control-Allow-Credentials', 'true')
            self.send_header('Vary', 'Origin')
        self.send_header('Access-Control-Allow-Methods', 'GET, OPTIONS')
        self.send_header('Access-Control-Allow-Headers', '*')

    def end_headers(self):
        # Injected just before the blank line that terminates the header block,
        # so CORS headers are added to every response (including file serving).
        self._send_cors_headers()
        super().end_headers()

    def _log_request_headers(self):
        print(f'\n===== {self.command} {self.path} from {self.client_address[0]} =====')
        for name, value in self.headers.items():
            print(f'  {name}: {value}')
        print('=' * 60, flush=True)

    def _tag(self, node, path):
        """Walk the parsed config and replace every string leaf with
        "<canary><sep><key-path>". Numbers/bools/null are left as-is to avoid
        breaking type-sensitive logic in the app. Returns (new_node, count)."""
        canary = type(self).CANARY
        if isinstance(node, dict):
            count = 0
            out = {}
            for k, v in node.items():
                child_path = f'{path}.{k}' if path else k
                out[k], n = self._tag(v, child_path)
                count += n
            return out, count
        if isinstance(node, list):
            count = 0
            out = []
            for i, v in enumerate(node):
                out_v, n = self._tag(v, f'{path}[{i}]')
                out.append(out_v)
                count += n
            return out, count
        if isinstance(node, str):
            # Some configs store arrays/objects as JSON-encoded strings (the app
            # parses them back into the expected type). Detect those, tag the
            # leaves inside, and re-serialize so the value stays parseable.
            if node.lstrip()[:1] in ('[', '{'):
                try:
                    inner = json.loads(node)
                except ValueError:
                    inner = None
                if isinstance(inner, (list, dict)):
                    tagged_inner, n = self._tag(inner, path)
                    return json.dumps(tagged_inner), n
            tag = f'{canary}{CANARY_SEP}{path}'
            if type(self).MODE == 'replace':
                return tag, 1
            return f'{node}{tag}', 1  # append: preserve original value so the app still works
        return node, 0  # leave numbers / bools / null unchanged

    def _serve_config(self):
        """Read the user-supplied config file fresh each request (edit it live)
        and tag every string value with the canary before serving."""
        try:
            with open(type(self).CONFIG_FILE) as f:
                data = json.load(f)
        except OSError as e:
            self.send_error(500, f'Cannot read config file: {e}')
            return
        except ValueError as e:
            self.send_error(500, f'Config file is not valid JSON: {e}')
            return
        tagged, count = self._tag(data, '')
        body = json.dumps(tagged, indent=2).encode()
        print(f'  -> served {type(self).CONFIG_FILE}: tagged {count} string value(s) '
              f'with canary {type(self).CANARY!r}', flush=True)
        self.send_response(200)
        self.send_header('Content-Type', 'application/json')
        self.send_header('Content-Length', str(len(body)))
        self.end_headers()  # CORS headers injected here
        self.wfile.write(body)

    def do_GET(self):
        self._log_request_headers()
        if urlparse(self.path).path.rstrip('/') == CONFIG_PATH:
            self._serve_config()
            return
        super().do_GET()

    def do_OPTIONS(self):
        self._log_request_headers()
        self.send_response(204)
        self.end_headers()


def serve(host, port, cert_fpath, privkey_fpath, config_file, canary, mode):
    CORSRequestHandler.CANARY = canary
    CORSRequestHandler.CONFIG_FILE = config_file
    CORSRequestHandler.MODE = mode
    context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)  # Might need to use ssl.PROTOCOL_TLS for older versions of Python
    context.load_cert_chain(certfile=cert_fpath, keyfile=privkey_fpath, password='')
    server_address = (host, port)
    httpd = http.server.HTTPServer(server_address, CORSRequestHandler)
    httpd.socket = context.wrap_socket(httpd.socket, server_side=True)
    print(f'Serving {CONFIG_PATH} from {config_file!r}; '
          f'{mode}-tagging every string value with canary {canary!r}.', flush=True)
    httpd.serve_forever()


if __name__ == '__main__':
    if len(sys.argv) < 5:
        print(f'Usage:\n  {sys.argv[0]} <port> <PEM certificate file> '
              f'<private key file> <config file> [canary] [append|replace]')
        sys.exit()

    PORT = int(sys.argv[1])
    CERT_FPATH = sys.argv[2]
    PRIVKEY_FPATH = sys.argv[3]
    CONFIG_FILE = sys.argv[4]
    CANARY = sys.argv[5] if len(sys.argv) > 5 else 'domInvaderCanary'
    MODE = sys.argv[6] if len(sys.argv) > 6 else 'append'
    if MODE not in ('append', 'replace'):
        print(f"Invalid mode {MODE!r}: expected 'append' or 'replace'")
        sys.exit()

    serve('0.0.0.0', PORT, CERT_FPATH, PRIVKEY_FPATH, CONFIG_FILE, CANARY, MODE)

