"""Support routines and classes, mostly around errors and I/O.
:copyright: Copyright (c) 2018 RadiaSoft LLC. All Rights Reserved.
:license: http://www.apache.org/licenses/LICENSE-2.0.html
"""
# NOTE: limit sirepo imports here
from pykern import pkcompat
from pykern import pkconfig
from pykern.pkcollections import PKDict
from pykern.pkdebug import pkdlog, pkdp, pkdexc, pkdc
import asyncio
import base64
import hashlib
import importlib
import io
import inspect
import numconv
import numpy
import os.path
import pykern.pkinspect
import pykern.pkio
import pykern.pkjson
import re
import random
import sirepo.const
import unicodedata
import zipfile
_cfg = None
#: length of string returned by create_token
TOKEN_SIZE = 16
#: POSIT: Matches anything generated by `unique_key`
UNIQUE_KEY_CHARS_RE = r"\w+"
#: A standalone unique key
UNIQUE_KEY_RE = re.compile(r"^{}$".format(UNIQUE_KEY_CHARS_RE))
_FIRST_SIM_TYPE = None
# See https://github.com/radiasoft/sirepo/pull/3889#discussion_r738769716
# for reasoning on why define both
_INVALID_PYTHON_IDENTIFIER = re.compile(r"\W|^(?=\d)")
_VALID_PYTHON_IDENTIFIER = re.compile(r"^[a-z_]\w*$", re.IGNORECASE)
_INVALID_PATH_CHARS = re.compile(r"[^A-Za-z0-9_.-]")
[docs]
class ReplyExc(Exception):
"""Raised to end the request.
Args:
sr_args (dict): exception args that Sirepo specific
log_fmt (str): server side log data
"""
def __init__(self, *args, **kwargs):
super().__init__()
if "sr_args" in kwargs:
self.sr_args = kwargs["sr_args"]
del kwargs["sr_args"]
else:
self.sr_args = PKDict()
if args or kwargs:
kwargs["pkdebug_frame"] = inspect.currentframe().f_back.f_back
pkdlog(*args, **kwargs)
def __repr__(self):
a = self.sr_args
return "{}({})".format(
self.__class__.__name__,
",".join(
("{}={}".format(k, a[k]) for k in sorted(a.keys())),
),
)
def __str__(self):
return self.__repr__()
[docs]
class BadRequest(ReplyExc):
"""Raised for bad request"""
pass
[docs]
class OKReplyExc(ReplyExc):
"""When a ReplyExc exception is a successful response"""
pass
[docs]
class InvalidEmail(ReplyExc):
"""Email user is trying to register with is malformed or domain is on deny list."""
pass
[docs]
class Error(ReplyExc):
"""Raised to send an error response
Args:
error_msg (str): just the error to output to user
"""
def __init__(self, error, *args, **kwargs):
# removed dict usage
assert isinstance(error, str)
super().__init__(*args, sr_args=PKDict(error=error), **kwargs)
[docs]
class Forbidden(ReplyExc):
"""Raised for forbidden"""
pass
[docs]
class NotFound(ReplyExc):
"""Raised for not found"""
pass
[docs]
class PlanExpired(ReplyExc):
"""API requires and active plan"""
pass
[docs]
class Redirect(OKReplyExc):
"""Raised to redirect
Args:
uri (str): where to redirect to
log_fmt (str): server side log data
"""
def __init__(self, uri, *args, **kwargs):
super().__init__(*args, sr_args=PKDict(uri=uri), **kwargs)
[docs]
class ContentTooLarge(ReplyExc):
"""The content requested by the user was too large (ex large data file)"""
pass
[docs]
class ServerError(ReplyExc):
"""Raised for server error"""
pass
[docs]
class SPathNotFound(NotFound):
"""Raised by simulation_db
Args:
sim_type (str): simulation type
uid (str): user
sid (str): simulation id
"""
def __init__(self, sim_type, uid, sid, *args, **kwargs):
super().__init__(
*args,
sr_args=PKDict(sim_type=sim_type, uid=uid, sid=sid),
**kwargs,
)
[docs]
class SReplyExc(OKReplyExc):
"""Raise with an SReply object
Args:
sreply (object): what the reply should be
log_fmt (str): server side log data
"""
def __init__(self, sreply, *args, **kwargs):
super().__init__(*args, sr_args=PKDict(sreply=sreply), **kwargs)
[docs]
class SRException(ReplyExc):
"""Raised to communicate a local redirect and log info
`params` may have ``simulationType``, which will be used for routeName rendering.
Otherwise, ``sim_type`` on ``qcall`` will be used.
Args:
route_name (str): a local route
params (dict): for route url or for srExceptionOnly case
log_fmt (str): server side log data
"""
def __init__(self, route_name, params, *args, **kwargs):
super().__init__(
*args,
sr_args=PKDict(routeName=route_name, params=params),
**kwargs,
)
[docs]
class Unauthorized(ReplyExc):
"""Raised to generate 401 response"""
pass
[docs]
class UserAlert(ReplyExc):
"""Raised to display a user error and log info
Args:
display_text (str): string that user will see
log_fmt (str): server side log data
"""
def __init__(self, display_text, *args, **kwargs):
super().__init__(*args, sr_args=PKDict(error=display_text), **kwargs)
[docs]
class UserDirNotFound(NotFound):
"""Raised by simulation_db
Args:
user_dir (py.path): directory not found
uid (str): user
"""
def __init__(self, user_dir, uid, *args, **kwargs):
super().__init__(
*args,
sr_args=PKDict(user_dir=user_dir, uid=uid),
**kwargs,
)
[docs]
class WWWAuthenticate(ReplyExc):
"""Raised to generate 401 response with WWWAuthenticate response"""
pass
[docs]
def assert_sim_type(sim_type):
"""Validate simulation type
Args:
sim_type (str): to check
Returns:
str: validated sim_type
"""
if not is_sim_type(sim_type):
raise AssertionError(f"invalid simulation type={sim_type}")
return sim_type
[docs]
def create_token(value):
if pkconfig.channel_in_internal_test() and _cfg.create_token_secret:
v = base64.b32encode(
hashlib.sha256(pkcompat.to_bytes(value + _cfg.create_token_secret)).digest()
)
return pkcompat.from_bytes(v[:TOKEN_SIZE])
return random_base62(TOKEN_SIZE)
[docs]
def err(obj, fmt="", *args, **kwargs):
return "{}: ".format(obj) + fmt.format(*args, **kwargs)
[docs]
def files_to_watch_for_reload(*extensions):
from sirepo import feature_config
for e in extensions:
for p in sorted(set(["sirepo", *feature_config.cfg().package_path])):
d = pykern.pkio.py_path(
getattr(importlib.import_module(p), "__file__"),
).dirname
for f in pykern.pkio.sorted_glob(f"{d}/**/*.{e}"):
yield f
[docs]
def find_obj(arr, key, value):
"""Return the first object in the array such that obj[key] == value
Args:
arr (list): list of dict-like objects
key (str): object key
value (*): value
Returns:
object: the object, or None if not found
"""
for o in arr:
if o[key] == value:
return o
return None
[docs]
def first_sim_type():
"""Returns the first configured sim_type"""
global _FIRST_SIM_TYPE
if _FIRST_SIM_TYPE:
return _FIRST_SIM_TYPE
from sirepo import feature_config
x = feature_config.auth_controlled_sim_types()
_FIRST_SIM_TYPE = (sorted(feature_config.cfg().sim_types - x) or sorted(x))[0]
return _FIRST_SIM_TYPE
[docs]
def import_submodule(submodule, type_or_data):
"""Import fully qualified module that contains submodule for sim type
sirepo.feature_config.package_path will be searched for a match.
Args:
submodule (str): the name of the submodule
type_or_data (str or dict): simulation type or description
Returns:
module: simulation type module instance
"""
from sirepo import feature_config
from sirepo import template
sim_type = template.assert_sim_type(
(
type_or_data.simulationType
if isinstance(
type_or_data,
PKDict,
)
else type_or_data
),
)
for p in feature_config.cfg().package_path:
n = None
try:
n = f"{p}.{submodule}.{sim_type}"
return importlib.import_module(n)
except ModuleNotFoundError as e:
if n is not None and n != e.name:
# import is failing due to ModuleNotFoundError in a sub-import
# not the module we are looking for
raise
s = pkdexc()
pass
# gives more debugging info (perhaps more confusion)
pkdc(s)
raise AssertionError(
f"cannot find submodule={submodule} for sim_type={sim_type} in package_path={feature_config.cfg().package_path}"
)
[docs]
def is_jupyter_enabled():
return is_sim_type(sirepo.const.SIM_TYPE_JUPYTERHUBLOGIN)
[docs]
def is_python_identifier(name):
return _VALID_PYTHON_IDENTIFIER.search(name)
[docs]
def is_sim_type(sim_type):
"""Validate simulation type
Args:
sim_type (str): to check
Returns:
bool: true if is a sim_type
"""
from sirepo import feature_config
return sim_type in feature_config.cfg().sim_types
[docs]
def json_dump(obj, path=None, pretty=False, **kwargs):
"""Formats as json as string, and writing atomically to disk
Args:
obj (object): any Python object
path (py.path): where to write (atomic) [None]
pretty (bool): pretty print [False]
kwargs (object): other arguments to `json.dumps`
Returns:
str: sorted and formatted JSON
"""
res = pykern.pkjson.dump_pretty(obj, pretty=pretty, allow_nan=False, **kwargs)
if path:
pykern.pkio.atomic_write(path, res)
return res
[docs]
def json_path(path, run_dir=None):
"""Append sirepo.const.JSON_SUFFIX if necessary and convert to str
Args:
path (py.path or str): to convert
run_dir (py.path): which directory to join (only if path is str)
Returns:
py.path: path.json
"""
def _path():
if not isinstance(path, str):
if run_dir:
raise AssertionError(
f"path={path} is a py.path, cannot join run_dir={run_dir}"
)
return path
if not run_dir:
return pykern.pkio.py_path(path)
if os.path.isabs(path):
raise AssertionError(
f"path={path} is absolute, cannot join run_dir={run_dir}"
)
return run_dir.join(path)
p = _path()
if p.ext == sirepo.const.JSON_SUFFIX:
return p
# Do not replace using new, because may already have suffix
return p + sirepo.const.JSON_SUFFIX
[docs]
def json_read(path):
"""Read data from json file
Args:
path (py.path or str): will append sirepo.const.JSON_SUFFIX if necessary
Returns:
object: json converted to python
"""
return pykern.pkjson.load_any(json_path(path))
[docs]
def numpy_to_py(obj):
"""Convert numpy objects to Python objects
Use to avoid `repr` conversions.
Args:
obj (object): source
Returns:
object: no numpy.floating objects
"""
if isinstance(obj, numpy.floating):
return float(obj)
if isinstance(obj, (list, tuple)):
return type(obj)(numpy_to_py(o) for o in obj)
if isinstance(obj, dict):
return type(obj)({k: numpy_to_py(v) for k, v in obj.items()})
return obj
[docs]
def plan_role_expiration(role):
"""Get expiration for an (asserted) plan
Args:
role (str): plan to change to
Returns:
datetime: new expiration
"""
from sirepo import auth_role, feature_config, srtime
import datetime
def _duration():
return datetime.timedelta(
feature_config.cfg().trial_expiration_days
if role == auth_role.ROLE_PLAN_TRIAL
else 365
)
if role not in auth_role.PLAN_ROLES:
raise AssertionError(f"invalid plan role={role}")
if not feature_config.have_payments():
return None
return srtime.utc_now() + _duration()
[docs]
def random_base62(length=32, prefix=None):
"""Returns a safe string of sufficient length to be a nonce
Args:
length (int): how long to make the base62 string [32]
Returns:
str: random base62 characters
"""
res = "".join(random.SystemRandom().choice(numconv.BASE62) for x in range(length))
return f"{prefix}_{res}" if prefix else res
[docs]
def read_zip(path_or_bytes):
"""Read the contents of a zip archive.
Protects against malicious filenames (ex ../../filename)
Args:
path_or_bytes (py.path or str or bytes): The path to the archive or it's contents
Returns:
(py.path, bytes): The basename of the file, the contents of the file
"""
p = path_or_bytes
if isinstance(p, bytes):
p = io.BytesIO(p)
with zipfile.ZipFile(p, "r") as z:
for i in z.infolist():
if i.is_dir():
continue
# SECURITY: Use only basename of file to prevent against
# malicious files (ex ../../filename)
yield pykern.pkio.py_path(i.filename).basename, z.read(i)
[docs]
def sanitize_string(string):
"""Remove special characters from string
This results in a string the is a valid python identifier.
This string can also be used as a css id because valid
python identifiers are also valid css ids.
Args:
string (str): The string to sanatize
Returns:
(str): A string with special characters replaced
"""
if is_python_identifier(string):
return string
return _INVALID_PYTHON_IDENTIFIER.sub("_", string)
[docs]
def secure_filename(path):
"""Converts a user supplied path to a secure file
Args:
path (str): contains anything
Returns:
str: does not contain special file system chars or path chars
"""
p = (
unicodedata.normalize(
"NFKD",
path,
)
.encode(
"ascii",
"ignore",
)
.decode(
"ascii",
)
.replace(
"/",
" ",
)
)
p = _INVALID_PATH_CHARS.sub("", "_".join(p.split())).strip("._")
return "file" if p == "" else p
[docs]
def setattr_imports(imports):
m = pykern.pkinspect.caller_module()
for k, v in imports.items():
setattr(m, k, v)
[docs]
def split_comma_delimited_string(s, f_type):
return [f_type(x) for x in re.split(r"\s*,\s*", s)]
[docs]
def to_comma_delimited_string(arr):
return ",".join([str(x) for x in arr])
[docs]
def unique_key():
# TODO(e-carlin): create_token should be aligned with unique_key.
return random_base62(32)
[docs]
def url_safe_hash(value):
return hashlib.md5(pkcompat.to_bytes(value)).hexdigest()
[docs]
def validate_path(uri):
"""Ensures path component of uri is safe
Very strict. Doesn't allow any dot files and few specials.
Args:
uri (str): uncheck path
Returns:
str: validated path
"""
if uri == "" or uri is None:
raise AssertionError(f"empty uri")
res = []
for p in uri.split("/"):
if _INVALID_PATH_CHARS.search(p):
raise AssertionError(f"illegal char(s) in component={p} uri={uri}")
if p == "":
# covers absolute path case
raise AssertionError(f"empty component in uri={uri}")
if p.startswith("."):
raise AssertionError(f"dot prefix in component={p} uri={uri}")
res.append(p)
return "/".join(res)
[docs]
def write_zip(path):
return zipfile.ZipFile(
path,
mode="w",
compression=zipfile.ZIP_DEFLATED,
)
[docs]
async def yield_to_event_loop():
"""Documents and wraps ``asyncio.sleep(0)``
If a server (api, supervisor, agent) is doing a lot of work, call
this routine to release the processor to the event loop.
"""
await asyncio.sleep(0)
_cfg = pkconfig.init(
create_token_secret=("oh so secret!", str, "used for internal test only"),
)