"""Agent for managing the execution of jobs.
:copyright: Copyright (c) 2019 RadiaSoft LLC. All Rights Reserved.
:license: http://www.apache.org/licenses/LICENSE-2.0.html
"""
from pykern import pkconfig
from pykern import pkio
from pykern import pkjson
from pykern.pkcollections import PKDict
from pykern.pkdebug import pkdlog, pkdp, pkdexc, pkdc, pkdformat
from sirepo import job
from sirepo.template import template_common
import copy
import datetime
import os
import re
import signal
import sirepo.const
import sirepo.feature_config
import sirepo.modules
import sirepo.nersc
import sirepo.tornado
import sirepo.util
import socket
import subprocess
import time
import tornado.gen
import tornado.ioloop
import tornado.iostream
import tornado.locks
import tornado.netutil
import tornado.process
import tornado.websocket
#: Long enough for job_cmd to write result in run_dir
_TERMINATE_SECS = 3
#: How often to poll in loop()
_LOOP_RETRY_SECS = 1
#: How many retries before the agent kills itself
_MAX_LOOP_RETRY = 10
#: Reasonable over the Internet connection
_CONNECT_SECS = 10
_IN_FILE = "in-{}.json"
_PID_FILE = "job_agent.pid"
_SBATCH_STATUS_FILE = "sbatch_status.json"
_MIN_SBATCH_POLL_SECS = 5
_MAX_SBATCH_QUERY_TRIES = 5
_cfg = None
_DEV_PYTHON_PATH = ":".join(
str(pkio.py_path(sirepo.const.DEV_SRC_RADIASOFT_DIR).join(p))
for p in ("sirepo", "pykern")
)
[docs]
def start():
# TODO(robnagler) commands need their own init hook like the server has
global _cfg
_cfg = pkconfig.init(
agent_id=pkconfig.Required(str, "id of this agent"),
# POSIT: same as job_driver.DriverBase._agent_env
dev_source_dirs=(
pkconfig.in_dev_mode(),
bool,
f"set PYTHONPATH={_DEV_PYTHON_PATH}",
),
fastcgi_sock_dir=(
pkio.py_path("/tmp"),
pkio.py_path,
"directory of fastcfgi socket, must be less than 50 chars",
),
mpich_shm_clean_up=(False, bool, "mpich4 orphans shm; see sirepo#7741"),
no_hdf5_do_mpi_file_sync=(False, bool, "turn off hdf5 file sync"),
start_delay=(0, pkconfig.parse_seconds, "delay startup in internal_test mode"),
global_resources_server_token=pkconfig.Required(
str,
"credential for global resources server",
),
global_resources_server_uri=pkconfig.Required(
str,
"how to connect to global resources",
),
sim_db_file_server_token=pkconfig.Required(
str,
"credential for sim db files",
),
sim_db_file_server_uri=pkconfig.Required(
str,
"how to connect to sim db files",
),
supervisor_uri=pkconfig.Required(
str,
"how to connect to the supervisor",
),
)
pkdlog("{}", _cfg)
if pkconfig.channel_in_internal_test() and _cfg.start_delay:
pkdlog("start_delay={}", _cfg.start_delay)
# Not asyncio.sleep: delay to startup tornado for testing
time.sleep(_cfg.start_delay)
i = tornado.ioloop.IOLoop.current()
d = _Dispatcher()
def s(*args):
return i.add_callback_from_signal(_terminate, d)
signal.signal(signal.SIGTERM, s)
signal.signal(signal.SIGINT, s)
i.spawn_callback(d.loop)
i.start()
[docs]
def start_sbatch():
def _get_host():
h = socket.gethostname()
if "." not in h:
h = socket.getfqdn()
return h
def _kill_agent(pid_file):
pkio.unchecked_remove(_PID_FILE)
if _get_host() == pid_file.host:
os.kill(pid_file.pid, signal.SIGKILL)
else:
try:
subprocess.run(
("ssh", pid_file.host, "kill", "-KILL", str(pid_file.pid)),
capture_output=True,
text=True,
).check_returncode()
except subprocess.CalledProcessError as e:
if "({}) - No such process".format(pid_file.pid) not in e.stderr:
pkdlog(
"cmd={cmd} returncode={returncode} stderr={stderr}", **vars(e)
)
def _read_pid_file():
try:
rv = pkjson.load_any(pkio.py_path(_PID_FILE))
if "host" in rv and "pid" in rv:
return rv
except Exception as e:
if not pkio.exception_is_not_found(e):
pkdlog("file={} error={} stack={}", e, pkdexc())
return None
def _remove_own_pid_file(info):
try:
if (f := _read_pid_file()) and f.host == info.host and f.pid == info.pid:
# race condition but very small so probably ok
pkio.unchecked_remove(_PID_FILE)
except Exception:
pass
try:
if f := _read_pid_file():
_kill_agent(f)
except Exception as e:
pkdlog("error={} stack={}", e, pkdexc())
p = None
try:
pkjson.dump_pretty(p := PKDict(host=_get_host(), pid=os.getpid()), _PID_FILE)
start()
finally:
_remove_own_pid_file(p)
class _Dispatcher(PKDict):
def __init__(self):
super().__init__(
cmds=[],
fastcgi_cmd=None,
fastcgi_error_count=0,
)
def fastcgi_destroy(self):
self._fastcgi_file and pkio.unchecked_remove(self._fastcgi_file)
self._fastcgi_file = None
self.fastcgi_cmd = None
def format_canceled(self, msg):
return self.format_op(msg, job.OP_OK, reply=PKDict(state=job.CANCELED))
def format_op(self, msg, op_name, **kwargs):
if msg:
kwargs["opId"] = msg.get("opId")
rv = _OpMsg(agentId=_cfg.agent_id, opName=op_name).pksetdefault(**kwargs)
if not rv.get("opName"):
raise AssertionError("missing opName in msg")
return rv
async def job_cmd_reply(self, msg, op_name, text=None, cmd=None, msg_items=None):
def _fixup(reply):
rv = PKDict(**msg_items) if msg_items else PKDict()
if msg.opName in (job.OP_RUN, job.OP_RUN_STATUS):
cmd.process_job_cmd_reply(reply)
rv.pkupdate(
# Not a "reply", just a msg with all these values, no opId
msg=None,
op_name=job.OP_RUN_STATUS_UPDATE,
computeJid=msg.computeJid,
computeJobSerial=msg.computeJobSerial,
state=cmd.job_state,
)
if t := cmd.get("computeJobStart"):
rv.computeJobStart = t
# Allow reply to override these things
rv.pkupdate(reply)
else:
rv.pkupdate(msg=msg, reply=reply, op_name=op_name)
return rv
def _parse_text():
if text is None:
return PKDict()
try:
return pkjson.load_any(text)
except Exception:
return PKDict(
state=job.ERROR,
error="unable to parse job_cmd output",
stdout=text,
op_name=job.ERROR,
)
try:
await self.send(self.format_op(**_fixup(_parse_text())))
except Exception as e:
pkdlog(
"text={} msg_items={} error={} stack={}", text, msg_items, e, pkdexc()
)
# something is really wrong, because format_op is messed up
raise
async def loop(self):
async def _connect_and_loop():
self._websocket = await tornado.websocket.websocket_connect(
tornado.httpclient.HTTPRequest(
connect_timeout=_CONNECT_SECS,
url=_cfg.supervisor_uri,
validate_cert=job.cfg().verify_tls,
),
max_message_size=job.cfg().max_message_bytes,
ping_interval=job.cfg().ping_interval_secs,
ping_timeout=job.cfg().ping_timeout_secs,
)
s = self.format_op(None, job.OP_ALIVE)
rv = False
while True:
if s and not await self.send(s):
break
r = await self._websocket.read_message()
if r is None:
pkdlog(
"websocket closed in response to len={} send={}",
s and len(s),
s,
)
raise tornado.iostream.StreamClosedError()
s = await self._op(r)
# One success
rv = True
return rv
t = _MAX_LOOP_RETRY
while t > 0:
self._websocket = None
try:
if await _connect_and_loop():
t = _MAX_LOOP_RETRY
except Exception as e:
if not isinstance(
e,
(
ConnectionError,
tornado.simple_httpclient.HTTPStreamClosedError,
tornado.iostream.StreamClosedError,
),
):
pkdlog(
"retries countdown={}, websocket; error={} stack={}",
t,
e,
pkdexc(),
)
finally:
if self._websocket:
self._websocket.close()
await tornado.gen.sleep(_LOOP_RETRY_SECS)
t -= 1
pkdlog("terminating after connection attempts={}", _MAX_LOOP_RETRY)
self.terminate()
def new_run_maybe_destroy_old(self, jid):
for c in list(self.cmds):
if c.jid == jid:
c.destroy()
async def send(self, msg):
if not self._websocket:
return False
try:
if not isinstance(msg, _OpMsg):
raise AssertionError(f"expected _OpMsg not msg type={type(msg)}")
await self._websocket.write_message(pkjson.dump_bytes(msg), binary=True)
return True
except Exception as e:
pkdlog("exception={} msg={} stack={}", e, msg, pkdexc())
return False
def terminate(self):
try:
x = self.cmds
# compute_jobs are passive
self.cmds = []
for c in x:
try:
c.destroy(terminating=True)
except Exception as e:
pkdlog("cmd={} error={} stack={}", c, e, pkdexc())
return None
finally:
tornado.ioloop.IOLoop.current().stop()
async def _cmd(self, msg, **kwargs):
def _class(msg):
if msg.jobRunMode == job.SBATCH:
return _SbatchRun if msg.opName == job.OP_RUN else _SbatchCmd
if msg.jobCmd == "fastcgi":
return _FastCgiCmd
return _Cmd
try:
if (
msg.opName
in (
job.OP_ANALYSIS,
job.OP_IO,
)
and msg.jobCmd != "fastcgi"
):
return await self._fastcgi_op(msg)
c = _class(msg)(msg=msg, dispatcher=self, **kwargs)
except _RunDirNotFound:
return self.format_op(
msg,
job.OP_ERROR,
reply=PKDict(runDirNotFound=True),
)
if msg.jobCmd == "fastcgi":
self.fastcgi_cmd = c
try:
return c.start()
except Exception as e:
pkdlog("start exception={} stack={}", e, pkdexc())
c.destroy()
def _fastcgi_accept(self, connection, *args, **kwargs):
# Impedence mismatch: _fastcgi_accept cannot be async, because
# bind_unix_socket doesn't await the callable.
_call_later_0(self._fastcgi_loop, connection)
async def _fastcgi_handle_error(self, msg, error, stack=None):
async def _reply_error(msg):
try:
await self.send(
self.format_op(
msg,
job.OP_ERROR,
error=error,
reply=PKDict(
state=job.ERROR,
error="internal error",
fastCgiErrorCount=self.fastcgi_error_count,
stack=stack,
),
)
)
except Exception as e:
pkdlog("msg={} error={} stack={}", msg, e, pkdexc())
# destroy _fastcgi state first, then send replies to avoid
# asynchronous modification of _fastcgi state.
self.fastcgi_error_count += 1
self._fastcgi_remove_handler()
q = self._fastcgi_msg_q
self._fastcgi_msg_q = None
self.fastcgi_cmd.destroy()
if msg:
await _reply_error(msg)
while q.qsize() > 0:
await _reply_error(q.get_nowait())
q.task_done()
async def _fastcgi_op(self, msg):
if msg.runDir:
_assert_run_dir_exists(pkio.py_path(msg.runDir))
if not self.fastcgi_cmd:
m = copy.deepcopy(msg)
m.jobCmd = "fastcgi"
self._fastcgi_file = _cfg.fastcgi_sock_dir.join(
f"sirepo_job_cmd-{_cfg.agent_id:8}.sock",
)
self._fastcgi_msg_q = sirepo.tornado.Queue(1)
pkio.unchecked_remove(self._fastcgi_file)
m.fastcgiFile = self._fastcgi_file
# Runs in an agent's directory and chdirs to real runDirs.
# Except in stateless_compute which doesn't interact with the db.
m.runDir = pkio.py_path()
# Kind of backwards, but it makes sense since we need to listen
# so _do_fastcgi can connect
self._fastcgi_remove_handler = tornado.netutil.add_accept_handler(
tornado.netutil.bind_unix_socket(str(self._fastcgi_file)),
self._fastcgi_accept,
)
# last thing, because of await: start fastcgi process
await self._cmd(m, send_reply=False)
if not self._fastcgi_msg_q:
return self.format_op(
msg,
job.ERROR,
reply=PKDict(state=job.ERROR, error="fastcgi process got an error"),
)
if msg.jobCmd == "fastcgi":
raise AssertionError("fastcgi called within fastcgi")
self._fastcgi_msg_q.put_nowait(msg)
# For better logging, msg.opId is used in format_op (reply)
# Also used in op_cancel so a cancel, cancels the fastcgi process
self.fastcgi_cmd.op_id = msg.opId
return None
async def _fastcgi_loop(self, connection):
s = None
m = None
try:
s = tornado.iostream.IOStream(
connection,
max_buffer_size=job.cfg().max_message_bytes,
)
while True:
m = await self._fastcgi_msg_q.get()
# Avoid issues with exceptions. We don't use q.join()
# so not an issue to call before work is done.
self._fastcgi_msg_q.task_done()
await s.write(pkjson.dump_bytes(m) + b"\n")
await self.job_cmd_reply(
m,
job.OP_OK,
text=await s.read_until(b"\n", job.cfg().max_message_bytes),
)
except Exception as e:
if isinstance(e, tornado.iostream.StreamClosedError):
pkdlog(
"msg={} stream closed unexpectedly exception={} real_error={}",
m,
e,
getattr(e, "real_error", None),
)
else:
pkdlog("msg={} error={} stack={}", m, e, pkdexc())
# If self.fastcgi_cmd is None we initiated the kill so not an error
if not self.fastcgi_cmd:
return
await self._fastcgi_handle_error(m, e, pkdexc())
finally:
if s:
s.close()
async def _op(self, msg):
m = None
try:
m = pkjson.load_any(msg)
pkdlog(
"opName={} o={:.4} runDir={}", m.opName, m.get("opId"), m.get("runDir")
)
pkdc("m={}", m)
return await getattr(self, "_op_" + m.opName)(m)
except Exception as e:
err = "exception=" + str(e)
stack = pkdexc()
pkdlog(
"opName={} o={:.4} exception={} stack={}",
m and m.get("opName"),
m and m.get("opId"),
e,
stack,
)
return self.format_op(m, job.OP_ERROR, error=err, stack=stack)
async def _op_analysis(self, msg):
return await self._cmd(msg)
async def _op_begin_session(self, msg):
return self.format_op(msg, job.OP_OK, reply=PKDict(awake=True))
async def _op_cancel(self, msg):
def _matches(op_id, jid):
return list(c for c in self.cmds if c.op_id == op_id or c.jid == jid)
for c in _matches(msg.get("opId", "no match"), msg.get("jid", "no match")):
c.cancel_request()
return self.format_canceled(msg)
async def _op_io(self, msg):
return await self._cmd(msg)
async def _op_kill(self, msg):
self.terminate()
return None
async def _op_run(self, msg):
return await self._cmd(msg)
async def _op_run_status(self, msg):
def _find():
for c in list(self.cmds):
if c.jid == msg.computeJid and c.get("job_state"):
return _reply(c)
return None
def _reply(cmd):
if cmd.computeJobSerial != msg.computeJobSerial:
pkdlog(
"expected computeJobSerial={} in msg={}", cmd.computeJobSerial, msg
)
# Supervisor is always right, so kill the job
cmd.destroy()
return self.format_op(
msg,
job.OP_ERROR,
reply=PKDict(
state=job.UNKNOWN, error="run_status computeJobSerial mismatch"
),
)
return self.format_op(
msg,
job.OP_OK,
reply=_copy_truthy(
cmd, PKDict(state=cmd.job_state), ("parallelStatus", "error")
),
)
if rv := _find():
pass
elif msg.jobRunMode == job.SBATCH:
# Try to ask job_state for status of the job
rv = _SbatchRunStatus.sbatch_status_request(msg=msg, dispatcher=self)
else:
# did not find job so assumed canceled, e.g. server restart
rv = self.format_canceled(msg)
pkdlog("reply={} computeJid={}", rv, msg.computeJid)
return rv
async def _op_sbatch_login(self, msg):
return self.format_op(msg, job.OP_OK, reply=PKDict(loginSuccess=True))
class _Cmd(PKDict):
def __init__(self, **kwargs):
def _run():
self.dispatcher.new_run_maybe_destroy_old(self.msg.computeJid)
self.computeJobSerial = self.msg.computeJobSerial
self.job_state = job.PENDING
if self.msg.opName == job.OP_RUN:
pkio.unchecked_remove(self.run_dir)
pkio.mkdir_parent(self.run_dir)
sirepo.sim_data.get_class(
self.msg.data.simulationType
).sim_run_input_to_run_dir(self.msg.data, self.run_dir)
else:
# Needs to exist for run_status so in_file can be created
pkio.mkdir_parent(self.run_dir)
super().__init__(**kwargs)
self.pksetdefault(
send_reply=True,
).pkupdate(
_destroying=False,
_terminating=False,
_uid=job.split_jid(jid=self.msg.computeJid).uid,
jid=self.msg.computeJid,
op_id=self.msg.opId,
)
# only certain types of commands have runDir
if self.msg.get("runDir"):
self.run_dir = pkio.py_path(self.msg.runDir)
if self.msg.opName in (job.OP_RUN, job.OP_RUN_STATUS):
_run()
else:
_assert_run_dir_exists(self.run_dir)
else:
# POSIT: same as fast_cgi
# Use agent's runDir.
self.run_dir = pkio.py_path()
self._process = _Process(self)
self.dispatcher.cmds.append(self)
def cancel_request(self):
self.destroy()
def destroy(self, terminating=False):
def _mpich_shm_clean_up():
if not self.msg.get("isParallel") or terminating:
return
# POSIT: only one parallel process per parallel job agent.
# Running inside a container.
pkio.unchecked_remove(*pkio.sorted_glob("/dev/shm/*"))
if self._destroying:
return
if self.dispatcher.fastcgi_cmd == self:
self.dispatcher.fastcgi_destroy()
self._destroying = True
self._terminating = terminating
if "_in_file" in self:
pkio.unchecked_remove(self.pkdel("_in_file"))
self._process.kill()
try:
self.dispatcher.cmds.remove(self)
except ValueError:
pass
if _cfg.mpich_shm_clean_up:
_mpich_shm_clean_up()
def format_op(self, **kwargs):
return self.dispatcher.format_op(
**PKDict(kwargs).pksetdefault(
op_name=job.ERROR,
msg=self.msg,
),
)
def format_op_reply(self, **reply_kwargs):
return self.dispatcher.format_op(
op_name=job.OP_OK,
msg=self.msg,
reply=PKDict(reply_kwargs),
)
def job_cmd_cmd(self):
return ("sirepo", "job_cmd", self._in_file)
def job_cmd_cmd_stdin_env(self):
return job.agent_cmd_stdin_env(
cmd=self.job_cmd_cmd(),
env=self.job_cmd_env(),
source_bashrc=self.job_cmd_source_bashrc(),
uid=self._uid,
)
def job_cmd_env(self, env=None):
e = (env or PKDict()).pksetdefault(
SIREPO_GLOBAL_RESOURCES_SERVER_TOKEN=_cfg.global_resources_server_token,
SIREPO_GLOBAL_RESOURCES_SERVER_URI=_cfg.global_resources_server_uri,
SIREPO_MPI_CORES=self.msg.get("mpiCores", 1),
SIREPO_SIM_DB_FILE_SERVER_TOKEN=_cfg.sim_db_file_server_token,
SIREPO_SIM_DB_FILE_SERVER_URI=_cfg.sim_db_file_server_uri,
)
if _cfg.no_hdf5_do_mpi_file_sync:
e.HDF5_DO_MPI_FILE_SYNC = "FALSE"
return job.agent_env(env=e, uid=self._uid)
def job_cmd_source_bashrc(self):
if sirepo.feature_config.cfg().trust_sh_env:
return ""
return "source $HOME/.bashrc"
async def on_stderr_read(self, text):
try:
await self.dispatcher.send(
self.format_op(
msg=None,
op_name=job.OP_JOB_CMD_STDERR,
stderr=text.decode("utf-8", errors="ignore"),
)
)
except Exception as exc:
pkdlog("{} text={} error={} stack={}", self, text, exc, pkdexc())
async def on_stdout_read(self, text):
if self._destroying:
return
if not self.send_reply:
pkdlog("{} unexpected stdout={}", self, text)
return
try:
await self.dispatcher.job_cmd_reply(
self.msg, job.OP_OK, text=text, cmd=self
)
except Exception as e:
pkdlog("{} text={} error={} stack={}", self, text, e, pkdexc())
def pkdebug_str(self):
return pkdformat(
"{}(a={:.4} jid={} o={:.4} job_cmd={} run_dir={})",
self.__class__.__name__,
_cfg.agent_id,
self.get("jid"),
self.get("op_id"),
self.msg.get("jobCmd"),
self.run_dir,
)
def process_job_cmd_reply(self, reply):
if "job_state" not in self:
pkdlog("{} unexpected reply={}", self, reply)
raise AssertionError("unexpected process_job_cmd_reply")
_copy_truthy(reply, self, ("state", "parallelStatus", "error"))
def start(self):
try:
self._in_file = self._create_in_file()
self._process.start()
except Exception as e:
pkdlog("{} exception={} stack={}", self, e, pkdexc())
rv = self.format_op(
reply=PKDict(state=job.ERROR, error="failed to start process")
)
self.destroy()
return rv
_call_later_0(self._await_exit)
if self.msg.opName != job.OP_RUN:
return None
self.job_state = job.RUNNING
self.computeJobStart = int(time.time())
rv = self.format_op_reply(state=job.STATE_OK)
self.msg.opId = None
self.msg.opName = job.OP_RUN_STATUS
_call_later_0(
self.dispatcher.job_cmd_reply,
msg=self.msg,
op_name=job.OP_RUN_STATUS_UPDATE,
cmd=self,
)
return rv
async def _await_exit(self):
try:
await self._process.exit_ready()
e = self._process.stderr.text.decode("utf-8", errors="ignore")
if e:
pkdlog("{} exit={} stderr={}", self, self._process.returncode, e)
if self._destroying:
return
if self._process.returncode != 0:
await self.dispatcher.send(
self.format_op(
error=e,
reply=PKDict(
state=job.ERROR,
error=f"process exit={self._process.returncode} jid={self.job.jid}",
),
)
)
except Exception as exc:
pkdlog(
"{} error={} returncode={} stack={}",
self,
exc,
self._process.returncode,
pkdexc(),
)
await self.dispatcher.send(
self.format_op(
error=str(exc),
reply=PKDict(
state=job.ERROR,
error="job_agent error",
),
),
)
finally:
self.destroy()
def _create_in_file(self):
f = self.run_dir.join(
_IN_FILE.format(sirepo.util.unique_key()),
)
pkjson.dump_pretty(self.msg, filename=f, pretty=False)
return f
class _FastCgiCmd(_Cmd):
pass
class _OpMsg(PKDict):
pass
class _Process(PKDict):
def __init__(self, cmd):
super().__init__()
self.update(
stderr=None,
stdout=None,
cmd=cmd,
_exit=sirepo.tornado.Event(),
)
async def exit_ready(self):
await self._exit.wait()
await self.stdout.stream_closed.wait()
await self.stderr.stream_closed.wait()
def kill(self):
# If the process is't started
if "returncode" in self or "_subprocess" not in self:
return
p = None
try:
pkdlog("{}", self)
p = self.pkdel("_subprocess").proc.pid
os.killpg(p, signal.SIGKILL)
except Exception as e:
pkdlog("{} error={}", self, e)
def pkdebug_str(self):
return pkdformat(
"{}(pid={} cmd={})",
self.__class__.__name__,
self._subprocess.proc.pid if self.get("_subprocess") else None,
self.cmd,
)
def start(self):
# SECURITY: msg must not contain agentId
assert not self.cmd.msg.get("agentId")
c, s, e = self.cmd.job_cmd_cmd_stdin_env()
pkdlog("cmd={} stdin={}", c, s.read())
s.seek(0)
self._subprocess = tornado.process.Subprocess(
c,
close_fds=True,
cwd=str(self.cmd.run_dir),
env=e,
start_new_session=True,
stdin=s,
stdout=tornado.process.Subprocess.STREAM,
stderr=tornado.process.Subprocess.STREAM,
)
s.close()
self.stdout = _ReadJsonlStream(self._subprocess.stdout, self.cmd)
self.stderr = _ReadUntilCloseStream(self._subprocess.stderr, self.cmd)
self._subprocess.set_exit_callback(self._on_exit)
return self
def _on_exit(self, returncode):
self.returncode = returncode
pkdlog("{} returncode={}", self, returncode)
self._exit.set()
class _RunDirNotFound(Exception):
pass
class _Stream(PKDict):
def __init__(self, stream, cmd):
super().__init__(
cmd=cmd,
stream_closed=sirepo.tornado.Event(),
text=bytearray(),
_stream=stream,
)
_call_later_0(self._begin_read_stream)
async def _begin_read_stream(self):
try:
while True:
await self._read_stream()
except tornado.iostream.StreamClosedError as e:
if x := getattr(e, "real_error", None):
raise AssertionError(f"real_error={x}")
finally:
self._stream.close()
self.stream_closed.set()
async def _read_stream(self):
raise NotImplementedError()
class _ReadJsonlStream(_Stream):
def __init__(self, *args):
self.proceed_with_read = tornado.locks.Condition()
self.read_occurred = tornado.locks.Condition()
super().__init__(*args)
async def _read_stream(self):
self.text = await self._stream.read_until(b"\n", job.cfg().max_message_bytes)
pkdc("cmd={} stdout={}", self.cmd, self.text[:1000])
await self.cmd.on_stdout_read(self.text)
class _ReadUntilCloseStream(_Stream):
def __init__(self, *args):
super().__init__(*args)
async def _read_stream(self):
t = await self._stream.read_bytes(
job.cfg().max_message_bytes - len(self.text),
partial=True,
)
pkdlog("cmd={} stderr={}", self.cmd, t)
await self.cmd.on_stderr_read(t)
l = len(self.text) + len(t)
assert (
l < job.cfg().max_message_bytes
), "len(bytes)={} greater than max_message_size={}".format(
l,
job.cfg().max_message_bytes,
)
self.text.extend(t)
class _SbatchCmd(_Cmd):
def __init__(self, **kwargs):
super().__init__(**kwargs)
if "job_state" not in self:
return
self._sbatch_status_file = self.run_dir.join(_SBATCH_STATUS_FILE)
self.msg.sbatchStatusFile = str(self._sbatch_status_file)
# Only exists when _SbatchRun starts _SbatchRunStatus
if r := self.pkdel("sbatch_run"):
self._sbatch_status = r._sbatch_status.copy()
else:
self._sbatch_status = PKDict(
job_cmd_state=None,
sbatch_id=None,
computeJobSerial=self.computeJobSerial,
computeJobStart=None,
)
def job_cmd_cmd_stdin_env(self, *args, **kwargs):
c, s, e = super().job_cmd_cmd_stdin_env()
if self.msg.get("shifterImage"):
c = (
"shifter",
"--entrypoint",
f"--image={self.msg.shifterImage}",
"/bin/bash",
"--norc",
"--noprofile",
"-l",
)
return c, s, e
def job_cmd_env(self):
# POSIT: sirepo.mpi cfg sentinel for running in slurm
e = PKDict(SIREPO_MPI_IN_SLURM=1)
if _cfg.dev_source_dirs:
e.PYTHONPATH = _DEV_PYTHON_PATH
return super().job_cmd_env(e)
def job_cmd_source_bashrc(self):
if not self.msg.get("shifterImage"):
return super().job_cmd_source_bashrc()
return ""
def _sbatch_status_update(self, want_write=True, **kwargs):
def _cascade_to_self():
if (
s := self._sbatch_status.job_cmd_state
) != job.JOB_CMD_STATE_SBATCH_RUN_STATUS_STOP:
self.job_state = s
_copy_truthy(
self._sbatch_status,
self,
("computeJobStart", "parallelStatus", "error"),
)
p = self._sbatch_status.copy()
self._sbatch_status.pkupdate(kwargs)
if p == self._sbatch_status:
return True
_cascade_to_self()
if not want_write:
return True
try:
pkio.atomic_write(
self._sbatch_status_file, pkjson.dump_pretty(self._sbatch_status)
)
return True
except Exception as e:
# The simulation directory might get deleted out from
# under this process or some other error.
pkdlog(
"error writing file={} exception={} stack={}",
self._sbatch_status_file,
e,
pkdexc(),
)
return False
class _SbatchRun(_SbatchCmd):
def start(self):
p = subprocess.run(
("bash", self._script()),
close_fds=True,
cwd=str(self.run_dir),
capture_output=True,
text=True,
)
m = re.search(r"Submitted batch job (\d+)", p.stdout)
# Failure might be out of hours or batch system down
if m:
if self._sbatch_status_update(
job_cmd_state=job.PENDING, sbatch_id=m.group(1)
):
_SbatchRunStatus(
msg=copy.deepcopy(self.msg),
dispatcher=self.dispatcher,
sbatch_run=self,
).start()
rv = self.format_op_reply(state=job.STATE_OK)
else:
rv = self.format_op(error="unable to write sbatch state file")
# TODO(robnagler) need to cancel job, because no way to attach
else:
pkdlog("exit={} stdout={} stderr={}", p.returncode, p.stdout, p.stderr)
rv = self.format_op(
error=f"error submitting sbatch job error={p.stderr}",
)
self.destroy()
return rv
def _script(self):
def _in_file(cmd, basename, content):
f = self.run_dir.join(basename)
f.write(content)
return f"{cmd} {f}" if cmd else str(f)
def _nodes_tasks():
if n := self.msg.get("sbatchNodes"):
return f"#SBATCH --nodes={n}\n#SBATCH --cpus-per-task={self.msg.sbatchCores}"
return f"#SBATCH --ntasks={self.msg.sbatchCores}"
def _python(include_image=False):
rv = "python"
if i := self.msg.get("shifterImage"):
if include_image:
rv = f"--image={i} {rv}"
rv = "shifter --entrypoint " + rv
return rv
def _shifter_header():
# POSIT: job_api has validated values
if not self.msg.get("shifterImage"):
return ""
return f"""#SBATCH --image={self.msg.shifterImage}
#SBATCH --constraint=cpu
#SBATCH --qos={self.msg.sbatchQueue}
#SBATCH --tasks-per-node={self.msg.tasksPerNode}
{sirepo.nersc.sbatch_project_option(self.msg.sbatchProject)}"""
def _sbatch_cmd():
return _in_file(
"sbatch",
"in.sbatch",
f"""#!/bin/bash
#SBATCH --error={template_common.RUN_LOG}
#SBATCH --output={template_common.RUN_LOG}
#SBATCH --time={_time()}
{_nodes_tasks()}
{_shifter_header()}
{self.job_cmd_env()}
{self.job_cmd_source_bashrc()}
# POSIT: same as sim_run_dir_prepare() return value
exec {_srun()} {_python()} {template_common.PARAMETERS_PYTHON_FILE}
""",
)
def _sim_prepare_cmd():
return _in_file(
_python(True),
"prepare.py",
f"""#!/usr/bin/env python
import sirepo.sim_data
sirepo.sim_data.get_class('{self.msg.data.simulationType}').sim_run_dir_prepare(
'{self.run_dir}',
)
""",
)
def _srun():
return "srun" + (
" --cpu-bind=cores" if self.msg.get("shifterImage") else ""
)
def _time():
return str(
datetime.timedelta(
seconds=int(
datetime.timedelta(
hours=float(self.msg.sbatchHours)
).total_seconds(),
),
)
)
return _in_file(
None,
"run.bash",
f"""#!/bin/bash
set -eou pipefail
{self.job_cmd_env()}
{_sim_prepare_cmd()}
{_sbatch_cmd()}
""",
)
class _SbatchRunStatus(_SbatchCmd):
def __init__(self, **kwargs):
kwargs["msg"].pkupdate(
jobCmd="sbatch_parallel_status",
opName=job.OP_RUN_STATUS,
)
super().__init__(**kwargs)
self.pkdel("computeJobStart")
self.pkupdate(
_sbatch_status_cb=None,
_sbatch_query_tries=0,
)
def destroy(self, terminating=False):
def _scancel(sbatch_id):
pkdlog("sbatch_id={}", sbatch_id)
p = subprocess.run(
("scancel", "--full", "--quiet", sbatch_id),
close_fds=True,
cwd=str(self.run_dir),
capture_output=True,
text=True,
)
if p.returncode != 0:
pkdlog(
"{} cancel error exit={} sbatch={} stderr={} stdout={}",
self,
p.returncode,
sbatch_id,
p.stderr,
p.stdout,
)
if self._destroying:
return
if self._sbatch_status_cb:
self._sbatch_status_cb.stop()
self._sbatch_status_cb = None
if (
self._sbatch_status.job_cmd_state not in job.JOB_CMD_STATE_EXITS
and self._sbatch_status.sbatch_id
and not terminating
):
self._sbatch_status_update(job_cmd_state=job.CANCELED)
_scancel(self._sbatch_status.sbatch_id)
super().destroy(terminating=terminating)
async def on_stdout_read(self, text):
if self._destroying:
return
try:
await self._sbatch_send_update(text=text)
except Exception as e:
pkdlog("{} text={} error={} stack={}", self, text, e, pkdexc())
def process_job_cmd_reply(self, reply):
super().process_job_cmd_reply(reply)
if v := _copy_truthy(reply, PKDict(), ("parallelStatus", "error")):
# job_cmd writes the final state so don't write again
self._sbatch_status_update(
want_write=reply.get("state") != job.COMPLETED, **v
)
@classmethod
def sbatch_status_request(cls, **kwargs):
self = cls(**kwargs)
if s := self._sbatch_is_not_running():
rv = self.format_op_reply(state=s)
if x := self._sbatch_status.get("parallelStatus"):
rv.parallelStatus = x
self.destroy()
return rv
# can't answer the question yet
rv = self.format_op_reply(state=job.UNKNOWN)
# running, possibly completed, but needs to write parallel status
self.start()
return rv
def start(self):
# Detach from op_run_status or op_run
self.op_id = self.msg.opId = None
super().start()
self._sbatch_status_cb = tornado.ioloop.PeriodicCallback(
self._sbatch_poll_query,
min(_MIN_SBATCH_POLL_SECS, self.msg.nextRequestSeconds) * 1000,
)
self._sbatch_status_cb.start()
# So happens right away
_call_later_0(self._sbatch_poll_query)
return None
def _sbatch_is_not_running(self):
def _read():
c = None
try:
c = self._sbatch_status_file.read()
s = pkjson.load_any(c)
except Exception as e:
pkdlog(
"file={} exception={} contents={}", self._sbatch_status_file, e, c
)
return None
if not s.get("sbatch_id") or not s.get("job_cmd_state"):
pkdlog(
"invalid sbatch_status={} status={} file={}",
s,
self._sbatch_status_file,
)
return None
if (x := self.msg.computeJobSerial) != s.get("computeJobSerial"):
pkdlog(
"expected computeJobSerial={} status={} file={}",
x,
s,
self._sbatch_status_file,
)
return None
return s
if not self._sbatch_status_file.exists():
# TODO(robnagler) could be missing run dir. Should cancel the job
pkdlog("missing sbatch status file={}", self._sbatch_status_file)
return job.CANCELED
if not (s := _read()):
if not pkconfig.in_dev_mode():
pkio.unchecked_remove(self._sbatch_status_file)
return job.CANCELED
# save in self for start() and sbatch_status_request()
self._sbatch_status_update(want_write=False, **s)
if s.job_cmd_state in job.EXIT_STATUSES:
return s.job_cmd_state
return None
async def _sbatch_poll_query(self):
async def _sbatch_query_try_count_ok():
if self._sbatch_query_tries < _MAX_SBATCH_QUERY_TRIES:
return True
pkdlog(
"{} sbatch_query failed after tries={} sbatch_id={}",
self,
self._sbatch_query_tries,
self._sbatch_status.sbatch_id,
)
self._sbatch_status_update(
job_cmd_state=job.ERROR,
error="sbatch_query unavailable or invalid output",
)
await self._sbatch_send_update()
return False
def _transition_state(prev, curr):
if prev == curr or (
curr == job.COMPLETED
and prev == job.JOB_CMD_STATE_SBATCH_RUN_STATUS_STOP
):
return False
rv = True
if prev == job.PENDING and curr in (job.RUNNING, job.COMPLETED):
if not self._sbatch_status.get("computeJobStart"):
self._sbatch_status.computeJobStart = int(time.time())
if curr == job.COMPLETED:
curr = job.JOB_CMD_STATE_SBATCH_RUN_STATUS_STOP
# waits for parallelStatus from job_cmd to send COMPLETED
rv = False
self._sbatch_status_update(job_cmd_state=curr)
return rv
try:
if self._destroying:
return
self._sbatch_query_tries += 1
if (
not (s := self._sbatch_query())
and not await _sbatch_query_try_count_ok()
):
return
self._sbatch_query_tries = 0
if _transition_state(self._sbatch_status.job_cmd_state, s):
await self._sbatch_send_update()
except Exception as e:
pkdlog("program error, stopping exception={} stack={}", e, pkdexc())
try:
await self.dispatcher.send(
self.dispatcher.format_op(
op_name=job.OP_RUN_STATUS_UPDATE,
msg=None,
computeJid=self.jid,
computeJobSerial=self.computeJobSerial,
error=f"_SbatchRunStatus exception={e}",
state=job.ERROR,
),
)
except Exception as e:
pkdlog("unable to send, stopping exception={} stack={}", e, pkdexc())
self.destroy()
def _sbatch_query(self):
def _sacct():
# Invalid job id specified (not running)
p = subprocess.run(
("sacct", f"--jobs={self._sbatch_status.sbatch_id}", "--format=State"),
cwd=str(self.run_dir),
close_fds=True,
capture_output=True,
text=True,
)
if p.returncode != 0:
pkdlog(
"{} sacct error exit={} sbatch={} stderr={} stdout={}",
self,
p.returncode,
self._sbatch_status.sbatch_id,
p.stderr,
p.stdout,
)
if "disabled" in p.stderr:
# Only in dev: saccount is not configured and job not running, assume canceled
return "CANCELLED"
# Job never ran?
return "FAILED"
# sacct outputs state for each part of the job (shifter, external, etc.) so be pessimistic.
rv = set()
for l in re.split(r"\s+", p.stdout):
if len(l) and not l.startswith("-") and l != "State":
rv.add(l)
if len(rv) == 1:
# Normal case
return next(iter(rv))
if len(rv) > 1 and "CANCELLED" in rv:
return "CANCELLED"
pkdlog("{} sacct parse failed words={} stdout={}", self, rv, p.stdout)
return "FAILED"
def _scontrol():
# try scontrol first, because that's the normal case and easier to parse
p = subprocess.run(
("scontrol", "show", "job", self._sbatch_status.sbatch_id),
cwd=str(self.run_dir),
close_fds=True,
capture_output=True,
text=True,
)
if p.returncode != 0:
# Invalid job id will happen on NERSC. No jobs in dev
if re.search("Invalid job id|No jobs", p.stderr):
pkdlog(
"sbatch={} not in system, trying sacct",
self._sbatch_status.sbatch_id,
)
return _sacct()
pkdlog(
"{} scontrol error exit={} sbatch={} stderr={} stdout={}",
self,
p.returncode,
self._sbatch_status.sbatch_id,
p.stderr,
p.stdout,
)
return None
r = re.search(r"(?<=JobState=)(\S+)(?= Reason)", p.stdout)
if not r:
pkdlog(
"{} failed to find JobState in stderr={} stdout={}",
self,
p.stderr,
p.stdout,
)
return None
return r.group(1)
if not (s := _scontrol()):
return None
if s in ("PENDING", "CONFIGURING"):
return job.PENDING
if s in ("COMPLETING", "RUNNING"):
return job.RUNNING
if s == "COMPLETED":
return job.COMPLETED
if s == "CANCELLED":
return job.CANCELED
if s == "FAILED":
return job.ERROR
if s == "TIMEOUT":
return job.CANCELED
pkdlog(
"{} sbatch_id={} unexpected sbatch query state={}",
self,
self._sbatch_status.sbatch_id,
s,
)
return job.ERROR
async def _sbatch_send_update(self, text=None):
def _optional():
# parallelStatus only happens in the case we are at the end
rv = PKDict()
for f in "error", "parallelStatus", "computeJobStart":
# Will be overwritten if in "text"
if x := self._sbatch_status.get(f):
rv[f] = x
return rv
await self.dispatcher.job_cmd_reply(
msg=self.msg,
op_name=job.OP_RUN_STATUS_UPDATE,
text=text,
cmd=self,
msg_items=_optional(),
)
if self._sbatch_status.job_cmd_state in job.EXIT_STATUSES:
self.destroy()
def _assert_run_dir_exists(run_dir):
if not run_dir.exists():
raise _RunDirNotFound()
def _call_later_0(*args, **kwargs):
return tornado.ioloop.IOLoop.current().call_later(0, *args, **kwargs)
def _copy_truthy(src, dst, keys):
for x in keys:
if y := src.get(x):
dst[x] = y
return dst
def _terminate(dispatcher):
dispatcher.terminate()