Coverage for /builds/kinetik161/ase/ase/ga/data.py: 78.22%

225 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-10 11:04 +0000

1""" 

2 Objects which handle all communication with the SQLite database. 

3""" 

4import os 

5 

6import ase.db 

7from ase import Atoms 

8from ase.ga import get_raw_score, set_neighbor_list, set_parametrization 

9 

10 

11def split_description(desc): 

12 """ Utility method for string splitting. """ 

13 d = desc.split(':') 

14 assert len(d) == 2, desc 

15 return d[0], d[1] 

16 

17 

18def test_raw_score(atoms): 

19 """Test that raw_score can be extracted.""" 

20 err_msg = "raw_score not put in atoms.info['key_value_pairs']" 

21 assert 'raw_score' in atoms.info['key_value_pairs'], err_msg 

22 

23 

24class DataConnection: 

25 """Class that handles all database communication. 

26 

27 All data communication is collected in this class in order to 

28 make a decoupling of the data representation and the GA method. 

29 

30 A new candidate must be added with one of the functions 

31 add_unrelaxed_candidate or add_relaxed_candidate this will correctly 

32 initialize a configuration id used to keep track of candidates in the 

33 database. 

34 After one of the add_*_candidate functions have been used, if the candidate 

35 is further modified or relaxed the functions add_unrelaxed_step or 

36 add_relaxed_step must be used. This way the configuration id carries 

37 through correctly. 

38 

39 Parameters: 

40 

41 db_file_name: Path to the ase.db data file. 

42 """ 

43 

44 def __init__(self, db_file_name): 

45 self.db_file_name = db_file_name 

46 if not os.path.isfile(self.db_file_name): 

47 raise OSError(f'DB file {self.db_file_name} not found') 

48 self.c = ase.db.connect(self.db_file_name) 

49 self.already_returned = set() 

50 

51 def get_number_of_unrelaxed_candidates(self): 

52 """ Returns the number of candidates not yet queued or relaxed. """ 

53 return len(self.__get_ids_of_all_unrelaxed_candidates__()) 

54 

55 def get_an_unrelaxed_candidate(self): 

56 """ Returns a candidate ready for relaxation. """ 

57 to_get = self.__get_ids_of_all_unrelaxed_candidates__() 

58 if len(to_get) == 0: 

59 raise ValueError('No unrelaxed candidate to return') 

60 

61 a = self.__get_latest_traj_for_confid__(to_get[0]) 

62 a.info['confid'] = to_get[0] 

63 if 'data' not in a.info: 

64 a.info['data'] = {} 

65 return a 

66 

67 def get_all_unrelaxed_candidates(self): 

68 """Return all unrelaxed candidates, 

69 useful if they can all be evaluated quickly.""" 

70 to_get = self.__get_ids_of_all_unrelaxed_candidates__() 

71 if len(to_get) == 0: 

72 return [] 

73 res = [] 

74 for confid in to_get: 

75 a = self.__get_latest_traj_for_confid__(confid) 

76 a.info['confid'] = confid 

77 if 'data' not in a.info: 

78 a.info['data'] = {} 

79 res.append(a) 

80 return res 

81 

82 def __get_ids_of_all_unrelaxed_candidates__(self): 

83 """ Helper method used by the two above methods. """ 

84 

85 all_unrelaxed_ids = {t.gaid for t in self.c.select(relaxed=0)} 

86 all_relaxed_ids = {t.gaid for t in self.c.select(relaxed=1)} 

87 all_queued_ids = {t.gaid for t in self.c.select(queued=1)} 

88 

89 actually_unrelaxed = [gaid for gaid in all_unrelaxed_ids 

90 if (gaid not in all_relaxed_ids and 

91 gaid not in all_queued_ids)] 

92 

93 return actually_unrelaxed 

94 

95 def __get_latest_traj_for_confid__(self, confid): 

96 """ Method for obtaining the latest traj 

97 file for a given configuration. 

98 There can be several traj files for 

99 one configuration if it has undergone 

100 several changes (mutations, pairings, etc.).""" 

101 allcands = list(self.c.select(gaid=confid)) 

102 allcands.sort(key=lambda x: x.mtime) 

103 # return self.get_atoms(all[-1].gaid) 

104 return self.get_atoms(allcands[-1].id) 

105 

106 def mark_as_queued(self, a): 

107 """ Marks a configuration as queued for relaxation. """ 

108 gaid = a.info['confid'] 

109 self.c.write(None, gaid=gaid, queued=1, 

110 key_value_pairs=a.info['key_value_pairs']) 

