Coverage for /builds/kinetik161/ase/ase/parallel.py: 49.11%

224 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-10 11:04 +0000

1import atexit 

2import functools 

3import os 

4import pickle 

5import sys 

6import time 

7import warnings 

8 

9import numpy as np 

10 

11 

12def get_txt(txt, rank): 

13 if hasattr(txt, 'write'): 

14 # Note: User-supplied object might write to files from many ranks. 

15 return txt 

16 elif rank == 0: 

17 if txt is None: 

18 return open(os.devnull, 'w') 

19 elif txt == '-': 

20 return sys.stdout 

21 else: 

22 return open(txt, 'w', 1) 

23 else: 

24 return open(os.devnull, 'w') 

25 

26 

27def paropen(name, mode='r', buffering=-1, encoding=None, comm=None): 

28 """MPI-safe version of open function. 

29 

30 In read mode, the file is opened on all nodes. In write and 

31 append mode, the file is opened on the master only, and /dev/null 

32 is opened on all other nodes. 

33 """ 

34 if comm is None: 

35 comm = world 

36 if comm.rank > 0 and mode[0] != 'r': 

37 name = os.devnull 

38 return open(name, mode, buffering, encoding) 

39 

40 

41def parprint(*args, **kwargs): 

42 """MPI-safe print - prints only from master. """ 

43 if world.rank == 0: 

44 print(*args, **kwargs) 

45 

46 

47class DummyMPI: 

48 rank = 0 

49 size = 1 

50 

51 def _returnval(self, a, root=-1): 

52 # MPI interface works either on numbers, in which case a number is 

53 # returned, or on arrays, in-place. 

54 if np.isscalar(a): 

55 return a 

56 if hasattr(a, '__array__'): 

57 a = a.__array__() 

58 assert isinstance(a, np.ndarray) 

59 return None 

60 

61 def sum(self, a, root=-1): 

62 if np.isscalar(a): 

63 warnings.warn('Please use sum_scalar(...) for scalar arguments', 

64 FutureWarning) 

65 return self._returnval(a) 

66 

67 def sum_scalar(self, a, root=-1): 

68 return a 

69 

70 def product(self, a, root=-1): 

71 return self._returnval(a) 

72 

73 def broadcast(self, a, root): 

74 assert root == 0 

75 return self._returnval(a) 

76 

77 def barrier(self): 

78 pass 

79 

80 

81class MPI: 

82 """Wrapper for MPI world object. 

83 

84 Decides at runtime (after all imports) which one to use: 

85 

86 * MPI4Py 

87 * GPAW 

88 * a dummy implementation for serial runs 

89 

90 """ 

91 

92 def __init__(self): 

93 self.comm = None 

94 

95 def __getattr__(self, name): 

96 # Pickling of objects that carry instances of MPI class 

97 # (e.g. NEB) raises RecursionError since it tries to access 

98 # the optional __setstate__ method (which we do not implement) 

99 # when unpickling. The two lines below prevent the 

100 # RecursionError. This also affects modules that use pickling 

101 # e.g. multiprocessing. For more details see: 

102 # https://gitlab.com/ase/ase/-/merge_requests/2695 

103 if name == '__setstate__': 

104 raise AttributeError(name) 

105 

106 if self.comm is None: 

107 self.comm = _get_comm() 

108 return getattr(self.comm, name) 

109 

110 

111def _get_comm(): 

112 """Get the correct MPI world object.""" 

113 if 'mpi4py' in sys.modules: 

114 return MPI4PY() 

115 if '_gpaw' in sys.modules: 

116 import _gpaw 

117 if hasattr(_gpaw, 'Communicator'): 

118 return _gpaw.Communicator() 

119 if '_asap' in sys.modules: 

120 import _asap 

121 if hasattr(_asap, 'Communicator'): 

122 return _asap.Communicator() 

123 return DummyMPI() 

124 

125 

126class MPI4PY: 

127 def __init__(self, mpi4py_comm=None): 

