Coverage for calorine/nep/model.py: 100%

403 statements  

« prev     ^ index     » next       coverage.py v7.10.5, created at 2025-09-12 22:43 +0000

1from dataclasses import dataclass 

2from itertools import product 

3 

4import numpy as np 

5 

6NetworkWeights = dict[str, dict[str, np.ndarray]] 

7DescriptorWeights = dict[tuple[str, str], np.ndarray] 

8RestartParameters = dict[str, dict[str, dict[str, np.ndarray]]] 

9 

10 

11def _get_restart_contents(filename: str) -> tuple[list[float], list[float]]: 

12 """Parses a ``nep.restart`` file, and returns an unformatted list of the 

13 mean and standard deviation for all model parameters. 

14 Intended to be used by the py:meth:`~Model.load_restart` function. 

15 

16 Parameters 

17 ---------- 

18 filename 

19 input file name 

20 """ 

21 mu = [] # Mean 

22 sigma = [] # Standard deviation 

23 with open(filename) as f: 

24 for k, line in enumerate(f.readlines()): 

25 flds = line.split() 

26 assert len(flds) != 0, f'Empty line number {k}' 

27 if len(flds) == 2: 

28 mu.append(float(flds[0])) 

29 sigma.append(float(flds[1])) 

30 else: 

31 raise IOError(f'Failed to parse line {k} from {filename}') 

32 return mu, sigma 

33 

34 

35def _get_model_type(first_row: list[str]) -> str: 

36 """Parses a the first row of a ``nep.txt`` file, 

37 and returns the type of NEP model. Available types 

38 are ``potential``, ``dipole`` and ``polarizability``. 

39 

40 Parameters 

41 ---------- 

42 first_row 

43 first row of a NEP file, split by white space. 

44 """ 

45 model_type = first_row[0] 

46 if 'dipole' in model_type: 

47 return 'dipole' 

48 elif 'polarizability' in model_type: 

49 return 'polarizability' 

50 return 'potential' 

51 

52 

53def _get_nep_contents(filename: str) -> tuple[dict, list[float]]: 

54 """Parses a ``nep.txt`` file, and returns a dict describing the header 

55 and an unformatted list of all model parameters. 

56 Intended to be used by the :func:`read_model <calorine.nep.read_model>` function. 

57 

58 Parameters 

59 ---------- 

60 filename 

61 input file name 

62 """ 

63 # parse file and split header and parameters 

64 header = [] 

65 parameters = [] 

66 nheader = 5 # 5 rows for NEP2, 6-7 rows for NEP3 onwards 

67 base_line = 3 

68 with open(filename) as f: 

69 for k, line in enumerate(f.readlines()): 

70 flds = line.split() 

71 assert len(flds) != 0, f'Empty line number {k}' 

72 if k == 0 and 'zbl' in flds[0]: 

73 base_line += 1 

74 nheader += 1 

75 if k == base_line and 'basis_size' in flds[0]: 

76 # Introduced in nep.txt after GPUMD v3.2 

77 nheader += 1 

78 if k < nheader: 

79 header.append(tuple(flds)) 

80 elif len(flds) == 1: 

81 parameters.append(float(flds[0])) 

82 else: 

83 raise IOError(f'Failed to parse line {k} from {filename}') 

84 # compile data from the header into a dict 

85 data = {} 

86 for flds in header: 

87 if flds[0] in ['cutoff', 'zbl']: 

88 data[flds[0]] = tuple(map(float, flds[1:])) 

89 elif flds[0] in ['n_max', 'l_max', 'ANN', 'basis_size']: 

90 data[flds[0]] = tuple(map(int, flds[1:])) 

91 elif flds[0].startswith('nep'): 

92 version = flds[0].replace('nep', '').split('_')[0] 

93 version = int(version) 

94 data['version'] = version 

95 data['types'] = flds[2:] 

96 data['model_type'] = _get_model_type(flds) 

97 else: 

98 raise ValueError(f'Unknown field: {flds[0]}') 

99 return data, parameters 

100 

101 

102def _sort_descriptor_parameters(parameters: list[float], 

103 types: list[str], 

104 n_max_radial: int, 

105 n_basis_radial: int, 

106 n_max_angular: int, 

107 n_basis_angular: int) -> tuple[DescriptorWeights, 

108 DescriptorWeights]: 

109 """Reads a list of descriptors parameters and sorts them into two 

110 appropriately structured `dicts`, one for radial and one for angular descriptor weights. 

111 Intended to be used by the :func:`read_model <calorine.nep.read_model>` function. 

112 """ 

