Coverage for /builds/kinetik161/ase/ase/parallel.py: 49.11%
224 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
1import atexit
2import functools
3import os
4import pickle
5import sys
6import time
7import warnings
9import numpy as np
12def get_txt(txt, rank):
13 if hasattr(txt, 'write'):
14 # Note: User-supplied object might write to files from many ranks.
15 return txt
16 elif rank == 0:
17 if txt is None:
18 return open(os.devnull, 'w')
19 elif txt == '-':
20 return sys.stdout
21 else:
22 return open(txt, 'w', 1)
23 else:
24 return open(os.devnull, 'w')
27def paropen(name, mode='r', buffering=-1, encoding=None, comm=None):
28 """MPI-safe version of open function.
30 In read mode, the file is opened on all nodes. In write and
31 append mode, the file is opened on the master only, and /dev/null
32 is opened on all other nodes.
33 """
34 if comm is None:
35 comm = world
36 if comm.rank > 0 and mode[0] != 'r':
37 name = os.devnull
38 return open(name, mode, buffering, encoding)
41def parprint(*args, **kwargs):
42 """MPI-safe print - prints only from master. """
43 if world.rank == 0:
44 print(*args, **kwargs)
47class DummyMPI:
48 rank = 0
49 size = 1
51 def _returnval(self, a, root=-1):
52 # MPI interface works either on numbers, in which case a number is
53 # returned, or on arrays, in-place.
54 if np.isscalar(a):
55 return a
56 if hasattr(a, '__array__'):
57 a = a.__array__()
58 assert isinstance(a, np.ndarray)
59 return None
61 def sum(self, a, root=-1):
62 if np.isscalar(a):
63 warnings.warn('Please use sum_scalar(...) for scalar arguments',
64 FutureWarning)
65 return self._returnval(a)
67 def sum_scalar(self, a, root=-1):
68 return a
70 def product(self, a, root=-1):
71 return self._returnval(a)
73 def broadcast(self, a, root):
74 assert root == 0
75 return self._returnval(a)
77 def barrier(self):
78 pass
81class MPI:
82 """Wrapper for MPI world object.
84 Decides at runtime (after all imports) which one to use:
86 * MPI4Py
87 * GPAW
88 * a dummy implementation for serial runs
90 """
92 def __init__(self):
93 self.comm = None
95 def __getattr__(self, name):
96 # Pickling of objects that carry instances of MPI class
97 # (e.g. NEB) raises RecursionError since it tries to access
98 # the optional __setstate__ method (which we do not implement)
99 # when unpickling. The two lines below prevent the
100 # RecursionError. This also affects modules that use pickling
101 # e.g. multiprocessing. For more details see:
102 # https://gitlab.com/ase/ase/-/merge_requests/2695
103 if name == '__setstate__':
104 raise AttributeError(name)
106 if self.comm is None:
107 self.comm = _get_comm()
108 return getattr(self.comm, name)
111def _get_comm():
112 """Get the correct MPI world object."""
113 if 'mpi4py' in sys.modules:
114 return MPI4PY()
115 if '_gpaw' in sys.modules:
116 import _gpaw
117 if hasattr(_gpaw, 'Communicator'):
118 return _gpaw.Communicator()
119 if '_asap' in sys.modules:
120 import _asap
121 if hasattr(_asap, 'Communicator'):
122 return _asap.Communicator()
123 return DummyMPI()
126class MPI4PY:
127 def __init__(self, mpi4py_comm=None):
128 if mpi4py_comm is None:
129 from mpi4py import MPI
130 mpi4py_comm = MPI.COMM_WORLD
131 self.comm = mpi4py_comm
133 @property
134 def rank(self):
135 return self.comm.rank
137 @property
138 def size(self):
139 return self.comm.size
141 def _returnval(self, a, b):
142 """Behave correctly when working on scalars/arrays.
144 Either input is an array and we in-place write b (output from
145 mpi4py) back into a, or input is a scalar and we return the
146 corresponding output scalar."""
147 if np.isscalar(a):
148 assert np.isscalar(b)
149 return b
150 else:
151 assert not np.isscalar(b)
152 a[:] = b
153 return None
155 def sum(self, a, root=-1):
156 if root == -1:
157 b = self.comm.allreduce(a)
158 else:
159 b = self.comm.reduce(a, root)
160 if np.isscalar(a):
161 warnings.warn('Please use sum_scalar(...) for scalar arguments',
162 FutureWarning)
163 return self._returnval(a, b)
165 def sum_scalar(self, a, root=-1):
166 if root == -1:
167 b = self.comm.allreduce(a)
168 else:
169 b = self.comm.reduce(a, root)
170 return b
172 def split(self, split_size=None):
173 """Divide the communicator."""
174 # color - subgroup id
175 # key - new subgroup rank
176 if not split_size:
177 split_size = self.size
178 color = int(self.rank // (self.size / split_size))
179 key = int(self.rank % (self.size / split_size))
180 comm = self.comm.Split(color, key)
181 return MPI4PY(comm)
183 def barrier(self):
184 self.comm.barrier()
186 def abort(self, code):
187 self.comm.Abort(code)
189 def broadcast(self, a, root):
190 b = self.comm.bcast(a, root=root)
191 if self.rank == root:
192 if np.isscalar(a):
193 return a
194 return
195 return self._returnval(a, b)
198world = None
200# Check for special MPI-enabled Python interpreters:
201if '_gpaw' in sys.builtin_module_names:
202 # http://wiki.fysik.dtu.dk/gpaw
203 import _gpaw
204 world = _gpaw.Communicator()
205elif '_asap' in sys.builtin_module_names:
206 # Modern version of Asap
207 # http://wiki.fysik.dtu.dk/asap
208 # We cannot import asap3.mpi here, as that creates an import deadlock
209 import _asap
210 world = _asap.Communicator()
212# Check if MPI implementation has been imported already:
213elif '_gpaw' in sys.modules:
214 # Same thing as above but for the module version
215 import _gpaw
216 try:
217 world = _gpaw.Communicator()
218 except AttributeError:
219 pass
220elif '_asap' in sys.modules:
221 import _asap
222 try:
223 world = _asap.Communicator()
224 except AttributeError:
225 pass
226elif 'mpi4py' in sys.modules:
227 world = MPI4PY()
229if world is None:
230 world = MPI()
233def barrier():
234 world.barrier()
237def broadcast(obj, root=0, comm=world):
238 """Broadcast a Python object across an MPI communicator and return it."""
239 if comm.rank == root:
240 string = pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
241 n = np.array([len(string)], int)
242 else:
243 string = None
244 n = np.empty(1, int)
245 comm.broadcast(n, root)
246 if comm.rank == root:
247 string = np.frombuffer(string, np.int8)
248 else:
249 string = np.zeros(n, np.int8)
250 comm.broadcast(string, root)
251 if comm.rank == root:
252 return obj
253 else:
254 return pickle.loads(string.tobytes())
257def parallel_function(func):
258 """Decorator for broadcasting from master to slaves using MPI.
260 Disable by passing parallel=False to the function. For a method,
261 you can also disable the parallel behavior by giving the instance
262 a self.serial = True.
263 """
265 @functools.wraps(func)
266 def new_func(*args, **kwargs):
267 if (world.size == 1 or
268 args and getattr(args[0], 'serial', False) or
269 not kwargs.pop('parallel', True)):
270 # Disable:
271 return func(*args, **kwargs)
273 ex = None
274 result = None
275 if world.rank == 0:
276 try:
277 result = func(*args, **kwargs)
278 except Exception as x:
279 ex = x
280 ex, result = broadcast((ex, result))
281 if ex is not None:
282 raise ex
283 return result
285 return new_func
288def parallel_generator(generator):
289 """Decorator for broadcasting yields from master to slaves using MPI.
291 Disable by passing parallel=False to the function. For a method,
292 you can also disable the parallel behavior by giving the instance
293 a self.serial = True.
294 """
296 @functools.wraps(generator)
297 def new_generator(*args, **kwargs):
298 if (world.size == 1 or
299 args and getattr(args[0], 'serial', False) or
300 not kwargs.pop('parallel', True)):
301 # Disable:
302 for result in generator(*args, **kwargs):
303 yield result
304 return
306 if world.rank == 0:
307 try:
308 for result in generator(*args, **kwargs):
309 broadcast((None, result))
310 yield result
311 except Exception as ex:
312 broadcast((ex, None))
313 raise ex
314 broadcast((None, None))
315 else:
316 ex2, result = broadcast((None, None))
317 if ex2 is not None:
318 raise ex2
319 while result is not None:
320 yield result
321 ex2, result = broadcast((None, None))
322 if ex2 is not None:
323 raise ex2
325 return new_generator
328def register_parallel_cleanup_function():
329 """Call MPI_Abort if python crashes.
331 This will terminate the processes on the other nodes."""
333 if world.size == 1:
334 return
336 def cleanup(sys=sys, time=time, world=world):
337 error = getattr(sys, 'last_type', None)
338 if error:
339 sys.stdout.flush()
340 sys.stderr.write(('ASE CLEANUP (node %d): %s occurred. ' +
341 'Calling MPI_Abort!\n') % (world.rank, error))
342 sys.stderr.flush()
343 # Give other nodes a moment to crash by themselves (perhaps
344 # producing helpful error messages):
345 time.sleep(3)
346 world.abort(42)
348 atexit.register(cleanup)
351def distribute_cpus(size, comm):
352 """Distribute cpus to tasks and calculators.
354 Input:
355 size: number of nodes per calculator
356 comm: total communicator object
358 Output:
359 communicator for this rank, number of calculators, index for this rank
360 """
362 assert size <= comm.size
363 assert comm.size % size == 0
365 tasks_rank = comm.rank // size
367 r0 = tasks_rank * size
368 ranks = np.arange(r0, r0 + size)
369 mycomm = comm.new_communicator(ranks)
371 return mycomm, comm.size // size, tasks_rank
374def myslice(ntotal, comm):
375 """Return the slice of your tasks for ntotal jobs"""
376 n = -(-ntotal // comm.size) # ceil divide
377 return slice(n * comm.rank, n * (comm.rank + 1))