Source code for aimet_torch.v2.seq_mse

# /usr/bin/env python
# -*- mode: python -*-
# =============================================================================
#  @@-COPYRIGHT-START-@@
#
#  Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#
#  1. Redistributions of source code must retain the above copyright notice,
#     this list of conditions and the following disclaimer.
#
#  2. Redistributions in binary form must reproduce the above copyright notice,
#     this list of conditions and the following disclaimer in the documentation
#     and/or other materials provided with the distribution.
#
#  3. Neither the name of the copyright holder nor the names of its contributors
#     may be used to endorse or promote products derived from this software
#     without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
#  POSSIBILITY OF SUCH DAMAGE.
#
#  SPDX-License-Identifier: BSD-3-Clause
#
#  @@-COPYRIGHT-END-@@
# =============================================================================
"""Sequential MSE implementation"""

from typing import List, Optional, Tuple, Callable, overload
import contextlib
import copy
import warnings
import torch
from torch.utils.data import DataLoader

from aimet_common.utils import AimetLogger, _red
from aimet_torch._base.seq_mse import SequentialMseBase, SeqMseParams, SUPPORTED_MODULES
from aimet_torch.v2.quantization.base import QuantizerBase
from aimet_torch.v2.quantization.affine import (
    AffineQuantizerBase,
    QuantizeDequantize,
    GroupedBlockQuantizeDequantize,
)
from aimet_torch.utils import place_model, get_device
from aimet_torch.v2.utils import (
    default_forward_fn,
    remove_all_quantizers,
)
from aimet_torch.v2.nn.base import BaseQuantizationMixin
from aimet_torch.v2.quantsim import QuantizationSimModel
from aimet_torch.v2.deepspeed_utils import SafeGatheredParameters
from .utils import remove_activation_quantizers, remove_param_quantizers

__all__ = [
    "SequentialMse",
    "SeqMseParams",
    "apply_seq_mse",
    "get_candidates",
    "optimize_module",
]

_logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.SeqMse)


@overload
def apply_seq_mse(
    sim: QuantizationSimModel,
    data_loader: DataLoader,
    num_candidates: int = 20,
    forward_fn: Callable = default_forward_fn,
    modules_to_exclude: Optional[List[torch.nn.Module]] = None,
    checkpoints_config: Optional[str] = None,
): ...


@overload
def apply_seq_mse(
    model: torch.nn.Module,
    sim: QuantizationSimModel,
    data_loader: DataLoader,
    params: SeqMseParams,
    modules_to_exclude: Optional[List[torch.nn.Module]] = None,
    checkpoints_config: Optional[str] = None,
):
    # Deprecated
    ...


