Source code for sirepo.template.template_common

# -*- coding: utf-8 -*-
u"""SRW execution template.

:copyright: Copyright (c) 2015 RadiaSoft LLC.  All Rights Reserved.
:license: http://www.apache.org/licenses/LICENSE-2.0.html
"""
from __future__ import absolute_import, division, print_function
from pykern import pkcollections
from pykern import pkcompat
from pykern import pkio
from pykern import pkjinja
from pykern import pkresource
from pykern.pkdebug import pkdc, pkdlog, pkdp
import hashlib
import json
import numpy as np
import os.path
import py.path
import re
import sirepo.template

ANIMATION_ARGS_VERSION_RE = re.compile(r'v(\d+)$')

DEFAULT_INTENSITY_DISTANCE = 20

#: Input json file
INPUT_BASE_NAME = 'in'

LIB_FILE_PARAM_RE = re.compile(r'.*File$')

#: Output json file
OUTPUT_BASE_NAME = 'out'

#: Python file (not all simulations)
PARAMETERS_PYTHON_FILE = 'parameters.py'

#: stderr and stdout
RUN_LOG = 'run.log'

_HISTOGRAM_BINS_MAX = 500

_PLOT_LINE_COLOR = ['#1f77b4', '#ff7f0e', '#2ca02c']

_RESOURCE_DIR = py.path.local(pkresource.filename('template'))

_WATCHPOINT_REPORT_NAME = 'watchpointReport'


