Coverage for /builds/kinetik161/ase/ase/utils/linesearch.py: 83.68%

239 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-10 11:04 +0000

1# flake8: noqa 

2import numpy as np 

3 

4pymin = min 

5pymax = max 

6 

7 

8class LineSearch: 

9 def __init__(self, xtol=1e-14): 

10 

11 self.xtol = xtol 

12 self.task = 'START' 

13 self.isave = np.zeros((2,), np.intc) 

14 self.dsave = np.zeros((13,), float) 

15 self.fc = 0 

16 self.gc = 0 

17 self.case = 0 

18 self.old_stp = 0 

19 

20 def _line_search(self, func, myfprime, xk, pk, gfk, old_fval, old_old_fval, 

21 maxstep=.2, c1=.23, c2=0.46, xtrapl=1.1, xtrapu=4., 

22 stpmax=50., stpmin=1e-8, args=()): 

23 self.stpmin = stpmin 

24 self.pk = pk 

25 # ??? p_size = np.sqrt((pk **2).sum()) 

26 self.stpmax = stpmax 

27 self.xtrapl = xtrapl 

28 self.xtrapu = xtrapu 

29 self.maxstep = maxstep 

30 phi0 = old_fval 

31 derphi0 = np.dot(gfk, pk) 

32 self.dim = len(pk) 

33 self.gms = np.sqrt(self.dim) * maxstep 

34 # alpha1 = pymin(maxstep,1.01*2*(phi0-old_old_fval)/derphi0) 

35 alpha1 = 1. 

36 self.no_update = False 

37 

38 if isinstance(myfprime, type(())): 

39 # eps = myfprime[1] 

40 fprime = myfprime[0] 

41 # ??? newargs = (f,eps) + args 

42 gradient = False 

43 else: 

44 fprime = myfprime 

45 newargs = args 

46 gradient = True 

47 

48 fval = old_fval 

49 gval = gfk 

50 self.steps = [] 

51 

52 while True: 

53 stp = self.step(alpha1, phi0, derphi0, c1, c2, 

54 self.xtol, 

55 self.isave, self.dsave) 

56 

57 if self.task[:2] == 'FG': 

58 alpha1 = stp 

59 fval = func(xk + stp * pk, *args) 

60 self.fc += 1 

61 gval = fprime(xk + stp * pk, *newargs) 

62 if gradient: 

63 self.gc += 1 

64 else: 

65 self.fc += len(xk) + 1 

66 phi0 = fval 

67 derphi0 = np.dot(gval, pk) 

68 self.old_stp = alpha1 

69 if self.no_update == True: 

70 break 

71 else: 

72 break 

73 

74 if self.task[:5] == 'ERROR' or self.task[1:4] == 'WARN': 

75 stp = None # failed 

76 return stp, fval, old_fval, self.no_update 

77 

78 def step(self, stp, f, g, c1, c2, xtol, isave, dsave): 

79 if self.task[:5] == 'START': 

80 # Check the input arguments for errors. 

81 if stp < self.stpmin: 

82 self.task = 'ERROR: STP .LT. minstep' 

83 if stp > self.stpmax: 

84 self.task = 'ERROR: STP .GT. maxstep' 

85 if g >= 0: 

86 self.task = 'ERROR: INITIAL G >= 0' 

87 if c1 < 0: 

88 self.task = 'ERROR: c1 .LT. 0' 

89 if c2 < 0: 

90 self.task = 'ERROR: c2 .LT. 0' 

91 if xtol < 0: 

92 self.task = 'ERROR: XTOL .LT. 0' 

93 if self.stpmin < 0: 

94 self.task = 'ERROR: minstep .LT. 0' 

95 if self.stpmax < self.stpmin: 

96 self.task = 'ERROR: maxstep .LT. minstep' 

97 if self.task[:5] == 'ERROR': 

98 return stp 

99 

100 # Initialize local variables. 

101 self.bracket = False 

102 stage = 1 

103 finit = f 

