Skip to main content

Periodic Data Saving During Training

During training, issues may occur such as GPU drop, GPU failure, network fluctuations, excessive traffic load, network disconnection, hardware failure, machine crash, or the training process being automatically terminated at batch N by the system due to OOM. Once these issues occur, if proper measures are not taken to save training progress, previous results may be lost, requiring training to restart from scratch. This not only wastes valuable time and computing resources but also increases the workload of research and development.

tip

Therefore, it is very important to periodically save the model state to disk. This includes not only the model parameters (weights and biases) but also other critical information, such as:

  1. Current iteration (Epochs): Know which stage training has reached.
  2. Optimizer state: Save optimizer parameters (e.g., learning rate, momentum, etc.) and internal state (e.g., first and second moment estimates in Adam), which are crucial for continuity.
  3. Loss function history: Helps monitor performance changes during training.
  4. Learning rate scheduler state (if used): Records any dynamic learning rate adjustments.

Saving this information allows resuming training from the last saved state after interruption, instead of restarting. Deep learning frameworks like PyTorch and TensorFlow provide tools and APIs to achieve this conveniently. This practice is especially important for long-term or large-scale training tasks, as it significantly reduces wasted resources and delays caused by unexpected interruptions.

By using PyTorch Checkpoint or TensorFlow ModelCheckpoint, developers can effectively manage model states during long training processes, ensuring that even if interruptions occur, training can resume from the latest state, saving time and resources.


PyTorch Checkpoint

The PyTorch framework provides a flexible mechanism for saving and loading models, including model parameters, optimizer states, and any other necessary information. In PyTorch, this is typically achieved using the torch.save() and torch.load() functions.

The official PyTorch documentation provides detailed guidance on saving and loading models for different scenarios, including saving only model parameters, saving the entire model, and saving multiple components (e.g., model and optimizer states).

Documentation: Saving and Loading Models - PyTorch


TensorFlow ModelCheckpoint

TensorFlow/Keras provides the ModelCheckpoint callback function to save models at specific times during training. This can be at the end of each epoch or whenever a monitored metric (such as validation loss) improves.

ModelCheckpoint can save not only the latest state of the model but also the best-performing model during training.

It allows flexible configuration of what is saved (weights only, the entire model, etc.) and how it is saved (save every time, save only the best model, etc.).

Documentation: Save and load models - TensorFlow