Coverage for calorine/nep/model.py: 100%

757 statements  

« prev     ^ index     » next       coverage.py v7.14.1, created at 2026-06-18 13:01 +0000

1import copy 

2from dataclasses import dataclass 

3from itertools import product 

4 

5import numpy as np 

6 

7NetworkWeights = dict[str, dict[str, np.ndarray]] 

8DescriptorWeights = dict[tuple[str, str], np.ndarray] 

9RestartParameters = dict[str, dict[str, dict[str, np.ndarray]]] 

10 

11 

12def _get_restart_contents(filename: str) -> tuple[list[float], list[float]]: 

13 """Parses a ``nep.restart`` file, and returns an unformatted list of the 

14 mean and standard deviation for all model parameters. 

15 Intended to be used by the py:meth:`~Model.read_restart` function. 

16 

17 Parameters 

18 ---------- 

19 filename 

20 input file name 

21 """ 

22 mu = [] # Mean 

23 sigma = [] # Standard deviation 

24 with open(filename) as f: 

25 for k, line in enumerate(f.readlines()): 

26 flds = line.split() 

27 assert len(flds) != 0, f'Empty line number {k}' 

28 if len(flds) == 2: 

29 mu.append(float(flds[0])) 

30 sigma.append(float(flds[1])) 

31 else: 

32 raise IOError(f'Failed to parse line {k} from {filename}') 

33 return mu, sigma 

34 

35 

36def _get_model_type(first_row: list[str]) -> str: 

37 """Parses a the first row of a ``nep.txt`` file, and returns the 

38 type of NEP model. Available types are `potential`, `potential_with_charges`, 

39 `dipole`, and `polarizability`. 

40 

41 Parameters 

42 ---------- 

43 first_row 

44 First row of a NEP file, split by white space. 

45 """ 

46 model_type = first_row[0] 

47 if 'charge' in model_type: 

48 return 'potential_with_charges' 

49 elif 'dipole' in model_type: 

50 return 'dipole' 

51 elif 'polarizability' in model_type: 

52 return 'polarizability' 

53 return 'potential' 

54 

55 

56def _get_nep_contents(filename: str) -> tuple[dict, list[float]]: 

57 """Parses a ``nep.txt`` file, and returns a dict describing the header 

58 and an unformatted list of all model parameters. 

59 Intended to be used by the :func:`read_model <calorine.nep.read_model>` function. 

60 

61 Parameters 

62 ---------- 

63 filename 

64 input file name 

65 """ 

66 # parse file and split header and parameters 

67 header = [] 

68 parameters = [] 

69 nheader = 5 # 5 rows for NEP2, 6-7 rows for NEP3 onwards 

70 base_line = 3 

71 with open(filename) as f: 

72 for k, line in enumerate(f.readlines()): 

73 flds = line.split() 

74 assert len(flds) != 0, f'Empty line number {k}' 

75 if k == 0 and 'zbl' in flds[0]: 

76 base_line += 1 

77 nheader += 1 

78 if k == base_line and 'basis_size' in flds[0]: 

79 # Introduced in nep.txt after GPUMD v3.2 

80 nheader += 1 

81 if k < nheader: 

82 header.append(tuple(flds)) 

83 elif len(flds) == 1: 

84 parameters.append(float(flds[0])) 

85 else: 

86 raise IOError(f'Failed to parse line {k} from {filename}') 

87 # compile data from the header into a dict 

88 data = {} 

89 for flds in header: 

90 if flds[0] in ['cutoff', 'zbl']: 

91 data[flds[0]] = tuple(map(float, flds[1:])) 

92 elif flds[0] in ['n_max', 'l_max', 'ANN', 'basis_size']: 

93 data[flds[0]] = tuple(map(int, flds[1:])) 

94 elif flds[0].startswith('nep'): 

95 version = flds[0].replace('nep', '').split('_')[0] 

96 version = int(version) 

97 data['version'] = version 

98 data['types'] = flds[2:] 

99 data['model_type'] = _get_model_type(flds) 

100 else: 

101 raise ValueError(f'Unknown field: {flds[0]}') 

102 return data, parameters 

103 

104 

105def _sort_descriptor_parameters(parameters: list[float], 

106 types: list[str], 

107 n_max_radial: int, 

108 n_basis_radial: int, 

109 n_max_angular: int, 

110 n_basis_angular: int) -> tuple[DescriptorWeights, 

111 DescriptorWeights]: 

112 """Reads a list of descriptors parameters and sorts them into two 

113 appropriately structured `dicts`, one for radial and one for angular descriptor weights. 

114 Intended to be used by the :func:`read_model <calorine.nep.read_model>` function. 

115 """ 

116 # split up descriptor by chemical species and radial/angular 

117 n_types = len(types) 

118 n = len(parameters) / (n_types * n_types) 

119 assert n.is_integer(), 'number of descriptor groups must be an integer' 

120 n = int(n) 

121 

122 m = (n_max_radial + 1) * (n_basis_radial + 1) 

123 descriptor_weights = parameters.reshape((n, n_types * n_types)).T 

124 descriptor_weights_radial = descriptor_weights[:, :m] 

125 descriptor_weights_angular = descriptor_weights[:, m:] 

126 

127 # add descriptors to data dict 

128 radial_descriptor_weights = {} 

129 angular_descriptor_weights = {} 

130 m = -1 

131 for i, j in product(range(n_types), repeat=2): 

132 m += 1 

133 s1, s2 = types[i], types[j] 

134 radial_descriptor_weights[(s1, s2)] = descriptor_weights_radial[m, :].reshape( 

135 (n_max_radial + 1, n_basis_radial + 1) 

136 ) 

137 angular_descriptor_weights[(s1, s2)] = descriptor_weights_angular[m, :].reshape( 

138 (n_max_angular + 1, n_basis_angular + 1) 

139 ) 

140 return radial_descriptor_weights, angular_descriptor_weights 

141 

142 

143def _sort_ann_parameters(parameters: list[float], 

144 ann_groupings: list[str], 

145 n_neuron: int, 

146 n_networks: int, 

147 n_bias: int, 

148 n_descriptor: int, 

149 is_polarizability_model: bool, 

150 is_model_with_charges: bool 

151 ) -> NetworkWeights: 

152 """Reads a list of model parameters and sorts them into an appropriately structured `dict`. 

153 Intended to be used by the :func:`read_model <calorine.nep.read_model>` function. 

154 """ 

155 n_ann_input_weights = (n_descriptor + 1) * n_neuron # weights + bias 

156 n_ann_output_weights = 2*n_neuron if is_model_with_charges else n_neuron # only weights 

157 n_ann_parameters = ( 

158 n_ann_input_weights + n_ann_output_weights 

159 ) * n_networks + n_bias 

160 

161 # Group ANN parameters 

162 pars = {} 

163 n1 = 0 

164 n_network_params = n_ann_input_weights + n_ann_output_weights # except last bias(es) 

165 

166 n_count = 2 if is_polarizability_model else 1 

167 n_outputs = 2 if is_model_with_charges else 1 

168 for count in range(n_count): 

169 # if polarizability model, all parameters including bias are repeated 

170 # need to offset n1 by +1 to handle bias 

171 n1 += count 

172 for s in ann_groupings: 

173 # Get the parameters for the ANN; in the case of NEP4, there is effectively 

174 # one network per atomic species. 

175 ann_parameters = parameters[n1 : n1 + n_network_params] 

176 ann_input_weights = ann_parameters[:n_ann_input_weights] 

177 w0 = np.zeros((n_neuron, n_descriptor)) 

178 w0[...] = np.nan 

179 b0 = np.zeros((n_neuron, 1)) 

180 b0[...] = np.nan 

181 for n in range(n_neuron): 

182 for nu in range(n_descriptor): 

183 w0[n, nu] = ann_input_weights[n * n_descriptor + nu] 

184 b0[:, 0] = ann_input_weights[n_neuron * n_descriptor :] 

185 

186 assert np.all( 

187 w0.shape == (n_neuron, n_descriptor) 

188 ), f'w0 has invalid shape for key {s}; please submit a bug report' 

189 assert np.all( 

190 b0.shape == (n_neuron, 1) 

191 ), f'b0 has invalid shape for key {s}; please submit a bug report' 

192 assert not np.any( 

193 np.isnan(w0) 

194 ), f'some weights in w0 are nan for key {s}; please submit a bug report' 

195 assert not np.any( 

196 np.isnan(b0) 

197 ), f'some weights in b0 are nan for key {s}; please submit a bug report' 

198 

199 ann_output_weights = ann_parameters[ 

200 n_ann_input_weights : n_ann_input_weights + n_ann_output_weights 

201 ] 

202 w1 = np.zeros((1, n_neuron * n_outputs)) 

203 w1[0, :] = ann_output_weights[:] 

204 assert np.all( 

205 w1.shape == (1, n_neuron * n_outputs) 

206 ), f'w1 has invalid shape for key {s}; please submit a bug report' 

207 assert not np.any( 

208 np.isnan(w1) 

209 ), f'some weights in w1 are nan for key {s}; please submit a bug report' 

210 

211 if count == 0 and n_outputs == 1: 

212 pars[s] = dict(w0=w0, b0=b0, w1=w1) 

213 elif count == 0 and n_outputs == 2: 

214 pars[s] = dict(w0=w0, b0=b0, w1=w1[0, :n_neuron], w1_charge=w1[0, n_neuron:]) 

215 else: 

216 pars[s].update({'w0_polar': w0, 'b0_polar': b0, 'w1_polar': w1}) 

217 # Jump to bias 

218 n1 += n_network_params 

219 if n_bias > 1 and not is_model_with_charges: 

220 # For NEP5 models we additionally have one bias term per species. 

221 # Currently NEP5 only exists for potential models, but we'll 

222 # keep it here in case it gets added down the line. 

223 bias_label = 'b1' if count == 0 else 'b1_polar' 

224 pars[s][bias_label] = parameters[n1] 

225 n1 += 1 

226 # For NEP3 and NEP4 we only have one bias. 

227 # For NEP4 with charges we have two biases. 

228 # For NEP5 we have one bias per species, and one global. 

229 if count == 0 and n_outputs == 1: 