[docs] def apply_seq_mse(*args, **kwargs): """ Sequentially minimizing activation MSE loss in layer-wise way to decide optimal param quantization encodings. 1 Disable all input/output quantizers, param quantizers of non-supported modules 2 Find and feeze optimal parameter encodings candidate for remaining supported modules 3 Re-enable disabled quantizers from step 1 Example userflow: model = Model().eval() sim = QuantizationSimModel(...) apply_seq_mse(...) sim.compute_encodings(...) [compute encodings for all activations and parameters of non-supported modules] sim.export(...) NOTE: modules in modules_to_exclude won't be quantized and skipped when applying sequential MSE. :param sim: QuantizationSimModel object :param data_loader: Data loader :param num_candidates: Number of candidate encodings to evaluate for each layer :param forward_fn: callback function to perform forward pass given accepts model, inputs :param modules_to_exclude: List of supported type module(s) to exclude when applying Sequential MSE :param checkpoints_config: Config files to split fp32/quant model by checkpoints to speedup activations sampling """ if "model" in kwargs or (args and isinstance(args[0], torch.nn.Module)): warnings.warn( _red( "apply_seq_mse was called using a deprecated function signature. This will raise an error in future releases." ), DeprecationWarning, stacklevel=2, ) return SequentialMse.apply_seq_mse(*args, **kwargs) return _apply_seq_mse(*args, **kwargs)
def _apply_seq_mse( sim: QuantizationSimModel, data_loader: DataLoader, num_candidates: int = 20, forward_fn: Callable = default_forward_fn, modules_to_exclude: Optional[List[torch.nn.Module]] = None, checkpoints_config: Optional[str] = None, ): params = SeqMseParams( num_batches=None, num_candidates=num_candidates, forward_fn=forward_fn, inp_symmetry=SequentialMse.inp_symmetry, loss_fn=SequentialMse.loss_fn, ) with ( place_model(sim.model, torch.device("cpu")), remove_all_quantizers(sim.model), ): weight_shared_fp_model = _copy_as_fp_model_with_shared_weights(sim.model) model_device = get_device(sim.model) with place_model(weight_shared_fp_model, model_device): return SequentialMse.apply_seq_mse( model=weight_shared_fp_model, sim=sim, data_loader=data_loader, params=params, modules_to_exclude=modules_to_exclude, checkpoints_config=checkpoints_config, ) class SequentialMse(SequentialMseBase): """ Sequentially minimizing activation MSE loss in layer-wise way to decide optimal param quantization encodings. """ inp_symmetry: str = "symqt" loss_fn: str = "mse" @classmethod def apply_seq_mse( cls, model: torch.nn.Module, sim: QuantizationSimModel, data_loader: DataLoader, params: SeqMseParams, modules_to_exclude: Optional[List[torch.nn.Module]] = None, checkpoints_config: Optional[str] = None, ): if not modules_to_exclude: modules_to_exclude = [] modules_to_exclude.extend( cls._get_grouped_convs_with_blockwise_quantization(sim) ) with cls._handle_grouped_block_quantizers(sim): super().apply_seq_mse( model, sim, data_loader, params, modules_to_exclude, checkpoints_config ) @staticmethod def _get_grouped_convs_with_blockwise_quantization(sim): """Return a list of all grouped conv modules using blockwise quantization for weights""" grouped_convs_with_blockwise_quantization = [] for module in sim.model.modules(): if ( isinstance(module, torch.nn.Conv2d) and isinstance(module, BaseQuantizationMixin) and module.groups != 1 and module.param_quantizers["weight"].block_size is not None and module.param_quantizers["weight"].block_size[1] != module.weight.shape[1] ): grouped_convs_with_blockwise_quantization.append(module) return grouped_convs_with_blockwise_quantization @staticmethod @contextlib.contextmanager def _handle_grouped_block_quantizers(sim: QuantizationSimModel): """Set all grouped block quantizers to regular blockwise quantization for the duration of the context manager""" grouped_block_quantize_dequantizers = [] for module in sim.model.modules(): if isinstance(module, GroupedBlockQuantizeDequantize): grouped_block_quantize_dequantizers.append( (module, module.block_grouping) ) module.block_grouping = tuple(1 for _ in enumerate(module.shape)) yield for module, block_grouping in grouped_block_quantize_dequantizers: module.block_grouping = block_grouping @classmethod def compute_all_param_encodings(cls, sim: QuantizationSimModel): """ Compute encodings for all parameters, needed for initializing Sequential MSE :param sim: Quant sim """ for _, qmodule in sim.named_qmodules(): qmodule._compute_param_encodings(overwrite=True) # pylint: disable=protected-access @classmethod @contextlib.contextmanager def temporarily_disable_quantizers( cls, model: torch.nn.Module, sim: QuantizationSimModel, modules_to_exclude: Optional[List[torch.nn.Module]], ): """ For given quantsim model, disable quantizers needed to be diabled before applying sequential MSE. :param model: Original fp32 model :param sim: QuantizationSimModel object :param modules_to_exclude: List of supported modules to exclude when applying Sequential MSE :return: List of quantizers to be disabled. """ # pylint: disable=protected-access fp_modules_to_exclude = set(modules_to_exclude or []) qmodules_to_exclude = set( sim.model.get_submodule(name) for name, fp_module in model.named_modules() if fp_module in fp_modules_to_exclude ) with contextlib.ExitStack() as stack: for _, qmodule in sim.named_qmodules(): ctx = remove_activation_quantizers(qmodule) stack.enter_context(ctx) if ( not isinstance(qmodule, SUPPORTED_MODULES) or qmodule in qmodules_to_exclude ): ctx = remove_param_quantizers(qmodule) stack.enter_context(ctx) yield @classmethod def compute_param_encodings( cls, quantizer: QuantizerBase, x_min: torch.Tensor, x_max: torch.Tensor ): """ Compute encodings for parameter quantizer using given x_min and x_max values. :param quantizer: Tensor quantizer :param x_min: min values :param x_max: max values """ quantize_dequantize = QuantizeDequantize( quantizer.shape, quantizer.bitwidth, quantizer.symmetric, block_size=quantizer.block_size, ).to(x_min.device) min_tensor = x_min max_tensor = x_max if quantizer.block_size: for axis, blk_size in enumerate(quantizer.block_size): if blk_size == -1: continue min_tensor = min_tensor.repeat_interleave(blk_size, axis) max_tensor = max_tensor.repeat_interleave(blk_size, axis) with quantize_dequantize.compute_encodings(): _ = quantize_dequantize(torch.stack([min_tensor, max_tensor])) # pylint: disable=not-callable # (pylint throws a false alarm) quantizer.set_range(quantize_dequantize.min, quantize_dequantize.max) @classmethod def _is_symmetric_quantizer(cls, quantizer: AffineQuantizerBase): # pylint: disable=protected-access return quantizer._symmetric @classmethod def _freeze_quantizer_encoding(cls, quantizer: QuantizerBase): # pylint: disable=protected-access quantizer.requires_grad_(False) quantizer.allow_overwrite(False) @classmethod def _get_quantized_weight(cls, quant_module: BaseQuantizationMixin): w = quant_module.weight return quant_module.param_quantizers["weight"](w) @classmethod def _get_original_module(cls, quant_module: BaseQuantizationMixin): return quant_module @staticmethod def _get_input_channel_block_size(quant_module): if not isinstance(quant_module, (torch.nn.Linear, torch.nn.Conv2d)): raise NotImplementedError("Unsupported module type: ", type(quant_module)) if quant_module.param_quantizers["weight"].block_size is None: # Per tensor or per channel case. For either one, treat loss computation as per channel return quant_module.weight.shape[1] return ( quant_module.weight.shape[1] // quant_module.param_quantizers["weight"].shape[1] ) @staticmethod def _get_indices_to_reduce(block_size, reshaped_weight): """ Return indices in reshaped_weight corresponding to block_sizes. Reshaped_weight is expected to contain alternating dimensions of num_blocks and block_sizes. """ indices_to_reduce = [] for idx, _ in enumerate(block_size): indices_to_reduce.insert(0, (len(reshaped_weight.shape) - 2 * idx) - 1) return indices_to_reduce @classmethod def get_min_and_max_for_candidate_selection( cls, quant_module: BaseQuantizationMixin ) -> Tuple[torch.Tensor, torch.Tensor]: """ Get min/max values for candidate selection. :param quant_module: Quant module to be optimized :return: Tuple of min and max values for candidate selection. """ # pylint: disable=protected-access assert hasattr(quant_module.param_quantizers["weight"], "block_size") if not isinstance(quant_module, (torch.nn.Conv2d, torch.nn.Linear)): raise ValueError("Unsupported module: ", quant_module) max_tensor = quant_module.param_quantizers["weight"].get_max() min_tensor = quant_module.param_quantizers["weight"].get_min() return min_tensor, max_tensor @classmethod def _get_candidate( cls, candidate_idx: int, num_candidates: int, min_tensor: torch.Tensor, max_tensor: torch.Tensor, ): """ Get candidate min and max tensors """ cand_max = max_tensor / num_candidates * (candidate_idx + 1) cand_min = min_tensor / num_candidates * (candidate_idx + 1) return cand_min, cand_max @classmethod def _compute_loss( cls, quant_module: BaseQuantizationMixin, x: torch.Tensor, xq: torch.Tensor, w: torch.Tensor, wq: torch.Tensor, params: SeqMseParams, ) -> torch.Tensor: """ Compute loss for the given (x, w) and (xq, wq) input/weight pairs. Assumes that block size will be on input_channel dimension. """ # pylint: disable=too-many-locals block_size = cls._get_input_channel_block_size(quant_module) if isinstance(quant_module, torch.nn.Linear): # General strategy (Linear): # Compute blockwise reconstruction loss using batched matrix multiplication out_channels = quant_module.weight.shape[0] in_channels = quant_module.weight.shape[1] num_blocks = in_channels // block_size # Reshape and permute x and w # reshape permute # * x: (N, Cin) -> (N, NUM_BLK, BLK_SIZE) -> (NUM_BLK, N, BLK_SIZE) # * w: (Cout, Cin) -> (Cout, NUM_BLK, BLK_SIZE) -> (NUM_BLK, BLK_SIZE, Cout) x = x.reshape(-1, num_blocks, block_size).permute(1, 0, 2) w = w.reshape(out_channels, num_blocks, block_size).permute(1, 2, 0) xq = xq.reshape(-1, num_blocks, block_size).permute(1, 0, 2) wq = wq.reshape(out_channels, num_blocks, block_size).permute(1, 2, 0) # Blockwise batched matmul # xw = x @ w # (NUM_BLK, N, Cout) (NUM_BLK, N, BLK_SIZE) (NUM_BLK, BLK_SIZE, Cout) xw = torch.bmm(x, w) xqwq = torch.bmm(xq, wq) # Permute to restore axis 0 back to batch dimension # permute # * xw: (NUM_BLK, N, Cout) -> (N, Cout, NUM_BLK) xw = xw.permute(1, 2, 0) xqwq = xqwq.permute(1, 2, 0) loss = ( params.get_loss_fn()(xw, xqwq, reduction="none") .sum(0) .view(out_channels, num_blocks) ) return loss # General strategy (Conv): # Split weights and input per block, and run a separate forward pass for each split. # In the case of per tensor and per channel, the entire input channel is treated as one block. # NOTE: Similar to Linear, convolution can be also vectorized with depthwise grouped conv. # However, vectorizing convolution in this manner harms the performance # because PyTorch grouped convolution kernels are much slower than regular convolution assert isinstance(quant_module, torch.nn.Conv2d) w_blocks = torch.split(w, block_size, dim=1) wq_blocks = torch.split(wq, block_size, dim=1) groups = quant_module.groups x_blocks = torch.split(x, block_size * groups, dim=-3) xq_blocks = torch.split(xq, block_size * groups, dim=-3) block_losses = [] for idx, x_block in enumerate(x_blocks): xqwq, xw = cls.compute_outputs( quant_module, x_block, xq_blocks[idx], w_blocks[idx], wq_blocks[idx] ) block_losses.append(cls.compute_recon_loss(xqwq, xw, params)) # Stack losses in the input channel dimension block_losses = torch.stack(block_losses, dim=-1) return block_losses @classmethod def optimize_module( cls, quant_module: BaseQuantizationMixin, x: torch.Tensor, xq: torch.Tensor, params: SeqMseParams, ): """ Find and freeze optimal parameter encodings candidate for given module. :param quant_module: Quant module to be optimized :param x: Inputs to module from FP32 model :param xq: Inputs to module from QuantSim model :param params: Sequenial MSE parameters """ # pylint: disable=too-many-locals with SafeGatheredParameters(quant_module.parameters(recurse=True)): min_tensor, max_tensor = cls.get_min_and_max_for_candidate_selection( quant_module ) total_loss = [] for i in range(params.num_candidates): cand_min, cand_max = cls._get_candidate( i, params.num_candidates, min_tensor, max_tensor ) cls.compute_param_encodings( quant_module.param_quantizers["weight"], cand_min, cand_max ) w = quant_module.weight wq = cls._get_quantized_weight(quant_module) with torch.no_grad(): loss = 0 for batch_idx, (batch_x, batch_xq) in enumerate(zip(x, xq)): if params.num_batches and batch_idx >= params.num_batches: break loss += cls._compute_loss( quant_module, batch_x, batch_xq, w, wq, params ) total_loss.append(loss) best_indices = torch.stack(total_loss).min(0)[1] # Unsqueeze best_indices until it matches dim length of max_tensor while best_indices.dim() < max_tensor.dim(): best_indices = best_indices[..., None] min_tensor, max_tensor = cls._get_candidate( best_indices, params.num_candidates, min_tensor, max_tensor ) # Compute and freeze parameter encodings using best candidate cls.compute_param_encodings( quant_module.param_quantizers["weight"], min_tensor, max_tensor ) cls._freeze_quantizer_encoding(quant_module.param_quantizers["weight"]) def _copy_as_fp_model_with_shared_weights(model): new_model = copy.copy(model) # pylint: disable=protected-access new_model._modules = copy.copy(new_model._modules) new_model._parameters = copy.copy(new_model._parameters) new_model._buffers = copy.copy(new_model._buffers) for name, child in model.named_children(): if isinstance(child, BaseQuantizationMixin): setattr(new_model, name, child.get_original_module()) else: setattr(new_model, name, _copy_as_fp_model_with_shared_weights(child)) return new_model # Global variables for compatibility get_candidates = SequentialMse.get_candidates optimize_module = SequentialMse.optimize_module