Coverage for /builds/kinetik161/ase/ase/calculators/subprocesscalculator.py: 91.41%

198 statements  

« prev     ^ index     » next       coverage.py v7.2.7, created at 2023-12-10 11:04 +0000

1import os 

2import pickle 

3import sys 

4from abc import ABC, abstractmethod 

5from subprocess import PIPE, Popen 

6 

7from ase.calculators.calculator import Calculator, all_properties 

8 

9 

10class PackedCalculator(ABC): 

11 """Portable calculator for use via PythonSubProcessCalculator. 

12 

13 This class allows creating and talking to a calculator which 

14 exists inside a different process, possibly with MPI or srun. 

15 

16 Use this when you want to use ASE mostly in serial, but run some 

17 calculations in a parallel Python environment. 

18 

19 Most existing calculators can be used this way through the 

20 NamedPackedCalculator implementation. To customize the behaviour 

21 for other calculators, write a custom class inheriting this one. 

22 

23 Example:: 

24 

25 from ase.build import bulk 

26 

27 atoms = bulk('Au') 

28 pack = NamedPackedCalculator('emt') 

29 

30 with pack.calculator() as atoms.calc: 

31 energy = atoms.get_potential_energy() 

32 

33 The computation takes place inside a subprocess which lives as long 

34 as the with statement. 

35 """ 

36 

37 @abstractmethod 

38 def unpack_calculator(self) -> Calculator: 

39 """Return the calculator packed inside. 

40 

41 This method will be called inside the subprocess doing 

42 computations.""" 

43 

44 def calculator(self, mpi_command=None) -> 'PythonSubProcessCalculator': 

45 """Return a PythonSubProcessCalculator for this calculator. 

46 

47 The subprocess calculator wraps a subprocess containing 

48 the actual calculator, and computations are done inside that 

49 subprocess.""" 

50 return PythonSubProcessCalculator(self, mpi_command=mpi_command) 

51 

52 

53class NamedPackedCalculator(PackedCalculator): 

54 """PackedCalculator implementation which works with standard calculators. 

55 

56 This works with calculators known by ase.calculators.calculator.""" 

57 

58 def __init__(self, name, kwargs=None): 

59 self._name = name 

60 if kwargs is None: 

61 kwargs = {} 

62 self._kwargs = kwargs 

63 

64 def unpack_calculator(self): 

65 from ase.calculators.calculator import get_calculator_class 

66 cls = get_calculator_class(self._name) 

67 return cls(**self._kwargs) 

68 

69 def __repr__(self): 

70 return f'{self.__class__.__name__}({self._name}, {self._kwargs})' 

71 

72 

73class MPICommand: 

74 def __init__(self, argv): 

75 self.argv = argv 

76 

77 @classmethod 

78 def python_argv(cls): 

79 return [sys.executable, '-m', 'ase.calculators.subprocesscalculator'] 

80 

81 @classmethod 

82 def parallel(cls, nprocs, mpi_argv=()): 

83 return cls(['mpiexec', '-n', str(nprocs)] 

84 + list(mpi_argv) 

85 + cls.python_argv() 

86 + ['mpi4py']) 

87 

88 @classmethod 

89 def serial(cls): 

90 return MPICommand(cls.python_argv() + ['standard']) 

91 

92 def execute(self): 

93 # On this computer (Ubuntu 20.04 + OpenMPI) the subprocess crashes 

94 # without output during startup if os.environ is not passed along. 

95 # Hence we pass os.environ. Not sure if this is a machine thing 

96 # or in general. --askhl 

97 return Popen(self.argv, stdout=PIPE, 

98 stdin=PIPE, env=os.environ) 

99 

100 

101def gpaw_process(ncores=1, **kwargs): 

102 packed = NamedPackedCalculator('gpaw', kwargs) 

103 mpicommand = MPICommand([ 

104 sys.executable, '-m', 'gpaw', '-P', str(ncores), 'python', '-m', 

105 'ase.calculators.subprocesscalculator', 'standard', 

106 ]) 

