Coverage for /builds/kinetik161/ase/ase/calculators/subprocesscalculator.py: 91.41%
198 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
1import os
2import pickle
3import sys
4from abc import ABC, abstractmethod
5from subprocess import PIPE, Popen
7from ase.calculators.calculator import Calculator, all_properties
10class PackedCalculator(ABC):
11 """Portable calculator for use via PythonSubProcessCalculator.
13 This class allows creating and talking to a calculator which
14 exists inside a different process, possibly with MPI or srun.
16 Use this when you want to use ASE mostly in serial, but run some
17 calculations in a parallel Python environment.
19 Most existing calculators can be used this way through the
20 NamedPackedCalculator implementation. To customize the behaviour
21 for other calculators, write a custom class inheriting this one.
23 Example::
25 from ase.build import bulk
27 atoms = bulk('Au')
28 pack = NamedPackedCalculator('emt')
30 with pack.calculator() as atoms.calc:
31 energy = atoms.get_potential_energy()
33 The computation takes place inside a subprocess which lives as long
34 as the with statement.
35 """
37 @abstractmethod
38 def unpack_calculator(self) -> Calculator:
39 """Return the calculator packed inside.
41 This method will be called inside the subprocess doing
42 computations."""
44 def calculator(self, mpi_command=None) -> 'PythonSubProcessCalculator':
45 """Return a PythonSubProcessCalculator for this calculator.
47 The subprocess calculator wraps a subprocess containing
48 the actual calculator, and computations are done inside that
49 subprocess."""
50 return PythonSubProcessCalculator(self, mpi_command=mpi_command)
53class NamedPackedCalculator(PackedCalculator):
54 """PackedCalculator implementation which works with standard calculators.
56 This works with calculators known by ase.calculators.calculator."""
58 def __init__(self, name, kwargs=None):
59 self._name = name
60 if kwargs is None:
61 kwargs = {}
62 self._kwargs = kwargs
64 def unpack_calculator(self):
65 from ase.calculators.calculator import get_calculator_class
66 cls = get_calculator_class(self._name)
67 return cls(**self._kwargs)
69 def __repr__(self):
70 return f'{self.__class__.__name__}({self._name}, {self._kwargs})'
73class MPICommand:
74 def __init__(self, argv):
75 self.argv = argv
77 @classmethod
78 def python_argv(cls):
79 return [sys.executable, '-m', 'ase.calculators.subprocesscalculator']
81 @classmethod
82 def parallel(cls, nprocs, mpi_argv=()):
83 return cls(['mpiexec', '-n', str(nprocs)]
84 + list(mpi_argv)
85 + cls.python_argv()
86 + ['mpi4py'])
88 @classmethod
89 def serial(cls):
90 return MPICommand(cls.python_argv() + ['standard'])
92 def execute(self):
93 # On this computer (Ubuntu 20.04 + OpenMPI) the subprocess crashes
94 # without output during startup if os.environ is not passed along.
95 # Hence we pass os.environ. Not sure if this is a machine thing
96 # or in general. --askhl
97 return Popen(self.argv, stdout=PIPE,
98 stdin=PIPE, env=os.environ)
101def gpaw_process(ncores=1, **kwargs):
102 packed = NamedPackedCalculator('gpaw', kwargs)
103 mpicommand = MPICommand([
104 sys.executable, '-m', 'gpaw', '-P', str(ncores), 'python', '-m',
105 'ase.calculators.subprocesscalculator', 'standard',
106 ])
107 return PythonSubProcessCalculator(packed, mpicommand)
110class PythonSubProcessCalculator(Calculator):
111 """Calculator for running calculations in external processes.
113 TODO: This should work with arbitrary commands including MPI stuff.
115 This calculator runs a subprocess wherein it sets up an
116 actual calculator. Calculations are forwarded through pickle
117 to that calculator, which returns results through pickle."""
118 implemented_properties = list(all_properties)
120 def __init__(self, calc_input, mpi_command=None):
121 super().__init__()
123 # self.proc = None
124 self.calc_input = calc_input
125 if mpi_command is None:
126 mpi_command = MPICommand.serial()
127 self.mpi_command = mpi_command
129 self.protocol = None
131 def set(self, **kwargs):
132 if hasattr(self, 'client'):
133 raise RuntimeError('No setting things for now, thanks')
135 def __repr__(self):
136 return '{}({})'.format(type(self).__name__,
137 self.calc_input)
139 def __enter__(self):
140 assert self.protocol is None
141 proc = self.mpi_command.execute()
142 self.protocol = Protocol(proc)
143 self.protocol.send(self.calc_input)
144 return self
146 def __exit__(self, *args):
147 self.protocol.send('stop')
148 self.protocol.proc.communicate()
149 self.protocol = None
151 def _run_calculation(self, atoms, properties, system_changes):
152 self.protocol.send('calculate')
153 self.protocol.send((atoms, properties, system_changes))
155 def calculate(self, atoms, properties, system_changes):
156 Calculator.calculate(self, atoms, properties, system_changes)
157 # We send a pickle of self.atoms because this is a fresh copy
158 # of the input, but without an unpicklable calculator:
159 self._run_calculation(self.atoms.copy(), properties, system_changes)
160 results = self.protocol.recv()
161 self.results.update(results)
163 def backend(self):
164 return ParallelBackendInterface(self)
167class Protocol:
168 def __init__(self, proc):
169 self.proc = proc
171 def send(self, obj):
172 pickle.dump(obj, self.proc.stdin)
173 self.proc.stdin.flush()
175 def recv(self):
176 response_type, value = pickle.load(self.proc.stdout)
178 if response_type == 'raise':
179 raise value
181 assert response_type == 'return'
182 return value
185class MockMethod:
186 def __init__(self, name, calc):
187 self.name = name
188 self.calc = calc
190 def __call__(self, *args, **kwargs):
191 protocol = self.calc.protocol
192 protocol.send('callmethod')
193 protocol.send([self.name, args, kwargs])
194 return protocol.recv()
197class ParallelBackendInterface:
198 def __init__(self, calc):
199 self.calc = calc
201 def __getattr__(self, name):
202 return MockMethod(name, self.calc)
205run_modes = {'standard', 'mpi4py'}
208def callmethod(calc, attrname, args, kwargs):
209 method = getattr(calc, attrname)
210 value = method(*args, **kwargs)
211 return value
214def callfunction(func, args, kwargs):
215 return func(*args, **kwargs)
218def calculate(calc, atoms, properties, system_changes):
219 # Again we need formalization of the results/outputs, and
220 # a way to programmatically access all available properties.
221 # We do a wild hack for now:
222 calc.results.clear()
223 # If we don't clear(), the caching is broken! For stress.
224 # But not for forces. What dark magic from the depths of the
225 # underworld is at play here?
226 calc.calculate(atoms=atoms, properties=properties,
227 system_changes=system_changes)
228 results = calc.results
229 return results
232def bad_mode():
233 return SystemExit(f'sys.argv[1] must be one of {run_modes}')
236def parallel_startup():
237 try:
238 run_mode = sys.argv[1]
239 except IndexError:
240 raise bad_mode()
242 if run_mode not in run_modes:
243 raise bad_mode()
245 if run_mode == 'mpi4py':
246 # We must import mpi4py before the rest of ASE, or world will not
247 # be correctly initialized.
248 import mpi4py # noqa
250 # We switch stdout so stray print statements won't interfere with outputs:
251 binary_stdout = sys.stdout.buffer
252 sys.stdout = sys.stderr
254 return Client(input_fd=sys.stdin.buffer,
255 output_fd=binary_stdout)
258class Client:
259 def __init__(self, input_fd, output_fd):
260 from ase.parallel import world
261 self._world = world
262 self.input_fd = input_fd
263 self.output_fd = output_fd
265 def recv(self):
266 from ase.parallel import broadcast
267 if self._world.rank == 0:
268 obj = pickle.load(self.input_fd)
269 else:
270 obj = None
272 obj = broadcast(obj, 0, self._world)
273 return obj
275 def send(self, obj):
276 if self._world.rank == 0:
277 pickle.dump(obj, self.output_fd)
278 self.output_fd.flush()
280 def mainloop(self, calc):
281 while True:
282 instruction = self.recv()
283 if instruction == 'stop':
284 return
286 instruction_data = self.recv()
288 response_type, value = self.process_instruction(
289 calc, instruction, instruction_data)
290 self.send((response_type, value))
292 def process_instruction(self, calc, instruction, instruction_data):
293 if instruction == 'callmethod':
294 function = callmethod
295 args = (calc, *instruction_data)
296 elif instruction == 'calculate':
297 function = calculate
298 args = (calc, *instruction_data)
299 elif instruction == 'callfunction':
300 function = callfunction
301 args = instruction_data
302 else:
303 raise RuntimeError(f'Bad instruction: {instruction}')
305 try:
306 print('ARGS', args)
307 value = function(*args)
308 except Exception as ex:
309 import traceback
310 traceback.print_exc()
311 response_type = 'raise'
312 value = ex
313 else:
314 response_type = 'return'
315 return response_type, value
318class ParallelDispatch:
319 """Utility class to run functions in parallel.
321 with ParallelDispatch(...) as parallel:
322 parallel.call(function, args, kwargs)
324 """
326 def __init__(self, mpicommand):
327 self._mpicommand = mpicommand
328 self._protocol = None
330 def call(self, func, *args, **kwargs):
331 self._protocol.send('callfunction')
332 self._protocol.send((func, args, kwargs))
333 return self._protocol.recv()
335 def __enter__(self):
336 assert self._protocol is None
337 self._protocol = Protocol(self._mpicommand.execute())
339 # Even if we are not using a calculator, we have to send one:
340 pack = NamedPackedCalculator('emt', {})
341 self._protocol.send(pack)
342 # (We should get rid of that requirement.)
344 return self
346 def __exit__(self, *args):
347 self._protocol.send('stop')
348 self._protocol.proc.communicate()
349 self._protocol = None
352def main():
353 client = parallel_startup()
354 pack = client.recv()
355 calc = pack.unpack_calculator()
356 client.mainloop(calc)
359if __name__ == '__main__':
360 main()