230 pars['b1'] = parameters[n1] 

231 elif count == 0 and n_outputs == 2: 

232 pars['sqrt_epsilon_infinity'] = parameters[n1] 

233 pars['b1'] = parameters[n1+1] 

234 else: 

235 pars['b1_polar'] = parameters[n1] 

236 sum = 0 

237 for s in pars.keys(): 

238 if s.startswith('b1') or s.startswith('sqrt'): 

239 sum += 1 

240 else: 

241 sum += np.sum([np.array(p).size for p in pars[s].values()]) 

242 assert sum == n_ann_parameters * n_count, ( 

243 'Inconsistent number of parameters accounted for; please submit a bug report\n' 

244 f'{sum} != {n_ann_parameters}' 

245 ) 

246 return pars 

247 

248 

249def _adaptive_sigma(mu_arr, sigma_factor: float, sigma_floor: float) -> np.ndarray: 

250 """Return adaptive SNES sigma: ``max(sigma_floor, sigma_factor * |mu|)``.""" 

251 return np.maximum(sigma_floor, sigma_factor * np.abs(mu_arr)) 

252 

253 

254def _apply_adaptive_sigma_to_restart(restart_params, keys, sigma_factor, sigma_floor): 

255 """Apply adaptive SNES sigma to every parameter in *restart_params* in-place. 

256 

257 Covers per-species ANN weights (w0, b0, w1, optional w1_charge), the global b1 

258 scalar, the optional sqrt_epsilon_infinity scalar, and all radial/angular descriptor 

259 weight pairs. *keys* is the list of per-species ANN keys to update. 

260 """ 

261 for s in keys: 

262 ann_mu = restart_params['ann_mu'][s] 

263 ann_sigma = restart_params['ann_sigma'][s] 

264 ann_sigma['w0'] = _adaptive_sigma(ann_mu['w0'], sigma_factor, sigma_floor) 

265 ann_sigma['b0'] = _adaptive_sigma(ann_mu['b0'], sigma_factor, sigma_floor) 

266 ann_sigma['w1'] = _adaptive_sigma(ann_mu['w1'], sigma_factor, sigma_floor) 

267 if 'w1_charge' in ann_mu: 

268 ann_sigma['w1_charge'] = _adaptive_sigma( 

269 ann_mu['w1_charge'], sigma_factor, sigma_floor 

270 ) 

271 b1_mu = restart_params['ann_mu']['b1'] 

272 restart_params['ann_sigma']['b1'] = float(_adaptive_sigma(b1_mu, sigma_factor, sigma_floor)) 

273 if 'sqrt_epsilon_infinity' in restart_params['ann_mu']: 

274 sei_mu = restart_params['ann_mu']['sqrt_epsilon_infinity'] 

275 restart_params['ann_sigma']['sqrt_epsilon_infinity'] = float( 

276 _adaptive_sigma(sei_mu, sigma_factor, sigma_floor) 

277 ) 

278 for desc_type in ['radial', 'angular']: 

279 sigma_key = f'{desc_type}_descriptor_sigma' 

280 mu_key = f'{desc_type}_descriptor_mu' 

281 for pair in restart_params[sigma_key]: 

282 restart_params[sigma_key][pair] = _adaptive_sigma( 

283 restart_params[mu_key][pair], sigma_factor, sigma_floor 

284 ) 

285 

286 

287def _recalculate_parameter_counts(new) -> None: 

288 """Recompute n_ann_parameters, n_descriptor_parameters, and n_parameters on *new*. 

289 

290 Reads all architectural state from *new* directly, so callers must update 

291 new.n_neuron, new.n_descriptor_radial/angular, new.model_type, and new.types 

292 before calling this function. 

293 """ 

294 n_types = len(new.types) 

295 n_desc = new.n_descriptor_radial + new.n_descriptor_angular 

296 is_charged = new.model_type == 'potential_with_charges' 

297 n_networks = n_types if new.version in (4, 5) else 1 

298 n_bias = 2 if is_charged else (1 + n_types if new.version == 5 else 1) 

299 n_ann_input_weights = (n_desc + 1) * new.n_neuron 

300 n_ann_output_weights = 2 * new.n_neuron if is_charged else new.n_neuron 

301 new.n_ann_parameters = (n_ann_input_weights + n_ann_output_weights) * n_networks + n_bias 

302 new.n_descriptor_parameters = n_types ** 2 * ( 

303 (new.n_max_radial + 1) * (new.n_basis_radial + 1) 

304 + (new.n_max_angular + 1) * (new.n_basis_angular + 1) 

305 ) 

306 new.n_parameters = new.n_ann_parameters + new.n_descriptor_parameters + n_desc 

307 if new.model_type == 'polarizability': 

308 new.n_parameters += new.n_ann_parameters 

309 

310 

311@dataclass 

312class Model: 

313 r"""Objects of this class represent a NEP model in a form suitable for 

314 inspection and manipulation. Typically a :class:`Model` object is instantiated 

315 by calling the :func:`read_model <calorine.nep.read_model>` function. 

316 

317 Attributes 

318 ---------- 

319 version : int 

320 NEP version. 

321 model_type: str 

322 One of ``potential``, ``dipole`` or ``polarizability``. 

323 types : tuple[str, ...] 

324 Chemical species that this model represents. 

325 radial_cutoff : float | list[float] 

326 The radial cutoff parameter in Å. 

327 Is a list of radial cutoffs ordered after ``types`` in the case of typewise cutoffs. 

328 angular_cutoff : float | list[float] 

329 The angular cutoff parameter in Å. 

330 Is a list of angular cutoffs ordered after ``types`` in the case of typewise cutoffs. 

331 max_neighbors_radial : int 

332 Maximum number of neighbors in neighbor list for radial terms. 

333 max_neighbors_angular : int 

334 Maximum number of neighbors in neighbor list for angular terms. 

335 radial_typewise_cutoff_factor : float 

336 The radial cutoff factor if use_typewise_cutoff is used. 

337 angular_typewise_cutoff_factor : float 

338 The angular cutoff factor if use_typewise_cutoff is used. 

339 zbl : tuple[float, float] 

340 Inner and outer cutoff for transition to ZBL potential. 

341 zbl_typewise_cutoff_factor : float 

342 Typewise cutoff when use_typewise_cutoff_zbl is used. 

343 n_basis_radial : int 

344 Number of radial basis functions :math:`n_\mathrm{basis}^\mathrm{R}`. 

345 n_basis_angular : int 

346 Number of angular basis functions :math:`n_\mathrm{basis}^\mathrm{A}`. 

347 n_max_radial : int 

348 Maximum order of Chebyshev polymonials included in 

349 radial expansion :math:`n_\mathrm{max}^\mathrm{R}`. 

350 n_max_angular : int 

351 Maximum order of Chebyshev polymonials included in 

352 angular expansion :math:`n_\mathrm{max}^\mathrm{A}`. 

353 l_max_3b : int 

354 Maximum expansion order for three-body terms :math:`l_\mathrm{max}^\mathrm{3b}`. 

355 l_max_4b : int 

356 Maximum expansion order for four-body terms :math:`l_\mathrm{max}^\mathrm{4b}`. 

357 l_max_5b : int 

358 Maximum expansion order for five-body terms :math:`l_\mathrm{max}^\mathrm{5b}`. 

359 has_q_112 : int 

360 Flag enabling the 5-body :math:`q_{112}` descriptor (0 or 1). 

361 has_q_123 : int 

362 Flag enabling the 5-body :math:`q_{123}` descriptor (0 or 1). 

363 has_q_233 : int 

364 Flag enabling the 5-body :math:`q_{233}` descriptor (0 or 1). 

365 has_q_134 : int 

366 Flag enabling the higher-body :math:`q_{134}` descriptor (0 or 1). 

367 n_descriptor_radial : int 

368 Dimension of radial part of descriptor. 

369 n_descriptor_angular : int 

370 Dimension of angular part of descriptor. 

371 n_neuron : int 

372 Number of neurons in hidden layer. 

373 n_parameters : int 

374 Total number of parameters including scalers (which are not fit parameters). 

375 n_descriptor_parameters : int 

376 Number of parameters in descriptor. 

377 n_ann_parameters : int 

378 Number of neural network weights. 

379 ann_parameters : dict[tuple[str, dict[str, np.darray]]] 

380 Neural network weights. 

381 q_scaler : List[float] 

382 Scaling parameters. 

383 radial_descriptor_weights : dict[tuple[str, str], np.ndarray] 

384 Radial descriptor weights by combination of species; the array for each combination 

385 has dimensions of 

386 :math:`(n_\mathrm{max}^\mathrm{R}+1) \times (n_\mathrm{basis}^\mathrm{R}+1)`. 

387 angular_descriptor_weights : dict[tuple[str, str], np.ndarray] 

388 Angular descriptor weights by combination of species; the array for each combination 

389 has dimensions of 

390 :math:`(n_\mathrm{max}^\mathrm{A}+1) \times (n_\mathrm{basis}^\mathrm{A}+1)`. 

391 sqrt_epsilon_infinity : Optional[float] 

392 Square root of epsilon infinity $\epsilon_\infty$ (only for NEP models with charges). 

393 restart_parameters : dict[str, dict[str, dict[str, np.ndarray]]] 

394 NEP restart parameters. A nested dictionary that contains the mean (mu) and standard 

395 deviation (sigma) for the ANN and descriptor parameters. Is set using the 

396 py:meth:`~Model.read_restart` method. Defaults to None. 

397 """ 

398 

399 version: int 

400 model_type: str 

401 types: tuple[str, ...] 

402 

403 radial_cutoff: float | list[float] 

404 angular_cutoff: float | list[float] 

405 

406 n_basis_radial: int 

407 n_basis_angular: int 

408 n_max_radial: int 

409 n_max_angular: int 

410 l_max_3b: int 

411 l_max_4b: int 

412 l_max_5b: int 

413 has_q_112: int 

414 has_q_123: int 

415 has_q_233: int 

416 has_q_134: int 

417 n_descriptor_radial: int 

418 n_descriptor_angular: int 

419 

420 n_neuron: int 

421 n_parameters: int 

