Coverage for /builds/kinetik161/ase/ase/optimize/precon/fire.py: 85.22%

115 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-10 11:04 +0000

1import time 

2 

3import numpy as np 

4 

5from ase.filters import UnitCellFilter 

6from ase.optimize.optimize import Optimizer 

7 

8 

9class PreconFIRE(Optimizer): 

10 

11 def __init__(self, atoms, restart=None, logfile='-', trajectory=None, 

12 dt=0.1, maxmove=0.2, dtmax=1.0, Nmin=5, finc=1.1, fdec=0.5, 

13 astart=0.1, fa=0.99, a=0.1, theta=0.1, master=None, 

14 precon=None, use_armijo=True, variable_cell=False): 

15 """ 

16 Preconditioned version of the FIRE optimizer 

17 

18 Parameters: 

19 

20 atoms: Atoms object 

21 The Atoms object to relax. 

22 

23 restart: string 

24 Pickle file used to store hessian matrix. If set, file with 

25 such a name will be searched and hessian matrix stored will 

26 be used, if the file exists. 

27 

28 trajectory: string 

29 Pickle file used to store trajectory of atomic movement. 

30 

31 logfile: file object or str 

32 If *logfile* is a string, a file with that name will be opened. 

33 Use '-' for stdout. 

34 

35 master: bool 

36 Defaults to None, which causes only rank 0 to save files. If 

37 set to true, this rank will save files. 

38 

39 variable_cell: bool 

40 If True, wrap atoms in UnitCellFilter to relax cell and positions. 

41 

42 In time this implementation is expected to replace 

43 ase.optimize.fire.FIRE. 

44 """ 

45 if variable_cell: 

46 atoms = UnitCellFilter(atoms) 

47 Optimizer.__init__(self, atoms, restart, logfile, trajectory, master) 

48 

49 self._actual_atoms = atoms 

50 

51 self.dt = dt 

52 self.Nsteps = 0 

53 self.maxmove = maxmove 

54 self.dtmax = dtmax 

55 self.Nmin = Nmin 

56 self.finc = finc 

57 self.fdec = fdec 

58 self.astart = astart 

59 self.fa = fa 

60 self.a = a 

61 self.theta = theta 

62 self.precon = precon 

63 self.use_armijo = use_armijo 

64 

65 def initialize(self): 

66 self.v = None 

67 self.skip_flag = False 

68 self.e1 = None 

69 

70 def read(self): 

71 self.v, self.dt = self.load() 

72 

73 def step(self, f=None): 

74 atoms = self._actual_atoms 

75 

76 if f is None: 

77 f = atoms.get_forces() 

78 

79 r = atoms.get_positions() 

80 

81 if self.precon is not None: 

82 # Can this be moved out of the step method? 

83 self.precon.make_precon(atoms) 

84 invP_f = self.precon.solve(f.reshape(-1)).reshape(len(atoms), -1) 

85 

86 if self.v is None: 

87 self.v = np.zeros((len(self._actual_atoms), 3)) 

88 else: 

89 if self.use_armijo: 

90 

91 if self.precon is None: 

92 v_test = self.v + self.dt * f 

93 else: 

94 v_test = self.v + self.dt * invP_f 

95 

96 r_test = r + self.dt * v_test 

97 

98 self.skip_flag = False 

99 func_val = self.func(r_test) 

100 self.e1 = func_val 

101 if (func_val > self.func(r) - 

102 self.theta * self.dt * np.vdot(v_test, f)): 

103 self.v[:] *= 0.0 

104 self.a = self.astart 

105 self.dt *= self.fdec 

106 self.Nsteps = 0 

107 self.skip_flag = True 

108 

109 if not self.skip_flag: 

110 

111 v_f = np.vdot(self.v, f) 

112 if v_f > 0.0: 

113 if self.precon is None: 

114 self.v = (1.0 - self.a) * self.v + self.a * f / \ 

115 np.sqrt(np.vdot(f, f)) * \ 

116 np.sqrt(np.vdot(self.v, self.v)) 

117 else: 

118 self.v = ( 

119 (1.0 - self.a) * self.v + 

120 self.a * 

121 (np.sqrt(self.precon.dot(self.v.reshape(-1), 

122 self.v.reshape(-1))) / 

123 np.sqrt(np.dot(f.reshape(-1), 

124 invP_f.reshape(-1))) * invP_f)) 

125 if self.Nsteps > self.Nmin: 

126 self.dt = min(self.dt * self.finc, self.dtmax) 

127 self.a *= self.fa 

128 self.Nsteps += 1 

129 else: 

130 self.v[:] *= 0.0 

131 self.a = self.astart 

132 self.dt *= self.fdec 

133 self.Nsteps = 0 

134 

135 if self.precon is None: 

136 self.v += self.dt * f 

137 else: 

138 self.v += self.dt * invP_f 

139 dr = self.dt * self.v 

140 normdr = np.sqrt(np.vdot(dr, dr)) 

141 if normdr > self.maxmove: 

142 dr = self.maxmove * dr / normdr 

143 atoms.set_positions(r + dr) 

144 self.dump((self.v, self.dt)) 

145 

146 def func(self, x): 

147 """Objective function for use of the optimizers""" 

148 self._actual_atoms.set_positions(x.reshape(-1, 3)) 

149 potl = self._actual_atoms.get_potential_energy() 

150 return potl 

151 

152 def run(self, fmax=0.05, steps=100000000, smax=None): 

153 if smax is None: 

154 smax = fmax 

155 self.smax = smax 

156 return Optimizer.run(self, fmax, steps) 

157 

158 def converged(self, forces=None): 

159 """Did the optimization converge?""" 

160 if forces is None: 

161 forces = self._actual_atoms.get_forces() 

162 if isinstance(self._actual_atoms, UnitCellFilter): 

163 natoms = len(self._actual_atoms.atoms) 

164 forces, stress = forces[:natoms], self._actual_atoms.stress 

165 fmax_sq = (forces**2).sum(axis=1).max() 

166 smax_sq = (stress**2).max() 

167 return (fmax_sq < self.fmax**2 and smax_sq < self.smax**2) 

168 else: 

169 fmax_sq = (forces**2).sum(axis=1).max() 

170 return fmax_sq < self.fmax**2 

171 

172 def log(self, forces=None): 

173 if forces is None: 

174 forces = self._actual_atoms.get_forces() 

175 if isinstance(self._actual_atoms, UnitCellFilter): 

176 natoms = len(self._actual_atoms.atoms) 

177 forces, stress = forces[:natoms], self._actual_atoms.stress 

178 fmax = np.sqrt((forces**2).sum(axis=1).max()) 

179 smax = np.sqrt((stress**2).max()) 

180 else: 

181 fmax = np.sqrt((forces**2).sum(axis=1).max()) 

182 if self.e1 is not None: 

183 # reuse energy at end of line search to avoid extra call 

184 e = self.e1 

185 else: 

186 e = self._actual_atoms.get_potential_energy() 

187 T = time.localtime() 

188 if self.logfile is not None: 

189 name = self.__class__.__name__ 

190 if isinstance(self._actual_atoms, UnitCellFilter): 

191 self.logfile.write( 

192 '%s: %3d %02d:%02d:%02d %15.6f %12.4f %12.4f\n' % 

193 (name, self.nsteps, T[3], T[4], T[5], e, fmax, smax)) 

194 

195 else: 

196 self.logfile.write( 

197 '%s: %3d %02d:%02d:%02d %15.6f %12.4f\n' % 

198 (name, self.nsteps, T[3], T[4], T[5], e, fmax)) 

199 self.logfile.flush()