Coverage for calorine / nep / io.py: 100%

235 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-03-05 13:55 +0000

1from os.path import exists 

2from os.path import join as join_path 

3from typing import Any, Iterable, NamedTuple, TextIO 

4from warnings import warn 

5 

6import numpy as np 

7from ase import Atoms 

8from ase.io import read, write 

9from ase.stress import voigt_6_to_full_3x3_stress 

10from pandas import DataFrame 

11 

12 

13def read_loss(filename: str) -> DataFrame: 

14 """Parses a file in `loss.out` format from GPUMD and returns the 

15 content as a data frame. More information concerning file format, 

16 content and units can be found `here 

17 <https://gpumd.org/nep/output_files/loss_out.html>`__. 

18 

19 Parameters 

20 ---------- 

21 filename 

22 input file name 

23 

24 """ 

25 data = np.loadtxt(filename) 

26 if isinstance(data[0], np.float64): 

27 # If only a single row in loss.out, append a dimension 

28 data = data.reshape(1, -1) 

29 if len(data[0]) == 6: 

30 tags = 'total_loss L1 L2' 

31 tags += ' RMSE_P_train' 

32 tags += ' RMSE_P_test' 

33 elif len(data[0]) == 10: 

34 tags = 'total_loss L1 L2' 

35 tags += ' RMSE_E_train RMSE_F_train RMSE_V_train' 

36 tags += ' RMSE_E_test RMSE_F_test RMSE_V_test' 

37 elif len(data[0]) == 14: 

38 tags = 'total_loss L1 L2' 

39 tags += ' RMSE_E_train RMSE_F_train RMSE_V_train RMSE_Q_train RMSE_Z_train' 

40 tags += ' RMSE_E_test RMSE_F_test RMSE_V_test RMSE_Q_test RMSE_Z_test' 

41 else: 

42 raise ValueError( 

43 f'Input file contains {len(data[0])} data columns. Expected 6 or 10 columns.' 

44 ) 

45 generations = range(100, len(data) * 100 + 1, 100) 

46 df = DataFrame(data=data[:, 1:], columns=tags.split(), index=generations) 

47 return df 

48 

49 

50def _write_structure_in_nep_format(structure: Atoms, f: TextIO) -> None: 

51 """Write structure block into a file-like object in format readable by nep executable. 

52 

53 Parameters 

54 ---------- 

55 structure 

56 input structure; must hold information regarding energy and forces 

57 f 

58 file-like object to which to write 

59 """ 

60 

61 # Allowed keyword=value pairs. Use ASEs extyz write functionality.: 

62 # lattice="ax ay az bx by bz cx cy cz" (mandatory) 

63 # energy=energy_value (mandatory) 

64 # virial="vxx vxy vxz vyx vyy vyz vzx vzy vzz" (optional) 

65 # weight=relative_weight (optional) 

66 # properties=property_name:data_type:number_of_columns 

67 # species:S:1 (mandatory) 

68 # pos:R:3 (mandatory) 

69 # force:R:3 or forces:R:3 (mandatory) 

70 

71 # If a structure is to be used for training, it needs to either have target 

72 # 1. energies and forces, 

73 # 2. dipole, denoted `dipole="dx dy dz"` in the info string, or 

74 # 3. polarizability/susceptibility, denoted `pol="pxx pxy pxz pyx pyy pyz pzx pzy pzz"` 

75 # in the info string. 

76 has_energies_and_forces = True 

77 try: 

78 structure.get_potential_energy() 

79 structure.get_forces() # calculate forces to have them on the Atoms object 

80 except RuntimeError: 

81 has_energies_and_forces = False 

82 

83 has_dipole = 'dipole' in structure.info.keys() 

84 has_pol = 'pol' in structure.info.keys() 

85 

86 if not has_energies_and_forces and not has_dipole and not has_pol: 

87 raise RuntimeError('Failed to retrieve target energies/forces,' 

88 ' dipoles, or polarizabilities for structure') 

89 if np.isclose(structure.get_volume(), 0): 

90 raise ValueError('Structure cell must have a non-zero volume!') 

91 try: 

92 structure.get_stress() 

93 except RuntimeError: 

94 warn('Failed to retrieve stresses for structure') 

95 write(filename=f, images=structure, write_info=True, format='extxyz') 

96 

97 

98def write_structures(outfile: str, structures: list[Atoms]) -> None: 

99 """Writes structures for training/testing in format readable by nep executable. 

100 

101 Parameters 

102 ---------- 

103 outfile 

104 output filename 

105 structures 

106 list of structures with energy, forces, and (possibly) stresses 

107 """ 

108 with open(outfile, 'w') as f: 