422 n_descriptor_parameters: int 

423 n_ann_parameters: int 

424 ann_parameters: NetworkWeights 

425 q_scaler: list[float] 

426 radial_descriptor_weights: DescriptorWeights 

427 angular_descriptor_weights: DescriptorWeights 

428 sqrt_epsilon_infinity: float = None 

429 restart_parameters: RestartParameters = None 

430 

431 zbl: tuple[float, float] = None 

432 zbl_typewise_cutoff_factor: float = None 

433 max_neighbors_radial: int = None 

434 max_neighbors_angular: int = None 

435 radial_typewise_cutoff_factor: float = None 

436 angular_typewise_cutoff_factor: float = None 

437 

438 _special_fields = [ 

439 'ann_parameters', 

440 'q_scaler', 

441 'radial_descriptor_weights', 

442 'angular_descriptor_weights', 

443 ] 

444 

445 def __str__(self) -> str: 

446 s = [] 

447 for fld in self.__dataclass_fields__: 

448 if fld not in self._special_fields: 

449 value = getattr(self, fld) 

450 if fld == 'restart_parameters': 

451 value = 'available' if value is not None else 'not available' 

452 s += [f'{fld:22} : {value}'] 

453 return '\n'.join(s) 

454 

455 def _repr_html_(self) -> str: 

456 s = [] 

457 s += ['<table border="1" class="dataframe"'] 

458 s += [ 

459 '<thead><tr><th style="text-align: left;">Field</th><th>Value</th></tr></thead>' 

460 ] 

461 s += ['<tbody>'] 

462 for fld in self.__dataclass_fields__: 

463 if fld not in self._special_fields: 

464 value = getattr(self, fld) 

465 if fld == 'restart_parameters': 

466 value = 'available' if value is not None else 'not available' 

467 s += [ 

468 f'<tr><td style="text-align: left;">{fld:22}</td>' 

469 f'<td>{value}</td><tr>' 

470 ] 

471 for fld in self._special_fields: 

472 d = getattr(self, fld) 

473 # print('xxx', fld, d) 

474 if fld.endswith('descriptor_weights'): 

475 dim = list(d.values())[0].shape 

476 elif fld == 'ann_parameters' and self.version == 4: 

477 dim = (len(self.types), len(list(d.values())[0])) 

478 else: 

479 dim = len(d) 

480 s += [ 

481 f'<tr><td style="text-align: left;">Dimension of {fld:22}</td><td>{dim}</td><tr>' 

482 ] 

483 s += ['</tbody>'] 

484 s += ['</table>'] 

485 return ''.join(s) 

486 

487 @property 

488 def training_parameters(self) -> dict: 

489 """Return model hyperparameters in the format accepted by :func:`write_nepfile 

490 <calorine.nep.write_nepfile>`. 

491 

492 Use this after any model modification (:meth:`augment`, :meth:`add_species`, 

493 :meth:`remove_species`, :meth:`keep_species`) to produce the architecture fields 

494 that must go into the new ``nep.in`` before training. Merge the result with your 

495 existing training-specific parameters (``lambda_*``, ``generation``, ``batch``, 

496 etc.) before calling :func:`write_nepfile <calorine.nep.write_nepfile>`. 

497 

498 Returns 

499 ------- 

500 dict 

501 Keys ``version``, ``type``, ``cutoff``, ``n_max``, ``basis_size``, ``l_max``, 

502 and ``neuron`` (plus ``zbl`` when applicable) with values in the format 

503 expected by :func:`write_nepfile <calorine.nep.write_nepfile>`. 

504 

505 """ 

506 l_max = [self.l_max_3b, self.l_max_4b, self.l_max_5b, 

507 self.has_q_112, self.has_q_123, self.has_q_233, self.has_q_134] 

508 while len(l_max) > 1 and l_max[-1] == 0: 

509 l_max = l_max[:-1] 

510 

511 if isinstance(self.radial_cutoff, list): 

512 cutoff = [] 

513 for rc, ac in zip(self.radial_cutoff, self.angular_cutoff): 

514 cutoff += [rc, ac] 

515 else: 

516 cutoff = [self.radial_cutoff, self.angular_cutoff] 

517 

518 params = { 

519 'version': self.version, 

520 'type': [len(self.types)] + list(self.types), 

521 'cutoff': cutoff, 

522 'n_max': [self.n_max_radial, self.n_max_angular], 

523 'basis_size': [self.n_basis_radial, self.n_basis_angular], 

524 'l_max': l_max, 

525 'neuron': self.n_neuron, 

526 } 

527 if self.zbl is not None: 

528 params['zbl'] = list(self.zbl) 

529 return params 

530 

531 def remove_species(self, 

532 species: list[str], 

533 sigma_factor: float = 0.1, 

534 sigma_floor: float = 1e-6) -> 'Model': 

535 """Remove one or more species from the model. 

536 

537 Returns a new :class:`Model` with the specified species removed. 

538 The source model is not modified. 

539 

540 If ``restart_parameters`` are loaded, the surviving parameters receive 

541 adaptive SNES sigma values: ``sigma = max(sigma_floor, sigma_factor * |mu|)``, 

542 re-opening the search distribution while preserving dormant parameters. 

543 

544 Parameters 

545 ---------- 

546 species 

547 Species names to remove. 

548 sigma_factor 

549 Used only when restart is loaded: ``sigma = max(sigma_floor, sigma_factor * |mu|)`` 

550 for surviving parameters. 

551 sigma_floor 

552 Minimum sigma for surviving parameters when restart is loaded. 

553 

554 Returns 

555 ------- 

556 Model 

557 New model with the specified species removed. 

558 

559 Raises 

560 ------ 

561 ValueError 

562 If any of the provided species is not found in the model. 

563 """ 

564 for s in species: 

565 if s not in self.types: 

566 raise ValueError(f'{s} is not a species supported by the NEP model') 

567 

568 new = copy.deepcopy(self) 

569 types_to_keep = [t for t in self.types if t not in species] 

570 new.types = tuple(types_to_keep) 

571 

572 # Prune ANN parameters (for NEP4 and NEP5) 

573 if self.version in [4, 5]: 

574 new.ann_parameters = { 

575 key: value for key, value in new.ann_parameters.items() 

576 if key in types_to_keep or key.startswith('b1') 

577 } 

578 

579 # Prune descriptor weights; key is a (species1, species2) tuple 

580 new.radial_descriptor_weights = { 

581 key: value for key, value in new.radial_descriptor_weights.items() 

582 if key[0] in types_to_keep and key[1] in types_to_keep 

583 } 

584 new.angular_descriptor_weights = { 

585 key: value for key, value in new.angular_descriptor_weights.items() 

586 if key[0] in types_to_keep and key[1] in types_to_keep 

587 } 

588 

589 # Prune typewise cutoff lists so remaining species map to correct cutoffs 

590 if isinstance(self.radial_cutoff, list): 

591 indices = [i for i, t in enumerate(self.types) if t not in species] 

592 new.radial_cutoff = [self.radial_cutoff[i] for i in indices] 

593 new.angular_cutoff = [self.angular_cutoff[i] for i in indices] 

594 

595 # Prune and optionally re-open restart parameters 

596 if new.restart_parameters is not None: 

597 ann_keys = types_to_keep if self.version in [4, 5] else ['all_species'] 

598 for param_type in ['mu', 'sigma']: 

599 ann_key = f'ann_{param_type}' 

600 if self.version in [4, 5]: 

601 # Keep per-species keys for survivors, global bias keys, and 

602 # sqrt_epsilon_infinity (charge models) 

603 new.restart_parameters[ann_key] = { 

604 key: value for key, value in new.restart_parameters[ann_key].items() 

605 if (key in types_to_keep or key.startswith('b1') 

606 or key == 'sqrt_epsilon_infinity') 

607 } 

608 

609 # Prune descriptor restart parameters 

610 for desc_type in ['radial', 'angular']: 

611 key = f'{desc_type}_descriptor_{param_type}' 

612 new.restart_parameters[key] = { 

613 k: v for k, v in new.restart_parameters[key].items() 

614 if k[0] in types_to_keep and k[1] in types_to_keep 

615 } 

616 

617 # Apply adaptive sigma to all surviving parameters 

618 _apply_adaptive_sigma_to_restart( 

619 new.restart_parameters, ann_keys, sigma_factor, sigma_floor 

620 ) 

621 

622 # Recalculate parameter counts 

623 _recalculate_parameter_counts(new) 

624 

625 return new 

626 

627 def keep_species(self, 

628 species: list[str], 

629 sigma_factor: float = 0.1, 

630 sigma_floor: float = 1e-6) -> 'Model': 

631 """Retain only the specified species, removing all others. 

632 

633 Convenience complement to :meth:`remove_species`. Useful when the set 

634 of species to drop is large (e.g. isolating two elements from a 

635 foundation model with dozens of species). 

636 

637 Parameters 

638 ---------- 

639 species 

640 Species names to keep. All other species are removed. 

641 sigma_factor 

642 Passed to :meth:`remove_species`. Controls adaptive sigma for 

643 surviving parameters when restart is loaded. 

644 sigma_floor 

645 Passed to :meth:`remove_species`. Minimum sigma for surviving 

646 parameters. 

647 

648 Returns 

649 ------- 

650 Model 

651 New model containing only the requested species. 

652 

653 Raises 

654 ------ 

655 ValueError 

656 If any of the requested species is not in the model. 

657 """ 

658 unknown = [s for s in species if s not in self.types] 

659 if unknown: 

660 raise ValueError( 

661 f'Species not in model: {unknown}' 

662 ) 

663 to_remove = [s for s in self.types if s not in species] 

664 return self.remove_species(to_remove, sigma_factor=sigma_factor, sigma_floor=sigma_floor) 

665 

666 def add_species(self, 

667 species: list[str], 

668 radial_cutoff: float | list[float] = None, 

669 angular_cutoff: float | list[float] = None, 

670 sigma_new: float = 0.1, 

671 sigma_factor: float = 0.1, 

672 sigma_floor: float = 1e-6, 

673 seed: int | None = None) -> 'Model': 