111 

112# if not np.array_equal(a.numbers, self.atom_numbers): 

113# raise ValueError('Wrong stoichiometry') 

114# self.c.write(a, gaid=gaid, queued=1) 

115 

116 def add_relaxed_step(self, a, find_neighbors=None, 

117 perform_parametrization=None): 

118 """After a candidate is relaxed it must be marked 

119 as such. Use this function if the candidate has already been in the 

120 database in an unrelaxed version, i.e. add_unrelaxed_candidate has 

121 been used. 

122 

123 Neighbor list and parametrization parameters to screen 

124 candidates before relaxation can be added. Default is not to use. 

125 """ 

126 # test that raw_score can be extracted 

127 err_msg = "raw_score not put in atoms.info['key_value_pairs']" 

128 assert 'raw_score' in a.info['key_value_pairs'], err_msg 

129 

130 # confid has already been set in add_unrelaxed_candidate 

131 gaid = a.info['confid'] 

132 

133 if 'generation' not in a.info['key_value_pairs']: 

134 g = self.get_generation_number() 

135 a.info['key_value_pairs']['generation'] = g 

136 

137 if find_neighbors is not None: 

138 set_neighbor_list(a, find_neighbors(a)) 

139 if perform_parametrization is not None: 

140 set_parametrization(a, perform_parametrization(a)) 

141 

142 relax_id = self.c.write(a, relaxed=1, gaid=gaid, 

143 key_value_pairs=a.info['key_value_pairs'], 

144 data=a.info['data']) 

145 a.info['relax_id'] = relax_id 

146 

147 def add_relaxed_candidate(self, a, find_neighbors=None, 

148 perform_parametrization=None): 

149 """After a candidate is relaxed it must be marked 

150 as such. Use this function if the candidate has *not* been in the 

151 database in an unrelaxed version, i.e. add_unrelaxed_candidate has 

152 *not* been used. 

153 

154 Neighbor list and parametrization parameters to screen 

155 candidates before relaxation can be added. Default is not to use. 

156 """ 

157 test_raw_score(a) 

158 

159 if 'generation' not in a.info['key_value_pairs']: 

160 g = self.get_generation_number() 

161 a.info['key_value_pairs']['generation'] = g 

162 

163 if find_neighbors is not None: 

164 set_neighbor_list(a, find_neighbors(a)) 

165 if perform_parametrization is not None: 

166 set_parametrization(a, perform_parametrization(a)) 

167 

168 relax_id = self.c.write(a, relaxed=1, 

169 key_value_pairs=a.info['key_value_pairs'], 

170 data=a.info['data']) 

171 self.c.update(relax_id, gaid=relax_id) 

172 a.info['confid'] = relax_id 

173 a.info['relax_id'] = relax_id 

174 

175 def add_more_relaxed_steps(self, a_list): 

176 # This function will be removed soon as the function name indicates 

177 # that unrelaxed candidates are added beforehand 

178 print('Please use add_more_relaxed_candidates instead') 

179 self.add_more_relaxed_candidates(a_list) 

180 

181 def add_more_relaxed_candidates(self, a_list): 

182 """Add more relaxed candidates quickly""" 

183 for a in a_list: 

184 try: 

185 a.info['key_value_pairs']['raw_score'] 

186 except KeyError: 

187 print("raw_score not put in atoms.info['key_value_pairs']") 

188 

189 g = self.get_generation_number() 

190 

191 # Insert gaid by getting the next available id and assuming that the 

192 # entire a_list will be written without interuption 

193 next_id = self.get_next_id() 

194 with self.c as con: 

195 for j, a in enumerate(a_list): 

196 if 'generation' not in a.info['key_value_pairs']: 

197 a.info['key_value_pairs']['generation'] = g 

198 

199 gaid = next_id + j 

200 relax_id = con.write(a, relaxed=1, gaid=gaid, 

201 key_value_pairs=a.info['key_value_pairs'], 

202 data=a.info['data']) 

203 assert gaid == relax_id 

204 a.info['confid'] = relax_id 

205 a.info['relax_id'] = relax_id 

206 

207 def get_next_id(self): 

208 """Get the id of the next candidate to be added to the database. 

209 This is a hacky way of obtaining the id and it only works on a 

210 sqlite database. 

211 """ 

212 con = self.c._connect() 

213 last_id = self.c.get_last_id(con.cursor()) 

214 con.close() 

215 return last_id + 1 

216 

217 def get_largest_in_db(self, var): 

218 return next(self.c.select(sort=f'-{var}')).get(var) 