[docs]def compute_field_range(args, compute_range): """ Computes the fieldRange values for all parameters across all animation files. Caches the value on the animation input file. compute_range() is called to read the simulation specific datafiles and extract the ranges by field. """ from sirepo import simulation_db run_dir = simulation_db.simulation_run_dir({ 'simulationType': args['simulationType'], 'simulationId': args['simulationId'], 'report': 'animation', }) data = simulation_db.read_json(run_dir.join(INPUT_BASE_NAME)) res = None model_name = args['modelName'] if model_name in data.models: if 'fieldRange' in data.models[model_name]: res = data.models[model_name].fieldRange else: res = compute_range(run_dir, data) data.models[model_name].fieldRange = res simulation_db.write_json(run_dir.join(INPUT_BASE_NAME), data) return { 'fieldRange': res, }
[docs]def compute_plot_color_and_range(plots): """ For parameter plots, assign each plot a color and compute the full y_range. """ y_range = None for i in range(len(plots)): plot = plots[i] plot['color'] = _PLOT_LINE_COLOR[i] vmin = min(plot['points']) vmax = max(plot['points']) if y_range: if vmin < y_range[0]: y_range[0] = vmin if vmax > y_range[1]: y_range[1] = vmax else: y_range = [vmin, vmax] return y_range
[docs]def copy_lib_files(data, source, target): """Copy auxiliary files to target Args: data (dict): simulation db target (py.path): destination directory """ for f in lib_files(data, source): path = target.join(f.basename) pkio.mkdir_parent_only(path) if not path.exists(): if not f.exists(): sim_resource = resource_dir(data.simulationType) r = sim_resource.join(f.basename) # the file doesn't exist in the simulation lib, check the resource lib if r.exists(): pkio.mkdir_parent_only(f) r.copy(f) else: pkdlog('No file in lib or resource: {}', f) continue if source: # copy files from another session f.copy(path) else: # symlink into the run directory path.mksymlinkto(f, absolute=False)
[docs]def enum_text(schema, name, value): for e in schema['enum'][name]: if e[0] == value: return e[1] assert False, 'unknown {} enum value: {}'.format(name, value)
[docs]def flatten_data(d, res, prefix=''): """Takes a nested dictionary and converts it to a single level dictionary with flattened keys.""" for k in d: v = d[k] if isinstance(v, dict): flatten_data(v, res, prefix + k + '_') elif isinstance(v, list): pass else: res[prefix + k] = v return res
[docs]def filename_to_path(files, source_lib): """Returns full, unique paths of simulation files Returns: list: py.path.local to files """ res = [] seen = set() for f in files: if f not in seen: seen.add(f) res.append(source_lib.join(f)) return res
[docs]def heatmap(values, model, plot_fields=None): """Computes a report histogram (x_range, y_range, z_matrix) for a report model.""" range = None if not np.any(values): values = [[], []] if 'plotRangeType' in model: if model['plotRangeType'] == 'fixed': range = [_plot_range(model, 'horizontal'), _plot_range(model, 'vertical')] elif model['plotRangeType'] == 'fit' and 'fieldRange' in model: range = [model.fieldRange[model['x']], model.fieldRange[model['y']]] hist, edges = np.histogramdd(values, histogram_bins(model['histogramBins']), range=range) res = { 'x_range': [float(edges[0][0]), float(edges[0][-1]), len(hist)], 'y_range': [float(edges[1][0]), float(edges[1][-1]), len(hist[0])], 'z_matrix': hist.T.tolist(), } if plot_fields: res.update(plot_fields) return res
[docs]def histogram_bins(nbins): """Ensure the histogram count is in a valid range""" nbins = int(nbins) if nbins <= 0: nbins = 1 elif nbins > _HISTOGRAM_BINS_MAX: nbins = _HISTOGRAM_BINS_MAX return nbins
[docs]def is_watchpoint(name): return _WATCHPOINT_REPORT_NAME in name
[docs]def lib_file_name(model_name, field, value): return '{}-{}.{}'.format(model_name, field, value)
[docs]def lib_files(data, source_lib=None): """Return list of files used by the simulation Args: data (dict): sim db Returns: list: py.path.local to files """ from sirepo import simulation_db sim_type = data.simulationType return sirepo.template.import_module(data).lib_files( data, source_lib or simulation_db.simulation_lib_dir(sim_type), )
[docs]def model_defaults(name, schema): """Returns a set of default model values from the schema.""" res = pkcollections.Dict() for f in schema['model'][name]: field_info = schema['model'][name][f] if len(field_info) >= 3 and field_info[2] is not None: res[f] = field_info[2] return res
[docs]def parameter_plot(x, plots, model, plot_fields=None): res = { 'x_points': x, 'x_range': [min(x), max(x)], 'plots': plots, 'y_range': compute_plot_color_and_range(plots), } if 'plotRangeType' in model: if model.plotRangeType == 'fixed': res['x_range'] = _plot_range(model, 'horizontal') res['y_range'] = _plot_range(model, 'vertical') elif model.plotRangeType == 'fit': res['x_range'] = model.fieldRange[model.x] for i in range(len(plots)): r = model.fieldRange[plots[i]['field']] if r[0] < res['y_range'][0]: res['y_range'][0] = r[0] if r[1] > res['y_range'][1]: res['y_range'][1] = r[1] if plot_fields: res.update(plot_fields) return res
[docs]def parse_animation_args(data, key_map): """Parse animation args according to key_map Args: data (dict): contains animationArgs key_map (dict): version to keys mapping, default is '' Returns: Dict: mapped animationArgs with version """ a = data['animationArgs'].split('_') m = ANIMATION_ARGS_VERSION_RE.search(a[0]) if m: a.pop(0) v = int(m.group(1)) else: v = 1 try: keys = key_map[v] except KeyError: keys = key_map[''] res = pkcollections.Dict(zip(keys, a)) res.version = v return res
[docs]def parse_enums(enum_schema): """Returns a list of enum values, keyed by enum name.""" res = {} for k in enum_schema: res[k] = {} for v in enum_schema[k]: res[k][v[0]] = True return res
[docs]def render_jinja(sim_type, v, name=PARAMETERS_PYTHON_FILE): """Render the values into a jinja template. Args: sim_type (str): application name v: flattened model data Returns: str: source text """ b = resource_dir(sim_type).join(name) return pkjinja.render_file(b + '.jinja', v)
[docs]def report_parameters_hash(data): """Compute a hash of the parameters for his report. Only needs to be unique relative to the report, not globally unique so MD5 is adequate. Long and cryptographic hashes make the cache checks slower. Args: data (dict): report and related models Returns: str: url safe encoded hash """ if not 'reportParametersHash' in data: models = sirepo.template.import_module(data).models_related_to_report(data) res = hashlib.md5() dm = data['models'] for m in models: if pkcompat.isinstance_str(m): name, field = m.split('.') if '.' in m else (m, None) value = dm[name][field] if field else dm[name] else: value = m res.update(json.dumps(value, sort_keys=True, allow_nan=False).encode()) data['reportParametersHash'] = res.hexdigest() return data['reportParametersHash']
[docs]def report_fields(data, report_name, style_fields): # if the model has "style" fields, then return the full list of non-style fields # otherwise returns the report name (which implies all model fields) m = data.models[report_name] for style_field in style_fields: if style_field not in m: continue res = [] for f in m: if f in style_fields: continue res.append('{}.{}'.format(report_name, f)) return res return [report_name]
[docs]def resource_dir(sim_type): """Where to get library files from Args: sim_type (str): application name Returns: py.path.Local: absolute path to folder """ return _RESOURCE_DIR.join(sim_type)
[docs]def update_model_defaults(model, name, schema): defaults = model_defaults(name, schema) for f in defaults: if f not in model: model[f] = defaults[f]
[docs]def validate_model(model_data, model_schema, enum_info): """Ensure the value is valid for the field type. Scales values as needed.""" for k in model_schema: label = model_schema[k][0] field_type = model_schema[k][1] if k in model_data: value = model_data[k] elif len(model_schema[k]) > 2: value = model_schema[k][2] else: raise Exception('no value for field "{}" and no default value in schema'.format(k)) if field_type in enum_info: if str(value) not in enum_info[field_type]: # Check a comma-delimited string against the enumeration for item in re.split(r'\s*,\s*', str(value)): if item not in enum_info[field_type]: assert item in enum_info[field_type], \ '{}: invalid enum "{}" value for field "{}"'.format(item, field_type, k) elif field_type == 'Float': if not value: value = 0 v = float(value) if re.search('\[m(m|rad)\]', label) or re.search('\[Lines/mm', label): v /= 1000 elif re.search('\[n(m|rad)\]', label) or re.search('\[nm/pixel\]', label): v /= 1e09 elif re.search('\[ps]', label): v /= 1e12 #TODO(pjm): need to handle unicode in label better (mu) elif re.search('\[\xb5(m|rad)\]', label) or re.search('\[mm-mrad\]', label): v /= 1e6 model_data[k] = float(v) elif field_type == 'Integer': if not value: value = 0 model_data[k] = int(value) else: model_data[k] = _escape(value)
[docs]def validate_models(model_data, model_schema): """Validate top-level models in the schema. Returns enum_info.""" enum_info = parse_enums(model_schema['enum']) for k in model_data['models']: if k in model_schema['model']: validate_model(model_data['models'][k], model_schema['model'][k], enum_info) if 'beamline' in model_data['models']: for m in model_data['models']['beamline']: validate_model(m, model_schema['model'][m['type']], enum_info) return enum_info
[docs]def file_extension_ok(file_path, white_list=[], black_list=['py', 'pyc']): """Determine whether a file has an acceptable extension Args: file_path (str): name of the file to examine white_list ([str]): list of file types allowed (defaults to empty list) black_list ([str]): list of file types rejected (defaults to ['py', 'pyc']). Ignored if white_list is not empty Returns: If file is a directory: True If white_list non-empty: True if the file's extension matches any in the list, otherwise False If white_list is empty: False if the file's extension matches any in black_list, otherwise True """ import os if os.path.isdir(file_path): return True if len(white_list) > 0: in_list = False for ext in white_list: in_list = in_list or pkio.has_file_extension(file_path, ext) if not in_list: return False return True for ext in black_list: if pkio.has_file_extension(file_path, ext): return False return True
[docs]def validate_safe_zip(zip_file_name, target_dir='.', *args): """Determine whether a zip file is safe to extract from Performs the following checks: - Each file must end up at or below the target directory - Files must be 100MB or smaller - If possible to determine, disallow "non-regular" and executable files - Existing files cannot be overwritten Args: zip_file_name (str): name of the zip file to examine target_dir (str): name of the directory to extract into (default to current directory) *args: list of validator functions taking a zip file as argument and returning True or False and a string Throws: AssertionError if any test fails, otherwise completes silently """ import zipfile import os def path_is_sub_path(path, dir_name): real_dir = os.path.realpath(dir_name) end_path = os.path.realpath(real_dir + '/' + path) return end_path.startswith(real_dir) def file_exists_in_dir(file_name, dir_name): return os.path.exists(os.path.realpath(dir_name + '/' + file_name)) def file_attrs_ok(attrs): # ms-dos attributes only use two bytes and don't contain much useful info, so pass them if attrs < 2 << 16: return True # UNIX file attributes live in the top two bytes mask = attrs >> 16 is_file_or_dir = mask & (0o0100000 | 0o0040000) != 0 no_exec = mask & (0o0000100 | 0o0000010 | 0o0000001) == 0 return is_file_or_dir and no_exec # 100MB max_file_size = 100000000 zip_file = zipfile.ZipFile(zip_file_name) for f in zip_file.namelist(): i = zip_file.getinfo(f) s = i.file_size attrs = i.external_attr assert path_is_sub_path(f, target_dir), 'Cannot extract {} above target directory'.format(f) assert s <= max_file_size, '{} too large ({} > {})'.format(f, str(s), str(max_file_size)) assert file_attrs_ok(attrs), '{} not a normal file or is executable'.format(f) assert not file_exists_in_dir(f, target_dir), 'Cannot overwrite file {} in target directory {}'.format(f, target_dir) for validator in args: res, err_string = validator(zip_file) assert res, '{} failed validator: {}'.format(os.path.basename(zip_file_name), err_string)
[docs]def zip_path_for_file(zf, file_to_find): """Find the full path of the specified file within the zip. For a zip zf containing: foo1 foo2 bar/ bar/foo3 zip_path_for_file(zf, 'foo3') will return 'bar/foo3' Args: zf(zipfile.ZipFile): the zip file to examine file_to_find (str): name of the file to find Returns: The first path in the zip that matches the file name, or None if no match is found """ import os # Get the base file names from the zip (directories have a basename of '') file_names_in_zip = map(lambda path: os.path.basename(path), zf.namelist()) return zf.namelist()[file_names_in_zip.index(file_to_find)]
[docs]def watchpoint_id(report): m = re.search(_WATCHPOINT_REPORT_NAME + '(\d+)', report) if not m: raise RuntimeError('invalid watchpoint report name: ', report) return int(m.group(1))
def _escape(v): return re.sub("[\"'()]", '', str(v)) def _plot_range(report, axis): half_size = float(report['{}Size'.format(axis)]) / 2.0 midpoint = float(report['{}Offset'.format(axis)]) return [midpoint - half_size, midpoint + half_size]