# /usr/bin/env python3.6
# -*- mode: python -*-
# =============================================================================
#  @@-COPYRIGHT-START-@@
#
#  Copyright (c) 2022, Qualcomm Innovation Center, Inc. All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#
#  1. Redistributions of source code must retain the above copyright notice,
#     this list of conditions and the following disclaimer.
#
#  2. Redistributions in binary form must reproduce the above copyright notice,
#     this list of conditions and the following disclaimer in the documentation
#     and/or other materials provided with the distribution.
#
#  3. Neither the name of the copyright holder nor the names of its contributors
#     may be used to endorse or promote products derived from this software
#     without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
#  POSSIBILITY OF SUCH DAMAGE.
#
#  SPDX-License-Identifier: BSD-3-Clause
#
#  @@-COPYRIGHT-END-@@
# =============================================================================
"""BatchNorm Reestimation"""
import itertools
from typing import Iterable, List, Callable, Any
from tqdm import tqdm
import torch
from torch.utils.data import DataLoader
from torch.nn.modules.batchnorm import _BatchNorm
from aimet_torch.utils import in_eval_mode, in_train_mode
from aimet_common.utils import Handle
def _get_active_bn_modules(model: torch.nn.Module) -> Iterable[_BatchNorm]:
    for module in model.modules():
        if isinstance(module, _BatchNorm):
            bn = module
            if bn.running_mean is not None and bn.running_var is not None:
                yield bn
def _for_each_module(modules: Iterable[torch.nn.Module],
                     action: Callable[[torch.nn.Module], Handle]) -> Handle:
    """
    Apply an undoable action to each module.
    :param modules: Modules to apply the action.
    :param action: Action to be applied to the modules.
    :returns: Handle that undos the applied action.
    """
    handles: List[Handle] = []
    def cleanup():
        for handle in handles:
            handle.remove()
    try:
        for module in modules:
            handle = action(module)
            assert isinstance(handle, Handle)
            handles.append(handle)
        return Handle(cleanup)
    except:
        cleanup()
        raise
def _reset_bn_stats(module: _BatchNorm) -> Handle:
    """
    Reset BN statistics to the initial values.
    :param module: BatchNorm module.
    :returns: Handle that restores the original BN statistics upon handle.remove().
    """
    orig_running_mean = module.running_mean.clone()
    orig_running_var = module.running_var.clone()
    orig_num_batches_tracked = module.num_batches_tracked.clone()
    def cleanup():
        module.running_mean.copy_(orig_running_mean)
        module.running_var.copy_(orig_running_var)
        module.num_batches_tracked.copy_(orig_num_batches_tracked)
    try:
        module.reset_running_stats()
        return Handle(cleanup)
    except:
        cleanup()
        raise
def _reset_momentum(module: _BatchNorm) -> Handle:
    """
    Set BN momentum to 1.0.
    :param module: BatchNorm module.
    :returns: Handle that restores the original BN momentum upon handle.remove().
    """
    momentum = module.momentum
    def cleanup():
        module.momentum = momentum
    try:
        module.momentum = 1.0
        return Handle(cleanup)
    except:
        cleanup()
        raise
DEFAULT_NUM_BATCHES = 100
[docs]def reestimate_bn_stats(model: torch.nn.Module,
                        dataloader: DataLoader,
                        num_batches: int = DEFAULT_NUM_BATCHES,
                        forward_fn: Callable[[torch.nn.Module, Any], Any] = None) -> Handle:
    """
    Reestimate BatchNorm statistics (running mean and var).
    :param model: Model to reestimate the BN stats.
    :param dataloader: Training dataset.
    :param num_batches: The number of batches to be used for reestimation.
    :param forward_fn: Optional adapter function that performs forward pass
                       given a model and a input batch yielded from the data loader.
    :returns: Handle that undos the effect of BN reestimation upon handle.remove().
    """
    forward_fn = forward_fn or (lambda model, data: model(data))
    bn_modules = tuple(_get_active_bn_modules(model))
    # Set all the layers to eval mode except batchnorm layers
    with in_eval_mode(model), in_train_mode(bn_modules), torch.no_grad():
        with _for_each_module(bn_modules, action=_reset_momentum):
            handle = _for_each_module(bn_modules, action=_reset_bn_stats)
            try:
                # Batchnorm statistics accumulation buffer
                buffer = {
                    bn: {"sum_mean": torch.zeros_like(bn.running_mean),
                         "sum_var":  torch.zeros_like(bn.running_var)}
                    for bn in bn_modules
                }
                num_batches = min(len(dataloader), num_batches)
                dataloader_slice = itertools.islice(dataloader, num_batches)
                for data in tqdm(dataloader_slice,
                                 total=num_batches,
                                 desc="batchnorm reestimation"):
                    forward_fn(model, data)
                    for bn in bn_modules:
                        buffer[bn]["sum_mean"] += bn.running_mean
                        buffer[bn]["sum_var"] += bn.running_var
                for bn in bn_modules:
                    sum_mean = buffer[bn]["sum_mean"]
                    sum_var = buffer[bn]["sum_var"]
                    # Override BN stats with the reestimated stats.
                    bn.running_mean.copy_(sum_mean / min(len(dataloader), num_batches))
                    bn.running_var.copy_(sum_var / min(len(dataloader), num_batches))
                return handle
            except:
                handle.remove()
                raise