Coverage for /builds/kinetik161/ase/ase/ga/data.py: 78.22%
225 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
1"""
2 Objects which handle all communication with the SQLite database.
3"""
4import os
6import ase.db
7from ase import Atoms
8from ase.ga import get_raw_score, set_neighbor_list, set_parametrization
11def split_description(desc):
12 """ Utility method for string splitting. """
13 d = desc.split(':')
14 assert len(d) == 2, desc
15 return d[0], d[1]
18def test_raw_score(atoms):
19 """Test that raw_score can be extracted."""
20 err_msg = "raw_score not put in atoms.info['key_value_pairs']"
21 assert 'raw_score' in atoms.info['key_value_pairs'], err_msg
24class DataConnection:
25 """Class that handles all database communication.
27 All data communication is collected in this class in order to
28 make a decoupling of the data representation and the GA method.
30 A new candidate must be added with one of the functions
31 add_unrelaxed_candidate or add_relaxed_candidate this will correctly
32 initialize a configuration id used to keep track of candidates in the
33 database.
34 After one of the add_*_candidate functions have been used, if the candidate
35 is further modified or relaxed the functions add_unrelaxed_step or
36 add_relaxed_step must be used. This way the configuration id carries
37 through correctly.
39 Parameters:
41 db_file_name: Path to the ase.db data file.
42 """
44 def __init__(self, db_file_name):
45 self.db_file_name = db_file_name
46 if not os.path.isfile(self.db_file_name):
47 raise OSError(f'DB file {self.db_file_name} not found')
48 self.c = ase.db.connect(self.db_file_name)
49 self.already_returned = set()
51 def get_number_of_unrelaxed_candidates(self):
52 """ Returns the number of candidates not yet queued or relaxed. """
53 return len(self.__get_ids_of_all_unrelaxed_candidates__())
55 def get_an_unrelaxed_candidate(self):
56 """ Returns a candidate ready for relaxation. """
57 to_get = self.__get_ids_of_all_unrelaxed_candidates__()
58 if len(to_get) == 0:
59 raise ValueError('No unrelaxed candidate to return')
61 a = self.__get_latest_traj_for_confid__(to_get[0])
62 a.info['confid'] = to_get[0]
63 if 'data' not in a.info:
64 a.info['data'] = {}
65 return a
67 def get_all_unrelaxed_candidates(self):
68 """Return all unrelaxed candidates,
69 useful if they can all be evaluated quickly."""
70 to_get = self.__get_ids_of_all_unrelaxed_candidates__()
71 if len(to_get) == 0:
72 return []
73 res = []
74 for confid in to_get:
75 a = self.__get_latest_traj_for_confid__(confid)
76 a.info['confid'] = confid
77 if 'data' not in a.info:
78 a.info['data'] = {}
79 res.append(a)
80 return res
82 def __get_ids_of_all_unrelaxed_candidates__(self):
83 """ Helper method used by the two above methods. """
85 all_unrelaxed_ids = {t.gaid for t in self.c.select(relaxed=0)}
86 all_relaxed_ids = {t.gaid for t in self.c.select(relaxed=1)}
87 all_queued_ids = {t.gaid for t in self.c.select(queued=1)}
89 actually_unrelaxed = [gaid for gaid in all_unrelaxed_ids
90 if (gaid not in all_relaxed_ids and
91 gaid not in all_queued_ids)]
93 return actually_unrelaxed
95 def __get_latest_traj_for_confid__(self, confid):
96 """ Method for obtaining the latest traj
97 file for a given configuration.
98 There can be several traj files for
99 one configuration if it has undergone
100 several changes (mutations, pairings, etc.)."""
101 allcands = list(self.c.select(gaid=confid))
102 allcands.sort(key=lambda x: x.mtime)
103 # return self.get_atoms(all[-1].gaid)
104 return self.get_atoms(allcands[-1].id)
106 def mark_as_queued(self, a):
107 """ Marks a configuration as queued for relaxation. """
108 gaid = a.info['confid']
109 self.c.write(None, gaid=gaid, queued=1,
110 key_value_pairs=a.info['key_value_pairs'])
112# if not np.array_equal(a.numbers, self.atom_numbers):
113# raise ValueError('Wrong stoichiometry')
114# self.c.write(a, gaid=gaid, queued=1)
116 def add_relaxed_step(self, a, find_neighbors=None,
117 perform_parametrization=None):
118 """After a candidate is relaxed it must be marked
119 as such. Use this function if the candidate has already been in the
120 database in an unrelaxed version, i.e. add_unrelaxed_candidate has
121 been used.
123 Neighbor list and parametrization parameters to screen
124 candidates before relaxation can be added. Default is not to use.
125 """
126 # test that raw_score can be extracted
127 err_msg = "raw_score not put in atoms.info['key_value_pairs']"
128 assert 'raw_score' in a.info['key_value_pairs'], err_msg
130 # confid has already been set in add_unrelaxed_candidate
131 gaid = a.info['confid']
133 if 'generation' not in a.info['key_value_pairs']:
134 g = self.get_generation_number()
135 a.info['key_value_pairs']['generation'] = g
137 if find_neighbors is not None:
138 set_neighbor_list(a, find_neighbors(a))
139 if perform_parametrization is not None:
140 set_parametrization(a, perform_parametrization(a))
142 relax_id = self.c.write(a, relaxed=1, gaid=gaid,
143 key_value_pairs=a.info['key_value_pairs'],
144 data=a.info['data'])
145 a.info['relax_id'] = relax_id
147 def add_relaxed_candidate(self, a, find_neighbors=None,
148 perform_parametrization=None):
149 """After a candidate is relaxed it must be marked
150 as such. Use this function if the candidate has *not* been in the
151 database in an unrelaxed version, i.e. add_unrelaxed_candidate has
152 *not* been used.
154 Neighbor list and parametrization parameters to screen
155 candidates before relaxation can be added. Default is not to use.
156 """
157 test_raw_score(a)
159 if 'generation' not in a.info['key_value_pairs']:
160 g = self.get_generation_number()
161 a.info['key_value_pairs']['generation'] = g
163 if find_neighbors is not None:
164 set_neighbor_list(a, find_neighbors(a))
165 if perform_parametrization is not None:
166 set_parametrization(a, perform_parametrization(a))
168 relax_id = self.c.write(a, relaxed=1,
169 key_value_pairs=a.info['key_value_pairs'],
170 data=a.info['data'])
171 self.c.update(relax_id, gaid=relax_id)
172 a.info['confid'] = relax_id
173 a.info['relax_id'] = relax_id
175 def add_more_relaxed_steps(self, a_list):
176 # This function will be removed soon as the function name indicates
177 # that unrelaxed candidates are added beforehand
178 print('Please use add_more_relaxed_candidates instead')
179 self.add_more_relaxed_candidates(a_list)
181 def add_more_relaxed_candidates(self, a_list):
182 """Add more relaxed candidates quickly"""
183 for a in a_list:
184 try:
185 a.info['key_value_pairs']['raw_score']
186 except KeyError:
187 print("raw_score not put in atoms.info['key_value_pairs']")
189 g = self.get_generation_number()
191 # Insert gaid by getting the next available id and assuming that the
192 # entire a_list will be written without interuption
193 next_id = self.get_next_id()
194 with self.c as con:
195 for j, a in enumerate(a_list):
196 if 'generation' not in a.info['key_value_pairs']:
197 a.info['key_value_pairs']['generation'] = g
199 gaid = next_id + j
200 relax_id = con.write(a, relaxed=1, gaid=gaid,
201 key_value_pairs=a.info['key_value_pairs'],
202 data=a.info['data'])
203 assert gaid == relax_id
204 a.info['confid'] = relax_id
205 a.info['relax_id'] = relax_id
207 def get_next_id(self):
208 """Get the id of the next candidate to be added to the database.
209 This is a hacky way of obtaining the id and it only works on a
210 sqlite database.
211 """
212 con = self.c._connect()
213 last_id = self.c.get_last_id(con.cursor())
214 con.close()
215 return last_id + 1
217 def get_largest_in_db(self, var):
218 return next(self.c.select(sort=f'-{var}')).get(var)
220 def add_unrelaxed_candidate(self, candidate, description):
221 """ Adds a new candidate which needs to be relaxed. """
222 t, desc = split_description(description)
223 kwargs = {'relaxed': 0,
224 'extinct': 0,
225 t: 1,
226 'description': desc}
228 if 'generation' not in candidate.info['key_value_pairs']:
229 kwargs.update({'generation': self.get_generation_number()})
231 gaid = self.c.write(candidate,
232 key_value_pairs=candidate.info['key_value_pairs'],
233 data=candidate.info['data'],
234 **kwargs)
235 self.c.update(gaid, gaid=gaid)
236 candidate.info['confid'] = gaid
238 def add_unrelaxed_step(self, candidate, description):
239 """ Add a change to a candidate without it having been relaxed.
240 This method is typically used when a
241 candidate has been mutated. """
243 # confid has already been set by add_unrelaxed_candidate
244 gaid = candidate.info['confid']
246 t, desc = split_description(description)
247 kwargs = {'relaxed': 0,
248 'extinct': 0,
249 t: 1,
250 'description': desc, 'gaid': gaid}
252 self.c.write(candidate,
253 key_value_pairs=candidate.info['key_value_pairs'],
254 data=candidate.info['data'],
255 **kwargs)
257 def get_number_of_atoms_to_optimize(self):
258 """ Get the number of atoms being optimized. """
259 v = self.c.get(simulation_cell=True)
260 return len(v.data.stoichiometry)
262 def get_atom_numbers_to_optimize(self):
263 """ Get the list of atom numbers being optimized. """
264 v = self.c.get(simulation_cell=True)
265 return v.data.stoichiometry
267 def get_slab(self):
268 """ Get the super cell, including stationary atoms, in which
269 the structure is being optimized. """
270 return self.c.get_atoms(simulation_cell=True)
272 def get_participation_in_pairing(self):
273 """ Get information about how many direct
274 offsprings each candidate has, and which specific
275 pairings have been made. This information is used
276 for the extended fitness calculation described in
277 L.B. Vilhelmsen et al., JACS, 2012, 134 (30), pp 12807-12816
278 """
279 entries = self.c.select(pairing=1)
281 frequency = {}
282 pairs = []
283 for e in entries:
284 c1, c2 = e.data['parents']
285 pairs.append(tuple(sorted([c1, c2])))
286 if c1 not in frequency.keys():
287 frequency[c1] = 0
288 frequency[c1] += 1
289 if c2 not in frequency.keys():
290 frequency[c2] = 0
291 frequency[c2] += 1
292 return (frequency, pairs)
294 def get_all_relaxed_candidates(self, only_new=False, use_extinct=False):
295 """ Returns all candidates that have been relaxed.
297 Parameters:
299 only_new: boolean (optional)
300 Used to specify only to get candidates relaxed since last
301 time this function was invoked. Default: False.
303 use_extinct: boolean (optional)
304 Set to True if the extinct key (and mass extinction) is going
305 to be used. Default: False."""
307 if use_extinct:
308 entries = self.c.select('relaxed=1,extinct=0',
309 sort='-raw_score')
310 else:
311 entries = self.c.select('relaxed=1', sort='-raw_score')
313 trajs = []
314 for v in entries:
315 if only_new and v.gaid in self.already_returned:
316 continue
317 t = self.get_atoms(id=v.id)
318 t.info['confid'] = v.gaid
319 t.info['relax_id'] = v.id
320 trajs.append(t)
321 self.already_returned.add(v.gaid)
322 return trajs
324 def get_all_relaxed_candidates_after_generation(self, gen):
325 """ Returns all candidates that have been relaxed up to
326 and including the specified generation
327 """
328 q = 'relaxed=1,extinct=0,generation<={0}'
329 entries = self.c.select(q.format(gen))
331 trajs = []
332 for v in entries:
333 t = self.get_atoms(id=v.id)
334 t.info['confid'] = v.gaid
335 t.info['relax_id'] = v.id
336 trajs.append(t)
337 trajs.sort(key=get_raw_score,
338 reverse=True)
339 return trajs
341 def get_all_candidates_in_queue(self):
342 """ Returns all structures that are queued, but have not yet
343 been relaxed. """
344 all_queued_ids = [t.gaid for t in self.c.select(queued=1)]
345 all_relaxed_ids = [t.gaid for t in self.c.select(relaxed=1)]
347 in_queue = [qid for qid in all_queued_ids
348 if qid not in all_relaxed_ids]
349 return in_queue
351 def remove_from_queue(self, confid):
352 """ Removes the candidate confid from the queue. """
354 queued_ids = self.c.select(queued=1, gaid=confid)
355 ids = [q.id for q in queued_ids]
356 self.c.delete(ids)
358 def get_generation_number(self, size=None):
359 """ Returns the current generation number, by looking
360 at the number of relaxed individuals and comparing
361 this number to the supplied size or population size.
363 If all individuals in generation 3 has been relaxed
364 it will return 4 if not all in generation 4 has been
365 relaxed.
366 """
367 if size is None:
368 size = self.get_param('population_size')
369 if size is None:
370 # size = len(list(self.c.select(relaxed=0,generation=0)))
371 return 0
372 lg = size
373 g = 0
374 all_candidates = list(self.c.select(relaxed=1))
375 while lg > 0:
376 lg = len([c for c in all_candidates if c.generation == g])
377 if lg >= size:
378 g += 1
379 else:
380 return g
382 def get_atoms(self, id, add_info=True):
383 """Return the atoms object with the specified id"""
384 a = self.c.get_atoms(id, add_additional_information=add_info)
385 return a
387 def get_param(self, parameter):
388 """ Get a parameter saved when creating the database. """
389 if self.c.get(1).get('data'):
390 return self.c.get(1).data.get(parameter, None)
391 return None
393 def remove_old_queued(self):
394 pass
395 # gen = self.get_generation_number()
396 # self.c.select()
398 def is_duplicate(self, **kwargs):
399 """Check if the key-value pair is already present in the database"""
400 return len(list(self.c.select(**kwargs))) > 0
402 def kill_candidate(self, confid):
403 """Sets extinct=1 in the key_value_pairs of the candidate
404 with gaid=confid. This could be used in the
405 mass extinction operator."""
406 for dct in self.c.select(gaid=confid):
407 self.c.update(dct.id, extinct=1)
410class PrepareDB:
411 """ Class used to initialize a database.
413 This class is used once to setup the database and create
414 working directories.
416 Parameters:
418 db_file_name: Database file to use
420 """
422 def __init__(self, db_file_name, simulation_cell=None, **kwargs):
423 if os.path.exists(db_file_name):
424 raise OSError('DB file {} already exists'
425 .format(os.path.abspath(db_file_name)))
426 self.db_file_name = db_file_name
427 if simulation_cell is None:
428 simulation_cell = Atoms()
430 self.c = ase.db.connect(self.db_file_name)
432 # Just put everything in data,
433 # because we don't want to search the db for it.
434 data = dict(kwargs)
436 self.c.write(simulation_cell, data=data,
437 simulation_cell=True)
439 def add_unrelaxed_candidate(self, candidate, **kwargs):
440 """ Add an unrelaxed starting candidate. """
441 gaid = self.c.write(candidate, origin='StartingCandidateUnrelaxed',
442 relaxed=0, generation=0, extinct=0, **kwargs)
443 self.c.update(gaid, gaid=gaid)
444 candidate.info['confid'] = gaid
446 def add_relaxed_candidate(self, candidate, **kwargs):
447 """ Add a relaxed starting candidate. """
448 test_raw_score(candidate)
450 if 'data' in candidate.info:
451 data = candidate.info['data']
452 else:
453 data = {}
455 gaid = self.c.write(candidate, origin='StartingCandidateRelaxed',
456 relaxed=1, generation=0, extinct=0,
457 key_value_pairs=candidate.info['key_value_pairs'],
458 data=data, **kwargs)
459 self.c.update(gaid, gaid=gaid)
460 candidate.info['confid'] = gaid