Coverage for calorine / nep / training_factory.py: 100%

66 statements  

« prev     ^ index     » next       coverage.py v7.13.2, created at 2026-03-05 13:55 +0000

1from os import makedirs 

2from pathlib import Path 

3from os.path import exists, join as join_path 

4from typing import List, NamedTuple, Optional 

5 

6import numpy as np 

7from ase import Atoms 

8from sklearn.model_selection import KFold 

9 

10from .io import write_nepfile, write_structures 

11 

12 

13def setup_training(parameters: NamedTuple, 

14 structures: List[Atoms], 

15 enforced_structures: List[int] = [], 

16 rootdir: str = '.', 

17 mode: str = 'kfold', 

18 n_splits: int = None, 

19 train_fraction: float = None, 

20 seed: int = 42, 

21 overwrite: bool = False, 

22 ) -> None: 

23 """Sets up the input files for training a NEP via the ``nep`` 

24 executable of the GPUMD package. 

25 

26 Parameters 

27 ---------- 

28 parameters 

29 dictionary containing the parameters to be set in the nep.in file; 

30 see `here <https://gpumd.org/nep/input_parameters/index.html>`__ 

31 for an overview of these parameters 

32 structures 

33 list of structures to be included 

34 enforced_structures 

35 structures that _must_ be included in the training set, provided in the form 

36 of a list of indices that refer to the content of the ``structures`` parameter 

37 rootdir 

38 root directory in which to create the input files 

39 mode 

40 how the test-train split is performed. Options: ``'kfold'`` and ``'bagging'`` 

41 n_splits 

42 number of splits of the input structures in training and test sets that ought to be 

43 performed; by default no split will be done and all input structures will be used 

44 for training 

45 train_fraction 

46 fraction of structures to use for training when mode ``'bagging'`` is used 

47 seed 

48 random number generator seed to be used; this ensures reproducability 

49 overwrite 

50 if True overwrite the content of ``rootdir`` if it exists 

51 """ 

52 if exists(rootdir) and not overwrite: 

53 raise FileExistsError('Output directory exists.' 

54 ' Set overwrite=True in order to override this behavior.') 

55 

56 if n_splits is not None and (n_splits <= 0 or n_splits > len(structures)): 

57 raise ValueError(f'n_splits ({n_splits}) must be positive and' 

58 f' must not exceed {len(structures)}.') 

59 

60 if mode == 'kfold' and train_fraction is not None: 

61 raise ValueError(f'train_fraction cannot be set when mode {mode} is used') 

62 elif mode == 'bagging' and (train_fraction <= 0 or train_fraction > 1): 

63 raise ValueError(f'train_fraction ({train_fraction}) must be in (0,1]') 

64 

65 rs = np.random.RandomState(seed) 

66 _prepare_training(parameters, structures, enforced_structures, 

67 rootdir, mode, n_splits, train_fraction, rs) 

68 

69 

70def _prepare_training(parameters: NamedTuple, 

71 structures: List[Atoms], 

72 enforced_structures: List[int], 

73 rootdir: str, 

74 mode: str, 

75 n_splits: Optional[int], 

76 train_fraction: Optional[float], 

77 rs: np.random.RandomState) -> None: 

78 """Prepares training and test sets and writes structural data as well as parameters files. 

79 

80 See docstring for `setup_training` for documentation of parameters. 

81 """ 

82 dirname = join_path(rootdir, 'nepmodel_full') 

83 makedirs(dirname, exist_ok=True) 

84 _write_structures(structures, dirname, list(set(range(len(structures)))), [0]) 

85 write_nepfile(parameters, dirname) 

86 

87 if n_splits is None: 

88 return 

89 

90 n_structures = len(structures) 

91 remaining_structures = list(set(range(n_structures)) - set(enforced_structures)) 

92 

93 if mode == 'kfold': 

94 kf = KFold(n_splits=n_splits, shuffle=True, random_state=rs) 

95 for k, (train_indices, test_indices) in enumerate(kf.split(remaining_structures)): 

96 # append enforced structures at the end of the training set 

97 train_selection = [remaining_structures[x] for x in list(train_indices)] 

98 test_selection = [remaining_structures[x] for x in list(test_indices)] 

99 

100 # sanity check: make sure there is no overlap between train and test 

101 assert set(train_selection).intersection(set(test_selection)) == set(), \ 

102 'Train and test set should not overlap' 

