Source code for aimet_torch.v2.nn.base

# -*- mode: python -*-
# =============================================================================
#  @@-COPYRIGHT-START-@@
#
#  Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#
#  1. Redistributions of source code must retain the above copyright notice,
#     this list of conditions and the following disclaimer.
#
#  2. Redistributions in binary form must reproduce the above copyright notice,
#     this list of conditions and the following disclaimer in the documentation
#     and/or other materials provided with the distribution.
#
#  3. Neither the name of the copyright holder nor the names of its contributors
#     may be used to endorse or promote products derived from this software
#     without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
#  POSSIBILITY OF SUCH DAMAGE.
#
#  SPDX-License-Identifier: BSD-3-Clause
#
#  @@-COPYRIGHT-END-@@
# =============================================================================
"""Base class of quantized modules"""

import abc
import contextlib
import itertools
from typing import Type, List, Dict, Union, Iterable, Mapping, Optional

import torch.nn as nn
from torch import Tensor

from aimet_torch.utils import is_vector_encoding
from aimet_torch.v2.quantization.affine.encoding import VectorEncoding, AffineEncoding

from aimet_torch.v2.quantization.tensor import QuantizedTensorBase
from aimet_torch.v2.quantization.base import QuantizerBase
from aimet_torch.v2.utils import (
    patch_attr,
    _ContextManager,
    flatten_nn_module_list,
)

def _no_op(in_tensor):
    return in_tensor