104 ginit = g 

105 gtest = c1 * ginit 

106 width = self.stpmax - self.stpmin 

107 width1 = width / .5 

108# The variables stx, fx, gx contain the values of the step, 

109# function, and derivative at the best step. 

110# The variables sty, fy, gy contain the values of the step, 

111# function, and derivative at sty. 

112# The variables stp, f, g contain the values of the step, 

113# function, and derivative at stp. 

114 stx = 0 

115 fx = finit 

116 gx = ginit 

117 sty = 0 

118 fy = finit 

119 gy = ginit 

120 stmin = 0 

121 stmax = stp + self.xtrapu * stp 

122 self.task = 'FG' 

123 self.save((stage, ginit, gtest, gx, 

124 gy, finit, fx, fy, stx, sty, 

125 stmin, stmax, width, width1)) 

126 stp = self.determine_step(stp) 

127 # return stp, f, g 

128 return stp 

129 else: 

130 if self.isave[0] == 1: 

131 self.bracket = True 

132 else: 

133 self.bracket = False 

134 stage = self.isave[1] 

135 (ginit, gtest, gx, gy, finit, fx, fy, stx, sty, stmin, stmax, 

136 width, width1) = self.dsave 

137 

138# If psi(stp) <= 0 and f'(stp) >= 0 for some step, then the 

139# algorithm enters the second stage. 

140 ftest = finit + stp * gtest 

141 if stage == 1 and f < ftest and g >= 0.: 

142 stage = 2 

143 

144# Test for warnings. 

145 if self.bracket and (stp <= stmin or stp >= stmax): 

146 self.task = 'WARNING: ROUNDING ERRORS PREVENT PROGRESS' 

147 if self.bracket and stmax - stmin <= self.xtol * stmax: 

148 self.task = 'WARNING: XTOL TEST SATISFIED' 

149 if stp == self.stpmax and f <= ftest and g <= gtest: 

150 self.task = 'WARNING: STP = maxstep' 

151 if stp == self.stpmin and (f > ftest or g >= gtest): 

152 self.task = 'WARNING: STP = minstep' 

153 

154# Test for convergence. 

155 if f <= ftest and abs(g) <= c2 * (- ginit): 

156 self.task = 'CONVERGENCE' 

157 

158# Test for termination. 

159 if self.task[:4] == 'WARN' or self.task[:4] == 'CONV': 

160 self.save((stage, ginit, gtest, gx, 

161 gy, finit, fx, fy, stx, sty, 

162 stmin, stmax, width, width1)) 

163 # return stp, f, g 

164 return stp 

165 

166# A modified function is used to predict the step during the 

167# first stage if a lower function value has been obtained but 

168# the decrease is not sufficient. 

169 # if stage == 1 and f <= fx and f > ftest: 

170# # Define the modified function and derivative values. 

171 # fm =f - stp * gtest 

172 # fxm = fx - stx * gtest 

173 # fym = fy - sty * gtest 

174 # gm = g - gtest 

175 # gxm = gx - gtest 

176 # gym = gy - gtest 

177 

178# Call step to update stx, sty, and to compute the new step. 

179 # stx, sty, stp, gxm, fxm, gym, fym = self.update (stx, fxm, gxm, sty, 

180 # fym, gym, stp, fm, gm, 

181 # stmin, stmax) 

182 

183# # Reset the function and derivative values for f. 

184 

185 # fx = fxm + stx * gtest 

186 # fy = fym + sty * gtest 

187 # gx = gxm + gtest 

188 # gy = gym + gtest 

189 

190 # else: 

191# Call step to update stx, sty, and to compute the new step. 

192 

193 stx, sty, stp, gx, fx, gy, fy = self.update(stx, fx, gx, sty, 

194 fy, gy, stp, f, g, 

195 stmin, stmax) 

196 

197 

198# Decide if a bisection step is needed. 

199 

200 if self.bracket: 

201 if abs(sty - stx) >= .66 * width1: 

