Coverage for calorine / nep / model.py: 100%

442 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-06-04 15:27 +0000

1from dataclasses import dataclass 

2from itertools import product 

3 

4import numpy as np 

5 

6NetworkWeights = dict[str, dict[str, np.ndarray]] 

7DescriptorWeights = dict[tuple[str, str], np.ndarray] 

8RestartParameters = dict[str, dict[str, dict[str, np.ndarray]]] 

9 

10 

11def _get_restart_contents(filename: str) -> tuple[list[float], list[float]]: 

12 """Parses a ``nep.restart`` file, and returns an unformatted list of the 

13 mean and standard deviation for all model parameters. 

14 Intended to be used by the py:meth:`~Model.read_restart` function. 

15 

16 Parameters 

17 ---------- 

18 filename 

19 input file name 

20 """ 

21 mu = [] # Mean 

22 sigma = [] # Standard deviation 

23 with open(filename) as f: 

24 for k, line in enumerate(f.readlines()): 

25 flds = line.split() 

26 assert len(flds) != 0, f'Empty line number {k}' 

27 if len(flds) == 2: 

28 mu.append(float(flds[0])) 

29 sigma.append(float(flds[1])) 

30 else: 

31 raise IOError(f'Failed to parse line {k} from {filename}') 

32 return mu, sigma 

33 

34 

35def _get_model_type(first_row: list[str]) -> str: 

36 """Parses a the first row of a ``nep.txt`` file, and returns the 

37 type of NEP model. Available types are `potential`, `potential_with_charges`, 

38 `dipole`, and `polarizability`. 

39 

40 Parameters 

41 ---------- 

42 first_row 

43 First row of a NEP file, split by white space. 

44 """ 

45 model_type = first_row[0] 

46 if 'charge' in model_type: 

47 return 'potential_with_charges' 

48 elif 'dipole' in model_type: 

49 return 'dipole' 

50 elif 'polarizability' in model_type: 

51 return 'polarizability' 

52 return 'potential' 

53 

54 

55def _get_nep_contents(filename: str) -> tuple[dict, list[float]]: 

56 """Parses a ``nep.txt`` file, and returns a dict describing the header 

57 and an unformatted list of all model parameters. 

58 Intended to be used by the :func:`read_model <calorine.nep.read_model>` function. 

59 

60 Parameters 

61 ---------- 

62 filename 

63 input file name 

64 """ 

65 # parse file and split header and parameters 

66 header = [] 

67 parameters = [] 

68 nheader = 5 # 5 rows for NEP2, 6-7 rows for NEP3 onwards 

69 base_line = 3 

70 with open(filename) as f: 

71 for k, line in enumerate(f.readlines()): 

72 flds = line.split() 

73 assert len(flds) != 0, f'Empty line number {k}' 

74 if k == 0 and 'zbl' in flds[0]: 

75 base_line += 1 

76 nheader += 1 

77 if k == base_line and 'basis_size' in flds[0]: 

78 # Introduced in nep.txt after GPUMD v3.2 

79 nheader += 1 

80 if k < nheader: 

81 header.append(tuple(flds)) 

82 elif len(flds) == 1: 

83 parameters.append(float(flds[0])) 

84 else: 

85 raise IOError(f'Failed to parse line {k} from {filename}') 

86 # compile data from the header into a dict 

87 data = {} 

88 for flds in header: 

89 if flds[0] in ['cutoff', 'zbl']: 

90 data[flds[0]] = tuple(map(float, flds[1:])) 

91 elif flds[0] in ['n_max', 'l_max', 'ANN', 'basis_size']: 

92 data[flds[0]] = tuple(map(int, flds[1:])) 

93 elif flds[0].startswith('nep'): 

94 version = flds[0].replace('nep', '').split('_')[0] 

95 version = int(version) 

96 data['version'] = version 

97 data['types'] = flds[2:] 

98 data['model_type'] = _get_model_type(flds) 

99 else: 

100 raise ValueError(f'Unknown field: {flds[0]}') 

101 return data, parameters 

102 

103 

104def _sort_descriptor_parameters(parameters: list[float], 

105 types: list[str], 

106 n_max_radial: int, 

107 n_basis_radial: int, 

108 n_max_angular: int, 

109 n_basis_angular: int) -> tuple[DescriptorWeights, 

110 DescriptorWeights]: 

111 """Reads a list of descriptors parameters and sorts them into two 

112 appropriately structured `dicts`, one for radial and one for angular descriptor weights. 

113 Intended to be used by the :func:`read_model <calorine.nep.read_model>` function. 

114 """ 

115 # split up descriptor by chemical species and radial/angular 

116 n_types = len(types) 

117 n = len(parameters) / (n_types * n_types) 

118 assert n.is_integer(), 'number of descriptor groups must be an integer' 

119 n = int(n) 

120 

121 m = (n_max_radial + 1) * (n_basis_radial + 1) 

122 descriptor_weights = parameters.reshape((n, n_types * n_types)).T 

123 descriptor_weights_radial = descriptor_weights[:, :m] 

