train
The train module provides functionality for training Nenya models using self-supervised contrastive learning.
Functions
- nenya.train.main(opt_path, debug=False, save_file=None)
Main function for training a Nenya model.
- Parameters:
This function:
Loads parameters from the JSON file
Sets up the model and criterion
Configures the optimizer
Trains the model for the specified number of epochs
Periodically validates the model
Saves model checkpoints and learning curves
Dependencies
This module depends on:
nenya.io: For loading and saving datanenya.train_util: For model setup and training utilitiesnenya.params: For parameter managementnenya.util: For optimization and model saving utilities
Training Process
The training process includes:
Loading parameters from a JSON file
Setting up the model and criterion
Setting up the optimizer
For each epoch: - Creating a data loader - Adjusting the learning rate - Training for one epoch - Recording losses
Optionally validating the model at specified intervals
Saving model checkpoints
Saving learning curves
Example Usage
from nenya.train import main as train_main
# Train with parameters from a JSON file
train_main("opts_nenya_modis_v5.json", debug=False)
# Train in debug mode (reduced epochs)
train_main("opts_nenya_modis_v5.json", debug=True)
# Train and save to a custom location
train_main("opts_nenya_modis_v5.json", save_file="/custom/path/model.pth")
Output Files
After training, the following files are created:
Model checkpoints in
{opt.model_folder}/ckpt_epoch_{epoch}.pthFinal model in
{opt.model_folder}/last.pthLearning curves in
{opt.model_folder}/learning_curve/-{opt.model_name}_losses_train.h5: Training losses -{opt.model_name}_losses_valid.h5: Validation losses
Model File Structure
The saved model files have the following structure:
{
'opt': opt, # Training options
'model': model.state_dict(), # Model weights
'optimizer': optimizer.state_dict(), # Optimizer state
'epoch': epoch, # Current epoch
}
Learning Curve Files
The learning curve HDF5 files contain:
loss_train: Array of training losses per epochloss_step_train: Array of per-step losses during trainingloss_avg_train: Array of running average losses during trainingSimilar arrays for validation losses