Coverage for /builds/kinetik161/ase/ase/calculators/genericfileio.py: 88.28%

145 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-10 11:04 +0000

1from abc import ABC, abstractmethod 

2from os import PathLike 

3from pathlib import Path 

4from typing import Any, Iterable, Mapping 

5 

6from ase.calculators.abc import GetOutputsMixin 

7from ase.calculators.calculator import BaseCalculator, EnvironmentError 

8from ase.config import cfg as _cfg 

9 

10from contextlib import ExitStack 

11 

12 

13class BaseProfile(ABC): 

14 def __init__(self, parallel=True, parallel_info=None): 

15 """ 

16 Parameters 

17 ---------- 

18 parallel : bool 

19 If the calculator should be run in parallel. 

20 parallel_info : dict 

21 Additional settings for parallel execution, e.g. arguments 

22 for the binary for parallelization (mpiexec, srun, mpirun). 

23 """ 

24 self.parallel_info = parallel_info or {} 

25 self.parallel = parallel 

26 

27 def get_translation_keys(self): 

28 """ 

29 Get the translation keys for the parallel_info dictionary. 

30 

31 A translation key is specified in a config file with the syntax 

32 `key_kwarg_trans = command, type`, e.g if `nprocs_kwarg_trans = -np` 

33 is specified in the config file, then the key `nprocs` will be 

34 translated to `-np`. Then `nprocs` can be specified in parallel_info 

35 and will be translated to `-np` when the command is build. 

36 

37 Returns 

38 ------- 

39 dict of iterable 

40 Dictionary with translation keys where the keys are the keys in 

41 parallel_info that will be translated, the value is what the key 

42 will be translated into. 

43 """ 

44 translation_keys = {} 

45 for key, value in self.parallel_info.items(): 

46 if len(key) < 12: 

47 continue 

48 if key.endswith('_kwarg_trans'): 

49 trans_key = key[:-12] 

50 translation_keys[trans_key] = value 

51 return translation_keys 

52 

53 def get_command(self, inputfile, calc_command=None) -> Iterable[str]: 

54 """ 

55 Get the command to run. This should be a list of strings. 

56 

57 Parameters 

58 ---------- 

59 inputfile : str 

60 calc_command: list[str]: calculator command (used for sockets) 

61 

62 Returns 

63 ------- 

64 list of str 

65 The command to run. 

66 """ 

67 command = [] 

68 if self.parallel: 

69 if 'binary' in self.parallel_info: 

70 command.append(self.parallel_info['binary']) 

71 

72 translation_keys = self.get_translation_keys() 

73 

74 for key, value in self.parallel_info.items(): 

75 if key == 'binary' or '_kwarg_trans' in key: 

76 continue 

77 

78 command_key = key 

79 if key in translation_keys: 

80 command_key = translation_keys[key] 

81 

82 if type(value) is not bool: 

83 command.append(f'{command_key}') 

84 command.append(f'{value}') 

85 elif value: 

86 command.append(f'{command_key}') 

87 

88 if calc_command is None: 

89 command.extend(self.get_calculator_command(inputfile)) 

90 else: 

91 command.extend(calc_command) 

92 return command 

93 

94 @abstractmethod 

95 def get_calculator_command(self, inputfile): 

96 """ 

97 The calculator specific command as a list of strings. 

98 

99 Parameters 

100 ---------- 

101 inputfile : str 

102 

103 Returns 

104 ------- 

105 list of str 

106 The command to run. 

107 """ 

108 ... 

109 

110 def run( 

111 self, directory, inputfile, outputfile, errorfile=None, append=False 

112 ): 

113 """ 

114 Run the command in the given directory. 

115 

116 Parameters 

117 ---------- 

118 directory : pathlib.Path 

119 The directory to run the command in. 

120 inputfile : str 

121 The name of the input file. 

122 outputfile : str 

123 The name of the output file. 

124 errorfile: str 

125 the stderror file 

126 append: bool 

127 if True then use append mode 

128 """ 