674 """Add one or more species to the model. 

675 

676 Returns a new :class:`Model` with the requested species added. New ANN 

677 sub-networks and descriptor weight pairs are initialised by drawing 

678 ``mu`` uniformly from [-1, 1] (matching the GPUMD fresh-model 

679 initialisation), with ``sigma = sigma_new`` in the restart. 

680 Charge-specific parameters (``w1_charge``) are kept at ``mu = 0`` to 

681 preserve stability, also matching GPUMD. 

682 Existing parameters receive adaptive sigma: 

683 ``sigma = max(sigma_floor, sigma_factor * |mu|)``. 

684 

685 Only supported for NEP4 models. For NEP3 the ANN is shared across all 

686 species and adding a per-species sub-network is not meaningful. 

687 

688 Parameters 

689 ---------- 

690 species 

691 New species names to add. Appended to ``types`` in the order given. 

692 radial_cutoff 

693 Radial cutoff(s) for the new species, in Å. Required when the model 

694 uses typewise cutoffs (i.e. ``isinstance(model.radial_cutoff, list)`` 

695 is ``True``). Pass a single float or a list with one value per new 

696 species. 

697 angular_cutoff 

698 Angular cutoff(s) for the new species, in Å. Same requirements as 

699 ``radial_cutoff``. 

700 sigma_new 

701 SNES sigma assigned to all newly created parameters. Defaults to 

702 ``0.1``, matching the GPUMD ``sigma0`` default. 

703 sigma_factor 

704 Controls sigma for *existing* parameters: 

705 ``sigma = max(sigma_floor, sigma_factor * |mu|)``. 

706 sigma_floor 

707 Minimum sigma for existing parameters. 

708 seed 

709 Seed for the random number generator used to draw the initial ``mu`` 

710 values. Pass an integer for reproducible initialisation. 

711 

712 Returns 

713 ------- 

714 Model 

715 New model with updated structure, weights, and restart statistics. 

716 

717 Raises 

718 ------ 

719 ValueError 

720 If the model version is not 4, if ``restart_parameters`` are not 

721 loaded, if any species is already in the model, or if typewise 

722 cutoffs are used and ``radial_cutoff``/``angular_cutoff`` are not 

723 provided. 

724 """ 

725 if self.version != 4: 

726 raise ValueError( 

727 f'add_species() only supports NEP4 models; got version {self.version}.' 

728 ) 

729 for s in species: 

730 if s in self.types: 

731 raise ValueError(f'{s!r} is already in the model.') 

732 if self.restart_parameters is None: 

733 raise ValueError( 

734 'restart_parameters must be loaded before calling add_species(). ' 

735 'Pass restart_file= to read_model() or call model.read_restart() first.' 

736 ) 

737 

738 uses_typewise = isinstance(self.radial_cutoff, list) 

739 if uses_typewise: 

740 if radial_cutoff is None or angular_cutoff is None: 

741 raise ValueError( 

742 'Model uses typewise cutoffs; provide radial_cutoff and angular_cutoff ' 

743 'for the new species.' 

744 ) 

745 rc_list = ([radial_cutoff] * len(species) 

746 if isinstance(radial_cutoff, (int, float)) else list(radial_cutoff)) 

747 ac_list = ([angular_cutoff] * len(species) 

748 if isinstance(angular_cutoff, (int, float)) else list(angular_cutoff)) 

749 if len(rc_list) != len(species) or len(ac_list) != len(species): 

750 raise ValueError( 

751 'Length of radial_cutoff/angular_cutoff must match the number of new species.' 

752 ) 

753 

754 new = copy.deepcopy(self) 

755 

756 n_descriptor = self.n_descriptor_radial + self.n_descriptor_angular 

757 n_neuron = self.n_neuron 

758 is_charged = self.model_type == 'potential_with_charges' 

759 all_types_after = list(self.types) + list(species) 

760 rng = np.random.default_rng(seed) 

761 

762 def _rand(shape): 

763 return rng.uniform(-1.0, 1.0, size=shape) 

764 

765 # Step 1: Adaptive sigma for existing parameters 

766 _apply_adaptive_sigma_to_restart( 

767 new.restart_parameters, list(self.types), sigma_factor, sigma_floor 

768 ) 

769 

770 # Step 2: New ANN sub-networks 

771 w1_shape = (n_neuron,) if is_charged else (1, n_neuron) 

772 for s_new in species: 

773 w0_vals = _rand((n_neuron, n_descriptor)) 

774 b0_vals = _rand((n_neuron, 1)) 

775 w1_vals = _rand(w1_shape) 

776 s_params = {'w0': w0_vals.copy(), 'b0': b0_vals.copy(), 'w1': w1_vals.copy()} 

777 if is_charged: 

778 s_params['w1_charge'] = np.zeros(n_neuron) 

779 new.ann_parameters[s_new] = s_params 

780 

781 mu_entry = {'w0': w0_vals, 'b0': b0_vals, 'w1': w1_vals} 

782 sigma_entry = { 

783 'w0': np.full((n_neuron, n_descriptor), sigma_new), 

784 'b0': np.full((n_neuron, 1), sigma_new), 

785 'w1': np.full(w1_shape, sigma_new), 

786 } 

787 if is_charged: 

788 mu_entry['w1_charge'] = np.zeros(n_neuron) 

789 sigma_entry['w1_charge'] = np.full(n_neuron, sigma_new) 

790 new.restart_parameters['ann_mu'][s_new] = mu_entry 

791 new.restart_parameters['ann_sigma'][s_new] = sigma_entry 

792 

793 # Step 3: New descriptor weight pairs 

794 n_r = (self.n_max_radial + 1, self.n_basis_radial + 1) 

795 n_a = (self.n_max_angular + 1, self.n_basis_angular + 1) 

796 existing_pairs = set(self.radial_descriptor_weights) 

797 new_pairs = { 

798 (s1, s2) 

799 for s1 in all_types_after for s2 in all_types_after 

800 if (s1, s2) not in existing_pairs 

801 } 

802 for pair in new_pairs: 

803 r_vals = _rand(n_r) 

804 a_vals = _rand(n_a) 

805 new.radial_descriptor_weights[pair] = r_vals.copy() 

806 new.angular_descriptor_weights[pair] = a_vals.copy() 

807 new.restart_parameters['radial_descriptor_mu'][pair] = r_vals 

808 new.restart_parameters['angular_descriptor_mu'][pair] = a_vals 

809 new.restart_parameters['radial_descriptor_sigma'][pair] = np.full(n_r, sigma_new) 

810 new.restart_parameters['angular_descriptor_sigma'][pair] = np.full(n_a, sigma_new) 

811 

812 # Step 4: Update types and typewise cutoffs 

813 new.types = tuple(all_types_after) 

814 if uses_typewise: 

815 new.radial_cutoff = list(self.radial_cutoff) + rc_list 

816 new.angular_cutoff = list(self.angular_cutoff) + ac_list 

817 

818 # Step 5: Recalculate parameter counts 

819 _recalculate_parameter_counts(new) 

820 

821 return new 

822 

823 def write(self, filename: str, restart_file: str = None) -> None: 

824 """Write NEP model to file in `nep.txt` format. 

825 

826 Parameters 

827 ---------- 

828 filename 

829 Output file name for the NEP model. 

830 restart_file 

831 If provided, also write restart parameters to this file in 

832 `nep.restart` format. Defaults to None. 

833 """ 

834 with open(filename, 'w') as f: 

835 # header 

836 version_name = f'nep{self.version}' 

837 if self.zbl is not None: 

838 version_name += '_zbl' 

839 elif self.model_type != 'potential': 

840 version_name += f'_{self.model_type}' 

841 f.write(f'{version_name} {len(self.types)} {" ".join(self.types)}\n') 

842 if self.zbl is not None: 

843 f.write(f'zbl {" ".join(map(str, self.zbl))}\n') 

844 f.write('cutoff') 

845 if isinstance(self.radial_cutoff, float) and isinstance(self.angular_cutoff, float): 

846 f.write(f' {self.radial_cutoff} {self.angular_cutoff}') 

847 else: 

848 # Typewise cutoffs: one set of cutoffs per type 

849 for i in range(len(self.types)): 

850 f.write(f' {self.radial_cutoff[i]} {self.angular_cutoff[i]}') 

851 f.write(f' {self.max_neighbors_radial} {self.max_neighbors_angular}') 

852 f.write('\n') 

853 f.write(f'n_max {self.n_max_radial} {self.n_max_angular}\n') 

854 f.write(f'basis_size {self.n_basis_radial} {self.n_basis_angular}\n') 

855 l_max_line = f'l_max {self.l_max_3b} {self.l_max_4b} {self.l_max_5b}' 

856 if self.has_q_112 or self.has_q_123 or self.has_q_233 or self.has_q_134: 

857 l_max_line += f' {self.has_q_112}' 

858 if self.has_q_123 or self.has_q_233 or self.has_q_134: 

859 l_max_line += f' {self.has_q_123}' 

860 if self.has_q_233 or self.has_q_134: 

861 l_max_line += f' {self.has_q_233}' 

862 if self.has_q_134: 

863 l_max_line += f' {self.has_q_134}' 

864 f.write(l_max_line + '\n') 

865 f.write(f'ANN {self.n_neuron} 0\n') 

866 

867 # neural network weights 

868 keys = self.types if self.version in (4, 5) else ['all_species'] 

869 suffixes = ['', '_polar'] if self.model_type == 'polarizability' else [''] 

870 for suffix in suffixes: 

871 for s in keys: 

872 # Order: w0, b0, w1 (, b1 if NEP5) 

873 # w0 indexed as: n*N_descriptor + nu 

874 w0 = self.ann_parameters[s][f'w0{suffix}'] 

875 b0 = self.ann_parameters[s][f'b0{suffix}'] 

876 w1 = self.ann_parameters[s][f'w1{suffix}'] 

877 for n in range(self.n_neuron): 

878 for nu in range( 

879 self.n_descriptor_radial + self.n_descriptor_angular 

880 ): 

881 f.write(f'{w0[n, nu]:15.7e}\n') 

882 for b in b0[:, 0]: 

883 f.write(f'{b:15.7e}\n') 

884 for v in (w1[0, :] if w1.ndim == 2 else w1): 

