Source code for sirepo.srunit_servers

"""start servers for unit tests

:copyright: Copyright (c) 2024 RadiaSoft LLC.  All Rights Reserved.
:license: http://www.apache.org/licenses/LICENSE-2.0.html
"""

import contextlib

# limit sirepo/pykern global imports


[docs] @contextlib.contextmanager def api_and_supervisor(pytest_req, fc_args): from pykern import pkunit, pkjson from pykern.pkcollections import PKDict from pykern.pkdebug import pkdlog, pkdp import os, requests, subprocess, time fc_args.pksetdefault( cfg=PKDict, sim_types=None, append_package=None, empty_work_dir=True, ) def _config_sbatch_supervisor_env(env): from pykern.pkcollections import PKDict import os import pykern.pkio import pykern.pkunit import re h = "localhost" k = pykern.pkio.py_path("~/.ssh/known_hosts").read() m = re.search("^{}.*$".format(h), k, re.MULTILINE) assert bool(m), "You need to ssh into {} to get the host key".format(h) env.pkupdate( SIREPO_JOB_DRIVER_MODULES="local:sbatch", SIREPO_JOB_DRIVER_SBATCH_CORES=os.getenv( "SIREPO_JOB_DRIVER_SBATCH_CORES", "2", ), SIREPO_JOB_DRIVER_SBATCH_HOST=h, SIREPO_JOB_DRIVER_SBATCH_HOST_KEY=m.group(0), SIREPO_JOB_DRIVER_SBATCH_SIREPO_CMD="sirepo", SIREPO_JOB_DRIVER_SBATCH_SRDB_ROOT=str( pykern.pkunit.work_dir().join("/{sbatch_user}/sirepo") ), SIREPO_JOB_SUPERVISOR_SBATCH_POLL_SECS="2", ) def _ping_supervisor(uri): from requests.exceptions import ConnectionError for _ in _poll_iter(): d = None s = None try: r = requests.post(uri, json=None, allow_redirects=False) except ConnectionError: s = 0 else: if (s := r.status_code) != 200: break d = pkjson.load_any(r.text) if d.get("state") == "ok": return if "unable to connect" not in d.get("error", ""): break pkunit.restart_or_fail("uri={} status={} reply={}", uri, s, d) def _subprocess(cmd): p.append(subprocess.Popen(cmd, env=env, cwd=wd)) def _subprocess_setup(pytest_req, fc_args): """setup the supervisor""" import os from pykern.pkcollections import PKDict sbatch_module = "sbatch" in pytest_req.module.__name__ env = PKDict(os.environ) cfg = fc_args.cfg from pykern import pkunit, pkio p = _port() cfg.pkupdate( PYKERN_PKDEBUG_WANT_PID_TIME="1", SIREPO_PKCLI_JOB_SUPERVISOR_IP=pkunit.LOCALHOST_IP, SIREPO_PKCLI_JOB_SUPERVISOR_PORT=p, SIREPO_PKCLI_JOB_SUPERVISOR_USE_RELOADER="0", SIREPO_PKCLI_SERVICE_IP=pkunit.LOCALHOST_IP, SIREPO_PKCLI_SERVICE_USE_RELOADER="0", SIREPO_SRDB_ROOT=str(pkio.mkdir_parent(pkunit.work_dir().join("db"))), ) cfg.SIREPO_PKCLI_SERVICE_PORT = _port() for x in "DRIVER_LOCAL", "DRIVER_DOCKER", "API", "DRIVER_SBATCH": cfg[f"SIREPO_JOB_{x}_SUPERVISOR_URI"] = f"http://{pkunit.LOCALHOST_IP}:{p}" if sbatch_module: cfg.pkupdate(SIREPO_SIMULATION_DB_SBATCH_DISPLAY="testing@123") env.pkupdate(**cfg) from sirepo import srunit c = None u = [env.SIREPO_PKCLI_JOB_SUPERVISOR_PORT] c = srunit.http_client( env=env, empty_work_dir=fc_args.empty_work_dir, job_run_mode="sbatch" if sbatch_module else None, sim_types=fc_args.sim_types, port=env.SIREPO_PKCLI_SERVICE_PORT, ) u.append(c.port) cfg.SIREPO_FEATURE_CONFIG_SIM_TYPES = _sim_types(fc_args) for i in u: subprocess.run( ["kill -9 $(lsof -t -i :" + i + ") >& /dev/null"], shell=True ) if sbatch_module: # must be performed after fc initialized so work_dir is configured _config_sbatch_supervisor_env(env) return (env, c) env, c = _subprocess_setup(pytest_req, fc_args) wd = pkunit.work_dir() p = [] try: for k in sorted(env.keys()): if k.endswith("_PORT"): pkdlog("{}={}", k, env[k]) _subprocess(("sirepo", "service", "server")) # allow db to be created time.sleep(0.5) _subprocess(("sirepo", "job_supervisor")) _ping_supervisor(c.http_prefix + "/job-supervisor-ping") yield c finally: import sys for x in p: try: x.terminate() x.wait(timeout=4) except subprocess.TimeoutExpired: x.kill() x.wait(timeout=2)
[docs] @contextlib.contextmanager def sim_db_file(pytest_req): from pykern import pkunit, pkio, pkdebug, pkconfig from pykern.pkcollections import PKDict import os, signal, time port = _port() # must match job.unique.key token = "a" * 32 # must be valid uid uid = "simdbfil" # test won't pass if this is different from job.SIM_DB_FILE_URI uri = "/sim-db-file" def _server(): from pykern import pkasyncio, pkunit from pykern.pkcollections import PKDict from tornado import websocket from sirepo import job, sim_db_file def _token_for_user(*args, **kwargs): return token _sim_dir() setattr(sim_db_file.SimDbServer, "token_for_user", token) setattr(sim_db_file.SimDbServer, "_TOKEN_TO_UID", PKDict({token: uid})) l = pkasyncio.Loop() l.http_server( PKDict( uri_map=((job.SIM_DB_FILE_URI + "/(.+)", sim_db_file.SimDbServer),), tcp_port=port, tcp_ip=pkunit.LOCALHOST_IP, ) ) l.start() def _setup(pytest_req): """setup the supervisor""" import os c = PKDict( PYKERN_PKDEBUG_WANT_PID_TIME="1", SIREPO_AUTH_LOGGED_IN_USER=uid, SIREPO_SIMULATION_DB_LOGGED_IN_USER=uid, SIREPO_SIM_DB_FILE_SERVER_TOKEN=token, SIREPO_SIM_DB_FILE_SERVER_URI=f"http://{pkunit.LOCALHOST_IP}:{port}{uri}/{uid}/", SIREPO_SRDB_ROOT=str(pkio.mkdir_parent(pkunit.work_dir().join("db"))), ) os.environ.update(**c) pkconfig.reset_state_for_testing(c) def _sim_dir(): from sirepo import simulation_db, quest, srunit stype = srunit.SR_SIM_TYPE_DEFAULT simulation_db.user_path_root().join(simulation_db._uid_arg()).ensure(dir=1) with quest.start(in_pkcli=True) as qcall: # create examples simulation_db.simulation_dir(stype, qcall=qcall) pkio.write_text( simulation_db.simulation_lib_dir(stype).join("hello.txt"), "xyzzy", ) _setup(pytest_req) p = os.fork() if p == 0: try: pkdebug.pkdlog("start server") _server() except Exception as e: pkdebug.pkdlog("server exception={} stack={}", e, pkdebug.pkdexc()) finally: os._exit(0) try: _wait_for_port(port) yield None finally: os.kill(p, signal.SIGKILL)
def _poll_iter(): import os, time for _ in range(int(os.environ.get("SIREPO_SRUNIT_SERVERS_PING_TIMEOUT", 20)) * 10): yield time.sleep(0.1) def _wait_for_port(port): import socket for _ in _poll_iter(): try: socket.create_connection(("127.0.0.1", int(port)), timeout=1).close() return except OSError: pass raise RuntimeError(f"server did not start port={port}") def _port(): from pykern import pkunit from sirepo import const return str( pkunit.unbound_localhost_tcp_port( const.TEST_PORT_RANGE.start, const.TEST_PORT_RANGE.stop ), ) def _sim_types(fc_args): t = fc_args.sim_types if isinstance(t, (tuple, list)): return ":".join(t) return t