Source code for sirepo.request

"""Requests hold context for API calls

:copyright: Copyright (c) 2022 RadiaSoft LLC.  All Rights Reserved.
:license: http://www.apache.org/licenses/LICENSE-2.0.html
"""

from pykern.pkcollections import PKDict
from pykern.pkdebug import pkdp, pkdexc
import base64
import email.utils
import pykern.pkcompat
import pykern.pkjson
import sirepo.const
import sirepo.http_util
import sirepo.quest
import sirepo.util
import user_agents


#: We always use the same name for a file upload
_FORM_FILE_NAME = "file"


[docs] def init_quest(qcall, internal_req=None): def _class(): if qcall.bucket_unchecked_get("in_pkcli"): return _SRequestCLI elif "websocket" in str(type(internal_req)).lower(): return _SRequestWebSocket elif "tornado" in str(type(internal_req)): return _SRequestHTTP else: raise AssertionError(f"unknown internal_req={type(internal_req)}") _class().init_quest(qcall, internal_req=internal_req)
class _FormFileBase(PKDict): def __init__(self, sreq): super().__init__() f = self._get(sreq.internal_req) if not f: raise sirepo.util.Error("must supply a file", "no file in request={}", sreq) self.filename = f.filename # TODO(robnagler) need to garbage collect self._internal = f def as_str(self): return pykern.pkcompat.from_bytes(self.as_bytes()) class _FormFileHTTP(_FormFileBase): def as_bytes(self): return self._internal.body def _get(self, internal_req): res = internal_req.request.files.get(_FORM_FILE_NAME) if not res: return None if len(res) > 1: raise sirepo.util.BadRequest("too many files={} in form", len(res)) return res[0] class _FormFileWebSocket(_FormFileBase): def as_bytes(self): return self._internal.blob def _get(self, internal_req): return internal_req.get("attachment") class _SRequestBase(sirepo.quest.Attr): """Holds context for incoming requests""" # bare minimum to operate a child quest _INIT_QUEST_FOR_CHILD_KEYS = frozenset( ( "http_authorization", "http_headers", "http_method", "http_server_uri", "remote_addr", ) ) def body_as_bytes(self): return pykern.pkjson.dump_bytes(self.body_as_dict()) def body_as_dict(self): if "_body_as_dict" not in self: raise sirepo.util.BadRequest("no body") return self.get("_body_as_dict") def header_uget(self, key): return self.http_headers.get(key) def is_spider(self): a = self.header_uget("User-Agent") if not a: # assume it's a spider if there's no header return True if a.startswith(sirepo.const.SRUNIT_USER_AGENT): # So our unit tests can run return False if "python-requests" in a: # user_agents doesn't see Python's requests module as a bot. # The package robot_detection does see it, but we don't want to introduce another dependency. return True return user_agents.parse(a).is_bot def init_quest_for_child(self, child, parent): return ( super() .init_quest_for_child(child, parent) .pkupdate( # need to cascade current value, not parent.sreq.cookie_state cookie_state=parent.cookie.export_state(), # no data yet; set_body will change http_method="GET", ), ) def method_is_post(self): return self.http_method == "POST" def set_body(self, body): if "_body_as_dict" in self or "_body_as_bytes" in self: raise AssertionError(f"body may only be set once; new body={body}") if not isinstance(body, PKDict): raise AssertionError(f"invalid body type={type(body)} body={body}") self.http_method = "POST" self._body_as_dict = body class _SRequestCLI(_SRequestBase): @classmethod def init_quest(cls, qcall, internal_req): return cls(qcall, internal_req=internal_req).pkupdate( cookie_state=None, http_authorization=None, http_headers=PKDict(), http_method="GET", http_server_uri="http://localhost/", remote_addr="0.0.0.0", ) def set_log_user(self, log_user): pass class _SRequestHTTP(_SRequestBase): @classmethod def init_quest(cls, qcall, internal_req): def _parse_authorization(value): if not value: return None try: t, i = value.split(None, 1) t = t.lower() except Exception: raise sirepo.util.Forbidden( "unparseable authorization header={}", value ) if t != "basic": raise sirepo.util.Forbidden("unhandled authorization type={}", t) try: u, p = base64.b64decode(i).split(b":", 1) except Exception: raise sirepo.util.Forbidden( "unparseable authorization info={} type={}", i, t ) return PKDict( type=t, username=pykern.pkcompat.from_bytes(u), password=pykern.pkcompat.from_bytes(p), ) r = internal_req.request return cls(qcall, internal_req=internal_req).pkupdate( # Property that extracts the body so defer until use _body_as_bytes=lambda: r.body, cookie_state=r.headers.get("Cookie"), http_authorization=_parse_authorization(r.headers.get("Authorization")), http_headers=r.headers, http_method=r.method, http_request_uri=r.full_url(), http_server_uri=f"{r.protocol}://{r.host}/", remote_addr=sirepo.http_util.remote_ip(r), ) def body_as_bytes(self): if "_body_as_dict" in self: return super().body_as_bytes() if "_body_as_bytes" not in self: raise sirepo.util.BadRequest("no body") return self._body_as_bytes() def body_as_dict(self): def _content_type_eq(value): c = _content_type()._key if c is None: return False return c.lower() == value.lower() def _content_type(): if "_content_type" not in self: self._content_type = _parse_header( self.header_uget("Content-Type") or "" ) return self._content_type def _parse_header(line): r"""Parse a Content-type like header. Copied from tornado.httputil._parse_header Return the main content-type and a dictionary of options. >>> d = "form-data; foo=\"b\\\\a\\\"r\"; file*=utf-8''T%C3%A4st" >>> ct, d = _parse_header(d) >>> ct 'form-data' >>> d['file'] == r'T\u00e4st'.encode('ascii').decode('unicode_escape') True >>> d['foo'] 'b\\a"r' """ def _parseparam(s): # tornado.httputil._parseparam while s[:1] == ";": s = s[1:] end = s.find(";") while ( end > 0 and (s.count('"', 0, end) - s.count('\\"', 0, end)) % 2 ): end = s.find(";", end + 1) if end < 0: end = len(s) f = s[:end] yield f.strip() s = s[end:] parts = _parseparam(";" + line) key = next(parts) if len(key) == 0: return PKDict(_key=None) # decode_params treats first argument special, but we already stripped key params = [("Dummy", "value")] for p in parts: i = p.find("=") if i >= 0: name = p[:i].strip().lower() value = p[i + 1 :].strip() params.append((name, value)) decoded_params = email.utils.decode_params(params) decoded_params.pop(0) # get rid of the dummy again rv = PKDict(_key=key) for name, decoded_value in decoded_params: value = email.utils.collapse_rfc2231_value(decoded_value) if len(value) >= 2 and value[0] == '"' and value[-1] == '"': value = value[1:-1] rv[name] = value return rv if "_body_as_dict" in self: return super().body_as_dict() if not _content_type_eq(pykern.pkjson.MIME_TYPE): raise sirepo.util.BadRequest( "Content-Type={} must be {}", self.header_uget("Content-Type"), pykern.pkjson.MIME_TYPE, ) return pykern.pkjson.load_any(self.body_as_bytes()) def form_file_get(self): return _FormFileHTTP(self) def form_get(self, name, default): return self.internal_req.get_argument(name, default) def set_log_user(self, log_user): self.internal_req.sr_set_log_user(log_user) class _SRequestWebSocket(_SRequestBase): @classmethod def init_quest(cls, qcall, internal_req): b = internal_req.get("body_as_dict") return cls(qcall, internal_req=internal_req).pkupdate( _body_as_dict=b, # This is not use except in api_errorLogging, which shouldn't happen much cookie_state=internal_req.handler.cookie_state, http_authorization=None, http_headers=internal_req.headers, http_method="POST" if b else "GET", http_server_uri=internal_req.handler.http_server_uri, remote_addr=internal_req.handler.remote_addr, ) def form_file_get(self): return _FormFileWebSocket(self) def form_get(self, name, default): return self.body_as_dict().get(name, default) def set_log_user(self, log_user): self.internal_req.set_log_user(log_user)