885 f.write(f'{v:15.7e}\n') 

886 if f'w1_charge{suffix}' in self.ann_parameters[s]: 

887 for v in self.ann_parameters[s][f'w1_charge{suffix}']: 

888 f.write(f'{v:15.7e}\n') 

889 if self.version == 5: 

890 b1 = self.ann_parameters[s][f'b1{suffix}'] 

891 f.write(f'{b1:15.7e}\n') 

892 if self.sqrt_epsilon_infinity is not None: 

893 f.write(f'{self.sqrt_epsilon_infinity:15.7e}\n') 

894 b1 = self.ann_parameters[f'b1{suffix}'] 

895 f.write(f'{b1:15.7e}\n') 

896 

897 # descriptor weights 

898 mat = [] 

899 for s1 in self.types: 

900 for s2 in self.types: 

901 mat = np.hstack( 

902 [mat, self.radial_descriptor_weights[(s1, s2)].flatten()] 

903 ) 

904 mat = np.hstack( 

905 [mat, self.angular_descriptor_weights[(s1, s2)].flatten()] 

906 ) 

907 n_types = len(self.types) 

908 n = int(len(mat) / (n_types * n_types)) 

909 mat = mat.reshape((n_types * n_types, n)).T 

910 for v in mat.flatten(): 

911 f.write(f'{v:15.7e}\n') 

912 

913 # scaler 

914 for v in self.q_scaler: 

915 f.write(f'{v:15.7e}\n') 

916 

917 if restart_file is not None: 

918 self.write_restart(restart_file) 

919 

920 def read_restart(self, filename: str): 

921 """Parses a file in `nep.restart` format and saves the 

922 content in the form of mean and standard deviation for each 

923 parameter in the corresponding NEP model. 

924 

925 Parameters 

926 ---------- 

927 filename 

928 Input file name. 

929 """ 

930 mu, sigma = _get_restart_contents(filename) 

931 restart_parameters = np.array([mu, sigma]).T 

932 

933 is_polarizability_model = self.model_type == 'polarizability' 

934 is_charged_model = self.model_type == 'potential_with_charges' 

935 

936 n1 = self.n_ann_parameters 

937 n1 *= 2 if is_polarizability_model else 1 

938 n2 = n1 + self.n_descriptor_parameters 

939 ann_parameters = restart_parameters[:n1] 

940 descriptor_parameters = np.array(restart_parameters[n1:n2]) 

941 

942 if self.version == 3: 

943 n_networks = 1 

944 n_bias = 1 

945 elif self.version == 4: 

946 # one hidden layer per atomic species 

947 n_networks = len(self.types) 

948 n_bias = 1 

949 else: 

950 raise ValueError(f'Cannot load nep.restart for NEP model version {self.version}') 

951 

952 ann_groups = [s for s in self.ann_parameters.keys() if not s.startswith('b1')] 

953 n_bias = len([s for s in self.ann_parameters.keys() if s.startswith('b1')]) 

954 if self.sqrt_epsilon_infinity is not None: 

955 n_bias += 1 # charge models have sqrt_epsilon_infinity before b1 

956 n_descriptor = self.n_descriptor_radial + self.n_descriptor_angular 

957 restart = {} 

958 

959 for i, content_type in enumerate(['mu', 'sigma']): 

960 ann = _sort_ann_parameters(ann_parameters[:, i], 

961 ann_groups, 

962 self.n_neuron, 

963 n_networks, 

964 n_bias, 

965 n_descriptor, 

966 is_polarizability_model, 

967 is_charged_model) 

968 radial, angular = _sort_descriptor_parameters(descriptor_parameters[:, i], 

969 self.types, 

970 self.n_max_radial, 

971 self.n_basis_radial, 

972 self.n_max_angular, 

973 self.n_basis_angular) 

974 

975 restart[f'ann_{content_type}'] = ann 

976 restart[f'radial_descriptor_{content_type}'] = radial 

977 restart[f'angular_descriptor_{content_type}'] = angular 

978 self.restart_parameters = restart 

979 

980 def write_restart(self, filename: str): 

981 """Write NEP restart parameters to file in `nep.restart` format.""" 

982 keys = self.types if self.version in (4, 5) else ['all_species'] 

983 suffixes = ['', '_polar'] if self.model_type == 'polarizability' else [''] 

984 columns = [] 

985 for i, parameter in enumerate(['mu', 'sigma']): 

986 # neural network weights 

987 ann_parameters = self.restart_parameters[f'ann_{parameter}'] 

988 column = [] 

989 for suffix in suffixes: 

990 for s in keys: 

991 # Order: w0, b0, w1 (, b1 if NEP5) 

992 # w0 indexed as: n*N_descriptor + nu 

993 w0 = ann_parameters[s][f'w0{suffix}'] 

994 b0 = ann_parameters[s][f'b0{suffix}'] 

995 w1 = ann_parameters[s][f'w1{suffix}'] 

996 for n in range(self.n_neuron): 

997 for nu in range( 

998 self.n_descriptor_radial + self.n_descriptor_angular 

999 ): 

1000 column.append(f'{w0[n, nu]:15.7e}') 

1001 for b in b0[:, 0]: 

1002 column.append(f'{b:15.7e}') 

1003 for v in (w1[0, :] if w1.ndim == 2 else w1): 

1004 column.append(f'{v:15.7e}') 

1005 if f'w1_charge{suffix}' in ann_parameters[s]: 

1006 for v in ann_parameters[s][f'w1_charge{suffix}']: 

1007 column.append(f'{v:15.7e}') 

1008 if f'sqrt_epsilon_infinity{suffix}' in ann_parameters: 

1009 column.append(f'{ann_parameters[f"sqrt_epsilon_infinity{suffix}"]:15.7e}') 

1010 b1 = ann_parameters[f'b1{suffix}'] 

1011 column.append(f'{b1:15.7e}') 

1012 columns.append(column) 

1013 

1014 # descriptor weights 

1015 radial_descriptor_parameters = self.restart_parameters[f'radial_descriptor_{parameter}'] 

1016 angular_descriptor_parameters = self.restart_parameters[ 

1017 f'angular_descriptor_{parameter}'] 

1018 mat = [] 

1019 for s1 in self.types: 

1020 for s2 in self.types: 

1021 mat = np.hstack( 

1022 [mat, radial_descriptor_parameters[(s1, s2)].flatten()] 

1023 ) 

1024 mat = np.hstack( 

1025 [mat, angular_descriptor_parameters[(s1, s2)].flatten()] 

1026 ) 

1027 n_types = len(self.types) 

1028 n = int(len(mat) / (n_types * n_types)) 

1029 mat = mat.reshape((n_types * n_types, n)).T 

1030 for v in mat.flatten(): 

1031 column.append(f'{v:15.7e}') 

1032 

1033 # Join the mean and standard deviation columns 

1034 assert len(columns[0]) == len(columns[1]), 'Length of means must match standard deviation' 

1035 joined = [f'{s1} {s2}\n' for s1, s2 in zip(*columns)] 

1036 with open(filename, 'w') as f: 

1037 f.writelines(joined) 

1038 

1039 def augment(self, 

1040 n_neuron: int = None, 

1041 l_max_4b: int = None, 

1042 l_max_5b: int = None, 

1043 has_q_112: bool = None, 

1044 has_q_123: bool = None, 

1045 has_q_233: bool = None, 

1046 has_q_134: bool = None, 

1047 charge_head: bool = False, 

1048 sigma_new: float = 0.01, 

1049 sigma_factor: float = 0.1, 

1050 sigma_floor: float = 1e-6) -> 'Model': 

1051 """Augment the model by adding neurons, descriptor terms, or a charge output head. 

1052 

1053 Returns a new :class:`Model` with the requested structural changes applied. 

1054 The source model is not modified. Existing parameter values are preserved exactly; 

1055 new parameters are initialized to zero. The restart SNES statistics are updated 

1056 as follows: 

1057 

1058 - Existing parameters: ``sigma = max(sigma_floor, sigma_factor * |mu|)``, which 

1059 re-opens the SNES search distribution while keeping parameters that were driven 

1060 toward zero effectively dormant. 

1061 - New parameters: ``mu = 0``, ``sigma = sigma_new``. 

1062 

1063 Parameters 

1064 ---------- 

1065 n_neuron 

1066 Target neuron count; must be >= current. ``None`` leaves unchanged. 

1067 l_max_4b 

1068 Target 4-body l_max value; must be >= current. ``None`` leaves unchanged. 

1069 l_max_5b 

1070 Target 5-body l_max value; must be >= current. ``None`` leaves unchanged. 

1071 has_q_112 

1072 ``True`` enables the q_112 5-body descriptor; ``None`` or ``False`` leaves 

1073 the current state unchanged (disabling an already-enabled term raises). 

1074 has_q_123 

1075 Same as ``has_q_112`` but for the q_123 term. 

1076 has_q_233 

1077 Same as ``has_q_112`` but for the q_233 term. 

1078 has_q_134 

1079 Same as ``has_q_112`` but for the q_134 term. 

1080 charge_head 

1081 If ``True``, promote a ``potential`` model to ``potential_with_charges`` by 

1082 adding a charge output head (w1_charge per species and sqrt_epsilon_infinity). 

1083 sigma_new 

1084 SNES sigma assigned to all newly created parameters. 

1085 sigma_factor 

1086 Controls the sigma for *existing* parameters: 

1087 ``sigma = max(sigma_floor, sigma_factor * |mu|)``. 

1088 sigma_floor 

1089 Minimum sigma for existing parameters; keeps near-zero (dormant) parameters 

1090 from being accidentally re-activated. 

1091 

1092 Returns 

1093 ------- 

1094 Model 

1095 New model with updated structure, weights, and restart statistics. 

1096 

1097 Raises 

1098 ------ 

1099 ValueError 

1100 If ``restart_parameters`` is not loaded, if ``n_neuron`` or an ``l_max_*`` 

1101 target is smaller than the current value, if a ``has_q_*`` flag attempts to 

1102 disable an already-enabled term, or if ``charge_head=True`` on a model that 

1103 is not of type ``potential``. 

1104 """ 

