aimet_torch.batch_norm_fold

Top-level API

aimet_torch.batch_norm_fold.fold_all_batch_norms(model, input_shapes, dummy_input=None)

Fold all batch_norm layers in a model into the weight of the corresponding conv layers

Parameters:
  • model (Module) – Model

  • input_shapes (Union[Tuple, List[Tuple]]) – Input shapes for the model (can be one or multiple inputs)

  • dummy_input (Union[Tensor, Tuple, None]) – A dummy input to the model. Can be a Tensor or a Tuple of Tensors

Return type:

List[Tuple[Union[Linear, Conv1d, Conv2d, ConvTranspose2d], Union[BatchNorm1d, BatchNorm2d]]]

Returns:

A list of pairs of layers [(Conv/Linear, BN layer that got folded)]