Coverage for /builds/alexhroom/ase/ase/optimize/optimize.py: 94.38%
178 statements
« prev ^ index » next coverage.py v7.5.3, created at 2024-08-05 14:37 +0000
« prev ^ index » next coverage.py v7.5.3, created at 2024-08-05 14:37 +0000
1"""Structure optimization. """
2import time
3import warnings
4from collections.abc import Callable
5from math import sqrt
6from os.path import isfile
7from typing import IO, Any, Dict, List, Optional, Tuple, Union
9from ase import Atoms
10from ase.calculators.calculator import PropertyNotImplementedError
11from ase.filters import UnitCellFilter
12from ase.parallel import world
13from ase.utils import IOContext, lazyproperty
14from ase.utils.abc import Optimizable
16DEFAULT_MAX_STEPS = 100_000_000
19class RestartError(RuntimeError):
20 pass
23class OptimizableAtoms(Optimizable):
24 def __init__(self, atoms):
25 self.atoms = atoms
27 def get_positions(self):
28 return self.atoms.get_positions()
30 def set_positions(self, positions):
31 self.atoms.set_positions(positions)
33 def get_forces(self):
34 return self.atoms.get_forces()
36 @lazyproperty
37 def _use_force_consistent_energy(self):
38 # This boolean is in principle invalidated if the
39 # calculator changes. This can lead to weird things
40 # in multi-step optimizations.
41 try:
42 self.atoms.get_potential_energy(force_consistent=True)
43 except PropertyNotImplementedError:
44 # warnings.warn(
45 # 'Could not get force consistent energy (\'free_energy\'). '
46 # 'Please make sure calculator provides \'free_energy\', even '
47 # 'if equal to the ordinary energy. '
48 # 'This will raise an error in future versions of ASE.',
49 # FutureWarning)
50 return False
51 else:
52 return True
54 def get_potential_energy(self):
55 force_consistent = self._use_force_consistent_energy
56 return self.atoms.get_potential_energy(
57 force_consistent=force_consistent)
59 def iterimages(self):
60 # XXX document purpose of iterimages
61 return self.atoms.iterimages()
63 def __len__(self):
64 # TODO: return 3 * len(self.atoms), because we want the length
65 # of this to be the number of DOFs
66 return len(self.atoms)
69class Dynamics(IOContext):
70 """Base-class for all MD and structure optimization classes."""
72 def __init__(
73 self,
74 atoms: Atoms,
75 logfile: Optional[Union[IO, str]] = None,
76 trajectory: Optional[str] = None,
77 append_trajectory: bool = False,
78 master: Optional[bool] = None,
79 comm=world,
80 *,
81 loginterval: int = 1,
82 ):
83 """Dynamics object.
85 Parameters
86 ----------
87 atoms : Atoms object
88 The Atoms object to operate on.
90 logfile : file object or str
91 If *logfile* is a string, a file with that name will be opened.
92 Use '-' for stdout.
94 trajectory : Trajectory object or str
95 Attach trajectory object. If *trajectory* is a string a
96 Trajectory will be constructed. Use *None* for no
97 trajectory.
99 append_trajectory : bool
100 Defaults to False, which causes the trajectory file to be
101 overwriten each time the dynamics is restarted from scratch.
102 If True, the new structures are appended to the trajectory
103 file instead.
105 master : bool
106 Defaults to None, which causes only rank 0 to save files. If set to
107 true, this rank will save files.
109 comm : Communicator object
110 Communicator to handle parallel file reading and writing.
112 loginterval : int, default: 1
113 Only write a log line for every *loginterval* time steps.
114 """
115 self.atoms = atoms
116 self.optimizable = atoms.__ase_optimizable__()
117 self.logfile = self.openfile(file=logfile, comm=comm, mode='a')
118 self.observers: List[Tuple[Callable, int, Tuple, Dict[str, Any]]] = []
119 self.nsteps = 0
120 self.max_steps = 0 # to be updated in run or irun
121 self.comm = comm
123 if trajectory is not None:
124 if isinstance(trajectory, str):
125 from ase.io.trajectory import Trajectory
126 mode = "a" if append_trajectory else "w"
127 trajectory = self.closelater(Trajectory(
128 trajectory, mode=mode, master=master, comm=comm
129 ))
130 self.attach(
131 trajectory,
132 interval=loginterval,
133 atoms=self.optimizable,
134 )
136 self.trajectory = trajectory
138 def todict(self) -> Dict[str, Any]:
139 raise NotImplementedError
141 def get_number_of_steps(self):
142 return self.nsteps
144 def insert_observer(
145 self, function, position=0, interval=1, *args, **kwargs
146 ):
147 """Insert an observer.
149 This can be used for pre-processing before logging and dumping.
151 Examples
152 --------
153 >>> from ase.build import bulk
154 >>> from ase.calculators.emt import EMT
155 >>> from ase.optimize import BFGS
156 ...
157 ...
158 >>> def update_info(atoms, opt):
159 ... atoms.info["nsteps"] = opt.nsteps
160 ...
161 ...
162 >>> atoms = bulk("Cu", cubic=True) * 2
163 >>> atoms.rattle()
164 >>> atoms.calc = EMT()
165 >>> with BFGS(atoms, logfile=None, trajectory="opt.traj") as opt:
166 ... opt.insert_observer(update_info, atoms=atoms, opt=opt)
167 ... opt.run(fmax=0.05, steps=10)
168 True
169 """
170 if not isinstance(function, Callable):
171 function = function.write
172 self.observers.insert(position, (function, interval, args, kwargs))
174 def attach(self, function, interval=1, *args, **kwargs):
175 """Attach callback function.
177 If *interval > 0*, at every *interval* steps, call *function* with
178 arguments *args* and keyword arguments *kwargs*.
180 If *interval <= 0*, after step *interval*, call *function* with
181 arguments *args* and keyword arguments *kwargs*. This is
182 currently zero indexed."""
184 if hasattr(function, "set_description"):
185 d = self.todict()
186 d.update(interval=interval)
187 function.set_description(d)
188 if not isinstance(function, Callable):
189 function = function.write
190 self.observers.append((function, interval, args, kwargs))
192 def call_observers(self):
193 for function, interval, args, kwargs in self.observers:
194 call = False
195 # Call every interval iterations
196 if interval > 0:
197 if (self.nsteps % interval) == 0:
198 call = True
199 # Call only on iteration interval
200 elif interval <= 0:
201 if self.nsteps == abs(interval):
202 call = True
203 if call:
204 function(*args, **kwargs)
206 def irun(self, steps=DEFAULT_MAX_STEPS):
207 """Run dynamics algorithm as generator.
209 Parameters
210 ----------
211 steps : int, default=DEFAULT_MAX_STEPS
212 Number of dynamics steps to be run.
214 Yields
215 ------
216 converged : bool
217 True if the forces on atoms are converged.
219 Examples
220 --------
221 This method allows, e.g., to run two optimizers or MD thermostats at
222 the same time.
223 >>> opt1 = BFGS(atoms)
224 >>> opt2 = BFGS(StrainFilter(atoms)).irun()
225 >>> for _ in opt2:
226 ... opt1.run()
227 """
229 # update the maximum number of steps
230 self.max_steps = self.nsteps + steps
232 # compute the initial step
233 self.optimizable.get_forces()
235 # log the initial step
236 if self.nsteps == 0:
237 self.log()
239 # we write a trajectory file if it is None
240 if self.trajectory is None:
241 self.call_observers()
242 # We do not write on restart w/ an existing trajectory file
243 # present. This duplicates the same entry twice
244 elif len(self.trajectory) == 0:
245 self.call_observers()
247 # check convergence
248 is_converged = self.converged()
249 yield is_converged
251 # run the algorithm until converged or max_steps reached
252 while not is_converged and self.nsteps < self.max_steps:
253 # compute the next step
254 self.step()
255 self.nsteps += 1
257 # log the step
258 self.log()
259 self.call_observers()
261 # check convergence
262 is_converged = self.converged()
263 yield is_converged
265 def run(self, steps=DEFAULT_MAX_STEPS):
266 """Run dynamics algorithm.
268 This method will return when the forces on all individual
269 atoms are less than *fmax* or when the number of steps exceeds
270 *steps*.
272 Parameters
273 ----------
274 steps : int, default=DEFAULT_MAX_STEPS
275 Number of dynamics steps to be run.
277 Returns
278 -------
279 converged : bool
280 True if the forces on atoms are converged.
281 """
283 for converged in Dynamics.irun(self, steps=steps):
284 pass
285 return converged
287 def converged(self):
288 """" a dummy function as placeholder for a real criterion, e.g. in
289 Optimizer """
290 return False
292 def log(self, *args):
293 """ a dummy function as placeholder for a real logger, e.g. in
294 Optimizer """
295 return True
297 def step(self):
298 """this needs to be implemented by subclasses"""
299 raise RuntimeError("step not implemented.")
302class Optimizer(Dynamics):
303 """Base-class for all structure optimization classes."""
305 # default maxstep for all optimizers
306 defaults = {'maxstep': 0.2}
307 _deprecated = object()
309 def __init__(
310 self,
311 atoms: Atoms,
312 restart: Optional[str] = None,
313 logfile: Optional[Union[IO, str]] = None,
314 trajectory: Optional[str] = None,
315 master: Optional[bool] = None,
316 comm=world,
317 append_trajectory: bool = False,
318 force_consistent=_deprecated,
319 ):
320 """Structure optimizer object.
322 Parameters:
324 atoms: Atoms object
325 The Atoms object to relax.
327 restart: str
328 Filename for restart file. Default value is *None*.
330 logfile: file object or str
331 If *logfile* is a string, a file with that name will be opened.
332 Use '-' for stdout.
334 trajectory: Trajectory object or str
335 Attach trajectory object. If *trajectory* is a string a
336 Trajectory will be constructed. Use *None* for no
337 trajectory.
339 master: boolean
340 Defaults to None, which causes only rank 0 to save files. If
341 set to true, this rank will save files.
343 comm: Communicator object
344 Communicator to handle parallel file reading and writing.
346 append_trajectory: boolean
347 Appended to the trajectory file instead of overwriting it.
349 force_consistent: boolean or None
350 Use force-consistent energy calls (as opposed to the energy
351 extrapolated to 0 K). If force_consistent=None, uses
352 force-consistent energies if available in the calculator, but
353 falls back to force_consistent=False if not.
354 """
355 self.check_deprecated(force_consistent)
357 super().__init__(
358 atoms=atoms,
359 logfile=logfile,
360 trajectory=trajectory,
361 append_trajectory=append_trajectory,
362 master=master,
363 comm=comm)
365 self.restart = restart
367 self.fmax = None
369 if restart is None or not isfile(restart):
370 self.initialize()
371 else:
372 self.read()
373 self.comm.barrier()
375 @classmethod
376 def check_deprecated(cls, force_consistent):
377 if force_consistent is cls._deprecated:
378 return False
380 warnings.warn(
381 'force_consistent keyword is deprecated and will '
382 'be ignored. This will raise an error in future versions '
383 'of ASE.',
384 FutureWarning)
386 def read(self):
387 raise NotImplementedError
389 def todict(self):
390 description = {
391 "type": "optimization",
392 "optimizer": self.__class__.__name__,
393 }
394 # add custom attributes from subclasses
395 for attr in ('maxstep', 'alpha', 'max_steps', 'restart',
396 'fmax'):
397 if hasattr(self, attr):
398 description.update({attr: getattr(self, attr)})
399 return description
401 def initialize(self):
402 pass
404 def irun(self, fmax=0.05, steps=DEFAULT_MAX_STEPS):
405 """Run optimizer as generator.
407 Parameters
408 ----------
409 fmax : float
410 Convergence criterion of the forces on atoms.
411 steps : int, default=DEFAULT_MAX_STEPS
412 Number of optimizer steps to be run.
414 Yields
415 ------
416 converged : bool
417 True if the forces on atoms are converged.
418 """
419 self.fmax = fmax
420 return Dynamics.irun(self, steps=steps)
422 def run(self, fmax=0.05, steps=DEFAULT_MAX_STEPS):
423 """Run optimizer.
425 Parameters
426 ----------
427 fmax : float
428 Convergence criterion of the forces on atoms.
429 steps : int, default=DEFAULT_MAX_STEPS
430 Number of optimizer steps to be run.
432 Returns
433 -------
434 converged : bool
435 True if the forces on atoms are converged.
436 """
437 self.fmax = fmax
438 return Dynamics.run(self, steps=steps)
440 def converged(self, forces=None):
441 """Did the optimization converge?"""
442 if forces is None:
443 forces = self.optimizable.get_forces()
444 return self.optimizable.converged(forces, self.fmax)
446 def log(self, forces=None):
447 if forces is None:
448 forces = self.optimizable.get_forces()
449 fmax = sqrt((forces ** 2).sum(axis=1).max())
450 e = self.optimizable.get_potential_energy()
451 T = time.localtime()
452 if self.logfile is not None:
453 name = self.__class__.__name__
454 if self.nsteps == 0:
455 args = (" " * len(name), "Step", "Time", "Energy", "fmax")
456 msg = "%s %4s %8s %15s %12s\n" % args
457 self.logfile.write(msg)
459 args = (name, self.nsteps, T[3], T[4], T[5], e, fmax)
460 msg = "%s: %3d %02d:%02d:%02d %15.6f %15.6f\n" % args
461 self.logfile.write(msg)
462 self.logfile.flush()
464 def dump(self, data):
465 from ase.io.jsonio import write_json
466 if self.comm.rank == 0 and self.restart is not None:
467 with open(self.restart, 'w') as fd:
468 write_json(fd, data)
470 def load(self):
471 from ase.io.jsonio import read_json
472 with open(self.restart) as fd:
473 try:
474 from ase.optimize import BFGS
475 if not isinstance(self, BFGS) and isinstance(
476 self.atoms, UnitCellFilter
477 ):
478 warnings.warn(
479 "WARNING: restart function is untested and may result "
480 "in unintended behavior. Namely orig_cell is not "
481 "loaded in the UnitCellFilter. Please test on your own"
482 " to ensure consistent results."
483 )
484 return read_json(fd, always_array=False)
485 except Exception as ex:
486 msg = ('Could not decode restart file as JSON. '
487 'You may need to delete the restart file '
488 f'{self.restart}')
489 raise RestartError(msg) from ex