113 # split up descriptor by chemical species and radial/angular 

114 n_types = len(types) 

115 n = len(parameters) / (n_types * n_types) 

116 assert n.is_integer(), 'number of descriptor groups must be an integer' 

117 n = int(n) 

118 

119 m = (n_max_radial + 1) * (n_basis_radial + 1) 

120 descriptor_weights = parameters.reshape((n, n_types * n_types)).T 

121 descriptor_weights_radial = descriptor_weights[:, :m] 

122 descriptor_weights_angular = descriptor_weights[:, m:] 

123 

124 # add descriptors to data dict 

125 radial_descriptor_weights = {} 

126 angular_descriptor_weights = {} 

127 m = -1 

128 for i, j in product(range(n_types), repeat=2): 

129 m += 1 

130 s1, s2 = types[i], types[j] 

131 radial_descriptor_weights[(s1, s2)] = descriptor_weights_radial[m, :].reshape( 

132 (n_max_radial + 1, n_basis_radial + 1) 

133 ) 

134 angular_descriptor_weights[(s1, s2)] = descriptor_weights_angular[m, :].reshape( 

135 (n_max_angular + 1, n_basis_angular + 1) 

136 ) 

137 return radial_descriptor_weights, angular_descriptor_weights 

138 

139 

140def _sort_ann_parameters(parameters: list[float], 

141 ann_groupings: list[str], 

142 n_neuron: int, 

143 n_networks: int, 

144 n_bias: int, 

145 n_descriptor: int, 

146 is_polarizability_model: bool 

147 ) -> NetworkWeights: 

148 """Reads a list of model parameters and sorts them into an appropriately structured `dict`. 

149 Intended to be used by the :func:`read_model <calorine.nep.read_model>` function. 

150 """ 

151 n_ann_input_weights = (n_descriptor + 1) * n_neuron # weights + bias 

152 n_ann_output_weights = n_neuron # only weights 

153 n_ann_parameters = ( 

154 n_ann_input_weights + n_ann_output_weights 

155 ) * n_networks + n_bias 

156 

157 # Group ANN parameters 

158 pars = {} 

159 n1 = 0 

160 n_network_params = n_ann_input_weights + n_ann_output_weights # except last bias 

161 

162 n_count = 2 if is_polarizability_model else 1 

163 for count in range(n_count): 

164 # if polarizability model, all parameters including bias are repeated 

165 # need to offset n1 by +1 to handle bias 

166 n1 += count 

167 for s in ann_groupings: 

168 # Get the parameters for the ANN; in the case of NEP4, there is effectively 

169 # one network per atomic species. 

170 ann_parameters = parameters[n1 : n1 + n_network_params] 

171 ann_input_weights = ann_parameters[:n_ann_input_weights] 

172 w0 = np.zeros((n_neuron, n_descriptor)) 

173 w0[...] = np.nan 

174 b0 = np.zeros((n_neuron, 1)) 

175 b0[...] = np.nan 

176 for n in range(n_neuron): 

177 for nu in range(n_descriptor): 

178 w0[n, nu] = ann_input_weights[n * n_descriptor + nu] 

179 b0[:, 0] = ann_input_weights[n_neuron * n_descriptor :] 

180 

181 assert np.all( 

182 w0.shape == (n_neuron, n_descriptor) 

183 ), f'w0 has invalid shape for key {s}; please submit a bug report' 

184 assert np.all( 

185 b0.shape == (n_neuron, 1) 

186 ), f'b0 has invalid shape for key {s}; please submit a bug report' 

187 assert not np.any( 

188 np.isnan(w0) 

189 ), f'some weights in w0 are nan for key {s}; please submit a bug report' 

190 assert not np.any( 

191 np.isnan(b0) 

192 ), f'some weights in b0 are nan for key {s}; please submit a bug report' 

193 

194 ann_output_weights = ann_parameters[ 

195 n_ann_input_weights : n_ann_input_weights + n_ann_output_weights 

196 ] 

197 

198 w1 = np.zeros((1, n_neuron)) 

199 w1[0, :] = ann_output_weights[:] 

200 assert np.all( 

201 w1.shape == (1, n_neuron) 

202 ), f'w1 has invalid shape for key {s}; please submit a bug report' 

203 assert not np.any( 

204 np.isnan(w1) 

205 ), f'some weights in w1 are nan for key {s}; please submit a bug report' 

206 

207 if count == 0: 