109 for structure in structures: 

110 _write_structure_in_nep_format(structure, f) 

111 

112 

113def write_nepfile(parameters: NamedTuple, dirname: str) -> None: 

114 """Writes parameters file for NEP construction. 

115 

116 Parameters 

117 ---------- 

118 parameters 

119 input parameters; see `here <https://gpumd.org/nep/input_parameters/index.html>`__ 

120 dirname 

121 directory in which to place input file and links 

122 """ 

123 with open(join_path(dirname, 'nep.in'), 'w') as f: 

124 for key, val in parameters.items(): 

125 f.write(f'{key} ') 

126 if isinstance(val, Iterable): 

127 f.write(' '.join([f'{v}' for v in val])) 

128 else: 

129 f.write(f'{val}') 

130 f.write('\n') 

131 

132 

133def read_nepfile(filename: str) -> dict[str, Any]: 

134 """Returns the content of a configuration file (`nep.in`) as a dictionary. 

135 

136 Parameters 

137 ---------- 

138 filename 

139 input file name 

140 """ 

141 int_vals = ['version', 'neuron', 'generation', 'batch', 'population', 

142 'mode', 'model_type', 'charge_mode'] 

143 float_vals = ['lambda_1', 'lambda_2', 'lambda_e', 'lambda_f', 'lambda_v', 

144 'lambda_q', 'lambda_shear', 'force_delta', 'zbl'] 

145 settings = {} 

146 with open(filename) as f: 

147 for line in f.readlines(): 

148 # remove comments - throw away everything after a '#' 

149 cleaned = line.split('#', 1)[0].strip() 

150 flds = cleaned.split() 

151 if len(flds) == 0: 

152 continue 

153 settings[flds[0]] = ' '.join(flds[1:]) 

154 for key, val in settings.items(): 

155 if key in int_vals: 

156 settings[key] = int(val) 

157 elif key in float_vals: 

158 settings[key] = float(val) 

159 elif key in ['cutoff', 'n_max', 'l_max', 'basis_size', 'type_weight']: 

160 settings[key] = [float(v) for v in val.split()] 

161 elif key == 'type': 

162 types = val.split() 

163 types[0] = int(types[0]) 

164 settings[key] = types 

165 return settings 

166 

167 

168def read_structures(dirname: str) -> tuple[list[Atoms], list[Atoms]]: 

169 """Parses the output files with training and test data from a nep run and returns their 

170 content as two lists of structures, representing training and test data, respectively. 

171 Target and predicted data are included in the :attr:`info` dict of the :class:`Atoms` 

172 objects. 

173 

174 Parameters 

175 ---------- 

176 dirname 

177 Directory from which to read output files. 

178 

179 """ 

180 path = join_path(dirname) 

181 if not exists(path): 

182 raise FileNotFoundError(f'Directory {path} does not exist') 

183 

184 # fetch model type from nep input file 

185 nep_info = read_nepfile(f'{path}/nep.in') 

186 model_type = nep_info.get('model_type', 0) 

187 

188 # set up which files to parse, what dimensions to expect etc 

189 # depending on the type of model that is parsed 

190 if model_type == 0: 

191 charge_mode = int(nep_info.get('charge_mode', 0)) 

192 if charge_mode not in [0, 1, 2]: 

193 raise ValueError(f'Unknown charge_mode: {charge_mode}') 

194 # files to parse: (sname, size, mandatory, includes_target, per_atom) 

195 files_to_parse = [ 

196 ('energy', 1, True, True, False), 

197 ('force', 3, True, True, True), 

198 ('virial', 6, True, True, False), 

199 ('stress', 6, True, True, False), 

200 ] 

201 if charge_mode in [1, 2]: 

202 # files to parse: (sname, size, includes_target, per_atom) 

203 files_to_parse += [ 

204 ('charge', 1, True, False, True), 

205 ('bec', 9, False, True, True), 

206 ] 

207 elif model_type == 1: 

208 # files to parse: (sname, size, includes_target, per_atom) 

209 files_to_parse = [('dipole', 3, True, True, False)] 

210 elif model_type == 2: 

211 # files to parse: (sname, size, includes_target, per_atom) 

212 files_to_parse = [('polarizability', 6, True, True, False)] 

213 else: 

214 raise ValueError(f'Unknown model_type: {model_type}') 

215 

216 # read training and test data 

217 structures = {} 

218 for stype in ['train', 'test']: 

219 filename = join_path(dirname, f'{stype}.xyz') 

220 try: 

221 structures[stype] = read(filename, format='extxyz', index=':') 

222 except FileNotFoundError: 

223 warn(f'File {filename} not found.') 

