Coverage for calorine/nep/model.py: 100%
757 statements
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-18 13:01 +0000
« prev ^ index » next coverage.py v7.14.1, created at 2026-06-18 13:01 +0000
1import copy
2from dataclasses import dataclass
3from itertools import product
5import numpy as np
7NetworkWeights = dict[str, dict[str, np.ndarray]]
8DescriptorWeights = dict[tuple[str, str], np.ndarray]
9RestartParameters = dict[str, dict[str, dict[str, np.ndarray]]]
12def _get_restart_contents(filename: str) -> tuple[list[float], list[float]]:
13 """Parses a ``nep.restart`` file, and returns an unformatted list of the
14 mean and standard deviation for all model parameters.
15 Intended to be used by the py:meth:`~Model.read_restart` function.
17 Parameters
18 ----------
19 filename
20 input file name
21 """
22 mu = [] # Mean
23 sigma = [] # Standard deviation
24 with open(filename) as f:
25 for k, line in enumerate(f.readlines()):
26 flds = line.split()
27 assert len(flds) != 0, f'Empty line number {k}'
28 if len(flds) == 2:
29 mu.append(float(flds[0]))
30 sigma.append(float(flds[1]))
31 else:
32 raise IOError(f'Failed to parse line {k} from {filename}')
33 return mu, sigma
36def _get_model_type(first_row: list[str]) -> str:
37 """Parses a the first row of a ``nep.txt`` file, and returns the
38 type of NEP model. Available types are `potential`, `potential_with_charges`,
39 `dipole`, and `polarizability`.
41 Parameters
42 ----------
43 first_row
44 First row of a NEP file, split by white space.
45 """
46 model_type = first_row[0]
47 if 'charge' in model_type:
48 return 'potential_with_charges'
49 elif 'dipole' in model_type:
50 return 'dipole'
51 elif 'polarizability' in model_type:
52 return 'polarizability'
53 return 'potential'
56def _get_nep_contents(filename: str) -> tuple[dict, list[float]]:
57 """Parses a ``nep.txt`` file, and returns a dict describing the header
58 and an unformatted list of all model parameters.
59 Intended to be used by the :func:`read_model <calorine.nep.read_model>` function.
61 Parameters
62 ----------
63 filename
64 input file name
65 """
66 # parse file and split header and parameters
67 header = []
68 parameters = []
69 nheader = 5 # 5 rows for NEP2, 6-7 rows for NEP3 onwards
70 base_line = 3
71 with open(filename) as f:
72 for k, line in enumerate(f.readlines()):
73 flds = line.split()
74 assert len(flds) != 0, f'Empty line number {k}'
75 if k == 0 and 'zbl' in flds[0]:
76 base_line += 1
77 nheader += 1
78 if k == base_line and 'basis_size' in flds[0]:
79 # Introduced in nep.txt after GPUMD v3.2
80 nheader += 1
81 if k < nheader:
82 header.append(tuple(flds))
83 elif len(flds) == 1:
84 parameters.append(float(flds[0]))
85 else:
86 raise IOError(f'Failed to parse line {k} from {filename}')
87 # compile data from the header into a dict
88 data = {}
89 for flds in header:
90 if flds[0] in ['cutoff', 'zbl']:
91 data[flds[0]] = tuple(map(float, flds[1:]))
92 elif flds[0] in ['n_max', 'l_max', 'ANN', 'basis_size']:
93 data[flds[0]] = tuple(map(int, flds[1:]))
94 elif flds[0].startswith('nep'):
95 version = flds[0].replace('nep', '').split('_')[0]
96 version = int(version)
97 data['version'] = version
98 data['types'] = flds[2:]
99 data['model_type'] = _get_model_type(flds)
100 else:
101 raise ValueError(f'Unknown field: {flds[0]}')
102 return data, parameters
105def _sort_descriptor_parameters(parameters: list[float],
106 types: list[str],
107 n_max_radial: int,
108 n_basis_radial: int,
109 n_max_angular: int,
110 n_basis_angular: int) -> tuple[DescriptorWeights,
111 DescriptorWeights]:
112 """Reads a list of descriptors parameters and sorts them into two
113 appropriately structured `dicts`, one for radial and one for angular descriptor weights.
114 Intended to be used by the :func:`read_model <calorine.nep.read_model>` function.
115 """
116 # split up descriptor by chemical species and radial/angular
117 n_types = len(types)
118 n = len(parameters) / (n_types * n_types)
119 assert n.is_integer(), 'number of descriptor groups must be an integer'
120 n = int(n)
122 m = (n_max_radial + 1) * (n_basis_radial + 1)
123 descriptor_weights = parameters.reshape((n, n_types * n_types)).T
124 descriptor_weights_radial = descriptor_weights[:, :m]
125 descriptor_weights_angular = descriptor_weights[:, m:]
127 # add descriptors to data dict
128 radial_descriptor_weights = {}
129 angular_descriptor_weights = {}
130 m = -1
131 for i, j in product(range(n_types), repeat=2):
132 m += 1
133 s1, s2 = types[i], types[j]
134 radial_descriptor_weights[(s1, s2)] = descriptor_weights_radial[m, :].reshape(
135 (n_max_radial + 1, n_basis_radial + 1)
136 )
137 angular_descriptor_weights[(s1, s2)] = descriptor_weights_angular[m, :].reshape(
138 (n_max_angular + 1, n_basis_angular + 1)
139 )
140 return radial_descriptor_weights, angular_descriptor_weights
143def _sort_ann_parameters(parameters: list[float],
144 ann_groupings: list[str],
145 n_neuron: int,
146 n_networks: int,
147 n_bias: int,
148 n_descriptor: int,
149 is_polarizability_model: bool,
150 is_model_with_charges: bool
151 ) -> NetworkWeights:
152 """Reads a list of model parameters and sorts them into an appropriately structured `dict`.
153 Intended to be used by the :func:`read_model <calorine.nep.read_model>` function.
154 """
155 n_ann_input_weights = (n_descriptor + 1) * n_neuron # weights + bias
156 n_ann_output_weights = 2*n_neuron if is_model_with_charges else n_neuron # only weights
157 n_ann_parameters = (
158 n_ann_input_weights + n_ann_output_weights
159 ) * n_networks + n_bias
161 # Group ANN parameters
162 pars = {}
163 n1 = 0
164 n_network_params = n_ann_input_weights + n_ann_output_weights # except last bias(es)
166 n_count = 2 if is_polarizability_model else 1
167 n_outputs = 2 if is_model_with_charges else 1
168 for count in range(n_count):
169 # if polarizability model, all parameters including bias are repeated
170 # need to offset n1 by +1 to handle bias
171 n1 += count
172 for s in ann_groupings:
173 # Get the parameters for the ANN; in the case of NEP4, there is effectively
174 # one network per atomic species.
175 ann_parameters = parameters[n1 : n1 + n_network_params]
176 ann_input_weights = ann_parameters[:n_ann_input_weights]
177 w0 = np.zeros((n_neuron, n_descriptor))
178 w0[...] = np.nan
179 b0 = np.zeros((n_neuron, 1))
180 b0[...] = np.nan
181 for n in range(n_neuron):
182 for nu in range(n_descriptor):
183 w0[n, nu] = ann_input_weights[n * n_descriptor + nu]
184 b0[:, 0] = ann_input_weights[n_neuron * n_descriptor :]
186 assert np.all(
187 w0.shape == (n_neuron, n_descriptor)
188 ), f'w0 has invalid shape for key {s}; please submit a bug report'
189 assert np.all(
190 b0.shape == (n_neuron, 1)
191 ), f'b0 has invalid shape for key {s}; please submit a bug report'
192 assert not np.any(
193 np.isnan(w0)
194 ), f'some weights in w0 are nan for key {s}; please submit a bug report'
195 assert not np.any(
196 np.isnan(b0)
197 ), f'some weights in b0 are nan for key {s}; please submit a bug report'
199 ann_output_weights = ann_parameters[
200 n_ann_input_weights : n_ann_input_weights + n_ann_output_weights
201 ]
202 w1 = np.zeros((1, n_neuron * n_outputs))
203 w1[0, :] = ann_output_weights[:]
204 assert np.all(
205 w1.shape == (1, n_neuron * n_outputs)
206 ), f'w1 has invalid shape for key {s}; please submit a bug report'
207 assert not np.any(
208 np.isnan(w1)
209 ), f'some weights in w1 are nan for key {s}; please submit a bug report'
211 if count == 0 and n_outputs == 1:
212 pars[s] = dict(w0=w0, b0=b0, w1=w1)
213 elif count == 0 and n_outputs == 2:
214 pars[s] = dict(w0=w0, b0=b0, w1=w1[0, :n_neuron], w1_charge=w1[0, n_neuron:])
215 else:
216 pars[s].update({'w0_polar': w0, 'b0_polar': b0, 'w1_polar': w1})
217 # Jump to bias
218 n1 += n_network_params
219 if n_bias > 1 and not is_model_with_charges:
220 # For NEP5 models we additionally have one bias term per species.
221 # Currently NEP5 only exists for potential models, but we'll
222 # keep it here in case it gets added down the line.
223 bias_label = 'b1' if count == 0 else 'b1_polar'
224 pars[s][bias_label] = parameters[n1]
225 n1 += 1
226 # For NEP3 and NEP4 we only have one bias.
227 # For NEP4 with charges we have two biases.
228 # For NEP5 we have one bias per species, and one global.
229 if count == 0 and n_outputs == 1:
230 pars['b1'] = parameters[n1]
231 elif count == 0 and n_outputs == 2:
232 pars['sqrt_epsilon_infinity'] = parameters[n1]
233 pars['b1'] = parameters[n1+1]
234 else:
235 pars['b1_polar'] = parameters[n1]
236 sum = 0
237 for s in pars.keys():
238 if s.startswith('b1') or s.startswith('sqrt'):
239 sum += 1
240 else:
241 sum += np.sum([np.array(p).size for p in pars[s].values()])
242 assert sum == n_ann_parameters * n_count, (
243 'Inconsistent number of parameters accounted for; please submit a bug report\n'
244 f'{sum} != {n_ann_parameters}'
245 )
246 return pars
249def _adaptive_sigma(mu_arr, sigma_factor: float, sigma_floor: float) -> np.ndarray:
250 """Return adaptive SNES sigma: ``max(sigma_floor, sigma_factor * |mu|)``."""
251 return np.maximum(sigma_floor, sigma_factor * np.abs(mu_arr))
254def _apply_adaptive_sigma_to_restart(restart_params, keys, sigma_factor, sigma_floor):
255 """Apply adaptive SNES sigma to every parameter in *restart_params* in-place.
257 Covers per-species ANN weights (w0, b0, w1, optional w1_charge), the global b1
258 scalar, the optional sqrt_epsilon_infinity scalar, and all radial/angular descriptor
259 weight pairs. *keys* is the list of per-species ANN keys to update.
260 """
261 for s in keys:
262 ann_mu = restart_params['ann_mu'][s]
263 ann_sigma = restart_params['ann_sigma'][s]
264 ann_sigma['w0'] = _adaptive_sigma(ann_mu['w0'], sigma_factor, sigma_floor)
265 ann_sigma['b0'] = _adaptive_sigma(ann_mu['b0'], sigma_factor, sigma_floor)
266 ann_sigma['w1'] = _adaptive_sigma(ann_mu['w1'], sigma_factor, sigma_floor)
267 if 'w1_charge' in ann_mu:
268 ann_sigma['w1_charge'] = _adaptive_sigma(
269 ann_mu['w1_charge'], sigma_factor, sigma_floor
270 )
271 b1_mu = restart_params['ann_mu']['b1']
272 restart_params['ann_sigma']['b1'] = float(_adaptive_sigma(b1_mu, sigma_factor, sigma_floor))
273 if 'sqrt_epsilon_infinity' in restart_params['ann_mu']:
274 sei_mu = restart_params['ann_mu']['sqrt_epsilon_infinity']
275 restart_params['ann_sigma']['sqrt_epsilon_infinity'] = float(
276 _adaptive_sigma(sei_mu, sigma_factor, sigma_floor)
277 )
278 for desc_type in ['radial', 'angular']:
279 sigma_key = f'{desc_type}_descriptor_sigma'
280 mu_key = f'{desc_type}_descriptor_mu'
281 for pair in restart_params[sigma_key]:
282 restart_params[sigma_key][pair] = _adaptive_sigma(
283 restart_params[mu_key][pair], sigma_factor, sigma_floor
284 )
287def _recalculate_parameter_counts(new) -> None:
288 """Recompute n_ann_parameters, n_descriptor_parameters, and n_parameters on *new*.
290 Reads all architectural state from *new* directly, so callers must update
291 new.n_neuron, new.n_descriptor_radial/angular, new.model_type, and new.types
292 before calling this function.
293 """
294 n_types = len(new.types)
295 n_desc = new.n_descriptor_radial + new.n_descriptor_angular
296 is_charged = new.model_type == 'potential_with_charges'
297 n_networks = n_types if new.version in (4, 5) else 1
298 n_bias = 2 if is_charged else (1 + n_types if new.version == 5 else 1)
299 n_ann_input_weights = (n_desc + 1) * new.n_neuron
300 n_ann_output_weights = 2 * new.n_neuron if is_charged else new.n_neuron
301 new.n_ann_parameters = (n_ann_input_weights + n_ann_output_weights) * n_networks + n_bias
302 new.n_descriptor_parameters = n_types ** 2 * (
303 (new.n_max_radial + 1) * (new.n_basis_radial + 1)
304 + (new.n_max_angular + 1) * (new.n_basis_angular + 1)
305 )
306 new.n_parameters = new.n_ann_parameters + new.n_descriptor_parameters + n_desc
307 if new.model_type == 'polarizability':
308 new.n_parameters += new.n_ann_parameters
311@dataclass
312class Model:
313 r"""Objects of this class represent a NEP model in a form suitable for
314 inspection and manipulation. Typically a :class:`Model` object is instantiated
315 by calling the :func:`read_model <calorine.nep.read_model>` function.
317 Attributes
318 ----------
319 version : int
320 NEP version.
321 model_type: str
322 One of ``potential``, ``dipole`` or ``polarizability``.
323 types : tuple[str, ...]
324 Chemical species that this model represents.
325 radial_cutoff : float | list[float]
326 The radial cutoff parameter in Å.
327 Is a list of radial cutoffs ordered after ``types`` in the case of typewise cutoffs.
328 angular_cutoff : float | list[float]
329 The angular cutoff parameter in Å.
330 Is a list of angular cutoffs ordered after ``types`` in the case of typewise cutoffs.
331 max_neighbors_radial : int
332 Maximum number of neighbors in neighbor list for radial terms.
333 max_neighbors_angular : int
334 Maximum number of neighbors in neighbor list for angular terms.
335 radial_typewise_cutoff_factor : float
336 The radial cutoff factor if use_typewise_cutoff is used.
337 angular_typewise_cutoff_factor : float
338 The angular cutoff factor if use_typewise_cutoff is used.
339 zbl : tuple[float, float]
340 Inner and outer cutoff for transition to ZBL potential.
341 zbl_typewise_cutoff_factor : float
342 Typewise cutoff when use_typewise_cutoff_zbl is used.
343 n_basis_radial : int
344 Number of radial basis functions :math:`n_\mathrm{basis}^\mathrm{R}`.
345 n_basis_angular : int
346 Number of angular basis functions :math:`n_\mathrm{basis}^\mathrm{A}`.
347 n_max_radial : int
348 Maximum order of Chebyshev polymonials included in
349 radial expansion :math:`n_\mathrm{max}^\mathrm{R}`.
350 n_max_angular : int
351 Maximum order of Chebyshev polymonials included in
352 angular expansion :math:`n_\mathrm{max}^\mathrm{A}`.
353 l_max_3b : int
354 Maximum expansion order for three-body terms :math:`l_\mathrm{max}^\mathrm{3b}`.
355 l_max_4b : int
356 Maximum expansion order for four-body terms :math:`l_\mathrm{max}^\mathrm{4b}`.
357 l_max_5b : int
358 Maximum expansion order for five-body terms :math:`l_\mathrm{max}^\mathrm{5b}`.
359 has_q_112 : int
360 Flag enabling the 5-body :math:`q_{112}` descriptor (0 or 1).
361 has_q_123 : int
362 Flag enabling the 5-body :math:`q_{123}` descriptor (0 or 1).
363 has_q_233 : int
364 Flag enabling the 5-body :math:`q_{233}` descriptor (0 or 1).
365 has_q_134 : int
366 Flag enabling the higher-body :math:`q_{134}` descriptor (0 or 1).
367 n_descriptor_radial : int
368 Dimension of radial part of descriptor.
369 n_descriptor_angular : int
370 Dimension of angular part of descriptor.
371 n_neuron : int
372 Number of neurons in hidden layer.
373 n_parameters : int
374 Total number of parameters including scalers (which are not fit parameters).
375 n_descriptor_parameters : int
376 Number of parameters in descriptor.
377 n_ann_parameters : int
378 Number of neural network weights.
379 ann_parameters : dict[tuple[str, dict[str, np.darray]]]
380 Neural network weights.
381 q_scaler : List[float]
382 Scaling parameters.
383 radial_descriptor_weights : dict[tuple[str, str], np.ndarray]
384 Radial descriptor weights by combination of species; the array for each combination
385 has dimensions of
386 :math:`(n_\mathrm{max}^\mathrm{R}+1) \times (n_\mathrm{basis}^\mathrm{R}+1)`.
387 angular_descriptor_weights : dict[tuple[str, str], np.ndarray]
388 Angular descriptor weights by combination of species; the array for each combination
389 has dimensions of
390 :math:`(n_\mathrm{max}^\mathrm{A}+1) \times (n_\mathrm{basis}^\mathrm{A}+1)`.
391 sqrt_epsilon_infinity : Optional[float]
392 Square root of epsilon infinity $\epsilon_\infty$ (only for NEP models with charges).
393 restart_parameters : dict[str, dict[str, dict[str, np.ndarray]]]
394 NEP restart parameters. A nested dictionary that contains the mean (mu) and standard
395 deviation (sigma) for the ANN and descriptor parameters. Is set using the
396 py:meth:`~Model.read_restart` method. Defaults to None.
397 """
399 version: int
400 model_type: str
401 types: tuple[str, ...]
403 radial_cutoff: float | list[float]
404 angular_cutoff: float | list[float]
406 n_basis_radial: int
407 n_basis_angular: int
408 n_max_radial: int
409 n_max_angular: int
410 l_max_3b: int
411 l_max_4b: int
412 l_max_5b: int
413 has_q_112: int
414 has_q_123: int
415 has_q_233: int
416 has_q_134: int
417 n_descriptor_radial: int
418 n_descriptor_angular: int
420 n_neuron: int
421 n_parameters: int
422 n_descriptor_parameters: int
423 n_ann_parameters: int
424 ann_parameters: NetworkWeights
425 q_scaler: list[float]
426 radial_descriptor_weights: DescriptorWeights
427 angular_descriptor_weights: DescriptorWeights
428 sqrt_epsilon_infinity: float = None
429 restart_parameters: RestartParameters = None
431 zbl: tuple[float, float] = None
432 zbl_typewise_cutoff_factor: float = None
433 max_neighbors_radial: int = None
434 max_neighbors_angular: int = None
435 radial_typewise_cutoff_factor: float = None
436 angular_typewise_cutoff_factor: float = None
438 _special_fields = [
439 'ann_parameters',
440 'q_scaler',
441 'radial_descriptor_weights',
442 'angular_descriptor_weights',
443 ]
445 def __str__(self) -> str:
446 s = []
447 for fld in self.__dataclass_fields__:
448 if fld not in self._special_fields:
449 value = getattr(self, fld)
450 if fld == 'restart_parameters':
451 value = 'available' if value is not None else 'not available'
452 s += [f'{fld:22} : {value}']
453 return '\n'.join(s)
455 def _repr_html_(self) -> str:
456 s = []
457 s += ['<table border="1" class="dataframe"']
458 s += [
459 '<thead><tr><th style="text-align: left;">Field</th><th>Value</th></tr></thead>'
460 ]
461 s += ['<tbody>']
462 for fld in self.__dataclass_fields__:
463 if fld not in self._special_fields:
464 value = getattr(self, fld)
465 if fld == 'restart_parameters':
466 value = 'available' if value is not None else 'not available'
467 s += [
468 f'<tr><td style="text-align: left;">{fld:22}</td>'
469 f'<td>{value}</td><tr>'
470 ]
471 for fld in self._special_fields:
472 d = getattr(self, fld)
473 # print('xxx', fld, d)
474 if fld.endswith('descriptor_weights'):
475 dim = list(d.values())[0].shape
476 elif fld == 'ann_parameters' and self.version == 4:
477 dim = (len(self.types), len(list(d.values())[0]))
478 else:
479 dim = len(d)
480 s += [
481 f'<tr><td style="text-align: left;">Dimension of {fld:22}</td><td>{dim}</td><tr>'
482 ]
483 s += ['</tbody>']
484 s += ['</table>']
485 return ''.join(s)
487 @property
488 def training_parameters(self) -> dict:
489 """Return model hyperparameters in the format accepted by :func:`write_nepfile
490 <calorine.nep.write_nepfile>`.
492 Use this after any model modification (:meth:`augment`, :meth:`add_species`,
493 :meth:`remove_species`, :meth:`keep_species`) to produce the architecture fields
494 that must go into the new ``nep.in`` before training. Merge the result with your
495 existing training-specific parameters (``lambda_*``, ``generation``, ``batch``,
496 etc.) before calling :func:`write_nepfile <calorine.nep.write_nepfile>`.
498 Returns
499 -------
500 dict
501 Keys ``version``, ``type``, ``cutoff``, ``n_max``, ``basis_size``, ``l_max``,
502 and ``neuron`` (plus ``zbl`` when applicable) with values in the format
503 expected by :func:`write_nepfile <calorine.nep.write_nepfile>`.
505 """
506 l_max = [self.l_max_3b, self.l_max_4b, self.l_max_5b,
507 self.has_q_112, self.has_q_123, self.has_q_233, self.has_q_134]
508 while len(l_max) > 1 and l_max[-1] == 0:
509 l_max = l_max[:-1]
511 if isinstance(self.radial_cutoff, list):
512 cutoff = []
513 for rc, ac in zip(self.radial_cutoff, self.angular_cutoff):
514 cutoff += [rc, ac]
515 else:
516 cutoff = [self.radial_cutoff, self.angular_cutoff]
518 params = {
519 'version': self.version,
520 'type': [len(self.types)] + list(self.types),
521 'cutoff': cutoff,
522 'n_max': [self.n_max_radial, self.n_max_angular],
523 'basis_size': [self.n_basis_radial, self.n_basis_angular],
524 'l_max': l_max,
525 'neuron': self.n_neuron,
526 }
527 if self.zbl is not None:
528 params['zbl'] = list(self.zbl)
529 return params
531 def remove_species(self,
532 species: list[str],
533 sigma_factor: float = 0.1,
534 sigma_floor: float = 1e-6) -> 'Model':
535 """Remove one or more species from the model.
537 Returns a new :class:`Model` with the specified species removed.
538 The source model is not modified.
540 If ``restart_parameters`` are loaded, the surviving parameters receive
541 adaptive SNES sigma values: ``sigma = max(sigma_floor, sigma_factor * |mu|)``,
542 re-opening the search distribution while preserving dormant parameters.
544 Parameters
545 ----------
546 species
547 Species names to remove.
548 sigma_factor
549 Used only when restart is loaded: ``sigma = max(sigma_floor, sigma_factor * |mu|)``
550 for surviving parameters.
551 sigma_floor
552 Minimum sigma for surviving parameters when restart is loaded.
554 Returns
555 -------
556 Model
557 New model with the specified species removed.
559 Raises
560 ------
561 ValueError
562 If any of the provided species is not found in the model.
563 """
564 for s in species:
565 if s not in self.types:
566 raise ValueError(f'{s} is not a species supported by the NEP model')
568 new = copy.deepcopy(self)
569 types_to_keep = [t for t in self.types if t not in species]
570 new.types = tuple(types_to_keep)
572 # Prune ANN parameters (for NEP4 and NEP5)
573 if self.version in [4, 5]:
574 new.ann_parameters = {
575 key: value for key, value in new.ann_parameters.items()
576 if key in types_to_keep or key.startswith('b1')
577 }
579 # Prune descriptor weights; key is a (species1, species2) tuple
580 new.radial_descriptor_weights = {
581 key: value for key, value in new.radial_descriptor_weights.items()
582 if key[0] in types_to_keep and key[1] in types_to_keep
583 }
584 new.angular_descriptor_weights = {
585 key: value for key, value in new.angular_descriptor_weights.items()
586 if key[0] in types_to_keep and key[1] in types_to_keep
587 }
589 # Prune typewise cutoff lists so remaining species map to correct cutoffs
590 if isinstance(self.radial_cutoff, list):
591 indices = [i for i, t in enumerate(self.types) if t not in species]
592 new.radial_cutoff = [self.radial_cutoff[i] for i in indices]
593 new.angular_cutoff = [self.angular_cutoff[i] for i in indices]
595 # Prune and optionally re-open restart parameters
596 if new.restart_parameters is not None:
597 ann_keys = types_to_keep if self.version in [4, 5] else ['all_species']
598 for param_type in ['mu', 'sigma']:
599 ann_key = f'ann_{param_type}'
600 if self.version in [4, 5]:
601 # Keep per-species keys for survivors, global bias keys, and
602 # sqrt_epsilon_infinity (charge models)
603 new.restart_parameters[ann_key] = {
604 key: value for key, value in new.restart_parameters[ann_key].items()
605 if (key in types_to_keep or key.startswith('b1')
606 or key == 'sqrt_epsilon_infinity')
607 }
609 # Prune descriptor restart parameters
610 for desc_type in ['radial', 'angular']:
611 key = f'{desc_type}_descriptor_{param_type}'
612 new.restart_parameters[key] = {
613 k: v for k, v in new.restart_parameters[key].items()
614 if k[0] in types_to_keep and k[1] in types_to_keep
615 }
617 # Apply adaptive sigma to all surviving parameters
618 _apply_adaptive_sigma_to_restart(
619 new.restart_parameters, ann_keys, sigma_factor, sigma_floor
620 )
622 # Recalculate parameter counts
623 _recalculate_parameter_counts(new)
625 return new
627 def keep_species(self,
628 species: list[str],
629 sigma_factor: float = 0.1,
630 sigma_floor: float = 1e-6) -> 'Model':
631 """Retain only the specified species, removing all others.
633 Convenience complement to :meth:`remove_species`. Useful when the set
634 of species to drop is large (e.g. isolating two elements from a
635 foundation model with dozens of species).
637 Parameters
638 ----------
639 species
640 Species names to keep. All other species are removed.
641 sigma_factor
642 Passed to :meth:`remove_species`. Controls adaptive sigma for
643 surviving parameters when restart is loaded.
644 sigma_floor
645 Passed to :meth:`remove_species`. Minimum sigma for surviving
646 parameters.
648 Returns
649 -------
650 Model
651 New model containing only the requested species.
653 Raises
654 ------
655 ValueError
656 If any of the requested species is not in the model.
657 """
658 unknown = [s for s in species if s not in self.types]
659 if unknown:
660 raise ValueError(
661 f'Species not in model: {unknown}'
662 )
663 to_remove = [s for s in self.types if s not in species]
664 return self.remove_species(to_remove, sigma_factor=sigma_factor, sigma_floor=sigma_floor)
666 def add_species(self,
667 species: list[str],
668 radial_cutoff: float | list[float] = None,
669 angular_cutoff: float | list[float] = None,
670 sigma_new: float = 0.1,
671 sigma_factor: float = 0.1,
672 sigma_floor: float = 1e-6,
673 seed: int | None = None) -> 'Model':
674 """Add one or more species to the model.
676 Returns a new :class:`Model` with the requested species added. New ANN
677 sub-networks and descriptor weight pairs are initialised by drawing
678 ``mu`` uniformly from [-1, 1] (matching the GPUMD fresh-model
679 initialisation), with ``sigma = sigma_new`` in the restart.
680 Charge-specific parameters (``w1_charge``) are kept at ``mu = 0`` to
681 preserve stability, also matching GPUMD.
682 Existing parameters receive adaptive sigma:
683 ``sigma = max(sigma_floor, sigma_factor * |mu|)``.
685 Only supported for NEP4 models. For NEP3 the ANN is shared across all
686 species and adding a per-species sub-network is not meaningful.
688 Parameters
689 ----------
690 species
691 New species names to add. Appended to ``types`` in the order given.
692 radial_cutoff
693 Radial cutoff(s) for the new species, in Å. Required when the model
694 uses typewise cutoffs (i.e. ``isinstance(model.radial_cutoff, list)``
695 is ``True``). Pass a single float or a list with one value per new
696 species.
697 angular_cutoff
698 Angular cutoff(s) for the new species, in Å. Same requirements as
699 ``radial_cutoff``.
700 sigma_new
701 SNES sigma assigned to all newly created parameters. Defaults to
702 ``0.1``, matching the GPUMD ``sigma0`` default.
703 sigma_factor
704 Controls sigma for *existing* parameters:
705 ``sigma = max(sigma_floor, sigma_factor * |mu|)``.
706 sigma_floor
707 Minimum sigma for existing parameters.
708 seed
709 Seed for the random number generator used to draw the initial ``mu``
710 values. Pass an integer for reproducible initialisation.
712 Returns
713 -------
714 Model
715 New model with updated structure, weights, and restart statistics.
717 Raises
718 ------
719 ValueError
720 If the model version is not 4, if ``restart_parameters`` are not
721 loaded, if any species is already in the model, or if typewise
722 cutoffs are used and ``radial_cutoff``/``angular_cutoff`` are not
723 provided.
724 """
725 if self.version != 4:
726 raise ValueError(
727 f'add_species() only supports NEP4 models; got version {self.version}.'
728 )
729 for s in species:
730 if s in self.types:
731 raise ValueError(f'{s!r} is already in the model.')
732 if self.restart_parameters is None:
733 raise ValueError(
734 'restart_parameters must be loaded before calling add_species(). '
735 'Pass restart_file= to read_model() or call model.read_restart() first.'
736 )
738 uses_typewise = isinstance(self.radial_cutoff, list)
739 if uses_typewise:
740 if radial_cutoff is None or angular_cutoff is None:
741 raise ValueError(
742 'Model uses typewise cutoffs; provide radial_cutoff and angular_cutoff '
743 'for the new species.'
744 )
745 rc_list = ([radial_cutoff] * len(species)
746 if isinstance(radial_cutoff, (int, float)) else list(radial_cutoff))
747 ac_list = ([angular_cutoff] * len(species)
748 if isinstance(angular_cutoff, (int, float)) else list(angular_cutoff))
749 if len(rc_list) != len(species) or len(ac_list) != len(species):
750 raise ValueError(
751 'Length of radial_cutoff/angular_cutoff must match the number of new species.'
752 )
754 new = copy.deepcopy(self)
756 n_descriptor = self.n_descriptor_radial + self.n_descriptor_angular
757 n_neuron = self.n_neuron
758 is_charged = self.model_type == 'potential_with_charges'
759 all_types_after = list(self.types) + list(species)
760 rng = np.random.default_rng(seed)
762 def _rand(shape):
763 return rng.uniform(-1.0, 1.0, size=shape)
765 # Step 1: Adaptive sigma for existing parameters
766 _apply_adaptive_sigma_to_restart(
767 new.restart_parameters, list(self.types), sigma_factor, sigma_floor
768 )
770 # Step 2: New ANN sub-networks
771 w1_shape = (n_neuron,) if is_charged else (1, n_neuron)
772 for s_new in species:
773 w0_vals = _rand((n_neuron, n_descriptor))
774 b0_vals = _rand((n_neuron, 1))
775 w1_vals = _rand(w1_shape)
776 s_params = {'w0': w0_vals.copy(), 'b0': b0_vals.copy(), 'w1': w1_vals.copy()}
777 if is_charged:
778 s_params['w1_charge'] = np.zeros(n_neuron)
779 new.ann_parameters[s_new] = s_params
781 mu_entry = {'w0': w0_vals, 'b0': b0_vals, 'w1': w1_vals}
782 sigma_entry = {
783 'w0': np.full((n_neuron, n_descriptor), sigma_new),
784 'b0': np.full((n_neuron, 1), sigma_new),
785 'w1': np.full(w1_shape, sigma_new),
786 }
787 if is_charged:
788 mu_entry['w1_charge'] = np.zeros(n_neuron)
789 sigma_entry['w1_charge'] = np.full(n_neuron, sigma_new)
790 new.restart_parameters['ann_mu'][s_new] = mu_entry
791 new.restart_parameters['ann_sigma'][s_new] = sigma_entry
793 # Step 3: New descriptor weight pairs
794 n_r = (self.n_max_radial + 1, self.n_basis_radial + 1)
795 n_a = (self.n_max_angular + 1, self.n_basis_angular + 1)
796 existing_pairs = set(self.radial_descriptor_weights)
797 new_pairs = {
798 (s1, s2)
799 for s1 in all_types_after for s2 in all_types_after
800 if (s1, s2) not in existing_pairs
801 }
802 for pair in new_pairs:
803 r_vals = _rand(n_r)
804 a_vals = _rand(n_a)
805 new.radial_descriptor_weights[pair] = r_vals.copy()
806 new.angular_descriptor_weights[pair] = a_vals.copy()
807 new.restart_parameters['radial_descriptor_mu'][pair] = r_vals
808 new.restart_parameters['angular_descriptor_mu'][pair] = a_vals
809 new.restart_parameters['radial_descriptor_sigma'][pair] = np.full(n_r, sigma_new)
810 new.restart_parameters['angular_descriptor_sigma'][pair] = np.full(n_a, sigma_new)
812 # Step 4: Update types and typewise cutoffs
813 new.types = tuple(all_types_after)
814 if uses_typewise:
815 new.radial_cutoff = list(self.radial_cutoff) + rc_list
816 new.angular_cutoff = list(self.angular_cutoff) + ac_list
818 # Step 5: Recalculate parameter counts
819 _recalculate_parameter_counts(new)
821 return new
823 def write(self, filename: str, restart_file: str = None) -> None:
824 """Write NEP model to file in `nep.txt` format.
826 Parameters
827 ----------
828 filename
829 Output file name for the NEP model.
830 restart_file
831 If provided, also write restart parameters to this file in
832 `nep.restart` format. Defaults to None.
833 """
834 with open(filename, 'w') as f:
835 # header
836 version_name = f'nep{self.version}'
837 if self.zbl is not None:
838 version_name += '_zbl'
839 elif self.model_type != 'potential':
840 version_name += f'_{self.model_type}'
841 f.write(f'{version_name} {len(self.types)} {" ".join(self.types)}\n')
842 if self.zbl is not None:
843 f.write(f'zbl {" ".join(map(str, self.zbl))}\n')
844 f.write('cutoff')
845 if isinstance(self.radial_cutoff, float) and isinstance(self.angular_cutoff, float):
846 f.write(f' {self.radial_cutoff} {self.angular_cutoff}')
847 else:
848 # Typewise cutoffs: one set of cutoffs per type
849 for i in range(len(self.types)):
850 f.write(f' {self.radial_cutoff[i]} {self.angular_cutoff[i]}')
851 f.write(f' {self.max_neighbors_radial} {self.max_neighbors_angular}')
852 f.write('\n')
853 f.write(f'n_max {self.n_max_radial} {self.n_max_angular}\n')
854 f.write(f'basis_size {self.n_basis_radial} {self.n_basis_angular}\n')
855 l_max_line = f'l_max {self.l_max_3b} {self.l_max_4b} {self.l_max_5b}'
856 if self.has_q_112 or self.has_q_123 or self.has_q_233 or self.has_q_134:
857 l_max_line += f' {self.has_q_112}'
858 if self.has_q_123 or self.has_q_233 or self.has_q_134:
859 l_max_line += f' {self.has_q_123}'
860 if self.has_q_233 or self.has_q_134:
861 l_max_line += f' {self.has_q_233}'
862 if self.has_q_134:
863 l_max_line += f' {self.has_q_134}'
864 f.write(l_max_line + '\n')
865 f.write(f'ANN {self.n_neuron} 0\n')
867 # neural network weights
868 keys = self.types if self.version in (4, 5) else ['all_species']
869 suffixes = ['', '_polar'] if self.model_type == 'polarizability' else ['']
870 for suffix in suffixes:
871 for s in keys:
872 # Order: w0, b0, w1 (, b1 if NEP5)
873 # w0 indexed as: n*N_descriptor + nu
874 w0 = self.ann_parameters[s][f'w0{suffix}']
875 b0 = self.ann_parameters[s][f'b0{suffix}']
876 w1 = self.ann_parameters[s][f'w1{suffix}']
877 for n in range(self.n_neuron):
878 for nu in range(
879 self.n_descriptor_radial + self.n_descriptor_angular
880 ):
881 f.write(f'{w0[n, nu]:15.7e}\n')
882 for b in b0[:, 0]:
883 f.write(f'{b:15.7e}\n')
884 for v in (w1[0, :] if w1.ndim == 2 else w1):
885 f.write(f'{v:15.7e}\n')
886 if f'w1_charge{suffix}' in self.ann_parameters[s]:
887 for v in self.ann_parameters[s][f'w1_charge{suffix}']:
888 f.write(f'{v:15.7e}\n')
889 if self.version == 5:
890 b1 = self.ann_parameters[s][f'b1{suffix}']
891 f.write(f'{b1:15.7e}\n')
892 if self.sqrt_epsilon_infinity is not None:
893 f.write(f'{self.sqrt_epsilon_infinity:15.7e}\n')
894 b1 = self.ann_parameters[f'b1{suffix}']
895 f.write(f'{b1:15.7e}\n')
897 # descriptor weights
898 mat = []
899 for s1 in self.types:
900 for s2 in self.types:
901 mat = np.hstack(
902 [mat, self.radial_descriptor_weights[(s1, s2)].flatten()]
903 )
904 mat = np.hstack(
905 [mat, self.angular_descriptor_weights[(s1, s2)].flatten()]
906 )
907 n_types = len(self.types)
908 n = int(len(mat) / (n_types * n_types))
909 mat = mat.reshape((n_types * n_types, n)).T
910 for v in mat.flatten():
911 f.write(f'{v:15.7e}\n')
913 # scaler
914 for v in self.q_scaler:
915 f.write(f'{v:15.7e}\n')
917 if restart_file is not None:
918 self.write_restart(restart_file)
920 def read_restart(self, filename: str):
921 """Parses a file in `nep.restart` format and saves the
922 content in the form of mean and standard deviation for each
923 parameter in the corresponding NEP model.
925 Parameters
926 ----------
927 filename
928 Input file name.
929 """
930 mu, sigma = _get_restart_contents(filename)
931 restart_parameters = np.array([mu, sigma]).T
933 is_polarizability_model = self.model_type == 'polarizability'
934 is_charged_model = self.model_type == 'potential_with_charges'
936 n1 = self.n_ann_parameters
937 n1 *= 2 if is_polarizability_model else 1
938 n2 = n1 + self.n_descriptor_parameters
939 ann_parameters = restart_parameters[:n1]
940 descriptor_parameters = np.array(restart_parameters[n1:n2])
942 if self.version == 3:
943 n_networks = 1
944 n_bias = 1
945 elif self.version == 4:
946 # one hidden layer per atomic species
947 n_networks = len(self.types)
948 n_bias = 1
949 else:
950 raise ValueError(f'Cannot load nep.restart for NEP model version {self.version}')
952 ann_groups = [s for s in self.ann_parameters.keys() if not s.startswith('b1')]
953 n_bias = len([s for s in self.ann_parameters.keys() if s.startswith('b1')])
954 if self.sqrt_epsilon_infinity is not None:
955 n_bias += 1 # charge models have sqrt_epsilon_infinity before b1
956 n_descriptor = self.n_descriptor_radial + self.n_descriptor_angular
957 restart = {}
959 for i, content_type in enumerate(['mu', 'sigma']):
960 ann = _sort_ann_parameters(ann_parameters[:, i],
961 ann_groups,
962 self.n_neuron,
963 n_networks,
964 n_bias,
965 n_descriptor,
966 is_polarizability_model,
967 is_charged_model)
968 radial, angular = _sort_descriptor_parameters(descriptor_parameters[:, i],
969 self.types,
970 self.n_max_radial,
971 self.n_basis_radial,
972 self.n_max_angular,
973 self.n_basis_angular)
975 restart[f'ann_{content_type}'] = ann
976 restart[f'radial_descriptor_{content_type}'] = radial
977 restart[f'angular_descriptor_{content_type}'] = angular
978 self.restart_parameters = restart
980 def write_restart(self, filename: str):
981 """Write NEP restart parameters to file in `nep.restart` format."""
982 keys = self.types if self.version in (4, 5) else ['all_species']
983 suffixes = ['', '_polar'] if self.model_type == 'polarizability' else ['']
984 columns = []
985 for i, parameter in enumerate(['mu', 'sigma']):
986 # neural network weights
987 ann_parameters = self.restart_parameters[f'ann_{parameter}']
988 column = []
989 for suffix in suffixes:
990 for s in keys:
991 # Order: w0, b0, w1 (, b1 if NEP5)
992 # w0 indexed as: n*N_descriptor + nu
993 w0 = ann_parameters[s][f'w0{suffix}']
994 b0 = ann_parameters[s][f'b0{suffix}']
995 w1 = ann_parameters[s][f'w1{suffix}']
996 for n in range(self.n_neuron):
997 for nu in range(
998 self.n_descriptor_radial + self.n_descriptor_angular
999 ):
1000 column.append(f'{w0[n, nu]:15.7e}')
1001 for b in b0[:, 0]:
1002 column.append(f'{b:15.7e}')
1003 for v in (w1[0, :] if w1.ndim == 2 else w1):
1004 column.append(f'{v:15.7e}')
1005 if f'w1_charge{suffix}' in ann_parameters[s]:
1006 for v in ann_parameters[s][f'w1_charge{suffix}']:
1007 column.append(f'{v:15.7e}')
1008 if f'sqrt_epsilon_infinity{suffix}' in ann_parameters:
1009 column.append(f'{ann_parameters[f"sqrt_epsilon_infinity{suffix}"]:15.7e}')
1010 b1 = ann_parameters[f'b1{suffix}']
1011 column.append(f'{b1:15.7e}')
1012 columns.append(column)
1014 # descriptor weights
1015 radial_descriptor_parameters = self.restart_parameters[f'radial_descriptor_{parameter}']
1016 angular_descriptor_parameters = self.restart_parameters[
1017 f'angular_descriptor_{parameter}']
1018 mat = []
1019 for s1 in self.types:
1020 for s2 in self.types:
1021 mat = np.hstack(
1022 [mat, radial_descriptor_parameters[(s1, s2)].flatten()]
1023 )
1024 mat = np.hstack(
1025 [mat, angular_descriptor_parameters[(s1, s2)].flatten()]
1026 )
1027 n_types = len(self.types)
1028 n = int(len(mat) / (n_types * n_types))
1029 mat = mat.reshape((n_types * n_types, n)).T
1030 for v in mat.flatten():
1031 column.append(f'{v:15.7e}')
1033 # Join the mean and standard deviation columns
1034 assert len(columns[0]) == len(columns[1]), 'Length of means must match standard deviation'
1035 joined = [f'{s1} {s2}\n' for s1, s2 in zip(*columns)]
1036 with open(filename, 'w') as f:
1037 f.writelines(joined)
1039 def augment(self,
1040 n_neuron: int = None,
1041 l_max_4b: int = None,
1042 l_max_5b: int = None,
1043 has_q_112: bool = None,
1044 has_q_123: bool = None,
1045 has_q_233: bool = None,
1046 has_q_134: bool = None,
1047 charge_head: bool = False,
1048 sigma_new: float = 0.01,
1049 sigma_factor: float = 0.1,
1050 sigma_floor: float = 1e-6) -> 'Model':
1051 """Augment the model by adding neurons, descriptor terms, or a charge output head.
1053 Returns a new :class:`Model` with the requested structural changes applied.
1054 The source model is not modified. Existing parameter values are preserved exactly;
1055 new parameters are initialized to zero. The restart SNES statistics are updated
1056 as follows:
1058 - Existing parameters: ``sigma = max(sigma_floor, sigma_factor * |mu|)``, which
1059 re-opens the SNES search distribution while keeping parameters that were driven
1060 toward zero effectively dormant.
1061 - New parameters: ``mu = 0``, ``sigma = sigma_new``.
1063 Parameters
1064 ----------
1065 n_neuron
1066 Target neuron count; must be >= current. ``None`` leaves unchanged.
1067 l_max_4b
1068 Target 4-body l_max value; must be >= current. ``None`` leaves unchanged.
1069 l_max_5b
1070 Target 5-body l_max value; must be >= current. ``None`` leaves unchanged.
1071 has_q_112
1072 ``True`` enables the q_112 5-body descriptor; ``None`` or ``False`` leaves
1073 the current state unchanged (disabling an already-enabled term raises).
1074 has_q_123
1075 Same as ``has_q_112`` but for the q_123 term.
1076 has_q_233
1077 Same as ``has_q_112`` but for the q_233 term.
1078 has_q_134
1079 Same as ``has_q_112`` but for the q_134 term.
1080 charge_head
1081 If ``True``, promote a ``potential`` model to ``potential_with_charges`` by
1082 adding a charge output head (w1_charge per species and sqrt_epsilon_infinity).
1083 sigma_new
1084 SNES sigma assigned to all newly created parameters.
1085 sigma_factor
1086 Controls the sigma for *existing* parameters:
1087 ``sigma = max(sigma_floor, sigma_factor * |mu|)``.
1088 sigma_floor
1089 Minimum sigma for existing parameters; keeps near-zero (dormant) parameters
1090 from being accidentally re-activated.
1092 Returns
1093 -------
1094 Model
1095 New model with updated structure, weights, and restart statistics.
1097 Raises
1098 ------
1099 ValueError
1100 If ``restart_parameters`` is not loaded, if ``n_neuron`` or an ``l_max_*``
1101 target is smaller than the current value, if a ``has_q_*`` flag attempts to
1102 disable an already-enabled term, or if ``charge_head=True`` on a model that
1103 is not of type ``potential``.
1104 """
1105 # Structural checks (independent of restart)
1106 if self.version not in (3, 4):
1107 raise ValueError(
1108 f'augment() only supports NEP versions 3 and 4; got version {self.version}.'
1109 )
1110 if n_neuron is not None and n_neuron < self.n_neuron:
1111 raise ValueError(
1112 f'n_neuron ({n_neuron}) must be >= current n_neuron ({self.n_neuron}); '
1113 'use prune() to reduce.'
1114 )
1115 if l_max_4b is not None and l_max_4b < self.l_max_4b:
1116 raise ValueError(
1117 f'l_max_4b ({l_max_4b}) must be >= current l_max_4b ({self.l_max_4b}); '
1118 'use prune() to disable.'
1119 )
1120 if l_max_5b is not None and l_max_5b < self.l_max_5b:
1121 raise ValueError(
1122 f'l_max_5b ({l_max_5b}) must be >= current l_max_5b ({self.l_max_5b}); '
1123 'use prune() to disable.'
1124 )
1125 for flag_val, name in [
1126 (has_q_112, 'has_q_112'), (has_q_123, 'has_q_123'), (has_q_233, 'has_q_233'),
1127 (has_q_134, 'has_q_134')
1128 ]:
1129 if flag_val is False and getattr(self, name):
1130 raise ValueError(
1131 f'Cannot disable {name} via augment(); '
1132 'use prune() to disable descriptor terms.'
1133 )
1134 if charge_head and self.model_type != 'potential':
1135 raise ValueError(
1136 f'charge_head=True requires model_type="potential"; '
1137 f'got "{self.model_type}".'
1138 )
1139 if self.restart_parameters is None:
1140 raise ValueError(
1141 'restart_parameters must be loaded before calling augment(). '
1142 'Pass restart_file= to read_model() or call model.read_restart() first.'
1143 )
1145 new = copy.deepcopy(self)
1147 # Resolve new structural parameters
1148 new_l_max_4b = l_max_4b if l_max_4b is not None else self.l_max_4b
1149 new_l_max_5b = l_max_5b if l_max_5b is not None else self.l_max_5b
1150 new_has_q_112 = int(has_q_112) if has_q_112 is not None else self.has_q_112
1151 new_has_q_123 = int(has_q_123) if has_q_123 is not None else self.has_q_123
1152 new_has_q_233 = int(has_q_233) if has_q_233 is not None else self.has_q_233
1153 new_has_q_134 = int(has_q_134) if has_q_134 is not None else self.has_q_134
1154 new_n_neuron = n_neuron if n_neuron is not None else self.n_neuron
1156 new_l_max_enh = (self.l_max_3b
1157 + (new_l_max_4b > 0) + (new_l_max_5b > 0)
1158 + (new_has_q_112 > 0) + (new_has_q_123 > 0) + (new_has_q_233 > 0)
1159 + (new_has_q_134 > 0))
1160 new_n_desc_angular = (self.n_max_angular + 1) * new_l_max_enh
1161 old_n_desc = self.n_descriptor_radial + self.n_descriptor_angular
1162 new_n_desc = self.n_descriptor_radial + new_n_desc_angular
1163 delta_desc = new_n_desc - old_n_desc
1164 delta_neuron = new_n_neuron - self.n_neuron
1166 keys = self.types if self.version in (4, 5) else ['all_species']
1168 # Step 1: Apply adaptive sigma to all existing parameters (re-open SNES search width)
1169 _apply_adaptive_sigma_to_restart(new.restart_parameters, keys, sigma_factor, sigma_floor)
1171 # Step 2: Expand descriptor dimensions (new columns in w0, new q_scaler entries)
1172 if delta_desc > 0:
1173 for s in keys:
1174 old_w0 = new.ann_parameters[s]['w0'] # (n_neuron_old, old_n_desc)
1175 new.ann_parameters[s]['w0'] = np.hstack(
1176 [old_w0, np.zeros((self.n_neuron, delta_desc))]
1177 )
1178 old_mu_w0 = new.restart_parameters['ann_mu'][s]['w0']
1179 new.restart_parameters['ann_mu'][s]['w0'] = np.hstack(
1180 [old_mu_w0, np.zeros((self.n_neuron, delta_desc))]
1181 )
1182 old_sigma_w0 = new.restart_parameters['ann_sigma'][s]['w0']
1183 new.restart_parameters['ann_sigma'][s]['w0'] = np.hstack(
1184 [old_sigma_w0, np.full((self.n_neuron, delta_desc), sigma_new)]
1185 )
1186 new.q_scaler = list(new.q_scaler) + [1.0] * delta_desc
1188 # Step 3: Expand neuron count (new rows in w0/b0, new columns in w1)
1189 if delta_neuron > 0:
1190 for s in keys:
1191 # w0: append new rows
1192 cur_w0 = new.ann_parameters[s]['w0'] # (n_old, new_n_desc)
1193 new.ann_parameters[s]['w0'] = np.vstack(
1194 [cur_w0, np.zeros((delta_neuron, new_n_desc))]
1195 )
1196 # b0: append new rows
1197 cur_b0 = new.ann_parameters[s]['b0']
1198 new.ann_parameters[s]['b0'] = np.vstack(
1199 [cur_b0, np.zeros((delta_neuron, 1))]
1200 )
1201 # w1: append new columns; handle both 2D (standard) and 1D (charge)
1202 cur_w1 = new.ann_parameters[s]['w1']
1203 zeros_w1 = (np.zeros(delta_neuron) if cur_w1.ndim == 1
1204 else np.zeros((1, delta_neuron)))
1205 new.ann_parameters[s]['w1'] = np.hstack([cur_w1, zeros_w1])
1206 if 'w1_charge' in new.ann_parameters[s]:
1207 cur_wc = new.ann_parameters[s]['w1_charge']
1208 new.ann_parameters[s]['w1_charge'] = np.hstack([cur_wc, np.zeros(delta_neuron)])
1210 # restart w0
1211 cur_mu_w0 = new.restart_parameters['ann_mu'][s]['w0']
1212 new.restart_parameters['ann_mu'][s]['w0'] = np.vstack(
1213 [cur_mu_w0, np.zeros((delta_neuron, new_n_desc))]
1214 )
1215 cur_sigma_w0 = new.restart_parameters['ann_sigma'][s]['w0']
1216 new.restart_parameters['ann_sigma'][s]['w0'] = np.vstack(
1217 [cur_sigma_w0, np.full((delta_neuron, new_n_desc), sigma_new)]
1218 )
1219 # restart b0
1220 cur_mu_b0 = new.restart_parameters['ann_mu'][s]['b0']
1221 new.restart_parameters['ann_mu'][s]['b0'] = np.vstack(
1222 [cur_mu_b0, np.zeros((delta_neuron, 1))]
1223 )
1224 cur_sigma_b0 = new.restart_parameters['ann_sigma'][s]['b0']
1225 new.restart_parameters['ann_sigma'][s]['b0'] = np.vstack(
1226 [cur_sigma_b0, np.full((delta_neuron, 1), sigma_new)]
1227 )
1228 # restart w1
1229 cur_mu_w1 = new.restart_parameters['ann_mu'][s]['w1']
1230 zeros_w1 = (np.zeros(delta_neuron) if cur_mu_w1.ndim == 1
1231 else np.zeros((1, delta_neuron)))
1232 new.restart_parameters['ann_mu'][s]['w1'] = np.hstack([cur_mu_w1, zeros_w1])
1233 cur_sigma_w1 = new.restart_parameters['ann_sigma'][s]['w1']
1234 zeros_w1 = (np.full(delta_neuron, sigma_new) if cur_sigma_w1.ndim == 1
1235 else np.full((1, delta_neuron), sigma_new))
1236 new.restart_parameters['ann_sigma'][s]['w1'] = np.hstack([cur_sigma_w1, zeros_w1])
1237 if 'w1_charge' in new.restart_parameters['ann_mu'][s]:
1238 cur = new.restart_parameters['ann_mu'][s]['w1_charge']
1239 new.restart_parameters['ann_mu'][s]['w1_charge'] = np.hstack(
1240 [cur, np.zeros(delta_neuron)]
1241 )
1242 cur = new.restart_parameters['ann_sigma'][s]['w1_charge']
1243 new.restart_parameters['ann_sigma'][s]['w1_charge'] = np.hstack(
1244 [cur, np.full(delta_neuron, sigma_new)]
1245 )
1247 # Step 4: Add charge output head
1248 if charge_head:
1249 new.model_type = 'potential_with_charges'
1250 new.sqrt_epsilon_infinity = 1.0
1251 for s in keys:
1252 cur_w1 = new.ann_parameters[s]['w1'] # (1, new_n_neuron)
1253 new.ann_parameters[s]['w1'] = cur_w1[0, :] # flatten to 1D
1254 new.ann_parameters[s]['w1_charge'] = np.zeros(new_n_neuron)
1256 cur_mu_w1 = new.restart_parameters['ann_mu'][s]['w1']
1257 new.restart_parameters['ann_mu'][s]['w1'] = cur_mu_w1[0, :]
1258 new.restart_parameters['ann_mu'][s]['w1_charge'] = np.zeros(new_n_neuron)
1260 cur_sigma_w1 = new.restart_parameters['ann_sigma'][s]['w1']
1261 new.restart_parameters['ann_sigma'][s]['w1'] = cur_sigma_w1[0, :]
1262 new.restart_parameters['ann_sigma'][s]['w1_charge'] = np.full(
1263 new_n_neuron, sigma_new
1264 )
1266 new.restart_parameters['ann_mu']['sqrt_epsilon_infinity'] = 1.0
1267 new.restart_parameters['ann_sigma']['sqrt_epsilon_infinity'] = float(sigma_new)
1269 # Step 5: Update header metadata
1270 new.l_max_4b = new_l_max_4b
1271 new.l_max_5b = new_l_max_5b
1272 new.has_q_112 = new_has_q_112
1273 new.has_q_123 = new_has_q_123
1274 new.has_q_233 = new_has_q_233
1275 new.has_q_134 = new_has_q_134
1276 new.n_descriptor_angular = new_n_desc_angular
1277 new.n_neuron = new_n_neuron
1279 # Step 6: Recalculate parameter counts
1280 _recalculate_parameter_counts(new)
1282 return new
1284 def prune(self,
1285 n_neuron: int = None,
1286 l_max_4b: int = None,
1287 l_max_5b: int = None,
1288 has_q_112: bool = None,
1289 has_q_123: bool = None,
1290 has_q_233: bool = None,
1291 has_q_134: bool = None,
1292 charge_head: bool = False,
1293 sigma_factor: float = 0.1,
1294 sigma_floor: float = 1e-6) -> 'Model':
1295 """Prune the model by removing neurons, disabling descriptor terms, or removing
1296 the charge output head.
1298 Returns a new :class:`Model` with the requested structural changes applied.
1299 The source model is not modified. When reducing ``n_neuron``, neurons are
1300 selected by importance score averaged over species:
1301 ``importance[n] = mean_s(||w0_s[n,:]||_2 * |w1_s[n]|)``.
1303 All surviving parameters receive adaptive SNES sigma:
1304 ``sigma = max(sigma_floor, sigma_factor * |mu|)``.
1306 Parameters
1307 ----------
1308 n_neuron
1309 Target neuron count; must be <= current. ``None`` leaves unchanged.
1310 l_max_4b
1311 Target 4-body l_max; must be <= current. Setting to ``0`` removes the
1312 4-body angular descriptor block. Reducing to a lower non-zero value is
1313 a header-only change (descriptor dimensions unchanged). ``None`` leaves
1314 unchanged.
1315 l_max_5b
1316 Same as ``l_max_4b`` but for five-body terms.
1317 has_q_112
1318 ``False`` disables and removes the q_112 descriptor block. ``None``
1319 leaves unchanged. ``True`` is not valid; use :meth:`augment` instead.
1320 has_q_123
1321 Same as ``has_q_112`` but for the q_123 term.
1322 has_q_233
1323 Same as ``has_q_112`` but for the q_233 term.
1324 has_q_134
1325 Same as ``has_q_112`` but for the q_134 term.
1326 charge_head
1327 If ``True``, remove the charge output head from a
1328 ``potential_with_charges`` model, converting it back to ``potential``.
1329 Removes ``w1_charge`` per species and ``sqrt_epsilon_infinity`` from
1330 the restart.
1331 sigma_factor
1332 Controls sigma for surviving parameters:
1333 ``sigma = max(sigma_floor, sigma_factor * |mu|)``.
1334 sigma_floor
1335 Minimum sigma for surviving parameters.
1337 Returns
1338 -------
1339 Model
1340 New model with reduced structure, weights, and restart statistics.
1342 Raises
1343 ------
1344 ValueError
1345 If ``restart_parameters`` is not loaded, if any target value would
1346 expand the model (use :meth:`augment` instead), if a ``has_q_*``
1347 flag is set to ``True``, or if ``charge_head=True`` on a model
1348 without charges.
1349 """
1350 # --- Resolve target values ---
1351 new_n_neuron = n_neuron if n_neuron is not None else self.n_neuron
1352 new_l_max_4b = l_max_4b if l_max_4b is not None else self.l_max_4b
1353 new_l_max_5b = l_max_5b if l_max_5b is not None else self.l_max_5b
1354 new_has_q_112 = 0 if has_q_112 is False else self.has_q_112
1355 new_has_q_123 = 0 if has_q_123 is False else self.has_q_123
1356 new_has_q_233 = 0 if has_q_233 is False else self.has_q_233
1357 new_has_q_134 = 0 if has_q_134 is False else self.has_q_134
1359 # --- Validate ---
1360 if self.version not in (3, 4):
1361 raise ValueError(
1362 f'prune() only supports NEP versions 3 and 4; got version {self.version}.'
1363 )
1364 if new_n_neuron > self.n_neuron:
1365 raise ValueError(
1366 f'n_neuron ({new_n_neuron}) must be <= current n_neuron ({self.n_neuron}); '
1367 'use augment() to increase.'
1368 )
1369 if new_l_max_4b > self.l_max_4b:
1370 raise ValueError(
1371 f'l_max_4b ({new_l_max_4b}) must be <= current l_max_4b ({self.l_max_4b}); '
1372 'use augment() to increase.'
1373 )
1374 if new_l_max_5b > self.l_max_5b:
1375 raise ValueError(
1376 f'l_max_5b ({new_l_max_5b}) must be <= current l_max_5b ({self.l_max_5b}); '
1377 'use augment() to increase.'
1378 )
1379 for flag_val, name in [
1380 (has_q_112, 'has_q_112'), (has_q_123, 'has_q_123'),
1381 (has_q_233, 'has_q_233'), (has_q_134, 'has_q_134')
1382 ]:
1383 if flag_val is True:
1384 raise ValueError(
1385 f'Cannot enable {name} via prune(); '
1386 'use augment() to enable descriptor terms.'
1387 )
1388 if charge_head and self.model_type != 'potential_with_charges':
1389 raise ValueError(
1390 f'charge_head=True requires model_type="potential_with_charges"; '
1391 f'got "{self.model_type}".'
1392 )
1393 if self.restart_parameters is None:
1394 raise ValueError(
1395 'restart_parameters must be loaded before calling prune(). '
1396 'Pass restart_file= to read_model() or call model.read_restart() first.'
1397 )
1399 new = copy.deepcopy(self)
1400 keys = self.types if self.version in (4, 5) else ['all_species']
1402 # Step 1: Adaptive sigma for all existing parameters
1403 _apply_adaptive_sigma_to_restart(new.restart_parameters, keys, sigma_factor, sigma_floor)
1405 # Step 2: Neuron pruning — keep the most important neurons
1406 if new_n_neuron < self.n_neuron:
1407 importances = []
1408 for s in keys:
1409 w0 = self.ann_parameters[s]['w0'] # (n_neuron, n_desc)
1410 w1_flat = self.ann_parameters[s]['w1'].ravel()
1411 if 'w1_charge' in self.ann_parameters[s]:
1412 output_norm = np.abs(w1_flat) + np.abs(self.ann_parameters[s]['w1_charge'])
1413 else:
1414 output_norm = np.abs(w1_flat)
1415 importances.append(np.linalg.norm(w0, axis=1) * output_norm)
1417 keep_idx = np.sort(np.argsort(np.mean(importances, axis=0))[-new_n_neuron:])
1419 for s in keys:
1420 new.ann_parameters[s]['w0'] = new.ann_parameters[s]['w0'][keep_idx, :]
1421 new.ann_parameters[s]['b0'] = new.ann_parameters[s]['b0'][keep_idx, :]
1422 w1 = new.ann_parameters[s]['w1']
1423 new.ann_parameters[s]['w1'] = w1[:, keep_idx] if w1.ndim == 2 else w1[keep_idx]
1424 if 'w1_charge' in new.ann_parameters[s]:
1425 new.ann_parameters[s]['w1_charge'] = (
1426 new.ann_parameters[s]['w1_charge'][keep_idx]
1427 )
1428 for pk in ['ann_mu', 'ann_sigma']:
1429 rp = new.restart_parameters[pk][s]
1430 rp['w0'] = rp['w0'][keep_idx, :]
1431 rp['b0'] = rp['b0'][keep_idx, :]
1432 w1 = rp['w1']
1433 rp['w1'] = w1[:, keep_idx] if w1.ndim == 2 else w1[keep_idx]
1434 if 'w1_charge' in rp:
1435 rp['w1_charge'] = rp['w1_charge'][keep_idx]
1437 # Step 3: Descriptor column pruning (disabling higher-body terms)
1438 n_per = self.n_max_angular + 1
1439 hb_terms = [
1440 (self.l_max_4b, new_l_max_4b),
1441 (self.l_max_5b, new_l_max_5b),
1442 (self.has_q_112, new_has_q_112),
1443 (self.has_q_123, new_has_q_123),
1444 (self.has_q_233, new_has_q_233),
1445 (self.has_q_134, new_has_q_134),
1446 ]
1447 keep_cols = list(range(self.n_descriptor_radial + n_per * self.l_max_3b))
1448 col_offset = len(keep_cols)
1449 for old_val, new_val in hb_terms:
1450 if old_val > 0:
1451 if new_val > 0:
1452 keep_cols.extend(range(col_offset, col_offset + n_per))
1453 col_offset += n_per
1455 old_n_desc = self.n_descriptor_radial + self.n_descriptor_angular
1456 if len(keep_cols) < old_n_desc:
1457 keep_cols = np.array(keep_cols, dtype=int)
1458 for s in keys:
1459 new.ann_parameters[s]['w0'] = new.ann_parameters[s]['w0'][:, keep_cols]
1460 for pk in ['ann_mu', 'ann_sigma']:
1461 rp = new.restart_parameters[pk][s]
1462 rp['w0'] = rp['w0'][:, keep_cols]
1463 new.q_scaler = [new.q_scaler[i] for i in keep_cols]
1465 # Step 4: Charge head removal
1466 if charge_head:
1467 new.model_type = 'potential'
1468 new.sqrt_epsilon_infinity = None
1469 for s in keys:
1470 w1 = new.ann_parameters[s]['w1'] # 1D (n_neuron,)
1471 new.ann_parameters[s]['w1'] = w1.reshape(1, -1)
1472 del new.ann_parameters[s]['w1_charge']
1473 for pk in ['ann_mu', 'ann_sigma']:
1474 rp = new.restart_parameters[pk][s]
1475 rp['w1'] = rp['w1'].reshape(1, -1)
1476 del rp['w1_charge']
1477 del new.restart_parameters['ann_mu']['sqrt_epsilon_infinity']
1478 del new.restart_parameters['ann_sigma']['sqrt_epsilon_infinity']
1480 # Step 5: Update header fields
1481 new.n_neuron = new_n_neuron
1482 new.l_max_4b = new_l_max_4b
1483 new.l_max_5b = new_l_max_5b
1484 new.has_q_112 = new_has_q_112
1485 new.has_q_123 = new_has_q_123
1486 new.has_q_233 = new_has_q_233
1487 new.has_q_134 = new_has_q_134
1489 new_l_max_enh = (self.l_max_3b
1490 + (new_l_max_4b > 0) + (new_l_max_5b > 0)
1491 + (new_has_q_112 > 0) + (new_has_q_123 > 0) + (new_has_q_233 > 0)
1492 + (new_has_q_134 > 0))
1493 new.n_descriptor_angular = (self.n_max_angular + 1) * new_l_max_enh
1495 # Step 6: Recalculate parameter counts
1496 _recalculate_parameter_counts(new)
1498 return new
1501def read_model(filename: str, restart_file: str = None) -> Model:
1502 """Parses a file in ``nep.txt`` format and returns the
1503 content in the form of a :class:`Model <calorine.nep.model.Model>`
1504 object.
1506 Parameters
1507 ----------
1508 filename
1509 Input file name.
1510 restart_file
1511 If provided, also read restart parameters from this file in
1512 `nep.restart` format and attach them to the returned model.
1513 Defaults to None.
1514 """
1515 data, parameters = _get_nep_contents(filename)
1517 # sanity checks
1518 for fld in ['cutoff', 'basis_size', 'n_max', 'l_max', 'ANN']:
1519 assert fld in data, f'Invalid model file; {fld} line is missing'
1520 assert data['version'] in [
1521 3,
1522 4,
1523 5,
1524 ], 'Invalid model file; only NEP versions 3, 4 and 5 are currently supported'
1526 # split up cutoff tuple
1527 N_types = len(data['types'])
1528 # Either global cutoffs + max neighbirs, or typewise cutoffs + max_neighbors
1529 assert len(data['cutoff']) in [4, 2*N_types+2]
1530 data['max_neighbors_radial'] = int(data['cutoff'][-2])
1531 data['max_neighbors_angular'] = int(data['cutoff'][-1])
1532 if len(data['cutoff']) == 2*N_types+2:
1533 # Typewise cutoffs: radial are even, angular are odd
1534 data['radial_cutoff'] = [data['cutoff'][i*2] for i in range(N_types)]
1535 data['angular_cutoff'] = [data['cutoff'][i*2+1] for i in range(N_types)]
1536 else:
1537 data['radial_cutoff'] = data['cutoff'][0]
1538 data['angular_cutoff'] = data['cutoff'][1]
1539 del data['cutoff']
1541 # split up basis_size tuple
1542 assert len(data['basis_size']) == 2
1543 data['n_basis_radial'] = data['basis_size'][0]
1544 data['n_basis_angular'] = data['basis_size'][1]
1545 del data['basis_size']
1547 # split up n_max tuple
1548 assert len(data['n_max']) == 2
1549 data['n_max_radial'] = data['n_max'][0]
1550 data['n_max_angular'] = data['n_max'][1]
1551 del data['n_max']
1553 # split up nl_max tuple
1554 len_l = len(data['l_max'])
1555 assert len_l in [1, 2, 3, 4, 5, 6, 7]
1556 data['l_max_3b'] = data['l_max'][0]
1557 data['l_max_4b'] = data['l_max'][1] if len_l > 1 else 0
1558 data['l_max_5b'] = data['l_max'][2] if len_l > 2 else 0
1559 data['has_q_112'] = data['l_max'][3] if len_l > 3 else 0
1560 data['has_q_123'] = data['l_max'][4] if len_l > 4 else 0
1561 data['has_q_233'] = data['l_max'][5] if len_l > 5 else 0
1562 data['has_q_134'] = data['l_max'][6] if len_l > 6 else 0
1563 del data['l_max']
1565 # compute dimensions of descriptor components
1566 data['n_descriptor_radial'] = data['n_max_radial'] + 1
1567 l_max_enh = (data['l_max_3b']
1568 + (data['l_max_4b'] > 0)
1569 + (data['l_max_5b'] > 0)
1570 + (data['has_q_112'] > 0)
1571 + (data['has_q_123'] > 0)
1572 + (data['has_q_233'] > 0)
1573 + (data['has_q_134'] > 0))
1574 data['n_descriptor_angular'] = (data['n_max_angular'] + 1) * l_max_enh
1575 n_descriptor = data['n_descriptor_radial'] + data['n_descriptor_angular']
1577 is_charged_model = data['model_type'] == 'potential_with_charges'
1578 # compute number of parameters
1579 data['n_neuron'] = data['ANN'][0]
1580 del data['ANN']
1581 n_types = len(data['types'])
1582 if data['version'] == 3:
1583 n = 1
1584 n_bias = 1
1585 elif data['version'] == 4 and is_charged_model:
1586 # one hidden layer per atomic species, but two output nodes
1587 n = n_types
1588 n_bias = 2
1589 elif data['version'] == 4:
1590 # one hidden layer per atomic species
1591 n = n_types
1592 n_bias = 1
1593 else: # NEP5
1594 # like nep4, but additionally has an
1595 # individual bias term in the output
1596 # layer for each species.
1597 n = n_types
1598 n_bias = 1 + n_types # one global bias + one per species
1600 n_ann_input_weights = (n_descriptor + 1) * data['n_neuron'] # weights + bias
1601 n_ann_output_weights = 2*data['n_neuron'] if is_charged_model else data['n_neuron'] # weights
1602 n_ann_parameters = (
1603 n_ann_input_weights + n_ann_output_weights
1604 ) * n + n_bias
1606 n_descriptor_weights = n_types**2 * (
1607 (data['n_max_radial'] + 1) * (data['n_basis_radial'] + 1)
1608 + (data['n_max_angular'] + 1) * (data['n_basis_angular'] + 1)
1609 )
1610 data['n_parameters'] = n_ann_parameters + n_descriptor_weights + n_descriptor
1611 is_polarizability_model = data['model_type'] == 'polarizability'
1612 if data['n_parameters'] + n_ann_parameters == len(parameters):
1613 data['n_parameters'] += n_ann_parameters
1614 assert is_polarizability_model, (
1615 'Model is not labelled as a polarizability model, but the number of '
1616 'parameters matches a polarizability model.\n'
1617 'If this is a polarizability model trained with GPUMD <=v3.8, please '
1618 'modify the header in the nep.txt file to enable parsing '
1619 f'`nep{data["version"]}_polarizability`.\n'
1620 )
1621 assert data['n_parameters'] == len(parameters), (
1622 'Parsing of parameters inconsistent; please submit a bug report\n'
1623 f'{data["n_parameters"]} != {len(parameters)}'
1624 )
1625 data['n_ann_parameters'] = n_ann_parameters
1627 # split up parameters into the ANN weights, descriptor weights, and scaling parameters
1628 n1 = n_ann_parameters
1629 n1 *= 2 if is_polarizability_model else 1
1630 n2 = n1 + n_descriptor_weights
1631 data['ann_parameters'] = parameters[:n1]
1632 descriptor_weights = np.array(parameters[n1:n2])
1633 data['q_scaler'] = parameters[n2:]
1635 # add ann parameters to data dict
1636 ann_groups = data['types'] if data['version'] in (4, 5) else ['all_species']
1637 sorted_ann_parameters = _sort_ann_parameters(data['ann_parameters'],
1638 ann_groups,
1639 data['n_neuron'],
1640 n,
1641 n_bias,
1642 n_descriptor,
1643 is_polarizability_model,
1644 is_charged_model)
1646 data['ann_parameters'] = sorted_ann_parameters
1647 if 'sqrt_epsilon_infinity' in sorted_ann_parameters.keys():
1648 data['sqrt_epsilon_infinity'] = sorted_ann_parameters['sqrt_epsilon_infinity']
1649 sorted_ann_parameters.pop('sqrt_epsilon_infinity')
1650 data['ann_parameters'] = sorted_ann_parameters
1652 # add descriptors to data dict
1653 data['n_descriptor_parameters'] = len(descriptor_weights)
1654 radial, angular = _sort_descriptor_parameters(descriptor_weights,
1655 data['types'],
1656 data['n_max_radial'],
1657 data['n_basis_radial'],
1658 data['n_max_angular'],
1659 data['n_basis_angular'])
1660 data['radial_descriptor_weights'] = radial
1661 data['angular_descriptor_weights'] = angular
1663 model = Model(**data)
1664 if restart_file is not None:
1665 model.read_restart(restart_file)
1666 return model