208 pars[s] = dict(w0=w0, b0=b0, w1=w1) 

209 else: 

210 pars[s].update({'w0_polar': w0, 'b0_polar': b0, 'w1_polar': w1}) 

211 # Jump to bias 

212 n1 += n_network_params 

213 if n_bias > 1: 

214 # For NEP5 models we additionally have one bias term per species. 

215 # Currently NEP5 only exists for potential models, but we'll 

216 # keep it here in case it gets added down the line. 

217 bias_label = 'b1' if count == 0 else 'b1_polar' 

218 pars[s][bias_label] = parameters[n1] 

219 n1 += 1 

220 # For NEP3 and NEP4 we only have one bias. 

221 # For NEP5 we have one bias per species, and one global. 

222 if count == 0: 

223 pars['b1'] = parameters[n1] 

224 else: 

225 pars['b1_polar'] = parameters[n1] 

226 sum = 0 

227 for s in pars.keys(): 

228 if s.startswith('b1'): 

229 sum += 1 

230 else: 

231 sum += np.sum([np.array(p).size for p in pars[s].values()]) 

232 assert sum == n_ann_parameters * n_count, ( 

233 'Inconsistent number of parameters accounted for; please submit a bug report\n' 

234 f'{sum} != {n_ann_parameters}' 

235 ) 

236 return pars 

237 

238 

239@dataclass 

240class Model: 

241 r"""Objects of this class represent a NEP model in a form suitable for 

242 inspection and manipulation. Typically a :class:`Model` object is instantiated 

243 by calling the :func:`read_model <calorine.nep.read_model>` function. 

244 

245 Attributes 

246 ---------- 

247 version : int 

248 NEP version. 

249 model_type: str 

250 One of ``potential``, ``dipole`` or ``polarizability``. 

251 types : tuple[str, ...] 

252 Chemical species that this model represents. 

253 radial_cutoff : float 

254 The radial cutoff parameter in Å. 

255 angular_cutoff : float 

256 The angular cutoff parameter in Å. 

257 max_neighbors_radial : int 

258 Maximum number of neighbors in neighbor list for radial terms. 

259 max_neighbors_angular : int 

260 Maximum number of neighbors in neighbor list for angular terms. 

261 radial_typewise_cutoff_factor : float 

262 The radial cutoff factor if use_typewise_cutoff is used. 

263 angular_typewise_cutoff_factor : float 

264 The angular cutoff factor if use_typewise_cutoff is used. 

265 zbl : tuple[float, float] 

266 Inner and outer cutoff for transition to ZBL potential. 

267 zbl_typewise_cutoff_factor : float 

268 Typewise cutoff when use_typewise_cutoff_zbl is used. 

269 n_basis_radial : int 

270 Number of radial basis functions :math:`n_\mathrm{basis}^\mathrm{R}`. 

271 n_basis_angular : int 

272 Number of angular basis functions :math:`n_\mathrm{basis}^\mathrm{A}`. 

273 n_max_radial : int 

274 Maximum order of Chebyshev polymonials included in 

275 radial expansion :math:`n_\mathrm{max}^\mathrm{R}`. 

276 n_max_angular : int 

277 Maximum order of Chebyshev polymonials included in 

278 angular expansion :math:`n_\mathrm{max}^\mathrm{A}`. 

279 l_max_3b : int 

280 Maximum expansion order for three-body terms :math:`l_\mathrm{max}^\mathrm{3b}`. 

281 l_max_4b : int 

282 Maximum expansion order for four-body terms :math:`l_\mathrm{max}^\mathrm{4b}`. 

283 l_max_5b : int 

284 Maximum expansion order for five-body terms :math:`l_\mathrm{max}^\mathrm{5b}`. 

285 n_descriptor_radial : int 

286 Dimension of radial part of descriptor. 

287 n_descriptor_angular : int 

288 Dimension of angular part of descriptor. 

289 n_neuron : int 

290 Number of neurons in hidden layer. 

291 n_parameters : int 

292 Total number of parameters including scalers (which are not fit parameters). 

293 n_descriptor_parameters : int 

294 Number of parameters in descriptor. 

295 n_ann_parameters : int 

296 Number of neural network weights. 

297 ann_parameters : dict[tuple[str, dict[str, np.darray]]] 

298 Neural network weights. 

299 q_scaler : List[float] 

300 Scaling parameters. 

301 radial_descriptor_weights : dict[tuple[str, str], np.ndarray] 

302 Radial descriptor weights by combination of species; the array for each combination 

303 has dimensions of 

304 :math:`(n_\mathrm{max}^\mathrm{R}+1) \times (n_\mathrm{basis}^\mathrm{R}+1)`. 

305 angular_descriptor_weights : dict[tuple[str, str], np.ndarray] 

306 Angular descriptor weights by combination of species; the array for each combination 

307 has dimensions of 

308 :math:`(n_\mathrm{max}^\mathrm{A}+1) \times (n_\mathrm{basis}^\mathrm{A}+1)`. 

309 restart_parameters : dict[str, dict[str, dict[str, np.ndarray]]] 

310 NEP restart parameters. A nested dictionary that contains the mean (mu) and standard 

311 deviation (sigma) for the ANN and descriptor parameters. Is set using the 

312 py:meth:`~Model.load_restart` method. Defaults to None. 

313 """ 

