Coverage for calorine/nep/training_factory.py: 100%

53 statements  

« prev     ^ index     » next       coverage.py v7.6.4, created at 2024-12-10 08:26 +0000

1from os import makedirs 

2from os.path import exists, join as join_path 

3from typing import List, NamedTuple, Optional 

4 

5import numpy as np 

6from ase import Atoms 

7from sklearn.model_selection import KFold 

8 

9from .io import write_nepfile, write_structures 

10 

11 

12def setup_training(parameters: NamedTuple, 

13 structures: List[Atoms], 

14 enforced_structures: List[int] = [], 

15 rootdir: str = '.', 

16 mode: str = 'kfold', 

17 n_splits: int = None, 

18 train_fraction: float = None, 

19 seed: int = 42, 

20 overwrite: bool = False, 

21 ) -> None: 

22 """Sets up the input files for training a NEP via the ``nep`` 

23 executable of the GPUMD package. 

24 

25 Parameters 

26 ---------- 

27 parameters 

28 dictionary containing the parameters to be set in the nep.in file; 

29 see `here <https://gpumd.org/nep/input_parameters/index.html>`__ 

30 for an overview of these parameters 

31 structures 

32 list of structures to be included 

33 enforced_structures 

34 structures that _must_ be included in the training set, provided in the form 

35 of a list of indices that refer to the content of the ``structures`` parameter 

36 rootdir 

37 root directory in which to create the input files 

38 mode 

39 how the test-train split is performed. Options: ``'kfold'`` and ``'bagging'`` 

40 n_splits 

41 number of splits of the input structures in training and test sets that ought to be 

42 performed; by default no split will be done and all input structures will be used 

43 for training 

44 train_fraction 

45 fraction of structures to use for training when mode ``'bagging'`` is used 

46 seed 

47 random number generator seed to be used; this ensures reproducability 

48 overwrite 

49 if True overwrite the content of ``rootdir`` if it exists 

50 """ 

51 if exists(rootdir) and not overwrite: 

52 raise FileExistsError('Output directory exists.' 

53 ' Set overwrite=True in order to override this behavior.') 

54 

55 if n_splits is not None and (n_splits <= 0 or n_splits > len(structures)): 

56 raise ValueError(f'n_splits ({n_splits}) must be positive and' 

57 f' must not exceed {len(structures)}.') 

58 

59 if mode == 'kfold' and train_fraction is not None: 

60 raise ValueError(f'train_fraction cannot be set when mode {mode} is used') 

61 elif mode == 'bagging' and (train_fraction <= 0 or train_fraction > 1): 

62 raise ValueError(f'train_fraction ({train_fraction}) must be in (0,1]') 

63 

64 rs = np.random.RandomState(seed) 

65 _prepare_training(parameters, structures, enforced_structures, 

66 rootdir, mode, n_splits, train_fraction, rs) 

67 

68 

69def _prepare_training(parameters: NamedTuple, 

70 structures: List[Atoms], 

71 enforced_structures: List[int], 

72 rootdir: str, 

73 mode: str, 

74 n_splits: Optional[int], 

75 train_fraction: Optional[float], 

76 rs: np.random.RandomState) -> None: 

77 """Prepares training and test sets and writes structural data as well as parameters files. 

78 

79 See class-level docstring for documentation of parameters. 

80 """ 

81 dirname = join_path(rootdir, 'nepmodel_full') 

82 makedirs(dirname, exist_ok=True) 

83 _write_structures(structures, dirname, list(set(range(len(structures)))), [0]) 

84 write_nepfile(parameters, dirname) 

85 

86 if n_splits is None: 

87 return 

88 

89 n_structures = len(structures) 

90 remaining_structures = list(set(range(n_structures)) - set(enforced_structures)) 

91 

92 if mode == 'kfold': 

93 kf = KFold(n_splits=n_splits, shuffle=True, random_state=rs) 

94 for k, (train_indices, test_indices) in enumerate(kf.split(remaining_structures)): 

95 # append enforced structures at the end of the training set 

96 train_selection = [remaining_structures[x] for x in list(train_indices)] 

97 test_selection = [remaining_structures[x] for x in list(test_indices)] 

98 

99 # sanity check: make sure there is no overlap between train and test 

100 assert set(train_selection).intersection(set(test_selection)) == set(), \ 

101 'Train and test set should not overlap' 

102 

103 subdir = f'nepmodel_split{k+1}' 

104 dirname = join_path(rootdir, subdir) 

105 makedirs(dirname, exist_ok=True) 

106 _write_structures(structures, dirname, train_selection, test_selection) 

107 write_nepfile(parameters, dirname) 

108 

109 elif mode == 'bagging': 

110 for k in range(n_splits): 

111 train_selection = rs.choice( 

112 remaining_structures, 

113 size=int(train_fraction * n_structures) - len(enforced_structures), 

114 replace=False) 

115 

116 # append enforced structures at the end of the training set 

117 train_selection = list(train_selection) 

118 train_selection.extend(enforced_structures) 

119 

120 # add the remaining structures to the test set 

121 test_selection = list(set(range(n_structures)) - set(train_selection)) 

122 

123 # sanity check: make sure there is no overlap between train and test 

124 assert set(train_selection).intersection(set(test_selection)) == set(), \ 

125 'Train and test set should not overlap' 

126 

127 dirname = join_path(rootdir, f'nepmodel_split{k+1}') 

128 makedirs(dirname, exist_ok=True) 

129 _write_structures(structures, dirname, train_selection, test_selection) 

130 write_nepfile(parameters, dirname) 

131 

132 else: 

133 raise ValueError(f'Unknown value for mode: {mode}.') 

134 

135 

136def _write_structures(structures: List[Atoms], 

137 dirname: str, 

138 train_selection: List[int], 

139 test_selection: List[int]): 

140 """Writes structures in format readable by nep executable. 

141 

142 See class-level docstring for documentation of parameters. 

143 """ 

144 write_structures( 

145 join_path(dirname, 'train.xyz'), 

146 [s for k, s in enumerate(structures) if k in train_selection]) 

147 write_structures( 

148 join_path(dirname, 'test.xyz'), 

149 [s for k, s in enumerate(structures) if k in test_selection])