124 descriptor_weights_angular = descriptor_weights[:, m:] 

125 

126 # add descriptors to data dict 

127 radial_descriptor_weights = {} 

128 angular_descriptor_weights = {} 

129 m = -1 

130 for i, j in product(range(n_types), repeat=2): 

131 m += 1 

132 s1, s2 = types[i], types[j] 

133 radial_descriptor_weights[(s1, s2)] = descriptor_weights_radial[m, :].reshape( 

134 (n_max_radial + 1, n_basis_radial + 1) 

135 ) 

136 angular_descriptor_weights[(s1, s2)] = descriptor_weights_angular[m, :].reshape( 

137 (n_max_angular + 1, n_basis_angular + 1) 

138 ) 

139 return radial_descriptor_weights, angular_descriptor_weights 

140 

141 

142def _sort_ann_parameters(parameters: list[float], 

143 ann_groupings: list[str], 

144 n_neuron: int, 

145 n_networks: int, 

146 n_bias: int, 

147 n_descriptor: int, 

148 is_polarizability_model: bool, 

149 is_model_with_charges: bool 

150 ) -> NetworkWeights: 

151 """Reads a list of model parameters and sorts them into an appropriately structured `dict`. 

152 Intended to be used by the :func:`read_model <calorine.nep.read_model>` function. 

153 """ 

154 n_ann_input_weights = (n_descriptor + 1) * n_neuron # weights + bias 

155 n_ann_output_weights = 2*n_neuron if is_model_with_charges else n_neuron # only weights 

156 n_ann_parameters = ( 

157 n_ann_input_weights + n_ann_output_weights 

158 ) * n_networks + n_bias 

159 

160 # Group ANN parameters 

161 pars = {} 

162 n1 = 0 

163 n_network_params = n_ann_input_weights + n_ann_output_weights # except last bias(es) 

164 

165 n_count = 2 if is_polarizability_model else 1 

166 n_outputs = 2 if is_model_with_charges else 1 

167 for count in range(n_count): 

168 # if polarizability model, all parameters including bias are repeated 

169 # need to offset n1 by +1 to handle bias 

170 n1 += count 

171 for s in ann_groupings: 

172 # Get the parameters for the ANN; in the case of NEP4, there is effectively 

173 # one network per atomic species. 

174 ann_parameters = parameters[n1 : n1 + n_network_params] 

175 ann_input_weights = ann_parameters[:n_ann_input_weights] 

176 w0 = np.zeros((n_neuron, n_descriptor)) 

177 w0[...] = np.nan 

178 b0 = np.zeros((n_neuron, 1)) 

179 b0[...] = np.nan 

180 for n in range(n_neuron): 

181 for nu in range(n_descriptor): 

182 w0[n, nu] = ann_input_weights[n * n_descriptor + nu] 

183 b0[:, 0] = ann_input_weights[n_neuron * n_descriptor :] 

184 

185 assert np.all( 

186 w0.shape == (n_neuron, n_descriptor) 

187 ), f'w0 has invalid shape for key {s}; please submit a bug report' 

188 assert np.all( 

189 b0.shape == (n_neuron, 1) 

190 ), f'b0 has invalid shape for key {s}; please submit a bug report' 

191 assert not np.any( 

192 np.isnan(w0) 

193 ), f'some weights in w0 are nan for key {s}; please submit a bug report' 

194 assert not np.any( 

195 np.isnan(b0) 

196 ), f'some weights in b0 are nan for key {s}; please submit a bug report' 

197 

198 ann_output_weights = ann_parameters[ 

199 n_ann_input_weights : n_ann_input_weights + n_ann_output_weights 

200 ] 

201 w1 = np.zeros((1, n_neuron * n_outputs)) 

202 w1[0, :] = ann_output_weights[:] 

203 assert np.all( 

204 w1.shape == (1, n_neuron * n_outputs) 

205 ), f'w1 has invalid shape for key {s}; please submit a bug report' 

206 assert not np.any( 

207 np.isnan(w1) 

208 ), f'some weights in w1 are nan for key {s}; please submit a bug report' 

209 

210 if count == 0 and n_outputs == 1: 

211 pars[s] = dict(w0=w0, b0=b0, w1=w1) 

212 elif count == 0 and n_outputs == 2: 

213 pars[s] = dict(w0=w0, b0=b0, w1=w1[0, :n_neuron], w1_charge=w1[0, n_neuron:]) 

214 else: 

215 pars[s].update({'w0_polar': w0, 'b0_polar': b0, 'w1_polar': w1}) 

216 # Jump to bias 

217 n1 += n_network_params 

218 if n_bias > 1 and not is_model_with_charges: 

219 # For NEP5 models we additionally have one bias term per species. 

220 # Currently NEP5 only exists for potential models, but we'll 

221 # keep it here in case it gets added down the line. 

222 bias_label = 'b1' if count == 0 else 'b1_polar' 

223 pars[s][bias_label] = parameters[n1] 

224 n1 += 1 