103 

104 subdir = f'nepmodel_split{k+1}' 

105 dirname = join_path(rootdir, subdir) 

106 makedirs(dirname, exist_ok=True) 

107 _write_structures(structures, dirname, train_selection, test_selection) 

108 write_nepfile(parameters, dirname) 

109 

110 elif mode == 'bagging': 

111 for k in range(n_splits): 

112 train_selection = rs.choice( 

113 remaining_structures, 

114 size=int(train_fraction * n_structures) - len(enforced_structures), 

115 replace=False) 

116 

117 # append enforced structures at the end of the training set 

118 train_selection = list(train_selection) 

119 train_selection.extend(enforced_structures) 

120 

121 # add the remaining structures to the test set 

122 test_selection = list(set(range(n_structures)) - set(train_selection)) 

123 

124 # sanity check: make sure there is no overlap between train and test 

125 assert set(train_selection).intersection(set(test_selection)) == set(), \ 

126 'Train and test set should not overlap' 

127 

128 dirname = join_path(rootdir, f'nepmodel_split{k+1}') 

129 makedirs(dirname, exist_ok=True) 

130 _write_structures(structures, dirname, train_selection, test_selection) 

131 write_nepfile(parameters, dirname) 

132 

133 else: 

134 raise ValueError(f'Unknown value for mode: {mode}.') 

135 

136 

137def _write_structures(structures: List[Atoms], 

138 dirname: str, 

139 train_selection: List[int], 

140 test_selection: List[int]): 

141 """Writes structures in format readable by nep executable. 

142 

143 See docstring for `setup_training` for documentation of parameters. 

144 """ 

145 write_structures( 

146 join_path(dirname, 'train.xyz'), 

147 [s for k, s in enumerate(structures) if k in train_selection]) 

148 write_structures( 

149 join_path(dirname, 'test.xyz'), 

150 [s for k, s in enumerate(structures) if k in test_selection]) 

151 

152 

153def setup_fine_tuning_nep89(parameters: NamedTuple, 

154 nep: Path, 

155 restart: Path, 

156 **kwargs_to_setup_training) -> None: 

157 """ 

158 Sets up a fine-tuning of the NEP89 foundation model. 

159 

160 Note that only the types, the number of generations, the batch, 

161 the population, and the regularization parameters are allowed 

162 to be changed. 

163 

164 The types must be a subset of the 89 types atomic species supported by 

165 the NEP89 foundation model. 

166 

167 This function wraps :func:`setup_training`. 

168 

169 Parameters 

170 ---------- 

171 parameters 

172 Dictionary containing the parameters to be set in the `nep.in` file; 

173 see `here <https://gpumd.org/nep/input_parameters/index.html>`__ 

174 for an overview of these parameters. 

175 Note that only `lambda_1`, `lambda_2`, `lambda_e`, `lambda_f`, `lambda_v`, 

176 `generation`, `population`, `type`, and `batch` are allowed parameters when fine-tuning. 

177 nep: 

178 Path to the `nep.txt` file for NEP89. 

179 restart: 

180 Path to the `nep.restart` file for NEP89. 

181 kwargs_to_setup_training: 

182 See the dosctring for `setup_training` for the rest of the parameters. 

183 """ 

184 allowed_parameters = ['type', 

185 'lambda_1', 

186 'lambda_2', 

187 'lambda_e', 

188 'lambda_f', 

189 'lambda_v', 

190 'generation', 

191 'population', 

192 'batch'] 

193 

194 for param in parameters.keys(): 

195 if param not in allowed_parameters: 

196 raise ValueError(f'Parameter {param} not allowed when fine-tuning.') 

197 

198 if not Path(nep).is_file(): 

199 raise FileNotFoundError(f'{nep} does not exist.') 

200 if not Path(restart).is_file(): 

201 raise FileNotFoundError(f'{restart} does not exist.') 

202 

203 # Default parameters that need to be set for NEP89. 

204 nep89_parameters = dict(version=4, 

205 zbl=2, 

206 cutoff=[6, 5], 

207 n_max=[4, 4], 

208 basis_size=[8, 8], 

209 l_max=[4, 2, 1], 

210 neuron=80) 

211 

212 fine_tuning = (dict(fine_tune=[str(nep), str(restart)]) | parameters | nep89_parameters) 

213 setup_training(fine_tuning, **kwargs_to_setup_training)