# /usr/bin/env python
# -*- mode: python -*-
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
# SPDX-License-Identifier: BSD-3-Clause
#
# @@-COPYRIGHT-END-@@
# =============================================================================
"""Sequential MSE implementation"""
from typing import List, Optional, Tuple, Callable, overload
import contextlib
import copy
import warnings
import torch
from torch.utils.data import DataLoader
from aimet_common.utils import AimetLogger, _red
from aimet_torch._base.seq_mse import SequentialMseBase, SeqMseParams, SUPPORTED_MODULES
from aimet_torch.v2.quantization.base import QuantizerBase
from aimet_torch.v2.quantization.affine import (
AffineQuantizerBase,
QuantizeDequantize,
GroupedBlockQuantizeDequantize,
)
from aimet_torch.utils import place_model, get_device
from aimet_torch.v2.utils import (
default_forward_fn,
remove_all_quantizers,
)
from aimet_torch.v2.nn.base import BaseQuantizationMixin
from aimet_torch.v2.quantsim import QuantizationSimModel
from aimet_torch.v2.deepspeed_utils import SafeGatheredParameters
from .utils import remove_activation_quantizers, remove_param_quantizers
__all__ = [
"SequentialMse",
"SeqMseParams",
"apply_seq_mse",
"get_candidates",
"optimize_module",
]
_logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.SeqMse)
@overload
def apply_seq_mse(
sim: QuantizationSimModel,
data_loader: DataLoader,
num_candidates: int = 20,
forward_fn: Callable = default_forward_fn,
modules_to_exclude: Optional[List[torch.nn.Module]] = None,
checkpoints_config: Optional[str] = None,
): ...
@overload
def apply_seq_mse(
model: torch.nn.Module,
sim: QuantizationSimModel,
data_loader: DataLoader,
params: SeqMseParams,
modules_to_exclude: Optional[List[torch.nn.Module]] = None,
checkpoints_config: Optional[str] = None,
):
# Deprecated
...
[docs]
def apply_seq_mse(*args, **kwargs):
"""
Sequentially minimizing activation MSE loss in layer-wise way to decide optimal param quantization encodings.
1 Disable all input/output quantizers, param quantizers of non-supported modules
2 Find and feeze optimal parameter encodings candidate for remaining supported modules
3 Re-enable disabled quantizers from step 1
Example userflow:
model = Model().eval()
sim = QuantizationSimModel(...)
apply_seq_mse(...)
sim.compute_encodings(...) [compute encodings for all activations and parameters of non-supported modules]
sim.export(...)
NOTE: modules in modules_to_exclude won't be quantized and skipped when applying sequential MSE.
:param sim: QuantizationSimModel object
:param data_loader: Data loader
:param num_candidates: Number of candidate encodings to evaluate for each layer
:param forward_fn: callback function to perform forward pass given accepts model, inputs
:param modules_to_exclude: List of supported type module(s) to exclude when applying Sequential MSE
:param checkpoints_config: Config files to split fp32/quant model by checkpoints to speedup activations sampling
"""
if "model" in kwargs or (args and isinstance(args[0], torch.nn.Module)):
warnings.warn(
_red(
"apply_seq_mse was called using a deprecated function signature. This will raise an error in future releases."
),
DeprecationWarning,
stacklevel=2,
)
return SequentialMse.apply_seq_mse(*args, **kwargs)
return _apply_seq_mse(*args, **kwargs)
def _apply_seq_mse(
sim: QuantizationSimModel,
data_loader: DataLoader,
num_candidates: int = 20,
forward_fn: Callable = default_forward_fn,
modules_to_exclude: Optional[List[torch.nn.Module]] = None,
checkpoints_config: Optional[str] = None,
):
params = SeqMseParams(
num_batches=None,
num_candidates=num_candidates,
forward_fn=forward_fn,
inp_symmetry=SequentialMse.inp_symmetry,
loss_fn=SequentialMse.loss_fn,
)
with (
place_model(sim.model, torch.device("cpu")),
remove_all_quantizers(sim.model),
):
weight_shared_fp_model = _copy_as_fp_model_with_shared_weights(sim.model)
model_device = get_device(sim.model)
with place_model(weight_shared_fp_model, model_device):
return SequentialMse.apply_seq_mse(
model=weight_shared_fp_model,
sim=sim,
data_loader=data_loader,
params=params,
modules_to_exclude=modules_to_exclude,
checkpoints_config=checkpoints_config,
)
class SequentialMse(SequentialMseBase):
"""
Sequentially minimizing activation MSE loss in layer-wise way to decide optimal param quantization encodings.
"""
inp_symmetry: str = "symqt"
loss_fn: str = "mse"
@classmethod
def apply_seq_mse(
cls,
model: torch.nn.Module,
sim: QuantizationSimModel,
data_loader: DataLoader,
params: SeqMseParams,
modules_to_exclude: Optional[List[torch.nn.Module]] = None,
checkpoints_config: Optional[str] = None,
):
if not modules_to_exclude:
modules_to_exclude = []
modules_to_exclude.extend(
cls._get_grouped_convs_with_blockwise_quantization(sim)
)
with cls._handle_grouped_block_quantizers(sim):
super().apply_seq_mse(
model, sim, data_loader, params, modules_to_exclude, checkpoints_config
)
@staticmethod
def _get_grouped_convs_with_blockwise_quantization(sim):
"""Return a list of all grouped conv modules using blockwise quantization for weights"""
grouped_convs_with_blockwise_quantization = []
for module in sim.model.modules():
if (
isinstance(module, torch.nn.Conv2d)
and isinstance(module, BaseQuantizationMixin)
and module.groups != 1
and module.param_quantizers["weight"].block_size is not None
and module.param_quantizers["weight"].block_size[1]
!= module.weight.shape[1]
):
grouped_convs_with_blockwise_quantization.append(module)
return grouped_convs_with_blockwise_quantization
@staticmethod
@contextlib.contextmanager
def _handle_grouped_block_quantizers(sim: QuantizationSimModel):
"""Set all grouped block quantizers to regular blockwise quantization for the duration of the context manager"""
grouped_block_quantize_dequantizers = []
for module in sim.model.modules():
if isinstance(module, GroupedBlockQuantizeDequantize):
grouped_block_quantize_dequantizers.append(
(module, module.block_grouping)
)
module.block_grouping = tuple(1 for _ in enumerate(module.shape))
yield
for module, block_grouping in grouped_block_quantize_dequantizers:
module.block_grouping = block_grouping
@classmethod
def compute_all_param_encodings(cls, sim: QuantizationSimModel):
"""
Compute encodings for all parameters, needed for initializing Sequential MSE
:param sim: Quant sim
"""
for _, qmodule in sim.named_qmodules():
qmodule._compute_param_encodings(overwrite=True) # pylint: disable=protected-access
@classmethod
@contextlib.contextmanager
def temporarily_disable_quantizers(
cls,
model: torch.nn.Module,
sim: QuantizationSimModel,
modules_to_exclude: Optional[List[torch.nn.Module]],
):
"""
For given quantsim model, disable quantizers needed to be diabled before applying sequential MSE.
:param model: Original fp32 model
:param sim: QuantizationSimModel object
:param modules_to_exclude: List of supported modules to exclude when applying Sequential MSE
:return: List of quantizers to be disabled.
"""
# pylint: disable=protected-access
fp_modules_to_exclude = set(modules_to_exclude or [])
qmodules_to_exclude = set(
sim.model.get_submodule(name)
for name, fp_module in model.named_modules()
if fp_module in fp_modules_to_exclude
)
with contextlib.ExitStack() as stack:
for _, qmodule in sim.named_qmodules():
ctx = remove_activation_quantizers(qmodule)
stack.enter_context(ctx)
if (
not isinstance(qmodule, SUPPORTED_MODULES)
or qmodule in qmodules_to_exclude
):
ctx = remove_param_quantizers(qmodule)
stack.enter_context(ctx)
yield
@classmethod
def compute_param_encodings(
cls, quantizer: QuantizerBase, x_min: torch.Tensor, x_max: torch.Tensor
):
"""
Compute encodings for parameter quantizer using given x_min and x_max values.
:param quantizer: Tensor quantizer
:param x_min: min values
:param x_max: max values
"""
quantize_dequantize = QuantizeDequantize(
quantizer.shape,
quantizer.bitwidth,
quantizer.symmetric,
block_size=quantizer.block_size,
).to(x_min.device)
min_tensor = x_min
max_tensor = x_max
if quantizer.block_size:
for axis, blk_size in enumerate(quantizer.block_size):
if blk_size == -1:
continue
min_tensor = min_tensor.repeat_interleave(blk_size, axis)
max_tensor = max_tensor.repeat_interleave(blk_size, axis)
with quantize_dequantize.compute_encodings():
_ = quantize_dequantize(torch.stack([min_tensor, max_tensor])) # pylint: disable=not-callable
# (pylint throws a false alarm)
quantizer.set_range(quantize_dequantize.min, quantize_dequantize.max)
@classmethod
def _is_symmetric_quantizer(cls, quantizer: AffineQuantizerBase):
# pylint: disable=protected-access
return quantizer._symmetric
@classmethod
def _freeze_quantizer_encoding(cls, quantizer: QuantizerBase):
# pylint: disable=protected-access
quantizer.requires_grad_(False)
quantizer.allow_overwrite(False)
@classmethod
def _get_quantized_weight(cls, quant_module: BaseQuantizationMixin):
w = quant_module.weight
return quant_module.param_quantizers["weight"](w)
@classmethod
def _get_original_module(cls, quant_module: BaseQuantizationMixin):
return quant_module
@staticmethod
def _get_input_channel_block_size(quant_module):
if not isinstance(quant_module, (torch.nn.Linear, torch.nn.Conv2d)):
raise NotImplementedError("Unsupported module type: ", type(quant_module))
if quant_module.param_quantizers["weight"].block_size is None:
# Per tensor or per channel case. For either one, treat loss computation as per channel
return quant_module.weight.shape[1]
return (
quant_module.weight.shape[1]
// quant_module.param_quantizers["weight"].shape[1]
)
@staticmethod
def _get_indices_to_reduce(block_size, reshaped_weight):
"""
Return indices in reshaped_weight corresponding to block_sizes. Reshaped_weight is expected to contain
alternating dimensions of num_blocks and block_sizes.
"""
indices_to_reduce = []
for idx, _ in enumerate(block_size):
indices_to_reduce.insert(0, (len(reshaped_weight.shape) - 2 * idx) - 1)
return indices_to_reduce
@classmethod
def get_min_and_max_for_candidate_selection(
cls, quant_module: BaseQuantizationMixin
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Get min/max values for candidate selection.
:param quant_module: Quant module to be optimized
:return: Tuple of min and max values for candidate selection.
"""
# pylint: disable=protected-access
assert hasattr(quant_module.param_quantizers["weight"], "block_size")
if not isinstance(quant_module, (torch.nn.Conv2d, torch.nn.Linear)):
raise ValueError("Unsupported module: ", quant_module)
max_tensor = quant_module.param_quantizers["weight"].get_max()
min_tensor = quant_module.param_quantizers["weight"].get_min()
return min_tensor, max_tensor
@classmethod
def _get_candidate(
cls,
candidate_idx: int,
num_candidates: int,
min_tensor: torch.Tensor,
max_tensor: torch.Tensor,
):
"""
Get candidate min and max tensors
"""
cand_max = max_tensor / num_candidates * (candidate_idx + 1)
cand_min = min_tensor / num_candidates * (candidate_idx + 1)
return cand_min, cand_max
@classmethod
def _compute_loss(
cls,
quant_module: BaseQuantizationMixin,
x: torch.Tensor,
xq: torch.Tensor,
w: torch.Tensor,
wq: torch.Tensor,
params: SeqMseParams,
) -> torch.Tensor:
"""
Compute loss for the given (x, w) and (xq, wq) input/weight pairs. Assumes that block size will be on
input_channel dimension.
"""
# pylint: disable=too-many-locals
block_size = cls._get_input_channel_block_size(quant_module)
if isinstance(quant_module, torch.nn.Linear):
# General strategy (Linear):
# Compute blockwise reconstruction loss using batched matrix multiplication
out_channels = quant_module.weight.shape[0]
in_channels = quant_module.weight.shape[1]
num_blocks = in_channels // block_size
# Reshape and permute x and w
# reshape permute
# * x: (N, Cin) -> (N, NUM_BLK, BLK_SIZE) -> (NUM_BLK, N, BLK_SIZE)
# * w: (Cout, Cin) -> (Cout, NUM_BLK, BLK_SIZE) -> (NUM_BLK, BLK_SIZE, Cout)
x = x.reshape(-1, num_blocks, block_size).permute(1, 0, 2)
w = w.reshape(out_channels, num_blocks, block_size).permute(1, 2, 0)
xq = xq.reshape(-1, num_blocks, block_size).permute(1, 0, 2)
wq = wq.reshape(out_channels, num_blocks, block_size).permute(1, 2, 0)
# Blockwise batched matmul
# xw = x @ w
# (NUM_BLK, N, Cout) (NUM_BLK, N, BLK_SIZE) (NUM_BLK, BLK_SIZE, Cout)
xw = torch.bmm(x, w)
xqwq = torch.bmm(xq, wq)
# Permute to restore axis 0 back to batch dimension
# permute
# * xw: (NUM_BLK, N, Cout) -> (N, Cout, NUM_BLK)
xw = xw.permute(1, 2, 0)
xqwq = xqwq.permute(1, 2, 0)
loss = (
params.get_loss_fn()(xw, xqwq, reduction="none")
.sum(0)
.view(out_channels, num_blocks)
)
return loss
# General strategy (Conv):
# Split weights and input per block, and run a separate forward pass for each split.
# In the case of per tensor and per channel, the entire input channel is treated as one block.
# NOTE: Similar to Linear, convolution can be also vectorized with depthwise grouped conv.
# However, vectorizing convolution in this manner harms the performance
# because PyTorch grouped convolution kernels are much slower than regular convolution
assert isinstance(quant_module, torch.nn.Conv2d)
w_blocks = torch.split(w, block_size, dim=1)
wq_blocks = torch.split(wq, block_size, dim=1)
groups = quant_module.groups
x_blocks = torch.split(x, block_size * groups, dim=-3)
xq_blocks = torch.split(xq, block_size * groups, dim=-3)
block_losses = []
for idx, x_block in enumerate(x_blocks):
xqwq, xw = cls.compute_outputs(
quant_module, x_block, xq_blocks[idx], w_blocks[idx], wq_blocks[idx]
)
block_losses.append(cls.compute_recon_loss(xqwq, xw, params))
# Stack losses in the input channel dimension
block_losses = torch.stack(block_losses, dim=-1)
return block_losses
@classmethod
def optimize_module(
cls,
quant_module: BaseQuantizationMixin,
x: torch.Tensor,
xq: torch.Tensor,
params: SeqMseParams,
):
"""
Find and freeze optimal parameter encodings candidate for given module.
:param quant_module: Quant module to be optimized
:param x: Inputs to module from FP32 model
:param xq: Inputs to module from QuantSim model
:param params: Sequenial MSE parameters
"""
# pylint: disable=too-many-locals
with SafeGatheredParameters(quant_module.parameters(recurse=True)):
min_tensor, max_tensor = cls.get_min_and_max_for_candidate_selection(
quant_module
)
total_loss = []
for i in range(params.num_candidates):
cand_min, cand_max = cls._get_candidate(
i, params.num_candidates, min_tensor, max_tensor
)
cls.compute_param_encodings(
quant_module.param_quantizers["weight"], cand_min, cand_max
)
w = quant_module.weight
wq = cls._get_quantized_weight(quant_module)
with torch.no_grad():
loss = 0
for batch_idx, (batch_x, batch_xq) in enumerate(zip(x, xq)):
if params.num_batches and batch_idx >= params.num_batches:
break
loss += cls._compute_loss(
quant_module, batch_x, batch_xq, w, wq, params
)
total_loss.append(loss)
best_indices = torch.stack(total_loss).min(0)[1]
# Unsqueeze best_indices until it matches dim length of max_tensor
while best_indices.dim() < max_tensor.dim():
best_indices = best_indices[..., None]
min_tensor, max_tensor = cls._get_candidate(
best_indices, params.num_candidates, min_tensor, max_tensor
)
# Compute and freeze parameter encodings using best candidate
cls.compute_param_encodings(
quant_module.param_quantizers["weight"], min_tensor, max_tensor
)
cls._freeze_quantizer_encoding(quant_module.param_quantizers["weight"])
def _copy_as_fp_model_with_shared_weights(model):
new_model = copy.copy(model)
# pylint: disable=protected-access
new_model._modules = copy.copy(new_model._modules)
new_model._parameters = copy.copy(new_model._parameters)
new_model._buffers = copy.copy(new_model._buffers)
for name, child in model.named_children():
if isinstance(child, BaseQuantizationMixin):
setattr(new_model, name, child.get_original_module())
else:
setattr(new_model, name, _copy_as_fp_model_with_shared_weights(child))
return new_model
# Global variables for compatibility
get_candidates = SequentialMse.get_candidates
optimize_module = SequentialMse.optimize_module