Coverage for /builds/kinetik161/ase/ase/calculators/genericfileio.py: 88.28%
145 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
1from abc import ABC, abstractmethod
2from os import PathLike
3from pathlib import Path
4from typing import Any, Iterable, Mapping
6from ase.calculators.abc import GetOutputsMixin
7from ase.calculators.calculator import BaseCalculator, EnvironmentError
8from ase.config import cfg as _cfg
10from contextlib import ExitStack
13class BaseProfile(ABC):
14 def __init__(self, parallel=True, parallel_info=None):
15 """
16 Parameters
17 ----------
18 parallel : bool
19 If the calculator should be run in parallel.
20 parallel_info : dict
21 Additional settings for parallel execution, e.g. arguments
22 for the binary for parallelization (mpiexec, srun, mpirun).
23 """
24 self.parallel_info = parallel_info or {}
25 self.parallel = parallel
27 def get_translation_keys(self):
28 """
29 Get the translation keys for the parallel_info dictionary.
31 A translation key is specified in a config file with the syntax
32 `key_kwarg_trans = command, type`, e.g if `nprocs_kwarg_trans = -np`
33 is specified in the config file, then the key `nprocs` will be
34 translated to `-np`. Then `nprocs` can be specified in parallel_info
35 and will be translated to `-np` when the command is build.
37 Returns
38 -------
39 dict of iterable
40 Dictionary with translation keys where the keys are the keys in
41 parallel_info that will be translated, the value is what the key
42 will be translated into.
43 """
44 translation_keys = {}
45 for key, value in self.parallel_info.items():
46 if len(key) < 12:
47 continue
48 if key.endswith('_kwarg_trans'):
49 trans_key = key[:-12]
50 translation_keys[trans_key] = value
51 return translation_keys
53 def get_command(self, inputfile, calc_command=None) -> Iterable[str]:
54 """
55 Get the command to run. This should be a list of strings.
57 Parameters
58 ----------
59 inputfile : str
60 calc_command: list[str]: calculator command (used for sockets)
62 Returns
63 -------
64 list of str
65 The command to run.
66 """
67 command = []
68 if self.parallel:
69 if 'binary' in self.parallel_info:
70 command.append(self.parallel_info['binary'])
72 translation_keys = self.get_translation_keys()
74 for key, value in self.parallel_info.items():
75 if key == 'binary' or '_kwarg_trans' in key:
76 continue
78 command_key = key
79 if key in translation_keys:
80 command_key = translation_keys[key]
82 if type(value) is not bool:
83 command.append(f'{command_key}')
84 command.append(f'{value}')
85 elif value:
86 command.append(f'{command_key}')
88 if calc_command is None:
89 command.extend(self.get_calculator_command(inputfile))
90 else:
91 command.extend(calc_command)
92 return command
94 @abstractmethod
95 def get_calculator_command(self, inputfile):
96 """
97 The calculator specific command as a list of strings.
99 Parameters
100 ----------
101 inputfile : str
103 Returns
104 -------
105 list of str
106 The command to run.
107 """
108 ...
110 def run(
111 self, directory, inputfile, outputfile, errorfile=None, append=False
112 ):
113 """
114 Run the command in the given directory.
116 Parameters
117 ----------
118 directory : pathlib.Path
119 The directory to run the command in.
120 inputfile : str
121 The name of the input file.
122 outputfile : str
123 The name of the output file.
124 errorfile: str
125 the stderror file
126 append: bool
127 if True then use append mode
128 """
130 from subprocess import check_call
131 import os
133 argv_command = self.get_command(inputfile)
134 mode = 'wb' if not append else 'ab'
136 with ExitStack() as stack:
137 fd_out = stack.enter_context(open(outputfile, mode))
138 if errorfile is not None:
139 fd_err = stack.enter_context(open(errorfile, mode))
140 else:
141 fd_err = None
142 check_call(
143 argv_command,
144 cwd=directory,
145 stdout=fd_out,
146 stderr=fd_err,
147 env=os.environ,
148 )
150 @abstractmethod
151 def version(self):
152 """
153 Get the version of the code.
155 Returns
156 -------
157 str
158 The version of the code.
159 """
160 ...
162 @classmethod
163 def from_config(cls, cfg, section_name, parallel_info=None, parallel=True):
164 """
165 Create a profile from a configuration file.
167 Parameters
168 ----------
169 cfg : ase.config.Config
170 The configuration object.
171 section_name : str
172 The name of the section in the configuration file. E.g. the name
173 of the template that this profile is for.
175 Returns
176 -------
177 BaseProfile
178 The profile object.
179 """
180 parallel_config = dict(cfg.parser['parallel'])
181 parallel_info = parallel_info if parallel_info is not None else {}
182 parallel_config.update(parallel_info)
184 return cls(
185 **cfg.parser[section_name],
186 parallel_info=parallel_config,
187 parallel=parallel,
188 )
191def read_stdout(args, createfile=None):
192 """Run command in tempdir and return standard output.
194 Helper function for getting version numbers of DFT codes.
195 Most DFT codes don't implement a --version flag, so in order to
196 determine the code version, we just run the code until it prints
197 a version number."""
198 import tempfile
199 from subprocess import PIPE, Popen
201 with tempfile.TemporaryDirectory() as directory:
202 if createfile is not None:
203 path = Path(directory) / createfile
204 path.touch()
205 proc = Popen(
206 args,
207 stdout=PIPE,
208 stderr=PIPE,
209 stdin=PIPE,
210 cwd=directory,
211 encoding='ascii',
212 )
213 stdout, _ = proc.communicate()
214 # Exit code will be != 0 because there isn't an input file
215 return stdout
218class CalculatorTemplate(ABC):
219 def __init__(self, name: str, implemented_properties: Iterable[str]):
220 self.name = name
221 self.implemented_properties = frozenset(implemented_properties)
223 @abstractmethod
224 def write_input(self, profile, directory, atoms, parameters, properties):
225 ...
227 @abstractmethod
228 def execute(self, directory, profile):
229 ...
231 @abstractmethod
232 def read_results(self, directory: PathLike) -> Mapping[str, Any]:
233 ...
235 @abstractmethod
236 def load_profile(self, cfg, parallel_info=None, parallel=True):
237 ...
239 def socketio_calculator(
240 self,
241 profile,
242 parameters,
243 directory,
244 # We may need quite a few socket kwargs here
245 # if we want to expose all the timeout etc. from
246 # SocketIOCalculator.
247 unixsocket=None,
248 port=None,
249 ):
250 import os
251 from subprocess import Popen
253 from ase.calculators.socketio import SocketIOCalculator
255 if port and unixsocket:
256 raise TypeError(
257 'For the socketio_calculator only a UNIX '
258 '(unixsocket) or INET (port) socket can be used'
259 ' not both.'
260 )
262 if not port and not unixsocket:
263 raise TypeError(
264 'For the socketio_calculator either a '
265 'UNIX (unixsocket) or INET (port) socket '
266 'must be used'
267 )
269 if not (
270 hasattr(self, 'socketio_argv')
271 and hasattr(self, 'socketio_parameters')
272 ):
273 raise TypeError(
274 f'Template {self} does not implement mandatory '
275 'socketio_argv() and socketio_parameters()'
276 )
278 # XXX need socketio ABC or something
279 argv = profile.get_command(
280 inputfile=None,
281 calc_command=self.socketio_argv(profile, unixsocket, port)
282 )
283 parameters = {
284 **self.socketio_parameters(unixsocket, port),
285 **parameters,
286 }
288 # Not so elegant that socket args are passed to this function
289 # via socketiocalculator when we could make a closure right here.
290 def launch(atoms, properties, port, unixsocket):
291 directory.mkdir(exist_ok=True, parents=True)
293 self.write_input(
294 atoms=atoms,
295 profile=profile,
296 parameters=parameters,
297 properties=properties,
298 directory=directory,
299 )
301 with open(directory / self.outputname, 'w') as out_fd:
302 return Popen(argv, stdout=out_fd, cwd=directory, env=os.environ)
304 return SocketIOCalculator(
305 launch_client=launch, unixsocket=unixsocket, port=port
306 )
309class GenericFileIOCalculator(BaseCalculator, GetOutputsMixin):
310 cfg = _cfg
312 def __init__(
313 self,
314 *,
315 template,
316 profile,
317 directory,
318 parameters=None,
319 parallel_info=None,
320 parallel=True,
321 ):
322 self.template = template
323 if profile is None:
324 if template.name not in self.cfg.parser:
325 raise EnvironmentError(f'No configuration of {template.name}')
326 try:
327 profile = template.load_profile(
328 self.cfg, parallel_info=parallel_info, parallel=parallel
329 )
330 except Exception as err:
331 configvars = self.cfg.as_dict()
332 raise EnvironmentError(
333 f'Failed to load section [{template.name}] '
334 f'from configuration: {configvars}'
335 ) from err
337 self.profile = profile
339 # Maybe we should allow directory to be a factory, so
340 # calculators e.g. produce new directories on demand.
341 self.directory = Path(directory)
342 super().__init__(parameters)
344 def set(self, *args, **kwargs):
345 raise RuntimeError(
346 'No setting parameters for now, please. '
347 'Just create new calculators.'
348 )
350 def __repr__(self):
351 return f'{type(self).__name__}({self.template.name})'
353 @property
354 def implemented_properties(self):
355 return self.template.implemented_properties
357 @property
358 def name(self):
359 return self.template.name
361 def write_inputfiles(self, atoms, properties):
362 # SocketIOCalculators like to write inputfiles
363 # without calculating.
364 self.directory.mkdir(exist_ok=True, parents=True)
365 self.template.write_input(
366 profile=self.profile,
367 atoms=atoms,
368 parameters=self.parameters,
369 properties=properties,
370 directory=self.directory,
371 )
373 def calculate(self, atoms, properties, system_changes):
374 self.write_inputfiles(atoms, properties)
375 self.template.execute(self.directory, self.profile)
376 self.results = self.template.read_results(self.directory)
377 # XXX Return something useful?
379 def _outputmixin_get_results(self):
380 return self.results
382 def socketio(self, **socketkwargs):
383 return self.template.socketio_calculator(
384 directory=self.directory,
385 parameters=self.parameters,
386 profile=self.profile,
387 **socketkwargs,
388 )