Coverage for calorine/nep/training_factory.py: 100%
53 statements
« prev ^ index » next coverage.py v7.6.4, created at 2024-12-10 08:26 +0000
« prev ^ index » next coverage.py v7.6.4, created at 2024-12-10 08:26 +0000
1from os import makedirs
2from os.path import exists, join as join_path
3from typing import List, NamedTuple, Optional
5import numpy as np
6from ase import Atoms
7from sklearn.model_selection import KFold
9from .io import write_nepfile, write_structures
12def setup_training(parameters: NamedTuple,
13 structures: List[Atoms],
14 enforced_structures: List[int] = [],
15 rootdir: str = '.',
16 mode: str = 'kfold',
17 n_splits: int = None,
18 train_fraction: float = None,
19 seed: int = 42,
20 overwrite: bool = False,
21 ) -> None:
22 """Sets up the input files for training a NEP via the ``nep``
23 executable of the GPUMD package.
25 Parameters
26 ----------
27 parameters
28 dictionary containing the parameters to be set in the nep.in file;
29 see `here <https://gpumd.org/nep/input_parameters/index.html>`__
30 for an overview of these parameters
31 structures
32 list of structures to be included
33 enforced_structures
34 structures that _must_ be included in the training set, provided in the form
35 of a list of indices that refer to the content of the ``structures`` parameter
36 rootdir
37 root directory in which to create the input files
38 mode
39 how the test-train split is performed. Options: ``'kfold'`` and ``'bagging'``
40 n_splits
41 number of splits of the input structures in training and test sets that ought to be
42 performed; by default no split will be done and all input structures will be used
43 for training
44 train_fraction
45 fraction of structures to use for training when mode ``'bagging'`` is used
46 seed
47 random number generator seed to be used; this ensures reproducability
48 overwrite
49 if True overwrite the content of ``rootdir`` if it exists
50 """
51 if exists(rootdir) and not overwrite:
52 raise FileExistsError('Output directory exists.'
53 ' Set overwrite=True in order to override this behavior.')
55 if n_splits is not None and (n_splits <= 0 or n_splits > len(structures)):
56 raise ValueError(f'n_splits ({n_splits}) must be positive and'
57 f' must not exceed {len(structures)}.')
59 if mode == 'kfold' and train_fraction is not None:
60 raise ValueError(f'train_fraction cannot be set when mode {mode} is used')
61 elif mode == 'bagging' and (train_fraction <= 0 or train_fraction > 1):
62 raise ValueError(f'train_fraction ({train_fraction}) must be in (0,1]')
64 rs = np.random.RandomState(seed)
65 _prepare_training(parameters, structures, enforced_structures,
66 rootdir, mode, n_splits, train_fraction, rs)
69def _prepare_training(parameters: NamedTuple,
70 structures: List[Atoms],
71 enforced_structures: List[int],
72 rootdir: str,
73 mode: str,
74 n_splits: Optional[int],
75 train_fraction: Optional[float],
76 rs: np.random.RandomState) -> None:
77 """Prepares training and test sets and writes structural data as well as parameters files.
79 See class-level docstring for documentation of parameters.
80 """
81 dirname = join_path(rootdir, 'nepmodel_full')
82 makedirs(dirname, exist_ok=True)
83 _write_structures(structures, dirname, list(set(range(len(structures)))), [0])
84 write_nepfile(parameters, dirname)
86 if n_splits is None:
87 return
89 n_structures = len(structures)
90 remaining_structures = list(set(range(n_structures)) - set(enforced_structures))
92 if mode == 'kfold':
93 kf = KFold(n_splits=n_splits, shuffle=True, random_state=rs)
94 for k, (train_indices, test_indices) in enumerate(kf.split(remaining_structures)):
95 # append enforced structures at the end of the training set
96 train_selection = [remaining_structures[x] for x in list(train_indices)]
97 test_selection = [remaining_structures[x] for x in list(test_indices)]
99 # sanity check: make sure there is no overlap between train and test
100 assert set(train_selection).intersection(set(test_selection)) == set(), \
101 'Train and test set should not overlap'
103 subdir = f'nepmodel_split{k+1}'
104 dirname = join_path(rootdir, subdir)
105 makedirs(dirname, exist_ok=True)
106 _write_structures(structures, dirname, train_selection, test_selection)
107 write_nepfile(parameters, dirname)
109 elif mode == 'bagging':
110 for k in range(n_splits):
111 train_selection = rs.choice(
112 remaining_structures,
113 size=int(train_fraction * n_structures) - len(enforced_structures),
114 replace=False)
116 # append enforced structures at the end of the training set
117 train_selection = list(train_selection)
118 train_selection.extend(enforced_structures)
120 # add the remaining structures to the test set
121 test_selection = list(set(range(n_structures)) - set(train_selection))
123 # sanity check: make sure there is no overlap between train and test
124 assert set(train_selection).intersection(set(test_selection)) == set(), \
125 'Train and test set should not overlap'
127 dirname = join_path(rootdir, f'nepmodel_split{k+1}')
128 makedirs(dirname, exist_ok=True)
129 _write_structures(structures, dirname, train_selection, test_selection)
130 write_nepfile(parameters, dirname)
132 else:
133 raise ValueError(f'Unknown value for mode: {mode}.')
136def _write_structures(structures: List[Atoms],
137 dirname: str,
138 train_selection: List[int],
139 test_selection: List[int]):
140 """Writes structures in format readable by nep executable.
142 See class-level docstring for documentation of parameters.
143 """
144 write_structures(
145 join_path(dirname, 'train.xyz'),
146 [s for k, s in enumerate(structures) if k in train_selection])
147 write_structures(
148 join_path(dirname, 'test.xyz'),
149 [s for k, s in enumerate(structures) if k in test_selection])