Coverage for calorine/nep/model.py: 100%
281 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-12-10 08:26 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-12-10 08:26 +0000
1from dataclasses import dataclass
2from itertools import product
3from typing import Dict, List, Tuple
5import numpy as np
8def _get_model_type(first_row: List[str]) -> str:
9 """Parses a the first row of a ``nep.txt`` file,
10 and returns the type of NEP model. Available types
11 are ``potential``, ``dipole`` and ``polarizability``.
13 Parameters
14 ----------
15 first_row
16 first row of a NEP file, split by white space.
17 """
18 model_type = first_row[0]
19 if 'dipole' in model_type:
20 return 'dipole'
21 elif 'polarizability' in model_type:
22 return 'polarizability'
23 return 'potential'
26def _get_nep_contents(filename: str) -> Tuple[Dict, List]:
27 """Parses a ``nep.txt`` file, and returns a dict describing the header
28 and an unformatted list of all model parameters.
29 Intended to be used by the :func:`read_model <calorine.nep.read_model>` function.
31 Parameters
32 ----------
33 filename
34 input file name
35 """
36 # parse file and split header and parameters
37 header = []
38 parameters = []
39 nheader = 5 # 5 rows for NEP2, 6-7 rows for NEP3 onwards
40 base_line = 3
41 with open(filename) as f:
42 for k, line in enumerate(f.readlines()):
43 flds = line.split()
44 assert len(flds) != 0, f'Empty line number {k}'
45 if k == 0 and 'zbl' in flds[0]:
46 base_line += 1
47 nheader += 1
48 if k == base_line and 'basis_size' in flds[0]:
49 # Introduced in nep.txt after GPUMD v3.2
50 nheader += 1
51 if k < nheader:
52 header.append(tuple(flds))
53 elif len(flds) == 1:
54 parameters.append(float(flds[0]))
55 else:
56 raise IOError(f'Failed to parse line {k} from {filename}')
57 # compile data from the header into a dict
58 data = {}
59 for flds in header:
60 if flds[0] in ['cutoff', 'zbl']:
61 data[flds[0]] = tuple(map(float, flds[1:]))
62 elif flds[0] in ['n_max', 'l_max', 'ANN', 'basis_size']:
63 data[flds[0]] = tuple(map(int, flds[1:]))
64 elif flds[0].startswith('nep'):
65 version = flds[0].replace('nep', '').split('_')[0]
66 version = int(version)
67 data['version'] = version
68 data['types'] = flds[2:]
69 data['model_type'] = _get_model_type(flds)
70 else:
71 raise ValueError(f'Unknown field: {flds[0]}')
72 return data, parameters
75@dataclass
76class Model:
77 r"""Objects of this class represent a NEP model in a form suitable for
78 inspection and manipulation. Typically a :class:`Model` object is instantiated
79 by calling the :func:`read_model <calorine.nep.read_model>` function.
81 Attributes
82 ----------
83 version : int
84 NEP version.
85 model_type: str
86 One of ``potential``, ``dipole`` or ``polarizability``.
87 types : Tuple[str, ...]
88 Chemical species that this model represents.
89 radial_cutoff : float
90 The radial cutoff parameter in Å.
91 angular_cutoff : float
92 The angular cutoff parameter in Å.
93 max_neighbors_radial : int
94 Maximum number of neighbors in neighbor list for radial terms.
95 max_neighbors_angular : int
96 Maximum number of neighbors in neighbor list for angular terms.
97 radial_typewise_cutoff_factor : float
98 The radial cutoff factor if use_typewise_cutoff is used.
99 angular_typewise_cutoff_factor : float
100 The angular cutoff factor if use_typewise_cutoff is used.
101 zbl : Tuple[float, float]
102 Inner and outer cutoff for transition to ZBL potential.
103 zbl_typewise_cutoff_factor : float
104 Typewise cutoff when use_typewise_cutoff_zbl is used.
105 n_basis_radial : int
106 Number of radial basis functions :math:`n_\mathrm{basis}^\mathrm{R}`.
107 n_basis_angular : int
108 Number of angular basis functions :math:`n_\mathrm{basis}^\mathrm{A}`.
109 n_max_radial : int
110 Maximum order of Chebyshev polymonials included in
111 radial expansion :math:`n_\mathrm{max}^\mathrm{R}`.
112 n_max_angular : int
113 Maximum order of Chebyshev polymonials included in
114 angular expansion :math:`n_\mathrm{max}^\mathrm{A}`.
115 l_max_3b : int
116 Maximum expansion order for three-body terms :math:`l_\mathrm{max}^\mathrm{3b}`.
117 l_max_4b : int
118 Maximum expansion order for four-body terms :math:`l_\mathrm{max}^\mathrm{4b}`.
119 l_max_5b : int
120 Maximum expansion order for five-body terms :math:`l_\mathrm{max}^\mathrm{5b}`.
121 n_descriptor_radial : int
122 Dimension of radial part of descriptor.
123 n_descriptor_angular : int
124 Dimension of angular part of descriptor.
125 n_neuron : int
126 Number of neurons in hidden layer.
127 n_parameters : int
128 Total number of parameters including scalers (which are not fit parameters).
129 n_descriptor_parameters : int
130 Number of parameters in descriptor.
131 n_ann_parameters : int
132 Number of neural network weights.
133 ann_parameters : Dict[Tuple[str, Dict[str, np.darray]]]
134 Neural network weights.
135 q_scaler : List[float]
136 Scaling parameters.
137 radial_descriptor_weights : Dict[Tuple[str, str], np.ndarray]
138 Radial descriptor weights by combination of species; the array for each combination
139 has dimensions of
140 :math:`(n_\mathrm{max}^\mathrm{R}+1) \times (n_\mathrm{basis}^\mathrm{R}+1)`.
141 angular_descriptor_weights : Dict[Tuple[str, str], np.ndarray]
142 Angular descriptor weights by combination of species; the array for each combination
143 has dimensions of
144 :math:`(n_\mathrm{max}^\mathrm{A}+1) \times (n_\mathrm{basis}^\mathrm{A}+1)`.
145 """
147 version: int
148 model_type: str
149 types: Tuple[str, ...]
151 radial_cutoff: float
152 angular_cutoff: float
154 n_basis_radial: int
155 n_basis_angular: int
156 n_max_radial: int
157 n_max_angular: int
158 l_max_3b: int
159 l_max_4b: int
160 l_max_5b: int
161 n_descriptor_radial: int
162 n_descriptor_angular: int
164 n_neuron: int
165 n_parameters: int
166 n_descriptor_parameters: int
167 n_ann_parameters: int
168 ann_parameters: Dict[str, Dict[str, np.ndarray]]
169 q_scaler: List[float]
170 radial_descriptor_weights: Dict[Tuple[str, str], np.ndarray]
171 angular_descriptor_weights: Dict[Tuple[str, str], np.ndarray]
173 zbl: Tuple[float, float] = None
174 zbl_typewise_cutoff_factor: float = None
175 max_neighbors_radial: int = None
176 max_neighbors_angular: int = None
177 radial_typewise_cutoff_factor: float = None
178 angular_typewise_cutoff_factor: float = None
180 _special_fields = [
181 'ann_parameters',
182 'q_scaler',
183 'radial_descriptor_weights',
184 'angular_descriptor_weights',
185 ]
187 def __str__(self) -> str:
188 s = []
189 for fld in self.__dataclass_fields__:
190 if fld not in self._special_fields:
191 s += [f'{fld:22} : {getattr(self, fld)}']
192 return '\n'.join(s)
194 def _repr_html_(self) -> str:
195 s = []
196 s += ['<table border="1" class="dataframe"']
197 s += [
198 '<thead><tr><th style="text-align: left;">Field</th><th>Value</th></tr></thead>'
199 ]
200 s += ['<tbody>']
201 for fld in self.__dataclass_fields__:
202 if fld not in self._special_fields:
203 s += [
204 f'<tr><td style="text-align: left;">{fld:22}</td>'
205 f'<td>{getattr(self, fld)}</td><tr>'
206 ]
207 for fld in self._special_fields:
208 d = getattr(self, fld)
209 if fld.endswith('descriptor_weights'):
210 dim = list(d.values())[0].shape
211 if fld == 'ann_parameters' and self.version == 4:
212 dim = (len(self.types), len(list(d.values())[0]))
213 else:
214 dim = len(d)
215 s += [
216 f'<tr><td style="text-align: left;">Dimension of {fld:22}</td><td>{dim}</td><tr>'
217 ]
218 s += ['</tbody>']
219 s += ['</table>']
220 return ''.join(s)
222 def write(self, filename: str) -> None:
223 """Write NEP model to file in ``nep.txt`` format."""
224 with open(filename, 'w') as f:
225 # header
226 version_name = f'nep{self.version}'
227 if self.zbl is not None:
228 version_name += '_zbl'
229 elif self.model_type != 'potential':
230 version_name += f'_{self.model_type}'
231 f.write(f'{version_name} {len(self.types)} {" ".join(self.types)}\n')
232 if self.zbl is not None:
233 f.write(f'zbl {" ".join(map(str, self.zbl))}\n')
234 f.write(f'cutoff {self.radial_cutoff} {self.angular_cutoff}')
235 f.write(f' {self.max_neighbors_radial} {self.max_neighbors_angular}')
236 if (
237 self.radial_typewise_cutoff_factor is not None
238 and self.angular_typewise_cutoff_factor is not None
239 ):
240 f.write(f' {self.radial_typewise_cutoff_factor}'
241 f' {self.angular_typewise_cutoff_factor}')
242 if self.zbl_typewise_cutoff_factor:
243 f.write(f' {self.zbl_typewise_cutoff_factor}')
244 f.write('\n')
245 f.write(f'n_max {self.n_max_radial} {self.n_max_angular}\n')
246 f.write(f'basis_size {self.n_basis_radial} {self.n_basis_angular}\n')
247 f.write(f'l_max {self.l_max_3b} {self.l_max_4b} {self.l_max_5b}\n')
248 f.write(f'ANN {self.n_neuron} 0\n')
250 # neural network weights
251 keys = self.types if self.version in (4, 5) else ['all_species']
252 suffixes = ['', '_polar'] if self.model_type == 'polarizability' else ['']
253 for suffix in suffixes:
254 for s in keys:
255 # Order: w0, b0, w1 (, b1 if NEP5)
256 # w0 indexed as: n*N_descriptor + nu
257 w0 = self.ann_parameters[s][f'w0{suffix}']
258 b0 = self.ann_parameters[s][f'b0{suffix}']
259 w1 = self.ann_parameters[s][f'w1{suffix}']
260 for n in range(self.n_neuron):
261 for nu in range(
262 self.n_descriptor_radial + self.n_descriptor_angular
263 ):
264 f.write(f'{w0[n, nu]:15.7e}\n')
265 for b in b0[:, 0]:
266 f.write(f'{b:15.7e}\n')
267 for v in w1[0, :]:
268 f.write(f'{v:15.7e}\n')
269 if self.version == 5:
270 b1 = self.ann_parameters[s][f'b1{suffix}']
271 f.write(f'{b1:15.7e}\n')
272 b1 = self.ann_parameters[f'b1{suffix}']
273 f.write(f'{b1:15.7e}\n')
275 # descriptor weights
276 mat = []
277 for s1 in self.types:
278 for s2 in self.types:
279 mat = np.hstack(
280 [mat, self.radial_descriptor_weights[(s1, s2)].flatten()]
281 )
282 mat = np.hstack(
283 [mat, self.angular_descriptor_weights[(s1, s2)].flatten()]
284 )
285 n_types = len(self.types)
286 n = int(len(mat) / (n_types * n_types))
287 mat = mat.reshape((n_types * n_types, n)).T
288 for v in mat.flatten():
289 f.write(f'{v:15.7e}\n')
291 # scaler
292 for v in self.q_scaler:
293 f.write(f'{v:15.7e}\n')
296def read_model(filename: str) -> Model:
297 """Parses a file in ``nep.txt`` format and returns the
298 content in the form of a :class:`Model <calorine.nep.model.Model>`
299 object.
301 Parameters
302 ----------
303 filename
304 Input file name.
305 """
306 data, parameters = _get_nep_contents(filename)
308 # sanity checks
309 for fld in ['cutoff', 'basis_size', 'n_max', 'l_max', 'ANN']:
310 assert fld in data, f'Invalid model file; {fld} line is missing'
311 assert data['version'] in [
312 3,
313 4,
314 5,
315 ], 'Invalid model file; only NEP versions 3, 4 and 5 are currently supported'
317 # split up cutoff tuple
318 assert len(data['cutoff']) in [4, 6, 7]
319 data['radial_cutoff'] = data['cutoff'][0]
320 data['angular_cutoff'] = data['cutoff'][1]
321 data['max_neighbors_radial'] = int(data['cutoff'][2])
322 data['max_neighbors_angular'] = int(data['cutoff'][3])
323 if len(data['cutoff']) >= 6:
324 data['radial_typewise_cutoff_factor'] = data['cutoff'][4]
325 data['angular_typewise_cutoff_factor'] = data['cutoff'][5]
326 if len(data['cutoff']) == 7:
327 data['zbl_typewise_cutoff_factor'] = data['cutoff'][6]
328 del data['cutoff']
330 # split up basis_size tuple
331 assert len(data['basis_size']) == 2
332 data['n_basis_radial'] = data['basis_size'][0]
333 data['n_basis_angular'] = data['basis_size'][1]
334 del data['basis_size']
336 # split up n_max tuple
337 assert len(data['n_max']) == 2
338 data['n_max_radial'] = data['n_max'][0]
339 data['n_max_angular'] = data['n_max'][1]
340 del data['n_max']
342 # split up nl_max tuple
343 len_l = len(data['l_max'])
344 assert len_l in [1, 2, 3]
345 data['l_max_3b'] = data['l_max'][0]
346 data['l_max_4b'] = data['l_max'][1] if len_l > 1 else 0
347 data['l_max_5b'] = data['l_max'][2] if len_l > 2 else 0
348 del data['l_max']
350 # compute dimensions of descriptor components
351 data['n_descriptor_radial'] = data['n_max_radial'] + 1
352 l_max_enh = data['l_max_3b'] + (data['l_max_4b'] > 0) + (data['l_max_5b'] > 0)
353 data['n_descriptor_angular'] = (data['n_max_angular'] + 1) * l_max_enh
354 n_descriptor = data['n_descriptor_radial'] + data['n_descriptor_angular']
356 # compute number of parameters
357 data['n_neuron'] = data['ANN'][0]
358 del data['ANN']
359 n_types = len(data['types'])
360 if data['version'] == 3:
361 n = 1
362 n_bias = 1
363 elif data['version'] == 4:
364 # one hidden layer per atomic species
365 n = n_types
366 n_bias = 1
367 else: # NEP5
368 # like nep4, but additionally has an
369 # individual bias term in the output
370 # layer for each species.
371 n = n_types
372 n_bias = 1 + n_types # one global bias + one per species
374 n_ann_input_weights = (n_descriptor + 1) * data['n_neuron'] # weights + bias
375 n_ann_output_weights = data['n_neuron'] # only weights
376 n_ann_parameters = (
377 n_ann_input_weights + n_ann_output_weights
378 ) * n + n_bias
380 n_descriptor_weights = n_types**2 * (
381 (data['n_max_radial'] + 1) * (data['n_basis_radial'] + 1)
382 + (data['n_max_angular'] + 1) * (data['n_basis_angular'] + 1)
383 )
384 data['n_parameters'] = n_ann_parameters + n_descriptor_weights + n_descriptor
385 is_polarizability_model = data['model_type'] == 'polarizability'
386 if data['n_parameters'] + n_ann_parameters == len(parameters):
387 data['n_parameters'] += n_ann_parameters
388 assert is_polarizability_model, (
389 'Model is not labelled as a polarizability model, but the number of '
390 'parameters matches a polarizability model.\n'
391 'If this is a polarizability model trained with GPUMD <=v3.8, please '
392 'modify the header in the nep.txt file to read '
393 f'`nep{data["version"]}_polarizability`.\n'
394 )
395 assert data['n_parameters'] == len(parameters), (
396 'Parsing of parameters inconsistent; please submit a bug report\n'
397 f'{data["n_parameters"]} != {len(parameters)}'
398 )
399 data['n_ann_parameters'] = n_ann_parameters
401 # split up parameters into the ANN weights, descriptor weights, and scaling parameters
402 n1 = n_ann_parameters
403 n1 *= 2 if is_polarizability_model else 1
404 n2 = n1 + n_descriptor_weights
405 data['ann_parameters'] = parameters[:n1]
406 descriptor_weights = np.array(parameters[n1:n2])
407 data['q_scaler'] = parameters[n2:]
409 # Group ANN parameters
410 pars = {}
411 n1 = 0
412 n_network_params = n_ann_input_weights + n_ann_output_weights # except last bias
413 n_neuron = data['n_neuron']
414 keys = data['types'] if data['version'] in (4, 5) else ['all_species']
416 n_count = 2 if is_polarizability_model else 1
417 for count in range(n_count):
418 # if polarizability model, all parameters including bias are repeated
419 # need to offset n1 by +1 to handle bias
420 n1 += count
421 for s in keys:
422 # Get the parameters for the ANN; in the case of NEP4, there is effectively
423 # one network per atomic species.
424 ann_parameters = data['ann_parameters'][n1 : n1 + n_network_params]
425 ann_input_weights = ann_parameters[:n_ann_input_weights]
426 w0 = np.zeros((n_neuron, n_descriptor))
427 w0[...] = np.nan
428 b0 = np.zeros((n_neuron, 1))
429 b0[...] = np.nan
430 for n in range(n_neuron):
431 for nu in range(n_descriptor):
432 w0[n, nu] = ann_input_weights[n * n_descriptor + nu]
433 b0[:, 0] = ann_input_weights[n_neuron * n_descriptor :]
435 assert np.all(
436 w0.shape == (n_neuron, n_descriptor)
437 ), f'w0 has invalid shape for key {s}; please submit a bug report'
438 assert np.all(
439 b0.shape == (n_neuron, 1)
440 ), f'b0 has invalid shape for key {s}; please submit a bug report'
441 assert not np.any(
442 np.isnan(w0)
443 ), f'some weights in w0 are nan for key {s}; please submit a bug report'
444 assert not np.any(
445 np.isnan(b0)
446 ), f'some weights in b0 are nan for key {s}; please submit a bug report'
448 ann_output_weights = ann_parameters[
449 n_ann_input_weights : n_ann_input_weights + n_ann_output_weights
450 ]
452 w1 = np.zeros((1, n_neuron))
453 w1[0, :] = ann_output_weights[:]
454 assert np.all(
455 w1.shape == (1, n_neuron)
456 ), f'w1 has invalid shape for key {s}; please submit a bug report'
457 assert not np.any(
458 np.isnan(w1)
459 ), f'some weights in w1 are nan for key {s}; please submit a bug report'
461 if count == 0:
462 pars[s] = dict(w0=w0, b0=b0, w1=w1)
463 else:
464 pars[s].update({'w0_polar': w0, 'b0_polar': b0, 'w1_polar': w1})
465 # Jump to bias
466 n1 += n_network_params
467 if n_bias > 1:
468 # For NEP5 models we additionally have one bias term per species.
469 # Currently NEP5 only exists for potential models, but we'll
470 # keep it here in case it gets added down the line.
471 bias_label = 'b1' if count == 0 else 'b1_polar'
472 pars[s][bias_label] = data['ann_parameters'][n1]
473 n1 += 1
474 # For NEP3 and NEP4 we only have one bias.
475 # For NEP5 we have one bias per species, and one global.
476 if count == 0:
477 pars['b1'] = data['ann_parameters'][n1]
478 else:
479 pars['b1_polar'] = data['ann_parameters'][n1]
480 sum = 0
481 for s in pars.keys():
482 if s.startswith('b1'):
483 sum += 1
484 else:
485 sum += np.sum([np.count_nonzero(p) for p in pars[s].values()])
486 assert sum == n_ann_parameters * n_count, (
487 'Inconsistent number of parameters accounted for; please submit a bug report\n'
488 f'{sum} != {n_ann_parameters}'
489 )
490 data['ann_parameters'] = pars
492 # split up descriptor by chemical species and radial/angular
493 data['n_descriptor_parameters'] = len(descriptor_weights)
494 n = int(len(descriptor_weights) / (n_types * n_types))
495 n_max_radial = data['n_max_radial']
496 n_max_angular = data['n_max_angular']
497 n_basis_radial = data['n_basis_radial']
498 n_basis_angular = data['n_basis_angular']
499 m = (n_max_radial + 1) * (n_basis_radial + 1)
500 descriptor_weights = descriptor_weights.reshape((n, n_types * n_types)).T
501 descriptor_weights_radial = descriptor_weights[:, :m]
502 descriptor_weights_angular = descriptor_weights[:, m:]
504 # add descriptors to data dict
505 data['radial_descriptor_weights'] = {}
506 data['angular_descriptor_weights'] = {}
507 m = -1
508 for i, j in product(range(n_types), repeat=2):
509 m += 1
510 s1, s2 = data['types'][i], data['types'][j]
511 subdata = descriptor_weights_radial[m, :].reshape(
512 (n_max_radial + 1, n_basis_radial + 1)
513 )
514 data['radial_descriptor_weights'][(s1, s2)] = subdata
515 subdata = descriptor_weights_angular[m, :].reshape(
516 (n_max_angular + 1, n_basis_angular + 1)
517 )
518 data['angular_descriptor_weights'][(s1, s2)] = subdata
520 return Model(**data)