225 # For NEP3 and NEP4 we only have one bias. 

226 # For NEP4 with charges we have two biases. 

227 # For NEP5 we have one bias per species, and one global. 

228 if count == 0 and n_outputs == 1: 

229 pars['b1'] = parameters[n1] 

230 elif count == 0 and n_outputs == 2: 

231 pars['sqrt_epsilon_infinity'] = parameters[n1] 

232 pars['b1'] = parameters[n1+1] 

233 else: 

234 pars['b1_polar'] = parameters[n1] 

235 sum = 0 

236 for s in pars.keys(): 

237 if s.startswith('b1') or s.startswith('sqrt'): 

238 sum += 1 

239 else: 

240 sum += np.sum([np.array(p).size for p in pars[s].values()]) 

241 assert sum == n_ann_parameters * n_count, ( 

242 'Inconsistent number of parameters accounted for; please submit a bug report\n' 

243 f'{sum} != {n_ann_parameters}' 

244 ) 

245 return pars 

246 

247 

248@dataclass 

249class Model: 

250 r"""Objects of this class represent a NEP model in a form suitable for 

251 inspection and manipulation. Typically a :class:`Model` object is instantiated 

252 by calling the :func:`read_model <calorine.nep.read_model>` function. 

253 

254 Attributes 

255 ---------- 

256 version : int 

257 NEP version. 

258 model_type: str 

259 One of ``potential``, ``dipole`` or ``polarizability``. 

260 types : tuple[str, ...] 

261 Chemical species that this model represents. 

262 radial_cutoff : float | list[float] 

263 The radial cutoff parameter in Å. 

264 Is a list of radial cutoffs ordered after ``types`` in the case of typewise cutoffs. 

265 angular_cutoff : float | list[float] 

266 The angular cutoff parameter in Å. 

267 Is a list of angular cutoffs ordered after ``types`` in the case of typewise cutoffs. 

268 max_neighbors_radial : int 

269 Maximum number of neighbors in neighbor list for radial terms. 

270 max_neighbors_angular : int 

271 Maximum number of neighbors in neighbor list for angular terms. 

272 radial_typewise_cutoff_factor : float 

273 The radial cutoff factor if use_typewise_cutoff is used. 

274 angular_typewise_cutoff_factor : float 

275 The angular cutoff factor if use_typewise_cutoff is used. 

276 zbl : tuple[float, float] 

277 Inner and outer cutoff for transition to ZBL potential. 

278 zbl_typewise_cutoff_factor : float 

279 Typewise cutoff when use_typewise_cutoff_zbl is used. 

280 n_basis_radial : int 

281 Number of radial basis functions :math:`n_\mathrm{basis}^\mathrm{R}`. 

282 n_basis_angular : int 

283 Number of angular basis functions :math:`n_\mathrm{basis}^\mathrm{A}`. 

284 n_max_radial : int 

285 Maximum order of Chebyshev polymonials included in 

286 radial expansion :math:`n_\mathrm{max}^\mathrm{R}`. 

287 n_max_angular : int 

288 Maximum order of Chebyshev polymonials included in 

289 angular expansion :math:`n_\mathrm{max}^\mathrm{A}`. 

290 l_max_3b : int 

291 Maximum expansion order for three-body terms :math:`l_\mathrm{max}^\mathrm{3b}`. 

292 l_max_4b : int 

293 Maximum expansion order for four-body terms :math:`l_\mathrm{max}^\mathrm{4b}`. 

294 l_max_5b : int 

295 Maximum expansion order for five-body terms :math:`l_\mathrm{max}^\mathrm{5b}`. 

296 has_q_112 : int 

297 Flag enabling the 5-body :math:`q_{112}` descriptor (0 or 1). 

298 has_q_123 : int 

299 Flag enabling the 5-body :math:`q_{123}` descriptor (0 or 1). 

300 has_q_233 : int 

301 Flag enabling the 5-body :math:`q_{233}` descriptor (0 or 1). 

302 has_q_134 : int 

303 Flag enabling the higher-body :math:`q_{134}` descriptor (0 or 1). 

304 n_descriptor_radial : int 

305 Dimension of radial part of descriptor. 

306 n_descriptor_angular : int 

307 Dimension of angular part of descriptor. 

308 n_neuron : int 

309 Number of neurons in hidden layer. 

310 n_parameters : int 

311 Total number of parameters including scalers (which are not fit parameters). 

312 n_descriptor_parameters : int 

313 Number of parameters in descriptor. 

314 n_ann_parameters : int 

315 Number of neural network weights. 

316 ann_parameters : dict[tuple[str, dict[str, np.darray]]] 

317 Neural network weights. 

318 q_scaler : List[float] 

319 Scaling parameters. 

320 radial_descriptor_weights : dict[tuple[str, str], np.ndarray] 

321 Radial descriptor weights by combination of species; the array for each combination 

322 has dimensions of 

323 :math:`(n_\mathrm{max}^\mathrm{R}+1) \times (n_\mathrm{basis}^\mathrm{R}+1)`. 

324 angular_descriptor_weights : dict[tuple[str, str], np.ndarray] 

325 Angular descriptor weights by combination of species; the array for each combination 

326 has dimensions of 

327 :math:`(n_\mathrm{max}^\mathrm{A}+1) \times (n_\mathrm{basis}^\mathrm{A}+1)`. 

328 sqrt_epsilon_infinity : Optional[float] 

329 Square root of epsilon infinity $\epsilon_\infty$ (only for NEP models with charges). 

330 restart_parameters : dict[str, dict[str, dict[str, np.ndarray]]] 

331 NEP restart parameters. A nested dictionary that contains the mean (mu) and standard 

332 deviation (sigma) for the ANN and descriptor parameters. Is set using the 

333 py:meth:`~Model.read_restart` method. Defaults to None. 

334 """ 

