Coverage for calorine / nep / training_factory.py: 100%
66 statements
« prev ^ index » next coverage.py v7.13.2, created at 2026-03-05 13:55 +0000
« prev ^ index » next coverage.py v7.13.2, created at 2026-03-05 13:55 +0000
1from os import makedirs
2from pathlib import Path
3from os.path import exists, join as join_path
4from typing import List, NamedTuple, Optional
6import numpy as np
7from ase import Atoms
8from sklearn.model_selection import KFold
10from .io import write_nepfile, write_structures
13def setup_training(parameters: NamedTuple,
14 structures: List[Atoms],
15 enforced_structures: List[int] = [],
16 rootdir: str = '.',
17 mode: str = 'kfold',
18 n_splits: int = None,
19 train_fraction: float = None,
20 seed: int = 42,
21 overwrite: bool = False,
22 ) -> None:
23 """Sets up the input files for training a NEP via the ``nep``
24 executable of the GPUMD package.
26 Parameters
27 ----------
28 parameters
29 dictionary containing the parameters to be set in the nep.in file;
30 see `here <https://gpumd.org/nep/input_parameters/index.html>`__
31 for an overview of these parameters
32 structures
33 list of structures to be included
34 enforced_structures
35 structures that _must_ be included in the training set, provided in the form
36 of a list of indices that refer to the content of the ``structures`` parameter
37 rootdir
38 root directory in which to create the input files
39 mode
40 how the test-train split is performed. Options: ``'kfold'`` and ``'bagging'``
41 n_splits
42 number of splits of the input structures in training and test sets that ought to be
43 performed; by default no split will be done and all input structures will be used
44 for training
45 train_fraction
46 fraction of structures to use for training when mode ``'bagging'`` is used
47 seed
48 random number generator seed to be used; this ensures reproducability
49 overwrite
50 if True overwrite the content of ``rootdir`` if it exists
51 """
52 if exists(rootdir) and not overwrite:
53 raise FileExistsError('Output directory exists.'
54 ' Set overwrite=True in order to override this behavior.')
56 if n_splits is not None and (n_splits <= 0 or n_splits > len(structures)):
57 raise ValueError(f'n_splits ({n_splits}) must be positive and'
58 f' must not exceed {len(structures)}.')
60 if mode == 'kfold' and train_fraction is not None:
61 raise ValueError(f'train_fraction cannot be set when mode {mode} is used')
62 elif mode == 'bagging' and (train_fraction <= 0 or train_fraction > 1):
63 raise ValueError(f'train_fraction ({train_fraction}) must be in (0,1]')
65 rs = np.random.RandomState(seed)
66 _prepare_training(parameters, structures, enforced_structures,
67 rootdir, mode, n_splits, train_fraction, rs)
70def _prepare_training(parameters: NamedTuple,
71 structures: List[Atoms],
72 enforced_structures: List[int],
73 rootdir: str,
74 mode: str,
75 n_splits: Optional[int],
76 train_fraction: Optional[float],
77 rs: np.random.RandomState) -> None:
78 """Prepares training and test sets and writes structural data as well as parameters files.
80 See docstring for `setup_training` for documentation of parameters.
81 """
82 dirname = join_path(rootdir, 'nepmodel_full')
83 makedirs(dirname, exist_ok=True)
84 _write_structures(structures, dirname, list(set(range(len(structures)))), [0])
85 write_nepfile(parameters, dirname)
87 if n_splits is None:
88 return
90 n_structures = len(structures)
91 remaining_structures = list(set(range(n_structures)) - set(enforced_structures))
93 if mode == 'kfold':
94 kf = KFold(n_splits=n_splits, shuffle=True, random_state=rs)
95 for k, (train_indices, test_indices) in enumerate(kf.split(remaining_structures)):
96 # append enforced structures at the end of the training set
97 train_selection = [remaining_structures[x] for x in list(train_indices)]
98 test_selection = [remaining_structures[x] for x in list(test_indices)]
100 # sanity check: make sure there is no overlap between train and test
101 assert set(train_selection).intersection(set(test_selection)) == set(), \
102 'Train and test set should not overlap'
104 subdir = f'nepmodel_split{k+1}'
105 dirname = join_path(rootdir, subdir)
106 makedirs(dirname, exist_ok=True)
107 _write_structures(structures, dirname, train_selection, test_selection)
108 write_nepfile(parameters, dirname)
110 elif mode == 'bagging':
111 for k in range(n_splits):
112 train_selection = rs.choice(
113 remaining_structures,
114 size=int(train_fraction * n_structures) - len(enforced_structures),
115 replace=False)
117 # append enforced structures at the end of the training set
118 train_selection = list(train_selection)
119 train_selection.extend(enforced_structures)
121 # add the remaining structures to the test set
122 test_selection = list(set(range(n_structures)) - set(train_selection))
124 # sanity check: make sure there is no overlap between train and test
125 assert set(train_selection).intersection(set(test_selection)) == set(), \
126 'Train and test set should not overlap'
128 dirname = join_path(rootdir, f'nepmodel_split{k+1}')
129 makedirs(dirname, exist_ok=True)
130 _write_structures(structures, dirname, train_selection, test_selection)
131 write_nepfile(parameters, dirname)
133 else:
134 raise ValueError(f'Unknown value for mode: {mode}.')
137def _write_structures(structures: List[Atoms],
138 dirname: str,
139 train_selection: List[int],
140 test_selection: List[int]):
141 """Writes structures in format readable by nep executable.
143 See docstring for `setup_training` for documentation of parameters.
144 """
145 write_structures(
146 join_path(dirname, 'train.xyz'),
147 [s for k, s in enumerate(structures) if k in train_selection])
148 write_structures(
149 join_path(dirname, 'test.xyz'),
150 [s for k, s in enumerate(structures) if k in test_selection])
153def setup_fine_tuning_nep89(parameters: NamedTuple,
154 nep: Path,
155 restart: Path,
156 **kwargs_to_setup_training) -> None:
157 """
158 Sets up a fine-tuning of the NEP89 foundation model.
160 Note that only the types, the number of generations, the batch,
161 the population, and the regularization parameters are allowed
162 to be changed.
164 The types must be a subset of the 89 types atomic species supported by
165 the NEP89 foundation model.
167 This function wraps :func:`setup_training`.
169 Parameters
170 ----------
171 parameters
172 Dictionary containing the parameters to be set in the `nep.in` file;
173 see `here <https://gpumd.org/nep/input_parameters/index.html>`__
174 for an overview of these parameters.
175 Note that only `lambda_1`, `lambda_2`, `lambda_e`, `lambda_f`, `lambda_v`,
176 `generation`, `population`, `type`, and `batch` are allowed parameters when fine-tuning.
177 nep:
178 Path to the `nep.txt` file for NEP89.
179 restart:
180 Path to the `nep.restart` file for NEP89.
181 kwargs_to_setup_training:
182 See the dosctring for `setup_training` for the rest of the parameters.
183 """
184 allowed_parameters = ['type',
185 'lambda_1',
186 'lambda_2',
187 'lambda_e',
188 'lambda_f',
189 'lambda_v',
190 'generation',
191 'population',
192 'batch']
194 for param in parameters.keys():
195 if param not in allowed_parameters:
196 raise ValueError(f'Parameter {param} not allowed when fine-tuning.')
198 if not Path(nep).is_file():
199 raise FileNotFoundError(f'{nep} does not exist.')
200 if not Path(restart).is_file():
201 raise FileNotFoundError(f'{restart} does not exist.')
203 # Default parameters that need to be set for NEP89.
204 nep89_parameters = dict(version=4,
205 zbl=2,
206 cutoff=[6, 5],
207 n_max=[4, 4],
208 basis_size=[8, 8],
209 l_max=[4, 2, 1],
210 neuron=80)
212 fine_tuning = (dict(fine_tune=[str(nep), str(restart)]) | parameters | nep89_parameters)
213 setup_training(fine_tuning, **kwargs_to_setup_training)