citadel

My dotfiles, scripts and nix configs
git clone git://jb55.com/citadel
Log | Files | Refs | README | LICENSE

av98 (61330B)


      1 #!/usr/bin/env python3
      2 # AV-98 Gemini client
      3 # Dervied from VF-1 (https://github.com/solderpunk/VF-1),
      4 # (C) 2019, 2020 Solderpunk <solderpunk@sdf.org>
      5 # With contributions from:
      6 #  - danceka <hannu.hartikainen@gmail.com>
      7 #  - <jprjr@tilde.club>
      8 #  - <vee@vnsf.xyz>
      9 #  - Klaus Alexander Seistrup <klaus@seistrup.dk>
     10 
     11 import argparse
     12 import cmd
     13 import cgi
     14 import codecs
     15 import collections
     16 import datetime
     17 import fnmatch
     18 import getpass
     19 import glob
     20 import hashlib
     21 import io
     22 import mimetypes
     23 import os
     24 import os.path
     25 import random
     26 import shlex
     27 import shutil
     28 import socket
     29 import sqlite3
     30 import ssl
     31 from ssl import CertificateError
     32 import subprocess
     33 import sys
     34 import tempfile
     35 import time
     36 import urllib.parse
     37 import uuid
     38 import webbrowser
     39 
     40 try:
     41     import ansiwrap as textwrap
     42 except ModuleNotFoundError:
     43     import textwrap
     44 
     45 try:
     46     from cryptography import x509
     47     from cryptography.hazmat.backends import default_backend
     48     _HAS_CRYPTOGRAPHY = True
     49     _BACKEND = default_backend()
     50 except ModuleNotFoundError:
     51     _HAS_CRYPTOGRAPHY = False
     52 
     53 _VERSION = "1.0.2dev"
     54 
     55 _MAX_REDIRECTS = 5
     56 
     57 # Command abbreviations
     58 _ABBREVS = {
     59     "a":    "add",
     60     "b":    "back",
     61     "bb":   "blackbox",
     62     "bm":   "bookmarks",
     63     "book": "bookmarks",
     64     "f":    "fold",
     65     "fo":   "forward",
     66     "g":    "go",
     67     "h":    "history",
     68     "hist": "history",
     69     "l":    "less",
     70     "n":    "next",
     71     "p":    "previous",
     72     "prev": "previous",
     73     "q":    "quit",
     74     "r":    "reload",
     75     "s":    "save",
     76     "se":   "search",
     77     "/":    "search",
     78     "t":    "tour",
     79     "u":    "up",
     80 }
     81 
     82 _MIME_HANDLERS = {
     83     "application/pdf":      "zathura %s",
     84     "audio/mpeg":           "mpg123 %s",
     85     "audio/ogg":            "ogg123 %s",
     86     "image/*":              "sxiv %s",
     87     "text/html":            "lynx -dump -force_html %s",
     88     "text/plain":           "less --quit-if-one-screen %s",
     89     "text/gemini":          "less --quit-if-one-screen %s",
     90 }
     91 
     92 # monkey-patch Gemini support in urllib.parse
     93 # see https://github.com/python/cpython/blob/master/Lib/urllib/parse.py
     94 urllib.parse.uses_relative.append("gemini")
     95 urllib.parse.uses_netloc.append("gemini")
     96 
     97 
     98 def fix_ipv6_url(url):
     99     if not url.count(":") > 2: # Best way to detect them?
    100         return url
    101     # If there's a pair of []s in there, it's probably fine as is.
    102     if "[" in url and "]" in url:
    103         return url
    104     # Easiest case is a raw address, no schema, no path.
    105     # Just wrap it in square brackets and whack a slash on the end
    106     if "/" not in url:
    107         return "[" + url + "]/"
    108     # Now the trickier cases...
    109     if "://" in url:
    110         schema, schemaless = url.split("://")
    111     else:
    112         schema, schemaless = None, url
    113     if "/" in schemaless:
    114         netloc, rest = schemaless.split("/",1)
    115         schemaless = "[" + netloc + "]" + "/" + rest
    116     if schema:
    117         return schema + "://" + schemaless
    118     return schemaless
    119 
    120 standard_ports = {
    121         "gemini": 1965,
    122         "gopher": 70,
    123 }
    124 
    125 class GeminiItem():
    126 
    127     def __init__(self, url, name=""):
    128         if "://" not in url:
    129             url = "gemini://" + url
    130         self.url = fix_ipv6_url(url)
    131         self.name = name
    132         parsed = urllib.parse.urlparse(self.url)
    133         self.scheme = parsed.scheme
    134         self.host = parsed.hostname
    135         self.port = parsed.port or standard_ports.get(self.scheme, 0)
    136         self.path = parsed.path
    137 
    138     def root(self):
    139         return GeminiItem(self._derive_url("/"))
    140 
    141     def up(self):
    142         pathbits = list(os.path.split(self.path.rstrip('/')))
    143         # Don't try to go higher than root
    144         if len(pathbits) == 1:
    145             return self
    146         # Get rid of bottom component
    147         pathbits.pop()
    148         new_path = os.path.join(*pathbits)
    149         return GeminiItem(self._derive_url(new_path))
    150 
    151     def query(self, query):
    152         query = urllib.parse.quote(query)
    153         return GeminiItem(self._derive_url(query=query))
    154 
    155     def _derive_url(self, path="", query=""):
    156         """
    157         A thin wrapper around urlunparse which avoids inserting standard ports
    158         into URLs just to keep things clean.
    159         """
    160         return urllib.parse.urlunparse((self.scheme,
    161             self.host if self.port == standard_ports[self.scheme] else self.host + ":" + str(self.port),
    162             path or self.path, "", query, ""))
    163 
    164     def absolutise_url(self, relative_url):
    165         """
    166         Convert a relative URL to an absolute URL by using the URL of this
    167         GeminiItem as a base.
    168         """
    169         return urllib.parse.urljoin(self.url, relative_url)
    170 
    171     def to_map_line(self, name=None):
    172         if name or self.name:
    173             return "=> {} {}\n".format(self.url, name or self.name)
    174         else:
    175             return "=> {}\n".format(self.url)
    176 
    177     @classmethod
    178     def from_map_line(cls, line, origin_gi):
    179         assert line.startswith("=>")
    180         assert line[2:].strip()
    181         bits = line[2:].strip().split(maxsplit=1)
    182         bits[0] = origin_gi.absolutise_url(bits[0])
    183         return cls(*bits)
    184 
    185 CRLF = '\r\n'
    186 
    187 # Cheap and cheerful URL detector
    188 def looks_like_url(word):
    189     return "." in word and word.startswith("gemini://")
    190 
    191 # GeminiClient Decorators
    192 def needs_gi(inner):
    193     def outer(self, *args, **kwargs):
    194         if not self.gi:
    195             print("You need to 'go' somewhere, first")
    196             return None
    197         else:
    198             return inner(self, *args, **kwargs)
    199     outer.__doc__ = inner.__doc__
    200     return outer
    201 
    202 def restricted(inner):
    203     def outer(self, *args, **kwargs):
    204         if self.restricted:
    205             print("Sorry, this command is not available in restricted mode!")
    206             return None
    207         else:
    208             return inner(self, *args, **kwargs)
    209     outer.__doc__ = inner.__doc__
    210     return outer
    211 
    212 class GeminiClient(cmd.Cmd):
    213 
    214     def __init__(self, restricted=False):
    215         cmd.Cmd.__init__(self)
    216 
    217         # Set umask so that nothing we create can be read by anybody else.
    218         # The certificate cache and TOFU database contain "browser history"
    219         # type sensitivie information.
    220         os.umask(0o077)
    221 
    222         # Find config directory
    223         ## Look for something pre-existing
    224         for confdir in ("~/.av98/", "~/.config/av98/"):
    225             confdir = os.path.expanduser(confdir)
    226             if os.path.exists(confdir):
    227                 self.config_dir = confdir
    228                 break
    229         ## Otherwise, make one in .config if it exists
    230         else:
    231             if os.path.exists(os.path.expanduser("~/.config/")):
    232                 self.config_dir = os.path.expanduser("~/.config/av98/")
    233             else:
    234                 self.config_dir = os.path.expanduser("~/.av98/")
    235             print("Creating config directory {}".format(self.config_dir))
    236             os.makedirs(self.config_dir)
    237 
    238         self.no_cert_prompt = "\x1b[38;5;76m" + "AV-98" + "\x1b[38;5;255m" + "> " + "\x1b[0m"
    239         self.cert_prompt = "\x1b[38;5;202m" + "AV-98" + "\x1b[38;5;255m" + "+cert> " + "\x1b[0m"
    240         self.prompt = self.no_cert_prompt
    241         self.gi = None
    242         self.history = []
    243         self.hist_index = 0
    244         self.idx_filename = ""
    245         self.index = []
    246         self.index_index = -1
    247         self.lookup = self.index
    248         self.marks = {}
    249         self.page_index = 0
    250         self.permanent_redirects = {}
    251         self.previous_redirectors = set()
    252         self.restricted = restricted
    253         self.tmp_filename = ""
    254         self.visited_hosts = set()
    255         self.waypoints = []
    256 
    257         self.client_certs = {
    258             "active": None
    259         }
    260         self.active_cert_domains = []
    261         self.active_is_transient = False
    262         self.transient_certs_created = []
    263 
    264         self.options = {
    265             "debug" : False,
    266             "ipv6" : True,
    267             "timeout" : 10,
    268             "width" : 80,
    269             "auto_follow_redirects" : True,
    270             "gopher_proxy" : "localhost:1965",
    271             "tls_mode" : "tofu",
    272         }
    273 
    274         self.log = {
    275             "start_time": time.time(),
    276             "requests": 0,
    277             "ipv4_requests": 0,
    278             "ipv6_requests": 0,
    279             "bytes_recvd": 0,
    280             "ipv4_bytes_recvd": 0,
    281             "ipv6_bytes_recvd": 0,
    282             "dns_failures": 0,
    283             "refused_connections": 0,
    284             "reset_connections": 0,
    285             "timeouts": 0,
    286         }
    287 
    288         self._connect_to_tofu_db()
    289 
    290     def _connect_to_tofu_db(self):
    291 
    292         db_path = os.path.join(self.config_dir, "tofu.db")
    293         self.db_conn = sqlite3.connect(db_path)
    294         self.db_cur = self.db_conn.cursor()
    295 
    296         self.db_cur.execute("""CREATE TABLE IF NOT EXISTS cert_cache
    297             (hostname text, address text, fingerprint text,
    298             first_seen date, last_seen date, count integer)""")
    299 
    300     def _go_to_gi(self, gi, update_hist=True, handle=True):
    301         """This method might be considered "the heart of AV-98".
    302         Everything involved in fetching a gemini resource happens here:
    303         sending the request over the network, parsing the response if
    304         its a menu, storing the response in a temporary file, choosing
    305         and calling a handler program, and updating the history."""
    306         # Don't try to speak to servers running other protocols
    307         if gi.scheme in ("http", "https"):
    308             webbrowser.open_new_tab(gi.url)
    309             return
    310         elif gi.scheme == "gopher" and not self.options.get("gopher_proxy", None):
    311             print("""AV-98 does not speak Gopher natively.
    312 However, you can use `set gopher_proxy hostname:port` to tell it about a
    313 Gopher-to-Gemini proxy (such as a running Agena instance), in which case
    314 you'll be able to transparently follow links to Gopherspace!""")
    315             return
    316         elif gi.scheme not in ("gemini", "gopher"):
    317             print("Sorry, no support for {} links.".format(gi.scheme))
    318             return
    319         # Obey permanent redirects
    320         if gi.url in self.permanent_redirects:
    321             new_gi = GeminiItem(self.permanent_redirects[gi.url], name=gi.name)
    322             self._go_to_gi(new_gi)
    323             return
    324 
    325         # Be careful with client certificates!
    326         # Are we crossing a domain boundary?
    327         if self.active_cert_domains and gi.host not in self.active_cert_domains:
    328             if self.active_is_transient:
    329                 print("Permanently delete currently active transient certificate?")
    330                 resp = input("Y/N? ")
    331                 if resp.strip().lower() in ("y", "yes"):
    332                     print("Destroying certificate.")
    333                     self._deactivate_client_cert()
    334                 else:
    335                     print("Staying here.")
    336 
    337             else:
    338                 print("PRIVACY ALERT: Deactivate client cert before connecting to a new domain?")
    339                 resp = input("Y/N? ")
    340                 if resp.strip().lower() in ("n", "no"):
    341                     print("Keeping certificate active for {}".format(gi.host))
    342                 else:
    343                     print("Deactivating certificate.")
    344                     self._deactivate_client_cert()
    345 
    346         # Suggest reactivating previous certs
    347         if not self.client_certs["active"] and gi.host in self.client_certs:
    348             print("PRIVACY ALERT: Reactivate previously used client cert for {}?".format(gi.host))
    349             resp = input("Y/N? ")
    350             if resp.strip().lower() in ("y", "yes"):
    351                 self._activate_client_cert(*self.client_certs[gi.host])
    352             else:
    353                 print("Remaining unidentified.")
    354                 self.client_certs.pop(gi.host)
    355 
    356         # Do everything which touches the network in one block,
    357         # so we only need to catch exceptions once
    358         try:
    359             # Is this a local file?
    360             if not gi.host:
    361                 address, f = None, open(gi.path, "rb")
    362             else:
    363                 address, f = self._send_request(gi)
    364 
    365             # Spec dictates <META> should not exceed 1024 bytes,
    366             # so maximum valid header length is 1027 bytes.
    367             header = f.readline(1027)
    368             header = header.decode("UTF-8")
    369             if not header or header[-1] != '\n':
    370                 raise RuntimeError("Received invalid header from server!")
    371             header = header.strip()
    372             self._debug("Response header: %s." % header)
    373 
    374         # Catch network errors which may happen on initial connection
    375         except Exception as err:
    376             # Print an error message
    377             if isinstance(err, socket.gaierror):
    378                 self.log["dns_failures"] += 1
    379                 print("ERROR: DNS error!")
    380             elif isinstance(err, ConnectionRefusedError):
    381                 self.log["refused_connections"] += 1
    382                 print("ERROR: Connection refused!")
    383             elif isinstance(err, ConnectionResetError):
    384                 self.log["reset_connections"] += 1
    385                 print("ERROR: Connection reset!")
    386             elif isinstance(err, (TimeoutError, socket.timeout)):
    387                 self.log["timeouts"] += 1
    388                 print("""ERROR: Connection timed out!
    389 Slow internet connection?  Use 'set timeout' to be more patient.""")
    390             else:
    391                 print("ERROR: " + str(err))
    392             return
    393 
    394         # Validate header
    395         status, meta = header.split(maxsplit=1)
    396         if len(meta) > 1024 or len(status) != 2 or not status.isnumeric():
    397             print("ERROR: Received invalid header from server!")
    398             f.close()
    399             return
    400 
    401         # Update redirect loop/maze escaping state
    402         if not status.startswith("3"):
    403             self.previous_redirectors = set()
    404 
    405         # Handle non-SUCCESS headers, which don't have a response body
    406         # Inputs
    407         if status.startswith("1"):
    408             print(meta)
    409             if status == "11":
    410                 user_input = getpass.getpass("> ")
    411             else:
    412                 user_input = input("> ")
    413             self._go_to_gi(gi.query(user_input))
    414             return
    415         # Redirects
    416         elif status.startswith("3"):
    417             new_gi = GeminiItem(gi.absolutise_url(meta))
    418             if new_gi.url in self.previous_redirectors:
    419                 print("Error: caught in redirect loop!")
    420                 return
    421             elif len(self.previous_redirectors) == _MAX_REDIRECTS:
    422                 print("Error: refusing to follow more than %d consecutive redirects!" % _MAX_REDIRECTS)
    423                 return
    424             # Never follow cross-domain redirects without asking
    425             elif new_gi.host != gi.host:
    426                 follow = input("Follow cross-domain redirect to %s? (y/n) " % new_gi.url)
    427             # Never follow cross-protocol redirects without asking
    428             elif new_gi.scheme != gi.scheme:
    429                 follow = input("Follow cross-protocol redirect to %s? (y/n) " % new_gi.url)
    430             # Don't follow *any* redirect without asking if auto-follow is off
    431             elif not self.options["auto_follow_redirects"]:
    432                 follow = input("Follow redirect to %s? (y/n) " % new_gi.url)
    433             # Otherwise, follow away
    434             else:
    435                 follow = "yes"
    436             if follow.strip().lower() not in ("y", "yes"):
    437                 return
    438             self._debug("Following redirect to %s." % new_gi.url)
    439             self._debug("This is consecutive redirect number %d." % len(self.previous_redirectors))
    440             self.previous_redirectors.add(gi.url)
    441             if status == "31":
    442                 # Permanent redirect
    443                 self.permanent_redirects[gi.url] = new_gi.url
    444             self._go_to_gi(new_gi)
    445             return
    446         # Errors
    447         elif status.startswith("4") or status.startswith("5"):
    448             print("Error: %s" % meta)
    449             return
    450         # Client cert
    451         elif status.startswith("6"):
    452             # Don't do client cert stuff in restricted mode, as in principle
    453             # it could be used to fill up the disk by creating a whole lot of
    454             # certificates
    455             if self.restricted:
    456                 print("The server is requesting a client certificate.")
    457                 print("These are not supported in restricted mode, sorry.")
    458                 return
    459 
    460             # Transient certs are a special case
    461             if status == "61":
    462                 print("The server is asking to start a transient client certificate session.")
    463                 print("What do you want to do?")
    464                 print("1. Start a transient session.")
    465                 print("2. Refuse.")
    466                 choice = input("> ").strip()
    467                 if choice.strip() == "1":
    468                     self._generate_transient_cert_cert()
    469                     self._go_to_gi(gi, update_hist, handle)
    470                     return
    471                 else:
    472                     return
    473 
    474             # Present different messages for different 6x statuses, but
    475             # handle them the same.
    476             if status in ("64", "65"):
    477                 print("The server rejected your certificate because it is either expired or not yet valid.")
    478             elif status == "63":
    479                 print("The server did not accept your certificate.")
    480                 print("You may need to e.g. coordinate with the admin to get your certificate fingerprint whitelisted.")
    481             else:
    482                 print("The site {} is requesting a client certificate.".format(gi.host))
    483                 print("This will allow the site to recognise you across requests.")
    484             print("What do you want to do?")
    485             print("1. Give up.")
    486             print("2. Generate new certificate and retry the request.")
    487             print("3. Load previously generated certificate from file.")
    488             print("4. Load certificate from file and retry the request.")
    489             choice = input("> ").strip()
    490             if choice == "2":
    491                 self._generate_persistent_client_cert()
    492                 self._go_to_gi(gi, update_hist, handle)
    493             elif choice == "3":
    494                 self._choose_client_cert()
    495                 self._go_to_gi(gi, update_hist, handle)
    496             elif choice == "4":
    497                 self._load_client_cert()
    498                 self._go_to_gi(gi, update_hist, handle)
    499             else:
    500                 print("Giving up.")
    501             return
    502         # Invalid status
    503         elif not status.startswith("2"):
    504             print("ERROR: Server returned undefined status code %s!" % status)
    505             return
    506 
    507         # If we're here, this must be a success and there's a response body
    508         assert status.startswith("2")
    509 
    510         # Can we terminate a transient client session?
    511         if status == "21":
    512             # Make sure we're actually in such a session
    513             if self.active_is_transient:
    514                 self._deactivate_client_cert()
    515                 print("INFO: Server terminated transient client certificate session.")
    516             else:
    517                 # Huh, that's weird
    518                 self._debug("Server issues a 21 but we're not in transient session?")
    519 
    520         mime = meta
    521         if mime == "":
    522             mime = "text/gemini; charset=utf-8"
    523         mime, mime_options = cgi.parse_header(mime)
    524         if "charset" in mime_options:
    525             try:
    526                 codecs.lookup(mime_options["charset"])
    527             except LookupError:
    528                 print("Header declared unknown encoding %s" % value)
    529                 return
    530 
    531         # Read the response body over the network
    532         body = f.read()
    533 
    534         # Save the result in a temporary file
    535         ## Delete old file
    536         if self.tmp_filename and os.path.exists(self.tmp_filename):
    537             os.unlink(self.tmp_filename)
    538         ## Set file mode
    539         if mime.startswith("text/"):
    540             mode = "w"
    541             encoding = mime_options.get("charset", "UTF-8")
    542             try:
    543                 body = body.decode(encoding)
    544             except UnicodeError:
    545                 print("Could not decode response body using %s encoding declared in header!" % encoding)
    546                 return
    547         else:
    548             mode = "wb"
    549             encoding = None
    550         ## Write
    551         tmpf = tempfile.NamedTemporaryFile(mode, encoding=encoding, delete=False)
    552         size = tmpf.write(body)
    553         tmpf.close()
    554         self.tmp_filename = tmpf.name
    555         self._debug("Wrote %d byte response to %s." % (size, self.tmp_filename))
    556 
    557         # Pass file to handler, unless we were asked not to
    558         if handle:
    559             if mime == "text/gemini":
    560                 self._handle_index(body, gi)
    561             else:
    562                 cmd_str = self._get_handler_cmd(mime)
    563                 try:
    564                     subprocess.call(shlex.split(cmd_str % tmpf.name))
    565                 except FileNotFoundError:
    566                     print("Handler program %s not found!" % shlex.split(cmd_str)[0])
    567                     print("You can use the ! command to specify another handler program or pipeline.")
    568 
    569         # Update state
    570         self.gi = gi
    571         self.mime = mime
    572         self._log_visit(gi, address, size)
    573         if update_hist:
    574             self._update_history(gi)
    575 
    576     def _send_request(self, gi):
    577         """Send a selector to a given host and port.
    578         Returns the resolved address and binary file with the reply."""
    579         if gi.scheme == "gemini":
    580             # For Gemini requests, connect to the host and port specified in the URL
    581             host, port = gi.host, gi.port
    582         elif gi.scheme == "gopher":
    583             # For Gopher requests, use the configured proxy
    584             host, port = self.options["gopher_proxy"].rsplit(":", 1)
    585             self._debug("Using gopher proxy: " + self.options["gopher_proxy"])
    586 
    587         # Do DNS resolution
    588         addresses = self._get_addresses(host, port)
    589 
    590         # Prepare TLS context
    591         protocol = ssl.PROTOCOL_TLS if sys.version_info.minor >=6 else ssl.PROTOCOL_TLSv1_2
    592         context = ssl.SSLContext(protocol)
    593         # Use CAs or TOFU
    594         if self.options["tls_mode"] == "ca":
    595             context.verify_mode = ssl.CERT_REQUIRED
    596             context.check_hostname = True
    597             context.load_default_certs()
    598         else:
    599             context.check_hostname = False
    600             context.verify_mode = ssl.CERT_NONE
    601         # Impose minimum TLS version
    602         ## In 3.7 and above, this is easy...
    603         if sys.version_info.minor >= 7:
    604             context.minimum_version = ssl.TLSVersion.TLSv1_2
    605         ## Otherwise, it seems very hard...
    606         ## The below is less strict than it ought to be, but trying to disable
    607         ## TLS v1.1 here using ssl.OP_NO_TLSv1_1 produces unexpected failures
    608         ## with recent versions of OpenSSL.  What a mess...
    609         else:
    610             context.options |= ssl.OP_NO_SSLv3
    611             context.options |= ssl.OP_NO_SSLv2
    612         # Try to enforce sensible ciphers
    613         try:
    614             context.set_ciphers("AESGCM+ECDHE:AESGCM+DHE:CHACHA20+ECDHE:CHACHA20+DHE:!DSS:!SHA1:!MD5:@STRENGTH")
    615         except ssl.SSLError:
    616             # Rely on the server to only support sensible things, I guess...
    617             pass
    618         # Load client certificate if needed
    619         if self.client_certs["active"]:
    620             certfile, keyfile = self.client_certs["active"]
    621             context.load_cert_chain(certfile, keyfile)
    622 
    623         # Connect to remote host by any address possible
    624         err = None
    625         for address in addresses:
    626             self._debug("Connecting to: " + str(address[4]))
    627             s = socket.socket(address[0], address[1])
    628             s.settimeout(self.options["timeout"])
    629             s = context.wrap_socket(s, server_hostname = gi.host)
    630             try:
    631                 s.connect(address[4])
    632                 break
    633             except OSError as e:
    634                 err = e
    635         else:
    636             # If we couldn't connect to *any* of the addresses, just
    637             # bubble up the exception from the last attempt and deny
    638             # knowledge of earlier failures.
    639             raise err
    640 
    641         if sys.version_info.minor >=5:
    642             self._debug("Established {} connection.".format(s.version()))
    643         self._debug("Cipher is: {}.".format(s.cipher()))
    644 
    645         # Do TOFU
    646         if self.options["tls_mode"] != "ca":
    647             cert = s.getpeercert(binary_form=True)
    648             self._validate_cert(address[4][0], host, cert)
    649 
    650         # Remember that we showed the current cert to this domain...
    651         if self.client_certs["active"]:
    652             self.active_cert_domains.append(gi.host)
    653             self.client_certs[gi.host] = self.client_certs["active"]
    654 
    655         # Send request and wrap response in a file descriptor
    656         self._debug("Sending %s<CRLF>" % gi.url)
    657         s.sendall((gi.url + CRLF).encode("UTF-8"))
    658         return address, s.makefile(mode = "rb")
    659 
    660     def _get_addresses(self, host, port):
    661         # DNS lookup - will get IPv4 and IPv6 records if IPv6 is enabled
    662         if ":" in host:
    663             # This is likely a literal IPv6 address, so we can *only* ask for
    664             # IPv6 addresses or getaddrinfo will complain
    665             family_mask = socket.AF_INET6
    666         elif socket.has_ipv6 and self.options["ipv6"]:
    667             # Accept either IPv4 or IPv6 addresses
    668             family_mask = 0
    669         else:
    670             # IPv4 only
    671             family_mask = socket.AF_INET
    672         addresses = socket.getaddrinfo(host, port, family=family_mask,
    673                 type=socket.SOCK_STREAM)
    674         # Sort addresses so IPv6 ones come first
    675         addresses.sort(key=lambda add: add[0] == socket.AF_INET6, reverse=True)
    676 
    677         return addresses
    678 
    679     def _validate_cert(self, address, host, cert):
    680         """
    681         Validate a TLS certificate in TOFU mode.
    682 
    683         If the cryptography module is installed:
    684          - Check the certificate Common Name or SAN matches `host`
    685          - Check the certificate's not valid before date is in the past
    686          - Check the certificate's not valid after date is in the future
    687 
    688         Whether the cryptography module is installed or not, check the
    689         certificate's fingerprint against the TOFU database to see if we've
    690         previously encountered a different certificate for this IP address and
    691         hostname.
    692         """
    693         now = datetime.datetime.utcnow()
    694         if _HAS_CRYPTOGRAPHY:
    695             # Using the cryptography module we can get detailed access
    696             # to the properties of even self-signed certs, unlike in
    697             # the standard ssl library...
    698             c = x509.load_der_x509_certificate(cert, _BACKEND)
    699 
    700             # Check certificate validity dates
    701             if c.not_valid_before >= now:
    702                 raise CertificateError("Certificate not valid until: {}!".format(c.not_valid_before))
    703             elif c.not_valid_after <= now:
    704                 raise CertificateError("Certificate expired as of: {})!".format(c.not_valid_after))
    705 
    706             # Check certificate hostnames
    707             names = []
    708             common_name = c.subject.get_attributes_for_oid(x509.oid.NameOID.COMMON_NAME)
    709             if common_name:
    710                 names.append(common_name[0].value)
    711             try:
    712                 names.extend([alt.value for alt in c.extensions.get_extension_for_oid(x509.oid.ExtensionOID.SUBJECT_ALTERNATIVE_NAME).value])
    713             except x509.ExtensionNotFound:
    714                 pass
    715             names = set(names)
    716             for name in names:
    717                 try:
    718                     ssl._dnsname_match(name, host)
    719                     break
    720                 except CertificateError:
    721                     continue
    722             else:
    723                 # If we didn't break out, none of the names were valid
    724                 raise CertificateError("Hostname does not match certificate common name or any alternative names.")
    725 
    726         sha = hashlib.sha256()
    727         sha.update(cert)
    728         fingerprint = sha.hexdigest()
    729 
    730         # Have we been here before?
    731         self.db_cur.execute("""SELECT fingerprint, first_seen, last_seen, count
    732             FROM cert_cache
    733             WHERE hostname=? AND address=?""", (host, address))
    734         cached_certs = self.db_cur.fetchall()
    735 
    736         # If so, check for a match
    737         if cached_certs:
    738             max_count = 0
    739             most_frequent_cert = None
    740             for cached_fingerprint, first, last, count in cached_certs:
    741                 if count > max_count:
    742                     max_count = count
    743                     most_frequent_cert = cached_fingerprint
    744                 if fingerprint == cached_fingerprint:
    745                     # Matched!
    746                     self._debug("TOFU: Accepting previously seen ({} times) certificate {}".format(count, fingerprint))
    747                     self.db_cur.execute("""UPDATE cert_cache
    748                         SET last_seen=?, count=?
    749                         WHERE hostname=? AND address=? AND fingerprint=?""",
    750                         (now, count+1, host, address, fingerprint))
    751                     self.db_conn.commit()
    752                     break
    753             else:
    754                 if _HAS_CRYPTOGRAPHY:
    755                     # Load the most frequently seen certificate to see if it has
    756                     # expired
    757                     certdir = os.path.join(self.config_dir, "cert_cache")
    758                     with open(os.path.join(certdir, most_frequent_cert+".crt"), "rb") as fp:
    759                         previous_cert = fp.read()
    760                     previous_cert = x509.load_der_x509_certificate(previous_cert, _BACKEND)
    761                     previous_ttl = previous_cert.not_valid_after - now
    762                     print(previous_ttl)
    763 
    764                 self._debug("TOFU: Unrecognised certificate {}!  Raising the alarm...".format(fingerprint))
    765                 print("****************************************")
    766                 print("[SECURITY WARNING] Unrecognised certificate!")
    767                 print("The certificate presented for {} ({}) has never been seen before.".format(host, address))
    768                 print("This MIGHT be a Man-in-the-Middle attack.")
    769                 print("A different certificate has previously been seen {} times.".format(max_count))
    770                 if _HAS_CRYPTOGRAPHY:
    771                     if previous_ttl < datetime.timedelta():
    772                         print("That certificate has expired, which reduces suspicion somewhat.")
    773                     else:
    774                         print("That certificate is still valid for: {}".format(previous_ttl))
    775                 print("****************************************")
    776                 print("Attempt to verify the new certificate fingerprint out-of-band:")
    777                 print(fingerprint)
    778                 choice = input("Accept this new certificate? Y/N ").strip().lower()
    779                 if choice in ("y", "yes"):
    780                     self.db_cur.execute("""INSERT INTO cert_cache
    781                         VALUES (?, ?, ?, ?, ?, ?)""",
    782                         (host, address, fingerprint, now, now, 1))
    783                     self.db_conn.commit()
    784                     with open(os.path.join(certdir, fingerprint+".crt"), "wb") as fp:
    785                         fp.write(cert)
    786                 else:
    787                     raise Exception("TOFU Failure!")
    788 
    789         # If not, cache this cert
    790         else:
    791             self._debug("TOFU: Blindly trusting first ever certificate for this host!")
    792             self.db_cur.execute("""INSERT INTO cert_cache
    793                 VALUES (?, ?, ?, ?, ?, ?)""",
    794                 (host, address, fingerprint, now, now, 1))
    795             self.db_conn.commit()
    796             certdir = os.path.join(self.config_dir, "cert_cache")
    797             if not os.path.exists(certdir):
    798                 os.makedirs(certdir)
    799             with open(os.path.join(certdir, fingerprint+".crt"), "wb") as fp:
    800                 fp.write(cert)
    801 
    802     def _get_handler_cmd(self, mimetype):
    803         # Now look for a handler for this mimetype
    804         # Consider exact matches before wildcard matches
    805         exact_matches = []
    806         wildcard_matches = []
    807         for handled_mime, cmd_str in _MIME_HANDLERS.items():
    808             if "*" in handled_mime:
    809                 wildcard_matches.append((handled_mime, cmd_str))
    810             else:
    811                 exact_matches.append((handled_mime, cmd_str))
    812         for handled_mime, cmd_str in exact_matches + wildcard_matches:
    813             if fnmatch.fnmatch(mimetype, handled_mime):
    814                 break
    815         else:
    816             # Use "xdg-open" as a last resort.
    817             cmd_str = "xdg-open %s"
    818         self._debug("Using handler: %s" % cmd_str)
    819         return cmd_str
    820 
    821     def _handle_index(self, body, menu_gi, display=True):
    822         self.index = []
    823         preformatted = False
    824         if self.idx_filename:
    825             os.unlink(self.idx_filename)
    826         tmpf = tempfile.NamedTemporaryFile("w", encoding="UTF-8", delete=False)
    827         self.idx_filename = tmpf.name
    828         for line in body.splitlines():
    829             if line.startswith("```"):
    830                 preformatted = not preformatted
    831             elif preformatted:
    832                 tmpf.write(line + "\n")
    833             elif line.startswith("=>"):
    834                 try:
    835                     gi = GeminiItem.from_map_line(line, menu_gi)
    836                     self.index.append(gi)
    837                     tmpf.write(self._format_geminiitem(len(self.index), gi) + "\n")
    838                 except:
    839                     self._debug("Skipping possible link: %s" % line)
    840             elif line.startswith("* "):
    841                 line = line[1:].lstrip("\t ")
    842                 tmpf.write(textwrap.fill(line, self.options["width"],
    843                     initial_indent = "• ", subsequent_indent="  ") + "\n")
    844             elif line.startswith(">"):
    845                 line = line[1:].lstrip("\t ")
    846                 tmpf.write(textwrap.fill(line, self.options["width"],
    847                     initial_indent = "> ", subsequent_indent="> ") + "\n")
    848             elif line.startswith("###"):
    849                 line = line[3:].lstrip("\t ")
    850                 tmpf.write("\x1b[4m" + line + "\x1b[0m""\n")
    851             elif line.startswith("##"):
    852                 line = line[2:].lstrip("\t ")
    853                 tmpf.write("\x1b[1m" + line + "\x1b[0m""\n")
    854             elif line.startswith("#"):
    855                 line = line[1:].lstrip("\t ")
    856                 tmpf.write("\x1b[1m\x1b[4m" + line + "\x1b[0m""\n")
    857             else:
    858                 tmpf.write(textwrap.fill(line, self.options["width"]) + "\n")
    859         tmpf.close()
    860 
    861         self.lookup = self.index
    862         self.page_index = 0
    863         self.index_index = -1
    864 
    865         if display:
    866             cmd_str = _MIME_HANDLERS["text/plain"]
    867             subprocess.call(shlex.split(cmd_str % self.idx_filename))
    868 
    869     def _format_geminiitem(self, index, gi, url=False):
    870         line = "[%d] %s" % (index, gi.name or gi.url)
    871         if gi.name and url:
    872             line += " (%s)" % gi.url
    873         return line
    874 
    875     def _show_lookup(self, offset=0, end=None, url=False):
    876         for n, gi in enumerate(self.lookup[offset:end]):
    877             print(self._format_geminiitem(n+offset+1, gi, url))
    878 
    879     def _update_history(self, gi):
    880         # Don't duplicate
    881         if self.history and self.history[self.hist_index] == gi:
    882             return
    883         self.history = self.history[0:self.hist_index+1]
    884         self.history.append(gi)
    885         self.hist_index = len(self.history) - 1
    886 
    887     def _log_visit(self, gi, address, size):
    888         if not address:
    889             return
    890         self.log["requests"] += 1
    891         self.log["bytes_recvd"] += size
    892         self.visited_hosts.add(address)
    893         if address[0] == socket.AF_INET:
    894             self.log["ipv4_requests"] += 1
    895             self.log["ipv4_bytes_recvd"] += size
    896         elif address[0] == socket.AF_INET6:
    897             self.log["ipv6_requests"] += 1
    898             self.log["ipv6_bytes_recvd"] += size
    899 
    900     def _get_active_tmpfile(self):
    901         if self.mime == "text/gemini":
    902             return self.idx_filename
    903         else:
    904             return self.tmp_filename
    905 
    906     def _debug(self, debug_text):
    907         if not self.options["debug"]:
    908             return
    909         debug_text = "\x1b[0;32m[DEBUG] " + debug_text + "\x1b[0m"
    910         print(debug_text)
    911 
    912     def _load_client_cert(self):
    913         """
    914         Interactively load a TLS client certificate from the filesystem in PEM
    915         format.
    916         """
    917         print("Loading client certificate file, in PEM format (blank line to cancel)")
    918         certfile = input("Certfile path: ").strip()
    919         if not certfile:
    920             print("Aborting.")
    921             return
    922         elif not os.path.exists(certfile):
    923             print("Certificate file {} does not exist.".format(certfile))
    924             return
    925         print("Loading private key file, in PEM format (blank line to cancel)")
    926         keyfile = input("Keyfile path: ").strip()
    927         if not keyfile:
    928             print("Aborting.")
    929             return
    930         elif not os.path.exists(keyfile):
    931             print("Private key file {} does not exist.".format(keyfile))
    932             return
    933         self._activate_client_cert(certfile, keyfile)
    934 
    935     def _generate_transient_cert_cert(self):
    936         """
    937         Use `openssl` command to generate a new transient client certificate
    938         with 24 hours of validity.
    939         """
    940         certdir = os.path.join(self.config_dir, "transient_certs")
    941         name = str(uuid.uuid4())
    942         self._generate_client_cert(certdir, name, transient=True)
    943         self.active_is_transient = True
    944         self.transient_certs_created.append(name)
    945 
    946     def _generate_persistent_client_cert(self):
    947         """
    948         Interactively use `openssl` command to generate a new persistent client
    949         certificate with one year of validity.
    950         """
    951         print("What do you want to name this new certificate?")
    952         print("Answering `mycert` will create `~/.av98/certs/mycert.crt` and `~/.av98/certs/mycert.key`")
    953         name = input()
    954         if not name.strip():
    955             print("Aborting.")
    956             return
    957         certdir = os.path.join(self.config_dir, "client_certs")
    958         self._generate_client_cert(certdir, name)
    959 
    960     def _generate_client_cert(self, certdir, basename, transient=False):
    961         """
    962         Use `openssl` binary to generate a client certificate (which may be
    963         transient or persistent) and save the certificate and private key to the
    964         specified directory with the specified basename.
    965         """
    966         if not os.path.exists(certdir):
    967             os.makedirs(certdir)
    968         certfile = os.path.join(certdir, basename+".crt")
    969         keyfile = os.path.join(certdir, basename+".key")
    970         cmd = "openssl req -x509 -newkey rsa:2048 -days {} -nodes -keyout {} -out {}".format(1 if transient else 365, keyfile, certfile)
    971         if transient:
    972             cmd += " -subj '/CN={}'".format(basename)
    973         os.system(cmd)
    974         self._activate_client_cert(certfile, keyfile)
    975 
    976     def _choose_client_cert(self):
    977         """
    978         Interactively select a previously generated client certificate and
    979         activate it.
    980         """
    981         certdir = os.path.join(self.config_dir, "client_certs")
    982         certs = glob.glob(os.path.join(certdir, "*.crt"))
    983         certdir = {}
    984         for n, cert in enumerate(certs):
    985             certdir[str(n+1)] = (cert, os.path.splitext(cert)[0] + ".key")
    986             print("{}. {}".format(n+1, os.path.splitext(os.path.basename(cert))[0]))
    987         choice = input("> ").strip()
    988         if choice in certdir:
    989             certfile, keyfile = certdir[choice]
    990             self._activate_client_cert(certfile, keyfile)
    991         else:
    992             print("What?")
    993 
    994     def _activate_client_cert(self, certfile, keyfile):
    995         self.client_certs["active"] = (certfile, keyfile)
    996         self.active_cert_domains = []
    997         self.prompt = self.cert_prompt
    998         self._debug("Using ID {} / {}.".format(*self.client_certs["active"]))
    999 
   1000     def _deactivate_client_cert(self):
   1001         if self.active_is_transient:
   1002             for filename in self.client_certs["active"]:
   1003                 os.remove(filename)
   1004             for domain in self.active_cert_domains:
   1005                 self.client_certs.pop(domain)
   1006         self.client_certs["active"] = None
   1007         self.active_cert_domains = []
   1008         self.prompt = self.no_cert_prompt
   1009         self.active_is_transient = False
   1010 
   1011     # Cmd implementation follows
   1012 
   1013     def default(self, line):
   1014         if line.strip() == "EOF":
   1015             return self.onecmd("quit")
   1016         elif line.strip() == "..":
   1017             return self.do_up()
   1018         elif line.startswith("/"):
   1019             return self.do_search(line[1:])
   1020 
   1021         # Expand abbreviated commands
   1022         first_word = line.split()[0].strip()
   1023         if first_word in _ABBREVS:
   1024             full_cmd = _ABBREVS[first_word]
   1025             expanded = line.replace(first_word, full_cmd, 1)
   1026             return self.onecmd(expanded)
   1027 
   1028         # Try to parse numerical index for lookup table
   1029         try:
   1030             n = int(line.strip())
   1031         except ValueError:
   1032             print("What?")
   1033             return
   1034 
   1035         try:
   1036             gi = self.lookup[n-1]
   1037         except IndexError:
   1038             print ("Index too high!")
   1039             return
   1040 
   1041         self.index_index = n
   1042         self._go_to_gi(gi)
   1043 
   1044     ### Settings
   1045     @restricted
   1046     def do_set(self, line):
   1047         """View or set various options."""
   1048         if not line.strip():
   1049             # Show all current settings
   1050             for option in sorted(self.options.keys()):
   1051                 print("%s   %s" % (option, self.options[option]))
   1052         elif len(line.split()) == 1:
   1053             # Show current value of one specific setting
   1054             option = line.strip()
   1055             if option in self.options:
   1056                 print("%s   %s" % (option, self.options[option]))
   1057             else:
   1058                 print("Unrecognised option %s" % option)
   1059         else:
   1060             # Set value of one specific setting
   1061             option, value = line.split(" ", 1)
   1062             if option not in self.options:
   1063                 print("Unrecognised option %s" % option)
   1064                 return
   1065             # Validate / convert values
   1066             if option == "gopher_proxy":
   1067                 if ":" not in value:
   1068                     value += ":1965"
   1069                 else:
   1070                     host, port = value.rsplit(":",1)
   1071                     if not port.isnumeric():
   1072                         print("Invalid proxy port %s" % port)
   1073                         return
   1074             elif option == "tls_mode":
   1075                 if value.lower() not in ("ca", "tofu"):
   1076                     print("TLS mode must be `ca` or `tofu`!")
   1077                     return
   1078             elif value.isnumeric():
   1079                 value = int(value)
   1080             elif value.lower() == "false":
   1081                 value = False
   1082             elif value.lower() == "true":
   1083                 value = True
   1084             else:
   1085                 try:
   1086                     value = float(value)
   1087                 except ValueError:
   1088                     pass
   1089             self.options[option] = value
   1090 
   1091     @restricted
   1092     def do_cert(self, line):
   1093         """Manage client certificates"""
   1094         print("Managing client certificates")
   1095         if self.client_certs["active"]:
   1096             print("Active certificate: {}".format(self.client_certs["active"][0]))
   1097         print("1. Deactivate client certificate.")
   1098         print("2. Generate new certificate.")
   1099         print("3. Load previously generated certificate.")
   1100         print("4. Load externally created client certificate from file.")
   1101         print("Enter blank line to exit certificate manager.")
   1102         choice = input("> ").strip()
   1103         if choice == "1":
   1104             print("Deactivating client certificate.")
   1105             self._deactivate_client_cert()
   1106         elif choice == "2":
   1107             self._generate_persistent_client_cert()
   1108         elif choice == "3":
   1109             self._choose_client_cert()
   1110         elif choice == "4":
   1111             self._load_client_cert()
   1112         else:
   1113             print("Aborting.")
   1114 
   1115     @restricted
   1116     def do_handler(self, line):
   1117         """View or set handler commands for different MIME types."""
   1118         if not line.strip():
   1119             # Show all current handlers
   1120             for mime in sorted(_MIME_HANDLERS.keys()):
   1121                 print("%s   %s" % (mime, _MIME_HANDLERS[mime]))
   1122         elif len(line.split()) == 1:
   1123             mime = line.strip()
   1124             if mime in _MIME_HANDLERS:
   1125                 print("%s   %s" % (mime, _MIME_HANDLERS[mime]))
   1126             else:
   1127                 print("No handler set for MIME type %s" % mime)
   1128         else:
   1129             mime, handler = line.split(" ", 1)
   1130             _MIME_HANDLERS[mime] = handler
   1131             if "%s" not in handler:
   1132                 print("Are you sure you don't want to pass the filename to the handler?")
   1133 
   1134     def do_abbrevs(self, *args):
   1135         """Print all AV-98 command abbreviations."""
   1136         header = "Command Abbreviations:"
   1137         self.stdout.write("\n{}\n".format(str(header)))
   1138         if self.ruler:
   1139             self.stdout.write("{}\n".format(str(self.ruler * len(header))))
   1140         for k, v in _ABBREVS.items():
   1141             self.stdout.write("{:<7}  {}\n".format(k, v))
   1142         self.stdout.write("\n")
   1143 
   1144     ### Stuff for getting around
   1145     def do_go(self, line):
   1146         """Go to a gemini URL or marked item."""
   1147         line = line.strip()
   1148         if not line:
   1149             print("Go where?")
   1150         # First, check for possible marks
   1151         elif line in self.marks:
   1152             gi = self.marks[line]
   1153             self._go_to_gi(gi)
   1154         # or a local file
   1155         elif os.path.exists(os.path.expanduser(line)):
   1156             gi = GeminiItem(None, None, os.path.expanduser(line),
   1157                             "1", line, False)
   1158             self._go_to_gi(gi)
   1159         # If this isn't a mark, treat it as a URL
   1160         else:
   1161             self._go_to_gi(GeminiItem(line))
   1162 
   1163     @needs_gi
   1164     def do_reload(self, *args):
   1165         """Reload the current URL."""
   1166         self._go_to_gi(self.gi)
   1167 
   1168     @needs_gi
   1169     def do_up(self, *args):
   1170         """Go up one directory in the path."""
   1171         self._go_to_gi(self.gi.up())
   1172 
   1173     def do_back(self, *args):
   1174         """Go back to the previous gemini item."""
   1175         if not self.history or self.hist_index == 0:
   1176             return
   1177         self.hist_index -= 1
   1178         gi = self.history[self.hist_index]
   1179         self._go_to_gi(gi, update_hist=False)
   1180 
   1181     def do_forward(self, *args):
   1182         """Go forward to the next gemini item."""
   1183         if not self.history or self.hist_index == len(self.history) - 1:
   1184             return
   1185         self.hist_index += 1
   1186         gi = self.history[self.hist_index]
   1187         self._go_to_gi(gi, update_hist=False)
   1188 
   1189     def do_next(self, *args):
   1190         """Go to next item after current in index."""
   1191         return self.onecmd(str(self.index_index+1))
   1192 
   1193     def do_previous(self, *args):
   1194         """Go to previous item before current in index."""
   1195         self.lookup = self.index
   1196         return self.onecmd(str(self.index_index-1))
   1197 
   1198     @needs_gi
   1199     def do_root(self, *args):
   1200         """Go to root selector of the server hosting current item."""
   1201         self._go_to_gi(self.gi.root())
   1202 
   1203     def do_tour(self, line):
   1204         """Add index items as waypoints on a tour, which is basically a FIFO
   1205 queue of gemini items.
   1206 
   1207 Items can be added with `tour 1 2 3 4` or ranges like `tour 1-4`.
   1208 All items in current menu can be added with `tour *`.
   1209 Current tour can be listed with `tour ls` and scrubbed with `tour clear`."""
   1210         line = line.strip()
   1211         if not line:
   1212             # Fly to next waypoint on tour
   1213             if not self.waypoints:
   1214                 print("End of tour.")
   1215             else:
   1216                 gi = self.waypoints.pop(0)
   1217                 self._go_to_gi(gi)
   1218         elif line == "ls":
   1219             old_lookup = self.lookup
   1220             self.lookup = self.waypoints
   1221             self._show_lookup()
   1222             self.lookup = old_lookup
   1223         elif line == "clear":
   1224             self.waypoints = []
   1225         elif line == "*":
   1226             self.waypoints.extend(self.lookup)
   1227         elif looks_like_url(line):
   1228             self.waypoints.append(GeminiItem(line))
   1229         else:
   1230             for index in line.split():
   1231                 try:
   1232                     pair = index.split('-')
   1233                     if len(pair) == 1:
   1234                         # Just a single index
   1235                         n = int(index)
   1236                         gi = self.lookup[n-1]
   1237                         self.waypoints.append(gi)
   1238                     elif len(pair) == 2:
   1239                         # Two endpoints for a range of indices
   1240                         for n in range(int(pair[0]), int(pair[1]) + 1):
   1241                             gi = self.lookup[n-1]
   1242                             self.waypoints.append(gi)
   1243                     else:
   1244                         # Syntax error
   1245                         print("Invalid use of range syntax %s, skipping" % index)
   1246                 except ValueError:
   1247                     print("Non-numeric index %s, skipping." % index)
   1248                 except IndexError:
   1249                     print("Invalid index %d, skipping." % n)
   1250 
   1251     @needs_gi
   1252     def do_mark(self, line):
   1253         """Mark the current item with a single letter.  This letter can then
   1254 be passed to the 'go' command to return to the current item later.
   1255 Think of it like marks in vi: 'mark a'='ma' and 'go a'=''a'."""
   1256         line = line.strip()
   1257         if not line:
   1258             for mark, gi in self.marks.items():
   1259                 print("[%s] %s (%s)" % (mark, gi.name, gi.url))
   1260         elif line.isalpha() and len(line) == 1:
   1261             self.marks[line] = self.gi
   1262         else:
   1263             print("Invalid mark, must be one letter")
   1264 
   1265     def do_version(self, line):
   1266         """Display version information."""
   1267         print("AV-98 " + _VERSION)
   1268 
   1269     ### Stuff that modifies the lookup table
   1270     def do_ls(self, line):
   1271         """List contents of current index.
   1272 Use 'ls -l' to see URLs."""
   1273         self.lookup = self.index
   1274         self._show_lookup(url = "-l" in line)
   1275         self.page_index = 0
   1276 
   1277     def do_gus(self, line):
   1278         """Submit a search query to the GUS search engine."""
   1279         gus = GeminiItem("gemini://gus.guru/search")
   1280         self._go_to_gi(gus.query(line))
   1281 
   1282     def do_history(self, *args):
   1283         """Display history."""
   1284         self.lookup = self.history
   1285         self._show_lookup(url=True)
   1286         self.page_index = 0
   1287 
   1288     def do_search(self, searchterm):
   1289         """Search index (case insensitive)."""
   1290         results = [
   1291             gi for gi in self.lookup if searchterm.lower() in gi.name.lower()]
   1292         if results:
   1293             self.lookup = results
   1294             self._show_lookup()
   1295             self.page_index = 0
   1296         else:
   1297             print("No results found.")
   1298 
   1299     def emptyline(self):
   1300         """Page through index ten lines at a time."""
   1301         i = self.page_index
   1302         if i > len(self.lookup):
   1303             return
   1304         self._show_lookup(offset=i, end=i+10)
   1305         self.page_index += 10
   1306 
   1307     ### Stuff that does something to most recently viewed item
   1308     @needs_gi
   1309     def do_cat(self, *args):
   1310         """Run most recently visited item through "cat" command."""
   1311         subprocess.call(shlex.split("cat %s" % self._get_active_tmpfile()))
   1312 
   1313     @needs_gi
   1314     def do_less(self, *args):
   1315         """Run most recently visited item through "less" command."""
   1316         cmd_str = self._get_handler_cmd(self.mime)
   1317         cmd_str = cmd_str % self._get_active_tmpfile()
   1318         subprocess.call("%s | less -R" % cmd_str, shell=True)
   1319 
   1320     @needs_gi
   1321     def do_fold(self, *args):
   1322         """Run most recently visited item through "fold" command."""
   1323         cmd_str = self._get_handler_cmd(self.mime)
   1324         cmd_str = cmd_str % self._get_active_tmpfile()
   1325         subprocess.call("%s | fold -w 70 -s" % cmd_str, shell=True)
   1326 
   1327     @restricted
   1328     @needs_gi
   1329     def do_shell(self, line):
   1330         """'cat' most recently visited item through a shell pipeline."""
   1331         subprocess.call(("cat %s |" % self._get_active_tmpfile()) + line, shell=True)
   1332 
   1333     @restricted
   1334     @needs_gi
   1335     def do_save(self, line):
   1336         """Save an item to the filesystem.
   1337 'save n filename' saves menu item n to the specified filename.
   1338 'save filename' saves the last viewed item to the specified filename.
   1339 'save n' saves menu item n to an automagic filename."""
   1340         args = line.strip().split()
   1341 
   1342         # First things first, figure out what our arguments are
   1343         if len(args) == 0:
   1344             # No arguments given at all
   1345             # Save current item, if there is one, to a file whose name is
   1346             # inferred from the gemini path
   1347             if not self.tmp_filename:
   1348                 print("You need to visit an item first!")
   1349                 return
   1350             else:
   1351                 index = None
   1352                 filename = None
   1353         elif len(args) == 1:
   1354             # One argument given
   1355             # If it's numeric, treat it as an index, and infer the filename
   1356             try:
   1357                 index = int(args[0])
   1358                 filename = None
   1359             # If it's not numeric, treat it as a filename and
   1360             # save the current item
   1361             except ValueError:
   1362                 index = None
   1363                 filename = os.path.expanduser(args[0])
   1364         elif len(args) == 2:
   1365             # Two arguments given
   1366             # Treat first as an index and second as filename
   1367             index, filename = args
   1368             try:
   1369                 index = int(index)
   1370             except ValueError:
   1371                 print("First argument is not a valid item index!")
   1372                 return
   1373             filename = os.path.expanduser(filename)
   1374         else:
   1375             print("You must provide an index, a filename, or both.")
   1376             return
   1377 
   1378         # Next, fetch the item to save, if it's not the current one.
   1379         if index:
   1380             last_gi = self.gi
   1381             try:
   1382                 gi = self.lookup[index-1]
   1383                 self._go_to_gi(gi, update_hist = False, handle = False)
   1384             except IndexError:
   1385                 print ("Index too high!")
   1386                 self.gi = last_gi
   1387                 return
   1388         else:
   1389             gi = self.gi
   1390 
   1391         # Derive filename from current GI's path, if one hasn't been set
   1392         if not filename:
   1393             filename = os.path.basename(gi.path)
   1394 
   1395         # Check for filename collisions and actually do the save if safe
   1396         if os.path.exists(filename):
   1397             print("File %s already exists!" % filename)
   1398         else:
   1399             # Don't use _get_active_tmpfile() here, because we want to save the
   1400             # "source code" of menus, not the rendered view - this way AV-98
   1401             # can navigate to it later.
   1402             shutil.copyfile(self.tmp_filename, filename)
   1403             print("Saved to %s" % filename)
   1404 
   1405         # Restore gi if necessary
   1406         if index != None:
   1407             self._go_to_gi(last_gi, handle=False)
   1408 
   1409     @needs_gi
   1410     def do_url(self, *args):
   1411         """Print URL of most recently visited item."""
   1412         print(self.gi.url)
   1413 
   1414     ### Bookmarking stuff
   1415     @restricted
   1416     @needs_gi
   1417     def do_add(self, line):
   1418         """Add the current URL to the bookmarks menu.
   1419 Optionally, specify the new name for the bookmark."""
   1420         with open(os.path.join(self.config_dir, "bookmarks.gmi"), "a") as fp:
   1421             fp.write(self.gi.to_map_line(line))
   1422 
   1423     def do_bookmarks(self, line):
   1424         """Show or access the bookmarks menu.
   1425 'bookmarks' shows all bookmarks.
   1426 'bookmarks n' navigates immediately to item n in the bookmark menu.
   1427 Bookmarks are stored using the 'add' command."""
   1428         bm_file = os.path.join(self.config_dir, "bookmarks.gmi")
   1429         if not os.path.exists(bm_file):
   1430             print("You need to 'add' some bookmarks, first!")
   1431             return
   1432         args = line.strip()
   1433         if len(args.split()) > 1 or (args and not args.isnumeric()):
   1434             print("bookmarks command takes a single integer argument!")
   1435             return
   1436         with open(bm_file, "r") as fp:
   1437             body = fp.read()
   1438             gi = GeminiItem("localhost/" + bm_file)
   1439             self._handle_index(body, gi, display = not args)
   1440             if args:
   1441                 # Use argument as a numeric index
   1442                 self.default(line)
   1443 
   1444     ### Help
   1445     def do_help(self, arg):
   1446         """ALARM! Recursion detected! ALARM! Prepare to eject!"""
   1447         if arg == "!":
   1448             print("! is an alias for 'shell'")
   1449         elif arg == "?":
   1450             print("? is an alias for 'help'")
   1451         else:
   1452             cmd.Cmd.do_help(self, arg)
   1453 
   1454     ### Flight recorder
   1455     def do_blackbox(self, *args):
   1456         """Display contents of flight recorder, showing statistics for the
   1457 current gemini browsing session."""
   1458         lines = []
   1459         # Compute flight time
   1460         now = time.time()
   1461         delta = now - self.log["start_time"]
   1462         hours, remainder = divmod(delta, 3600)
   1463         minutes, seconds = divmod(remainder, 60)
   1464         # Count hosts
   1465         ipv4_hosts = len([host for host in self.visited_hosts if host[0] == socket.AF_INET])
   1466         ipv6_hosts = len([host for host in self.visited_hosts if host[0] == socket.AF_INET6])
   1467         # Assemble lines
   1468         lines.append(("Patrol duration", "%02d:%02d:%02d" % (hours, minutes, seconds)))
   1469         lines.append(("Requests sent:", self.log["requests"]))
   1470         lines.append(("   IPv4 requests:", self.log["ipv4_requests"]))
   1471         lines.append(("   IPv6 requests:", self.log["ipv6_requests"]))
   1472         lines.append(("Bytes received:", self.log["bytes_recvd"]))
   1473         lines.append(("   IPv4 bytes:", self.log["ipv4_bytes_recvd"]))
   1474         lines.append(("   IPv6 bytes:", self.log["ipv6_bytes_recvd"]))
   1475         lines.append(("Unique hosts visited:", len(self.visited_hosts)))
   1476         lines.append(("   IPv4 hosts:", ipv4_hosts))
   1477         lines.append(("   IPv6 hosts:", ipv6_hosts))
   1478         lines.append(("DNS failures:", self.log["dns_failures"]))
   1479         lines.append(("Timeouts:", self.log["timeouts"]))
   1480         lines.append(("Refused connections:", self.log["refused_connections"]))
   1481         lines.append(("Reset connections:", self.log["reset_connections"]))
   1482         # Print
   1483         for key, value in lines:
   1484             print(key.ljust(24) + str(value).rjust(8))
   1485 
   1486     ### The end!
   1487     def do_quit(self, *args):
   1488         """Exit AV-98."""
   1489         # Close TOFU DB
   1490         self.db_conn.commit()
   1491         self.db_conn.close()
   1492         # Clean up after ourself
   1493         if self.tmp_filename and os.path.exists(self.tmp_filename):
   1494             os.unlink(self.tmp_filename)
   1495         if self.idx_filename and os.path.exists(self.idx_filename):
   1496             os.unlink(self.idx_filename)
   1497         for cert in self.transient_certs_created:
   1498             for ext in (".crt", ".key"):
   1499                 certfile = os.path.join(self.config_dir, "transient_certs", cert+ext)
   1500                 if os.path.exists(certfile):
   1501                     os.remove(certfile)
   1502         print()
   1503         print("Thank you for flying AV-98!")
   1504         sys.exit()
   1505 
   1506     do_exit = do_quit
   1507 
   1508 # Main function
   1509 def main():
   1510 
   1511     # Parse args
   1512     parser = argparse.ArgumentParser(description='A command line gemini client.')
   1513     parser.add_argument('--bookmarks', action='store_true',
   1514                         help='start with your list of bookmarks')
   1515     parser.add_argument('--tls-cert', metavar='FILE', help='TLS client certificate file')
   1516     parser.add_argument('--tls-key', metavar='FILE', help='TLS client certificate private key file')
   1517     parser.add_argument('--restricted', action="store_true", help='Disallow shell, add, and save commands')
   1518     parser.add_argument('--version', action='store_true',
   1519                         help='display version information and quit')
   1520     parser.add_argument('url', metavar='URL', nargs='*',
   1521                         help='start with this URL')
   1522     args = parser.parse_args()
   1523 
   1524     # Handle --version
   1525     if args.version:
   1526         print("AV-98 " + _VERSION)
   1527         sys.exit()
   1528 
   1529     # Instantiate client
   1530     gc = GeminiClient(args.restricted)
   1531 
   1532     # Process config file
   1533     rcfile = os.path.join(gc.config_dir, "av98rc")
   1534     if os.path.exists(rcfile):
   1535         print("Using config %s" % rcfile)
   1536         with open(rcfile, "r") as fp:
   1537             for line in fp:
   1538                 line = line.strip()
   1539                 if ((args.bookmarks or args.url) and
   1540                     any((line.startswith(x) for x in ("go", "g", "tour", "t")))
   1541                    ):
   1542                     if args.bookmarks:
   1543                         print("Skipping rc command \"%s\" due to --bookmarks option." % line)
   1544                     else:
   1545                         print("Skipping rc command \"%s\" due to provided URLs." % line)
   1546                     continue
   1547                 gc.cmdqueue.append(line)
   1548 
   1549     # Say hi
   1550     print("Welcome to AV-98!")
   1551     if args.restricted:
   1552         print("Restricted mode engaged!")
   1553     print("Enjoy your patrol through Geminispace...")
   1554 
   1555     # Act on args
   1556     if args.tls_cert:
   1557         # If tls_key is None, python will attempt to load the key from tls_cert.
   1558         gc._activate_client_cert(args.tls_cert, args.tls_key)
   1559     if args.bookmarks:
   1560         gc.cmdqueue.append("bookmarks")
   1561     elif args.url:
   1562         if len(args.url) == 1:
   1563             gc.cmdqueue.append("go %s" % args.url[0])
   1564         else:
   1565             for url in args.url:
   1566                 if not url.startswith("gemini://"):
   1567                     url = "gemini://" + url
   1568                 gc.cmdqueue.append("tour %s" % url)
   1569             gc.cmdqueue.append("tour")
   1570 
   1571     # Endless interpret loop
   1572     while True:
   1573         try:
   1574             gc.cmdloop()
   1575         except KeyboardInterrupt:
   1576             print("")
   1577 
   1578 if __name__ == '__main__':
   1579     main()