335 

336 version: int 

337 model_type: str 

338 types: tuple[str, ...] 

339 

340 radial_cutoff: float | list[float] 

341 angular_cutoff: float | list[float] 

342 

343 n_basis_radial: int 

344 n_basis_angular: int 

345 n_max_radial: int 

346 n_max_angular: int 

347 l_max_3b: int 

348 l_max_4b: int 

349 l_max_5b: int 

350 has_q_112: int 

351 has_q_123: int 

352 has_q_233: int 

353 has_q_134: int 

354 n_descriptor_radial: int 

355 n_descriptor_angular: int 

356 

357 n_neuron: int 

358 n_parameters: int 

359 n_descriptor_parameters: int 

360 n_ann_parameters: int 

361 ann_parameters: NetworkWeights 

362 q_scaler: list[float] 

363 radial_descriptor_weights: DescriptorWeights 

364 angular_descriptor_weights: DescriptorWeights 

365 sqrt_epsilon_infinity: float = None 

366 restart_parameters: RestartParameters = None 

367 

368 zbl: tuple[float, float] = None 

369 zbl_typewise_cutoff_factor: float = None 

370 max_neighbors_radial: int = None 

371 max_neighbors_angular: int = None 

372 radial_typewise_cutoff_factor: float = None 

373 angular_typewise_cutoff_factor: float = None 

374 

375 _special_fields = [ 

376 'ann_parameters', 

377 'q_scaler', 

378 'radial_descriptor_weights', 

379 'angular_descriptor_weights', 

380 ] 

381 

382 def __str__(self) -> str: 

383 s = [] 

384 for fld in self.__dataclass_fields__: 

385 if fld not in self._special_fields: 

386 s += [f'{fld:22} : {getattr(self, fld)}'] 

387 return '\n'.join(s) 

388 

389 def _repr_html_(self) -> str: 

390 s = [] 

391 s += ['<table border="1" class="dataframe"'] 

392 s += [ 

393 '<thead><tr><th style="text-align: left;">Field</th><th>Value</th></tr></thead>' 

394 ] 

395 s += ['<tbody>'] 

396 for fld in self.__dataclass_fields__: 

397 if fld not in self._special_fields: 

398 s += [ 

399 f'<tr><td style="text-align: left;">{fld:22}</td>' 

400 f'<td>{getattr(self, fld)}</td><tr>' 

401 ] 

402 for fld in self._special_fields: 

403 d = getattr(self, fld) 

404 # print('xxx', fld, d) 

405 if fld.endswith('descriptor_weights'): 

406 dim = list(d.values())[0].shape 

407 elif fld == 'ann_parameters' and self.version == 4: 

408 dim = (len(self.types), len(list(d.values())[0])) 

409 else: 

410 dim = len(d) 

411 s += [ 

412 f'<tr><td style="text-align: left;">Dimension of {fld:22}</td><td>{dim}</td><tr>' 

413 ] 

414 s += ['</tbody>'] 

415 s += ['</table>'] 

416 return ''.join(s) 

417 

418 def remove_species(self, species: list[str]): 

419 """Removes one or more species from the model. 

420 

421 This method modifies the model in-place by removing all parameters 

422 associated with the specified chemical species. It prunes the species 

423 list, the Artificial Neural Network (ANN) parameters, and the 

424 descriptor weights. It also recalculates the total number of 

425 parameters in the model. 

426 

427 Parameters 

428 ---------- 

429 species 

430 A list of species names (str) to remove from the model. 

431 

432 Raises 

433 ------ 

434 ValueError 

435 If any of the provided species is not found in the model. 

436 """ 

437 for s in species: 

438 if s not in self.types: 

439 raise ValueError(f'{s} is not a species supported by the NEP model') 

440 

441 # --- Prune attributes based on species --- 

442 types_to_keep = [t for t in self.types if t not in species] 

443 self.types = tuple(types_to_keep) 

444 

445 # Prune ANN parameters (for NEP4 and NEP5) 

446 if self.version in [4, 5]: 

447 self.ann_parameters = { 

448 key: value for key, value in self.ann_parameters.items() 

449 if key in types_to_keep or key.startswith('b1') 

450 } 

451 

452 # Prune descriptor weights 

