clortho.py (6314B)
1 # clortho - A simple key/value server 2 # 3 # Copyright 2014-2018 by Brian C. Lane <bcl@brianlane.com> 4 # All Rights Reserved 5 # 6 # This program is free software; you can redistribute it and/or modify 7 # it under the terms of the GNU General Public License as published by 8 # the Free Software Foundation; either version 2 of the License, or 9 # (at your option) any later version. 10 # 11 # This program is distributed in the hope that it will be useful, 12 # but WITHOUT ANY WARRANTY; without even the implied warranty of 13 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 14 # GNU General Public License for more details. 15 # 16 # You should have received a copy of the GNU General Public License 17 # along with this program. If not, see <http://www.gnu.org/licenses/>. 18 # 19 # Author(s): Brian C. Lane <bcl@brianlane.com> 20 # 21 from typing import cast, Tuple, Dict 22 23 import os 24 import asyncio 25 from asyncio.base_events import Server 26 from asyncio import AbstractEventLoop 27 import signal 28 from argparse import ArgumentParser, Namespace 29 import pickle 30 from aiohttp import web 31 32 KeystoreType = Dict[str, Dict[str, str]] 33 34 VERSION = "1.1.0" # type: str 35 36 def get_client(request: web.Request) -> str: 37 client = "" # type: str 38 if "X-Forwarded-For" in request.headers: 39 client = request.headers["X-Forwarded-For"].split(",")[0] 40 if client.startswith("::ffff:"): 41 client = client[7:] 42 else: 43 peername = request.transport.get_extra_info('peername') # type: Tuple[str, str] 44 if peername is not None: 45 client, _port = peername 46 return client 47 48 async def get_version(request: web.Request) -> web.Response: 49 text = "version: %s" % VERSION 50 status = 200 51 return web.Response(text=text, status=status) 52 53 async def show_info(request: web.Request) -> web.Response: 54 text = "<html><body><pre>\n" 55 text += "\n".join("%s = %s" % (hdr, request.headers[hdr]) for hdr in request.headers) 56 peername = request.transport.get_extra_info('peername') 57 if peername is not None: 58 text += "\npeer = %s:%s\n" % (peername[0], peername[1]) 59 text += "</pre></body></html>\n" 60 61 return web.Response(text=text, content_type="text/html", status=200) 62 63 async def get_key(request: web.Request) -> web.Response: 64 keystore = request.app["keystore"] # type: KeystoreType 65 key = request.match_info.get('key') # type: str 66 67 client = get_client(request) # type: str 68 if client in keystore and key in keystore[client]: 69 text = keystore[client][key] # type: str 70 status = 200 # type: int 71 else: 72 text = "%s doesn't exist for %s" % (key, client) 73 status = 404 74 return web.Response(text=text, status=status) 75 76 async def set_key(request: web.Request) -> web.Response: 77 keystore = request.app["keystore"] # type: KeystoreType 78 key = request.match_info.get('key') # type: str 79 post_data = await request.post() # type: Dict[str, str] 80 81 client = get_client(request) # type: str 82 if client != "" and key != "" and "value" in post_data: 83 if client not in keystore: 84 keystore[client] = {} 85 if post_data["value"] is not None: 86 keystore[client][key] = post_data["value"] 87 else: 88 del keystore[client][key] 89 text = "OK" # type: str 90 status = 200 # type: int 91 else: 92 text = "ERROR" 93 status = 404 94 95 return web.Response(text=text, status=status) 96 97 def setup_app(loop: AbstractEventLoop) -> web.Application: 98 app = web.Application(loop=loop) 99 app.router.add_route('GET', '/keystore/version', get_version) 100 app.router.add_route('GET', '/keystore/info', show_info) 101 app.router.add_route('GET', '/keystore/{key}', get_key) 102 app.router.add_route('POST', '/keystore/{key}', set_key) 103 return app 104 105 async def init(loop: AbstractEventLoop, host: str, port: int, keystore: KeystoreType) -> Server: 106 app = setup_app(loop) 107 app["keystore"] = keystore 108 srv = await loop.create_server(app.make_handler(), host, port) 109 print("Server started at http://%s:%s" % (host, port)) 110 return srv 111 112 def setup_parser() -> ArgumentParser: 113 parser = ArgumentParser(description="Clortho key server") 114 parser.add_argument("--host", default="127.0.0.1", help="Hostname or IP address to bind to") 115 parser.add_argument("--port", default="9001", help="Port number to listen to") 116 parser.add_argument("--keystore", default="clortho.dat", help="File to store keys in") 117 118 return parser 119 120 def read_keystore(filename: str) -> KeystoreType: 121 if not os.path.exists(filename): 122 return {} 123 124 with open(filename, "rb") as f: 125 try: 126 return cast(KeystoreType, pickle.load(f)) 127 except EOFError: 128 return {} 129 130 def clean_exit(signame: str, loop: AbstractEventLoop, filename: str, keystore: Dict[str, Dict[str, str]]) -> None: 131 print("got signal %s, exiting" % signame) 132 save_keystore(filename, keystore) 133 134 loop.stop() 135 136 def handle_usr1(filename: str, keystore: Dict[str, Dict[str, str]]) -> None: 137 print("Got USR1 signal, saving keystore") 138 save_keystore(filename, keystore) 139 140 def hourly_save_keystore(loop: AbstractEventLoop, filename: str, keystore: Dict[str, Dict[str, str]]) -> None: 141 save_keystore(filename, keystore) 142 loop.call_later(3600, hourly_save_keystore, loop, filename, keystore) 143 144 def save_keystore(filename: str, keystore: Dict[str, Dict[str, str]]) -> None: 145 #TODO: Write to a tempfile first, rename to target 146 with open(filename, "wb") as f: 147 pickle.dump(keystore, f, pickle.HIGHEST_PROTOCOL) 148 149 def main(args: Namespace) -> None: 150 keystore = read_keystore(args.keystore) 151 152 loop = asyncio.get_event_loop() 153 for signame in ('SIGINT', 'SIGTERM'): 154 loop.add_signal_handler(getattr(signal, signame), clean_exit, *[signame, loop, args.keystore, keystore]) 155 loop.add_signal_handler(getattr(signal, 'SIGUSR1'), handle_usr1, *[args.keystore, keystore]) 156 157 # Start saving the keys every hour 158 loop.call_later(3600, hourly_save_keystore, loop, args.keystore, keystore) 159 160 loop.run_until_complete(init(loop, args.host, int(args.port), keystore)) 161 loop.run_forever() 162 163 if __name__ == '__main__': 164 main(setup_parser().parse_args())