1105 # Structural checks (independent of restart) 

1106 if self.version not in (3, 4): 

1107 raise ValueError( 

1108 f'augment() only supports NEP versions 3 and 4; got version {self.version}.' 

1109 ) 

1110 if n_neuron is not None and n_neuron < self.n_neuron: 

1111 raise ValueError( 

1112 f'n_neuron ({n_neuron}) must be >= current n_neuron ({self.n_neuron}); ' 

1113 'use prune() to reduce.' 

1114 ) 

1115 if l_max_4b is not None and l_max_4b < self.l_max_4b: 

1116 raise ValueError( 

1117 f'l_max_4b ({l_max_4b}) must be >= current l_max_4b ({self.l_max_4b}); ' 

1118 'use prune() to disable.' 

1119 ) 

1120 if l_max_5b is not None and l_max_5b < self.l_max_5b: 

1121 raise ValueError( 

1122 f'l_max_5b ({l_max_5b}) must be >= current l_max_5b ({self.l_max_5b}); ' 

1123 'use prune() to disable.' 

1124 ) 

1125 for flag_val, name in [ 

1126 (has_q_112, 'has_q_112'), (has_q_123, 'has_q_123'), (has_q_233, 'has_q_233'), 

1127 (has_q_134, 'has_q_134') 

1128 ]: 

1129 if flag_val is False and getattr(self, name): 

1130 raise ValueError( 

1131 f'Cannot disable {name} via augment(); ' 

1132 'use prune() to disable descriptor terms.' 

1133 ) 

1134 if charge_head and self.model_type != 'potential': 

1135 raise ValueError( 

1136 f'charge_head=True requires model_type="potential"; ' 

1137 f'got "{self.model_type}".' 

1138 ) 

1139 if self.restart_parameters is None: 

1140 raise ValueError( 

1141 'restart_parameters must be loaded before calling augment(). ' 

1142 'Pass restart_file= to read_model() or call model.read_restart() first.' 

1143 ) 

1144 

1145 new = copy.deepcopy(self) 

1146 

1147 # Resolve new structural parameters 

1148 new_l_max_4b = l_max_4b if l_max_4b is not None else self.l_max_4b 

1149 new_l_max_5b = l_max_5b if l_max_5b is not None else self.l_max_5b 

1150 new_has_q_112 = int(has_q_112) if has_q_112 is not None else self.has_q_112 

1151 new_has_q_123 = int(has_q_123) if has_q_123 is not None else self.has_q_123 

1152 new_has_q_233 = int(has_q_233) if has_q_233 is not None else self.has_q_233 

1153 new_has_q_134 = int(has_q_134) if has_q_134 is not None else self.has_q_134 

1154 new_n_neuron = n_neuron if n_neuron is not None else self.n_neuron 

1155 

1156 new_l_max_enh = (self.l_max_3b 

1157 + (new_l_max_4b > 0) + (new_l_max_5b > 0) 

1158 + (new_has_q_112 > 0) + (new_has_q_123 > 0) + (new_has_q_233 > 0) 

1159 + (new_has_q_134 > 0)) 

1160 new_n_desc_angular = (self.n_max_angular + 1) * new_l_max_enh 

1161 old_n_desc = self.n_descriptor_radial + self.n_descriptor_angular 

1162 new_n_desc = self.n_descriptor_radial + new_n_desc_angular 

1163 delta_desc = new_n_desc - old_n_desc 

1164 delta_neuron = new_n_neuron - self.n_neuron 

1165 

1166 keys = self.types if self.version in (4, 5) else ['all_species'] 

1167 

1168 # Step 1: Apply adaptive sigma to all existing parameters (re-open SNES search width) 

1169 _apply_adaptive_sigma_to_restart(new.restart_parameters, keys, sigma_factor, sigma_floor) 

1170 

1171 # Step 2: Expand descriptor dimensions (new columns in w0, new q_scaler entries) 

1172 if delta_desc > 0: 

1173 for s in keys: 

1174 old_w0 = new.ann_parameters[s]['w0'] # (n_neuron_old, old_n_desc) 

1175 new.ann_parameters[s]['w0'] = np.hstack( 

1176 [old_w0, np.zeros((self.n_neuron, delta_desc))] 

1177 ) 

1178 old_mu_w0 = new.restart_parameters['ann_mu'][s]['w0'] 

1179 new.restart_parameters['ann_mu'][s]['w0'] = np.hstack( 

1180 [old_mu_w0, np.zeros((self.n_neuron, delta_desc))] 

1181 ) 

1182 old_sigma_w0 = new.restart_parameters['ann_sigma'][s]['w0'] 

1183 new.restart_parameters['ann_sigma'][s]['w0'] = np.hstack( 

1184 [old_sigma_w0, np.full((self.n_neuron, delta_desc), sigma_new)] 

1185 ) 

1186 new.q_scaler = list(new.q_scaler) + [1.0] * delta_desc 

1187 

1188 # Step 3: Expand neuron count (new rows in w0/b0, new columns in w1) 

1189 if delta_neuron > 0: 

1190 for s in keys: 

1191 # w0: append new rows 

1192 cur_w0 = new.ann_parameters[s]['w0'] # (n_old, new_n_desc) 

1193 new.ann_parameters[s]['w0'] = np.vstack( 

1194 [cur_w0, np.zeros((delta_neuron, new_n_desc))] 

1195 ) 

1196 # b0: append new rows 

1197 cur_b0 = new.ann_parameters[s]['b0'] 

1198 new.ann_parameters[s]['b0'] = np.vstack( 

1199 [cur_b0, np.zeros((delta_neuron, 1))] 

1200 ) 

1201 # w1: append new columns; handle both 2D (standard) and 1D (charge) 

1202 cur_w1 = new.ann_parameters[s]['w1'] 

1203 zeros_w1 = (np.zeros(delta_neuron) if cur_w1.ndim == 1 

1204 else np.zeros((1, delta_neuron))) 

1205 new.ann_parameters[s]['w1'] = np.hstack([cur_w1, zeros_w1]) 

1206 if 'w1_charge' in new.ann_parameters[s]: 

1207 cur_wc = new.ann_parameters[s]['w1_charge'] 

1208 new.ann_parameters[s]['w1_charge'] = np.hstack([cur_wc, np.zeros(delta_neuron)]) 

1209 

1210 # restart w0 

1211 cur_mu_w0 = new.restart_parameters['ann_mu'][s]['w0'] 

1212 new.restart_parameters['ann_mu'][s]['w0'] = np.vstack( 

1213 [cur_mu_w0, np.zeros((delta_neuron, new_n_desc))] 

1214 ) 

1215 cur_sigma_w0 = new.restart_parameters['ann_sigma'][s]['w0'] 

1216 new.restart_parameters['ann_sigma'][s]['w0'] = np.vstack( 

1217 [cur_sigma_w0, np.full((delta_neuron, new_n_desc), sigma_new)] 

1218 ) 

1219 # restart b0 

1220 cur_mu_b0 = new.restart_parameters['ann_mu'][s]['b0'] 

1221 new.restart_parameters['ann_mu'][s]['b0'] = np.vstack( 

1222 [cur_mu_b0, np.zeros((delta_neuron, 1))] 

1223 ) 

1224 cur_sigma_b0 = new.restart_parameters['ann_sigma'][s]['b0'] 

1225 new.restart_parameters['ann_sigma'][s]['b0'] = np.vstack( 

1226 [cur_sigma_b0, np.full((delta_neuron, 1), sigma_new)] 

1227 ) 

1228 # restart w1 

1229 cur_mu_w1 = new.restart_parameters['ann_mu'][s]['w1'] 

1230 zeros_w1 = (np.zeros(delta_neuron) if cur_mu_w1.ndim == 1 

1231 else np.zeros((1, delta_neuron))) 

1232 new.restart_parameters['ann_mu'][s]['w1'] = np.hstack([cur_mu_w1, zeros_w1]) 

1233 cur_sigma_w1 = new.restart_parameters['ann_sigma'][s]['w1'] 

1234 zeros_w1 = (np.full(delta_neuron, sigma_new) if cur_sigma_w1.ndim == 1 

1235 else np.full((1, delta_neuron), sigma_new)) 

1236 new.restart_parameters['ann_sigma'][s]['w1'] = np.hstack([cur_sigma_w1, zeros_w1]) 

1237 if 'w1_charge' in new.restart_parameters['ann_mu'][s]: 

1238 cur = new.restart_parameters['ann_mu'][s]['w1_charge'] 

1239 new.restart_parameters['ann_mu'][s]['w1_charge'] = np.hstack( 

1240 [cur, np.zeros(delta_neuron)] 

1241 ) 

1242 cur = new.restart_parameters['ann_sigma'][s]['w1_charge'] 

1243 new.restart_parameters['ann_sigma'][s]['w1_charge'] = np.hstack( 

1244 [cur, np.full(delta_neuron, sigma_new)] 

1245 ) 

1246 

1247 # Step 4: Add charge output head 

1248 if charge_head: 

1249 new.model_type = 'potential_with_charges' 

1250 new.sqrt_epsilon_infinity = 1.0 

1251 for s in keys: 

1252 cur_w1 = new.ann_parameters[s]['w1'] # (1, new_n_neuron) 

1253 new.ann_parameters[s]['w1'] = cur_w1[0, :] # flatten to 1D 

1254 new.ann_parameters[s]['w1_charge'] = np.zeros(new_n_neuron) 

1255 

1256 cur_mu_w1 = new.restart_parameters['ann_mu'][s]['w1'] 

1257 new.restart_parameters['ann_mu'][s]['w1'] = cur_mu_w1[0, :] 

1258 new.restart_parameters['ann_mu'][s]['w1_charge'] = np.zeros(new_n_neuron) 

1259 

1260 cur_sigma_w1 = new.restart_parameters['ann_sigma'][s]['w1'] 

1261 new.restart_parameters['ann_sigma'][s]['w1'] = cur_sigma_w1[0, :] 

1262 new.restart_parameters['ann_sigma'][s]['w1_charge'] = np.full( 

1263 new_n_neuron, sigma_new 

1264 ) 

1265 

1266 new.restart_parameters['ann_mu']['sqrt_epsilon_infinity'] = 1.0 