202 stp = stx + .5 * (sty - stx) 

203 width1 = width 

204 width = abs(sty - stx) 

205 

206# Set the minimum and maximum steps allowed for stp. 

207 

208 if self.bracket: 

209 stmin = min(stx, sty) 

210 stmax = max(stx, sty) 

211 else: 

212 stmin = stp + self.xtrapl * (stp - stx) 

213 stmax = stp + self.xtrapu * (stp - stx) 

214 

215# Force the step to be within the bounds maxstep and minstep. 

216 

217 stp = max(stp, self.stpmin) 

218 stp = min(stp, self.stpmax) 

219 

220 if (stx == stp and stp == self.stpmax and stmin > self.stpmax): 

221 self.no_update = True 

222# If further progress is not possible, let stp be the best 

223# point obtained during the search. 

224 

225 if (self.bracket and stp < stmin or stp >= stmax) \ 

226 or (self.bracket and stmax - stmin < self.xtol * stmax): 

227 stp = stx 

228 

229# Obtain another function and derivative. 

230 

231 self.task = 'FG' 

232 self.save((stage, ginit, gtest, gx, 

233 gy, finit, fx, fy, stx, sty, 

234 stmin, stmax, width, width1)) 

235 return stp 

236 

237 def update(self, stx, fx, gx, sty, fy, gy, stp, fp, gp, 

238 stpmin, stpmax): 

239 sign = gp * (gx / abs(gx)) 

240 

241# First case: A higher function value. The minimum is bracketed. 

242# If the cubic step is closer to stx than the quadratic step, the 

243# cubic step is taken, otherwise the average of the cubic and 

244# quadratic steps is taken. 

245 if fp > fx: # case1 

246 self.case = 1 

247 theta = 3. * (fx - fp) / (stp - stx) + gx + gp 

248 s = max(abs(theta), abs(gx), abs(gp)) 

249 gamma = s * np.sqrt((theta / s) ** 2. - (gx / s) * (gp / s)) 

250 if stp < stx: 

251 gamma = -gamma 

252 p = (gamma - gx) + theta 

253 q = ((gamma - gx) + gamma) + gp 

254 r = p / q 

255 stpc = stx + r * (stp - stx) 

256 stpq = stx + ((gx / ((fx - fp) / (stp - stx) + gx)) / 2.) \ 

257 * (stp - stx) 

258 if (abs(stpc - stx) < abs(stpq - stx)): 

259 stpf = stpc 

260 else: 

261 stpf = stpc + (stpq - stpc) / 2. 

262 

263 self.bracket = True 

264 

265# Second case: A lower function value and derivatives of opposite 

266# sign. The minimum is bracketed. If the cubic step is farther from 

267# stp than the secant step, the cubic step is taken, otherwise the 

268# secant step is taken. 

269 

270 elif sign < 0: # case2 

271 self.case = 2 

272 theta = 3. * (fx - fp) / (stp - stx) + gx + gp 

273 s = max(abs(theta), abs(gx), abs(gp)) 

274 gamma = s * np.sqrt((theta / s) ** 2 - (gx / s) * (gp / s)) 

275 if stp > stx: 

276 gamma = -gamma 

277 p = (gamma - gp) + theta 

278 q = ((gamma - gp) + gamma) + gx 

279 r = p / q 

280 stpc = stp + r * (stx - stp) 

281 stpq = stp + (gp / (gp - gx)) * (stx - stp) 

282 if (abs(stpc - stp) > abs(stpq - stp)): 

283 stpf = stpc 

284 else: 

285 stpf = stpq 

286 self.bracket = True 

287 

288# Third case: A lower function value, derivatives of the same sign, 

289# and the magnitude of the derivative decreases. 

290 

291 elif abs(gp) < abs(gx): # case3 

292 self.case = 3 

293# The cubic step is computed only if the cubic tends to infinity 

294# in the direction of the step or if the minimum of the cubic 

295# is beyond stp. Otherwise the cubic step is defined to be the 

