# -*- coding: utf-8 -*-
"""Lattice utilities.
:copyright: Copyright (c) 2019 RadiaSoft LLC. All Rights Reserved.
:license: http://www.apache.org/licenses/LICENSE-2.0.html
"""
from __future__ import absolute_import, division, print_function
from pykern.pkcollections import PKDict
from pykern.pkdebug import pkdc, pkdlog, pkdp
from sirepo.template.line_parser import LineParser
import re
[docs]
class ModelIterator(object):
"""Base class for model iterators with stubbed out methods.
When iterate_models() is called, the iterator calls are made:
it.start(model)
foreach field in model:
it.field(model, field_schema, field)
it.end(model)
"""
[docs]
def end(self, model):
pass
[docs]
def field(self, model, field_schema, field):
pass
[docs]
def start(self, model):
pass
[docs]
class ElementIterator(ModelIterator):
"""Iterate all fields, adding any set to non-default values to the results."""
IS_DISABLED_FIELD = "isDisabled"
def __init__(self, filename_map, formatter):
self.result = []
self.filename_map = filename_map
self.formatter = formatter
[docs]
def end(self, model):
self.result.append([model, self.fields])
[docs]
def field(self, model, field_schema, field):
if field == self.IS_DISABLED_FIELD or field == "_super":
return
self.field_index += 1
if self.is_ignore_field(field) or self.__is_default(model, field_schema, field):
return
f = self.formatter(self, model, field, field_schema[1])
if f:
self.fields.append(f)
[docs]
def is_ignore_field(self, field):
return False
[docs]
def start(self, model):
self.field_index = 0
self.fields = []
def __is_default(self, model, field_schema, field):
from sirepo.template.code_variable import CodeVar
if len(field_schema) < 3:
return True
default_value = field_schema[2]
value = model[field]
if value is not None and default_value is not None:
if value == default_value:
return True
if field_schema[1] == "RPNValue":
if value and not CodeVar.is_var_value(value):
return float(value) == default_value
return str(value) == str(default_value)
if value is not None:
return False
return True
[docs]
class LatticeIterator(ElementIterator):
"""Iterate all lattice elements/fields which are not set to the default value."""
[docs]
def is_ignore_field(self, field):
return field in ["name", "type", "_id"] or re.search("(X|Y|File)$", field)
[docs]
class UpdateIterator(ModelIterator):
def __init__(self, update_func):
self.update_func = update_func
[docs]
def field(self, model, field_schema, field):
if field_schema[1] == "RPNValue":
self.update_func(model, field)
[docs]
class LatticeParser(object):
COMMAND_PREFIX = "command_"
def __init__(self, sim_data):
self.sim_data = sim_data
self.schema = sim_data.schema()
[docs]
def parse_file(self, lattice_text):
from sirepo import simulation_db
self.data = simulation_db.default_data(self.sim_data.sim_type())
self.parser = LineParser(100)
self.data.models.rpnVariables = {}
self.data.models.sequences = []
# None | sequence | track | match | edit
self.container = None
self.elements_by_name = PKDict()
lines = lattice_text.replace("\r", "").split("\n")
self.__parse_lines(lines)
return self.data
def _add_variables_for_lattice_references(self):
# iterate all values, adding "x->y" lattice referenes as variables "x.y"
from sirepo.template.code_variable import CodeVar
def _fix_value(value, names):
value = re.sub(r"\-\>", ".", value)
expr = CodeVar.infix_to_postfix(value.lower())
for v in expr.split(" "):
if CodeVar.is_var_value(v):
m = re.match(r"^(.*?)\.(.*)", v)
if m:
names[v] = [m.group(1), m.group(2)]
return value
names = {}
for v in self.data.models.rpnVariables:
if CodeVar.is_var_value(v.value):
v.value = _fix_value(v.value, names)
for el in self.data.models.elements:
for f in el:
v = el[f]
if CodeVar.is_var_value(v):
el[f] = _fix_value(v, names)
for name in names:
for el in self.data.models.elements:
if el.name.lower() == names[name][0]:
f = names[name][1]
if f in el:
self.data.models.rpnVariables.append(
PKDict(
name=name,
value=el[f],
)
)
def _code_variables_to_float(self, code_var):
def _float_update(model, field):
if not code_var.is_var_value(model[field]) and type(model[field]) != float:
model[field] = float(model[field])
for v in self.data.models.rpnVariables:
if not code_var.is_var_value(v.value):
v.value = float(v.value)
LatticeUtil(self.data, self.schema).iterate_models(
UpdateIterator(_float_update)
)
def _compute_drifts(self, code_var):
drifts = PKDict()
for el in self.data.models.elements:
if el.type == "DRIFT":
length = self._format_length(self._eval_var(code_var, el.l))
if length not in drifts:
drifts[length] = el._id
return drifts
def _downcase_variables(self, code_var):
def _downcase_update(model, field):
if code_var.is_var_value(model[field]):
model[field] = model[field].lower()
for v in self.data.models.rpnVariables:
v.name = v.name.lower()
if code_var.is_var_value(v.value):
v.value = v.value.lower()
LatticeUtil(self.data, self.schema).iterate_models(
UpdateIterator(_downcase_update)
)
def _eval_var(self, code_var, value):
return code_var.eval_var_with_assert(value)
@classmethod
def _format_command(cls, name):
return f"{cls.COMMAND_PREFIX}{name}"
def _format_length(self, length):
res = "{:.8E}".format(length)
res = re.sub(r"(\.\d+?)(0+)E", r"\1e", res)
res = re.sub(r"e\+00$", "", res)
return res
def _get_drift(self, drifts, length, allow_negative_drift=False):
if length == 0:
return None
if length < 0 and not allow_negative_drift:
pkdlog("warning: negative drift: {}", length)
return None
length = self._format_length(length)
if length not in drifts:
name = "D{}".format(length)
name = re.sub(r"\+", "", name)
name = re.sub(r"e?-", "R", name)
drift = PKDict(
_id=self.parser.next_id(),
l=float(length),
name=name,
type="DRIFT",
)
self.sim_data.update_model_defaults(drift, "DRIFT")
self.data.models.elements.append(drift)
drifts[length] = drift._id
return drifts[length]
def _set_default_beamline(self, cmd_type, field1, field2=None):
name = None
for cmd in self.data.models.commands:
if cmd._type == cmd_type:
name = None
if field1 in cmd:
name = cmd.get(field1)
elif field2 and field2 in cmd:
name = cmd.get(field2)
if name and name.upper() in self.elements_by_name:
name = name.upper()
break
name = None
beamline_id = None
if name:
beamline_id = self.elements_by_name[name].id
elif self.data.models.beamlines:
beamline_id = self.data.models.beamlines[-1].id
self.data.models.simulation.activeBeamlineId = (
self.data.models.simulation.visualizationBeamlineId
) = beamline_id
def __model_name(self, cmd):
res = cmd
while res not in self.schema.model:
parent = self.elements_by_name[res]
assert parent and parent.type
res = parent.type
return res
def __parse_beamline(self, label, values):
assert label
# remove beamline attributes
attrs = PKDict()
items = []
values[0] = re.sub(r"^.*?=\s*\(\s*", "", values[0])
for v in values:
if "=" in v:
m = re.match(r"^\s*([\w.]+)\s*:?=\s*(.+?)\s*$", v)
if m:
attrs[m.group(1).lower()] = m.group(2)
else:
items.append(v)
items[-1] = re.sub(r"\s*\)$", "", items[-1])
res = PKDict(
name=label,
id=self.parser.next_id(),
items=[],
).pkupdate(attrs)
self.sim_data.update_model_defaults(res, "beamline")
for v in items:
v = self.__remove_quotes(v)
count = 1
m = re.match(r"^(\d+)\s*\*\s*\(?([\w.]+)\)?$", v)
if m:
count = int(m.group(1))
v = m.group(2)
reverse = False
if v[0] == "-":
reverse = True
v = v[1:]
el = self.elements_by_name.get(v.upper())
assert el, "line: {}, element not found: {}".format(label, v)
el_id = el._id if "_id" in el else el.id
for _ in range(count):
res["items"].append(-el_id if reverse else el_id)
assert (
label.upper() not in self.elements_by_name
), "duplicate beamline: {}".format(label)
self.elements_by_name[label.upper()] = res
self.data.models.beamlines.append(res)
def __parse_element(self, cmd, label, values):
res = self.__parse_fields(
self.__model_name(cmd),
values,
PKDict(
name=label,
_id=self.parser.next_id(),
),
)
res.type = cmd
if self.container:
assert "at" in res, 'sequence element missing "at": {}'.format(values)
at = res.at
del res["at"]
# assert label, 'unlabeled element: {}'.format(values)
if not label:
label = cmd
if label.upper() in self.elements_by_name:
self.container["items"].append(
[self.elements_by_name[label.upper()]._id, at]
)
return
if cmd not in self.schema.model:
parent = self.elements_by_name[cmd]
assert parent
assert len(res) >= 3
# if len(res) == 3:
# self.container["items"].append([parent._id, at])
# return
self.container["items"].append([res._id, at])
assert "at" not in res
# copy in superclass values
while cmd not in self.schema.model:
parent = self.elements_by_name[cmd]
assert parent and parent.type
res = PKDict(list(parent.items()) + list(res.items()))
res.type = parent.type
cmd = parent.type
self.sim_data.update_model_defaults(res, res.type)
if not label:
label = values[0].upper()
assert (
label in self.elements_by_name
), "no element for label: {}: {}".format(label, values)
self.elements_by_name[label].update(res)
else:
# assert label.upper() not in self.elements_by_name, \
# 'duplicate element labeled: {}'.format(label)
self.elements_by_name[label.upper()] = res
self.data.models.elements.append(res)
def __parse_fields(self, cmd, values, res):
model_schema = self.schema.model.get(cmd)
prev_field = None
for value in values[1:]:
m = re.match(r"^\s*([\w.]+)\s*:?=\s*(.+?)\s*$", value)
if m:
f, v = m.group(1, 2)
f = f.lower()
# allow native fields named "type" which conflicts element sirepo element type
if f == "type" and model_schema and f not in model_schema:
f = f"{cmd.lower()}_type"
# skip non-schema fields, with the exception of positional fields "at" and "elemedge"
if (
model_schema
and f not in model_schema
and f not in ("at", "elemedge", "z", "origin")
):
continue
if f != "name":
# some commands may have a "name" field
assert f not in res, "field already defined: {}, values: {}".format(
f, values
)
res[f] = self.__remove_quotes(v)
prev_field = f
continue
# no assignment, maybe a boolean value
m = re.match(r"^\s*(!|-)?\s*([\w.]+)\s*$", value)
assert m, "failed to parse field assignment: {}".format(value)
v, f = m.group(1, 2)
if model_schema and f not in model_schema:
# special case for "column" field, may contain multiple comma separated values
if prev_field == "column":
res[prev_field] += f", {f}"
continue
res[f.lower()] = "0" if v else "1"
return res
def __parse_lines(self, lines):
prev_line = ""
in_comment = False
for line in lines:
self.parser.increment_line_number()
line = re.sub(r"\&\s*$", "", line)
# strip comments
line = line.strip()
line = re.sub(r"(.*?)(!|//).*$", r"\1", line)
line = re.sub(r"\/\*.*?\*\/", "", line)
# special case, some commands often missing a comma
line = re.sub(
r"^\s*(title|exec|call)\s+([^,])", r"\1, \2", line, flags=re.IGNORECASE
)
if in_comment and re.search(r"^.*\*\/", line):
line = re.sub(r"^.*\*\/", "", line)
in_comment = False
if re.search(r"\/\*.*$", line):
line = re.sub(r"\/\*.*$", "", line)
in_comment = True
if not line or in_comment:
continue
assert not re.search(
r"^\s*(if|while)\s*\(", line, re.IGNORECASE
), "Remove conditional if() or while() statements from input file before import"
while ";" in line:
m = re.match(r"^(.*?);(.*)$", line)
assert m, "parse ; failed: {}".format(line)
item = (prev_line + " " + m.group(1)).strip()
self.__parse_values(self.__split_values(item))
line = m.group(2)
prev_line = ""
prev_line += line
self.data.models["rpnVariables"] = [
PKDict(name=k, value=v) for k, v in self.data.models.rpnVariables.items()
]
def __parse_statement(self, cmd, label, values):
if cmd.upper() in self.schema.model or cmd.upper() in self.elements_by_name:
self.__parse_element(cmd.upper(), label, values)
return
cmd = cmd.lower()
if self.container and cmd == "end{}".format(self.container.type):
assert len(values) == 1, "invalid end{}: {}".format(self.container, values)
self.container = None
return
if cmd in ("sequence", "track"):
self.container = PKDict(
name=label,
type=cmd,
_id=self.parser.next_id(),
)
self.__parse_fields(self._format_command(cmd), values, self.container)
self.container["items"] = []
if cmd == "sequence":
self.data.models.sequences.append(self.container)
return
if self._format_command(cmd) in self.schema.model:
res = PKDict(
_type=cmd,
_id=self.parser.next_id(),
name=label,
)
self.__parse_fields(self._format_command(cmd), values, res)
self.sim_data.update_model_defaults(
res, LatticeUtil.model_name_for_data(res)
)
self.data.models.commands.append(res)
elif cmd == "line":
self.__parse_beamline(label, values)
elif cmd == "title":
if len(values) > 1:
self.data.models.simulation.name = self.__remove_quotes(values[1])
elif cmd not in self.ignore_commands:
assert (
cmd != "call"
), '"CALL" statement not supported, combine subfiles into one input file before import'
if re.search(r"^ptc_", cmd):
pass
else:
pkdlog("unknown cmd: {}", values)
def __parse_values(self, values):
if not values:
return
if (
(re.search(r"^\s*REAL\s", values[0], re.IGNORECASE) or len(values) == 1)
and "=" in values[0]
and not re.search(r"\Wline\s*\:?=\s*\(", values[0].lower())
):
if re.search(R"^\s*(BOOL|STRING)\s", values[0], re.IGNORECASE):
return
# a variable assignment
val = ", ".join(values)
m = re.match(r".*?([\w.\'\-]+)\s*:?=\s*(.*)$", val)
assert m, "invalid variable assignment: {}".format(val)
name = m.group(1)
v = m.group(2)
if name not in self.data.models.rpnVariables:
self.data.models.rpnVariables[name] = v
return
if ":" in values[0]:
m = re.match(r'(".*?")\s*:\s*(".*?")', values[0])
if not m:
m = re.match(r'(".*?")\s*:\s*(\w+)', values[0])
if not m:
m = re.match(r'([\w.#"\-\/]+)\s*:\s*([\w."]+)', values[0])
assert m, "label match failed: {}".format(values[0])
label, cmd = m.group(1, 2)
label = self.__remove_quotes(label)
cmd = self.__remove_quotes(cmd)
else:
label, cmd = None, values[0]
self.__parse_statement(cmd, label, values)
def __remove_quotes(self, value):
return re.sub(r'[\'"](.*)[\'"]', r"\1", value)
def __split_values(self, item):
# split items into values by commas
values = []
while item:
item = item.strip()
m = re.match(r'^(".*?"\s*:\s*\w+)\s*,(.*)$', item)
if m:
values.append(m.group(1))
item = m.group(2)
continue
m = re.match(
r'^\s*((?:[\w.\']+\s*:?=\s*)(?:(?:".*?")|(?:\'.*?\')|(?:\{.*?\})|(?:\w+\(.*?\).*?)))(?:,(.*))?$',
item,
)
if m:
values.append(m.group(1))
assert item != m.group(2)
item = m.group(2)
continue
m = re.match(r"^\s*(.+?)(?:,(.*))?$", item)
if m:
values.append(m.group(1).strip())
assert item != m.group(2)
item = m.group(2)
continue
assert False, "line parse failed: {}".format(item)
# try to fix up mismatched parenthesis
res = []
for idx in range(len(values)):
v = values[idx]
if len(res):
prev = res[-1]
mismatch_count = prev.count("(") - prev.count(")")
if mismatch_count and not re.search(r"\bline\b", prev, re.IGNORECASE):
if mismatch_count == v.count(")") - v.count("("):
res[-1] += f", {v}"
continue
res.append(v)
return res
[docs]
class LatticeUtil(object):
_OUTPUT_NAME_PREFIX = "elementAnimation"
_FILE_ID_SEP = "-"
"""Utility class for generating lattice elements, beamlines and commands.
"""
def __init__(self, data, schema):
self.data = data
self.schema = schema
self.id_map, self.max_id = self.__build_id_map(data)
[docs]
def explode_beamline(self, beamline_id):
res = []
for bid in self.get_item(beamline_id)["items"]:
e = self.get_item(abs(bid))
if self.is_beamline(e):
r = self.explode_beamline(e.id)
if bid < 0:
r.reverse()
res += r
else:
res.append(bid)
return res
[docs]
@classmethod
def find_first_command(cls, data, command_type):
for m in data.models.commands:
if m._type == command_type:
return m
return None
[docs]
@classmethod
def file_id(cls, model_id, field_index):
return f"{model_id}{LatticeUtil._FILE_ID_SEP}{field_index}"
[docs]
@classmethod
def file_id_from_output_model_name(cls, name):
return re.sub(cls._OUTPUT_NAME_PREFIX, "", name)
[docs]
@classmethod
def fixup_output_files(cls, data, schema, output_file_iterator):
# if new model fields are added to the schema,
# the output file id may be invalid, fixup original by filename
v = LatticeUtil(data, schema).iterate_models(output_file_iterator).result
remove_list = []
add_list = {}
for m in data.models:
if not cls.__is_output_model_name(m):
continue
if cls.file_id_from_output_model_name(m) in v:
continue
file_id = None
if "xFile" in data.models[m]:
for k in v:
if v[k] == data.models[m].xFile:
file_id = k
break
if file_id:
name = cls.output_model_name(file_id)
if name not in data.models:
add_list[name] = data.models[m]
data.models[m].xFileId = file_id
remove_list.append(m)
for m in remove_list:
del data.models[m]
for m in add_list:
data.models[m] = add_list[m]
[docs]
def get_item(self, item_id):
return self.id_map[item_id]
[docs]
@classmethod
def get_lattice_id_from_file_id(cls, data, file_id):
for c in data.models.commands:
if (
c._id == int(file_id.split(LatticeUtil._FILE_ID_SEP)[0])
and "use_beamline" in c
):
return c.use_beamline
return None
[docs]
@classmethod
def has_command(cls, data, command_type):
for cmd in data.models.commands:
if cmd._type == command_type:
return True
return False
[docs]
@classmethod
def is_beamline(cls, model):
"""Is the model a beamline?"""
return "_id" not in model and "type" not in model
[docs]
@classmethod
def is_command(cls, model):
"""Is the model a command or a lattice element?"""
return "_type" in model
[docs]
def iterate_models(self, iterator, name=None):
"""Iterate the models in the named container.
By default the commands and elements containers are iterated.
"""
iterator.id_map = self.id_map
names = (name,) if name else ("commands", "elements")
for name in names:
for m in self.data.models[name]:
model_schema = self.schema.model[self.model_name_for_data(m)]
iterator.start(m)
for k in sorted(m):
if k in model_schema:
iterator.field(m, model_schema[k], k)
iterator.end(m)
return iterator
[docs]
@classmethod
def max_id(cls, data):
max_id = 1
for model_type in "elements", "beamlines", "commands":
if model_type not in data.models:
continue
for m in data.models[model_type]:
assert "_id" in m or "id" in m, "Missing id: {}".format(m)
i = m._id if "_id" in m else m.id
if i > max_id:
max_id = i
return max_id
[docs]
@classmethod
def model_name_for_data(cls, model):
"""Returns the model's schema name."""
return (
LatticeParser._format_command(model._type)
if cls.is_command(model)
else model.type
)
[docs]
@classmethod
def output_model_name(cls, file_id):
return "{}{}".format(cls._OUTPUT_NAME_PREFIX, file_id)
[docs]
def render_lattice(
self,
fields,
quote_name=False,
want_semicolon=False,
want_name=True,
want_var_assign=False,
madx_name=False,
comment="//",
):
"""Render lattice elements."""
from sirepo.template.code_variable import CodeVar
res = ""
for el in fields:
# el is [model, [[f, v], [f, v]...]]
el_type = self.type_for_data(el[0])
if (
ElementIterator.IS_DISABLED_FIELD in el[0]
and el[0][ElementIterator.IS_DISABLED_FIELD] == "1"
):
res += comment + " "
if want_name:
name = self.__format_name(el[0].name, quote_name, madx_name)
res += "{}: {},".format(name, el_type)
else:
res += "{},".format(el_type)
for f in el[1]:
var_assign = ""
if want_var_assign:
s = self.schema.model[el_type]
if (
f[0] in s
and s[f[0]][1] == "RPNValue"
and CodeVar.is_var_value(f[1])
):
var_assign = ":"
res += "{}{}={},".format(f[0], var_assign, f[1])
res = res[:-1]
if want_semicolon:
res += ";"
res += "\n"
return res
[docs]
def render_lattice_and_beamline(self, iterator, **kwargs):
return self.render_lattice(
self.iterate_models(iterator, "elements").result, **kwargs
) + self.__render_beamline(**kwargs)
[docs]
def select_beamline(self):
"""Returns the beamline to use based for the selected report."""
sim = self.data.models.simulation
if self.data.get("report", "") == "twissReport":
beamline_id = sim.activeBeamlineId
else:
if "visualizationBeamlineId" not in sim or not sim.visualizationBeamlineId:
sim.visualizationBeamlineId = self.data.models.beamlines[0].id
beamline_id = sim.visualizationBeamlineId
return self.get_item(int(beamline_id))
[docs]
def sort_elements_and_beamlines(self):
"""Sort elements and beamline models in place, by (type, name) and (name)"""
m = self.data.models
m.elements = sorted(m.elements, key=lambda e: (e.type, e.name.lower()))
m.beamlines = sorted(m.beamlines, key=lambda e: e.name.lower())
[docs]
@classmethod
def type_for_data(cls, model):
return model["_type" if cls.is_command(model) else "type"]
def __add_beamlines(self, beamline, beamlines, ordered_beamlines):
if beamline in ordered_beamlines:
return
for bid in beamline["items"]:
bid = abs(bid)
if bid in beamlines and "type" not in beamlines[bid]:
self.__add_beamlines(beamlines[bid], beamlines, ordered_beamlines)
ordered_beamlines.append(beamline)
def __build_id_map(self, data):
"""Returns a map of beamlines and elements, (id => model)."""
res = {}
for bl in data.models.beamlines:
res[bl.id] = bl
for el in data.models.elements:
res[el._id] = el
if "commands" in data.models:
for cmd in data.models.commands:
# TODO(pjm): some old elegant sims have overlap in element and command ids
if cmd._id not in res:
res[cmd._id] = cmd
max_id = max(res.keys()) if res else 0
return res, max_id
def __format_name(self, name, quote_name, madx_name, is_reversed=False):
name = name.upper()
if madx_name:
name = re.sub(r"[\-:/]", "_", name)
if is_reversed:
name = f"-{name}"
if quote_name:
name = '"{}"'.format(name)
return name
@classmethod
def __is_output_model_name(cls, name):
return cls._OUTPUT_NAME_PREFIX in name
def __render_beamline(
self,
quote_name=False,
want_semicolon=False,
want_var_assign=False,
madx_name=False,
):
"""Render the beamlines list in precedence order."""
ordered_beamlines = []
for bid in sorted(self.id_map):
model = self.id_map[bid]
if "type" not in model and not self.is_command(model):
self.__add_beamlines(model, self.id_map, ordered_beamlines)
res = ""
for bl in ordered_beamlines:
if bl["items"]:
name = self.__format_name(bl.name, quote_name, madx_name)
res += "{}: LINE=(".format(name)
for bid in bl["items"]:
res += "{},".format(
self.__format_name(
self.id_map[abs(bid)].name,
quote_name,
madx_name,
is_reversed=bid < 0,
)
)
res = res[:-1]
res += ")"
if want_semicolon:
res += ";"
res += "\n"
return res