# -*- coding: utf-8 -*-
"""Code variables.
:copyright: Copyright (c) 2020 RadiaSoft LLC. All Rights Reserved.
:license: http://www.apache.org/licenses/LICENSE-2.0.html
"""
from __future__ import absolute_import, division, print_function
from pykern.pkcollections import PKDict
from pykern.pkdebug import pkdc, pkdlog, pkdp
from sirepo.template import lattice
from sirepo.template import template_common
import ast
import inspect
import math
import operator
import re
[docs]
class CodeVar:
_INFIX_TO_RPN = PKDict(
{
ast.Add: "+",
ast.Div: "/",
ast.Invert: "!",
ast.Mult: "*",
ast.Not: "!",
ast.Pow: "pow",
ast.Sub: "-",
ast.USub: "chs",
}
)
def __init__(self, variables, evaluator, case_insensitive=False):
self.case_insensitive = case_insensitive
self.variables = self.__variables_by_name(variables)
self.postfix_variables = self.__variables_to_postfix(self.variables)
self.evaluator = evaluator
[docs]
def canonicalize(self, expr):
if self.case_insensitive:
return expr.lower()
return expr
[docs]
def compute_cache(self, data, schema):
if "models" not in data:
return None
it = CodeVarIterator(self, data, schema)
cache = lattice.LatticeUtil(data, schema).iterate_models(it).result
for name, value in self.variables.items():
it.add_to_cache(name, value)
data.models.rpnCache = cache
return cache
[docs]
def eval_var(self, expr):
if not self.is_var_value(expr):
return expr, None
expr = self.infix_to_postfix(self.canonicalize(expr))
return self.evaluator.eval_var(
expr,
self.get_expr_dependencies(expr),
self.postfix_variables,
)
[docs]
def eval_var_with_assert(self, expr):
(v, err) = self.eval_var(expr)
assert not err, f"expr={expr} err={err}"
try:
return float(v)
except ValueError:
return v
[docs]
def get_expr_dependencies(self, expr, depends=None, visited=None):
# expr must be in postfix format
if depends is None:
depends = []
visited = {}
if self.is_var_value(expr):
expr = self.canonicalize(expr)
for v in str(expr).split(" "):
if v in self.postfix_variables:
if v not in depends:
if v in visited:
# avoid circular dependencies
return depends
visited[v] = True
self.get_expr_dependencies(
self.postfix_variables[v],
depends,
visited,
)
depends.append(v)
return depends
[docs]
def generate_variables(self, variable_formatter, postfix=False):
res = ""
visited = PKDict()
variables = self.postfix_variables if postfix else self.variables
for name in sorted(variables):
for dependency in self.get_expr_dependencies(
self.postfix_variables[name],
):
res += variable_formatter(dependency, variables, visited)
res += variable_formatter(name, variables, visited)
return res
[docs]
def recompute_cache(self, cache):
for k in cache:
v, err = self.eval_var(k)
if not err:
cache[k] = v
[docs]
def stateful_compute_rpn_value(self, data, schema, **kwargs):
v, err = self.eval_var(data.value)
if err:
data.error = err
else:
data.result = v
return data
[docs]
def stateful_compute_recompute_rpn_cache_values(self, data, schema, **kwargs):
self.recompute_cache(data.cache)
return data
[docs]
def stateful_compute_validate_rpn_delete(self, data, schema, **kwargs):
from sirepo import simulation_db
model_data = simulation_db.read_json(
simulation_db.sim_data_file(
data.simulationType,
data.simulationId,
)
)
data.error = self.validate_var_delete(
data.name,
model_data,
schema,
)
return data
[docs]
def validate_var_delete(self, name, data, schema):
search = self.canonicalize(name)
in_use = []
for k, value in self.postfix_variables.items():
if k == search:
continue
for v in str(value).split(" "):
if v == search:
in_use.append(k)
break
if in_use:
return '"{}" is in use in variable(s): {}'.format(
name,
", ".join(in_use),
)
in_use = (
lattice.LatticeUtil(data, schema)
.iterate_models(
CodeVarDeleteIterator(self, search),
)
.result
)
if in_use:
return '"{}" is in use in element(s): {}'.format(
name,
", ".join(in_use),
)
return None
[docs]
@classmethod
def infix_to_postfix(cls, expr):
try:
if cls.is_var_value(expr):
expr = re.sub(r"\^", "**", expr)
rpn = cls.__parse_expr_infix(expr)
expr = rpn
else:
expr = float(expr)
except Exception as e:
pass
return expr
[docs]
@classmethod
def is_var_value(cls, value):
if value:
# is it a single value in numeric format?
if template_common.NUMERIC_RE.search(str(value)):
return False
return True
return False
@classmethod
def __parse_expr_infix(cls, expr):
"""Use Python parser (ast) and return depth first (RPN) tree"""
# https://bitbucket.org/takluyver/greentreesnakes/src/587ad72894bc7595bc30e33affaa238ac32f0740/astpp.py?at=default&fileviewer=file-view-default
def _do(n):
# http://greentreesnakes.readthedocs.io/en/latest/nodes.html
if isinstance(n, ast.Str):
assert not re.search(r'^[^\'"]*$', n.s), "{}: invalid string".format(
n.s
)
return ['"{}"'.format(n.s)]
elif isinstance(n, ast.Name):
return [str(n.id)]
elif isinstance(n, ast.Num):
return [str(n.n)]
elif isinstance(n, ast.Expression):
return _do(n.body)
elif isinstance(n, ast.Call):
res = []
for x in n.args:
res.extend(_do(x))
return res + [n.func.id]
elif isinstance(n, ast.BinOp):
return _do(n.left) + _do(n.right) + _do(n.op)
elif isinstance(n, ast.UAdd):
return []
elif isinstance(n, ast.UnaryOp):
return _do(n.operand) + _do(n.op)
elif isinstance(n, ast.IfExp):
return _do(n.test) + ["?"] + _do(n.body) + [":"] + _do(n.orelse) + ["$"]
# convert an attribute-like value, ex. l.MQ, into a string "l.MQ"
elif isinstance(n, ast.Attribute):
return ["{}.{}".format(_do(n.value)[0], n.attr)]
else:
x = CodeVar._INFIX_TO_RPN.get(type(n), None)
if x:
return [x]
raise ValueError("invalid node: {}".format(ast.dump(n)))
tree = ast.parse(expr, filename="eval", mode="eval")
assert isinstance(tree, ast.Expression), "{}: must be an expression".format(
tree
)
return " ".join(_do(tree))
def __variables_by_name(self, variables):
res = PKDict()
for v in variables:
# work-around for #4935 skip invalid variables
if v is None or v["name"] is None:
continue
n = self.canonicalize(v["name"])
value = v.get("value", 0)
if self.case_insensitive and type(value) == str:
value = value.lower()
res[n] = value
return res
def __variables_to_postfix(self, variables):
res = PKDict()
for name in variables:
res[name] = self.infix_to_postfix(variables[name])
return res
[docs]
class CodeVarIterator(lattice.ModelIterator):
def __init__(self, code_var, data, schema):
self.result = PKDict()
self.code_var = code_var
self.__add_beamline_fields(data, schema)
[docs]
def add_to_cache(self, name, value):
v = self.__add_value(value)
if v is not None:
self.result[name] = v
[docs]
def field(self, model, field_schema, field):
value = model[field]
if field_schema[1] == "RPNValue":
self.__add_value(value)
def __add_beamline_fields(self, data, schema):
if not schema.get("model") or not schema.model.get("beamline"):
return
bs = schema.model.beamline
for bl in data.models.beamlines:
if "positions" not in bl:
continue
for f in bs:
if f in bl and bl[f]:
self.field(bl, bs[f], f)
for p in bl.positions:
for f in p:
if p[f] and self.code_var.is_var_value(p[f]):
self.add_to_cache(p[f], p[f])
def __add_value(self, value):
if self.code_var.is_var_value(value):
value = self.code_var.canonicalize(value)
if value not in self.result:
v, err = self.code_var.eval_var(value)
if err:
return None
self.result[value] = v
return self.result[value]
return float(value) if value else 0
[docs]
class CodeVarDeleteIterator(lattice.ModelIterator):
def __init__(self, code_var, name):
self.result = []
self.code_var = code_var
self.name = name
[docs]
def field(self, model, field_schema, field):
if field_schema[1] == "RPNValue" and self.code_var.is_var_value(model[field]):
expr = self.code_var.canonicalize(
self.code_var.infix_to_postfix(str(model[field]))
)
for v in str(expr).split(" "):
if v == self.name:
if lattice.LatticeUtil.is_command(model):
self.result.append("{}.{}".format(model._type, field))
else:
self.result.append(
"{} {}.{}".format(model.type, model.name, field),
)
[docs]
class PurePythonEval:
_OPS = PKDict(
{
"*": operator.mul,
"+": operator.add,
"-": operator.sub,
"/": operator.truediv,
"abs": operator.abs,
"acos": math.acos,
"asin": math.asin,
"atan": math.atan,
"chs": operator.neg,
"cos": math.cos,
"pow": operator.pow,
"sin": math.sin,
"sqrt": math.sqrt,
"tan": math.tan,
}
)
def __init__(self, constants=None):
self.constants = constants or []
[docs]
def eval_var(self, expr, depends, variables):
variables = variables.copy()
for d in depends:
# recurse eval_var, but with empty dependencies
v, err = PurePythonEval.eval_var(
self,
self.__eval_indexed_variable(variables[d], variables),
{},
variables,
)
if err:
return None, err
variables[d] = v
return self.__eval_python_stack(
self.__eval_indexed_variable(expr, variables), variables
)
[docs]
@classmethod
def postfix_to_infix(cls, expr):
if not CodeVar.is_var_value(expr):
return expr
def __strip_parens(v):
return re.sub(r"^\((.*)\)$", r"\1", v)
values = str(expr).split(" ")
stack = []
for v in values:
if v in cls._OPS:
try:
op = cls._OPS[v]
args = list(
reversed([stack.pop() for _ in range(_get_arg_count(op))])
)
if v == "chs":
stack.append("-{}".format(args[0]))
elif re.search(r"\w", v):
stack.append(
"{}({})".format(
v, ",".join([__strip_parens(arg) for arg in args])
)
)
else:
stack.append("({} {} {})".format(args[0], v, args[1]))
except IndexError:
# not parseable, return original expression
return expr
else:
stack.append(v)
return __strip_parens(stack[-1])
def __eval_indexed_variable(self, expr, variables):
if isinstance(expr, list):
return [self.__eval_indexed_variable(e, variables) for e in expr]
r = rf"(.*)({'|'.join(list(variables.keys()))})\s*\[\s*(\d+)\s*\]"
if not re.match(r, str(expr)):
return CodeVar.infix_to_postfix(expr)
return self.__eval_indexed_variable(
re.sub(
r,
lambda m: m.group(1) + str(variables[m.group(2)][int(m.group(3))]),
expr,
),
variables,
)
def __eval_python_stack(self, expr, variables):
if not CodeVar.is_var_value(expr):
return expr, None
if isinstance(expr, list):
evs = []
# loop instead of map so we can fail out on the first error
for e in expr:
ev = self.__eval_python_stack(CodeVar.infix_to_postfix(e), variables)
if ev[1] is not None:
return None, ev[1]
evs.append(ev[0])
return evs, None
values = str(expr).split(" ")
stack = []
for v in values:
if v in variables:
stack.append(variables[v])
elif v in self.constants:
stack.append(self.constants[v])
elif v in self._OPS:
try:
op = self._OPS[v]
args = reversed(
[float(stack.pop()) for _ in range(_get_arg_count(op))],
)
stack.append(op(*args))
except IndexError:
return None, "too few items on stack"
except ZeroDivisionError:
return None, "division by zero"
else:
try:
stack.append(float(v))
except ValueError:
return None, "unknown token: {}".format(v)
return stack[-1], None
def _get_arg_count(fn):
return len(inspect.signature(fn).parameters)