Coverage for /builds/kinetik161/ase/ase/utils/parsemath.py: 85.11%
94 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
1"""A Module to safely parse/evaluate Mathematical Expressions"""
2import ast
3import math
4import operator as op
6from numpy import int64
8# Sets the limit of how high the number can get to prevent DNS attacks
9max_value = 1e17
12# Redefine mathematical operations to prevent DNS attacks
13def add(a, b):
14 """Redefine add function to prevent too large numbers"""
15 if any(abs(n) > max_value for n in [a, b]):
16 raise ValueError((a, b))
17 return op.add(a, b)
20def sub(a, b):
21 """Redefine sub function to prevent too large numbers"""
22 if any(abs(n) > max_value for n in [a, b]):
23 raise ValueError((a, b))
24 return op.sub(a, b)
27def mul(a, b):
28 """Redefine mul function to prevent too large numbers"""
29 if a == 0.0 or b == 0.0:
30 pass
31 elif math.log10(abs(a)) + math.log10(abs(b)) > math.log10(max_value):
32 raise ValueError((a, b))
33 return op.mul(a, b)
36def div(a, b):
37 """Redefine div function to prevent too large numbers"""
38 if b == 0.0:
39 raise ValueError((a, b))
40 elif a == 0.0:
41 pass
42 elif math.log10(abs(a)) - math.log10(abs(b)) > math.log10(max_value):
43 raise ValueError((a, b))
44 return op.truediv(a, b)
47def power(a, b):
48 """Redefine pow function to prevent too large numbers"""
49 if a == 0.0:
50 return 0.0
51 elif b / math.log(max_value, abs(a)) >= 1:
52 raise ValueError((a, b))
53 return op.pow(a, b)
56def exp(a):
57 """Redefine exp function to prevent too large numbers"""
58 if a > math.log(max_value):
59 raise ValueError(a)
60 return math.exp(a)
63# The list of allowed operators with defined functions they should operate on
64operators = {
65 ast.Add: add,
66 ast.Sub: sub,
67 ast.Mult: mul,
68 ast.Div: div,
69 ast.Pow: power,
70 ast.USub: op.neg,
71 ast.Mod: op.mod,
72 ast.FloorDiv: op.ifloordiv
73}
75# Take all functions from math module as allowed functions
76allowed_math_fxn = {
77 "sin": math.sin,
78 "cos": math.cos,
79 "tan": math.tan,
80 "asin": math.asin,
81 "acos": math.acos,
82 "atan": math.atan,
83 "atan2": math.atan2,
84 "hypot": math.hypot,
85 "sinh": math.sinh,
86 "cosh": math.cosh,
87 "tanh": math.tanh,
88 "asinh": math.asinh,
89 "acosh": math.acosh,
90 "atanh": math.atanh,
91 "radians": math.radians,
92 "degrees": math.degrees,
93 "sqrt": math.sqrt,
94 "log": math.log,
95 "log10": math.log10,
96 "log2": math.log2,
97 "fmod": math.fmod,
98 "abs": math.fabs,
99 "ceil": math.ceil,
100 "floor": math.floor,
101 "round": round,
102 "exp": exp,
103}
106def get_function(node):
107 """Get the function from an ast.node"""
109 # The function call can be to a bare function or a module.function
110 if isinstance(node.func, ast.Name):
111 return node.func.id
112 elif isinstance(node.func, ast.Attribute):
113 return node.func.attr
114 else:
115 raise TypeError("node.func is of the wrong type")
118def limit(max_=None):
119 """Return decorator that limits allowed returned values."""
120 import functools
122 def decorator(func):
123 @functools.wraps(func)
124 def wrapper(*args, **kwargs):
125 ret = func(*args, **kwargs)
126 try:
127 mag = abs(ret)
128 except TypeError:
129 pass # not applicable
130 else:
131 if mag > max_:
132 raise ValueError(ret)
133 if isinstance(ret, int):
134 ret = int64(ret)
135 return ret
137 return wrapper
139 return decorator
142@limit(max_=max_value)
143def _eval(node):
144 """Evaluate a mathematical expression string parsed by ast"""
145 # Allow evaluate certain types of operators
146 if isinstance(node, ast.Num): # <number>
147 return node.n
148 elif isinstance(node, ast.BinOp): # <left> <operator> <right>
149 return operators[type(node.op)](_eval(node.left), _eval(node.right))
150 elif isinstance(node, ast.UnaryOp): # <operator> <operand> e.g., -1
151 return operators[type(node.op)](_eval(node.operand))
152 elif isinstance(node, ast.Call): # using math.function
153 func = get_function(node)
154 # Evaluate all arguments
155 evaled_args = [_eval(arg) for arg in node.args]
156 return allowed_math_fxn[func](*evaled_args)
157 elif isinstance(node, ast.Name):
158 if node.id.lower() == "pi":
159 return math.pi
160 elif node.id.lower() == "e":
161 return math.e
162 elif node.id.lower() == "tau":
163 return math.pi * 2.0
164 else:
165 raise TypeError(
166 "Found a str in the expression, either param_dct/the "
167 "expression has a mistake in the parameter names or "
168 "attempting to parse non-mathematical code")
169 else:
170 raise TypeError(node)
173def eval_expression(expression, param_dct=dict()):
174 """Parse a mathematical expression,
176 Replaces variables with the values in param_dict and solves the expression
178 """
179 if not isinstance(expression, str):
180 raise TypeError("The expression must be a string")
181 if len(expression) > 1e4:
182 raise ValueError("The expression is too long.")
184 expression_rep = expression.strip()
186 if "()" in expression_rep:
187 raise ValueError("Invalid operation in expression")
189 for key, val in param_dct.items():
190 expression_rep = expression_rep.replace(key, str(val))
192 return _eval(ast.parse(expression_rep, mode="eval").body)