314 

315 version: int 

316 model_type: str 

317 types: tuple[str, ...] 

318 

319 radial_cutoff: float 

320 angular_cutoff: float 

321 

322 n_basis_radial: int 

323 n_basis_angular: int 

324 n_max_radial: int 

325 n_max_angular: int 

326 l_max_3b: int 

327 l_max_4b: int 

328 l_max_5b: int 

329 n_descriptor_radial: int 

330 n_descriptor_angular: int 

331 

332 n_neuron: int 

333 n_parameters: int 

334 n_descriptor_parameters: int 

335 n_ann_parameters: int 

336 ann_parameters: NetworkWeights 

337 q_scaler: list[float] 

338 radial_descriptor_weights: DescriptorWeights 

339 angular_descriptor_weights: DescriptorWeights 

340 restart_parameters: RestartParameters = None 

341 

342 zbl: tuple[float, float] = None 

343 zbl_typewise_cutoff_factor: float = None 

344 max_neighbors_radial: int = None 

345 max_neighbors_angular: int = None 

346 radial_typewise_cutoff_factor: float = None 

347 angular_typewise_cutoff_factor: float = None 

348 

349 _special_fields = [ 

350 'ann_parameters', 

351 'q_scaler', 

352 'radial_descriptor_weights', 

353 'angular_descriptor_weights', 

354 ] 

355 

356 def __str__(self) -> str: 

357 s = [] 

358 for fld in self.__dataclass_fields__: 

359 if fld not in self._special_fields: 

360 s += [f'{fld:22} : {getattr(self, fld)}'] 

361 return '\n'.join(s) 

362 

363 def _repr_html_(self) -> str: 

364 s = [] 

365 s += ['<table border="1" class="dataframe"'] 

366 s += [ 

367 '<thead><tr><th style="text-align: left;">Field</th><th>Value</th></tr></thead>' 

368 ] 

369 s += ['<tbody>'] 

370 for fld in self.__dataclass_fields__: 

371 if fld not in self._special_fields: 

372 s += [ 

373 f'<tr><td style="text-align: left;">{fld:22}</td>' 

374 f'<td>{getattr(self, fld)}</td><tr>' 

375 ] 

376 for fld in self._special_fields: 

377 d = getattr(self, fld) 

378 if fld.endswith('descriptor_weights'): 

379 dim = list(d.values())[0].shape 

380 if fld == 'ann_parameters' and self.version == 4: 

381 dim = (len(self.types), len(list(d.values())[0])) 

382 else: 

383 dim = len(d) 

384 s += [ 

385 f'<tr><td style="text-align: left;">Dimension of {fld:22}</td><td>{dim}</td><tr>' 

386 ] 

387 s += ['</tbody>'] 

388 s += ['</table>'] 

389 return ''.join(s) 

390 

391 def remove_species(self, species: list[str]): 

392 """Removes one or more species from the model. 

393 

394 This method modifies the model in-place by removing all parameters 

395 associated with the specified chemical species. It prunes the species 

396 list, the Artificial Neural Network (ANN) parameters, and the 

397 descriptor weights. It also recalculates the total number of 

398 parameters in the model. 

399 

400 Parameters 

401 ---------- 

402 species 

403 A list of species names (str) to remove from the model. 

404 

405 Raises 

406 ------ 

407 ValueError 

408 If any of the provided species is not found in the model. 

409 """ 

410 for s in species: 

411 if s not in self.types: 

412 raise ValueError(f'{s} is not a species supported by the NEP model') 

413 

414 # --- Prune attributes based on species --- 

415 types_to_keep = [t for t in self.types if t not in species] 

