"""Handles dispatching of uris to server.api_* functions
:copyright: Copyright (c) 2017 RadiaSoft LLC. All Rights Reserved.
:license: http://www.apache.org/licenses/LICENSE-2.0.html
"""
from pykern import pkcollections
from pykern import pkconfig
from pykern import pkinspect
from pykern.pkcollections import PKDict
from pykern.pkdebug import pkdc, pkdexc, pkdlog, pkdp, pkdformat
import asyncio
import importlib
import inspect
import re
import sirepo.api_auth
import sirepo.auth
import sirepo.const
import sirepo.events
import sirepo.feature_config
import sirepo.http_util
import sirepo.spa_session
import sirepo.uri
import sirepo.util
#: prefix for api functions
_FUNC_PREFIX = "api_"
#: modules that must be initialized (convenient as a list)
_REQUIRED_MODULES = ["auth_api", "job_api", "server", "srtime", "auth_role_moderation"]
#: uri for default dispatches
_ROUTE_URI_DEFAULT = ""
#: uri for not found dispatches
_ROUTE_URI_NOT_FOUND = "not-found"
#: Where to route when no routes match (root)
_route_default = None
#: Where to route when no routes match (root)
_route_default = None
#: Where to route when a route is not found (notFound)
_route_default = None
#: dict of base_uri to route (base_uri, func, name, decl_uri, params)
_uri_to_route = None
#: dict of base_uri to route (base_uri, func, name, decl_uri, params)
_api_to_route = None
#: modules which support APIs
_api_modules = []
#: functions which implement APIs
_api_funcs = PKDict()
_BUCKET_KEY = "uri_route"
[docs]
async def call_api(qcall, name, kwargs=None, body=None):
"""Should not be called outside of Base.call_api(). Use self.call_api() to call API.
Call another API with permission checks.
Args:
qcall (quest.API): request object
route_or_name (object): api function or name (without `api_` prefix)
kwargs (PKDict): to be passed to API [None]
body (PKDict): will be returned `qcall.body_as_dict`
Returns:
Response: result
"""
return await _call_api(qcall, _api_to_route[name], kwargs=kwargs, body=body)
[docs]
def init_module(want_apis, **imports):
"""Convert route map to dispatchable callables
Initializes `_uri_to_route`
"""
global _uri_to_route
if _uri_to_route is not None:
return
# import simulation_db
sirepo.util.setattr_imports(imports)
if not want_apis:
_uri_to_route = PKDict()
return
f = sirepo.feature_config.cfg()
for n in _REQUIRED_MODULES + sorted(f.api_modules):
register_api_module("sirepo." + n)
_register_sim_modules("sim_api", f.sim_types)
_register_sim_modules("sim_oauth", f.proprietary_oauth_sim_types)
_init_uris(simulation_db, f.sim_types)
[docs]
def maybe_sim_type_required_for_api(qcall):
return sirepo.api_auth.maybe_sim_type_required_for_api(
qcall.bucket_get(_BUCKET_KEY).func
)
[docs]
def register_api_module(module):
"""Add caller_module to the list of modules which implements apis.
The module must have methods: api_XXX which do not collide with
other apis. It may also have init_apis(), which will be called unless
it is already registered.
Args:
module (module or str): name of module or module
"""
def _is_api_func(cls, name, obj):
return (
name.startswith(_FUNC_PREFIX)
and inspect.isfunction(obj)
and name in cls.__dict__
)
assert (
not _route_default
), "_init_uris already called. All APIs must registered at init"
m = importlib.import_module(module) if isinstance(module, str) else module
if m in _api_modules:
return
# prevent recursion
_api_modules.append(m)
if hasattr(m, "init_apis"):
m.init_apis(uri_router=pkinspect.this_module())
if not hasattr(m, "API"):
if pkinspect.module_functions("api_", module=m):
raise AssertionError(f"module={m.__name__} has old interface")
if pkconfig.in_dev_mode():
pkdlog(f"api_module={m.__name__} does not have API class (no apis)")
# some modules (ex: sirepo.auth.basic) don't have any APIs
return
c = m.API
for n, o in inspect.getmembers(c):
if _is_api_func(cls=c, name=n, obj=o):
assert (
not n in _api_funcs
), "function is duplicate: func={} module={}".format(n, m.__name__)
_api_funcs[n] = _Route(func=o, cls=c, func_name=n)
[docs]
def start_tornado(ip, port, debug, is_primary=True):
"""Start tornado server, does not return"""
from tornado import httpserver, ioloop, web, log, websocket
ws_count = 0
class _HTTPRequest(web.RequestHandler):
async def _route(self):
_log(self, "start")
p = sirepo.uri.decode_to_str(self.request.path)
e, r, k = _path_to_route(p[1:])
if e:
_log(
self,
"error",
fmt=" msg={} route={} kwargs={}",
args=[e, r, k],
)
r = _not_found_route
await _call_api(
None,
r,
kwargs=k,
internal_req=self,
reply_op=lambda r: r.tornado_response(self),
)
async def get(self):
await self._route()
async def post(self):
await self._route()
def sr_get_log_user(self):
return getattr(self, "_sr_log_user", "")
def sr_set_log_user(self, log_user):
self._sr_log_user = log_user
class _WebSocket(websocket.WebSocketHandler):
async def get(self, *args, **kwargs):
_log(self, "start")
return await super().get(*args, **kwargs)
async def on_message(self, msg):
# WebSocketHandler only allows one on_message at a time.
asyncio.ensure_future(self.__on_message(msg))
def on_close(self):
self.sr_log(
None,
"close",
fmt=" code={} reason={}",
args=[self.close_code, self.close_reason or ""],
)
def open(self):
nonlocal ws_count
# self.get_compression_options
self.set_nodelay(True)
r = self.request
ws_count += 1
self.__headers = PKDict(r.headers)
self.cookie_state = self.__headers.get("Cookie")
self.http_server_uri = f"{r.protocol}://{r.host}/"
self.remote_addr = sirepo.http_util.remote_ip(r)
self.ws_id = ws_count
self.sr_log(None, "open", fmt=" ip={}", args=[_remote_peer(r)])
def sr_get_log_user(self):
"""Needed for initial websocket creation call"""
return ""
def sr_log(self, ws_req, which, fmt="", args=None):
pkdlog(
"{} ws={}#{}" + fmt,
which,
self.ws_id,
ws_req and ws_req.header.get("reqSeq") or 0,
*args,
)
async def __on_message(self, msg):
w = _WebSocketRequest(handler=self, headers=self.__headers)
async def _reply_op(sreply):
nonlocal w
self.cookie_state = sreply.qcall.cookie.export_state()
await sreply.websocket_response(w)
try:
w.parse_msg(msg)
await _call_api(
None,
w.route,
kwargs=w.kwargs,
internal_req=w,
reply_op=_reply_op,
)
# TODO(robnagler) what if msg poorly constructed? Close socket?
except Exception as e:
self.sr_log(w, "error", fmt=" msg={} uri={}", args=[e, w.get("uri")])
raise
finally:
self.sr_log(w, "end", fmt=" uid={}", args=[w.get("log_user")])
class _WebSocketRequest(PKDict):
def parse_msg(self, msg):
import msgpack
def _maybe_srunit_caller():
if pkconfig.in_dev_mode() and (c := self.header.get("srunit_caller")):
return pkdformat(" srunit={}", c)
return ""
if not isinstance(msg, bytes):
raise AssertionError(f"incoming msg type={type(msg)}")
u = msgpack.Unpacker(
max_buffer_size=sirepo.job.cfg().max_message_bytes,
object_pairs_hook=pkcollections.object_pairs_hook,
)
u.feed(msg)
self.header = u.unpack()
self.handler.sr_log(
self,
"start",
fmt=" uri={}{}",
args=[self.header.get("uri"), _maybe_srunit_caller()],
)
if sirepo.const.SCHEMA_COMMON.websocketMsg.version != self.header.get(
"version"
):
raise AssertionError(
pkdformat("invalid header.version={}", self.header.get("version"))
)
# Ensures protocol conforms for all requests
if (
sirepo.const.SCHEMA_COMMON.websocketMsg.kind.httpRequest
!= self.header.get("kind")
):
raise AssertionError(
pkdformat("invalid header.kind={}", self.header.get("kind"))
)
self.req_seq = self.header.reqSeq
self.uri = self.header.uri
if u.tell() < len(msg):
self.body_as_dict = u.unpack()
if u.tell() < len(msg):
self.attachment = u.unpack()
# content may or may not exist so defer checking
e, self.route, self.kwargs = _path_to_route(self.uri[1:])
if e:
self.handler.sr_log(
self,
"error",
fmt=" msg={} route={} kwargs={}",
args=[e, self.route, self.kwargs],
)
self.route = _not_found_route
# Overwrite kwarg values if present in the message body
if self.get("body_as_dict"):
for k in self.body_as_dict:
if k in self.kwargs:
self.kwargs[k] = self.body_as_dict[k]
def set_log_user(self, log_user):
self.log_user = log_user
def _cron_and_start():
from sirepo import cron
l = ioloop.IOLoop.current()
cron.CronTask.init_class(l if is_primary else None)
l.start()
def _log(handler, which="end", fmt="", args=None):
r = handler.request
f = "{} ip={} uri={} "
a = [which, _remote_peer(r), r.uri]
if fmt:
f += " " + fmt
a += args
elif which == "start":
f += "proto={} {} ref={} ua={}"
a += [
r.method,
r.version,
r.headers.get("Referer") or "",
r.headers.get("User-Agent") or "",
]
else:
f += "uid={} status={} ms={:.2f}"
a += [
handler.sr_get_log_user(),
handler.get_status(),
r.request_time() * 1000.0,
]
pkdlog(f, *a)
def _remote_peer(request):
# https://github.com/tornadoweb/tornado/issues/2967#issuecomment-757370594
# implementation may change; Code in tornado.httputil check connection.
p = 0
if c := request.connection:
# socket is not set on stream for websockets.
if hasattr(c, "stream") and hasattr(c.stream, "socket"):
p = c.stream.socket.getpeername()[1]
return f"{sirepo.http_util.remote_ip(request)}:{p}"
sirepo.modules.import_and_init("sirepo.server").init_tornado()
s = httpserver.HTTPServer(
web.Application(
[
("/ws", _WebSocket),
("/.*", _HTTPRequest),
],
debug=debug,
websocket_max_message_size=sirepo.job.cfg().max_message_bytes,
websocket_ping_interval=sirepo.job.cfg().ping_interval_secs,
websocket_ping_timeout=sirepo.job.cfg().ping_timeout_secs,
log_function=_log,
),
xheaders=True,
max_buffer_size=sirepo.job.cfg().max_message_bytes,
).listen(port=port, address=ip)
log.enable_pretty_logging()
_cron_and_start()
[docs]
def uri_for_api(api_name, params=None):
"""Generate uri for api method
Args:
api_name (str): full name of api
params (PKDict): paramters to pass to uri
Returns:
str: formatted URI
"""
if params is None:
params = PKDict()
r = _api_to_route[api_name]
s = "/"
res = (s + r.base_uri).rstrip("/")
for p in r.params:
if p.name in params:
v = params[p.name]
if not v is None and len(v) > 0:
if not (p.is_path_info and v.startswith("/")):
res += "/"
res += v
continue
assert p.is_optional, "missing parameter={} for api={}".format(p.name, api_name)
return res or "/"
class _Route(PKDict):
"""Holds all route information for an API.
Keys:
base_uri (str): first part of URI (ex: 'adjust-time')
cls (class): The class in the API's module that contains the API function.
decl_uri (str): full URI that's in schema (ex: '/adjust-time/?<days>')
func (function): object that has api_perm attributes. should not be called as a function
func_name (str): method name in cls that implements the route (ex: 'api_admJobs').
name (str): API route name
params (list): parameters for URI
"""
pass
class _URIParams(PKDict):
"""Holds parameters for URI.
Keys:
is_optional (bool): is parameter optional
is_path_info (bool): is parameter path info
name (str): parameter name
"""
pass
async def _call_api(parent, route, kwargs, body=None, internal_req=None, reply_op=None):
qcall = route.cls()
c = False
r = None
try:
if parent:
qcall.parent_set(parent)
qcall.bucket_set(_BUCKET_KEY, route)
qcall.sim_type_set_from_spec(route.func)
if not parent:
sirepo.auth.init_quest(qcall=qcall, internal_req=internal_req)
await sirepo.spa_session.maybe_begin(qcall=qcall)
if body is not None:
qcall.sreq.set_body(body)
try:
# must be first so exceptions have access to sim_type
if kwargs:
# Any (GET) uri will have simulation_type in uri if it is application
# specific.
qcall.sim_type_set(kwargs.get("simulation_type"))
elif kwargs is None:
kwargs = PKDict()
_check_route(qcall, route)
r = qcall.sreply.uri_router_process_api_call(
await getattr(qcall, route.func_name)(**kwargs)
)
c = True
except Exception as e:
if isinstance(e, sirepo.util.ReplyExc):
if isinstance(e, sirepo.util.OKReplyExc):
c = True
pkdc("api={} exception={} stack={}", route.name, e, pkdexc())
else:
pkdlog("api={} exception={} stack={}", route.name, e, pkdexc())
qcall.cookie.has_sentinel()
r = qcall.sreply.gen_exception(e)
if parent:
# Done with nested call. Detach since qcall destroyed below
return r.detach_from_quest()
sirepo.events.emit(qcall, "end_api_call", PKDict(resp=r))
if pkconfig.in_dev_mode():
r.header_set("Access-Control-Allow-Origin", "*")
if inspect.iscoroutinefunction(reply_op):
return await reply_op(r)
else:
return reply_op(r)
except:
c = False
raise
finally:
qcall.destroy(commit=c)
def _check_route(qcall, route):
"""Check if the route is authorized
Args:
route (_Route): API to check
"""
sirepo.api_auth.check_api_call(qcall, route.func)
def _init_uris(simulation_db, sim_types):
global _route_default, _not_found_route, _api_to_route, _uri_to_route
assert not _route_default, "_init_uris called twice"
_uri_to_route = PKDict()
_api_to_route = PKDict()
for k, v in simulation_db.SCHEMA_COMMON.route.items():
r = _Route(_split_uri(v))
try:
r.update(_api_funcs[_FUNC_PREFIX + k])
except KeyError:
pkdc("not adding api, because module not registered: uri={}", v)
continue
sirepo.api_auth.assert_api_def(r.func)
r.decl_uri = v
r.name = k
assert (
not r.base_uri in _uri_to_route
), "{}: duplicate end point; other={}".format(v, _uri_to_route[r.base_uri])
_uri_to_route[r.base_uri] = r
_api_to_route[k] = r
if r.base_uri == _ROUTE_URI_DEFAULT:
_route_default = r
elif r.base_uri == _ROUTE_URI_NOT_FOUND:
_not_found_route = r
assert _route_default, f"missing constant route: default /{_ROUTE_URI_DEFAULT}"
assert (
_not_found_route
), f"missing constant route: not found /{_ROUTE_URI_NOT_FOUND}"
_validate_root_redirect_uris(_uri_to_route, simulation_db)
def _path_to_route(path):
if path is None:
return (None, _route_default, PKDict(path_info=None))
parts = path.split("/")
route = None
kwargs = None
try:
try:
route = _uri_to_route[parts[0]]
parts.pop(0)
except KeyError:
# Get here if the first part of the uri doesn't match a
# route (all routes have only top level uris). It's likely
# to be a sim_type, but could be other top level route items.
# There should be no other parts so /foo/bar is going to yield
# "too many parts" below.
route = _route_default
kwargs = PKDict()
for p in route.params:
if not parts:
if not p.is_optional:
return (f"missing parameter={p.name}", route, kwargs)
break
if p.is_path_info:
kwargs[p.name] = "/".join(parts)
parts = None
break
kwargs[p.name] = parts.pop(0)
if parts:
return (pkdformat("has too many parts={}", parts), route, kwargs)
except Exception as e:
return (pkdformat("parse exception={} stack={}", e, pkdexc()), route, kwargs)
return (None, route, kwargs)
def _register_sim_modules(package, sim_types):
def _modules(package_path):
for n in sim_types:
try:
yield pkinspect.import_submodule(n, package, package_path)
except pkinspect.SubmoduleNotFound:
continue
for m in _modules(sirepo.feature_config.cfg().package_path):
register_api_module(m)
def _split_uri(uri):
"""Parse the URL for parameters
Args:
uri (str): full path with parameter args in uri format
Returns:
Dict: with base_uri, func, params, etc.
"""
parts = uri.split("/")
assert "" == parts.pop(0)
params = []
res = PKDict(params=params)
in_optional = None
in_path_info = None
first = None
for p in parts:
assert not in_path_info, "path_info parameter={} must be last: next={}".format(
rp.name, p
)
m = re.search(f"^{sirepo.uri.PARAM_RE.format('(.+?)')}$", p)
if not m:
assert first is None, "too many non-parameter components of uri={}".format(
uri
)
first = p
continue
rp = _URIParams()
params.append(rp)
rp.is_optional = bool(m.group(1))
if rp.is_optional:
rp.is_path_info = m.group(1) == sirepo.uri.PATH_INFO_CHAR
in_path_info = rp.is_path_info
else:
rp.is_path_info = False
rp.name = m.group(2)
if rp.is_optional:
in_optional = True
else:
assert (
not in_optional
), "{}: optional parameter ({}) followed by non-optional".format(
uri,
rp.name,
)
res.base_uri = first or ""
return res
def _validate_root_redirect_uris(uri_to_route, simulation_db):
u = set(uri_to_route.keys())
t = sirepo.feature_config.cfg().sim_types
r = set(simulation_db.SCHEMA_COMMON.rootRedirectUri.keys())
i = u & r | u & t | r & t
assert not i, f"rootRedirectUri, sim_types, and routes have overlapping uris={i}"
for x in r:
assert re.search(
r"^[a-z]+$", x
), f"rootRedirectUri={x} must consist of letters only"