aimet_torch.bn_reestimation

Top-level API

aimet_torch.bn_reestimation.reestimate_bn_stats(model, dataloader, num_batches=100, forward_fn=None)[source]

Reestimate BatchNorm statistics (running mean and var).

Parameters:
  • model (Module) – Model to reestimate the BN stats.

  • dataloader (DataLoader) – Training dataset.

  • num_batches (int) – The number of batches to be used for reestimation.

  • forward_fn (Optional[Callable[[Module, Any], Any]]) – Optional adapter function that performs forward pass given a model and a input batch yielded from the data loader.

Return type:

Handle

Returns:

Handle that undos the effect of BN reestimation upon handle.remove().