Source code for aimet_torch.v2.quantization.base.quantizer

# -*- mode: python -*-
# =============================================================================
#  @@-COPYRIGHT-START-@@
#
#  Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#
#  1. Redistributions of source code must retain the above copyright notice,
#     this list of conditions and the following disclaimer.
#
#  2. Redistributions in binary form must reproduce the above copyright notice,
#     this list of conditions and the following disclaimer in the documentation
#     and/or other materials provided with the distribution.
#
#  3. Neither the name of the copyright holder nor the names of its contributors
#     may be used to endorse or promote products derived from this software
#     without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
#  POSSIBILITY OF SUCH DAMAGE.
#
#  SPDX-License-Identifier: BSD-3-Clause
#
#  @@-COPYRIGHT-END-@@
# =============================================================================
""" Quantizer base class """

import abc
import copy
from collections import OrderedDict
import contextlib
import weakref
from typing import Optional, List, Dict
import functools

import torch
from torch import nn
from torch.utils._pytree import tree_map

from packaging import version  # pylint: disable=wrong-import-order
from aimet_torch.v2.quantization.base import EncodingBase
from aimet_torch.v2.quantization.encoding_analyzer import EncodingAnalyzer


__all__ = ['QuantizerBase']


[docs]class QuantizerBase(abc.ABC, torch.nn.Module): """ Quantizer base class """ encoding_analyzer: EncodingAnalyzer def __init__(self): super().__init__() # param_name -> (weakref of initial parameter, version info of the initial parameter) # This info will be used for judging whether the current parameter has ever been # initialized after it was instantiated. self._initial_parameters = OrderedDict() self._allow_overwrite = True
[docs] @abc.abstractmethod @contextlib.contextmanager def compute_encodings(self): """ Observe inputs and update quantization parameters based on the input statistics. """
[docs] @abc.abstractmethod def get_legacy_encodings(self) -> Optional[List[Dict]]: """ Returns a list of encodings, each represented as a List of Dicts """
[docs] @abc.abstractmethod def set_legacy_encodings(self, encodings: List[Dict]): """ Set encodings represented in the same format as the output of get_legacy_encodings. """
[docs] @abc.abstractmethod def get_encoding(self) -> Optional[EncodingBase]: """ Return the quantizer's encodings as an EncodingBase object """
[docs] def register_quantization_parameter(self, name: str, param: nn.Parameter): """ Register quantization parameter. """ # pylint: disable=protected-access self.register_parameter(name, param) param = getattr(self, name) self._initial_parameters[name] = (weakref.ref(param), param._version)
[docs] def is_initialized(self) -> bool: """ Returns true if the quantization parameters are initialized. """ for param_name, _ in self.named_parameters(): if not self._is_initialized(param_name): return False return True
def _is_initialized(self, param_name) -> bool: # pylint: disable=protected-access initial_param_weakref, initial_param_version = self._initial_parameters[param_name] initial_param = initial_param_weakref() if initial_param is None: # The initial parameter object doesn't exist in memory space anymore. return True current_param = getattr(self, param_name) if current_param is initial_param and current_param._version == initial_param_version: # 1. Current parameter is the identical object as the initial parameter # 2. The version nubmer of the current parameter never changed return False return True def state_dict(self, *args, **kwargs): # pylint: disable=arguments-differ state_dict = super().state_dict(*args, **kwargs) # pylint: disable=missing-kwoa if version.parse(torch.__version__) < version.parse("1.10"): # This is for backward compatibility with torch < 1.10 # which doesn't support get/set_extra_state() hooks prefix = kwargs['prefix'] state_dict[f'{prefix}extra_state'] = self.get_extra_state() return state_dict def load_state_dict(self, state_dict, strict: bool = True): # pylint:disable=arguments-differ if '_extra_state' not in state_dict: is_initialized = OrderedDict({ param_name: torch.tensor(True) for param_name in state_dict if param_name in self._parameters }) state_dict['_extra_state'] = is_initialized ret = super().load_state_dict(state_dict, strict) if version.parse(torch.__version__) < version.parse("1.10"): # This is for backward compatibility with torch < 1.10 # which doesn't support get/set_extra_state() hooks self.set_extra_state(state_dict['_extra_state']) return ret def get_extra_state(self): """ Get extra state that describes which parameters are initialized. """ extra_state_dict = OrderedDict({ param_name: torch.tensor(self._is_initialized(param_name)) for param_name, _ in self.named_parameters() }) # NOTE: This is a hack to bypass a bug in PyTorch onnx export # where it assumes state dict is always Mapping[str, Tensor] # and tries to `.detach()` all the values in the state dict. setattr(extra_state_dict, 'detach', functools.partial(tree_map, torch.Tensor.detach, extra_state_dict)) return extra_state_dict @torch.no_grad() def set_extra_state(self, state): """ Set extra state that describes which parameters are initialized. """ is_initialized = state for param_name, param in self._parameters.items(): if param_name in is_initialized: self.register_quantization_parameter(param_name, param) if is_initialized[param_name]: # If the parameter has been already initialized, # artificially increment the parameter version to mark as initialized param.mul_(1.) @torch.no_grad() def __deepcopy__(self, memo): self_copy = self.__new__(type(self)) self_copy.__dict__ = copy.deepcopy(self.__dict__, memo) self_copy.set_extra_state(self.get_extra_state()) return self_copy def __getstate__(self): getstate = getattr(super(), '__getstate__', self.__dict__.copy) state = getstate() state.pop('_initial_parameters') state['is_initialized'] = self.get_extra_state() return state @torch.no_grad() def __setstate__(self, state): self._initial_parameters = OrderedDict() is_initialized = state.pop('is_initialized') setstate = getattr(super(), '__setstate__', self.__dict__.update) setstate(state) self.set_extra_state(is_initialized)
[docs] def allow_overwrite(self, mode: bool): """ Set allow_overwite flag """ self._allow_overwrite = mode