453 # key is here a tuple, (species1, species2) 

454 self.radial_descriptor_weights = { 

455 key: value for key, value in self.radial_descriptor_weights.items() 

456 if key[0] in types_to_keep and key[1] in types_to_keep 

457 } 

458 self.angular_descriptor_weights = { 

459 key: value for key, value in self.angular_descriptor_weights.items() 

460 if key[0] in types_to_keep and key[1] in types_to_keep 

461 } 

462 

463 # Prune restart parameters if they have been loaded 

464 if self.restart_parameters is not None: 

465 for param_type in ['mu', 'sigma']: 

466 # Prune ANN restart parameters 

467 ann_key = f'ann_{param_type}' 

468 if self.version in [4, 5]: 

469 self.restart_parameters[ann_key] = { 

470 key: value for key, value in self.restart_parameters[ann_key].items() 

471 if key in types_to_keep or key.startswith('b1') 

472 } 

473 

474 # Prune descriptor restart parameters 

475 for desc_type in ['radial', 'angular']: 

476 key = f'{desc_type}_descriptor_{param_type}' 

477 self.restart_parameters[key] = { 

478 k: v for k, v in self.restart_parameters[key].items() 

479 if k[0] in types_to_keep and k[1] in types_to_keep 

480 } 

481 

482 # --- Recalculate parameter counts --- 

483 n_types = len(self.types) 

484 n_descriptor = self.n_descriptor_radial + self.n_descriptor_angular 

485 

486 # Recalculate descriptor parameter count 

487 self.n_descriptor_parameters = n_types**2 * ( 

488 (self.n_max_radial + 1) * (self.n_basis_radial + 1) 

489 + (self.n_max_angular + 1) * (self.n_basis_angular + 1) 

490 ) 

491 

492 # Recalculate ANN parameter count 

493 if self.version == 3: 

494 n_networks = 1 

495 n_bias = 1 

496 elif self.version == 4: 

497 n_networks = n_types 

498 n_bias = 1 

499 else: # NEP5 

500 n_networks = n_types 

501 n_bias = 1 + n_types 

502 

503 n_ann_input_weights = (n_descriptor + 1) * self.n_neuron 

504 n_ann_output_weights = self.n_neuron 

505 self.n_ann_parameters = ( 

506 n_ann_input_weights + n_ann_output_weights 

507 ) * n_networks + n_bias 

508 

509 # Recalculate total parameter count 

510 self.n_parameters = ( 

511 self.n_ann_parameters 

512 + self.n_descriptor_parameters 

513 + n_descriptor # q_scaler parameters 

514 ) 

515 if self.model_type == 'polarizability': 

516 self.n_parameters += self.n_ann_parameters 

517 

518 def write(self, filename: str, restart_file: str = None) -> None: 

519 """Write NEP model to file in `nep.txt` format. 

520 

521 Parameters 

522 ---------- 

523 filename 

524 Output file name for the NEP model. 

525 restart_file 

526 If provided, also write restart parameters to this file in 

527 `nep.restart` format. Defaults to None. 

528 """ 

529 with open(filename, 'w') as f: 

530 # header 

531 version_name = f'nep{self.version}' 

532 if self.zbl is not None: 

533 version_name += '_zbl' 

534 elif self.model_type != 'potential': 

535 version_name += f'_{self.model_type}' 

536 f.write(f'{version_name} {len(self.types)} {" ".join(self.types)}\n') 

537 if self.zbl is not None: 

538 f.write(f'zbl {" ".join(map(str, self.zbl))}\n') 

539 f.write('cutoff') 

540 if isinstance(self.radial_cutoff, float) and isinstance(self.angular_cutoff, float): 

541 f.write(f' {self.radial_cutoff} {self.angular_cutoff}') 

542 else: 

543 # Typewise cutoffs: one set of cutoffs per type 

544 for i in range(len(self.types)): 

545 f.write(f' {self.radial_cutoff[i]} {self.angular_cutoff[i]}') 

546 f.write(f' {self.max_neighbors_radial} {self.max_neighbors_angular}') 

547 f.write('\n') 

548 f.write(f'n_max {self.n_max_radial} {self.n_max_angular}\n') 

549 f.write(f'basis_size {self.n_basis_radial} {self.n_basis_angular}\n') 

550 l_max_line = f'l_max {self.l_max_3b} {self.l_max_4b} {self.l_max_5b}' 

551 if self.has_q_112 or self.has_q_123 or self.has_q_233 or self.has_q_134: 

552 l_max_line += f' {self.has_q_112}' 

553 if self.has_q_123 or self.has_q_233 or self.has_q_134: 

554 l_max_line += f' {self.has_q_123}' 

555 if self.has_q_233 or self.has_q_134: 

556 l_max_line += f' {self.has_q_233}' 

557 if self.has_q_134: 

558 l_max_line += f' {self.has_q_134}' 

559 f.write(l_max_line + '\n') 

560 f.write(f'ANN {self.n_neuron} 0\n') 

561 

562 # neural network weights 

