WeightAveraging

class lightning.pytorch.callbacks.WeightAveraging(device=None, use_buffers=True, **kwargs)[source]

Bases: Callback

A callback that updates an averaged model for Stochastic Weight Averaging (SWA) or Exponential Moving Average (EMA) after each training step.

Arguments given to the constructor will be passed to the AveragedModel constructor. If no device is specified, the device of the original model will be used. Contrary to AveragedModel, use_buffers is set to True by default. That is, by default the callback will compute running averages for both the parameters and the buffers of the model. Setting use_buffers to False will cause only the model parameters to be averaged, leaving updating the batch normalization statistics to the user (using torch.optim.swa_utils.update_bn()).

You can provide a custom averaging function with the avg_fn or multi_avg_fn parameter. See the AveragedModel class for details. If no averaging function is provided, the default is to compute the equally-weighted average of the weights (SWA).

You can customize when the average model is updated by overriding the should_update() method. The callback calls it with either step_idx or epoch_idx and the method returns a boolean indicating whether to update after the given step or epoch. The default is to update after every step.

During validation and after the training finishes, the current model parameters will be replaced with the averaged values.

See also the documentation on the weight averaging callbacks provided by Lightning.

Note

To ensure that the AveragedModel will contain all layers, setup() will call configure_model() before instantiating the AveragedModel. However, that hook is not called in a strategy aware context, sharded models do not work with weight averaging, and a warning will be issued.

Example:

from lightning.pytorch.callbacks import WeightAveraging
from torch.optim.swa_utils import get_ema_avg_fn

class EMAWeightAveraging(WeightAveraging):
    def __init__(self):
        super().__init__(avg_fn=get_ema_avg_fn())

    def should_update(self, step_idx=None, epoch_idx=None):
        # Start after 100 steps.
        return (step_idx is not None) and (step_idx >= 100)

trainer = Trainer(callbacks=EMAWeightAveraging(), max_epochs=10)
trainer.fit(model, dataloader)
Parameters:
  • device (Union[device, str, int, None]) – By default, the AveragedModel will be stored on the same device as the original model. If the device argument is provided, the AveragedModel will be stored on this device instead. If you run out of GPU memory, you might want to use "cpu".

  • use_buffers (bool) – If False, the buffers of the model will not be averaged.

  • kwargs (Any) – Additional keyword arguments to be passed to the AveragedModel constructor, such as avg_fn or multi_avg_fn.

load_state_dict(state_dict)[source]

Called when loading a checkpoint.

Reloads the callback state given a state_dict.

Parameters:

state_dict (dict[str, Any]) – A dictionary containing the callback state.

Return type:

None

on_load_checkpoint(trainer, pl_module, checkpoint)[source]

Called when loading a model checkpoint.

Loads the current model and the AveragedModel parameters from the checkpoint.

Parameters:
Return type:

None

on_save_checkpoint(trainer, pl_module, checkpoint)[source]

Called when saving a checkpoint.

Moves the current model state to the key current_model_state, and places the average model state in state_dict instead. Any other state variables of the AveragedModel will be saved in averaging_state.

Parameters:
Return type:

None

on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]

Called when a training batch ends.

Updates the AveragedModel parameters, if requested by self.should_update().

Parameters:
Return type:

None

on_train_end(trainer, pl_module)[source]

Called when training ends.

Transfers parameters from the AveragedModel to the current model.

Parameters:
Return type:

None

on_train_epoch_end(trainer, pl_module)[source]

Called when a training epoch ends.

Updates the AveragedModel parameters, if requested by self.should_update().

Parameters:
Return type:

None

on_validation_epoch_end(trainer, pl_module)[source]

Called when a validation epoch ends.

Recovers the current model parameters from the AveragedModel.

Parameters:
Return type:

None

on_validation_epoch_start(trainer, pl_module)[source]

Called when a validation epoch begins.

Transfers parameter values from the AveragedModel to the current model.

Parameters:
Return type:

None

setup(trainer, pl_module, stage)[source]

Called when fit, validate, test, predict, or tune begins.

Creates an AveragedModel when fit begins.

Parameters:
Return type:

None

should_update(step_idx=None, epoch_idx=None)[source]

Called after every optimizer step and after every training epoch to check whether the average model should be updated.

One of the arguments is set to the zero-based index of the last training step or epoch. The default implementation returns True when any step_idx is provided. The user can customize when the average model gets updated by overriding this method.

Parameters:
  • step_idx (Optional[int]) – Index of the last optimizer step, or None when called at the epoch end.

  • epoch_idx (Optional[int]) – Index of the last epoch, or None when called after an optimizer step.

Return type:

bool

Returns:

True if the average model should be updated and False if not.

state_dict()[source]

Called when saving a checkpoint.

Creates a state_dict of the callback state.

Return type:

dict[str, Any]

Returns:

A dictionary containing the callback state.