107 return PythonSubProcessCalculator(packed, mpicommand) 

108 

109 

110class PythonSubProcessCalculator(Calculator): 

111 """Calculator for running calculations in external processes. 

112 

113 TODO: This should work with arbitrary commands including MPI stuff. 

114 

115 This calculator runs a subprocess wherein it sets up an 

116 actual calculator. Calculations are forwarded through pickle 

117 to that calculator, which returns results through pickle.""" 

118 implemented_properties = list(all_properties) 

119 

120 def __init__(self, calc_input, mpi_command=None): 

121 super().__init__() 

122 

123 # self.proc = None 

124 self.calc_input = calc_input 

125 if mpi_command is None: 

126 mpi_command = MPICommand.serial() 

127 self.mpi_command = mpi_command 

128 

129 self.protocol = None 

130 

131 def set(self, **kwargs): 

132 if hasattr(self, 'client'): 

133 raise RuntimeError('No setting things for now, thanks') 

134 

135 def __repr__(self): 

136 return '{}({})'.format(type(self).__name__, 

137 self.calc_input) 

138 

139 def __enter__(self): 

140 assert self.protocol is None 

141 proc = self.mpi_command.execute() 

142 self.protocol = Protocol(proc) 

143 self.protocol.send(self.calc_input) 

144 return self 

145 

146 def __exit__(self, *args): 

147 self.protocol.send('stop') 

148 self.protocol.proc.communicate() 

149 self.protocol = None 

150 

151 def _run_calculation(self, atoms, properties, system_changes): 

152 self.protocol.send('calculate') 

153 self.protocol.send((atoms, properties, system_changes)) 

154 

155 def calculate(self, atoms, properties, system_changes): 

156 Calculator.calculate(self, atoms, properties, system_changes) 

157 # We send a pickle of self.atoms because this is a fresh copy 

158 # of the input, but without an unpicklable calculator: 

159 self._run_calculation(self.atoms.copy(), properties, system_changes) 

160 results = self.protocol.recv() 

161 self.results.update(results) 

162 

163 def backend(self): 

164 return ParallelBackendInterface(self) 

165 

166 

167class Protocol: 

168 def __init__(self, proc): 

169 self.proc = proc 

170 

171 def send(self, obj): 

172 pickle.dump(obj, self.proc.stdin) 

173 self.proc.stdin.flush() 

174 

175 def recv(self): 

176 response_type, value = pickle.load(self.proc.stdout) 

177 

178 if response_type == 'raise': 

179 raise value 

180 

181 assert response_type == 'return' 

182 return value 

183 

184 

185class MockMethod: 

186 def __init__(self, name, calc): 

187 self.name = name 

188 self.calc = calc 

189 

190 def __call__(self, *args, **kwargs): 

191 protocol = self.calc.protocol 

192 protocol.send('callmethod') 

193 protocol.send([self.name, args, kwargs]) 

194 return protocol.recv() 

195 

196 

197class ParallelBackendInterface: 

198 def __init__(self, calc): 

199 self.calc = calc 

200 

201 def __getattr__(self, name): 

202 return MockMethod(name, self.calc) 

203 

204 

205run_modes = {'standard', 'mpi4py'} 

206 

207 

208def callmethod(calc, attrname, args, kwargs): 

209 method = getattr(calc, attrname) 

210 value = method(*args, **kwargs) 

211 return value 

212 

213 

214def callfunction(func, args, kwargs): 

215 return func(*args, **kwargs) 

216 

217 

218def calculate(calc, atoms, properties, system_changes): 

219 # Again we need formalization of the results/outputs, and 

220 # a way to programmatically access all available properties. 

221 # We do a wild hack for now: 

222 calc.results.clear() 

223 # If we don't clear(), the caching is broken! For stress. 

224 # But not for forces. What dark magic from the depths of the 

225 # underworld is at play here? 