[docs]class BaseQuantizationMixin(abc.ABC): """Mixin that implements quantization on top of regular pytorch modules. Attributes: input_quantizers (nn.ModuleList): :class:`ModuleList` containing :class:`QuantizerBase` objects to be applied to the layer's input tensors output_quantizers (nn.ModuleList): :class:`ModuleList` containing :class:`QuantizerBase` objects to be applied to the layer's output tensors param_quantizers (nn.ModuleDict): :class:`ModuleDict` mapping parameter names to associated :class:`QuantizerBase` objects """ input_quantizers: nn.ModuleList output_quantizers: nn.ModuleList param_quantizers: nn.ModuleDict def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.__quant_init__()
[docs] def __quant_init__(self): """Initializer for quantized module. This method will be invoked right after :meth:`__init__`. This method initializes the :attr:`input_quantizers`, :attr:`output_quantizers`, and :attr:`param_quantizers` structures to the appropriate sizes based on the number of input tensors, output tensors, and parameters of the base :class:`nn.Module` class. All quantizers are initializd to ``None``. For custom quantized classes, this method should be overridden to set the appropriate lengths of :attr:`input_quantizers` and :attr:`output_quantizers` for the given base class. """ self.param_quantizers = nn.ModuleDict({ name: None for name, _ in self.named_parameters(recurse=False) }) # Currently assume single input & output self.input_quantizers = nn.ModuleList([None]) self.output_quantizers = nn.ModuleList([None])
def __call__(self, *args, **kwargs): self._compute_param_encodings(overwrite=False) return super().__call__(*args, **kwargs)
[docs] @abc.abstractmethod def forward(self, *args, **kwargs): """Forward function for quantized module. This method will replace the original forward function of the base :class:`nn.Module` class and is responsible for computing a quantized version of the base class' forward function using the configuration of the layer's :class:`QuantizerBase` objects. """ return super().forward(*args, **kwargs)
@contextlib.contextmanager def _patch_quantized_parameters(self): with contextlib.ExitStack() as stack: for param_name, param_quantizer in self.param_quantizers.items(): if param_quantizer: orig_param = getattr(self, param_name) quantized_param = param_quantizer(orig_param) ctx = patch_attr(self, param_name, quantized_param) stack.enter_context(ctx) yield def _compute_param_encodings(self, overwrite: bool): """ :param bool overwrite: If True, the quantizers that are already initialized will also recompute encodings. Otherwise, only the uninitialized quantizers will compute encodings. """ for param_name, param_quantizer in self.param_quantizers.items(): if not param_quantizer: continue if not param_quantizer._allow_overwrite: # pylint: disable=protected-access continue if not param_quantizer.is_initialized() or overwrite: param = getattr(self, param_name) if param is not None: with patch_attr(param_quantizer, "forward", _no_op), param_quantizer.compute_encodings(): _ = param_quantizer(param) def compute_param_encodings(self): """ Compute encodings of parameter quantizers """ self._compute_param_encodings(overwrite=True)
[docs] @contextlib.contextmanager def compute_encodings(self): """Enters the :meth:`compute_encodings` context for all :class:`QuantizerBase` objects in the layer. Inside this context, each quantizer will observe all inputs passed to the quantizer and will compute quantization encodings upon exiting the context. Example: >>> qlinear = QuantizedLinear(10, 10) >>> qlinear.output_quantizers[0] = Quantize((), 8, symmetric=False) >>> with qlinear.compute_encodings(): >>> qlinear(torch.randn(16, 10)) >>> print(qlinear.output_quantizers[0].is_initialized()) True """ self._compute_param_encodings(overwrite=True) with contextlib.ExitStack() as stack: input_quantizers = flatten_nn_module_list(self.input_quantizers) output_quantizers = flatten_nn_module_list(self.output_quantizers) for quantizer in itertools.chain(input_quantizers, output_quantizers): if not isinstance(quantizer, QuantizerBase): continue if not quantizer._allow_overwrite: # pylint: disable=protected-access continue # Set input/output quantizers into pass-through mode during compute_encodings # NOTE: This behavior is for backawrd-compatibility with V1 quantsim. stack.enter_context(patch_attr(quantizer, 'forward', _no_op)) ctx = quantizer.compute_encodings() stack.enter_context(ctx) yield
@classmethod @abc.abstractmethod def wrap(cls, module_cls: Type[nn.Module]): """ Wrap a regular module class into a quantized module class """ @classmethod def from_module(cls, module: nn.Module): r"""Create an instance of quantized module from a regular module instance. The resulting quantized module contains the same attributes and parameters as the original module, but may be assigned input, output and parameter quantizers. :param module: Floating point module to quantize :return: Quantized version of the original module Example: >>> linear = torch.nn.linear(10, 10) >>> quantized_linear = FakeQuantizationMixin.from_module(linear) >>> print(quantized_linear.weight is linear.weight) True >>> print(quantized_linear.param_quantizers) ModuleDict( (weight): None (bias): None ) """ # pylint: disable=protected-access module_cls = type(module) qtzn_module_cls = cls.cls_to_qcls.get(module_cls, None) if not qtzn_module_cls: raise RuntimeError( f'The quantized module definition of {module_cls} is not registered. ' f'Please register the quantized module definition of {module_cls} ' f'using `@{cls.__name__}.implements({module_cls.__name__})` decorator.' ) qtzn_module = cls.__new__(qtzn_module_cls) qtzn_module.__dict__ = module.__dict__.copy() qtzn_module._modules = module._modules.copy() qtzn_module._parameters = module._parameters.copy() qtzn_module._buffers = module._buffers.copy() qtzn_module.__quant_init__() return qtzn_module def export_input_encodings(self) -> List[List[Dict]]: """ Returns a list of input encodings, each represented as a List of Dicts """ return [ quantizer.get_legacy_encodings() if isinstance(quantizer, QuantizerBase) else None for quantizer in flatten_nn_module_list(self.input_quantizers) ] def import_input_encodings(self, encodings: Mapping[str, Mapping], strict: bool, partial: bool, requires_grad: Optional[bool], allow_overwrite: bool): """ Import input encodings represented in below format: { '0': dict, '1': dict, ... } :param encodings: Dictionary mapping quantizer index (str) to encoding (dict) :param ignore_when_quantizer_disabled: If True, does not raise RuntimeError when a quantizer is disabled :param disable_quantizer_without_encoding: If True, disable any quantizer without an encoding in `encodings` :param freeze: If True, freezes the quantizer's encodings after loading """ for i, quantizer in enumerate(list(self.input_quantizers)): if quantizer and not quantizer._allow_overwrite: # pylint: disable=protected-access continue encoding = encodings.get(str(i), None) if not encoding: if not partial: # Dangling quantizers have to be removed when importing non-partial encodings self.input_quantizers[i] = None continue if quantizer is None: if strict: raise RuntimeError continue if isinstance(encoding, dict): encoding = [encoding] quantizer.set_legacy_encodings(encoding) if requires_grad is not None: quantizer.requires_grad_(requires_grad) quantizer.allow_overwrite(allow_overwrite) def export_output_encodings(self) -> List[List[Dict]]: """ Returns a list of output encodings, each represented as a List of Dicts """ return [ quantizer.get_legacy_encodings() if isinstance(quantizer, QuantizerBase) else None for quantizer in flatten_nn_module_list(self.output_quantizers) ] def import_output_encodings(self, encodings: Mapping[str, Mapping], strict: bool, partial: bool, requires_grad: Optional[bool], allow_overwrite: bool): """ Import output encodings represented in below format: { '0': dict, '1': dict, ... } :param encodings: Dictionary mapping quantizer index (str) to encoding (dict) :param ignore_when_quantizer_disabled: If True, does not raise RuntimeError when a quantizer is disabled :param disable_quantizer_without_encoding: If True, disable any quantizer without an encoding in `encodings` :param freeze: If True, freezes the quantizer's encodings after loading """ for i, quantizer in enumerate(list(self.output_quantizers)): if quantizer and not quantizer._allow_overwrite: # pylint: disable=protected-access continue encoding = encodings.get(str(i), None) if not encoding: if not partial: # Dangling quantizers have to be removed when importing non-partial encodings self.output_quantizers[i] = None continue if quantizer is None: if strict: raise RuntimeError continue if isinstance(encoding, dict): encoding = [encoding] quantizer.set_legacy_encodings(encoding) if requires_grad is not None: quantizer.requires_grad_(requires_grad) quantizer.allow_overwrite(allow_overwrite) def export_param_encodings(self) -> Dict[str, List[Dict]]: """ Returns a dict of {param name: param encodings}, with each encoding represented as a List of Dicts """ encodings = { param_name: quantizer.get_legacy_encodings() if isinstance(quantizer, QuantizerBase) else None for param_name, quantizer in self.param_quantizers.items() } for param_name, quantizer in self.param_quantizers.items(): param = getattr(self, param_name) if isinstance(quantizer, QuantizerBase): e = encodings[param_name] elif isinstance(param, QuantizedTensorBase) and param.encoding is not None: # If parameter itself is an already-quantized tensor, # export the encoding held by the parameter e = param.encoding._to_legacy_format() # pylint: disable=protected-access else: e = None encodings[param_name] = e return encodings def import_param_encodings(self, encodings: Mapping[str, Mapping], strict: bool, partial: bool, requires_grad: Optional[bool], allow_overwrite: bool): """ Import parameter encodings represented in below format: { 'param_name_0': [dict, dict, ...], 'param_name_1': [dict, dict, ...], ... } :param encodings: Dictionary mapping quantizer parameter name (str) to encodings (dict) :param ignore_when_quantizer_disabled: If True, does not raise RuntimeError when a quantizer is disabled :param disable_quantizer_without_encoding: If True, disable any quantizer without an encoding in `encodings` :param freeze: If True, freezes the quantizer's encodings after loading """ for param_name, quantizer in dict(self.param_quantizers).items(): if quantizer and not quantizer._allow_overwrite: # pylint: disable=protected-access continue encoding = encodings.get(param_name, None) if is_vector_encoding(encoding): # Vector encodings will be held directly by weights, not by quantizers. quantizer.set_legacy_encodings(encoding) param = getattr(self, param_name) rounded_weight = quantizer(param) # At this point, rounded_weight is a quantized tensor with affine encoding # since quantizer is an affine quantizer assert isinstance(rounded_weight, QuantizedTensorBase) assert isinstance(rounded_weight.encoding, AffineEncoding) e = rounded_weight.encoding # Convert affine encoding to vector encoding vector_encoding_properties = { "rows_per_block": encoding[0]["rows_per_block"], "cols_per_block": encoding[0]["cols_per_block"], "vector_dim": encoding[0]["vector_dim"], "vector_stride": encoding[0]["vector_stride"], "index_bw": encoding[0]["index_bw"], } rounded_weight.encoding = VectorEncoding(e.scale, e.offset, e.bitwidth, e.signed, e.symmetry, block_size=None, **vector_encoding_properties) setattr(self, param_name, nn.Parameter(rounded_weight)) # Remove associated quantizer since the weight is holding already-quantized values self.param_quantizers[param_name] = None if not encoding: if not partial: # Dangling quantizers have to be removed when importing non-partial encodings self.param_quantizers[param_name] = None continue if quantizer is None: if strict: raise RuntimeError continue if isinstance(encoding, dict): encoding = [encoding] quantizer.set_legacy_encodings(encoding) if requires_grad is not None: quantizer.requires_grad_(requires_grad) quantizer.allow_overwrite(allow_overwrite) def get_original_module(self) -> nn.Module: """Returns the floating point version of the quantized module Returns: A floating point module with quantizers removed Example: >>> qlinear = QuantizedLinear(10, 20, bias=False) >>> linear = qlinear.get_original_module() >>> linear Linear(in_features=10, out_features=20, bias=False) >>> linear.weight is qlinear.weight True """ # pylint: disable=protected-access qtzn_module_cls = type(self) orig_module_cls = self.qcls_to_cls.get(qtzn_module_cls) orig_module = self.__new__(orig_module_cls) orig_module.__dict__ = self.__dict__.copy() orig_module.__dict__.pop('forward', None) orig_module._parameters = self._parameters.copy() orig_module._buffers = self._buffers.copy() orig_module._modules = self._modules.copy() del orig_module._modules['input_quantizers'] del orig_module._modules['output_quantizers'] del orig_module._modules['param_quantizers'] return orig_module def _remove_input_quantizers(self, indices: Union[int, Iterable[int]] = None): """ Remove input quantizers :param indices: Indices of input quantizers to remove. If None, all input quantizers will be removed. """ if isinstance(indices, int): indices = [indices] elif indices is None: indices = list(range(len(self.input_quantizers))) return _remove_quantizers(self.input_quantizers, indices) def _remove_param_quantizers(self, keys: Union[str, Iterable[str]] = None): """ Remove parameter quantizers :param indices: Indices of parameter quantizers to remove. If None, all input quantizers will be removed. """ if isinstance(keys, str): keys = [keys] elif keys is None: keys = list(self.param_quantizers.keys()) return _remove_quantizers(self.param_quantizers, keys) def _remove_output_quantizers(self, indices: Union[int, Iterable[int]] = None): """ Remove output quantizers :param indices: Indices of input quantizers to remove. If None, all input quantizers will be removed. """ if isinstance(indices, int): indices = [indices] elif indices is None: indices = list(range(len(self.output_quantizers))) return _remove_quantizers(self.output_quantizers, indices) def _remove_activation_quantizers(self): """ Remove all activation quantizers """ # pylint: disable=protected-access ctx_1 = self._remove_output_quantizers() ctx_2 = self._remove_input_quantizers() return _ContextManager(action=lambda: None, cleanup=lambda: (ctx_1._cleanup(), ctx_2._cleanup())) def _remove_all_quantizers(self): """ Remove all quantizers """ # pylint: disable=protected-access ctx_1 = self._remove_activation_quantizers() ctx_2 = self._remove_param_quantizers() return _ContextManager(action=lambda: None, cleanup=lambda: (ctx_1._cleanup(), ctx_2._cleanup()))
class _BaseQuantizedUnaryOpMixin(BaseQuantizationMixin): def forward(self, *args, **kwargs) -> Tensor: # pylint: disable=missing-function-docstring x, *others = args if isinstance(x, Tensor) and x.is_floating_point() and self.input_quantizers[0]: x = self.input_quantizers[0](x) with self._patch_quantized_parameters(): output = super().forward(x, *others, **kwargs) if isinstance(output, Tensor) and output.is_floating_point() and self.output_quantizers[0]: output = self.output_quantizers[0](output) return output class _BaseQuantizedBinaryOpMixin(BaseQuantizationMixin): def __quant_init__(self): super().__quant_init__() self.input_quantizers = nn.ModuleList([None, None]) def forward(self, *args, **kwargs) -> Tensor: # pylint: disable=missing-function-docstring x, y, *others = args if isinstance(x, Tensor) and x.is_floating_point() and self.input_quantizers[0]: x = self.input_quantizers[0](x) if isinstance(y, Tensor) and y.is_floating_point() and self.input_quantizers[1]: y = self.input_quantizers[1](y) with self._patch_quantized_parameters(): output = super().forward(x, y, *others, **kwargs) if isinstance(output, Tensor) and output.is_floating_point() and self.output_quantizers[0]: output = self.output_quantizers[0](output) return output class _BaseQuantizedTernaryOpMixin(BaseQuantizationMixin): def __quant_init__(self): super().__quant_init__() self.input_quantizers = nn.ModuleList([None, None, None]) def forward(self, *args, **kwargs) -> Tensor: # pylint: disable=missing-function-docstring x, y, z, *others = args if isinstance(x, Tensor) and x.is_floating_point() and self.input_quantizers[0]: x = self.input_quantizers[0](x) if isinstance(y, Tensor) and y.is_floating_point() and self.input_quantizers[1]: y = self.input_quantizers[1](y) if isinstance(z, Tensor) and z.is_floating_point() and self.input_quantizers[2]: z = self.input_quantizers[2](z) with self._patch_quantized_parameters(): output = super().forward(x, y, z, *others, **kwargs) if isinstance(output, Tensor) and output.is_floating_point() and self.output_quantizers[0]: output = self.output_quantizers[0](output) return output def _remove_quantizers(quantizers, keys): orig_quantizers = {key: quantizers[key] for key in keys} def restore_quantizers(): for key, orig_qtzr in orig_quantizers.items(): quantizers[key] = orig_qtzr ctx = _ContextManager(action=lambda: None, cleanup=restore_quantizers) try: for key in keys: quantizers[key] = None except Exception: ctx._cleanup() # pylint: disable=protected-access raise else: return ctx