224 structures[stype] = [] 

225 continue 

226 

227 n_structures = len(structures[stype]) 

228 

229 # loop over files from which to read target data and predictions 

230 for sname, size, mandatory, includes_target, per_atom in files_to_parse: 

231 infile = f'{sname}_{stype}.out' 

232 path = join_path(dirname, infile) 

233 if not exists(path): 

234 if mandatory: 

235 raise FileNotFoundError(f'File {path} does not exist') 

236 else: 

237 continue 

238 ts, ps = _read_data_file(path, includes_target=includes_target) 

239 

240 if ts is not None: 

241 if ts.shape[1] != size: 

242 raise ValueError(f'Target data in {infile} has unexpected shape:' 

243 f' {ts.shape} (expected: (-1, {size}))') 

244 if ps.shape[1] != size: 

245 raise ValueError(f'Predicted data in {infile} has unexpected shape:' 

246 f' {ps.shape} (expected: (-1, {size}))') 

247 

248 if per_atom: 

249 # data per-atom, e.g., forces, per-atom-virials, Born effective charges ... 

250 n_atoms_total = sum([len(s) for s in structures[stype]]) 

251 if len(ps) != n_atoms_total: 

252 raise ValueError(f'Number of atoms in {infile} ({len(ps)})' 

253 f' and {stype}.xyz ({n_atoms_total}) inconsistent.') 

254 n = 0 

255 for structure in structures[stype]: 

256 nat = len(structure) 

257 if ts is not None: 

258 structure.info[f'{sname}_target'] = \ 

259 np.array(ts[n: n + nat]).reshape(nat, size) 

260 structure.info[f'{sname}_predicted'] = \ 

261 np.array(ps[n: n + nat]).reshape(nat, size) 

262 n += nat 

263 else: 

264 # data per structure, e.g., energy, virials, stress 

265 if len(ps) != n_structures: 

266 raise ValueError(f'Number of structures in {infile} ({len(ps)})' 

267 f' and {stype}.xyz ({n_structures}) inconsistent.') 

268 for k, structure in enumerate(structures[stype]): 

269 assert ts is not None, 'This should not occur. Please report.' 

270 t = ts[k] 

271 assert np.shape(t) == (size,) 

272 structure.info[f'{sname}_target'] = t 

273 p = ps[k] 

274 assert np.shape(p) == (size,) 

275 structure.info[f'{sname}_predicted'] = p 

276 

277 # special handling of target data for BECs 

278 # The target data for BECs need not be complete. In this case nep writes 

279 # zeros for every component (not optimal). If we encounter such a case we set 

280 # all components to nan instead in order to be able to quickly filter for 

281 # this case when analyzing data. 

282 for s in structures[stype]: 

283 if 'bec_target' in s.info and np.allclose(s.info['bec_target'], 0): 

284 nat = len(s) 

285 size = 9 

286 s.info['bec_target'] = np.array(size * nat * [np.nan]).reshape(nat, size) 

287 

288 return structures['train'], structures['test'] 

289 

290 

291def _read_data_file( 

292 path: str, 

293 includes_target: bool = True, 

294): 

295 """Private function that parses *.out files and 

296 returns their content for further processing. 

297 """ 

298 with open(path, 'r') as f: 

299 lines = f.readlines() 

300 target, predicted = [], [] 

301 for line in lines: 

302 flds = line.split() 

303 if includes_target: 

304 if len(flds) % 2 != 0: 

305 raise ValueError(f'Incorrect number of columns in {path} ({len(flds)}).') 

306 n = len(flds) // 2 

307 predicted.append([float(s) for s in flds[:n]]) 

308 target.append([float(s) for s in flds[n:]]) 

309 else: 

310 predicted.append([float(s) for s in flds]) 

311 target = None 

312 if target is not None: 

313 target = np.array(target) 

314 predicted = np.array(predicted) 

315 return target, predicted 

316 

317 

318def get_parity_data( 

319 structures: list[Atoms], 

320 property: str, 

321 selection: list[str] = None, 

322 flatten: bool = True, 

323) -> DataFrame: 

