Coverage for /builds/kinetik161/ase/ase/utils/parsemath.py: 85.11%

94 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-10 11:04 +0000

1"""A Module to safely parse/evaluate Mathematical Expressions""" 

2import ast 

3import math 

4import operator as op 

5 

6from numpy import int64 

7 

8# Sets the limit of how high the number can get to prevent DNS attacks 

9max_value = 1e17 

10 

11 

12# Redefine mathematical operations to prevent DNS attacks 

13def add(a, b): 

14 """Redefine add function to prevent too large numbers""" 

15 if any(abs(n) > max_value for n in [a, b]): 

16 raise ValueError((a, b)) 

17 return op.add(a, b) 

18 

19 

20def sub(a, b): 

21 """Redefine sub function to prevent too large numbers""" 

22 if any(abs(n) > max_value for n in [a, b]): 

23 raise ValueError((a, b)) 

24 return op.sub(a, b) 

25 

26 

27def mul(a, b): 

28 """Redefine mul function to prevent too large numbers""" 

29 if a == 0.0 or b == 0.0: 

30 pass 

31 elif math.log10(abs(a)) + math.log10(abs(b)) > math.log10(max_value): 

32 raise ValueError((a, b)) 

33 return op.mul(a, b) 

34 

35 

36def div(a, b): 

37 """Redefine div function to prevent too large numbers""" 

38 if b == 0.0: 

39 raise ValueError((a, b)) 

40 elif a == 0.0: 

41 pass 

42 elif math.log10(abs(a)) - math.log10(abs(b)) > math.log10(max_value): 

43 raise ValueError((a, b)) 

44 return op.truediv(a, b) 

45 

46 

47def power(a, b): 

48 """Redefine pow function to prevent too large numbers""" 

49 if a == 0.0: 

50 return 0.0 

51 elif b / math.log(max_value, abs(a)) >= 1: 

52 raise ValueError((a, b)) 

53 return op.pow(a, b) 

54 

55 

56def exp(a): 

57 """Redefine exp function to prevent too large numbers""" 

58 if a > math.log(max_value): 

59 raise ValueError(a) 

60 return math.exp(a) 

61 

62 

63# The list of allowed operators with defined functions they should operate on 

64operators = { 

65 ast.Add: add, 

66 ast.Sub: sub, 

67 ast.Mult: mul, 

68 ast.Div: div, 

69 ast.Pow: power, 

70 ast.USub: op.neg, 

71 ast.Mod: op.mod, 

72 ast.FloorDiv: op.ifloordiv 

73} 

74 

75# Take all functions from math module as allowed functions 

76allowed_math_fxn = { 

77 "sin": math.sin, 

78 "cos": math.cos, 

79 "tan": math.tan, 

80 "asin": math.asin, 

81 "acos": math.acos, 

82 "atan": math.atan, 

83 "atan2": math.atan2, 

84 "hypot": math.hypot, 

85 "sinh": math.sinh, 

86 "cosh": math.cosh, 

87 "tanh": math.tanh, 

88 "asinh": math.asinh, 

89 "acosh": math.acosh, 

90 "atanh": math.atanh, 

91 "radians": math.radians, 

92 "degrees": math.degrees, 

93 "sqrt": math.sqrt, 

94 "log": math.log, 

95 "log10": math.log10, 

96 "log2": math.log2, 

97 "fmod": math.fmod, 

98 "abs": math.fabs, 

99 "ceil": math.ceil, 

100 "floor": math.floor, 

101 "round": round, 

102 "exp": exp, 

103} 

104 

105 

106def get_function(node): 

107 """Get the function from an ast.node""" 

108 

109 # The function call can be to a bare function or a module.function 

110 if isinstance(node.func, ast.Name): 

111 return node.func.id 

112 elif isinstance(node.func, ast.Attribute): 

113 return node.func.attr 

114 else: 

115 raise TypeError("node.func is of the wrong type") 

116 

117 

118def limit(max_=None): 

119 """Return decorator that limits allowed returned values.""" 

120 import functools 

121 

122 def decorator(func): 

123 @functools.wraps(func) 

124 def wrapper(*args, **kwargs): 

125 ret = func(*args, **kwargs) 

126 try: 

127 mag = abs(ret) 

128 except TypeError: 

129 pass # not applicable 

130 else: 

131 if mag > max_: 

132 raise ValueError(ret) 

133 if isinstance(ret, int): 

134 ret = int64(ret) 

135 return ret 

136 

137 return wrapper 

138 

139 return decorator 

140 

141 

142@limit(max_=max_value) 

143def _eval(node): 

144 """Evaluate a mathematical expression string parsed by ast""" 

145 # Allow evaluate certain types of operators 

146 if isinstance(node, ast.Num): # <number> 

147 return node.n 

148 elif isinstance(node, ast.BinOp): # <left> <operator> <right> 

149 return operators[type(node.op)](_eval(node.left), _eval(node.right)) 

150 elif isinstance(node, ast.UnaryOp): # <operator> <operand> e.g., -1 

151 return operators[type(node.op)](_eval(node.operand)) 

152 elif isinstance(node, ast.Call): # using math.function 

153 func = get_function(node) 

154 # Evaluate all arguments 

155 evaled_args = [_eval(arg) for arg in node.args] 

156 return allowed_math_fxn[func](*evaled_args) 

157 elif isinstance(node, ast.Name): 

158 if node.id.lower() == "pi": 

159 return math.pi 

160 elif node.id.lower() == "e": 

161 return math.e 

162 elif node.id.lower() == "tau": 

163 return math.pi * 2.0 

164 else: 

165 raise TypeError( 

166 "Found a str in the expression, either param_dct/the " 

167 "expression has a mistake in the parameter names or " 

168 "attempting to parse non-mathematical code") 

169 else: 

170 raise TypeError(node) 

171 

172 

173def eval_expression(expression, param_dct=dict()): 

174 """Parse a mathematical expression, 

175 

176 Replaces variables with the values in param_dict and solves the expression 

177 

178 """ 

179 if not isinstance(expression, str): 

180 raise TypeError("The expression must be a string") 

181 if len(expression) > 1e4: 

182 raise ValueError("The expression is too long.") 

183 

184 expression_rep = expression.strip() 

185 

186 if "()" in expression_rep: 

187 raise ValueError("Invalid operation in expression") 

188 

189 for key, val in param_dct.items(): 

190 expression_rep = expression_rep.replace(key, str(val)) 

191 

192 return _eval(ast.parse(expression_rep, mode="eval").body)