Coverage for /builds/kinetik161/ase/ase/utils/linesearch.py: 83.68%
239 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
1# flake8: noqa
2import numpy as np
4pymin = min
5pymax = max
8class LineSearch:
9 def __init__(self, xtol=1e-14):
11 self.xtol = xtol
12 self.task = 'START'
13 self.isave = np.zeros((2,), np.intc)
14 self.dsave = np.zeros((13,), float)
15 self.fc = 0
16 self.gc = 0
17 self.case = 0
18 self.old_stp = 0
20 def _line_search(self, func, myfprime, xk, pk, gfk, old_fval, old_old_fval,
21 maxstep=.2, c1=.23, c2=0.46, xtrapl=1.1, xtrapu=4.,
22 stpmax=50., stpmin=1e-8, args=()):
23 self.stpmin = stpmin
24 self.pk = pk
25 # ??? p_size = np.sqrt((pk **2).sum())
26 self.stpmax = stpmax
27 self.xtrapl = xtrapl
28 self.xtrapu = xtrapu
29 self.maxstep = maxstep
30 phi0 = old_fval
31 derphi0 = np.dot(gfk, pk)
32 self.dim = len(pk)
33 self.gms = np.sqrt(self.dim) * maxstep
34 # alpha1 = pymin(maxstep,1.01*2*(phi0-old_old_fval)/derphi0)
35 alpha1 = 1.
36 self.no_update = False
38 if isinstance(myfprime, type(())):
39 # eps = myfprime[1]
40 fprime = myfprime[0]
41 # ??? newargs = (f,eps) + args
42 gradient = False
43 else:
44 fprime = myfprime
45 newargs = args
46 gradient = True
48 fval = old_fval
49 gval = gfk
50 self.steps = []
52 while True:
53 stp = self.step(alpha1, phi0, derphi0, c1, c2,
54 self.xtol,
55 self.isave, self.dsave)
57 if self.task[:2] == 'FG':
58 alpha1 = stp
59 fval = func(xk + stp * pk, *args)
60 self.fc += 1
61 gval = fprime(xk + stp * pk, *newargs)
62 if gradient:
63 self.gc += 1
64 else:
65 self.fc += len(xk) + 1
66 phi0 = fval
67 derphi0 = np.dot(gval, pk)
68 self.old_stp = alpha1
69 if self.no_update == True:
70 break
71 else:
72 break
74 if self.task[:5] == 'ERROR' or self.task[1:4] == 'WARN':
75 stp = None # failed
76 return stp, fval, old_fval, self.no_update
78 def step(self, stp, f, g, c1, c2, xtol, isave, dsave):
79 if self.task[:5] == 'START':
80 # Check the input arguments for errors.
81 if stp < self.stpmin:
82 self.task = 'ERROR: STP .LT. minstep'
83 if stp > self.stpmax:
84 self.task = 'ERROR: STP .GT. maxstep'
85 if g >= 0:
86 self.task = 'ERROR: INITIAL G >= 0'
87 if c1 < 0:
88 self.task = 'ERROR: c1 .LT. 0'
89 if c2 < 0:
90 self.task = 'ERROR: c2 .LT. 0'
91 if xtol < 0:
92 self.task = 'ERROR: XTOL .LT. 0'
93 if self.stpmin < 0:
94 self.task = 'ERROR: minstep .LT. 0'
95 if self.stpmax < self.stpmin:
96 self.task = 'ERROR: maxstep .LT. minstep'
97 if self.task[:5] == 'ERROR':
98 return stp
100 # Initialize local variables.
101 self.bracket = False
102 stage = 1
103 finit = f
104 ginit = g
105 gtest = c1 * ginit
106 width = self.stpmax - self.stpmin
107 width1 = width / .5
108# The variables stx, fx, gx contain the values of the step,
109# function, and derivative at the best step.
110# The variables sty, fy, gy contain the values of the step,
111# function, and derivative at sty.
112# The variables stp, f, g contain the values of the step,
113# function, and derivative at stp.
114 stx = 0
115 fx = finit
116 gx = ginit
117 sty = 0
118 fy = finit
119 gy = ginit
120 stmin = 0
121 stmax = stp + self.xtrapu * stp
122 self.task = 'FG'
123 self.save((stage, ginit, gtest, gx,
124 gy, finit, fx, fy, stx, sty,
125 stmin, stmax, width, width1))
126 stp = self.determine_step(stp)
127 # return stp, f, g
128 return stp
129 else:
130 if self.isave[0] == 1:
131 self.bracket = True
132 else:
133 self.bracket = False
134 stage = self.isave[1]
135 (ginit, gtest, gx, gy, finit, fx, fy, stx, sty, stmin, stmax,
136 width, width1) = self.dsave
138# If psi(stp) <= 0 and f'(stp) >= 0 for some step, then the
139# algorithm enters the second stage.
140 ftest = finit + stp * gtest
141 if stage == 1 and f < ftest and g >= 0.:
142 stage = 2
144# Test for warnings.
145 if self.bracket and (stp <= stmin or stp >= stmax):
146 self.task = 'WARNING: ROUNDING ERRORS PREVENT PROGRESS'
147 if self.bracket and stmax - stmin <= self.xtol * stmax:
148 self.task = 'WARNING: XTOL TEST SATISFIED'
149 if stp == self.stpmax and f <= ftest and g <= gtest:
150 self.task = 'WARNING: STP = maxstep'
151 if stp == self.stpmin and (f > ftest or g >= gtest):
152 self.task = 'WARNING: STP = minstep'
154# Test for convergence.
155 if f <= ftest and abs(g) <= c2 * (- ginit):
156 self.task = 'CONVERGENCE'
158# Test for termination.
159 if self.task[:4] == 'WARN' or self.task[:4] == 'CONV':
160 self.save((stage, ginit, gtest, gx,
161 gy, finit, fx, fy, stx, sty,
162 stmin, stmax, width, width1))
163 # return stp, f, g
164 return stp
166# A modified function is used to predict the step during the
167# first stage if a lower function value has been obtained but
168# the decrease is not sufficient.
169 # if stage == 1 and f <= fx and f > ftest:
170# # Define the modified function and derivative values.
171 # fm =f - stp * gtest
172 # fxm = fx - stx * gtest
173 # fym = fy - sty * gtest
174 # gm = g - gtest
175 # gxm = gx - gtest
176 # gym = gy - gtest
178# Call step to update stx, sty, and to compute the new step.
179 # stx, sty, stp, gxm, fxm, gym, fym = self.update (stx, fxm, gxm, sty,
180 # fym, gym, stp, fm, gm,
181 # stmin, stmax)
183# # Reset the function and derivative values for f.
185 # fx = fxm + stx * gtest
186 # fy = fym + sty * gtest
187 # gx = gxm + gtest
188 # gy = gym + gtest
190 # else:
191# Call step to update stx, sty, and to compute the new step.
193 stx, sty, stp, gx, fx, gy, fy = self.update(stx, fx, gx, sty,
194 fy, gy, stp, f, g,
195 stmin, stmax)
198# Decide if a bisection step is needed.
200 if self.bracket:
201 if abs(sty - stx) >= .66 * width1:
202 stp = stx + .5 * (sty - stx)
203 width1 = width
204 width = abs(sty - stx)
206# Set the minimum and maximum steps allowed for stp.
208 if self.bracket:
209 stmin = min(stx, sty)
210 stmax = max(stx, sty)
211 else:
212 stmin = stp + self.xtrapl * (stp - stx)
213 stmax = stp + self.xtrapu * (stp - stx)
215# Force the step to be within the bounds maxstep and minstep.
217 stp = max(stp, self.stpmin)
218 stp = min(stp, self.stpmax)
220 if (stx == stp and stp == self.stpmax and stmin > self.stpmax):
221 self.no_update = True
222# If further progress is not possible, let stp be the best
223# point obtained during the search.
225 if (self.bracket and stp < stmin or stp >= stmax) \
226 or (self.bracket and stmax - stmin < self.xtol * stmax):
227 stp = stx
229# Obtain another function and derivative.
231 self.task = 'FG'
232 self.save((stage, ginit, gtest, gx,
233 gy, finit, fx, fy, stx, sty,
234 stmin, stmax, width, width1))
235 return stp
237 def update(self, stx, fx, gx, sty, fy, gy, stp, fp, gp,
238 stpmin, stpmax):
239 sign = gp * (gx / abs(gx))
241# First case: A higher function value. The minimum is bracketed.
242# If the cubic step is closer to stx than the quadratic step, the
243# cubic step is taken, otherwise the average of the cubic and
244# quadratic steps is taken.
245 if fp > fx: # case1
246 self.case = 1
247 theta = 3. * (fx - fp) / (stp - stx) + gx + gp
248 s = max(abs(theta), abs(gx), abs(gp))
249 gamma = s * np.sqrt((theta / s) ** 2. - (gx / s) * (gp / s))
250 if stp < stx:
251 gamma = -gamma
252 p = (gamma - gx) + theta
253 q = ((gamma - gx) + gamma) + gp
254 r = p / q
255 stpc = stx + r * (stp - stx)
256 stpq = stx + ((gx / ((fx - fp) / (stp - stx) + gx)) / 2.) \
257 * (stp - stx)
258 if (abs(stpc - stx) < abs(stpq - stx)):
259 stpf = stpc
260 else:
261 stpf = stpc + (stpq - stpc) / 2.
263 self.bracket = True
265# Second case: A lower function value and derivatives of opposite
266# sign. The minimum is bracketed. If the cubic step is farther from
267# stp than the secant step, the cubic step is taken, otherwise the
268# secant step is taken.
270 elif sign < 0: # case2
271 self.case = 2
272 theta = 3. * (fx - fp) / (stp - stx) + gx + gp
273 s = max(abs(theta), abs(gx), abs(gp))
274 gamma = s * np.sqrt((theta / s) ** 2 - (gx / s) * (gp / s))
275 if stp > stx:
276 gamma = -gamma
277 p = (gamma - gp) + theta
278 q = ((gamma - gp) + gamma) + gx
279 r = p / q
280 stpc = stp + r * (stx - stp)
281 stpq = stp + (gp / (gp - gx)) * (stx - stp)
282 if (abs(stpc - stp) > abs(stpq - stp)):
283 stpf = stpc
284 else:
285 stpf = stpq
286 self.bracket = True
288# Third case: A lower function value, derivatives of the same sign,
289# and the magnitude of the derivative decreases.
291 elif abs(gp) < abs(gx): # case3
292 self.case = 3
293# The cubic step is computed only if the cubic tends to infinity
294# in the direction of the step or if the minimum of the cubic
295# is beyond stp. Otherwise the cubic step is defined to be the
296# secant step.
298 theta = 3. * (fx - fp) / (stp - stx) + gx + gp
299 s = max(abs(theta), abs(gx), abs(gp))
301# The case gamma = 0 only arises if the cubic does not tend
302# to infinity in the direction of the step.
304 gamma = s * np.sqrt(max(0., (theta / s) ** 2 - (gx / s) * (gp / s)))
305 if stp > stx:
306 gamma = -gamma
307 p = (gamma - gp) + theta
308 q = (gamma + (gx - gp)) + gamma
309 r = p / q
310 if r < 0. and gamma != 0:
311 stpc = stp + r * (stx - stp)
312 elif stp > stx:
313 stpc = stpmax
314 else:
315 stpc = stpmin
316 stpq = stp + (gp / (gp - gx)) * (stx - stp)
318 if self.bracket:
320 # A minimizer has been bracketed. If the cubic step is
321 # closer to stp than the secant step, the cubic step is
322 # taken, otherwise the secant step is taken.
324 if abs(stpc - stp) < abs(stpq - stp):
325 stpf = stpc
326 else:
327 stpf = stpq
328 if stp > stx:
329 stpf = min(stp + .66 * (sty - stp), stpf)
330 else:
331 stpf = max(stp + .66 * (sty - stp), stpf)
332 else:
334 # A minimizer has not been bracketed. If the cubic step is
335 # farther from stp than the secant step, the cubic step is
336 # taken, otherwise the secant step is taken.
338 if abs(stpc - stp) > abs(stpq - stp):
339 stpf = stpc
340 else:
341 stpf = stpq
342 stpf = min(stpmax, stpf)
343 stpf = max(stpmin, stpf)
345# Fourth case: A lower function value, derivatives of the same sign,
346# and the magnitude of the derivative does not decrease. If the
347# minimum is not bracketed, the step is either minstep or maxstep,
348# otherwise the cubic step is taken.
350 else: # case4
351 self.case = 4
352 if self.bracket:
353 theta = 3. * (fp - fy) / (sty - stp) + gy + gp
354 s = max(abs(theta), abs(gy), abs(gp))
355 gamma = s * np.sqrt((theta / s) ** 2 - (gy / s) * (gp / s))
356 if stp > sty:
357 gamma = -gamma
358 p = (gamma - gp) + theta
359 q = ((gamma - gp) + gamma) + gy
360 r = p / q
361 stpc = stp + r * (sty - stp)
362 stpf = stpc
363 elif stp > stx:
364 stpf = stpmax
365 else:
366 stpf = stpmin
368# Update the interval which contains a minimizer.
370 if fp > fx:
371 sty = stp
372 fy = fp
373 gy = gp
374 else:
375 if sign < 0:
376 sty = stx
377 fy = fx
378 gy = gx
379 stx = stp
380 fx = fp
381 gx = gp
382# Compute the new step.
384 stp = self.determine_step(stpf)
386 return stx, sty, stp, gx, fx, gy, fy
388 def determine_step(self, stp):
389 dr = stp - self.old_stp
390 x = np.reshape(self.pk, (-1, 3))
391 steplengths = ((dr * x)**2).sum(1)**0.5
392 maxsteplength = pymax(steplengths)
393 if maxsteplength >= self.maxstep:
394 dr *= self.maxstep / maxsteplength
395 stp = self.old_stp + dr
396 return stp
398 def save(self, data):
399 if self.bracket:
400 self.isave[0] = 1
401 else:
402 self.isave[0] = 0
403 self.isave[1] = data[0]
404 self.dsave = data[1:]