563 keys = self.types if self.version in (4, 5) else ['all_species'] 

564 suffixes = ['', '_polar'] if self.model_type == 'polarizability' else [''] 

565 for suffix in suffixes: 

566 for s in keys: 

567 # Order: w0, b0, w1 (, b1 if NEP5) 

568 # w0 indexed as: n*N_descriptor + nu 

569 w0 = self.ann_parameters[s][f'w0{suffix}'] 

570 b0 = self.ann_parameters[s][f'b0{suffix}'] 

571 w1 = self.ann_parameters[s][f'w1{suffix}'] 

572 for n in range(self.n_neuron): 

573 for nu in range( 

574 self.n_descriptor_radial + self.n_descriptor_angular 

575 ): 

576 f.write(f'{w0[n, nu]:15.7e}\n') 

577 for b in b0[:, 0]: 

578 f.write(f'{b:15.7e}\n') 

579 for v in w1[0, :]: 

580 f.write(f'{v:15.7e}\n') 

581 if self.version == 5: 

582 b1 = self.ann_parameters[s][f'b1{suffix}'] 

583 f.write(f'{b1:15.7e}\n') 

584 b1 = self.ann_parameters[f'b1{suffix}'] 

585 f.write(f'{b1:15.7e}\n') 

586 

587 # descriptor weights 

588 mat = [] 

589 for s1 in self.types: 

590 for s2 in self.types: 

591 mat = np.hstack( 

592 [mat, self.radial_descriptor_weights[(s1, s2)].flatten()] 

593 ) 

594 mat = np.hstack( 

595 [mat, self.angular_descriptor_weights[(s1, s2)].flatten()] 

596 ) 

597 n_types = len(self.types) 

598 n = int(len(mat) / (n_types * n_types)) 

599 mat = mat.reshape((n_types * n_types, n)).T 

600 for v in mat.flatten(): 

601 f.write(f'{v:15.7e}\n') 

602 

603 # scaler 

604 for v in self.q_scaler: 

605 f.write(f'{v:15.7e}\n') 

606 

607 if restart_file is not None: 

608 self.write_restart(restart_file) 

609 

610 def read_restart(self, filename: str): 

611 """Parses a file in `nep.restart` format and saves the 

612 content in the form of mean and standard deviation for each 

613 parameter in the corresponding NEP model. 

614 

615 Parameters 

616 ---------- 

617 filename 

618 Input file name. 

619 """ 

620 mu, sigma = _get_restart_contents(filename) 

621 restart_parameters = np.array([mu, sigma]).T 

622 

623 is_polarizability_model = self.model_type == 'polarizability' 

624 is_charged_model = self.model_type == 'potential_with_charges' 

625 

626 n1 = self.n_ann_parameters 

627 n1 *= 2 if is_polarizability_model else 1 

628 n2 = n1 + self.n_descriptor_parameters 

629 ann_parameters = restart_parameters[:n1] 

630 descriptor_parameters = np.array(restart_parameters[n1:n2]) 

631 

632 if self.version == 3: 

633 n_networks = 1 

634 n_bias = 1 

635 elif self.version == 4: 

636 # one hidden layer per atomic species 

637 n_networks = len(self.types) 

638 n_bias = 1 

639 else: 

640 raise ValueError(f'Cannot load nep.restart for NEP model version {self.version}') 

641 

642 ann_groups = [s for s in self.ann_parameters.keys() if not s.startswith('b1')] 

643 n_bias = len([s for s in self.ann_parameters.keys() if s.startswith('b1')]) 

644 n_descriptor = self.n_descriptor_radial + self.n_descriptor_angular 

645 restart = {} 

646 

647 for i, content_type in enumerate(['mu', 'sigma']): 

648 ann = _sort_ann_parameters(ann_parameters[:, i], 

649 ann_groups, 

650 self.n_neuron, 

651 n_networks, 

652 n_bias, 

653 n_descriptor, 

654 is_polarizability_model, 

655 is_charged_model) 

656 radial, angular = _sort_descriptor_parameters(descriptor_parameters[:, i], 

657 self.types, 

658 self.n_max_radial, 

659 self.n_basis_radial, 

660 self.n_max_angular, 

661 self.n_basis_angular) 

662 

663 restart[f'ann_{content_type}'] = ann 

664 restart[f'radial_descriptor_{content_type}'] = radial 

665 restart[f'angular_descriptor_{content_type}'] = angular 

666 self.restart_parameters = restart 

667 

668 def write_restart(self, filename: str): 

669 """Write NEP restart parameters to file in `nep.restart` format.""" 

670 keys = self.types if self.version in (4, 5) else ['all_species'] 

671 suffixes = ['', '_polar'] if self.model_type == 'polarizability' else [''] 

672 columns = [] 

673 for i, parameter in enumerate(['mu', 'sigma']): 

674 # neural network weights 

675 ann_parameters = self.restart_parameters[f'ann_{parameter}'] 

676 column = [] 

677 for suffix in suffixes: 

678 for s in keys: 

