Training prodecure
The model training follows a supervised learning approach, where the model is trained to predict atomic energies and optionally forces based on reference data. Force predictions are done by computing the negative gradient of the predicted energy with respect to atomic positions using automatic differentiation.
Loss Functions
The training loss is a weighted sum of energy and force losses:
where \(L_E\) is the energy loss (Mean Absolute Error), \(L_F\) is the force loss (Mean Absolute Error or Mean Squared Error), and \(\lambda_E\), \(\lambda_F\) are weighting factors to balance the contributions of each loss component.
Each datapoint consists of atomic positions, species, reference energy, and reference forces. The energy loss is computed as the MSE between predicted and true energies. The force loss is computed as the MAE / MSE (to be set in config.yaml) between predicted and true forces.
While one datapoint corresponds to one single energy value (per molecule), the forces are computed per atom and per spatial dimension (x,y,z). Thus, the force loss is averaged over all atoms and dimensions in the batch:
where \(N_{atoms}\) is the total number of atoms in the batch, and \(\mathcal{L}\) is the chosen loss function (MAE or MSE).
Training Loop
The training loop iterates over the dataset for a specified number of epochs and batch sizes. In each epoch, the following steps are performed:
1. Forward pass: Compute predicted energies and forces for the input batch.
2. Loss computation: Calculate the total loss using the defined loss functions.
3. Backward pass: Compute gradients of the loss with respect to model parameters.
4. Parameter update: Update model parameters using an optimizer (e.g., Adam) based on the computed gradients.
Early Stopping, Validation and Regularization
To prevent overfitting, early stopping is employed based on validation loss. If the validation loss does not improve for a specified number of consecutive epochs (patience), training is halted, and the best model (with the lowest validation loss) is restored. Additionally, the peanut and the peanut dual cutoff model can use drop out rates in the energy prediction MLP to reduce overfitting (if set to True in the config file).
Hyperparameter Optimization
Hyperparameters are optimized using a fixed pre-selected set of values. Slurm job scripts are utilized to run multiple training experiments in parallel on a computing cluster. Currently, the following hyperparameters are varied:
- model_type_list=("peanut" "peanut_dualcutoff")
- symmetry_functions_b_list=(16 32)
- update_mlp_out_list=(32 64)
- linear_depth_list=(64 128)
- loss_function_forces_list=("MSE" "MAE")
- dropout_energy_pass_list=("True" "False")
Additionally, the Peanut mini model is trained with a smaller set of hyperparameters for quicker evaluation:
- model_type_list=("peanut_mini")
- symmetry_functions_b_list=(16 32)
- linear_depth_list=(64 128)
- loss_function_forces_list=("MSE" "MAE")
All other hyperparameters are kept constant during these sweeps.
Logging and Monitoring
Training progress, including training and validation losses, is logged in csv files and plots. Key metrics are monitored to assess model performance and convergence during training. (To be fixed): Currently, there is an inconsistency in logging where the total validation loss and average training loss are saved, which may lead to confusion when comparing the two.
Model Checkpointing
Model checkpoints are saved at regular intervals or when the validation loss improves. This allows for resuming training from the last checkpoint in case of interruptions and for retaining the best-performing model.
Visualization
Post-training, predicted vs. true energy and force values are visualized using scatter plots to assess model accuracy. Loss curves for both training and validation over epochs are also plotted to analyze training dynamics.