Source code for sirepo.srunit

"""Support for unit tests

:copyright: Copyright (c) 2016 RadiaSoft LLC.  All Rights Reserved.
:license: http://www.apache.org/licenses/LICENSE-2.0.html
"""

from pykern import pkcollections
from pykern import pkcompat
from pykern import pkjson
from pykern.pkcollections import PKDict
import base64
import contextlib
import copy
import json
import os.path
import pykern.pkinspect
import re
import requests
import threading
import urllib

#: Default "app"
MYAPP = "myapp"

#: Matches javascript-redirect.html
_JAVASCRIPT_REDIRECT_RE = re.compile(r'window.location = "([^"]+)"')

#: set by conftest.py
CONFTEST_DEFAULT_CODES = None

SR_SIM_TYPE_DEFAULT = MYAPP

SR_SIM_NAME_DEFAULT = "Scooby Doo"

#: Sirepo db dir
_DB_DIR = "db"

#: How many checks and how long to wait between checks for scenarios
_ITER_SLEEP = PKDict(
    slurm=PKDict(
        count=5,
        sleep_secs=5,
    )
)


__cfg = None


[docs] class JobAgentError(RuntimeError): pass
[docs] class ServerRequestError(RuntimeError): pass
[docs] def http_client( env=None, sim_types=None, job_run_mode=None, empty_work_dir=True, port=None ): """Create an http_client that talks to server""" t = sim_types or CONFTEST_DEFAULT_CODES if t: if isinstance(t, (tuple, list)): t = ":".join(t) env.SIREPO_FEATURE_CONFIG_SIM_TYPES = t from pykern import pkconfig pkconfig.reset_state_for_testing(env) from pykern import pkunit if empty_work_dir: pkunit.empty_work_dir() else: pkunit.work_dir() setup_srdb_root(cfg=env) from sirepo import modules modules.import_and_init("sirepo.uri") return _TestClient(env=env, job_run_mode=job_run_mode, port=port)
[docs] @contextlib.contextmanager def quest_start(want_user=False, cfg=None): if cfg is None: cfg = {} setup_srdb_root(cfg=cfg) from pykern import pkconfig pkconfig.reset_state_for_testing(cfg) from sirepo import quest with quest.start(in_pkcli=True) as qcall: qcall.auth_db.create_or_upgrade() if want_user: qcall.auth.srunit_user() yield qcall
[docs] def setup_srdb_root(cfg=None): from pykern import pkunit, pkio import os e = cfg if not e: e = os.environ e.update( SIREPO_SRDB_ROOT=str(pkio.mkdir_parent(pkunit.work_dir().join(_DB_DIR))), )
[docs] def template_import_file(sim_type, path, arguments=None): """Call `stateful_compute_import_file` Args: sim_type (str): template name path (object): if path is a str, will be joined with `data_dir`, arguments (object): literal passed as import_file_arguments Returns: PKDict: imported_data if successful; otherwise, error state """ from pykern import pkio, pkunit from sirepo import template if isinstance(path, str): path = pkunit.data_dir().join(path) return template.import_module(sim_type).stateful_compute_import_file( data=PKDict( args=PKDict( basename=path.basename, ext_lower=path.ext.lower(), file_as_str=pkio.read_text(path), folder="/import_test", import_file_arguments=arguments, purebasename=path.purebasename, ), ), )
class _TestClient: def __init__(self, env, job_run_mode, port): from sirepo import feature_config super().__init__() self._init_args = PKDict(env=env, job_run_mode=job_run_mode, port=port) self.sr_job_run_mode = job_run_mode self.sr_sbatch_logged_in = False self.sr_sim_type = None self.sr_uid = None self.port = port self.http_prefix = f"http://{env.SIREPO_PKCLI_SERVICE_IP}:{port}" self._session = requests.Session() self.cookie_jar = self._session.cookies self._threads = PKDict() if feature_config.cfg().ui_websocket: self._websocket = _WebSocket(self) def add_plan_trial_role(self): from sirepo.pkcli import roles from sirepo import auth_role roles.add_plan(self.sr_uid, auth_role.ROLE_PLAN_TRIAL) def assert_post_will_redirect(self, expect_re, *args, **kwargs): rv = self.sr_post(*args, **kwargs) rv.assert_http_redirect(expect_re) return rv def error_or_sr_exception(self): """Hack to check for SRException or Error Works around the websocket/http differences in this module. What should happen is that an exception is always created and can be checked later. This means fixing some tests. This is easier for now. """ from pykern.pkunit import pkexcept return pkexcept(r"(SRException|Error)\(") def iter_sleep(self, kind, op_desc): import time from pykern import pkunit def _setup(): rv = PKDict(_ITER_SLEEP[kind]) rv.countdown = range(rv.count, -1, -1) return rv s = _setup() for i in s.countdown: yield if i > 0: time.sleep(s.sleep_secs) else: pkunit.pkfail( "timeout secs={} kind={} op_desc={}", s.sleep_secs * s.count, kind, op_desc, ) @contextlib.contextmanager def sr_adjust_time(self, days): from sirepo import srtime def _do(days): srtime.adjust_time(days) self.sr_get_json("adjustTime", params=PKDict(days=days)) _do(days) yield _do(0) def sr_animation_run(self, data, compute_model, reports=None, **kwargs): from pykern import pkunit from pykern.pkcollections import PKDict from pykern.pkdebug import pkdp, pkdlog import re run = self.sr_run_sim(data, compute_model, **kwargs) for r, a in reports.items(): if "runSimulation" in a: f = self.sr_run_sim(data, r) for k, v in a.items(): m = re.search("^expect_(.+)", k) if m: pkunit.pkre( v(i) if callable(v) else v, str(f.get(m.group(1))), ) continue if "frame_index" in a: c = [a.get("frame_index")] else: c = range(run.get(a.get("frame_count_key", "frameCount"))) assert c, "frame_count_key={} or frameCount={} is zero".format( a.get("frame_count_key"), a.get("frameCount"), ) pkdlog("frameReport={} count={}", r, c) from sirepo import sim_data s = sim_data.get_class(self.sr_sim_type) for i in c: pkdlog("frameIndex={} frameCount={}", i, run.get("frameCount")) f = self.sr_get_json( "simulationFrame", PKDict(frame_id=s.frame_id(data, run, r, i)), ) for k, v in a.items(): m = re.search("^expect_(.+)", k) if m: pkunit.pkre( v(i) if callable(v) else v, str(f.get(m.group(1))), ) def sr_auth_state(self, **kwargs): """Gets authState and parses Returns: dict: parsed auth_state """ from pykern import pkunit, pkcollections from pykern.pkdebug import pkdp m = re.search( r"(\{.*\})", pkcompat.from_bytes(self.sr_get("authState").data), ) s = pkjson.load_any(m.group(1)) for k, v in kwargs.items(): pkunit.pkeq( v, s[k], "key={} expected={} != actual={}: auth_state={}", k, v, s[k], s, ) return s def sr_clone(self): res = self.__class__(**self._init_args) res._session = copy.deepcopy(self._session) res.cookie_jar = res._session.cookies return res def sr_db_dir(self): """Database directory of the server""" from pykern import pkunit return pkunit.work_dir().join(_DB_DIR) def sr_email_confirm(self, resp, display_name=None): from pykern.pkdebug import pkdlog self.sr_get(resp.uri) pkdlog(resp.uri) m = re.search(r"/(\w+)$", resp.uri) assert bool(m) r = PKDict(token=m.group(1)) if display_name: r.displayName = display_name self.sr_post(resp.uri, r, raw_response=True) def sr_email_login(self, email, sim_type=None): self._uid_clear() self.sr_sim_type_set(sim_type) self.sr_logout() r = self.sr_post( "authEmailLogin", PKDict(email=email, simulationType=self.sr_sim_type), ) self.sr_email_confirm(r, display_name=email) return self._uid_verify_and_save() def sr_get( self, route_or_uri, params=None, query=None, raw_response=True, **kwargs ): """Gets a request to route_or_uri to server Args: route_or_uri (str): string name of route or uri if contains '/' (http:// or '/foo') params (dict): optional params to route_or_uri Returns: SReply: reply object """ return self._do_req( route_or_uri, params=params, query=query, op=self._get, raw_response=raw_response, **kwargs, ) def sr_get_json(self, route_or_uri, params=None, query=None, **kwargs): """Gets a request to route_or_uri to server Args: route_or_uri (str): identifies route in schema-common.json params (dict): optional params to route_or_uri query (dict): query to be added to uri headers (dict): headers to apply (in kwargs) Returns: object: Parsed JSON result """ return self._do_req( route_or_uri, params=params, query=query, op=self._get, raw_response=False, **kwargs, ) def sr_get_root(self, sim_type=None, **kwargs): """Gets root app for sim_type Args: sim_type (str): app name ['myapp' or default type] Returns: SReply: reply object """ self.sr_sim_type_set(sim_type) return self._do_req( "root", params={"path_info": self.sr_sim_type}, query=None, op=self._get, raw_response=True, **kwargs, ) def sr_login_as_guest(self, sim_type=None): """Sets up a guest login Args: sim_type (str): simulation type ['myapp'] Returns: str: new user id """ self._uid_clear() self.sr_sim_type_set(sim_type) self.cookie_jar.clear() # Get a cookie self.sr_get("authState") self.sr_get("authGuestLogin", {"simulation_type": self.sr_sim_type}) return self._uid_verify_and_save() def sr_logout(self): """Logout but leave cookie in place Returns: object: self """ self._uid_clear() self.sr_get("authLogout", PKDict(simulation_type=self.sr_sim_type)) return self def sr_post(self, route_or_uri, data, params=None, raw_response=False, **kwargs): """Posts JSON data to route_or_uri to server File parameters are posted as `file_handle`. Args: route_or_uri (str): string name of route or uri if contains '/' (http:// or '/foo') data (object): will be formatted as form data params (dict): optional params to route_or_uri Returns: object: Parsed JSON result """ kwargs["json"] = data return self._do_req( route_or_uri, params=params, query={}, op=self._post, raw_response=raw_response, **kwargs, ) def sr_post_form( self, route_or_uri, data, params=None, raw_response=False, file=None, **kwargs ): """Posts form data to route_or_uri to server with data Args: route_or_uri (str): identifies route in schema-common.json data (dict): content to post params (dict): optional params to route_or_uri file (object): if str, will look in data_dir, else assumed py.path Returns: object: Parsed JSON result """ from pykern import pkunit, pkconfig k = PKDict(data=data, **kwargs) if file: p = file if isinstance(p, pkconfig.STRING_TYPES): p = pkunit.data_dir().join(p) k.file_handle = open(str(p), "rb") return self._do_req( route_or_uri, params, query=PKDict(), op=self._post, raw_response=raw_response, **k, ) def sr_run_sim(self, data, model, expect_completed=True, timeout=20, **post_args): from pykern import pkunit from pykern.pkdebug import pkdlog, pkdexc from sirepo import job import time # Needs to be specific def _assert_no_mpiexec(): """Ensure mpiexec is not running in work_dir""" import subprocess o = pkcompat.from_bytes( subprocess.check_output(["ps", "axwwe"], stderr=subprocess.STDOUT), ) d = str(pkunit.work_dir()) o = list( filter( # regex is problematic because "d" may have specials so just do this lambda x: "mpiexec" in x and d in x, o.split("\n"), ), ) if o: pkdlog('found "mpiexec" after cancel in ps={}', "\n".join(o)) # this exception won't be seen because in finally raise AssertionError("cancel failed") cancel = None try: r = self.sr_post( "runSimulation", PKDict( models=data.models, report=model, simulationId=data.models.simulation.simulationId, simulationType=data.simulationType, ).pkupdate(**post_args), ) if r.state == "completed": return r pkunit.pkok( r.state in ("running", "pending"), "runSimulation did not start: reply={}", r, ) cancel = r.nextRequest for _ in range(timeout): r = self.sr_post("runStatus", r.nextRequest) if r.state in job.EXIT_STATUSES: cancel = None break pkunit.pkok( "nextRequest" in r, "nextRequest missing from reply={}", r, ) pkdlog("reply.state={}", r.get("state")) # not asyncio.sleep: in a synchronous unit test time.sleep(1) else: pkunit.pkok(not expect_completed, "did not complete: runStatus={}", r) if expect_completed: pkunit.pkeq("completed", r.state, "reply={}", r) return r finally: if cancel: pkdlog("runCancel") self.sr_post("runCancel", cancel) _assert_no_mpiexec() def sr_sbatch_creds(self): return PKDict(sbatchCredentials=_cfg().sbatch.copy()) def sr_sbatch_login(self, compute_model, sim_name): from pykern.pkunit import pkexcept if self.sr_sbatch_logged_in: return d = self.sr_sim_data(sim_name, compute_model=compute_model) with pkexcept("SRException.*no-creds"): # Must try to run sim first to seed job_supervisor.db self.sr_run_sim(d, compute_model, expect_completed=False) self.sr_post( "sbatchLogin", PKDict( report=compute_model, simulationId=d.models.simulation.simulationId, simulationType=d.simulationType, ).pkupdate(self.sr_sbatch_creds()), raw_response=True, ).assert_success() self.sr_sbatch_logged_in = True def sr_sim_data(self, sim_name=None, sim_type=None, compute_model=None): """Return simulation data by name Args: sim_name (str): case sensitive name ['Scooby Doo'] sim_type (str): app ['myapp'] compute_model (str): what model is selected to set jobRunMode [None] Returns: dict: data """ from pykern import pkunit from pykern.pkdebug import pkdpretty, pkdp self.sr_sim_type_set(sim_type) if not sim_name: sim_name = SR_SIM_NAME_DEFAULT d = self.sr_post( "listSimulations", PKDict( simulationType=self.sr_sim_type, search=PKDict({"simulation.name": sim_name}), ), ) assert 1 == len(d), "listSimulations name={} returned count={}".format( sim_name, len(d) ) d = d[0].simulation res = self.sr_get_json( "simulationData", PKDict( simulation_type=self.sr_sim_type, pretty="0", simulation_id=d.simulationId, ), ) pkunit.pkeq(sim_name, res.models.simulation.name) if compute_model and self.sr_job_run_mode: res.models[compute_model].jobRunMode = self.sr_job_run_mode return res def sr_sim_type_set(self, sim_type=None): """Set `sr_sim_type Args: sim_type (str): app name Returns: object: self """ self.sr_sim_type = sim_type or self.sr_sim_type or SR_SIM_TYPE_DEFAULT return self def sr_thread_join(self, *names): from pykern import pkunit res = PKDict() for n in names or list(self._threads.keys()): t = self._threads[n] t.join() del self._threads[n] pkunit.pkok(t.sr_ok, f"thread={n} got an exception") res[n] = t.sr_res return res def sr_thread_start(self, name, op, **kwargs): t = self._threads[name] = _Thread( op=op, kwargs=PKDict(kwargs).pkupdate(fc=self.sr_clone()), ) t.start() def sr_user_dir(self, uid=None): """User's db dir""" if not uid: uid = self.sr_auth_state().uid return self.sr_db_dir().join("user", uid) def timeout_secs(self): if not hasattr(self, "_timeout_secs"): self._timeout_secs = round(_cfg().cpu_div * 0.5) return self._timeout_secs def _do_req(self, route_or_uri, params, query, op, raw_response, **kwargs): """Make request and parse result Args: route_or_uri (str): string name of route or uri if contains '/' (http:// or '/foo') params (dict): parameters to apply to route op (func): how to request Returns: object: parsed JSON result """ from pykern.pkdebug import pkdlog, pkdexc, pkdc, pkdp from pykern import pkjson from sirepo import uri, util, reply redirects = kwargs.setdefault("__redirects", 0) + 1 assert redirects <= 5 kwargs["__redirects"] = redirects u = None r = None e = None try: u = uri.server_route(route_or_uri, params, query) pkdc("uri={}", u) r = op(u, **kwargs) pkdc( "status={} data={}", r.status_code, "<snip-file>" if "download-data-file" in u else r.data, ) res = r.process(raw_response) if r.status_code in (301, 302, 303, 305, 307, 308): if kwargs.get("redirect", True): # Execute the redirect pkdlog( "redirect status={} location={}", r.status_code, r.header_get("Location"), ) return self._do_req( r.header_get("Location"), params=None, query=None, op=self._get, raw_response=raw_response, __redirects=redirects, ) return res except util.ReplyExc: raise except Exception as e2: pkdlog( "Exception: {}: msg={} uri={} status={} data={} stack={}", type(e2), e2, u, r and r.status_code, r and r.data, pkdexc(simplify=True), ) e = e2 if e: # Avoids long stack traces raise ServerRequestError(str(e)) def _get(self, uri, **kwargs): return self._requests_op("get", uri, kwargs=PKDict(kwargs)) def _post(self, uri, data=None, json=None, file_handle=None, **kwargs): assert (data is None) != (json is None) k = PKDict(kwargs) if data is not None: k.data = data else: k.json = json if file_handle is not None: k.files = PKDict(file=file_handle) return self._requests_op("post", uri, k) def _requests_op(self, op, uri, kwargs): from pykern.pkdebug import pkdlog, pkdexc, pkdc, pkdp from sirepo import const if hasattr(self, "_websocket") and not kwargs.get("want_http"): r = self._websocket.send(op, uri, **kwargs) if r is not None: return r if "headers" not in kwargs: kwargs.headers = PKDict() kwargs.headers.setdefault( "User-Agent", f"{const.SRUNIT_USER_AGENT} {pykern.pkinspect.caller()}", ) u = self._uri(uri) # Delete all of our kwargs kwargs.pkdel("want_http") kwargs.pkdel("__redirects") kwargs.pkdel("redirect") try: return _HTTPResponse( getattr(self._session, op)(u, allow_redirects=False, **kwargs), self, ) except requests.exceptions.ConnectionError as e: from pykern.pkdebug import pkdlog pkdlog("op={} uri={} headers={}", op, u, kwargs.headers) raise def _uid_clear(self): self.sr_uid = None def _uid_verify_and_save(self): self.sr_uid = self.sr_auth_state( needCompleteRegistration=False, isLoggedIn=True ).uid return self.sr_uid def _uri(self, uri): from pykern.pkdebug import pkdp u = urllib.parse.urlparse(uri) if u.scheme: return uri return self.http_prefix + uri class _Response: def assert_http_redirect(self, expect_re): from pykern import pkunit, pkdebug self.assert_http_status(302) pkunit.pkre(expect_re, self.header_get("Location")) def assert_http_status(self, expect): from pykern import pkunit pkunit.pkeq( expect, self.status_code, "expect={} status={} data={}", expect, self.status_code, self.data, ) def assert_success(self): from sirepo import util self.assert_http_status(200) d = self._assert_not_exception() if isinstance(d, dict) and ( d.get("state") == "error" or d.get("error") is not None ): raise util.Error( d.get("error", "internal error"), "reply in error reply={}", d, ) return d def change_to_redirect(self, uri): self._headers = PKDict(Location=uri) self.data = "" self.mimetype = "text/plain" self.status_code = 302 return self def header_get(self, name): return self._headers[name] def _convert_http_status(self, sr_args): """set self.status_code Returns: bool: True if is a redirect """ from pykern import pkdebug from sirepo import uri if sr_args.routeName == "httpException": self.status_code = sr_args.params.code return False elif sr_args.routeName == "httpRedirect": self.change_to_redirect(sr_args.params.uri) return True elif not uri.is_sr_exception_only( self._test_client.sr_sim_type, sr_args.routeName ): self.change_to_redirect( uri.local_route( self._test_client.sr_sim_type, sr_args.routeName, sr_args.params ) ) return True return False def _maybe_json_decode(self): from pykern import pkjson if self.mimetype == pkjson.MIME_TYPE: return pkjson.load_any(self.data) return self.data class _HTTPResponse(_Response): def __init__(self, reply, test_client): from pykern.pkdebug import pkdlog, pkdexc, pkdc, pkdp self.status_code = reply.status_code self.data = reply.content self._test_client = test_client c = reply.headers.get("content-type") self.mimetype = c.split(";")[0].strip() if c else "" self._headers = reply.headers def process(self, raw_response): from pykern import pkdebug from sirepo import util, reply # Emulate code in sirepo.js to deal with redirects if self.status_code == 200: if self.mimetype == "text/html": m = _JAVASCRIPT_REDIRECT_RE.search(pkcompat.from_bytes(self.data)) if m: if m.group(1).endswith("#/error"): raise util.Error( PKDict(error="server error uri={}".format(m.group(1))), ) return self.change_to_redirect(m.group(1)) else: d = self._maybe_json_decode() if ( isinstance(d, dict) and d.get("state") == reply.SR_EXCEPTION_STATE and self._convert_http_status(d.srException) ): return self if raw_response: return self return self._assert_not_exception() def _assert_not_exception(self): from pykern import pkjson from sirepo import reply, util d = self._maybe_json_decode() if isinstance(d, dict) and d.get("state") == reply.SR_EXCEPTION_STATE: # Treat SRException as a real exception (so we don't ignore them) raise util.SRException( d.srException.routeName, d.srException.params, ) if self.status_code != 200: raise util.Error( f"unexpected status={self.status_code}", "reply={}", self.data, ) return d class _Thread(threading.Thread): def __init__(self, op, kwargs): super().__init__() self.sr_ok = False self._op = op self._kwargs = kwargs def run(self): self.sr_res = self._op(**self._kwargs) self.sr_ok = True class _WebSocket: # /download is special below _HTTP_RE = re.compile(r"^(?:/download-|https?:)") _ANCHOR_RE = re.compile(r"(/.*?)#") def __init__(self, test_client): self._enabled = False self._connection = None self.test_client = test_client self._is_async = None def save_cookie_hash(self): self._cookie_hash = self._hash_cookies() def send(self, op, uri, headers=None, data=None, files=None, json=None, **kwargs): from pykern.pkdebug import pkdp, pkdlog import msgpack from sirepo import const from sirepo import uri as sirepo_uri def _combine_req(encoded_uri): m = PKDict( header=PKDict( kind=const.SCHEMA_COMMON.websocketMsg.kind.httpRequest, uri=sirepo_uri.decode_to_str(encoded_uri), version=const.SCHEMA_COMMON.websocketMsg.version, # POSIT: uri_router will look for this in_dev_mode srunit_caller=str(pykern.pkinspect.caller()), ), ) if op == "get": assert ( data is None and json is None and files is None ), f"GET does not support content uri={uri}" else: assert ( data is None or json is None ), f"only json or data may be supplied uri={uri}" m.content = json if data is None else data if files: m.attachment = PKDict( blob=files.file.read(), filename=os.path.basename(files.file.name), ) return m def _must_be_http(uri): # POSIT: /auth- match like sirepo.js msgRouter and https?: # for browser click on email msg. If there are headers, # it's basic auth. /download is special, because we # don't have a way of saving a file (easily) in sirepo.js. if headers: if not headers.get("Authorization"): raise AssertionError(f"restricted use of headers={headers}") # basic auth return True if self._HTTP_RE.search(uri): return True if not self._enabled: self.start() if self._cookie_hash != self._hash_cookies(): # Cookies have changed via http so need to reset websocket state self.stop() return True return False if _must_be_http(uri): pkdlog("via http: uri={} websocket.enabled={}", uri, self._enabled) return None assert uri[0] == "/", f"uri={uri} must begin with '/'" m = self._ANCHOR_RE.search(uri) return self._send(_combine_req(m.group(1) if m else uri)) def start(self): assert not self._enabled self.save_cookie_hash() self._enabled = True def stop(self): if not self._enabled: return c = self._connection self._connection = None self._enabled = False self.req_seq = None if c: c.close() def _hash_cookies(self): """Reaches inside cookiejar. No other way...""" rv = 0 for d, domains in self.test_client.cookie_jar._cookies.items(): rv += hash(d) for p, paths in domains.items(): rv += hash(p) for n, c in paths.items(): rv += hash(n) + hash(c.value) return rv def _send(self, msg): from websockets.sync import client from sirepo import job from pykern import pkdebug def _connect(): r = requests.Request( url=self.test_client.http_prefix + "/ws", cookies=self.test_client.cookie_jar, ).prepare() self._connection = client.connect( r.url.replace("http", "ws"), additional_headers=r.headers, max_size=job.cfg().max_message_bytes, ) self.req_seq = 1 def _connect_and_send(): if not self._connection: _connect() self.req_seq += 1 msg.header.reqSeq = self.req_seq self._connection.send(_WebSocketRequest(msg).buf) for _ in range(10): r = _WebSocketResponse( self._connection.recv(timeout=self.test_client.timeout_secs()), msg, self, ) if not r.is_async_msg: return r return "too many asyncMsg _WebSocketResponses" rv = None try: rv = _connect_and_send() except Exception as e: pkdebug.pkdlog( "communication error={} stack={}", e, pkdebug.pkdexc(simplify=True) ) if rv is None: # Avoids complex stack traces raise JobAgentError("Check job_agent.log for errors") if isinstance(rv, str): raise AssertionError(rv) return rv class _WebSocketRequest: def __init__(self, msg): import msgpack p = msgpack.Packer(autoreset=False) p.pack(msg.header) for x in "content", "attachment": if x in msg: p.pack(msg[x]) self.buf = p.bytes() p.reset() class _WebSocketResponse(_Response): def __init__(self, msg, req_msg, websocket): from pykern.pkdebug import pkdp from sirepo import const, util import msgpack self._req_msg = req_msg self._websocket = websocket self._test_client = websocket.test_client u = msgpack.Unpacker(object_pairs_hook=pkcollections.object_pairs_hook) u.feed(msg) h = u.unpack() self.data = u.unpack() if u.tell() < len(msg) else None assert ( const.SCHEMA_COMMON.websocketMsg.version == h.version ), f"invalid msg.version={h.version}" self.is_async_msg = const.SCHEMA_COMMON.websocketMsg.kind.asyncMsg == h.kind if self.is_async_msg: getattr(self, "_async_msg_" + h.method)(self.data) return assert ( req_msg.header.reqSeq == h.reqSeq ), f"invalid msg.reqSeq={h.reqSeq} expect={req_msg.header.reqSeq}" self._headers = PKDict() self.mimetype = "application/octet" self.status_code = 200 if const.SCHEMA_COMMON.websocketMsg.kind.httpReply == h.kind: self._sr_exception = None elif const.SCHEMA_COMMON.websocketMsg.kind.srException == h.kind: self._sr_exception = util.SRException(self.data.routeName, self.data.params) else: raise AssertionError(f"invalid msg.kind={h.kind}") def assert_http_status(self, expect): from pykern import pkunit, pkdebug if expect != 200 and not self._sr_exception: pkunit.pkfail("not an srException data={}", self.data) super().assert_http_status(expect) def header_get(self, name): return self._headers[name] def process(self, raw_response): from sirepo import util, reply def _data_to_json(): from pykern import pkjson if not isinstance(self.data, (dict, list)): return self if self._sr_exception: self.data = PKDict( state=reply.SR_EXCEPTION_STATE, srException=self.data, ) # Opposite of what you expect: raw_response is encoded json so need to encode self.mimetype = pkjson.MIME_TYPE self.data = pkjson.dump_pretty(self.data, pretty=False) return self if not self._sr_exception: if raw_response: return _data_to_json() return self.data if self._convert_http_status(self._sr_exception.sr_args): return self if raw_response: return _data_to_json() # Always raises, because _sr_exception is truthy at this point self._assert_not_exception() def _assert_not_exception(self): from sirepo import util from pykern.pkdebug import pkdp # All websocket errors are sr exceptions if self._sr_exception: raise self._sr_exception if self.status_code != 200: raise util.Error(f"unexpected status={self.status_code}", "reply={}", d) return self._maybe_json_decode() def _async_msg_setCookies(self, content): self._test_client.cookie_jar.extract_cookies( response=PKDict(info=lambda: PKDict(get_all=lambda x, y: content)), request=PKDict( unverifiable=False, get_full_url=lambda: self._test_client.http_prefix + self._req_msg.header.uri, ), ) self._websocket.save_cookie_hash() def _cfg(): global __cfg if __cfg: return __cfg from pykern import pkconfig import getpass u = getpass.getuser() __cfg = pkconfig.init( # 50 is based on a 2.2 GHz server cpu_div=(50, int, "cpu speed divisor to compute timeouts"), sbatch=dict( password=(u, str, "password to login to sbatch"), username=(u, str, "user to login to sbatch"), ), ) return __cfg