324 """Returns the predicted and target energies, forces, virials or stresses 

325 from a list of structures in a format suitable for generating parity plots. 

326 

327 The structures should have been read using :func:`read_structures 

328 <calorine.nep.read_structures>`, such that the `info` object is 

329 populated with keys of the form `<property>_<type>` where `<property>` 

330 is, e.g., `energy` or `force` and `<type>` is one of `predicted` or `target`. 

331 

332 The resulting parity data is returned as a tuple of dicts, where each entry 

333 corresponds to a list. 

334 

335 Parameters 

336 ---------- 

337 structures 

338 List of structures as read with :func:`read_structures <calorine.nep.read_structures>`. 

339 property 

340 One of `energy`, `force`, `virial`, `stress`, `bec`, `dipole`, or `polarizability`. 

341 selection 

342 A list containing which components to return, and/or the norm. 

343 Possible values are `x`, `y`, `z`, `xx`, `yy`, 

344 `zz`, `yz`, `xz`, `xy`, `norm`, `pressure`. 

345 flatten 

346 if True return flattened lists; this is useful for flattening 

347 the components of force or virials into a simple list 

348 """ 

349 voigt_mapping = { 

350 'x': 0, 'y': 1, 'z': 2, 'xx': 0, 'yy': 1, 'zz': 2, 'yz': 3, 'xz': 4, 'xy': 5, 

351 } 

352 if property not in ('energy', 'force', 'virial', 'stress', 'polarizability', 'dipole', 'bec'): 

353 raise ValueError( 

354 "`property` must be one of 'energy', 'force', 'virial', 'stress'," 

355 " 'polarizability', 'dipole', or 'bec'." 

356 ) 

357 if property in ['energy'] and selection: 

358 raise ValueError('Selection cannot be applied to scalars.') 

359 if property != 'stress' and selection and 'pressure' in selection: 

360 raise ValueError(f'Cannot calculate pressure for `{property}`.') 

361 

362 data = {'predicted': [], 'target': []} 

363 if property in ['force', 'bec'] and flatten: 

364 size = 3 if property == 'force' else 9 

365 data['species'] = [] 

366 for structure in structures: 

367 if 'species' in data: 

368 data['species'].extend(np.repeat(structure.symbols, size).tolist()) 

369 for stype in ['predicted', 'target']: 

370 property_with_stype = f'{property}_{stype}' 

371 if property_with_stype not in structure.info.keys(): 

372 raise KeyError(f'{property_with_stype} not available in info field of structure') 

373 extracted_property = np.array(structure.info[property_with_stype]) 

374 

375 if selection is None or len(selection) == 0: 

376 data[stype].append(extracted_property) 

377 continue 

378 

379 selected_values = [] 

380 for select in selection: 

381 if property in ['force', 'bec']: 

382 # flip to get (n_components, n_structures) 

383 extracted_property = extracted_property.T 

384 if select == 'norm': 

385 if property == 'force': 

386 selected_values.append(np.linalg.norm(extracted_property, axis=0)) 

387 elif property in ['virial', 'stress']: 

388 full_tensor = voigt_6_to_full_3x3_stress(extracted_property) 

389 selected_values.append(np.linalg.norm(full_tensor)) 

390 elif property in ['dipole']: 

391 selected_values.append(np.linalg.norm(extracted_property)) 

392 else: 

393 raise ValueError( 

394 f'Cannot handle selection=`norm` with property=`{property}`.') 

395 continue 

396 

397 if select == 'pressure' and property == 'stress': 

398 total_stress = extracted_property 

399 selected_values.append(-np.sum(total_stress[:3]) / 3) 

400 continue 

401 

402 if select not in voigt_mapping: 

403 raise ValueError(f'Selection `{select}` is not allowed.') 

404 index = voigt_mapping[select] 

405 if index >= extracted_property.shape[0]: 

406 raise ValueError( 

407 f'Selection `{select}` is not compatible with property `{property}`.' 

408 ) 

409 selected_values.append(extracted_property[index]) 

410 

411 data[stype].append(selected_values) 

412 if flatten: 

413 for stype in ['target', 'predicted']: 

414 value = data[stype] 

415 if len(np.shape(value[0])) > 0: 

416 data[stype] = np.concatenate(value).ravel().tolist() 

417 if property in ['force']: 

418 n = len(data['target']) // 3 

419 data['component'] = ['x', 'y', 'z'] * n 

420 elif property in ['virial', 'stress']: 

421 n = len(data['target']) // 6 

422 data['component'] = ['xx', 'yy', 'zz', 'yz', 'xz', 'xy'] * n 

423 elif property in ['bec']: 

424 n = len(data['target']) // 9 

425 data['component'] = ['xx', 'xy', 'xz', 'yx', 'yy', 'yz', 'zx', 'zy', 'zz'] * n 

426 df = DataFrame(data) 

427 # In case of flatten, cast to float64 for compatibility 

428 # with e.g. seaborn. 

429 # Casting in this way breaks tensorial properties though, 

430 # so skip it there. 

431 if flatten: 

432 df['target'] = df.target.astype('float64') 

433 df['predicted'] = df.predicted.astype('float64') 

434 return df