128 if mpi4py_comm is None: 

129 from mpi4py import MPI 

130 mpi4py_comm = MPI.COMM_WORLD 

131 self.comm = mpi4py_comm 

132 

133 @property 

134 def rank(self): 

135 return self.comm.rank 

136 

137 @property 

138 def size(self): 

139 return self.comm.size 

140 

141 def _returnval(self, a, b): 

142 """Behave correctly when working on scalars/arrays. 

143 

144 Either input is an array and we in-place write b (output from 

145 mpi4py) back into a, or input is a scalar and we return the 

146 corresponding output scalar.""" 

147 if np.isscalar(a): 

148 assert np.isscalar(b) 

149 return b 

150 else: 

151 assert not np.isscalar(b) 

152 a[:] = b 

153 return None 

154 

155 def sum(self, a, root=-1): 

156 if root == -1: 

157 b = self.comm.allreduce(a) 

158 else: 

159 b = self.comm.reduce(a, root) 

160 if np.isscalar(a): 

161 warnings.warn('Please use sum_scalar(...) for scalar arguments', 

162 FutureWarning) 

163 return self._returnval(a, b) 

164 

165 def sum_scalar(self, a, root=-1): 

166 if root == -1: 

167 b = self.comm.allreduce(a) 

168 else: 

169 b = self.comm.reduce(a, root) 

170 return b 

171 

172 def split(self, split_size=None): 

173 """Divide the communicator.""" 

174 # color - subgroup id 

175 # key - new subgroup rank 

176 if not split_size: 

177 split_size = self.size 