416 self.types = tuple(types_to_keep) 

417 

418 # Prune ANN parameters (for NEP4 and NEP5) 

419 if self.version in [4, 5]: 

420 self.ann_parameters = { 

421 key: value for key, value in self.ann_parameters.items() 

422 if key in types_to_keep or key.startswith('b1') 

423 } 

424 

425 # Prune descriptor weights 

426 # key is here a tuple, (species1, species2) 

427 self.radial_descriptor_weights = { 

428 key: value for key, value in self.radial_descriptor_weights.items() 

429 if key[0] in types_to_keep and key[1] in types_to_keep 

430 } 

431 self.angular_descriptor_weights = { 

432 key: value for key, value in self.angular_descriptor_weights.items() 

433 if key[0] in types_to_keep and key[1] in types_to_keep 

434 } 

435 

436 # Prune restart parameters if they have been loaded 

437 if self.restart_parameters is not None: 

438 for param_type in ['mu', 'sigma']: 

439 # Prune ANN restart parameters 

440 ann_key = f'ann_{param_type}' 

441 if self.version in [4, 5]: 

442 self.restart_parameters[ann_key] = { 

443 key: value for key, value in self.restart_parameters[ann_key].items() 

444 if key in types_to_keep or key.startswith('b1') 

445 } 

446 

447 # Prune descriptor restart parameters 

448 for desc_type in ['radial', 'angular']: 

449 key = f'{desc_type}_descriptor_{param_type}' 

450 self.restart_parameters[key] = { 

451 k: v for k, v in self.restart_parameters[key].items() 

452 if k[0] in types_to_keep and k[1] in types_to_keep 

453 } 

454 

455 # --- Recalculate parameter counts --- 

456 n_types = len(self.types) 

457 n_descriptor = self.n_descriptor_radial + self.n_descriptor_angular 

458 

459 # Recalculate descriptor parameter count 

460 self.n_descriptor_parameters = n_types**2 * ( 

461 (self.n_max_radial + 1) * (self.n_basis_radial + 1) 

462 + (self.n_max_angular + 1) * (self.n_basis_angular + 1) 

463 ) 

464 

465 # Recalculate ANN parameter count 

466 if self.version == 3: 

467 n_networks = 1 

468 n_bias = 1 

469 elif self.version == 4: 

470 n_networks = n_types 

471 n_bias = 1 

472 else: # NEP5 

473 n_networks = n_types 

474 n_bias = 1 + n_types 

475 

476 n_ann_input_weights = (n_descriptor + 1) * self.n_neuron 

477 n_ann_output_weights = self.n_neuron 

478 self.n_ann_parameters = ( 

479 n_ann_input_weights + n_ann_output_weights 

480 ) * n_networks + n_bias 

481 

482 # Recalculate total parameter count 

483 self.n_parameters = ( 

484 self.n_ann_parameters 

485 + self.n_descriptor_parameters 

486 + n_descriptor # q_scaler parameters 

487 ) 

488 if self.model_type == 'polarizability': 

489 self.n_parameters += self.n_ann_parameters 

490 

491 def write(self, filename: str) -> None: 

492 """Write NEP model to file in `nep.txt` format.""" 

493 with open(filename, 'w') as f: 

494 # header 

495 version_name = f'nep{self.version}' 

496 if self.zbl is not None: 

497 version_name += '_zbl' 

498 elif self.model_type != 'potential': 

499 version_name += f'_{self.model_type}' 

500 f.write(f'{version_name} {len(self.types)} {" ".join(self.types)}\n') 

501 if self.zbl is not None: 

502 f.write(f'zbl {" ".join(map(str, self.zbl))}\n') 

503 f.write(f'cutoff {self.radial_cutoff} {self.angular_cutoff}') 

504 f.write(f' {self.max_neighbors_radial} {self.max_neighbors_angular}') 

505 if ( 

506 self.radial_typewise_cutoff_factor is not None 

507 and self.angular_typewise_cutoff_factor is not None 

508 ): 

509 f.write(f' {self.radial_typewise_cutoff_factor}' 

510 f' {self.angular_typewise_cutoff_factor}') 

511 if self.zbl_typewise_cutoff_factor: 

512 f.write(f' {self.zbl_typewise_cutoff_factor}') 

513 f.write('\n') 

514 f.write(f'n_max {self.n_max_radial} {self.n_max_angular}\n') 

515 f.write(f'basis_size {self.n_basis_radial} {self.n_basis_angular}\n') 

