Source code for aimet_torch.quantization.base.quantizer

# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause


"""Quantizer base class"""

import abc
import copy
from collections import OrderedDict
import contextlib
import weakref
from typing import Optional, List, Dict, TYPE_CHECKING, overload

import torch
from torch import nn

from packaging import version
from aimet_torch.common.utils import deprecated
from aimet_torch.quantization.base import EncodingBase
from aimet_torch.quantization.encoding_analyzer import EncodingAnalyzer

if TYPE_CHECKING:
    # pylint: disable=cyclic-import
    from aimet_torch.quantization.tensor import QuantizedTensorBase


__all__ = ["QuantizerBase"]


[docs] class QuantizerBase(abc.ABC, torch.nn.Module): """ Quantizer base class """ encoding_analyzer: EncodingAnalyzer def __init__(self): super().__init__() # param_name -> (weakref of initial parameter, version info of the initial parameter) # This info will be used for judging whether the current parameter has ever been # initialized after it was instantiated. self._initial_parameters = OrderedDict() self._is_overwrite_allowed: dict[str, bool] = OrderedDict()
[docs] def forward(self, input: torch.Tensor) -> "QuantizedTensorBase": # pylint: disable=redefined-builtin """ Quantize the input tensor Args: input (torch.Tensor): Input tensor to quantize """ # Call parent's forward to throw NotImplementedError return super().forward(input)
[docs] @abc.abstractmethod @contextlib.contextmanager def compute_encodings(self): """ Observe inputs and update quantization parameters based on the input statistics. """
@abc.abstractmethod def get_legacy_encodings(self) -> Optional[List[Dict]]: """ Returns a list of encodings, each represented as a List of Dicts """ @abc.abstractmethod def set_legacy_encodings(self, encodings: List[Dict]): """ Set encodings represented in the same format as the output of get_legacy_encodings. """ @abc.abstractmethod def get_encodings(self) -> Optional[EncodingBase]: """ Return the quantizer's encodings as an EncodingBase object """ @deprecated(f"Use {get_encodings.__qualname__} instead") def get_encoding(self) -> Optional[EncodingBase]: """ Alias of get_encodings """ return self.get_encodings() def set_encodings(self, encodings: EncodingBase): """ Set the quantizer's encodings """ raise NotImplementedError @classmethod @abc.abstractmethod def from_encodings(cls, encodings: EncodingBase) -> "QuantizerBase": """ Create quantizer object from encoding object """ def register_quantization_parameter(self, name: str, param: Optional[nn.Parameter]): """ Register quantization parameter. """ # pylint: disable=protected-access self.register_parameter(name, param) if param is not None: self._initial_parameters[name] = (weakref.ref(param), param._version)
[docs] def is_initialized(self) -> bool: """ Returns true if the quantization parameters are initialized. """ return all( self._is_initialized(param_name, param) for param_name, param in self.named_parameters() )
def _is_initialized( self, param_name: str, current_param: torch.nn.Parameter ) -> bool: # pylint: disable=protected-access initial_param_weakref, initial_param_version = self._initial_parameters.get( param_name, (None, None) ) if not initial_param_weakref: # parameters created using register_parameter need not be initialized return True initial_param = initial_param_weakref() if initial_param is None: # The initial parameter object doesn't exist in memory space anymore. return True if ( current_param is initial_param and current_param._version == initial_param_version ): # 1. Current parameter is the identical object as the initial parameter # 2. The version nubmer of the current parameter never changed return False return True def state_dict(self, *args, **kwargs): # pylint: disable=arguments-differ state_dict = super().state_dict(*args, **kwargs) # pylint: disable=missing-kwoa if version.parse(torch.__version__) < version.parse("1.10"): # This is for backward compatibility with torch < 1.10 # which doesn't support get/set_extra_state() hooks prefix = kwargs.get("prefix", "") state_dict[f"{prefix}extra_state"] = self.get_extra_state() if ( version.parse(torch.__version__) < version.parse("2") and torch.onnx.is_in_onnx_export() ): # During ONNX export in torch 1.x, state_dict cannot contain # non-tensor objects due to a bug in torch.onnx.export. # Skip adding extra_state because extra_state is unnecessary anyway # for ONNX export. prefix = kwargs.get("prefix", "") state_dict.pop(f"{prefix}_extra_state", None) return state_dict def load_state_dict(self, state_dict, *args, **kwargs): # pylint:disable=arguments-differ if "_extra_state" not in state_dict: is_initialized = OrderedDict( { param_name: torch.tensor(True) for param_name in state_dict if param_name in self._parameters } ) state_dict["_extra_state"] = is_initialized ret = super().load_state_dict(state_dict, *args, **kwargs) if version.parse(torch.__version__) < version.parse("1.10"): # This is for backward compatibility with torch < 1.10 # which doesn't support get/set_extra_state() hooks self.set_extra_state(state_dict["_extra_state"]) return ret def get_extra_state(self): """ Get extra state that describes which parameters are initialized. """ if torch.onnx.is_in_onnx_export(): # Bypass get_extra_state during ONNX export. # ONNX export doesn't support non-tensor objects in state_dict # Return empty tensor since extra state is unnecessary for ONNX export anyway return torch.tensor([]) return { param_name: torch.tensor(self._is_initialized(param_name, param)) for param_name, param in self.named_parameters() } @torch.no_grad() def set_extra_state(self, state): """ Set extra state that describes which parameters are initialized. """ is_initialized = state for param_name, param in self._parameters.items(): if param_name in is_initialized: self.register_quantization_parameter(param_name, param) if is_initialized[param_name]: # If the parameter has been already initialized, # artificially increment the parameter version to mark as initialized param.mul_(1.0) @torch.no_grad() def __deepcopy__(self, memo): cls = type(self) self_copy = cls.__new__(cls) self_copy.__dict__ = copy.deepcopy(self.__dict__, memo) self_copy.set_extra_state(self.get_extra_state()) return self_copy def __getstate__(self): getstate = getattr(super(), "__getstate__", self.__dict__.copy) state = getstate() state.pop("_initial_parameters") state["is_initialized"] = self.get_extra_state() return state @torch.no_grad() def __setstate__(self, state): self._initial_parameters = OrderedDict() is_initialized = state.pop("is_initialized") setstate = getattr(super(), "__setstate__", self.__dict__.update) setstate(state) self.set_extra_state(is_initialized) @overload def allow_overwrite(self, mode: bool): """Set allow_overwite flag""" @overload def allow_overwrite(self, **kwargs): """Set allow_overwite flag""" def allow_overwrite(self, *args, **kwargs): mode = kwargs.get("mode", args[0] if args else None) if mode is not None: allow_overwrite = { param_name: mode for param_name in self._is_overwrite_allowed.keys() } else: allow_overwrite = kwargs.copy() expected_keys = self._is_overwrite_allowed.keys() unexpected_keys = allow_overwrite.keys() - expected_keys if unexpected_keys: unexpected_keys = sorted(list(unexpected_keys)) expected_keys = sorted(list(expected_keys)) raise RuntimeError( f"'allow_overwrite' expected param names {expected_keys};" f" got unexpected parameter names {unexpected_keys}" ) self._is_overwrite_allowed.update(allow_overwrite) def is_overwrite_allowed(self, name: str): return self._is_overwrite_allowed[name] # Define _allow_overwrite getter/setter for backwards compatibility @property @deprecated( f"Use {is_overwrite_allowed.__qualname__}(<param_name: str>) function instead" ) def _allow_overwrite(self) -> bool: return any(self._is_overwrite_allowed.values()) @_allow_overwrite.setter @deprecated(f"Use {allow_overwrite.__qualname__} function instead") def _allow_overwrite(self, mode: bool): self.allow_overwrite(mode) @contextlib.contextmanager def _precompute_encodings(self, dtype: torch.dtype | None = None): # pylint: disable=unused-argument yield