679 # Order: w0, b0, w1 (, b1 if NEP5) 

680 # w0 indexed as: n*N_descriptor + nu 

681 w0 = ann_parameters[s][f'w0{suffix}'] 

682 b0 = ann_parameters[s][f'b0{suffix}'] 

683 w1 = ann_parameters[s][f'w1{suffix}'] 

684 for n in range(self.n_neuron): 

685 for nu in range( 

686 self.n_descriptor_radial + self.n_descriptor_angular 

687 ): 

688 column.append(f'{w0[n, nu]:15.7e}') 

689 for b in b0[:, 0]: 

690 column.append(f'{b:15.7e}') 

691 for v in w1[0, :]: 

692 column.append(f'{v:15.7e}') 

693 b1 = ann_parameters[f'b1{suffix}'] 

694 column.append(f'{b1:15.7e}') 

695 columns.append(column) 

696 

697 # descriptor weights 

698 radial_descriptor_parameters = self.restart_parameters[f'radial_descriptor_{parameter}'] 

699 angular_descriptor_parameters = self.restart_parameters[ 

700 f'angular_descriptor_{parameter}'] 

701 mat = [] 

702 for s1 in self.types: 

703 for s2 in self.types: 

704 mat = np.hstack( 

705 [mat, radial_descriptor_parameters[(s1, s2)].flatten()] 

706 ) 

707 mat = np.hstack( 

708 [mat, angular_descriptor_parameters[(s1, s2)].flatten()] 

709 ) 

710 n_types = len(self.types) 

711 n = int(len(mat) / (n_types * n_types)) 

712 mat = mat.reshape((n_types * n_types, n)).T 

713 for v in mat.flatten(): 

714 column.append(f'{v:15.7e}') 

715 

716 # Join the mean and standard deviation columns 

717 assert len(columns[0]) == len(columns[1]), 'Length of means must match standard deviation' 

718 joined = [f'{s1} {s2}\n' for s1, s2 in zip(*columns)] 

719 with open(filename, 'w') as f: 

720 f.writelines(joined) 

721 

722 

723def read_model(filename: str, restart_file: str = None) -> Model: 

724 """Parses a file in ``nep.txt`` format and returns the 

725 content in the form of a :class:`Model <calorine.nep.model.Model>` 

726 object. 

727 

728 Parameters 

729 ---------- 

730 filename 

731 Input file name. 

732 restart_file 

733 If provided, also read restart parameters from this file in 

734 `nep.restart` format and attach them to the returned model. 

735 Defaults to None. 

736 """ 

737 data, parameters = _get_nep_contents(filename) 

738 

739 # sanity checks 

740 for fld in ['cutoff', 'basis_size', 'n_max', 'l_max', 'ANN']: 

741 assert fld in data, f'Invalid model file; {fld} line is missing' 

742 assert data['version'] in [ 

743 3, 

744 4, 

745 5, 

746 ], 'Invalid model file; only NEP versions 3, 4 and 5 are currently supported' 

747 

748 # split up cutoff tuple 

749 N_types = len(data['types']) 

750 # Either global cutoffs + max neighbirs, or typewise cutoffs + max_neighbors 

751 assert len(data['cutoff']) in [4, 2*N_types+2] 

752 data['max_neighbors_radial'] = int(data['cutoff'][-2]) 

753 data['max_neighbors_angular'] = int(data['cutoff'][-1]) 

754 if len(data['cutoff']) == 2*N_types+2: 

755 # Typewise cutoffs: radial are even, angular are odd 

756 data['radial_cutoff'] = [data['cutoff'][i*2] for i in range(N_types)] 

757 data['angular_cutoff'] = [data['cutoff'][i*2+1] for i in range(N_types)] 

758 else: 

759 data['radial_cutoff'] = data['cutoff'][0] 

760 data['angular_cutoff'] = data['cutoff'][1] 

761 del data['cutoff'] 

762 

763 # split up basis_size tuple 

764 assert len(data['basis_size']) == 2 

765 data['n_basis_radial'] = data['basis_size'][0] 

766 data['n_basis_angular'] = data['basis_size'][1] 

767 del data['basis_size'] 

768 

769 # split up n_max tuple 

770 assert len(data['n_max']) == 2 

771 data['n_max_radial'] = data['n_max'][0] 

772 data['n_max_angular'] = data['n_max'][1] 

773 del data['n_max'] 

774 

775 # split up nl_max tuple 

776 len_l = len(data['l_max']) 

777 assert len_l in [1, 2, 3, 4, 5, 6, 7] 

778 data['l_max_3b'] = data['l_max'][0] 

779 data['l_max_4b'] = data['l_max'][1] if len_l > 1 else 0 

780 data['l_max_5b'] = data['l_max'][2] if len_l > 2 else 0 

781 data['has_q_112'] = data['l_max'][3] if len_l > 3 else 0 

782 data['has_q_123'] = data['l_max'][4] if len_l > 4 else 0 

783 data['has_q_233'] = data['l_max'][5] if len_l > 5 else 0 

