#!/usr/bin/env python3
# Author: Sean Pesce

# References:
#   https://stackoverflow.com/questions/19705785/python-3-simple-https-server
#   https://docs.python.org/3/library/ssl.html
#   https://docs.python.org/3/library/http.server.html

# Shell command to create a self-signed TLS certificate and private key:
#    openssl req -new -newkey rsa:4096 -x509 -sha256 -days 365 -nodes -out cert.crt -keyout private.key

import http.server
import ssl
import sys


class CORSRequestHandler(http.server.SimpleHTTPRequestHandler):
    """Serves files from the working directory with permissive CORS headers so a
    browser on another origin can read the response body (e.g. an app fetching
    ./config via ?configUrl). Reflects the request Origin to support credentialed
    requests; falls back to '*' when no Origin is sent."""

    def _send_cors_headers(self):
        origin = self.headers.get('Origin', '*')
        self.send_header('Access-Control-Allow-Origin', origin)
        if origin != '*':
            self.send_header('Access-Control-Allow-Credentials', 'true')
            self.send_header('Vary', 'Origin')
        self.send_header('Access-Control-Allow-Methods', 'GET, OPTIONS')
        self.send_header('Access-Control-Allow-Headers', '*')

    def end_headers(self):
        # Injected just before the blank line that terminates the header block,
        # so CORS headers are added to every response (including file serving).
        self._send_cors_headers()
        super().end_headers()

    def do_OPTIONS(self):
        self.send_response(204)
        self.end_headers()


def serve(host, port, cert_fpath, privkey_fpath):
    context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)  # Might need to use ssl.PROTOCOL_TLS for older versions of Python
    context.load_cert_chain(certfile=cert_fpath, keyfile=privkey_fpath, password='')
    server_address = (host, port)
    httpd = http.server.HTTPServer(server_address, CORSRequestHandler)
    httpd.socket = context.wrap_socket(httpd.socket, server_side=True)
    httpd.serve_forever()


if __name__ == '__main__':
    if len(sys.argv) < 4:
        print(f'Usage:\n  {sys.argv[0]} <port> <PEM certificate file> <private key file>')
        sys.exit()
    
    PORT = int(sys.argv[1])
    CERT_FPATH = sys.argv[2]
    PRIVKEY_FPATH = sys.argv[3]
    
    serve('0.0.0.0', PORT, CERT_FPATH, PRIVKEY_FPATH)