129 

130 from subprocess import check_call 

131 import os 

132 

133 argv_command = self.get_command(inputfile) 

134 mode = 'wb' if not append else 'ab' 

135 

136 with ExitStack() as stack: 

137 fd_out = stack.enter_context(open(outputfile, mode)) 

138 if errorfile is not None: 

139 fd_err = stack.enter_context(open(errorfile, mode)) 

140 else: 

141 fd_err = None 

142 check_call( 

143 argv_command, 

144 cwd=directory, 

145 stdout=fd_out, 

146 stderr=fd_err, 

147 env=os.environ, 

148 ) 

149 

150 @abstractmethod 

151 def version(self): 

152 """ 

153 Get the version of the code. 

154 

155 Returns 

156 ------- 

157 str 

158 The version of the code. 

159 """ 

160 ... 

161 

162 @classmethod 

163 def from_config(cls, cfg, section_name, parallel_info=None, parallel=True): 

164 """ 

165 Create a profile from a configuration file. 

166 

167 Parameters 

168 ---------- 

169 cfg : ase.config.Config 

170 The configuration object. 

171 section_name : str 

172 The name of the section in the configuration file. E.g. the name 

173 of the template that this profile is for. 

174 

175 Returns 

176 ------- 

177 BaseProfile 

178 The profile object. 

179 """ 

180 parallel_config = dict(cfg.parser['parallel']) 

181 parallel_info = parallel_info if parallel_info is not None else {} 

182 parallel_config.update(parallel_info) 

183 

184 return cls( 

185 **cfg.parser[section_name], 

186 parallel_info=parallel_config, 

187 parallel=parallel, 

188 ) 

189 

190 

191def read_stdout(args, createfile=None): 

192 """Run command in tempdir and return standard output. 

193 

194 Helper function for getting version numbers of DFT codes. 

195 Most DFT codes don't implement a --version flag, so in order to 

196 determine the code version, we just run the code until it prints 

197 a version number.""" 

198 import tempfile 

199 from subprocess import PIPE, Popen 

200 

201 with tempfile.TemporaryDirectory() as directory: 

202 if createfile is not None: 

203 path = Path(directory) / createfile 

204 path.touch() 

205 proc = Popen( 

206 args, 

207 stdout=PIPE, 

208 stderr=PIPE, 

209 stdin=PIPE, 

210 cwd=directory, 

211 encoding='ascii', 

212 ) 

213 stdout, _ = proc.communicate() 

214 # Exit code will be != 0 because there isn't an input file 

215 return stdout 

216 

217 

218class CalculatorTemplate(ABC): 

219 def __init__(self, name: str, implemented_properties: Iterable[str]): 

220 self.name = name 

221 self.implemented_properties = frozenset(implemented_properties) 

222 

223 @abstractmethod 

224 def write_input(self, profile, directory, atoms, parameters, properties): 

225 ... 

226 

227 @abstractmethod 

228 def execute(self, directory, profile): 

229 ... 

230 

231 @abstractmethod 

232 def read_results(self, directory: PathLike) -> Mapping[str, Any]: 

233 ... 

234 

235 @abstractmethod 

236 def load_profile(self, cfg, parallel_info=None, parallel=True): 

237 ... 

238 

239 def socketio_calculator( 

240 self, 

241 profile, 

242 parameters, 

243 directory, 

244 # We may need quite a few socket kwargs here 

245 # if we want to expose all the timeout etc. from 

246 # SocketIOCalculator. 

247 unixsocket=None, 

248 port=None, 

249 ): 

250 import os 

251 from subprocess import Popen 

252 

253 from ase.calculators.socketio import SocketIOCalculator 

254 

255 if port and unixsocket: 

256 raise TypeError( 

257 'For the socketio_calculator only a UNIX ' 

258 '(unixsocket) or INET (port) socket can be used' 

259 ' not both.' 

260 ) 

261 

262 if not port and not unixsocket: 