1267 new.restart_parameters['ann_sigma']['sqrt_epsilon_infinity'] = float(sigma_new) 

1268 

1269 # Step 5: Update header metadata 

1270 new.l_max_4b = new_l_max_4b 

1271 new.l_max_5b = new_l_max_5b 

1272 new.has_q_112 = new_has_q_112 

1273 new.has_q_123 = new_has_q_123 

1274 new.has_q_233 = new_has_q_233 

1275 new.has_q_134 = new_has_q_134 

1276 new.n_descriptor_angular = new_n_desc_angular 

1277 new.n_neuron = new_n_neuron 

1278 

1279 # Step 6: Recalculate parameter counts 

1280 _recalculate_parameter_counts(new) 

1281 

1282 return new 

1283 

1284 def prune(self, 

1285 n_neuron: int = None, 

1286 l_max_4b: int = None, 

1287 l_max_5b: int = None, 

1288 has_q_112: bool = None, 

1289 has_q_123: bool = None, 

1290 has_q_233: bool = None, 

1291 has_q_134: bool = None, 

1292 charge_head: bool = False, 

1293 sigma_factor: float = 0.1, 

1294 sigma_floor: float = 1e-6) -> 'Model': 

1295 """Prune the model by removing neurons, disabling descriptor terms, or removing 

1296 the charge output head. 

1297 

1298 Returns a new :class:`Model` with the requested structural changes applied. 

1299 The source model is not modified. When reducing ``n_neuron``, neurons are 

1300 selected by importance score averaged over species: 

1301 ``importance[n] = mean_s(||w0_s[n,:]||_2 * |w1_s[n]|)``. 

1302 

1303 All surviving parameters receive adaptive SNES sigma: 

1304 ``sigma = max(sigma_floor, sigma_factor * |mu|)``. 

1305 

1306 Parameters 

1307 ---------- 

1308 n_neuron 

1309 Target neuron count; must be <= current. ``None`` leaves unchanged. 

1310 l_max_4b 

1311 Target 4-body l_max; must be <= current. Setting to ``0`` removes the 

1312 4-body angular descriptor block. Reducing to a lower non-zero value is 

1313 a header-only change (descriptor dimensions unchanged). ``None`` leaves 

1314 unchanged. 

1315 l_max_5b 

1316 Same as ``l_max_4b`` but for five-body terms. 

1317 has_q_112 

1318 ``False`` disables and removes the q_112 descriptor block. ``None`` 

1319 leaves unchanged. ``True`` is not valid; use :meth:`augment` instead. 

1320 has_q_123 

1321 Same as ``has_q_112`` but for the q_123 term. 

1322 has_q_233 

1323 Same as ``has_q_112`` but for the q_233 term. 

1324 has_q_134 

1325 Same as ``has_q_112`` but for the q_134 term. 

1326 charge_head 

1327 If ``True``, remove the charge output head from a 

1328 ``potential_with_charges`` model, converting it back to ``potential``. 

1329 Removes ``w1_charge`` per species and ``sqrt_epsilon_infinity`` from 

1330 the restart. 

1331 sigma_factor 

1332 Controls sigma for surviving parameters: 

1333 ``sigma = max(sigma_floor, sigma_factor * |mu|)``. 

1334 sigma_floor 

1335 Minimum sigma for surviving parameters. 

1336 

1337 Returns 

1338 ------- 

1339 Model 

1340 New model with reduced structure, weights, and restart statistics. 

1341 

1342 Raises 

1343 ------ 

1344 ValueError 

1345 If ``restart_parameters`` is not loaded, if any target value would 

1346 expand the model (use :meth:`augment` instead), if a ``has_q_*`` 

1347 flag is set to ``True``, or if ``charge_head=True`` on a model 

1348 without charges. 

1349 """ 

1350 # --- Resolve target values --- 

1351 new_n_neuron = n_neuron if n_neuron is not None else self.n_neuron 

1352 new_l_max_4b = l_max_4b if l_max_4b is not None else self.l_max_4b 

1353 new_l_max_5b = l_max_5b if l_max_5b is not None else self.l_max_5b 

1354 new_has_q_112 = 0 if has_q_112 is False else self.has_q_112 

1355 new_has_q_123 = 0 if has_q_123 is False else self.has_q_123 

1356 new_has_q_233 = 0 if has_q_233 is False else self.has_q_233 

1357 new_has_q_134 = 0 if has_q_134 is False else self.has_q_134 

1358 

1359 # --- Validate --- 

1360 if self.version not in (3, 4): 

1361 raise ValueError( 

1362 f'prune() only supports NEP versions 3 and 4; got version {self.version}.' 

1363 ) 

1364 if new_n_neuron > self.n_neuron: 

1365 raise ValueError( 

1366 f'n_neuron ({new_n_neuron}) must be <= current n_neuron ({self.n_neuron}); ' 

1367 'use augment() to increase.' 

1368 ) 

1369 if new_l_max_4b > self.l_max_4b: 

1370 raise ValueError( 

1371 f'l_max_4b ({new_l_max_4b}) must be <= current l_max_4b ({self.l_max_4b}); ' 

1372 'use augment() to increase.' 

1373 ) 

1374 if new_l_max_5b > self.l_max_5b: 

1375 raise ValueError( 

1376 f'l_max_5b ({new_l_max_5b}) must be <= current l_max_5b ({self.l_max_5b}); ' 

1377 'use augment() to increase.' 

1378 ) 

1379 for flag_val, name in [ 

1380 (has_q_112, 'has_q_112'), (has_q_123, 'has_q_123'), 

1381 (has_q_233, 'has_q_233'), (has_q_134, 'has_q_134') 

1382 ]: 

1383 if flag_val is True: 

1384 raise ValueError( 

1385 f'Cannot enable {name} via prune(); ' 

1386 'use augment() to enable descriptor terms.' 

1387 ) 

1388 if charge_head and self.model_type != 'potential_with_charges': 

1389 raise ValueError( 

1390 f'charge_head=True requires model_type="potential_with_charges"; ' 

1391 f'got "{self.model_type}".' 

1392 ) 

1393 if self.restart_parameters is None: 

1394 raise ValueError( 

1395 'restart_parameters must be loaded before calling prune(). ' 

1396 'Pass restart_file= to read_model() or call model.read_restart() first.' 

1397 ) 

1398 

1399 new = copy.deepcopy(self) 

1400 keys = self.types if self.version in (4, 5) else ['all_species'] 

1401 

1402 # Step 1: Adaptive sigma for all existing parameters 

1403 _apply_adaptive_sigma_to_restart(new.restart_parameters, keys, sigma_factor, sigma_floor) 

1404 

1405 # Step 2: Neuron pruning — keep the most important neurons 

1406 if new_n_neuron < self.n_neuron: 

1407 importances = [] 

1408 for s in keys: 

1409 w0 = self.ann_parameters[s]['w0'] # (n_neuron, n_desc) 

1410 w1_flat = self.ann_parameters[s]['w1'].ravel() 

1411 if 'w1_charge' in self.ann_parameters[s]: 

1412 output_norm = np.abs(w1_flat) + np.abs(self.ann_parameters[s]['w1_charge']) 

1413 else: 

1414 output_norm = np.abs(w1_flat) 

1415 importances.append(np.linalg.norm(w0, axis=1) * output_norm) 

1416 

1417 keep_idx = np.sort(np.argsort(np.mean(importances, axis=0))[-new_n_neuron:]) 

1418 

1419 for s in keys: 

1420 new.ann_parameters[s]['w0'] = new.ann_parameters[s]['w0'][keep_idx, :] 

1421 new.ann_parameters[s]['b0'] = new.ann_parameters[s]['b0'][keep_idx, :] 

1422 w1 = new.ann_parameters[s]['w1'] 

1423 new.ann_parameters[s]['w1'] = w1[:, keep_idx] if w1.ndim == 2 else w1[keep_idx] 

1424 if 'w1_charge' in new.ann_parameters[s]: 

1425 new.ann_parameters[s]['w1_charge'] = ( 

1426 new.ann_parameters[s]['w1_charge'][keep_idx] 

1427 ) 

1428 for pk in ['ann_mu', 'ann_sigma']: 

1429 rp = new.restart_parameters[pk][s] 

1430 rp['w0'] = rp['w0'][keep_idx, :] 

1431 rp['b0'] = rp['b0'][keep_idx, :] 

1432 w1 = rp['w1'] 

1433 rp['w1'] = w1[:, keep_idx] if w1.ndim == 2 else w1[keep_idx] 

1434 if 'w1_charge' in rp: 

1435 rp['w1_charge'] = rp['w1_charge'][keep_idx] 

1436 

1437 # Step 3: Descriptor column pruning (disabling higher-body terms) 

1438 n_per = self.n_max_angular + 1 

1439 hb_terms = [ 

1440 (self.l_max_4b, new_l_max_4b), 

1441 (self.l_max_5b, new_l_max_5b), 

1442 (self.has_q_112, new_has_q_112), 

1443 (self.has_q_123, new_has_q_123), 

1444 (self.has_q_233, new_has_q_233), 

1445 (self.has_q_134, new_has_q_134), 

1446 ] 

1447 keep_cols = list(range(self.n_descriptor_radial + n_per * self.l_max_3b)) 

1448 col_offset = len(keep_cols) 

1449 for old_val, new_val in hb_terms: 

1450 if old_val > 0: 

1451 if new_val > 0: 

1452 keep_cols.extend(range(col_offset, col_offset + n_per)) 

1453 col_offset += n_per 

1454 

1455 old_n_desc = self.n_descriptor_radial + self.n_descriptor_angular 

1456 if len(keep_cols) < old_n_desc: 

1457 keep_cols = np.array(keep_cols, dtype=int) 

1458 for s in keys: 

1459 new.ann_parameters[s]['w0'] = new.ann_parameters[s]['w0'][:, keep_cols] 

1460 for pk in ['ann_mu', 'ann_sigma']: 

1461 rp = new.restart_parameters[pk][s] 

1462 rp['w0'] = rp['w0'][:, keep_cols] 

1463 new.q_scaler = [new.q_scaler[i] for i in keep_cols] 

