Coverage for /builds/kinetik161/ase/ase/calculators/mixing.py: 88.75%

80 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-10 11:04 +0000

1from ase.calculators.calculator import ( 

2 BaseCalculator, 

3 CalculatorSetupError, 

4 PropertyNotImplementedError, 

5 all_changes, 

6) 

7from ase.stress import full_3x3_to_voigt_6_stress 

8 

9 

10class Mixer: 

11 def __init__(self, calcs, weights): 

12 self.check_input(calcs, weights) 

13 common_properties = set.intersection( 

14 *(set(calc.implemented_properties) for calc in calcs) 

15 ) 

16 self.implemented_properties = list(common_properties) 

17 if not self.implemented_properties: 

18 raise PropertyNotImplementedError( 

19 "The provided Calculators have" 

20 " no properties in common!" 

21 ) 

22 self.calcs = calcs 

23 self.weights = weights 

24 

25 @staticmethod 

26 def check_input(calcs, weights): 

27 if len(calcs) == 0: 

28 raise CalculatorSetupError("Please provide a list of Calculators") 

29 for calc in calcs: 

30 if not isinstance(calc, BaseCalculator): 

31 raise CalculatorSetupError( 

32 "All Calculators should inherit" 

33 " from the BaseCalculator class" 

34 ) 

35 if len(weights) != len(calcs): 

36 raise ValueError( 

37 "The length of the weights must be the same as" 

38 " the number of Calculators!" 

39 ) 

40 

41 def get_properties(self, properties, atoms): 

42 results = {} 

43 

44 def get_property(prop): 

45 contribs = [calc.get_property(prop, atoms) for calc in self.calcs] 

46 # ensure that the contribution shapes are the same for stress prop 

47 if prop == "stress": 

48 shapes = [contrib.shape for contrib in contribs] 

49 if not all(shape == shapes[0] for shape in shapes): 

50 if prop == "stress": 

51 contribs = self.make_stress_voigt(contribs) 

52 else: 

53 raise ValueError( 

54 f"The shapes of the property {prop}" 

55 " are not the same from all" 

56 " calculators" 

57 ) 

58 results[f"{prop}_contributions"] = contribs 

59 results[prop] = sum( 

60 weight * value for weight, value in zip(self.weights, contribs) 

61 ) 

62 

63 for prop in properties: # get requested properties 

64 get_property(prop) 

65 for prop in self.implemented_properties: # cache all available props 

66 if all(prop in calc.results for calc in self.calcs): 

67 get_property(prop) 

68 return results 

69 

70 @staticmethod 

71 def make_stress_voigt(stresses): 

72 new_contribs = [] 

73 for contrib in stresses: 

74 if contrib.shape == (6,): 

75 new_contribs.append(contrib) 

76 elif contrib.shape == (3, 3): 

77 new_cont = full_3x3_to_voigt_6_stress(contrib) 

78 new_contribs.append(new_cont) 

79 else: 

80 raise ValueError( 

81 "The shapes of the stress" 

82 " property are not the same" 

83 " from all calculators" 

84 ) 

85 return new_contribs 

86 

87 

88class LinearCombinationCalculator(BaseCalculator): 

89 """Weighted summation of multiple calculators.""" 

90 

91 def __init__(self, calcs, weights): 

92 """Implementation of sum of calculators. 

93 

94 calcs: list 

95 List of an arbitrary number of :mod:`ase.calculators` objects. 

96 weights: list of float 

97 Weights for each calculator in the list. 

98 """ 

99 super().__init__() 

100 self.mixer = Mixer(calcs, weights) 

101 self.implemented_properties = self.mixer.implemented_properties 

102 

103 def calculate(self, atoms, properties, system_changes): 

104 """Calculates all the specific property for each calculator and 

105 returns with the summed value. 

106 

107 """ 

108 self.atoms = atoms.copy() # for caching of results 

109 self.results = self.mixer.get_properties(properties, atoms) 

110 

111 def __str__(self): 

112 calculators = ", ".join( 

113 calc.__class__.__name__ for calc in self.mixer.calcs 

114 ) 

115 return f"{self.__class__.__name__}({calculators})" 

116 

117 

118class MixedCalculator(LinearCombinationCalculator): 

119 """ 

120 Mixing of two calculators with different weights 

121 

122 H = weight1 * H1 + weight2 * H2 

123 

124 Has functionality to get the energy contributions from each calculator 

125 

126 Parameters 

127 ---------- 

128 calc1 : ASE-calculator 

129 calc2 : ASE-calculator 

130 weight1 : float 

131 weight for calculator 1 

132 weight2 : float 

133 weight for calculator 2 

134 """ 

135 

136 def __init__(self, calc1, calc2, weight1, weight2): 

137 super().__init__([calc1, calc2], [weight1, weight2]) 

138 

139 def set_weights(self, w1, w2): 

140 self.mixer.weights[0] = w1 

141 self.mixer.weights[1] = w2 

142 

143 def get_energy_contributions(self, atoms=None): 

144 """Return the potential energy from calc1 and calc2 respectively""" 

145 self.calculate( 

146 properties=["energy"], 

147 atoms=atoms, 

148 system_changes=all_changes 

149 ) 

150 return self.results["energy_contributions"] 

151 

152 

153class SumCalculator(LinearCombinationCalculator): 

154 """SumCalculator for combining multiple calculators. 

155 

156 This calculator can be used when there are different calculators 

157 for the different chemical environment or for example during delta 

158 leaning. It works with a list of arbitrary calculators and 

159 evaluates them in sequence when it is required. The supported 

160 properties are the intersection of the implemented properties in 

161 each calculator. 

162 

163 """ 

164 

165 def __init__(self, calcs): 

166 """Implementation of sum of calculators. 

167 

168 calcs: list 

169 List of an arbitrary number of :mod:`ase.calculators` objects. 

170 """ 

171 

172 weights = [1.0] * len(calcs) 

173 super().__init__(calcs, weights) 

174 

175 

176class AverageCalculator(LinearCombinationCalculator): 

177 """AverageCalculator for equal summation of multiple calculators (for 

178 thermodynamic purposes).""" 

179 

180 def __init__(self, calcs): 

181 """Implementation of average of calculators. 

182 

183 calcs: list 

184 List of an arbitrary number of :mod:`ase.calculators` objects. 

185 """ 

186 n = len(calcs) 

187 

188 if n == 0: 

189 raise CalculatorSetupError( 

190 "The value of the calcs must be a list of Calculators" 

191 ) 

192 

193 weights = [1 / n] * n 

194 super().__init__(calcs, weights)