Coverage for /builds/alexhroom/ase/ase/calculators/socketio.py: 91.65%
395 statements
« prev ^ index » next coverage.py v7.5.3, created at 2024-08-05 14:37 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2024-08-05 14:37 +0000
1import os
2import socket
3from contextlib import contextmanager
4from subprocess import PIPE, Popen
6import numpy as np
8import ase.units as units
9from ase.calculators.calculator import (StandardProfile, Calculator,
10 OldShellProfile,
11 PropertyNotImplementedError,
12 all_changes)
13from ase.calculators.genericfileio import GenericFileIOCalculator
14from ase.parallel import world
15from ase.stress import full_3x3_to_voigt_6_stress
16from ase.utils import IOContext
19def actualunixsocketname(name):
20 return f'/tmp/ipi_{name}'
23class SocketClosed(OSError):
24 pass
27class IPIProtocol:
28 """Communication using IPI protocol."""
30 def __init__(self, socket, txt=None):
31 self.socket = socket
33 if txt is None:
34 def log(*args):
35 pass
36 else:
37 def log(*args):
38 print('Driver:', *args, file=txt)
39 txt.flush()
40 self.log = log
42 def sendmsg(self, msg):
43 self.log(' sendmsg', repr(msg))
44 # assert msg in self.statements, msg
45 msg = msg.encode('ascii').ljust(12)
46 self.socket.sendall(msg)
48 def _recvall(self, nbytes):
49 """Repeatedly read chunks until we have nbytes.
51 Normally we get all bytes in one read, but that is not guaranteed."""
52 remaining = nbytes
53 chunks = []
54 while remaining > 0:
55 chunk = self.socket.recv(remaining)
56 if len(chunk) == 0:
57 # (If socket is still open, recv returns at least one byte)
58 raise SocketClosed
59 chunks.append(chunk)
60 remaining -= len(chunk)
61 msg = b''.join(chunks)
62 assert len(msg) == nbytes and remaining == 0
63 return msg
65 def recvmsg(self):
66 msg = self._recvall(12)
67 if not msg:
68 raise SocketClosed
70 assert len(msg) == 12, msg
71 msg = msg.rstrip().decode('ascii')
72 # assert msg in self.responses, msg
73 self.log(' recvmsg', repr(msg))
74 return msg
76 def send(self, a, dtype):
77 buf = np.asarray(a, dtype).tobytes()
78 # self.log(' send {}'.format(np.array(a).ravel().tolist()))
79 self.log(f' send {len(buf)} bytes of {dtype}')
80 self.socket.sendall(buf)
82 def recv(self, shape, dtype):
83 a = np.empty(shape, dtype)
84 nbytes = np.dtype(dtype).itemsize * np.prod(shape)
85 buf = self._recvall(nbytes)
86 assert len(buf) == nbytes, (len(buf), nbytes)
87 self.log(f' recv {len(buf)} bytes of {dtype}')
88 # print(np.frombuffer(buf, dtype=dtype))
89 a.flat[:] = np.frombuffer(buf, dtype=dtype)
90 # self.log(' recv {}'.format(a.ravel().tolist()))
91 assert np.isfinite(a).all()
92 return a
94 def sendposdata(self, cell, icell, positions):
95 assert cell.size == 9
96 assert icell.size == 9
97 assert positions.size % 3 == 0
99 self.log(' sendposdata')
100 self.sendmsg('POSDATA')
101 self.send(cell.T / units.Bohr, np.float64)
102 self.send(icell.T * units.Bohr, np.float64)
103 self.send(len(positions), np.int32)
104 self.send(positions / units.Bohr, np.float64)
106 def recvposdata(self):
107 cell = self.recv((3, 3), np.float64).T.copy()
108 icell = self.recv((3, 3), np.float64).T.copy()
109 natoms = self.recv(1, np.int32)[0]
110 positions = self.recv((natoms, 3), np.float64)
111 return cell * units.Bohr, icell / units.Bohr, positions * units.Bohr
113 def sendrecv_force(self):
114 self.log(' sendrecv_force')
115 self.sendmsg('GETFORCE')
116 msg = self.recvmsg()
117 assert msg == 'FORCEREADY', msg
118 e = self.recv(1, np.float64)[0]
119 natoms = self.recv(1, np.int32)[0]
120 assert natoms >= 0
121 forces = self.recv((int(natoms), 3), np.float64)
122 virial = self.recv((3, 3), np.float64).T.copy()
123 nmorebytes = self.recv(1, np.int32)[0]
124 morebytes = self.recv(nmorebytes, np.byte)
125 return (e * units.Ha, (units.Ha / units.Bohr) * forces,
126 units.Ha * virial, morebytes)
128 def sendforce(self, energy, forces, virial,
129 morebytes=np.zeros(1, dtype=np.byte)):
130 assert np.array([energy]).size == 1
131 assert forces.shape[1] == 3
132 assert virial.shape == (3, 3)
134 self.log(' sendforce')
135 self.sendmsg('FORCEREADY') # mind the units
136 self.send(np.array([energy / units.Ha]), np.float64)
137 natoms = len(forces)
138 self.send(np.array([natoms]), np.int32)
139 self.send(units.Bohr / units.Ha * forces, np.float64)
140 self.send(1.0 / units.Ha * virial.T, np.float64)
141 # We prefer to always send at least one byte due to trouble with
142 # empty messages. Reading a closed socket yields 0 bytes
143 # and thus can be confused with a 0-length bytestring.
144 self.send(np.array([len(morebytes)]), np.int32)
145 self.send(morebytes, np.byte)
147 def status(self):
148 self.log(' status')
149 self.sendmsg('STATUS')
150 msg = self.recvmsg()
151 return msg
153 def end(self):
154 self.log(' end')
155 self.sendmsg('EXIT')
157 def recvinit(self):
158 self.log(' recvinit')
159 bead_index = self.recv(1, np.int32)
160 nbytes = self.recv(1, np.int32)
161 initbytes = self.recv(nbytes, np.byte)
162 return bead_index, initbytes
164 def sendinit(self):
165 # XXX Not sure what this function is supposed to send.
166 # It 'works' with QE, but for now we try not to call it.
167 self.log(' sendinit')
168 self.sendmsg('INIT')
169 self.send(0, np.int32) # 'bead index' always zero for now
170 # We send one byte, which is zero, since things may not work
171 # with 0 bytes. Apparently implementations ignore the
172 # initialization string anyway.
173 self.send(1, np.int32)
174 self.send(np.zeros(1), np.byte) # initialization string
176 def calculate(self, positions, cell):
177 self.log('calculate')
178 msg = self.status()
179 # We don't know how NEEDINIT is supposed to work, but some codes
180 # seem to be okay if we skip it and send the positions instead.
181 if msg == 'NEEDINIT':
182 self.sendinit()
183 msg = self.status()
184 assert msg == 'READY', msg
185 icell = np.linalg.pinv(cell).transpose()
186 self.sendposdata(cell, icell, positions)
187 msg = self.status()
188 assert msg == 'HAVEDATA', msg
189 e, forces, virial, morebytes = self.sendrecv_force()
190 r = dict(energy=e,
191 forces=forces,
192 virial=virial,
193 morebytes=morebytes)
194 return r
197@contextmanager
198def bind_unixsocket(socketfile):
199 assert socketfile.startswith('/tmp/ipi_'), socketfile
200 serversocket = socket.socket(socket.AF_UNIX)
201 try:
202 serversocket.bind(socketfile)
203 except OSError as err:
204 raise OSError(f'{err}: {socketfile!r}')
206 try:
207 with serversocket:
208 yield serversocket
209 finally:
210 os.unlink(socketfile)
213@contextmanager
214def bind_inetsocket(port):
215 serversocket = socket.socket(socket.AF_INET)
216 serversocket.setsockopt(socket.SOL_SOCKET,
217 socket.SO_REUSEADDR, 1)
218 serversocket.bind(('', port))
219 with serversocket:
220 yield serversocket
223class FileIOSocketClientLauncher:
224 def __init__(self, calc):
225 self.calc = calc
227 def __call__(self, atoms, properties=None, port=None, unixsocket=None):
228 assert self.calc is not None
229 cwd = self.calc.directory
231 profile = getattr(self.calc, 'profile', None)
232 if isinstance(self.calc, GenericFileIOCalculator):
233 # New GenericFileIOCalculator:
234 template = getattr(self.calc, 'template')
236 self.calc.write_inputfiles(atoms, properties)
237 if unixsocket is not None:
238 argv = template.socketio_argv(
239 profile, unixsocket=unixsocket, port=None
240 )
241 else:
242 argv = template.socketio_argv(
243 profile, unixsocket=None, port=port
244 )
245 return Popen(argv, cwd=cwd, env=os.environ)
246 else:
247 # Old FileIOCalculator:
248 self.calc.write_input(atoms, properties=properties,
249 system_changes=all_changes)
251 if profile is None:
252 cmd = self.calc.command.replace('PREFIX', self.calc.prefix)
253 cmd = cmd.format(port=port, unixsocket=unixsocket)
254 elif isinstance(profile, OldShellProfile):
255 cmd = profile.command.replace("PREFIX", self.calc.prefix)
256 return Popen(cmd, shell=True, cwd=cwd)
257 elif isinstance(profile, StandardProfile):
258 return profile.execute_nonblocking(self.calc)
261class SocketServer(IOContext):
262 default_port = 31415
264 def __init__(self, # launch_client=None,
265 port=None, unixsocket=None, timeout=None,
266 log=None):
267 """Create server and listen for connections.
269 Parameters:
271 client_command: Shell command to launch client process, or None
272 The process will be launched immediately, if given.
273 Else the user is expected to launch a client whose connection
274 the server will then accept at any time.
275 One calculate() is called, the server will block to wait
276 for the client.
277 port: integer or None
278 Port on which to listen for INET connections. Defaults
279 to 31415 if neither this nor unixsocket is specified.
280 unixsocket: string or None
281 Filename for unix socket.
282 timeout: float or None
283 timeout in seconds, or unlimited by default.
284 This parameter is passed to the Python socket object; see
285 documentation therof
286 log: file object or None
287 useful debug messages are written to this."""
289 if unixsocket is None and port is None:
290 port = self.default_port
291 elif unixsocket is not None and port is not None:
292 raise ValueError('Specify only one of unixsocket and port')
294 self.port = port
295 self.unixsocket = unixsocket
296 self.timeout = timeout
297 self._closed = False
299 if unixsocket is not None:
300 actualsocket = actualunixsocketname(unixsocket)
301 conn_name = f'UNIX-socket {actualsocket}'
302 socket_context = bind_unixsocket(actualsocket)
303 else:
304 conn_name = f'INET port {port}'
305 socket_context = bind_inetsocket(port)
307 self.serversocket = self.closelater(socket_context)
309 if log:
310 print(f'Accepting clients on {conn_name}', file=log)
312 self.serversocket.settimeout(timeout)
314 self.serversocket.listen(1)
316 self.log = log
318 self.proc = None
320 self.protocol = None
321 self.clientsocket = None
322 self.address = None
324 # if launch_client is not None:
325 # self.proc = launch_client(port=port, unixsocket=unixsocket)
327 def _accept(self):
328 """Wait for client and establish connection."""
329 # It should perhaps be possible for process to be launched by user
330 log = self.log
331 if log:
332 print('Awaiting client', file=self.log)
334 # If we launched the subprocess, the process may crash.
335 # We want to detect this, using loop with timeouts, and
336 # raise an error rather than blocking forever.
337 if self.proc is not None:
338 self.serversocket.settimeout(1.0)
340 while True:
341 try:
342 self.clientsocket, self.address = self.serversocket.accept()
343 self.closelater(self.clientsocket)
344 except socket.timeout:
345 if self.proc is not None:
346 status = self.proc.poll()
347 if status is not None:
348 raise OSError('Subprocess terminated unexpectedly'
349 ' with status {}'.format(status))
350 else:
351 break
353 self.serversocket.settimeout(self.timeout)
354 self.clientsocket.settimeout(self.timeout)
356 if log:
357 # For unix sockets, address is b''.
358 source = ('client' if self.address == b'' else self.address)
359 print(f'Accepted connection from {source}', file=log)
361 self.protocol = IPIProtocol(self.clientsocket, txt=log)
363 def close(self):
364 if self._closed:
365 return
367 super().close()
369 if self.log:
370 print('Close socket server', file=self.log)
371 self._closed = True
373 # Proper way to close sockets?
374 # And indeed i-pi connections...
375 # if self.protocol is not None:
376 # self.protocol.end() # Send end-of-communication string
377 self.protocol = None
378 if self.proc is not None:
379 exitcode = self.proc.wait()
380 if exitcode != 0:
381 import warnings
383 # Quantum Espresso seems to always exit with status 128,
384 # even if successful.
385 # Should investigate at some point
386 warnings.warn('Subprocess exited with status {}'
387 .format(exitcode))
388 # self.log('IPI server closed')
390 def calculate(self, atoms):
391 """Send geometry to client and return calculated things as dict.
393 This will block until client has established connection, then
394 wait for the client to finish the calculation."""
395 assert not self._closed
397 # If we have not established connection yet, we must block
398 # until the client catches up:
399 if self.protocol is None:
400 self._accept()
401 return self.protocol.calculate(atoms.positions, atoms.cell)
404class SocketClient:
405 def __init__(self, host='localhost', port=None,
406 unixsocket=None, timeout=None, log=None, comm=world):
407 """Create client and connect to server.
409 Parameters:
411 host: string
412 Hostname of server. Defaults to localhost
413 port: integer or None
414 Port to which to connect. By default 31415.
415 unixsocket: string or None
416 If specified, use corresponding UNIX socket.
417 See documentation of unixsocket for SocketIOCalculator.
418 timeout: float or None
419 See documentation of timeout for SocketIOCalculator.
420 log: file object or None
421 Log events to this file
422 comm: communicator or None
423 MPI communicator object. Defaults to ase.parallel.world.
424 When ASE runs in parallel, only the process with world.rank == 0
425 will communicate over the socket. The received information
426 will then be broadcast on the communicator. The SocketClient
427 must be created on all ranks of world, and will see the same
428 Atoms objects."""
429 # Only rank0 actually does the socket work.
430 # The other ranks only need to follow.
431 #
432 # Note: We actually refrain from assigning all the
433 # socket-related things except on master
434 self.comm = comm
436 if self.comm.rank == 0:
437 if unixsocket is not None:
438 sock = socket.socket(socket.AF_UNIX)
439 actualsocket = actualunixsocketname(unixsocket)
440 sock.connect(actualsocket)
441 else:
442 if port is None:
443 port = SocketServer.default_port
444 sock = socket.socket(socket.AF_INET)
445 sock.connect((host, port))
446 sock.settimeout(timeout)
447 self.host = host
448 self.port = port
449 self.unixsocket = unixsocket
451 self.protocol = IPIProtocol(sock, txt=log)
452 self.log = self.protocol.log
453 self.closed = False
455 self.bead_index = 0
456 self.bead_initbytes = b''
457 self.state = 'READY'
459 def close(self):
460 if not self.closed:
461 self.log('Close SocketClient')
462 self.closed = True
463 self.protocol.socket.close()
465 def calculate(self, atoms, use_stress):
466 # We should also broadcast the bead index, once we support doing
467 # multiple beads.
468 self.comm.broadcast(atoms.positions, 0)
469 self.comm.broadcast(np.ascontiguousarray(atoms.cell), 0)
471 energy = atoms.get_potential_energy()
472 forces = atoms.get_forces()
473 if use_stress:
474 stress = atoms.get_stress(voigt=False)
475 virial = -atoms.get_volume() * stress
476 else:
477 virial = np.zeros((3, 3))
478 return energy, forces, virial
480 def irun(self, atoms, use_stress=None):
481 if use_stress is None:
482 use_stress = any(atoms.pbc)
484 my_irun = self.irun_rank0 if self.comm.rank == 0 else self.irun_rankN
485 return my_irun(atoms, use_stress)
487 def irun_rankN(self, atoms, use_stress=True):
488 stop_criterion = np.zeros(1, bool)
489 while True:
490 self.comm.broadcast(stop_criterion, 0)
491 if stop_criterion[0]:
492 return
494 self.calculate(atoms, use_stress)
495 yield
497 def irun_rank0(self, atoms, use_stress=True):
498 # For every step we either calculate or quit. We need to
499 # tell other MPI processes (if this is MPI-parallel) whether they
500 # should calculate or quit.
501 try:
502 while True:
503 try:
504 msg = self.protocol.recvmsg()
505 except SocketClosed:
506 # Server closed the connection, but we want to
507 # exit gracefully anyway
508 msg = 'EXIT'
510 if msg == 'EXIT':
511 # Send stop signal to clients:
512 self.comm.broadcast(np.ones(1, bool), 0)
513 # (When otherwise exiting, things crashed and we should
514 # let MPI_ABORT take care of the mess instead of trying
515 # to synchronize the exit)
516 return
517 elif msg == 'STATUS':
518 self.protocol.sendmsg(self.state)
519 elif msg == 'POSDATA':
520 assert self.state == 'READY'
521 cell, icell, positions = self.protocol.recvposdata()
522 atoms.cell[:] = cell
523 atoms.positions[:] = positions
525 # User may wish to do something with the atoms object now.
526 # Should we provide option to yield here?
527 #
528 # (In that case we should MPI-synchronize *before*
529 # whereas now we do it after.)
531 # Send signal for other ranks to proceed with calculation:
532 self.comm.broadcast(np.zeros(1, bool), 0)
533 energy, forces, virial = self.calculate(atoms, use_stress)
535 self.state = 'HAVEDATA'
536 yield
537 elif msg == 'GETFORCE':
538 assert self.state == 'HAVEDATA', self.state
539 self.protocol.sendforce(energy, forces, virial)
540 self.state = 'NEEDINIT'
541 elif msg == 'INIT':
542 assert self.state == 'NEEDINIT'
543 bead_index, initbytes = self.protocol.recvinit()
544 self.bead_index = bead_index
545 self.bead_initbytes = initbytes
546 self.state = 'READY'
547 else:
548 raise KeyError('Bad message', msg)
549 finally:
550 self.close()
552 def run(self, atoms, use_stress=False):
553 for _ in self.irun(atoms, use_stress=use_stress):
554 pass
557class SocketIOCalculator(Calculator, IOContext):
558 implemented_properties = ['energy', 'free_energy', 'forces', 'stress']
559 supported_changes = {'positions', 'cell'}
561 def __init__(self, calc=None, port=None,
562 unixsocket=None, timeout=None, log=None, *,
563 launch_client=None, comm=world):
564 """Initialize socket I/O calculator.
566 This calculator launches a server which passes atomic
567 coordinates and unit cells to an external code via a socket,
568 and receives energy, forces, and stress in return.
570 ASE integrates this with the Quantum Espresso, FHI-aims and
571 Siesta calculators. This works with any external code that
572 supports running as a client over the i-PI protocol.
574 Parameters:
576 calc: calculator or None
578 If calc is not None, a client process will be launched
579 using calc.command, and the input file will be generated
580 using ``calc.write_input()``. Otherwise only the server will
581 run, and it is up to the user to launch a compliant client
582 process.
584 port: integer
586 port number for socket. Should normally be between 1025
587 and 65535. Typical ports for are 31415 (default) or 3141.
589 unixsocket: str or None
591 if not None, ignore host and port, creating instead a
592 unix socket using this name prefixed with ``/tmp/ipi_``.
593 The socket is deleted when the calculator is closed.
595 timeout: float >= 0 or None
597 timeout for connection, by default infinite. See
598 documentation of Python sockets. For longer jobs it is
599 recommended to set a timeout in case of undetected
600 client-side failure.
602 log: file object or None (default)
604 logfile for communication over socket. For debugging or
605 the curious.
607 In order to correctly close the sockets, it is
608 recommended to use this class within a with-block:
610 >>> from ase.calculators.socketio import SocketIOCalculator
612 >>> with SocketIOCalculator(...) as calc: # doctest:+SKIP
613 ... atoms.calc = calc
614 ... atoms.get_forces()
615 ... atoms.rattle()
616 ... atoms.get_forces()
618 It is also possible to call calc.close() after
619 use. This is best done in a finally-block."""
621 Calculator.__init__(self)
623 if calc is not None:
624 if launch_client is not None:
625 raise ValueError('Cannot pass both calc and launch_client')
626 launch_client = FileIOSocketClientLauncher(calc)
627 self.launch_client = launch_client
628 self.timeout = timeout
629 self.server = None
631 self.log = self.openfile(file=log, comm=comm)
633 # We only hold these so we can pass them on to the server.
634 # They may both be None as stored here.
635 self._port = port
636 self._unixsocket = unixsocket
638 # If there is a calculator, we will launch in calculate() because
639 # we are responsible for executing the external process, too, and
640 # should do so before blocking. Without a calculator we want to
641 # block immediately:
642 if self.launch_client is None:
643 self.server = self.launch_server()
645 def todict(self):
646 d = {'type': 'calculator',
647 'name': 'socket-driver'}
648 # if self.calc is not None:
649 # d['calc'] = self.calc.todict()
650 return d
652 def launch_server(self):
653 return self.closelater(SocketServer(
654 # launch_client=launch_client,
655 port=self._port,
656 unixsocket=self._unixsocket,
657 timeout=self.timeout, log=self.log,
658 ))
660 def calculate(self, atoms=None, properties=['energy'],
661 system_changes=all_changes):
662 bad = [change for change in system_changes
663 if change not in self.supported_changes]
665 # First time calculate() is called, system_changes will be
666 # all_changes. After that, only positions and cell may change.
667 if self.atoms is not None and any(bad):
668 raise PropertyNotImplementedError(
669 'Cannot change {} through IPI protocol. '
670 'Please create new socket calculator.'
671 .format(bad if len(bad) > 1 else bad[0]))
673 self.atoms = atoms.copy()
675 if self.server is None:
676 self.server = self.launch_server()
677 proc = self.launch_client(atoms, properties,
678 port=self._port,
679 unixsocket=self._unixsocket)
680 self.server.proc = proc # XXX nasty hack
682 results = self.server.calculate(atoms)
683 results['free_energy'] = results['energy']
684 virial = results.pop('virial')
685 if self.atoms.cell.rank == 3 and any(self.atoms.pbc):
686 vol = atoms.get_volume()
687 results['stress'] = -full_3x3_to_voigt_6_stress(virial) / vol
688 self.results.update(results)
690 def close(self):
691 self.server = None
692 super().close()
695class PySocketIOClient:
696 def __init__(self, calculator_factory):
697 self._calculator_factory = calculator_factory
699 def __call__(self, atoms, properties=None, port=None, unixsocket=None):
700 import pickle
701 import sys
703 # We pickle everything first, so we won't need to bother with the
704 # process as long as it succeeds.
705 transferbytes = pickle.dumps([
706 dict(unixsocket=unixsocket, port=port),
707 atoms.copy(),
708 self._calculator_factory,
709 ])
711 proc = Popen([sys.executable, '-m', 'ase.calculators.socketio'],
712 stdin=PIPE)
714 proc.stdin.write(transferbytes)
715 proc.stdin.close()
716 return proc
718 @staticmethod
719 def main():
720 import pickle
721 import sys
723 socketinfo, atoms, get_calculator = pickle.load(sys.stdin.buffer)
724 atoms.calc = get_calculator()
725 client = SocketClient(host='localhost',
726 unixsocket=socketinfo.get('unixsocket'),
727 port=socketinfo.get('port'))
728 # XXX In principle we could avoid calculating stress until
729 # someone requests the stress, could we not?
730 # Which would make use_stress boolean unnecessary.
731 client.run(atoms, use_stress=True)
734if __name__ == '__main__':
735 PySocketIOClient.main()