516 f.write(f'l_max {self.l_max_3b} {self.l_max_4b} {self.l_max_5b}\n') 

517 f.write(f'ANN {self.n_neuron} 0\n') 

518 

519 # neural network weights 

520 keys = self.types if self.version in (4, 5) else ['all_species'] 

521 suffixes = ['', '_polar'] if self.model_type == 'polarizability' else [''] 

522 for suffix in suffixes: 

523 for s in keys: 

524 # Order: w0, b0, w1 (, b1 if NEP5) 

525 # w0 indexed as: n*N_descriptor + nu 

526 w0 = self.ann_parameters[s][f'w0{suffix}'] 

527 b0 = self.ann_parameters[s][f'b0{suffix}'] 

528 w1 = self.ann_parameters[s][f'w1{suffix}'] 

529 for n in range(self.n_neuron): 

530 for nu in range( 

531 self.n_descriptor_radial + self.n_descriptor_angular 

532 ): 

533 f.write(f'{w0[n, nu]:15.7e}\n') 

534 for b in b0[:, 0]: 

535 f.write(f'{b:15.7e}\n') 

536 for v in w1[0, :]: 

537 f.write(f'{v:15.7e}\n') 

538 if self.version == 5: 

539 b1 = self.ann_parameters[s][f'b1{suffix}'] 

540 f.write(f'{b1:15.7e}\n') 

541 b1 = self.ann_parameters[f'b1{suffix}'] 

542 f.write(f'{b1:15.7e}\n') 

543 

544 # descriptor weights 

545 mat = [] 

546 for s1 in self.types: 

547 for s2 in self.types: 

548 mat = np.hstack( 

549 [mat, self.radial_descriptor_weights[(s1, s2)].flatten()] 

550 ) 

551 mat = np.hstack( 

552 [mat, self.angular_descriptor_weights[(s1, s2)].flatten()] 

553 ) 

554 n_types = len(self.types) 

555 n = int(len(mat) / (n_types * n_types)) 

556 mat = mat.reshape((n_types * n_types, n)).T 

557 for v in mat.flatten(): 

558 f.write(f'{v:15.7e}\n') 

559 

560 # scaler 

561 for v in self.q_scaler: 

562 f.write(f'{v:15.7e}\n') 

563 

564 def load_restart(self, filename: str): 

565 """Parses a file in `nep.restart` format and saves the 

566 content in the form of mean and standard deviation for each 

567 parameter in the corresponding NEP model. 

568 

569 Parameters 

570 ---------- 

571 filename 

572 Input file name. 

573 """ 

574 mu, sigma = _get_restart_contents(filename) 

575 restart_parameters = np.array([mu, sigma]).T 

576 

577 is_polarizability_model = self.model_type == 'polarizability' 

578 n1 = self.n_ann_parameters 

579 n1 *= 2 if is_polarizability_model else 1 

580 n2 = n1 + self.n_descriptor_parameters 

581 ann_parameters = restart_parameters[:n1] 

582 descriptor_parameters = np.array(restart_parameters[n1:n2]) 

583 

584 if self.version == 3: 

585 n_networks = 1 

586 n_bias = 1 

587 elif self.version == 4: 

588 # one hidden layer per atomic species 

589 n_networks = len(self.types) 

590 n_bias = 1 

591 else: 

592 raise ValueError(f'Cannot load nep.restart for NEP model version {self.version}') 

593 

594 ann_groups = [s for s in self.ann_parameters.keys() if not s.startswith('b1')] 

595 n_bias = len([s for s in self.ann_parameters.keys() if s.startswith('b1')]) 

596 n_descriptor = self.n_descriptor_radial + self.n_descriptor_angular 

597 restart = {} 

598 

599 for i, content_type in enumerate(['mu', 'sigma']): 

600 ann = _sort_ann_parameters(ann_parameters[:, i], 

601 ann_groups, 

602 self.n_neuron, 

603 n_networks, 

604 n_bias, 

605 n_descriptor, 

606 is_polarizability_model) 

607 radial, angular = _sort_descriptor_parameters(descriptor_parameters[:, i], 

608 self.types, 

609 self.n_max_radial, 

610 self.n_basis_radial, 

611 self.n_max_angular, 

612 self.n_basis_angular) 

613 

614 restart[f'ann_{content_type}'] = ann 

615 restart[f'radial_descriptor_{content_type}'] = radial 

616 restart[f'angular_descriptor_{content_type}'] = angular 

617 self.restart_parameters = restart 

618 