1464 

1465 # Step 4: Charge head removal 

1466 if charge_head: 

1467 new.model_type = 'potential' 

1468 new.sqrt_epsilon_infinity = None 

1469 for s in keys: 

1470 w1 = new.ann_parameters[s]['w1'] # 1D (n_neuron,) 

1471 new.ann_parameters[s]['w1'] = w1.reshape(1, -1) 

1472 del new.ann_parameters[s]['w1_charge'] 

1473 for pk in ['ann_mu', 'ann_sigma']: 

1474 rp = new.restart_parameters[pk][s] 

1475 rp['w1'] = rp['w1'].reshape(1, -1) 

1476 del rp['w1_charge'] 

1477 del new.restart_parameters['ann_mu']['sqrt_epsilon_infinity'] 

1478 del new.restart_parameters['ann_sigma']['sqrt_epsilon_infinity'] 

1479 

1480 # Step 5: Update header fields 

1481 new.n_neuron = new_n_neuron 

1482 new.l_max_4b = new_l_max_4b 

1483 new.l_max_5b = new_l_max_5b 

1484 new.has_q_112 = new_has_q_112 

1485 new.has_q_123 = new_has_q_123 

1486 new.has_q_233 = new_has_q_233 

1487 new.has_q_134 = new_has_q_134 

1488 

1489 new_l_max_enh = (self.l_max_3b 

1490 + (new_l_max_4b > 0) + (new_l_max_5b > 0) 

1491 + (new_has_q_112 > 0) + (new_has_q_123 > 0) + (new_has_q_233 > 0) 

1492 + (new_has_q_134 > 0)) 

1493 new.n_descriptor_angular = (self.n_max_angular + 1) * new_l_max_enh 

1494 

1495 # Step 6: Recalculate parameter counts 

1496 _recalculate_parameter_counts(new) 

1497 

1498 return new 

1499 

1500 

1501def read_model(filename: str, restart_file: str = None) -> Model: 

1502 """Parses a file in ``nep.txt`` format and returns the 

1503 content in the form of a :class:`Model <calorine.nep.model.Model>` 

1504 object. 

1505 

1506 Parameters 

1507 ---------- 

1508 filename 

1509 Input file name. 

1510 restart_file 

1511 If provided, also read restart parameters from this file in 

1512 `nep.restart` format and attach them to the returned model. 

1513 Defaults to None. 

1514 """ 

1515 data, parameters = _get_nep_contents(filename) 

1516 

1517 # sanity checks 

1518 for fld in ['cutoff', 'basis_size', 'n_max', 'l_max', 'ANN']: 

1519 assert fld in data, f'Invalid model file; {fld} line is missing' 

1520 assert data['version'] in [ 

1521 3, 

1522 4, 

1523 5, 

1524 ], 'Invalid model file; only NEP versions 3, 4 and 5 are currently supported' 

1525 

1526 # split up cutoff tuple 

1527 N_types = len(data['types']) 

1528 # Either global cutoffs + max neighbirs, or typewise cutoffs + max_neighbors 

1529 assert len(data['cutoff']) in [4, 2*N_types+2] 

1530 data['max_neighbors_radial'] = int(data['cutoff'][-2]) 

1531 data['max_neighbors_angular'] = int(data['cutoff'][-1]) 

1532 if len(data['cutoff']) == 2*N_types+2: 

1533 # Typewise cutoffs: radial are even, angular are odd 

1534 data['radial_cutoff'] = [data['cutoff'][i*2] for i in range(N_types)] 

1535 data['angular_cutoff'] = [data['cutoff'][i*2+1] for i in range(N_types)] 

1536 else: 

1537 data['radial_cutoff'] = data['cutoff'][0] 

1538 data['angular_cutoff'] = data['cutoff'][1] 

1539 del data['cutoff'] 

1540 

1541 # split up basis_size tuple 

1542 assert len(data['basis_size']) == 2 

1543 data['n_basis_radial'] = data['basis_size'][0] 

1544 data['n_basis_angular'] = data['basis_size'][1] 

1545 del data['basis_size'] 

1546 

1547 # split up n_max tuple 

1548 assert len(data['n_max']) == 2 

1549 data['n_max_radial'] = data['n_max'][0] 

1550 data['n_max_angular'] = data['n_max'][1] 

1551 del data['n_max'] 

1552 

1553 # split up nl_max tuple 

1554 len_l = len(data['l_max']) 

1555 assert len_l in [1, 2, 3, 4, 5, 6, 7] 

1556 data['l_max_3b'] = data['l_max'][0] 

1557 data['l_max_4b'] = data['l_max'][1] if len_l > 1 else 0 

1558 data['l_max_5b'] = data['l_max'][2] if len_l > 2 else 0 

1559 data['has_q_112'] = data['l_max'][3] if len_l > 3 else 0 

1560 data['has_q_123'] = data['l_max'][4] if len_l > 4 else 0 

1561 data['has_q_233'] = data['l_max'][5] if len_l > 5 else 0 

1562 data['has_q_134'] = data['l_max'][6] if len_l > 6 else 0 

1563 del data['l_max'] 

1564 

1565 # compute dimensions of descriptor components 

1566 data['n_descriptor_radial'] = data['n_max_radial'] + 1 

1567 l_max_enh = (data['l_max_3b'] 

1568 + (data['l_max_4b'] > 0) 

1569 + (data['l_max_5b'] > 0) 

1570 + (data['has_q_112'] > 0) 

1571 + (data['has_q_123'] > 0) 

1572 + (data['has_q_233'] > 0) 

1573 + (data['has_q_134'] > 0)) 

1574 data['n_descriptor_angular'] = (data['n_max_angular'] + 1) * l_max_enh 

1575 n_descriptor = data['n_descriptor_radial'] + data['n_descriptor_angular'] 

1576 

1577 is_charged_model = data['model_type'] == 'potential_with_charges' 

1578 # compute number of parameters 

1579 data['n_neuron'] = data['ANN'][0] 

1580 del data['ANN'] 

1581 n_types = len(data['types']) 

1582 if data['version'] == 3: 

1583 n = 1 

1584 n_bias = 1 

1585 elif data['version'] == 4 and is_charged_model: 

1586 # one hidden layer per atomic species, but two output nodes 

1587 n = n_types 

1588 n_bias = 2 

1589 elif data['version'] == 4: 

1590 # one hidden layer per atomic species 

1591 n = n_types 

1592 n_bias = 1 

1593 else: # NEP5 

1594 # like nep4, but additionally has an 

1595 # individual bias term in the output 

1596 # layer for each species. 

1597 n = n_types 

1598 n_bias = 1 + n_types # one global bias + one per species 

1599 

1600 n_ann_input_weights = (n_descriptor + 1) * data['n_neuron'] # weights + bias 

1601 n_ann_output_weights = 2*data['n_neuron'] if is_charged_model else data['n_neuron'] # weights 

1602 n_ann_parameters = ( 

1603 n_ann_input_weights + n_ann_output_weights 

1604 ) * n + n_bias 

1605 

1606 n_descriptor_weights = n_types**2 * ( 

1607 (data['n_max_radial'] + 1) * (data['n_basis_radial'] + 1) 

1608 + (data['n_max_angular'] + 1) * (data['n_basis_angular'] + 1) 

1609 ) 

1610 data['n_parameters'] = n_ann_parameters + n_descriptor_weights + n_descriptor 

1611 is_polarizability_model = data['model_type'] == 'polarizability' 

1612 if data['n_parameters'] + n_ann_parameters == len(parameters): 

1613 data['n_parameters'] += n_ann_parameters 

1614 assert is_polarizability_model, ( 

1615 'Model is not labelled as a polarizability model, but the number of ' 

1616 'parameters matches a polarizability model.\n' 

1617 'If this is a polarizability model trained with GPUMD <=v3.8, please ' 

1618 'modify the header in the nep.txt file to enable parsing ' 

1619 f'`nep{data["version"]}_polarizability`.\n' 

1620 ) 

1621 assert data['n_parameters'] == len(parameters), ( 

1622 'Parsing of parameters inconsistent; please submit a bug report\n' 

1623 f'{data["n_parameters"]} != {len(parameters)}' 

1624 ) 

1625 data['n_ann_parameters'] = n_ann_parameters 

1626 

1627 # split up parameters into the ANN weights, descriptor weights, and scaling parameters 

1628 n1 = n_ann_parameters 

1629 n1 *= 2 if is_polarizability_model else 1 

1630 n2 = n1 + n_descriptor_weights 

1631 data['ann_parameters'] = parameters[:n1] 

1632 descriptor_weights = np.array(parameters[n1:n2]) 

1633 data['q_scaler'] = parameters[n2:] 

1634 

1635 # add ann parameters to data dict 

1636 ann_groups = data['types'] if data['version'] in (4, 5) else ['all_species'] 

1637 sorted_ann_parameters = _sort_ann_parameters(data['ann_parameters'], 

1638 ann_groups, 

1639 data['n_neuron'], 

1640 n, 

1641 n_bias, 

1642 n_descriptor, 

1643 is_polarizability_model, 

1644 is_charged_model) 

1645 

1646 data['ann_parameters'] = sorted_ann_parameters 

1647 if 'sqrt_epsilon_infinity' in sorted_ann_parameters.keys(): 

1648 data['sqrt_epsilon_infinity'] = sorted_ann_parameters['sqrt_epsilon_infinity'] 

1649 sorted_ann_parameters.pop('sqrt_epsilon_infinity') 

1650 data['ann_parameters'] = sorted_ann_parameters 

1651 

1652 # add descriptors to data dict 

1653 data['n_descriptor_parameters'] = len(descriptor_weights) 

1654 radial, angular = _sort_descriptor_parameters(descriptor_weights, 

1655 data['types'], 

1656 data['n_max_radial'], 

1657 data['n_basis_radial'], 

1658 data['n_max_angular'], 

1659 data['n_basis_angular']) 

1660 data['radial_descriptor_weights'] = radial 

1661 data['angular_descriptor_weights'] = angular 

1662 

1663 model = Model(**data) 

1664 if restart_file is not None: 

1665 model.read_restart(restart_file) 

1666 return model