Coverage for calorine / calculators / cpunep.py: 100%
163 statements
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-15 13:48 +0000
« prev ^ index » next coverage.py v7.13.1, created at 2026-01-15 13:48 +0000
1from __future__ import annotations
3import contextlib
4import os
5from tempfile import TemporaryFile
6from typing import List, Union
8import numpy as np
9from ase import Atoms
10from ase.calculators.calculator import Calculator, all_changes
11from ase.stress import full_3x3_to_voigt_6_stress
13import _nepy
14from calorine.nep.model import _get_nep_contents
15from calorine.nep.nep import _check_components_polarizability_gradient, \
16 _polarizability_gradient_to_3x3
19class CPUNEP(Calculator):
20 """This class provides an ASE calculator for `nep_cpu`,
21 the in-memory CPU implementation of GPUMD.
23 Parameters
24 ----------
25 model_filename : str
26 Path to file in ``nep.txt`` format with model parameters
27 atoms : Atoms
28 Atoms to attach the calculator to
29 label : str
30 Label for this calclator
31 debug : bool, optional
32 Flag to toggle debug mode. Prints GPUMD output. Defaults to False.
34 Raises
35 ------
36 FileNotFoundError
37 Raises :class:`FileNotFoundError` if :attr:`model_filename` does not point to a valid file.
38 ValueError
39 Raises :class:`ValueError` atoms are not defined when trying to get energies and forces.
40 Example
41 -------
43 >>> calc = CPUNEP('nep.txt')
44 >>> atoms.calc = calc
45 >>> atoms.get_potential_energy()
46 """
48 base_implemented_properties = [
49 'energy',
50 'energies',
51 'forces',
52 'stress',
53 'stresses',
54 ]
55 debug = False
56 nepy = None
58 def __init__(
59 self,
60 model_filename: str,
61 atoms: Atoms | None = None,
62 label: str | None = None,
63 debug: bool = False,
64 ):
65 self.debug = debug
67 if not os.path.exists(model_filename):
68 raise FileNotFoundError(f'{model_filename} does not exist.')
69 self.model_filename = str(model_filename)
71 # Get model type from first row in nep.txt
72 header, _ = _get_nep_contents(self.model_filename)
73 self.model_type = header['model_type']
74 self.supported_species = set(header['types'])
75 self.nep_version = header['version']
77 # Set implemented properties -- not use class-level property
78 # to avoid leaking state between calculator instances.
79 self.implemented_properties = list(self.base_implemented_properties)
80 if 'charge' in self.model_type:
81 # Only available for charge models
82 self.implemented_properties.extend(
83 ['charges', 'born_effective_charges'])
84 elif self.model_type == 'dipole':
85 # Only available for dipole models
86 self.implemented_properties = ['dipole']
87 elif self.model_type == 'polarizability':
88 # Only available for polarizability models
89 self.implemented_properties = ['polarizability']
91 # Initialize atoms, results and nepy - note that this is also done in Calculator.__init__()
92 if atoms is not None:
93 self.set_atoms(atoms)
94 parameters = {'model_filename': model_filename}
95 Calculator.__init__(self, label=label, atoms=atoms, **parameters)
96 if atoms is not None:
97 self._setup_nepy()
99 def __str__(self) -> str:
100 def indent(s: str, i: int) -> str:
101 s = '\n'.join([i * ' ' + line for line in s.split('\n')])
102 return s
104 parameters = '\n'.join(
105 [f'{key}: {value}' for key, value in self.parameters.items()]
106 )
107 parameters = indent(parameters, 4)
108 using_debug = '\nIn debug mode' if self.debug else ''
110 s = f'{self.__class__.__name__}\n{parameters}{using_debug}'
111 return s
113 def _setup_nepy(self):
114 """
115 Creates an instance of the NEPY class and attaches it to the calculator object.
116 The output from `nep.cpp` is only written to STDOUT if debug == True
117 """
118 if self.atoms is None:
119 raise ValueError('Atoms must be defined when calculating properties.')
120 if self.atoms.cell.rank == 0:
121 raise ValueError('Atoms must have a defined cell.')
123 natoms = len(self.atoms)
124 self.natoms = natoms
125 c = self.atoms.get_cell(complete=True).flatten()
126 cell = [c[0], c[3], c[6], c[1], c[4], c[7], c[2], c[5], c[8]]
127 symbols = self.atoms.get_chemical_symbols()
128 positions = list(
129 self.atoms.get_positions().T.flatten()
130 ) # [x1, ..., xN, y1, ... yN,...]
131 masses = self.atoms.get_masses()
133 # Disable output from C++ code by default
134 if self.debug:
135 self.nepy = _nepy.NEPY(
136 self.model_filename, self.natoms, cell, symbols, positions, masses
137 )
138 else:
139 with TemporaryFile('w') as f:
140 with contextlib.redirect_stdout(f):
141 self.nepy = _nepy.NEPY(
142 self.model_filename,
143 self.natoms,
144 cell,
145 symbols,
146 positions,
147 masses,
148 )
150 def set_atoms(self, atoms: Atoms):
151 """Updates the Atoms object.
153 Parameters
154 ----------
155 atoms : Atoms
156 Atoms to attach the calculator to
157 """
158 species_in_atoms_object = set(np.unique(atoms.get_chemical_symbols()))
159 if not species_in_atoms_object.issubset(self.supported_species):
160 raise ValueError('Structure contains species that are not supported by the NEP model.')
161 self.atoms = atoms
162 self.results = {}
163 self.nepy = None
165 def _update_symbols(self):
166 """Update atom symbols in NEPY."""
167 symbols = self.atoms.get_chemical_symbols()
168 self.nepy.set_symbols(symbols)
170 def _update_masses(self):
171 """Update atom masses in NEPY"""
172 masses = self.atoms.get_masses()
173 self.nepy.set_masses(masses)
175 def _update_cell(self):
176 """Update cell parameters in NEPY."""
177 c = self.atoms.get_cell(complete=True).flatten()
178 cell = [c[0], c[3], c[6], c[1], c[4], c[7], c[2], c[5], c[8]]
179 self.nepy.set_cell(cell)
181 def _update_positions(self):
182 """Update atom positions in NEPY."""
183 positions = list(
184 self.atoms.get_positions().T.flatten()
185 ) # [x1, ..., xN, y1, ... yN,...]
186 self.nepy.set_positions(positions)
188 def calculate(
189 self,
190 atoms: Atoms = None,
191 properties: List[str] = None,
192 system_changes: List[str] = all_changes,
193 ):
194 """Calculate energy, per atom energies, forces, stress and dipole.
196 Parameters
197 ----------
198 atoms : Atoms, optional
199 System for which to calculate properties, by default None
200 properties : List[str], optional
201 Properties to calculate, by default None
202 system_changes : List[str], optional
203 Changes to the system since last call, by default all_changes
204 """
205 if properties is None:
206 properties = self.implemented_properties
208 Calculator.calculate(self, atoms, properties, system_changes)
210 if self.nepy is None:
211 # Create new NEPY interface
212 self._setup_nepy()
213 # Update existing NEPY interface
214 for change in system_changes:
215 if change == 'positions':
216 self._update_positions()
217 elif change == 'numbers':
218 self._update_symbols()
219 self._update_masses()
220 elif change == 'cell':
221 self._update_cell()
223 if 'dipole' in properties:
224 dipole = np.array(self.nepy.get_dipole())
225 self.results['dipole'] = dipole
226 elif 'polarizability' in properties:
227 pol = np.array(self.nepy.get_polarizability())
228 polarizability = np.array([
229 [pol[0], pol[3], pol[5]],
230 [pol[3], pol[1], pol[4]],
231 [pol[5], pol[4], pol[2]]
232 ])
233 self.results['polarizability'] = polarizability
234 elif 'descriptors' in properties:
235 descriptors = np.array(self.nepy.get_descriptors())
236 descriptors_per_atom = descriptors.reshape(-1, self.natoms).T
237 self.results['descriptors'] = descriptors_per_atom
238 else:
239 if 'charge' in self.model_type:
240 energies, forces, virials, charges, becs = \
241 self.nepy.get_potential_forces_virials_and_charges()
242 else:
243 energies, forces, virials = self.nepy.get_potential_forces_and_virials()
245 energies_per_atom = np.array(energies)
246 energy = energies_per_atom.sum()
247 forces_per_atom = np.array(forces).reshape(-1, self.natoms).T
248 virials_per_atom = np.array(virials).reshape(-1, self.natoms).T
249 stresses_per_atom = virials_per_atom / self.atoms.get_volume()
250 stress = -(np.sum(virials_per_atom, axis=0) / self.atoms.get_volume()).reshape((3, 3))
251 stress = full_3x3_to_voigt_6_stress(stress)
253 self.results['energy'] = energy
254 self.results['forces'] = forces_per_atom
255 self.results['stress'] = stress
256 self.results['stresses'] = stresses_per_atom
258 if 'charge' in self.model_type:
259 charges_per_atom = np.array(charges)
260 becs_per_atom = np.array(becs).reshape(-1, self.natoms).T
262 self.results['charges'] = charges_per_atom
263 self.results['born_effective_charges'] = becs_per_atom
265 def get_dipole_gradient(
266 self,
267 displacement: float = 0.01,
268 method: str = 'central difference',
269 charge: float = 1.0,
270 ):
271 """Calculates the dipole gradient using finite differences.
273 Parameters
274 ----------
275 displacement
276 Displacement in Å to use for finite differences. Defaults to 0.01 Å.
277 method
278 Method for computing gradient with finite differences.
279 One of 'forward difference' and 'central difference'.
280 Defaults to 'central difference'
281 charge
282 System charge in units of the elemental charge.
283 Used for correcting the dipoles before computing the gradient.
284 Defaults to 1.0.
286 Returns
287 -------
288 dipole gradient with shape `(N, 3, 3)` where ``N`` are the number of atoms.
289 """
290 if 'dipole' not in self.implemented_properties:
291 raise ValueError('Dipole gradients are only defined for dipole NEP models.')
293 if displacement <= 0:
294 raise ValueError('displacement must be > 0 Å')
296 implemented_methods = {
297 'forward difference': 0,
298 'central difference': 1,
299 'second order central difference': 2,
300 }
302 if method not in implemented_methods.keys():
303 raise ValueError(f'Invalid method {method} for calculating gradient')
305 if self.nepy is None:
306 # Create new NEPY interface
307 self._setup_nepy()
309 dipole_gradient = np.array(
310 self.nepy.get_dipole_gradient(
311 displacement, implemented_methods[method], charge
312 )
313 ).reshape(self.natoms, 3, 3)
314 return dipole_gradient
316 def get_polarizability(
317 self,
318 atoms: Atoms = None,
319 properties: List[str] = None,
320 system_changes: List[str] = all_changes,
321 ) -> np.ndarray:
322 """Calculates the polarizability tensor for the current structure.
323 The model must have been trained to predict the polarizability.
324 This is a wrapper function for :func:`calculate`.
326 Parameters
327 ----------
328 atoms : Atoms, optional
329 System for which to calculate properties, by default None
330 properties : List[str], optional
331 Properties to calculate, by default None
332 system_changes : List[str], optional
333 Changes to the system since last call, by default all_changes
335 Returns
336 -------
337 polarizability with shape ``(3, 3)``
338 """
339 if properties is None:
340 properties = self.implemented_properties
342 if 'polarizability' not in properties:
343 raise ValueError('Polarizability is only defined for polarizability NEP models.')
344 self.calculate(atoms, properties, system_changes)
345 return self.results['polarizability']
347 def get_polarizability_gradient(
348 self,
349 displacement: float = 0.01,
350 component: Union[str, List[str]] = 'full',
351 ) -> np.ndarray:
352 """Calculates the dipole gradient for a given structure using finite differences.
353 This function computes the derivatives using the second-order central difference
354 method with a C++ backend.
356 Parameters
357 ----------
358 displacement
359 Displacement in Å to use for finite differences. Defaults to ``0.01``.
360 component
361 Component or components of the polarizability tensor that the gradient
362 should be computed for.
363 The following components are available: `x`, `y`, `z`, `full`
364 Option ``full`` computes the derivative whilst moving the atoms in each Cartesian
365 direction, which yields a tensor of shape ``(N, 3, 3, 3)``,
366 where ``N`` is the number of atoms.
367 Multiple components may be specified.
368 Defaults to ``full``.
370 Returns
371 -------
372 polarizability gradient with shape ``(N, C, 3, 3)`` with ``C`` components chosen.
373 """
374 if 'polarizability' not in self.implemented_properties:
375 raise ValueError('Polarizability gradients are only defined'
376 ' for polarizability NEP models.')
378 if displacement <= 0:
379 raise ValueError('displacement must be > 0 Å')
381 if self.nepy is None:
382 # Create new NEPY interface
383 self._setup_nepy()
385 component_array = _check_components_polarizability_gradient(component)
387 pg = np.array(
388 self.nepy.get_polarizability_gradient(
389 displacement, component_array
390 )
391 ).reshape(self.natoms, 3, 6)
392 polarizability_gradient = _polarizability_gradient_to_3x3(self.natoms, pg)
393 return polarizability_gradient[:, component_array, :, :]
395 def get_descriptors(
396 self,
397 atoms: Atoms = None,
398 properties: List[str] = None,
399 system_changes: List[str] = all_changes,
400 ) -> np.ndarray:
401 """Calculates the descriptor tensor for the current structure.
402 This is a wrapper function for :func:`calculate`.
404 Parameters
405 ----------
406 atoms : Atoms, optional
407 System for which to calculate properties, by default None
408 properties : List[str], optional
409 Properties to calculate, by default None
410 system_changes : List[str], optional
411 Changes to the system since last call, by default all_changes
413 Returns
414 -------
415 descriptors with shape ``(number_of_atoms, descriptor_components)``
416 """
417 self.calculate(atoms, ['descriptors'], system_changes)
418 return self.results['descriptors']
420 def get_born_effective_charges(
421 self,
422 atoms: Atoms = None,
423 properties: List[str] = None,
424 system_changes: List[str] = all_changes,
425 ) -> np.ndarray:
426 """Calculates (if needed) and returns the Born effective charges.
427 Note that this requires a qNEP model.
429 Parameters
430 ----------
431 atoms
432 System for which to calculate properties, by default `None`.
433 properties
434 Properties to calculate, by default `None`.
435 system_changes
436 Changes to the system since last call, by default all_changes.
437 """
438 if 'born_effective_charges' not in self.implemented_properties:
439 raise ValueError(
440 'This model does not support the calculation of Born effective charges.')
441 self.calculate(atoms, properties, system_changes)
442 return self.results['born_effective_charges']