619 def write_restart(self, filename: str): 

620 """Write NEP restart parameters to file in `nep.restart` format.""" 

621 keys = self.types if self.version in (4, 5) else ['all_species'] 

622 suffixes = ['', '_polar'] if self.model_type == 'polarizability' else [''] 

623 columns = [] 

624 for i, parameter in enumerate(['mu', 'sigma']): 

625 # neural network weights 

626 ann_parameters = self.restart_parameters[f'ann_{parameter}'] 

627 column = [] 

628 for suffix in suffixes: 

629 for s in keys: 

630 # Order: w0, b0, w1 (, b1 if NEP5) 

631 # w0 indexed as: n*N_descriptor + nu 

632 w0 = ann_parameters[s][f'w0{suffix}'] 

633 b0 = ann_parameters[s][f'b0{suffix}'] 

634 w1 = ann_parameters[s][f'w1{suffix}'] 

635 for n in range(self.n_neuron): 

636 for nu in range( 

637 self.n_descriptor_radial + self.n_descriptor_angular 

638 ): 

639 column.append(f'{w0[n, nu]:15.7e}') 

640 for b in b0[:, 0]: 

641 column.append(f'{b:15.7e}') 

642 for v in w1[0, :]: 

643 column.append(f'{v:15.7e}') 

644 b1 = ann_parameters[f'b1{suffix}'] 

645 column.append(f'{b1:15.7e}') 

646 columns.append(column) 

647 

648 # descriptor weights 

649 radial_descriptor_parameters = self.restart_parameters[f'radial_descriptor_{parameter}'] 

650 angular_descriptor_parameters = self.restart_parameters[ 

651 f'angular_descriptor_{parameter}'] 

652 mat = [] 

653 for s1 in self.types: 

654 for s2 in self.types: 

655 mat = np.hstack( 

656 [mat, radial_descriptor_parameters[(s1, s2)].flatten()] 

657 ) 

658 mat = np.hstack( 

659 [mat, angular_descriptor_parameters[(s1, s2)].flatten()] 

660 ) 

661 n_types = len(self.types) 

662 n = int(len(mat) / (n_types * n_types)) 

663 mat = mat.reshape((n_types * n_types, n)).T 

664 for v in mat.flatten(): 

665 column.append(f'{v:15.7e}') 

666 

667 # Join the mean and standard deviation columns 

668 assert len(columns[0]) == len(columns[1]), 'Length of means must match standard deviation' 

669 joined = [f'{s1} {s2}\n' for s1, s2 in zip(*columns)] 

670 with open(filename, 'w') as f: 

671 f.writelines(joined) 

672 

673 

674def read_model(filename: str) -> Model: 

675 """Parses a file in ``nep.txt`` format and returns the 

676 content in the form of a :class:`Model <calorine.nep.model.Model>` 

677 object. 

678 

679 Parameters 

680 ---------- 

681 filename 

682 Input file name. 

683 """ 

684 data, parameters = _get_nep_contents(filename) 

685 

686 # sanity checks 

687 for fld in ['cutoff', 'basis_size', 'n_max', 'l_max', 'ANN']: 

688 assert fld in data, f'Invalid model file; {fld} line is missing' 

689 assert data['version'] in [ 

690 3, 

691 4, 

692 5, 

693 ], 'Invalid model file; only NEP versions 3, 4 and 5 are currently supported' 

694 

695 # split up cutoff tuple 

696 assert len(data['cutoff']) in [4, 6, 7] 

697 data['radial_cutoff'] = data['cutoff'][0] 

698 data['angular_cutoff'] = data['cutoff'][1] 

699 data['max_neighbors_radial'] = int(data['cutoff'][2]) 

700 data['max_neighbors_angular'] = int(data['cutoff'][3]) 

701 if len(data['cutoff']) >= 6: 

702 data['radial_typewise_cutoff_factor'] = data['cutoff'][4] 

703 data['angular_typewise_cutoff_factor'] = data['cutoff'][5] 

704 if len(data['cutoff']) == 7: 

705 data['zbl_typewise_cutoff_factor'] = data['cutoff'][6] 

706 del data['cutoff'] 

707 

708 # split up basis_size tuple 

709 assert len(data['basis_size']) == 2 

710 data['n_basis_radial'] = data['basis_size'][0] 

711 data['n_basis_angular'] = data['basis_size'][1] 

712 del data['basis_size'] 

713 

714 # split up n_max tuple 

715 assert len(data['n_max']) == 2 