226 calc.calculate(atoms=atoms, properties=properties, 

227 system_changes=system_changes) 

228 results = calc.results 

229 return results 

230 

231 

232def bad_mode(): 

233 return SystemExit(f'sys.argv[1] must be one of {run_modes}') 

234 

235 

236def parallel_startup(): 

237 try: 

238 run_mode = sys.argv[1] 

239 except IndexError: 

240 raise bad_mode() 

241 

242 if run_mode not in run_modes: 

243 raise bad_mode() 

244 

245 if run_mode == 'mpi4py': 

246 # We must import mpi4py before the rest of ASE, or world will not 

247 # be correctly initialized. 

248 import mpi4py # noqa 

249 

250 # We switch stdout so stray print statements won't interfere with outputs: 

251 binary_stdout = sys.stdout.buffer 

252 sys.stdout = sys.stderr 

253 

254 return Client(input_fd=sys.stdin.buffer, 

255 output_fd=binary_stdout) 

256 

257 

258class Client: 

259 def __init__(self, input_fd, output_fd): 

260 from ase.parallel import world 

261 self._world = world 

262 self.input_fd = input_fd 

263 self.output_fd = output_fd 

264 

265 def recv(self): 

266 from ase.parallel import broadcast 

267 if self._world.rank == 0: 

268 obj = pickle.load(self.input_fd) 

269 else: 

270 obj = None 

271 

272 obj = broadcast(obj, 0, self._world) 

273 return obj 

274 

275 def send(self, obj): 

276 if self._world.rank == 0: 

277 pickle.dump(obj, self.output_fd) 

278 self.output_fd.flush() 

279 

280 def mainloop(self, calc): 

281 while True: 

282 instruction = self.recv() 

283 if instruction == 'stop': 

284 return 

285 

286 instruction_data = self.recv() 

287 

288 response_type, value = self.process_instruction( 

289 calc, instruction, instruction_data) 

290 self.send((response_type, value)) 

291 

292 def process_instruction(self, calc, instruction, instruction_data): 

293 if instruction == 'callmethod': 

294 function = callmethod 

295 args = (calc, *instruction_data) 

296 elif instruction == 'calculate': 

297 function = calculate 

298 args = (calc, *instruction_data) 

299 elif instruction == 'callfunction': 

300 function = callfunction 

301 args = instruction_data 

302 else: 

303 raise RuntimeError(f'Bad instruction: {instruction}') 

304 

305 try: 

306 print('ARGS', args) 

307 value = function(*args) 

308 except Exception as ex: 

309 import traceback 

310 traceback.print_exc() 

311 response_type = 'raise' 

312 value = ex 

313 else: 

314 response_type = 'return' 

315 return response_type, value 

316 

317 

318class ParallelDispatch: 

319 """Utility class to run functions in parallel. 

320 

321 with ParallelDispatch(...) as parallel: 

322 parallel.call(function, args, kwargs) 

323 

324 """ 

325 

326 def __init__(self, mpicommand): 

327 self._mpicommand = mpicommand 

328 self._protocol = None 

329 

330 def call(self, func, *args, **kwargs): 

331 self._protocol.send('callfunction') 

332 self._protocol.send((func, args, kwargs)) 

333 return self._protocol.recv() 

334 

335 def __enter__(self): 

336 assert self._protocol is None 

337 self._protocol = Protocol(self._mpicommand.execute()) 

338 

339 # Even if we are not using a calculator, we have to send one: 

340 pack = NamedPackedCalculator('emt', {}) 

341 self._protocol.send(pack) 

342 # (We should get rid of that requirement.) 

343 

344 return self 

345 

346 def __exit__(self, *args): 

347 self._protocol.send('stop') 

348 self._protocol.proc.communicate() 

349 self._protocol = None 

350 

351 

352def main(): 

353 client = parallel_startup() 

354 pack = client.recv() 

355 calc = pack.unpack_calculator() 

356 client.mainloop(calc) 

357 

358 

359if __name__ == '__main__': 

360 main()