263 raise TypeError( 

264 'For the socketio_calculator either a ' 

265 'UNIX (unixsocket) or INET (port) socket ' 

266 'must be used' 

267 ) 

268 

269 if not ( 

270 hasattr(self, 'socketio_argv') 

271 and hasattr(self, 'socketio_parameters') 

272 ): 

273 raise TypeError( 

274 f'Template {self} does not implement mandatory ' 

275 'socketio_argv() and socketio_parameters()' 

276 ) 

277 

278 # XXX need socketio ABC or something 

279 argv = profile.get_command( 

280 inputfile=None, 

281 calc_command=self.socketio_argv(profile, unixsocket, port) 

282 ) 

283 parameters = { 

284 **self.socketio_parameters(unixsocket, port), 

285 **parameters, 

286 } 

287 

288 # Not so elegant that socket args are passed to this function 

289 # via socketiocalculator when we could make a closure right here. 

290 def launch(atoms, properties, port, unixsocket): 

291 directory.mkdir(exist_ok=True, parents=True) 

292 

293 self.write_input( 

294 atoms=atoms, 

295 profile=profile, 

296 parameters=parameters, 

297 properties=properties, 

298 directory=directory, 

299 ) 

300 

301 with open(directory / self.outputname, 'w') as out_fd: 

302 return Popen(argv, stdout=out_fd, cwd=directory, env=os.environ) 

303 

304 return SocketIOCalculator( 

305 launch_client=launch, unixsocket=unixsocket, port=port 

306 ) 

307 

308 

309class GenericFileIOCalculator(BaseCalculator, GetOutputsMixin): 

310 cfg = _cfg 

311 

312 def __init__( 

313 self, 

314 *, 

315 template, 

316 profile, 

317 directory, 

318 parameters=None, 

319 parallel_info=None, 

320 parallel=True, 

321 ): 

322 self.template = template 

323 if profile is None: 

324 if template.name not in self.cfg.parser: 

325 raise EnvironmentError(f'No configuration of {template.name}') 

326 try: 

327 profile = template.load_profile( 

328 self.cfg, parallel_info=parallel_info, parallel=parallel 

329 ) 

330 except Exception as err: 

331 configvars = self.cfg.as_dict() 

332 raise EnvironmentError( 

333 f'Failed to load section [{template.name}] ' 

334 f'from configuration: {configvars}' 

335 ) from err 

336 

337 self.profile = profile 

338 

339 # Maybe we should allow directory to be a factory, so 

340 # calculators e.g. produce new directories on demand. 

341 self.directory = Path(directory) 

342 super().__init__(parameters) 

343 

344 def set(self, *args, **kwargs): 

345 raise RuntimeError( 

346 'No setting parameters for now, please. ' 

347 'Just create new calculators.' 

348 ) 

349 

350 def __repr__(self): 

351 return f'{type(self).__name__}({self.template.name})' 

352 

353 @property 

354 def implemented_properties(self): 

355 return self.template.implemented_properties 

356 

357 @property 

358 def name(self): 

359 return self.template.name 

360 

361 def write_inputfiles(self, atoms, properties): 

362 # SocketIOCalculators like to write inputfiles 

363 # without calculating. 

364 self.directory.mkdir(exist_ok=True, parents=True) 

365 self.template.write_input( 

366 profile=self.profile, 

367 atoms=atoms, 

368 parameters=self.parameters, 

369 properties=properties, 

370 directory=self.directory, 

371 ) 

372 

373 def calculate(self, atoms, properties, system_changes): 

374 self.write_inputfiles(atoms, properties) 

375 self.template.execute(self.directory, self.profile) 

376 self.results = self.template.read_results(self.directory) 

377 # XXX Return something useful? 

378 

379 def _outputmixin_get_results(self): 

380 return self.results 

381 

382 def socketio(self, **socketkwargs): 

383 return self.template.socketio_calculator( 

384 directory=self.directory, 

385 parameters=self.parameters, 

386 profile=self.profile, 

387 **socketkwargs, 

388 )