Coverage for /builds/kinetik161/ase/ase/calculators/socketio.py: 90.73%
399 statements
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
« prev ^ index » next coverage.py v7.2.7, created at 2023-12-10 11:04 +0000
1import os
2import socket
3from contextlib import contextmanager
4from subprocess import PIPE, Popen
6import numpy as np
8import ase.units as units
9from ase.calculators.calculator import (Calculator,
10 PropertyNotImplementedError,
11 all_changes,
12 ArgvProfile,
13 OldShellProfile)
14from ase.calculators.genericfileio import GenericFileIOCalculator
15from ase.stress import full_3x3_to_voigt_6_stress
16from ase.utils import IOContext
19def actualunixsocketname(name):
20 return f'/tmp/ipi_{name}'
23class SocketClosed(OSError):
24 pass
27class IPIProtocol:
28 """Communication using IPI protocol."""
30 def __init__(self, socket, txt=None):
31 self.socket = socket
33 if txt is None:
34 def log(*args):
35 pass
36 else:
37 def log(*args):
38 print('Driver:', *args, file=txt)
39 txt.flush()
40 self.log = log
42 def sendmsg(self, msg):
43 self.log(' sendmsg', repr(msg))
44 # assert msg in self.statements, msg
45 msg = msg.encode('ascii').ljust(12)
46 self.socket.sendall(msg)
48 def _recvall(self, nbytes):
49 """Repeatedly read chunks until we have nbytes.
51 Normally we get all bytes in one read, but that is not guaranteed."""
52 remaining = nbytes
53 chunks = []
54 while remaining > 0:
55 chunk = self.socket.recv(remaining)
56 if len(chunk) == 0:
57 # (If socket is still open, recv returns at least one byte)
58 raise SocketClosed()
59 chunks.append(chunk)
60 remaining -= len(chunk)
61 msg = b''.join(chunks)
62 assert len(msg) == nbytes and remaining == 0
63 return msg
65 def recvmsg(self):
66 msg = self._recvall(12)
67 if not msg:
68 raise SocketClosed()
70 assert len(msg) == 12, msg
71 msg = msg.rstrip().decode('ascii')
72 # assert msg in self.responses, msg
73 self.log(' recvmsg', repr(msg))
74 return msg
76 def send(self, a, dtype):
77 buf = np.asarray(a, dtype).tobytes()
78 # self.log(' send {}'.format(np.array(a).ravel().tolist()))
79 self.log(f' send {len(buf)} bytes of {dtype}')
80 self.socket.sendall(buf)
82 def recv(self, shape, dtype):
83 a = np.empty(shape, dtype)
84 nbytes = np.dtype(dtype).itemsize * np.prod(shape)
85 buf = self._recvall(nbytes)
86 assert len(buf) == nbytes, (len(buf), nbytes)
87 self.log(f' recv {len(buf)} bytes of {dtype}')
88 # print(np.frombuffer(buf, dtype=dtype))
89 a.flat[:] = np.frombuffer(buf, dtype=dtype)
90 # self.log(' recv {}'.format(a.ravel().tolist()))
91 assert np.isfinite(a).all()
92 return a
94 def sendposdata(self, cell, icell, positions):
95 assert cell.size == 9
96 assert icell.size == 9
97 assert positions.size % 3 == 0
99 self.log(' sendposdata')
100 self.sendmsg('POSDATA')
101 self.send(cell.T / units.Bohr, np.float64)
102 self.send(icell.T * units.Bohr, np.float64)
103 self.send(len(positions), np.int32)
104 self.send(positions / units.Bohr, np.float64)
106 def recvposdata(self):
107 cell = self.recv((3, 3), np.float64).T.copy()
108 icell = self.recv((3, 3), np.float64).T.copy()
109 natoms = self.recv(1, np.int32)[0]
110 positions = self.recv((natoms, 3), np.float64)
111 return cell * units.Bohr, icell / units.Bohr, positions * units.Bohr
113 def sendrecv_force(self):
114 self.log(' sendrecv_force')
115 self.sendmsg('GETFORCE')
116 msg = self.recvmsg()
117 assert msg == 'FORCEREADY', msg
118 e = self.recv(1, np.float64)[0]
119 natoms = self.recv(1, np.int32)[0]
120 assert natoms >= 0
121 forces = self.recv((int(natoms), 3), np.float64)
122 virial = self.recv((3, 3), np.float64).T.copy()
123 nmorebytes = self.recv(1, np.int32)[0]
124 morebytes = self.recv(nmorebytes, np.byte)
125 return (e * units.Ha, (units.Ha / units.Bohr) * forces,
126 units.Ha * virial, morebytes)
128 def sendforce(self, energy, forces, virial,
129 morebytes=np.zeros(1, dtype=np.byte)):
130 assert np.array([energy]).size == 1
131 assert forces.shape[1] == 3
132 assert virial.shape == (3, 3)
134 self.log(' sendforce')
135 self.sendmsg('FORCEREADY') # mind the units
136 self.send(np.array([energy / units.Ha]), np.float64)
137 natoms = len(forces)
138 self.send(np.array([natoms]), np.int32)
139 self.send(units.Bohr / units.Ha * forces, np.float64)
140 self.send(1.0 / units.Ha * virial.T, np.float64)
141 # We prefer to always send at least one byte due to trouble with
142 # empty messages. Reading a closed socket yields 0 bytes
143 # and thus can be confused with a 0-length bytestring.
144 self.send(np.array([len(morebytes)]), np.int32)
145 self.send(morebytes, np.byte)
147 def status(self):
148 self.log(' status')
149 self.sendmsg('STATUS')
150 msg = self.recvmsg()
151 return msg
153 def end(self):
154 self.log(' end')
155 self.sendmsg('EXIT')
157 def recvinit(self):
158 self.log(' recvinit')
159 bead_index = self.recv(1, np.int32)
160 nbytes = self.recv(1, np.int32)
161 initbytes = self.recv(nbytes, np.byte)
162 return bead_index, initbytes
164 def sendinit(self):
165 # XXX Not sure what this function is supposed to send.
166 # It 'works' with QE, but for now we try not to call it.
167 self.log(' sendinit')
168 self.sendmsg('INIT')
169 self.send(0, np.int32) # 'bead index' always zero for now
170 # We send one byte, which is zero, since things may not work
171 # with 0 bytes. Apparently implementations ignore the
172 # initialization string anyway.
173 self.send(1, np.int32)
174 self.send(np.zeros(1), np.byte) # initialization string
176 def calculate(self, positions, cell):
177 self.log('calculate')
178 msg = self.status()
179 # We don't know how NEEDINIT is supposed to work, but some codes
180 # seem to be okay if we skip it and send the positions instead.
181 if msg == 'NEEDINIT':
182 self.sendinit()
183 msg = self.status()
184 assert msg == 'READY', msg
185 icell = np.linalg.pinv(cell).transpose()
186 self.sendposdata(cell, icell, positions)
187 msg = self.status()
188 assert msg == 'HAVEDATA', msg
189 e, forces, virial, morebytes = self.sendrecv_force()
190 r = dict(energy=e,
191 forces=forces,
192 virial=virial,
193 morebytes=morebytes)
194 return r
197@contextmanager
198def bind_unixsocket(socketfile):
199 assert socketfile.startswith('/tmp/ipi_'), socketfile
200 serversocket = socket.socket(socket.AF_UNIX)
201 try:
202 serversocket.bind(socketfile)
203 except OSError as err:
204 raise OSError(f'{err}: {repr(socketfile)}')
206 try:
207 with serversocket:
208 yield serversocket
209 finally:
210 os.unlink(socketfile)
213@contextmanager
214def bind_inetsocket(port):
215 serversocket = socket.socket(socket.AF_INET)
216 serversocket.setsockopt(socket.SOL_SOCKET,
217 socket.SO_REUSEADDR, 1)
218 serversocket.bind(('', port))
219 with serversocket:
220 yield serversocket
223class FileIOSocketClientLauncher:
224 def __init__(self, calc):
225 self.calc = calc
227 def __call__(self, atoms, properties=None, port=None, unixsocket=None):
228 assert self.calc is not None
229 cwd = self.calc.directory
231 profile = getattr(self.calc, 'profile', None)
232 if isinstance(self.calc, GenericFileIOCalculator):
233 # New GenericFileIOCalculator:
235 self.calc.write_inputfiles(atoms, properties)
236 if unixsocket is not None:
237 argv = profile.socketio_argv_unix(socket=unixsocket)
238 else:
239 argv = profile.socketio_argv_inet(port=port)
240 import os
241 return Popen(argv, cwd=cwd, env=os.environ)
242 else:
243 # Old FileIOCalculator:
244 if profile is None:
245 cmd = self.calc.command.replace('PREFIX', self.calc.prefix)
246 cmd = cmd.format(port=port, unixsocket=unixsocket)
247 elif isinstance(profile, OldShellProfile):
248 cmd = profile.command
249 if "PREFIX" in cmd:
250 cmd = cmd.replace("PREFIX", profile.prefix)
251 elif isinstance(profile, ArgvProfile):
252 cmd = " ".join(profile.argv)
254 self.calc.write_input(atoms, properties=properties,
255 system_changes=all_changes)
256 return Popen(cmd, shell=True, cwd=cwd)
259class SocketServer(IOContext):
260 default_port = 31415
262 def __init__(self, # launch_client=None,
263 port=None, unixsocket=None, timeout=None,
264 log=None):
265 """Create server and listen for connections.
267 Parameters:
269 client_command: Shell command to launch client process, or None
270 The process will be launched immediately, if given.
271 Else the user is expected to launch a client whose connection
272 the server will then accept at any time.
273 One calculate() is called, the server will block to wait
274 for the client.
275 port: integer or None
276 Port on which to listen for INET connections. Defaults
277 to 31415 if neither this nor unixsocket is specified.
278 unixsocket: string or None
279 Filename for unix socket.
280 timeout: float or None
281 timeout in seconds, or unlimited by default.
282 This parameter is passed to the Python socket object; see
283 documentation therof
284 log: file object or None
285 useful debug messages are written to this."""
287 if unixsocket is None and port is None:
288 port = self.default_port
289 elif unixsocket is not None and port is not None:
290 raise ValueError('Specify only one of unixsocket and port')
292 self.port = port
293 self.unixsocket = unixsocket
294 self.timeout = timeout
295 self._closed = False
297 if unixsocket is not None:
298 actualsocket = actualunixsocketname(unixsocket)
299 conn_name = f'UNIX-socket {actualsocket}'
300 socket_context = bind_unixsocket(actualsocket)
301 else:
302 conn_name = f'INET port {port}'
303 socket_context = bind_inetsocket(port)
305 self.serversocket = self.closelater(socket_context)
307 if log:
308 print(f'Accepting clients on {conn_name}', file=log)
310 self.serversocket.settimeout(timeout)
312 self.serversocket.listen(1)
314 self.log = log
316 self.proc = None
318 self.protocol = None
319 self.clientsocket = None
320 self.address = None
322 # if launch_client is not None:
323 # self.proc = launch_client(port=port, unixsocket=unixsocket)
325 def _accept(self):
326 """Wait for client and establish connection."""
327 # It should perhaps be possible for process to be launched by user
328 log = self.log
329 if log:
330 print('Awaiting client', file=self.log)
332 # If we launched the subprocess, the process may crash.
333 # We want to detect this, using loop with timeouts, and
334 # raise an error rather than blocking forever.
335 if self.proc is not None:
336 self.serversocket.settimeout(1.0)
338 while True:
339 try:
340 self.clientsocket, self.address = self.serversocket.accept()
341 self.closelater(self.clientsocket)
342 except socket.timeout:
343 if self.proc is not None:
344 status = self.proc.poll()
345 if status is not None:
346 raise OSError('Subprocess terminated unexpectedly'
347 ' with status {}'.format(status))
348 else:
349 break
351 self.serversocket.settimeout(self.timeout)
352 self.clientsocket.settimeout(self.timeout)
354 if log:
355 # For unix sockets, address is b''.
356 source = ('client' if self.address == b'' else self.address)
357 print(f'Accepted connection from {source}', file=log)
359 self.protocol = IPIProtocol(self.clientsocket, txt=log)
361 def close(self):
362 if self._closed:
363 return
365 super().close()
367 if self.log:
368 print('Close socket server', file=self.log)
369 self._closed = True
371 # Proper way to close sockets?
372 # And indeed i-pi connections...
373 # if self.protocol is not None:
374 # self.protocol.end() # Send end-of-communication string
375 self.protocol = None
376 if self.proc is not None:
377 exitcode = self.proc.wait()
378 if exitcode != 0:
379 import warnings
381 # Quantum Espresso seems to always exit with status 128,
382 # even if successful.
383 # Should investigate at some point
384 warnings.warn('Subprocess exited with status {}'
385 .format(exitcode))
386 # self.log('IPI server closed')
388 def calculate(self, atoms):
389 """Send geometry to client and return calculated things as dict.
391 This will block until client has established connection, then
392 wait for the client to finish the calculation."""
393 assert not self._closed
395 # If we have not established connection yet, we must block
396 # until the client catches up:
397 if self.protocol is None:
398 self._accept()
399 return self.protocol.calculate(atoms.positions, atoms.cell)
402class SocketClient:
403 def __init__(self, host='localhost', port=None,
404 unixsocket=None, timeout=None, log=None, comm=None):
405 """Create client and connect to server.
407 Parameters:
409 host: string
410 Hostname of server. Defaults to localhost
411 port: integer or None
412 Port to which to connect. By default 31415.
413 unixsocket: string or None
414 If specified, use corresponding UNIX socket.
415 See documentation of unixsocket for SocketIOCalculator.
416 timeout: float or None
417 See documentation of timeout for SocketIOCalculator.
418 log: file object or None
419 Log events to this file
420 comm: communicator or None
421 MPI communicator object. Defaults to ase.parallel.world.
422 When ASE runs in parallel, only the process with world.rank == 0
423 will communicate over the socket. The received information
424 will then be broadcast on the communicator. The SocketClient
425 must be created on all ranks of world, and will see the same
426 Atoms objects."""
427 if comm is None:
428 from ase.parallel import world
429 comm = world
431 # Only rank0 actually does the socket work.
432 # The other ranks only need to follow.
433 #
434 # Note: We actually refrain from assigning all the
435 # socket-related things except on master
436 self.comm = comm
438 if self.comm.rank == 0:
439 if unixsocket is not None:
440 sock = socket.socket(socket.AF_UNIX)
441 actualsocket = actualunixsocketname(unixsocket)
442 sock.connect(actualsocket)
443 else:
444 if port is None:
445 port = SocketServer.default_port
446 sock = socket.socket(socket.AF_INET)
447 sock.connect((host, port))
448 sock.settimeout(timeout)
449 self.host = host
450 self.port = port
451 self.unixsocket = unixsocket
453 self.protocol = IPIProtocol(sock, txt=log)
454 self.log = self.protocol.log
455 self.closed = False
457 self.bead_index = 0
458 self.bead_initbytes = b''
459 self.state = 'READY'
461 def close(self):
462 if not self.closed:
463 self.log('Close SocketClient')
464 self.closed = True
465 self.protocol.socket.close()
467 def calculate(self, atoms, use_stress):
468 # We should also broadcast the bead index, once we support doing
469 # multiple beads.
470 self.comm.broadcast(atoms.positions, 0)
471 self.comm.broadcast(np.ascontiguousarray(atoms.cell), 0)
473 energy = atoms.get_potential_energy()
474 forces = atoms.get_forces()
475 if use_stress:
476 stress = atoms.get_stress(voigt=False)
477 virial = -atoms.get_volume() * stress
478 else:
479 virial = np.zeros((3, 3))
480 return energy, forces, virial
482 def irun(self, atoms, use_stress=None):
483 if use_stress is None:
484 use_stress = any(atoms.pbc)
486 my_irun = self.irun_rank0 if self.comm.rank == 0 else self.irun_rankN
487 return my_irun(atoms, use_stress)
489 def irun_rankN(self, atoms, use_stress=True):
490 stop_criterion = np.zeros(1, bool)
491 while True:
492 self.comm.broadcast(stop_criterion, 0)
493 if stop_criterion[0]:
494 return
496 self.calculate(atoms, use_stress)
497 yield
499 def irun_rank0(self, atoms, use_stress=True):
500 # For every step we either calculate or quit. We need to
501 # tell other MPI processes (if this is MPI-parallel) whether they
502 # should calculate or quit.
503 try:
504 while True:
505 try:
506 msg = self.protocol.recvmsg()
507 except SocketClosed:
508 # Server closed the connection, but we want to
509 # exit gracefully anyway
510 msg = 'EXIT'
512 if msg == 'EXIT':
513 # Send stop signal to clients:
514 self.comm.broadcast(np.ones(1, bool), 0)
515 # (When otherwise exiting, things crashed and we should
516 # let MPI_ABORT take care of the mess instead of trying
517 # to synchronize the exit)
518 return
519 elif msg == 'STATUS':
520 self.protocol.sendmsg(self.state)
521 elif msg == 'POSDATA':
522 assert self.state == 'READY'
523 cell, icell, positions = self.protocol.recvposdata()
524 atoms.cell[:] = cell
525 atoms.positions[:] = positions
527 # User may wish to do something with the atoms object now.
528 # Should we provide option to yield here?
529 #
530 # (In that case we should MPI-synchronize *before*
531 # whereas now we do it after.)
533 # Send signal for other ranks to proceed with calculation:
534 self.comm.broadcast(np.zeros(1, bool), 0)
535 energy, forces, virial = self.calculate(atoms, use_stress)
537 self.state = 'HAVEDATA'
538 yield
539 elif msg == 'GETFORCE':
540 assert self.state == 'HAVEDATA', self.state
541 self.protocol.sendforce(energy, forces, virial)
542 self.state = 'NEEDINIT'
543 elif msg == 'INIT':
544 assert self.state == 'NEEDINIT'
545 bead_index, initbytes = self.protocol.recvinit()
546 self.bead_index = bead_index
547 self.bead_initbytes = initbytes
548 self.state = 'READY'
549 else:
550 raise KeyError('Bad message', msg)
551 finally:
552 self.close()
554 def run(self, atoms, use_stress=False):
555 for _ in self.irun(atoms, use_stress=use_stress):
556 pass
559class SocketIOCalculator(Calculator, IOContext):
560 implemented_properties = ['energy', 'free_energy', 'forces', 'stress']
561 supported_changes = {'positions', 'cell'}
563 def __init__(self, calc=None, port=None,
564 unixsocket=None, timeout=None, log=None, *,
565 launch_client=None):
566 """Initialize socket I/O calculator.
568 This calculator launches a server which passes atomic
569 coordinates and unit cells to an external code via a socket,
570 and receives energy, forces, and stress in return.
572 ASE integrates this with the Quantum Espresso, FHI-aims and
573 Siesta calculators. This works with any external code that
574 supports running as a client over the i-PI protocol.
576 Parameters:
578 calc: calculator or None
580 If calc is not None, a client process will be launched
581 using calc.command, and the input file will be generated
582 using ``calc.write_input()``. Otherwise only the server will
583 run, and it is up to the user to launch a compliant client
584 process.
586 port: integer
588 port number for socket. Should normally be between 1025
589 and 65535. Typical ports for are 31415 (default) or 3141.
591 unixsocket: str or None
593 if not None, ignore host and port, creating instead a
594 unix socket using this name prefixed with ``/tmp/ipi_``.
595 The socket is deleted when the calculator is closed.
597 timeout: float >= 0 or None
599 timeout for connection, by default infinite. See
600 documentation of Python sockets. For longer jobs it is
601 recommended to set a timeout in case of undetected
602 client-side failure.
604 log: file object or None (default)
606 logfile for communication over socket. For debugging or
607 the curious.
609 In order to correctly close the sockets, it is
610 recommended to use this class within a with-block:
612 >>> from ase.calculators.socketio import SocketIOCalculator
614 >>> with SocketIOCalculator(...) as calc: # doctest:+SKIP
615 ... atoms.calc = calc
616 ... atoms.get_forces()
617 ... atoms.rattle()
618 ... atoms.get_forces()
620 It is also possible to call calc.close() after
621 use. This is best done in a finally-block."""
623 Calculator.__init__(self)
625 if calc is not None:
626 if launch_client is not None:
627 raise ValueError('Cannot pass both calc and launch_client')
628 launch_client = FileIOSocketClientLauncher(calc)
629 self.launch_client = launch_client
630 self.timeout = timeout
631 self.server = None
633 self.log = self.openfile(log)
635 # We only hold these so we can pass them on to the server.
636 # They may both be None as stored here.
637 self._port = port
638 self._unixsocket = unixsocket
640 # If there is a calculator, we will launch in calculate() because
641 # we are responsible for executing the external process, too, and
642 # should do so before blocking. Without a calculator we want to
643 # block immediately:
644 if self.launch_client is None:
645 self.server = self.launch_server()
647 def todict(self):
648 d = {'type': 'calculator',
649 'name': 'socket-driver'}
650 # if self.calc is not None:
651 # d['calc'] = self.calc.todict()
652 return d
654 def launch_server(self):
655 return self.closelater(SocketServer(
656 # launch_client=launch_client,
657 port=self._port,
658 unixsocket=self._unixsocket,
659 timeout=self.timeout, log=self.log,
660 ))
662 def calculate(self, atoms=None, properties=['energy'],
663 system_changes=all_changes):
664 bad = [change for change in system_changes
665 if change not in self.supported_changes]
667 # First time calculate() is called, system_changes will be
668 # all_changes. After that, only positions and cell may change.
669 if self.atoms is not None and any(bad):
670 raise PropertyNotImplementedError(
671 'Cannot change {} through IPI protocol. '
672 'Please create new socket calculator.'
673 .format(bad if len(bad) > 1 else bad[0]))
675 self.atoms = atoms.copy()
677 if self.server is None:
678 self.server = self.launch_server()
679 proc = self.launch_client(atoms, properties,
680 port=self._port,
681 unixsocket=self._unixsocket)
682 self.server.proc = proc # XXX nasty hack
684 results = self.server.calculate(atoms)
685 results['free_energy'] = results['energy']
686 virial = results.pop('virial')
687 if self.atoms.cell.rank == 3 and any(self.atoms.pbc):
688 vol = atoms.get_volume()
689 results['stress'] = -full_3x3_to_voigt_6_stress(virial) / vol
690 self.results.update(results)
692 def close(self):
693 self.server = None
694 super().close()
697class PySocketIOClient:
698 def __init__(self, calculator_factory):
699 self._calculator_factory = calculator_factory
701 def __call__(self, atoms, properties=None, port=None, unixsocket=None):
702 import pickle
703 import sys
705 # We pickle everything first, so we won't need to bother with the
706 # process as long as it succeeds.
707 transferbytes = pickle.dumps([
708 dict(unixsocket=unixsocket, port=port),
709 atoms.copy(),
710 self._calculator_factory,
711 ])
713 proc = Popen([sys.executable, '-m', 'ase.calculators.socketio'],
714 stdin=PIPE)
716 proc.stdin.write(transferbytes)
717 proc.stdin.close()
718 return proc
720 @staticmethod
721 def main():
722 import pickle
723 import sys
725 socketinfo, atoms, get_calculator = pickle.load(sys.stdin.buffer)
726 atoms.calc = get_calculator()
727 client = SocketClient(host='localhost',
728 unixsocket=socketinfo.get('unixsocket'),
729 port=socketinfo.get('port'))
730 # XXX In principle we could avoid calculating stress until
731 # someone requests the stress, could we not?
732 # Which would make use_stress boolean unnecessary.
733 client.run(atoms, use_stress=True)
736if __name__ == '__main__':
737 PySocketIOClient.main()