# -*- mode: python -*-# =============================================================================# @@-COPYRIGHT-START-@@## Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved.## Redistribution and use in source and binary forms, with or without# modification, are permitted provided that the following conditions are met:## 1. Redistributions of source code must retain the above copyright notice,# this list of conditions and the following disclaimer.## 2. Redistributions in binary form must reproduce the above copyright notice,# this list of conditions and the following disclaimer in the documentation# and/or other materials provided with the distribution.## 3. Neither the name of the copyright holder nor the names of its contributors# may be used to endorse or promote products derived from this software# without specific prior written permission.## THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE# POSSIBILITY OF SUCH DAMAGE.## SPDX-License-Identifier: BSD-3-Clause## @@-COPYRIGHT-END-@@# =============================================================================# pylint: disable=too-many-lines, wrong-import-order, redefined-builtin""" Quantized modules"""frompackagingimportversionimportcontextlibimportitertoolsfrominspectimportsignaturefromabcimportabstractmethod,ABCMetafromcollectionsimportOrderedDictfromtypingimportType,Any,Optional,Callable,DictfromweakrefimportWeakKeyDictionaryimportwarningsimporttorchimporttorch.nnasnnimporttorch.nn.functionalasFfromtorchimportTensorfromtorch.overridesimportBaseTorchFunctionMode,get_overridable_functionsfromtorch._VFimport(# pylint: disable=no-name-in-modulegruas_gru,gru_cellas_gru_cell,lstmas_lstm,lstm_cellas_lstm_cell,rnn_reluas_rnn_relu,rnn_tanhas_rnn_tanh,rnn_relu_cellas_rnn_relu_cell,rnn_tanh_cellas_rnn_tanh_cell,)fromaimet_torch.v2.quantization.baseimportQuantizerBasefromaimet_torch.v2.quantization.tensorimportQuantizedTensorBasefromaimet_torch.v2.utilsimportpatch_attr,_ContextManager,allow_recomputefrom.baseimportBaseQuantizationMixin# pylint: disable=import-errordef_quantize_if_applicable(data:Any,quantizer:Optional[QuantizerBase]):""" Quantize data if it is a quantizable type and quantize is not None """ifquantizerandisinstance(data,Tensor)anddata.is_floating_point():ifisinstance(data,QuantizedTensorBase):data=data.dequantize()returnquantizer(data)ifisinstance(data,QuantizedTensorBase):returndata.quantize()returndatadef_dequantize_if_applicable(data:torch.Tensor):returndata.dequantize()ifisinstance(data,QuantizedTensorBase)elsedatadef_quantize_dequantize_if_applicable(data,quantizer):ifquantizerandisinstance(data,Tensor)anddata.is_floating_point():ifisinstance(data,QuantizedTensorBase):data=data.dequantize()data=quantizer(data)ifisinstance(data,QuantizedTensorBase):returndata.dequantize()returndata_QUANTIZED_MODULES_UNDER_COMPUTE_ENCODINGS=WeakKeyDictionary()def_is_computing_encodings(qmodule):return_QUANTIZED_MODULES_UNDER_COMPUTE_ENCODINGS.get(qmodule,0)>0def_enter_computing_encodings(qmodule):ifqmodulenotin_QUANTIZED_MODULES_UNDER_COMPUTE_ENCODINGS:_QUANTIZED_MODULES_UNDER_COMPUTE_ENCODINGS[qmodule]=0_QUANTIZED_MODULES_UNDER_COMPUTE_ENCODINGS[qmodule]+=1def_exit_compute_encodings(qmodule):assert_QUANTIZED_MODULES_UNDER_COMPUTE_ENCODINGS[qmodule]>0_QUANTIZED_MODULES_UNDER_COMPUTE_ENCODINGS[qmodule]-=1classQuantizationMixinMeta(ABCMeta):"""Sets :meth:`forward` to :meth:`quantized_forward` if only :meth:`quantized_forward` is defined """def__new__(mcs,name,bases,namespace,**kwargs):if"quantized_forward"innamespaceand"forward"notinnamespace:warnings.warn("Support for defining `quantized_forward` in place of `forward` method will be deprecated, ""please use `forward` instead.",DeprecationWarning,stacklevel=2)namespace["forward"]=namespace["quantized_forward"]returnsuper().__new__(mcs,name,bases,namespace,**kwargs)
[docs]classQuantizationMixin(BaseQuantizationMixin,metaclass=QuantizationMixinMeta):# pylint: disable=abstract-method"""Quantization mixin class for torch.nn.Module. Specifically, a quantized module will quantize input, output, and parameter tensors with its held :class:`QuantizerBase` objects during the :meth:`forward` method and use the inherited :class:`torch.nn.Module` forward method to compute the layer operation. If all input, output, and parameter quantizers are ``None``, a quantized module will behave exactly the same as its parent :class:`torch.nn.Module`. Attributes: input_quantizers: :class:`torch.nn.ModuleList` containing :class:`QuantizerBase` objects to be applied to the layer's input tensors output_quantizers: :class:`torch.nn.ModuleList` containing :class:`QuantizerBase` objects to be applied to the layer's output tensors param_quantizers: :class:`torch.nn.ModuleDict` mapping parameter names to associated :class:`QuantizerBase` objects Examples: >>> qlinear = QuantizedLinear(in_features=10, out_features=10) >>> print(qlinear) QuantizedLinear( in_features=10, out_features=10, bias=True (param_quantizers): ModuleDict( (weight): None (bias): None ) (input_quantizers): ModuleList( (0): None ) (output_quantizers): ModuleList( (0): None ) ) """cls_to_qcls=OrderedDict()# quantized class -> original classqcls_to_cls=OrderedDict()# original class -> quantized class_default_kernel:Optional[Callable]=None_kernels=WeakKeyDictionary()# instance -> instance_kernel
[docs]@abstractmethoddefforward(self,*args,**kwargs):"""Computes a quantized version of the parent module's forward method. The :meth:`forward` method should perform the following logic in order: 1) Apply existing input quantizers to input tensors 2) Apply existing param quantizers to the layer's parameters 3) Call the inherited :class:`torch.nn.Module` forward method with quantized inputs and parameters 4) Apply existing output quantizers to the outputs of the forward method If all input, output, and parameter quantizers are ``None``, this method will behave exactly the same as its parent module's forward pass. """returnsuper().forward(*args,**kwargs)
[docs]@classmethoddefset_default_kernel(cls,kernel:Callable):"""Set default kernel for the class. In general, this signature will follow the signature of the equivalent :mod:`torch.nn.functional` function, but should return a :class:`QuantizedTensor` object and take in the additional keyword argument ``output_encodings``. Once set, all instances of cls will call into kernel in the forward pass unless: 1) The instance is within the :meth:`compute_encodings` context, or 2) The kernel has been overridden by a :meth:`set_kernel` call Args: kernel: Callable object to be used as the default kernel by all the instances of this class. Example: >>> from aimet_torch.v2 import quantization as Q >>> def int_multiply(a, b, output_encodings=None): ... encodings = [a.encoding, b.encoding, output_encodings] ... if not all(enc.mapping == "affine" for enc in encodings): ... raise NotImplementedError ... q_output = (a.quantized_repr() + a.encoding.offset) * (b.quantized_repr() + b.encoding.offset) ... dq_output = q_output * (a.encoding.scale * b.encoding.scale) ... return Q.QuantizedTensor(output_encodings.quantize(dq_output), encoding=output_encodings) ... >>> QuantizedMultiply.set_default_kernel(int_multiply) >>> qmult = QuantizedMultiply() >>> qmult.get_kernel() <function int_multiply at ...> """cls._default_kernel=kernel
[docs]@classmethoddefget_default_kernel(cls)->Optional[Callable]:"""Return the default kernel of the class Returns: Default kernel of the class. None if the default kernel is not set. """returncls._default_kernel
[docs]defset_kernel(self,kernel:Callable):"""Set kernel for this instance of quantized module. In general, this signature will follow the signature of the equivalent :mod:`torch.nn.functional` function, but should return a :class:`QuantizedTensor` object and take in the additional keyword argument ``output_encodings``. Once set, the layer will call into ``kernel`` in the forward pass unless within the :meth:`compute_encodings` context. Args: kernel: Callable object to be used as the underlying kernel. Example: >>> from aimet_torch.v2 import quantization as Q >>> def int_multiply(a, b, output_encodings=None): ... encodings = [a.encoding, b.encoding, output_encodings] ... if not all(enc.mapping == "affine" for enc in encodings): ... raise NotImplementedError ... q_output = (a.quantized_repr() + a.encoding.offset) * (b.quantized_repr() + b.encoding.offset) ... dq_output = q_output * (a.encoding.scale * b.encoding.scale) ... return Q.QuantizedTensor(output_encodings.quantize(dq_output), encoding=output_encodings) ... >>> qmult = QuantizedMultiply() >>> qmult.set_kernel(int_multiply) """QuantizationMixin._kernels[self]=kernel
[docs]defget_kernel(self)->Optional[Callable]:"""Return the kernel to be used by this instance of quantized module. If the current instance does not have any kernel set, it will retrieve the default kernel of the class. Returns: The kernel to be used by this instance. """ifselfinQuantizationMixin._kernels:returnQuantizationMixin._kernels[self]returnself.get_default_kernel()
[docs]@classmethoddefwrap(cls,module_cls:Type[nn.Module])->Type[nn.Module]:""" Wrap a regular module class into a quantized module class """ifnotissubclass(module_cls,nn.Module):raiseValueError("Expected module_cls to be a subclass of torch.nn.Module. "f"Got {module_cls}.")ifmodule_clsincls.cls_to_qcls:returncls.cls_to_qcls[module_cls]quantized_cls_name=f"Quantized{module_cls.__name__}"base_classes=(cls,module_cls)quantized_cls=type(quantized_cls_name,base_classes,{'__module__':__name__})returncls.implements(module_cls)(quantized_cls)
[docs]@classmethoddeffrom_module(cls,module:nn.Module):r"""Create an instance of quantized module from a regular module instance. The resulting quantized module contains the same attributes and parameters as the original module, but may be assigned input, output and parameter quantizers. :param module: Floating point module to quantize :return: Quantized version of the original module Example: >>> linear = torch.nn.Linear(10, 10) >>> quantized_linear = QuantizationMixin.from_module(linear) >>> print(quantized_linear.param_quantizers) QuantizedLinear( in_features=10, out_features=10, bias=True (param_quantizers): ModuleDict( (weight): None (bias): None ) (input_quantizers): ModuleList( (0): None ) (output_quantizers): ModuleList( (0): None ) ) >>> print(quantized_linear.weight is linear.weight) True """returnsuper().from_module(module)
[docs]@classmethoddefimplements(cls,module_cls):r""" Decorator for registering quantized definition of the given base class. Even though AIMET supports quantization of all built-in modules in torch.nn subpackage such as ``torch.nn.Conv2d`` or ``torch.nn.Linear`` that AIMET is already aware of, :class:`QuantizationSimModel` will throw a runtime error when it encounters custom modules defined by the users, asking the users to provide the quantized definition of the custom modules that AIMET doesn't know of. To declare the quantized definition of a module, :class:`QuantizationSimModel` requires you to define a subclass of your module decorated with :meth:`implements`, in which you will implement ``__quant_init__`` and ``forward`` methods. As an example, given a custom module as below:: class MaskedAdd(torch.nn.Module): def forward(self, input: torch.Tensor, mask: torch.Tensor, value: torch.Tensor): return input + mask * value its quantized definition should be declared before creating :class:`QuantizationSimModel`, typically as below:: @QuantizationMixin.implements(MaskedAdd) class QuantizedMaskedAdd(QuantizationMixin, MaskedAdd): # The quantized definition of MaskedAdd should be a subclass of # QuantizationMixin and MaskedAdd (Order matters!) def __quant_init__(self): super().__quant_init__() # Declare the number of input/output quantizers self.input_quantizers = torch.nn.ModuleList([None, None, None]) self.output_quantizers = torch.nn.ModuleList([None]) def forward(self, input: torch.Tensor, mask: torch.Tensor, value: torch.Tensor): input_qtzr = self.input_quantizers[0] _ = self.input_quantizers[1] # I don't want to quantize the boolean masks! value_qtzr = self.input_quantizers[2] output_qtzr = self.output_quantizers[0] if input_qtzr is not None: input = input_qtzr(input) if value_qtzr is not None: value = value_qtzr(value) output = super().forward(input, mask, value) if output_qtzr is not None: output = output_qtzr(output) return output """returnsuper().implements(module_cls)
# pylint: disable=too-many-ancestors_dispatch_table:Dict[Callable,Optional[Callable]]_dispatch_table={torch_fn:Nonefortorch_fninitertools.chain(*get_overridable_functions().values())}# NOTE: ``torch.overrides.get_overridable_functions()`` doesn't include# F.hardswish, F.hardsigmoid, or Tensor.unflatten, even though# they are implemented in a perfectly dispatchable manner._dispatch_table[F.hardswish]=None_dispatch_table[F.hardsigmoid]=None_dispatch_table[Tensor.unflatten]=Noneclass_Dispatcher(BaseTorchFunctionMode):def__torch_function__(self,func,types,args=(),kwargs=None):impl=_dispatch_table.get(func,None)ifimplisNone:impl=funcreturnsuper().__torch_function__(impl,types,args,kwargs)@contextlib.contextmanagerdef_dispatch(torch_func:Callable,custom_impl:Callable):try:orig=_dispatch_table[torch_func]exceptKeyErrorase:raiseRuntimeError(f"PyTorch doesn't support overriding {torch_func}")frometry:_dispatch_table[torch_func]=custom_implwith_Dispatcher():yieldfinally:_dispatch_table[torch_func]=origclass_DispatchMeta(QuantizationMixinMeta):def__new__(mcs,name,bases,namespace,**kwargs):""" Sanity check for class definitions of dispatch-based quantized modules """if'_builtin_torch_fn'innamespace:torch_fn=namespace['_builtin_torch_fn']iftorch_fnandtorch_fnnotin_dispatch_table:raiseRuntimeError(f"PyTorch doesn't support overriding {torch_fn}")returnsuper().__new__(mcs,name,bases,namespace,**kwargs)class_DispatchMixin(metaclass=_DispatchMeta):_builtin_torch_fn:Optional[Callable]=Nonedef_get_builtin_torch_fn(self):returntype(self)._builtin_torch_fndefforward(self,*args,**kwargs):# pylint: disable=missing-function-docstringkernel=self.get_kernel()builtin_torch_fn=self._get_builtin_torch_fn()ifnotkernelor_is_computing_encodings(self):kernel=self._builtin_torch_fn_helper(builtin_torch_fn)else:kernel=self._custom_kernel_helper(kernel)withself._patch_quantized_parameters():with_dispatch(builtin_torch_fn,kernel):output=super().forward(*args,**kwargs)return_dequantize_if_applicable(output)def_builtin_torch_fn_helper(self,fn:Callable[...,Tensor]):defwrapper(*args,**kwargs):qtzd_args=(_quantize_dequantize_if_applicable(x,qtzr)forx,qtzrinzip(args,self.input_quantizers))others=(_dequantize_if_applicable(x)forxinargs[len(self.input_quantizers):])kwargs={key:_dequantize_if_applicable(value)forkey,valueinkwargs.items()}output=fn(*qtzd_args,*others,**kwargs)return_quantize_dequantize_if_applicable(output,self.output_quantizers[0])returnwrapperdef_custom_kernel_helper(self,fn:Callable[...,QuantizedTensorBase]):defwrapper(*args,**kwargs):qtzd_args=(_quantize_if_applicable(x,qtzr)forx,qtzrinzip(args,self.input_quantizers))others=args[len(self.input_quantizers):]output_encodings=self.output_quantizers[0].get_encodings()ifself.output_quantizers[0]elseNonekwargs.update(output_encodings=output_encodings)returnfn(*qtzd_args,*others,**kwargs)returnwrapperdef__nullary__(self):super(type(self),self).__quant_init__()self.input_quantizers=nn.ModuleList([])def__unary__(self):super(type(self),self).__quant_init__()def__binary__(self):super(type(self),self).__quant_init__()self.input_quantizers=nn.ModuleList([None,None])def__ternary__(self):super(type(self),self).__quant_init__()self.input_quantizers=nn.ModuleList([None,None,None])def_generate_docstring(parent_cls):return \
f""" Quantized subclass of torch.nn.{parent_cls.__name__} .. method:: forward{str(signature(parent_cls.forward))} :noindex: Quantized forward of torch.nn.{parent_cls.__name__}. The input(s), parameter(s) (if any), and output(s) will be quantized with ``self.input_quantizers``, ``self.param_quantizers``, and ``self.output_quantizers`` respectively. For more information, see :class:`QuantizationMixin`. """
[docs]@QuantizationMixin.implements(nn.Linear)classQuantizedLinear(_DispatchMixin,QuantizationMixin,nn.Linear):# pylint: disable=missing-class-docstring__doc__=_generate_docstring(parent_cls=nn.Linear)_builtin_torch_fn=F.linear__quant_init__=__unary__# Only allow activation recompute (a.k.a activation checkpointing) for QuantizedLinear.# This is mainly to reduce memory footprint of QAT of large language models.
[docs]@allow_recomputedefforward(self,*args,**kwargs):# Workaround for deepspeed.# Deepspeed zero3 sometimes forcefully mokey-patches F.linear to torch.addmm,# which collides with the core assumption of our dispatch mechanism# that nn.Linear invokes F.linear.# To circumvent this issue, we temporarily restore the original F.linear# before running forward.withpatch_attr(F,'linear',type(self)._builtin_torch_fn):returnsuper().forward(*args,**kwargs)