296# secant step. 

297 

298 theta = 3. * (fx - fp) / (stp - stx) + gx + gp 

299 s = max(abs(theta), abs(gx), abs(gp)) 

300 

301# The case gamma = 0 only arises if the cubic does not tend 

302# to infinity in the direction of the step. 

303 

304 gamma = s * np.sqrt(max(0., (theta / s) ** 2 - (gx / s) * (gp / s))) 

305 if stp > stx: 

306 gamma = -gamma 

307 p = (gamma - gp) + theta 

308 q = (gamma + (gx - gp)) + gamma 

309 r = p / q 

310 if r < 0. and gamma != 0: 

311 stpc = stp + r * (stx - stp) 

312 elif stp > stx: 

313 stpc = stpmax 

314 else: 

315 stpc = stpmin 

316 stpq = stp + (gp / (gp - gx)) * (stx - stp) 

317 

318 if self.bracket: 

319 

320 # A minimizer has been bracketed. If the cubic step is 

321 # closer to stp than the secant step, the cubic step is 

322 # taken, otherwise the secant step is taken. 

323 

324 if abs(stpc - stp) < abs(stpq - stp): 

325 stpf = stpc 

326 else: 

327 stpf = stpq 

328 if stp > stx: 

329 stpf = min(stp + .66 * (sty - stp), stpf) 

330 else: 

331 stpf = max(stp + .66 * (sty - stp), stpf) 

332 else: 

333 

334 # A minimizer has not been bracketed. If the cubic step is 

335 # farther from stp than the secant step, the cubic step is 

336 # taken, otherwise the secant step is taken. 

337 

338 if abs(stpc - stp) > abs(stpq - stp): 

339 stpf = stpc 

340 else: 

341 stpf = stpq 

342 stpf = min(stpmax, stpf) 

343 stpf = max(stpmin, stpf) 

344 

345# Fourth case: A lower function value, derivatives of the same sign, 

346# and the magnitude of the derivative does not decrease. If the 

347# minimum is not bracketed, the step is either minstep or maxstep, 

348# otherwise the cubic step is taken. 

349 

350 else: # case4 

351 self.case = 4 

352 if self.bracket: 

353 theta = 3. * (fp - fy) / (sty - stp) + gy + gp 

354 s = max(abs(theta), abs(gy), abs(gp)) 

355 gamma = s * np.sqrt((theta / s) ** 2 - (gy / s) * (gp / s)) 

356 if stp > sty: 

357 gamma = -gamma 

358 p = (gamma - gp) + theta 

359 q = ((gamma - gp) + gamma) + gy 

360 r = p / q 

361 stpc = stp + r * (sty - stp) 

362 stpf = stpc 

363 elif stp > stx: 

364 stpf = stpmax 

365 else: 

366 stpf = stpmin 

367 

368# Update the interval which contains a minimizer. 

369 

370 if fp > fx: 

371 sty = stp 

372 fy = fp 

373 gy = gp 

374 else: 

375 if sign < 0: 

376 sty = stx 

377 fy = fx 

378 gy = gx 

379 stx = stp 

380 fx = fp 

381 gx = gp 

382# Compute the new step. 

383 

384 stp = self.determine_step(stpf) 

385 

386 return stx, sty, stp, gx, fx, gy, fy 

387 

388 def determine_step(self, stp): 

389 dr = stp - self.old_stp 

390 x = np.reshape(self.pk, (-1, 3)) 

391 steplengths = ((dr * x)**2).sum(1)**0.5 

392 maxsteplength = pymax(steplengths) 

393 if maxsteplength >= self.maxstep: 

394 dr *= self.maxstep / maxsteplength 

395 stp = self.old_stp + dr 

396 return stp 

397 

398 def save(self, data): 

399 if self.bracket: 

400 self.isave[0] = 1 

401 else: 

402 self.isave[0] = 0 

403 self.isave[1] = data[0] 

404 self.dsave = data[1:]