219 

220 def add_unrelaxed_candidate(self, candidate, description): 

221 """ Adds a new candidate which needs to be relaxed. """ 

222 t, desc = split_description(description) 

223 kwargs = {'relaxed': 0, 

224 'extinct': 0, 

225 t: 1, 

226 'description': desc} 

227 

228 if 'generation' not in candidate.info['key_value_pairs']: 

229 kwargs.update({'generation': self.get_generation_number()}) 

230 

231 gaid = self.c.write(candidate, 

232 key_value_pairs=candidate.info['key_value_pairs'], 

233 data=candidate.info['data'], 

234 **kwargs) 

235 self.c.update(gaid, gaid=gaid) 

236 candidate.info['confid'] = gaid 

237 

238 def add_unrelaxed_step(self, candidate, description): 

239 """ Add a change to a candidate without it having been relaxed. 

240 This method is typically used when a 

241 candidate has been mutated. """ 

242 

243 # confid has already been set by add_unrelaxed_candidate 

244 gaid = candidate.info['confid'] 

245 

246 t, desc = split_description(description) 

247 kwargs = {'relaxed': 0, 

248 'extinct': 0, 

249 t: 1, 

250 'description': desc, 'gaid': gaid} 

251 

252 self.c.write(candidate, 

253 key_value_pairs=candidate.info['key_value_pairs'], 

254 data=candidate.info['data'], 

255 **kwargs) 

256 

257 def get_number_of_atoms_to_optimize(self): 

258 """ Get the number of atoms being optimized. """ 

259 v = self.c.get(simulation_cell=True) 

260 return len(v.data.stoichiometry) 

261 

262 def get_atom_numbers_to_optimize(self): 

263 """ Get the list of atom numbers being optimized. """ 

264 v = self.c.get(simulation_cell=True) 

265 return v.data.stoichiometry 

266 

267 def get_slab(self): 

268 """ Get the super cell, including stationary atoms, in which 

269 the structure is being optimized. """ 

270 return self.c.get_atoms(simulation_cell=True) 

271 

272 def get_participation_in_pairing(self): 

273 """ Get information about how many direct 

274 offsprings each candidate has, and which specific 

275 pairings have been made. This information is used 

276 for the extended fitness calculation described in 

277 L.B. Vilhelmsen et al., JACS, 2012, 134 (30), pp 12807-12816 

278 """ 

279 entries = self.c.select(pairing=1) 

280 

281 frequency = {} 

282 pairs = [] 

283 for e in entries: 

284 c1, c2 = e.data['parents'] 

285 pairs.append(tuple(sorted([c1, c2]))) 

286 if c1 not in frequency.keys(): 

287 frequency[c1] = 0 

288 frequency[c1] += 1 

289 if c2 not in frequency.keys(): 

290 frequency[c2] = 0 

291 frequency[c2] += 1 

292 return (frequency, pairs) 

293 

294 def get_all_relaxed_candidates(self, only_new=False, use_extinct=False): 

295 """ Returns all candidates that have been relaxed. 

296 

297 Parameters: 

298 

299 only_new: boolean (optional) 

300 Used to specify only to get candidates relaxed since last 

301 time this function was invoked. Default: False. 

302 

303 use_extinct: boolean (optional) 

304 Set to True if the extinct key (and mass extinction) is going 

305 to be used. Default: False.""" 

306 

307 if use_extinct: 

308 entries = self.c.select('relaxed=1,extinct=0', 

309 sort='-raw_score') 

310 else: 

311 entries = self.c.select('relaxed=1', sort='-raw_score') 

312 

313 trajs = [] 

314 for v in entries: 

315 if only_new and v.gaid in self.already_returned: 

316 continue 

317 t = self.get_atoms(id=v.id) 

318 t.info['confid'] = v.gaid 

319 t.info['relax_id'] = v.id 

320 trajs.append(t) 

321 self.already_returned.add(v.gaid) 

322 return trajs 

323 

324 def get_all_relaxed_candidates_after_generation(self, gen): 

325 """ Returns all candidates that have been relaxed up to 

326 and including the specified generation 

327 """ 

328 q = 'relaxed=1,extinct=0,generation<={0}' 

329 entries = self.c.select(q.format(gen)) 

330 

331 trajs = [] 

332 for v in entries: 

333 t = self.get_atoms(id=v.id) 

334 t.info['confid'] = v.gaid 

335 t.info['relax_id'] = v.id 

336 trajs.append(t) 

337 trajs.sort(key=get_raw_score, 

338 reverse=True) 