178 color = int(self.rank // (self.size / split_size)) 

179 key = int(self.rank % (self.size / split_size)) 

180 comm = self.comm.Split(color, key) 

181 return MPI4PY(comm) 

182 

183 def barrier(self): 

184 self.comm.barrier() 

185 

186 def abort(self, code): 

187 self.comm.Abort(code) 

188 

189 def broadcast(self, a, root): 

190 b = self.comm.bcast(a, root=root) 

191 if self.rank == root: 

192 if np.isscalar(a): 

193 return a 

194 return 

195 return self._returnval(a, b) 

196 

197 

198world = None 

199 

200# Check for special MPI-enabled Python interpreters: 

201if '_gpaw' in sys.builtin_module_names: 

202 # http://wiki.fysik.dtu.dk/gpaw 

203 import _gpaw 

204 world = _gpaw.Communicator() 

205elif '_asap' in sys.builtin_module_names: 

206 # Modern version of Asap 

207 # http://wiki.fysik.dtu.dk/asap 

208 # We cannot import asap3.mpi here, as that creates an import deadlock 

209 import _asap 

210 world = _asap.Communicator() 

211 

212# Check if MPI implementation has been imported already: 

213elif '_gpaw' in sys.modules: 

214 # Same thing as above but for the module version 

215 import _gpaw 

216 try: 

217 world = _gpaw.Communicator() 

218 except AttributeError: 

219 pass 

220elif '_asap' in sys.modules: 

221 import _asap 

222 try: 

223 world = _asap.Communicator() 

224 except AttributeError: 

225 pass 

226elif 'mpi4py' in sys.modules: 

227 world = MPI4PY() 

228 

229if world is None: 

230 world = MPI() 

231 

232 

233def barrier(): 

234 world.barrier() 

235 

236 

237def broadcast(obj, root=0, comm=world): 

238 """Broadcast a Python object across an MPI communicator and return it.""" 

239 if comm.rank == root: 

240 string = pickle.dumps(obj, pickle.HIGHEST_PROTOCOL) 

241 n = np.array([len(string)], int) 

242 else: 

243 string = None 

244 n = np.empty(1, int) 

245 comm.broadcast(n, root) 

246 if comm.rank == root: 

247 string = np.frombuffer(string, np.int8) 

248 else: 

249 string = np.zeros(n, np.int8) 

250 comm.broadcast(string, root) 

251 if comm.rank == root: 

252 return obj 

253 else: 

254 return pickle.loads(string.tobytes()) 

255 

256 

257def parallel_function(func): 

258 """Decorator for broadcasting from master to slaves using MPI. 

259 

260 Disable by passing parallel=False to the function. For a method, 

261 you can also disable the parallel behavior by giving the instance 

262 a self.serial = True. 

263 """ 

264 

265 @functools.wraps(func) 

266 def new_func(*args, **kwargs): 

267 if (world.size == 1 or 

268 args and getattr(args[0], 'serial', False) or 

269 not kwargs.pop('parallel', True)): 

270 # Disable: 

271 return func(*args, **kwargs) 

272 

273 ex = None 

274 result = None 

275 if world.rank == 0: 

276 try: 

277 result = func(*args, **kwargs) 

278 except Exception as x: 

279 ex = x 

280 ex, result = broadcast((ex, result)) 

281 if ex is not None: 

282 raise ex 

283 return result 

284 

285 return new_func 

286 

287 

288def parallel_generator(generator): 

289 """Decorator for broadcasting yields from master to slaves using MPI. 

290 

291 Disable by passing parallel=False to the function. For a method, 

292 you can also disable the parallel behavior by giving the instance 

293 a self.serial = True. 

294 """ 

295 

296 @functools.wraps(generator) 

297 def new_generator(*args, **kwargs): 

298 if (world.size == 1 or 

299 args and getattr(args[0], 'serial', False) or 

300 not kwargs.pop('parallel', True)): 

301 # Disable: 

302 for result in generator(*args, **kwargs): 

303 yield result 

304 return 

305 

306 if world.rank == 0: 

307 try: 

308 for result in generator(*args, **kwargs): 

309 broadcast((None, result)) 

310 yield result 

311 except Exception as ex: 

312 broadcast((ex, None)) 

313 raise ex 

314 broadcast((None, None)) 

315 else: 

316 ex2, result = broadcast((None, None)) 

317 if ex2 is not None: 

318 raise ex2 

319 while result is not None: 

320 yield result 

321 ex2, result = broadcast((None, None)) 

322 if ex2 is not None: 

323 raise ex2 

324 

325 return new_generator 

326 

327 

328def register_parallel_cleanup_function(): 

329 """Call MPI_Abort if python crashes. 

330 

331 This will terminate the processes on the other nodes.""" 

332 

333 if world.size == 1: 

334 return 

335 

336 def cleanup(sys=sys, time=time, world=world): 

337 error = getattr(sys, 'last_type', None) 

338 if error: 

339 sys.stdout.flush() 

340 sys.stderr.write(('ASE CLEANUP (node %d): %s occurred. ' + 

341 'Calling MPI_Abort!\n') % (world.rank, error)) 

342 sys.stderr.flush() 

343 # Give other nodes a moment to crash by themselves (perhaps 

344 # producing helpful error messages): 

345 time.sleep(3) 

346 world.abort(42) 

347 

348 atexit.register(cleanup) 

349 

350 

351def distribute_cpus(size, comm): 

352 """Distribute cpus to tasks and calculators. 

353 

354 Input: 

355 size: number of nodes per calculator 

356 comm: total communicator object 

357 

358 Output: 

359 communicator for this rank, number of calculators, index for this rank 

360 """ 

361 

362 assert size <= comm.size 

363 assert comm.size % size == 0 

364 

365 tasks_rank = comm.rank // size 

366 

367 r0 = tasks_rank * size 

368 ranks = np.arange(r0, r0 + size) 

369 mycomm = comm.new_communicator(ranks) 

370 

371 return mycomm, comm.size // size, tasks_rank 

372 

373 

374def myslice(ntotal, comm): 

375 """Return the slice of your tasks for ntotal jobs""" 

376 n = -(-ntotal // comm.size) # ceil divide 

377 return slice(n * comm.rank, n * (comm.rank + 1))