784 data['has_q_134'] = data['l_max'][6] if len_l > 6 else 0 

785 del data['l_max'] 

786 

787 # compute dimensions of descriptor components 

788 data['n_descriptor_radial'] = data['n_max_radial'] + 1 

789 l_max_enh = (data['l_max_3b'] 

790 + (data['l_max_4b'] > 0) 

791 + (data['l_max_5b'] > 0) 

792 + (data['has_q_112'] > 0) 

793 + (data['has_q_123'] > 0) 

794 + (data['has_q_233'] > 0) 

795 + (data['has_q_134'] > 0)) 

796 data['n_descriptor_angular'] = (data['n_max_angular'] + 1) * l_max_enh 

797 n_descriptor = data['n_descriptor_radial'] + data['n_descriptor_angular'] 

798 

799 is_charged_model = data['model_type'] == 'potential_with_charges' 

800 # compute number of parameters 

801 data['n_neuron'] = data['ANN'][0] 

802 del data['ANN'] 

803 n_types = len(data['types']) 

804 if data['version'] == 3: 

805 n = 1 

806 n_bias = 1 

807 elif data['version'] == 4 and is_charged_model: 

808 # one hidden layer per atomic species, but two output nodes 

809 n = n_types 

810 n_bias = 2 

811 elif data['version'] == 4: 

812 # one hidden layer per atomic species 

813 n = n_types 

814 n_bias = 1 

815 else: # NEP5 

816 # like nep4, but additionally has an 

817 # individual bias term in the output 

818 # layer for each species. 

819 n = n_types 

820 n_bias = 1 + n_types # one global bias + one per species 

821 

822 n_ann_input_weights = (n_descriptor + 1) * data['n_neuron'] # weights + bias 

823 n_ann_output_weights = 2*data['n_neuron'] if is_charged_model else data['n_neuron'] # weights 

824 n_ann_parameters = ( 

825 n_ann_input_weights + n_ann_output_weights 

826 ) * n + n_bias 

827 

828 n_descriptor_weights = n_types**2 * ( 

829 (data['n_max_radial'] + 1) * (data['n_basis_radial'] + 1) 

830 + (data['n_max_angular'] + 1) * (data['n_basis_angular'] + 1) 

831 ) 

832 data['n_parameters'] = n_ann_parameters + n_descriptor_weights + n_descriptor 

833 is_polarizability_model = data['model_type'] == 'polarizability' 

834 if data['n_parameters'] + n_ann_parameters == len(parameters): 

835 data['n_parameters'] += n_ann_parameters 

836 assert is_polarizability_model, ( 

837 'Model is not labelled as a polarizability model, but the number of ' 

838 'parameters matches a polarizability model.\n' 

839 'If this is a polarizability model trained with GPUMD <=v3.8, please ' 

840 'modify the header in the nep.txt file to enable parsing ' 

841 f'`nep{data["version"]}_polarizability`.\n' 

842 ) 

843 assert data['n_parameters'] == len(parameters), ( 

844 'Parsing of parameters inconsistent; please submit a bug report\n' 

845 f'{data["n_parameters"]} != {len(parameters)}' 

846 ) 

847 data['n_ann_parameters'] = n_ann_parameters 

848 

849 # split up parameters into the ANN weights, descriptor weights, and scaling parameters 

850 n1 = n_ann_parameters 

851 n1 *= 2 if is_polarizability_model else 1 

852 n2 = n1 + n_descriptor_weights 

853 data['ann_parameters'] = parameters[:n1] 

854 descriptor_weights = np.array(parameters[n1:n2]) 

855 data['q_scaler'] = parameters[n2:] 

856 

857 # add ann parameters to data dict 

858 ann_groups = data['types'] if data['version'] in (4, 5) else ['all_species'] 

859 sorted_ann_parameters = _sort_ann_parameters(data['ann_parameters'], 

860 ann_groups, 

861 data['n_neuron'], 

862 n, 

863 n_bias, 

864 n_descriptor, 

865 is_polarizability_model, 

866 is_charged_model) 

867 

868 data['ann_parameters'] = sorted_ann_parameters 

869 if 'sqrt_epsilon_infinity' in sorted_ann_parameters.keys(): 

870 data['sqrt_epsilon_infinity'] = sorted_ann_parameters['sqrt_epsilon_infinity'] 

871 sorted_ann_parameters.pop('sqrt_epsilon_infinity') 

872 data['ann_parameters'] = sorted_ann_parameters 

873 

874 # add descriptors to data dict 

875 data['n_descriptor_parameters'] = len(descriptor_weights) 

876 radial, angular = _sort_descriptor_parameters(descriptor_weights, 

877 data['types'], 

878 data['n_max_radial'], 

879 data['n_basis_radial'], 

880 data['n_max_angular'], 

881 data['n_basis_angular']) 

882 data['radial_descriptor_weights'] = radial 

883 data['angular_descriptor_weights'] = angular 

884 

885 model = Model(**data) 

886 if restart_file is not None: 

887 model.read_restart(restart_file) 

888 return model