339 return trajs 

340 

341 def get_all_candidates_in_queue(self): 

342 """ Returns all structures that are queued, but have not yet 

343 been relaxed. """ 

344 all_queued_ids = [t.gaid for t in self.c.select(queued=1)] 

345 all_relaxed_ids = [t.gaid for t in self.c.select(relaxed=1)] 

346 

347 in_queue = [qid for qid in all_queued_ids 

348 if qid not in all_relaxed_ids] 

349 return in_queue 

350 

351 def remove_from_queue(self, confid): 

352 """ Removes the candidate confid from the queue. """ 

353 

354 queued_ids = self.c.select(queued=1, gaid=confid) 

355 ids = [q.id for q in queued_ids] 

356 self.c.delete(ids) 

357 

358 def get_generation_number(self, size=None): 

359 """ Returns the current generation number, by looking 

360 at the number of relaxed individuals and comparing 

361 this number to the supplied size or population size. 

362 

363 If all individuals in generation 3 has been relaxed 

364 it will return 4 if not all in generation 4 has been 

365 relaxed. 

366 """ 

367 if size is None: 

368 size = self.get_param('population_size') 

369 if size is None: 

370 # size = len(list(self.c.select(relaxed=0,generation=0))) 

371 return 0 

372 lg = size 

373 g = 0 

374 all_candidates = list(self.c.select(relaxed=1)) 

375 while lg > 0: 

376 lg = len([c for c in all_candidates if c.generation == g]) 

377 if lg >= size: 

378 g += 1 

379 else: 

380 return g 

381 

382 def get_atoms(self, id, add_info=True): 

383 """Return the atoms object with the specified id""" 

384 a = self.c.get_atoms(id, add_additional_information=add_info) 

385 return a 

386 

387 def get_param(self, parameter): 

388 """ Get a parameter saved when creating the database. """ 

389 if self.c.get(1).get('data'): 

390 return self.c.get(1).data.get(parameter, None) 

391 return None 

392 

393 def remove_old_queued(self): 

394 pass 

395 # gen = self.get_generation_number() 

396 # self.c.select() 

397 

398 def is_duplicate(self, **kwargs): 

399 """Check if the key-value pair is already present in the database""" 

400 return len(list(self.c.select(**kwargs))) > 0 

401 

402 def kill_candidate(self, confid): 

403 """Sets extinct=1 in the key_value_pairs of the candidate 

404 with gaid=confid. This could be used in the 

405 mass extinction operator.""" 

406 for dct in self.c.select(gaid=confid): 

407 self.c.update(dct.id, extinct=1) 

408 

409 

410class PrepareDB: 

411 """ Class used to initialize a database. 

412 

413 This class is used once to setup the database and create 

414 working directories. 

415 

416 Parameters: 

417 

418 db_file_name: Database file to use 

419 

420 """ 

421 

422 def __init__(self, db_file_name, simulation_cell=None, **kwargs): 

423 if os.path.exists(db_file_name): 

424 raise OSError('DB file {} already exists' 

425 .format(os.path.abspath(db_file_name))) 

426 self.db_file_name = db_file_name 

427 if simulation_cell is None: 

428 simulation_cell = Atoms() 

429 

430 self.c = ase.db.connect(self.db_file_name) 

431 

432 # Just put everything in data, 

433 # because we don't want to search the db for it. 

434 data = dict(kwargs) 

435 

436 self.c.write(simulation_cell, data=data, 

437 simulation_cell=True) 

438 

439 def add_unrelaxed_candidate(self, candidate, **kwargs): 

440 """ Add an unrelaxed starting candidate. """ 

441 gaid = self.c.write(candidate, origin='StartingCandidateUnrelaxed', 

442 relaxed=0, generation=0, extinct=0, **kwargs) 

443 self.c.update(gaid, gaid=gaid) 

444 candidate.info['confid'] = gaid 

445 

446 def add_relaxed_candidate(self, candidate, **kwargs): 

447 """ Add a relaxed starting candidate. """ 

448 test_raw_score(candidate) 

449 

450 if 'data' in candidate.info: 

451 data = candidate.info['data'] 

452 else: 

453 data = {} 

454 

455 gaid = self.c.write(candidate, origin='StartingCandidateRelaxed', 

456 relaxed=1, generation=0, extinct=0, 

457 key_value_pairs=candidate.info['key_value_pairs'], 

458 data=data, **kwargs) 

459 self.c.update(gaid, gaid=gaid) 

460 candidate.info['confid'] = gaid