Source code for sirepo.util

"""Support routines and classes, mostly around errors and I/O.

:copyright: Copyright (c) 2018 RadiaSoft LLC.  All Rights Reserved.
:license: http://www.apache.org/licenses/LICENSE-2.0.html
"""

# NOTE: limit sirepo imports here
from pykern import pkcompat
from pykern import pkconfig
from pykern.pkcollections import PKDict
from pykern.pkdebug import pkdlog, pkdp, pkdexc, pkdc
import asyncio
import base64
import hashlib
import importlib
import io
import inspect
import numconv
import numpy
import os.path
import pykern.pkinspect
import pykern.pkio
import pykern.pkjson
import re
import random
import sirepo.const
import unicodedata
import zipfile


_cfg = None

#: length of string returned by create_token
TOKEN_SIZE = 16

#: POSIT: Matches anything generated by `unique_key`
UNIQUE_KEY_CHARS_RE = r"\w+"

#: A standalone unique key
UNIQUE_KEY_RE = re.compile(r"^{}$".format(UNIQUE_KEY_CHARS_RE))

_FIRST_SIM_TYPE = None

# See https://github.com/radiasoft/sirepo/pull/3889#discussion_r738769716
# for reasoning on why define both
_INVALID_PYTHON_IDENTIFIER = re.compile(r"\W|^(?=\d)")
_VALID_PYTHON_IDENTIFIER = re.compile(r"^[a-z_]\w*$", re.IGNORECASE)

_INVALID_PATH_CHARS = re.compile(r"[^A-Za-z0-9_.-]")


