aimet_torch.batch_norm_fold¶
Top-level API
- aimet_torch.batch_norm_fold.fold_all_batch_norms(model, input_shapes, dummy_input=None)¶
Fold all batch_norm layers in a model into the weight of the corresponding conv layers
- Parameters:
model (
Module
) – Modelinput_shapes (
Union
[Tuple
,List
[Tuple
]]) – Input shapes for the model (can be one or multiple inputs)dummy_input (
Union
[Tensor
,Tuple
,None
]) – A dummy input to the model. Can be a Tensor or a Tuple of Tensors
- Return type:
List
[Tuple
[Union
[Linear
,Conv1d
,Conv2d
,ConvTranspose2d
],Union
[BatchNorm1d
,BatchNorm2d
]]]- Returns:
A list of pairs of layers [(Conv/Linear, BN layer that got folded)]