716 data['n_max_radial'] = data['n_max'][0] 

717 data['n_max_angular'] = data['n_max'][1] 

718 del data['n_max'] 

719 

720 # split up nl_max tuple 

721 len_l = len(data['l_max']) 

722 assert len_l in [1, 2, 3] 

723 data['l_max_3b'] = data['l_max'][0] 

724 data['l_max_4b'] = data['l_max'][1] if len_l > 1 else 0 

725 data['l_max_5b'] = data['l_max'][2] if len_l > 2 else 0 

726 del data['l_max'] 

727 

728 # compute dimensions of descriptor components 

729 data['n_descriptor_radial'] = data['n_max_radial'] + 1 

730 l_max_enh = data['l_max_3b'] + (data['l_max_4b'] > 0) + (data['l_max_5b'] > 0) 

731 data['n_descriptor_angular'] = (data['n_max_angular'] + 1) * l_max_enh 

732 n_descriptor = data['n_descriptor_radial'] + data['n_descriptor_angular'] 

733 

734 # compute number of parameters 

735 data['n_neuron'] = data['ANN'][0] 

736 del data['ANN'] 

737 n_types = len(data['types']) 

738 if data['version'] == 3: 

739 n = 1 

740 n_bias = 1 

741 elif data['version'] == 4: 

742 # one hidden layer per atomic species 

743 n = n_types 

744 n_bias = 1 

745 else: # NEP5 

746 # like nep4, but additionally has an 

747 # individual bias term in the output 

748 # layer for each species. 

749 n = n_types 

750 n_bias = 1 + n_types # one global bias + one per species 

751 

752 n_ann_input_weights = (n_descriptor + 1) * data['n_neuron'] # weights + bias 

753 n_ann_output_weights = data['n_neuron'] # only weights 

754 n_ann_parameters = ( 

755 n_ann_input_weights + n_ann_output_weights 

756 ) * n + n_bias 

757 

758 n_descriptor_weights = n_types**2 * ( 

759 (data['n_max_radial'] + 1) * (data['n_basis_radial'] + 1) 

760 + (data['n_max_angular'] + 1) * (data['n_basis_angular'] + 1) 

761 ) 

762 data['n_parameters'] = n_ann_parameters + n_descriptor_weights + n_descriptor 

763 is_polarizability_model = data['model_type'] == 'polarizability' 

764 if data['n_parameters'] + n_ann_parameters == len(parameters): 

765 data['n_parameters'] += n_ann_parameters 

766 assert is_polarizability_model, ( 

767 'Model is not labelled as a polarizability model, but the number of ' 

768 'parameters matches a polarizability model.\n' 

769 'If this is a polarizability model trained with GPUMD <=v3.8, please ' 

770 'modify the header in the nep.txt file to read ' 

771 f'`nep{data["version"]}_polarizability`.\n' 

772 ) 

773 assert data['n_parameters'] == len(parameters), ( 

774 'Parsing of parameters inconsistent; please submit a bug report\n' 

775 f'{data["n_parameters"]} != {len(parameters)}' 

776 ) 

777 data['n_ann_parameters'] = n_ann_parameters 

778 

779 # split up parameters into the ANN weights, descriptor weights, and scaling parameters 

780 n1 = n_ann_parameters 

781 n1 *= 2 if is_polarizability_model else 1 

782 n2 = n1 + n_descriptor_weights 

783 data['ann_parameters'] = parameters[:n1] 

784 descriptor_weights = np.array(parameters[n1:n2]) 

785 data['q_scaler'] = parameters[n2:] 

786 

787 # add ann parameters to data dict 

788 ann_groups = data['types'] if data['version'] in (4, 5) else ['all_species'] 

789 sorted_ann_parameters = _sort_ann_parameters(data['ann_parameters'], 

790 ann_groups, 

791 data['n_neuron'], 

792 n, 

793 n_bias, 

794 n_descriptor, 

795 is_polarizability_model) 

796 

797 data['ann_parameters'] = sorted_ann_parameters 

798 

799 # add descriptors to data dict 

800 data['n_descriptor_parameters'] = len(descriptor_weights) 

801 radial, angular = _sort_descriptor_parameters(descriptor_weights, 

802 data['types'], 

803 data['n_max_radial'], 

804 data['n_basis_radial'], 

805 data['n_max_angular'], 

806 data['n_basis_angular']) 

807 data['radial_descriptor_weights'] = radial 

808 data['angular_descriptor_weights'] = angular 

809 

810 return Model(**data)