[docs] class ReplyExc(Exception): """Raised to end the request. Args: sr_args (dict): exception args that Sirepo specific log_fmt (str): server side log data """ def __init__(self, *args, **kwargs): super().__init__() if "sr_args" in kwargs: self.sr_args = kwargs["sr_args"] del kwargs["sr_args"] else: self.sr_args = PKDict() if args or kwargs: kwargs["pkdebug_frame"] = inspect.currentframe().f_back.f_back pkdlog(*args, **kwargs) def __repr__(self): a = self.sr_args return "{}({})".format( self.__class__.__name__, ",".join( ("{}={}".format(k, a[k]) for k in sorted(a.keys())), ), ) def __str__(self): return self.__repr__()
[docs] class BadRequest(ReplyExc): """Raised for bad request""" pass
[docs] class OKReplyExc(ReplyExc): """When a ReplyExc exception is a successful response""" pass
[docs] class InvalidEmail(ReplyExc): """Email user is trying to register with is malformed or domain is on deny list.""" pass
[docs] class Error(ReplyExc): """Raised to send an error response Args: error_msg (str): just the error to output to user """ def __init__(self, error, *args, **kwargs): # removed dict usage assert isinstance(error, str) super().__init__(*args, sr_args=PKDict(error=error), **kwargs)
[docs] class Forbidden(ReplyExc): """Raised for forbidden""" pass
[docs] class NotFound(ReplyExc): """Raised for not found""" pass
[docs] class PlanExpired(ReplyExc): """API requires and active plan""" pass
[docs] class Redirect(OKReplyExc): """Raised to redirect Args: uri (str): where to redirect to log_fmt (str): server side log data """ def __init__(self, uri, *args, **kwargs): super().__init__(*args, sr_args=PKDict(uri=uri), **kwargs)
[docs] class ContentTooLarge(ReplyExc): """The content requested by the user was too large (ex large data file)""" pass
[docs] class ServerError(ReplyExc): """Raised for server error""" pass
[docs] class SPathNotFound(NotFound): """Raised by simulation_db Args: sim_type (str): simulation type uid (str): user sid (str): simulation id """ def __init__(self, sim_type, uid, sid, *args, **kwargs): super().__init__( *args, sr_args=PKDict(sim_type=sim_type, uid=uid, sid=sid), **kwargs, )
[docs] class SReplyExc(OKReplyExc): """Raise with an SReply object Args: sreply (object): what the reply should be log_fmt (str): server side log data """ def __init__(self, sreply, *args, **kwargs): super().__init__(*args, sr_args=PKDict(sreply=sreply), **kwargs)
[docs] class SRException(ReplyExc): """Raised to communicate a local redirect and log info `params` may have ``simulationType``, which will be used for routeName rendering. Otherwise, ``sim_type`` on ``qcall`` will be used. Args: route_name (str): a local route params (dict): for route url or for srExceptionOnly case log_fmt (str): server side log data """ def __init__(self, route_name, params, *args, **kwargs): super().__init__( *args, sr_args=PKDict(routeName=route_name, params=params), **kwargs, )
[docs] class Unauthorized(ReplyExc): """Raised to generate 401 response""" pass
[docs] class UserAlert(ReplyExc): """Raised to display a user error and log info Args: display_text (str): string that user will see log_fmt (str): server side log data """ def __init__(self, display_text, *args, **kwargs): super().__init__(*args, sr_args=PKDict(error=display_text), **kwargs)
[docs] class UserDirNotFound(NotFound): """Raised by simulation_db Args: user_dir (py.path): directory not found uid (str): user """ def __init__(self, user_dir, uid, *args, **kwargs): super().__init__( *args, sr_args=PKDict(user_dir=user_dir, uid=uid), **kwargs, )
[docs] class WWWAuthenticate(ReplyExc): """Raised to generate 401 response with WWWAuthenticate response""" pass
[docs] def assert_sim_type(sim_type): """Validate simulation type Args: sim_type (str): to check Returns: str: validated sim_type """ if not is_sim_type(sim_type): raise AssertionError(f"invalid simulation type={sim_type}") return sim_type
[docs] def create_token(value): if pkconfig.channel_in_internal_test() and _cfg.create_token_secret: v = base64.b32encode( hashlib.sha256(pkcompat.to_bytes(value + _cfg.create_token_secret)).digest() ) return pkcompat.from_bytes(v[:TOKEN_SIZE]) return random_base62(TOKEN_SIZE)
[docs] def err(obj, fmt="", *args, **kwargs): return "{}: ".format(obj) + fmt.format(*args, **kwargs)
[docs] def files_to_watch_for_reload(*extensions): from sirepo import feature_config for e in extensions: for p in sorted(set(["sirepo", *feature_config.cfg().package_path])): d = pykern.pkio.py_path( getattr(importlib.import_module(p), "__file__"), ).dirname for f in pykern.pkio.sorted_glob(f"{d}/**/*.{e}"): yield f
[docs] def find_obj(arr, key, value): """Return the first object in the array such that obj[key] == value Args: arr (list): list of dict-like objects key (str): object key value (*): value Returns: object: the object, or None if not found """ for o in arr: if o[key] == value: return o return None
[docs] def first_sim_type(): """Returns the first configured sim_type""" global _FIRST_SIM_TYPE if _FIRST_SIM_TYPE: return _FIRST_SIM_TYPE from sirepo import feature_config x = feature_config.auth_controlled_sim_types() _FIRST_SIM_TYPE = (sorted(feature_config.cfg().sim_types - x) or sorted(x))[0] return _FIRST_SIM_TYPE
[docs] def import_submodule(submodule, type_or_data): """Import fully qualified module that contains submodule for sim type sirepo.feature_config.package_path will be searched for a match. Args: submodule (str): the name of the submodule type_or_data (str or dict): simulation type or description Returns: module: simulation type module instance """ from sirepo import feature_config from sirepo import template sim_type = template.assert_sim_type( ( type_or_data.simulationType if isinstance( type_or_data, PKDict, ) else type_or_data ), ) for p in feature_config.cfg().package_path: n = None try: n = f"{p}.{submodule}.{sim_type}" return importlib.import_module(n) except ModuleNotFoundError as e: if n is not None and n != e.name: # import is failing due to ModuleNotFoundError in a sub-import # not the module we are looking for raise s = pkdexc() pass # gives more debugging info (perhaps more confusion) pkdc(s) raise AssertionError( f"cannot find submodule={submodule} for sim_type={sim_type} in package_path={feature_config.cfg().package_path}" )
[docs] def is_jupyter_enabled(): return is_sim_type(sirepo.const.SIM_TYPE_JUPYTERHUBLOGIN)
[docs] def is_python_identifier(name): return _VALID_PYTHON_IDENTIFIER.search(name)
[docs] def is_sim_type(sim_type): """Validate simulation type Args: sim_type (str): to check Returns: bool: true if is a sim_type """ from sirepo import feature_config return sim_type in feature_config.cfg().sim_types
[docs] def json_dump(obj, path=None, pretty=False, **kwargs): """Formats as json as string, and writing atomically to disk Args: obj (object): any Python object path (py.path): where to write (atomic) [None] pretty (bool): pretty print [False] kwargs (object): other arguments to `json.dumps` Returns: str: sorted and formatted JSON """ res = pykern.pkjson.dump_pretty(obj, pretty=pretty, allow_nan=False, **kwargs) if path: pykern.pkio.atomic_write(path, res) return res
[docs] def json_path(path, run_dir=None): """Append sirepo.const.JSON_SUFFIX if necessary and convert to str Args: path (py.path or str): to convert run_dir (py.path): which directory to join (only if path is str) Returns: py.path: path.json """ def _path(): if not isinstance(path, str): if run_dir: raise AssertionError( f"path={path} is a py.path, cannot join run_dir={run_dir}" ) return path if not run_dir: return pykern.pkio.py_path(path) if os.path.isabs(path): raise AssertionError( f"path={path} is absolute, cannot join run_dir={run_dir}" ) return run_dir.join(path) p = _path() if p.ext == sirepo.const.JSON_SUFFIX: return p # Do not replace using new, because may already have suffix return p + sirepo.const.JSON_SUFFIX
[docs] def json_read(path): """Read data from json file Args: path (py.path or str): will append sirepo.const.JSON_SUFFIX if necessary Returns: object: json converted to python """ return pykern.pkjson.load_any(json_path(path))
[docs] def numpy_to_py(obj): """Convert numpy objects to Python objects Use to avoid `repr` conversions. Args: obj (object): source Returns: object: no numpy.floating objects """ if isinstance(obj, numpy.floating): return float(obj) if isinstance(obj, (list, tuple)): return type(obj)(numpy_to_py(o) for o in obj) if isinstance(obj, dict): return type(obj)({k: numpy_to_py(v) for k, v in obj.items()}) return obj
[docs] def plan_role_expiration(role): """Get expiration for an (asserted) plan Args: role (str): plan to change to Returns: datetime: new expiration """ from sirepo import auth_role, feature_config, srtime import datetime def _duration(): return datetime.timedelta( feature_config.cfg().trial_expiration_days if role == auth_role.ROLE_PLAN_TRIAL else 365 ) if role not in auth_role.PLAN_ROLES: raise AssertionError(f"invalid plan role={role}") if not feature_config.have_payments(): return None return srtime.utc_now() + _duration()
[docs] def random_base62(length=32, prefix=None): """Returns a safe string of sufficient length to be a nonce Args: length (int): how long to make the base62 string [32] Returns: str: random base62 characters """ res = "".join(random.SystemRandom().choice(numconv.BASE62) for x in range(length)) return f"{prefix}_{res}" if prefix else res
[docs] def read_zip(path_or_bytes): """Read the contents of a zip archive. Protects against malicious filenames (ex ../../filename) Args: path_or_bytes (py.path or str or bytes): The path to the archive or it's contents Returns: (py.path, bytes): The basename of the file, the contents of the file """ p = path_or_bytes if isinstance(p, bytes): p = io.BytesIO(p) with zipfile.ZipFile(p, "r") as z: for i in z.infolist(): if i.is_dir(): continue # SECURITY: Use only basename of file to prevent against # malicious files (ex ../../filename) yield pykern.pkio.py_path(i.filename).basename, z.read(i)
[docs] def sanitize_string(string): """Remove special characters from string This results in a string the is a valid python identifier. This string can also be used as a css id because valid python identifiers are also valid css ids. Args: string (str): The string to sanatize Returns: (str): A string with special characters replaced """ if is_python_identifier(string): return string return _INVALID_PYTHON_IDENTIFIER.sub("_", string)
[docs] def secure_filename(path): """Converts a user supplied path to a secure file Args: path (str): contains anything Returns: str: does not contain special file system chars or path chars """ p = ( unicodedata.normalize( "NFKD", path, ) .encode( "ascii", "ignore", ) .decode( "ascii", ) .replace( "/", " ", ) ) p = _INVALID_PATH_CHARS.sub("", "_".join(p.split())).strip("._") return "file" if p == "" else p
[docs] def setattr_imports(imports): m = pykern.pkinspect.caller_module() for k, v in imports.items(): setattr(m, k, v)
[docs] def split_comma_delimited_string(s, f_type): return [f_type(x) for x in re.split(r"\s*,\s*", s)]
[docs] def to_comma_delimited_string(arr): return ",".join([str(x) for x in arr])
[docs] def unique_key(): # TODO(e-carlin): create_token should be aligned with unique_key. return random_base62(32)
[docs] def url_safe_hash(value): return hashlib.md5(pkcompat.to_bytes(value)).hexdigest()
[docs] def validate_path(uri): """Ensures path component of uri is safe Very strict. Doesn't allow any dot files and few specials. Args: uri (str): uncheck path Returns: str: validated path """ if uri == "" or uri is None: raise AssertionError(f"empty uri") res = [] for p in uri.split("/"): if _INVALID_PATH_CHARS.search(p): raise AssertionError(f"illegal char(s) in component={p} uri={uri}") if p == "": # covers absolute path case raise AssertionError(f"empty component in uri={uri}") if p.startswith("."): raise AssertionError(f"dot prefix in component={p} uri={uri}") res.append(p) return "/".join(res)
[docs] def write_zip(path): return zipfile.ZipFile( path, mode="w", compression=zipfile.ZIP_DEFLATED, )
[docs] async def yield_to_event_loop(): """Documents and wraps ``asyncio.sleep(0)`` If a server (api, supervisor, agent) is doing a lot of work, call this routine to release the processor to the event loop. """ await asyncio.sleep(0)
_cfg = pkconfig.init( create_token_secret=("oh so secret!", str, "used for internal test only"), )