Source code for torchgpe.utils.configuration
import warnings
import yaml
from math import sqrt
import re
from scipy. constants import pi, hbar, c
from .potentials import linear_ramp, quench, s_ramp
# The global variables that are available to the !eval tag
__globals = {
# Prevents the user from accessing builtins
'__builtins__': None,
# Allows the user to access the sqrt method from the math module
"sqrt": sqrt,
# Allows the user to access the linear_ramp, quench, and s_ramp methods from the potentials2D module
"linear_ramp": linear_ramp,
"s_ramp": s_ramp,
"quench": quench,
# Allows the user to access the pi, hbar, and c constants from the scipy.constants module
"pi": pi,
"hbar": hbar,
"c": c,
}
def __config_tag_evaluate(loader, node):
"""Evaluates a YAML tag of the form !eval <expression> [locals]
Args:
loader (yaml.Loader): The YAML loader.
node (yaml.Node): The YAML node.
"""
expression = loader.construct_scalar(node.value[0])
locals = {} if len(
node.value) == 1 else loader.construct_mapping(node.value[1])
if any(key in locals for key in __globals.keys()):
warnings.warn(
f"{', '.join(__globals.keys())} are reserved keywords and are set to the respective constants. By specifying them, their value is overwritten")
return eval(expression, __globals, locals)
# Regex for parsing exponential numbers
# Taken from https://stackoverflow.com/questions/30458977/how-to-parse-exponential-numbers-with-pyyaml
__config_exponential_resolver =\
re.compile(u'''^(?:
[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
|[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+)
|\\.[0-9_]+(?:[eE][-+][0-9]+)?
|[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]*
|[-+]?\\.(?:inf|Inf|INF)
|\\.(?:nan|NaN|NAN))$''', re.X)
[docs]
def parse_config(path):
"""Parses a YAML configuration file.
Args:
path (str): The path to the configuration file.
Returns:
dict: The parsed configuration.
Raises:
yaml.YAMLError: If the configuration file is not valid YAML.
"""
loader = yaml.SafeLoader
loader.add_implicit_resolver(
u'tag:yaml.org,2002:float', __config_exponential_resolver, list(u'-+0123456789.'))
loader.add_constructor('!eval', __config_tag_evaluate)
with open(path, "r") as file:
return yaml.load(file, Loader=loader)