Source code for aimet_torch.v2.nn.true_quant

# -*- mode: python -*-
# =============================================================================
#  @@-COPYRIGHT-START-@@
#
#  Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#
#  1. Redistributions of source code must retain the above copyright notice,
#     this list of conditions and the following disclaimer.
#
#  2. Redistributions in binary form must reproduce the above copyright notice,
#     this list of conditions and the following disclaimer in the documentation
#     and/or other materials provided with the distribution.
#
#  3. Neither the name of the copyright holder nor the names of its contributors
#     may be used to endorse or promote products derived from this software
#     without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
#  POSSIBILITY OF SUCH DAMAGE.
#
#  SPDX-License-Identifier: BSD-3-Clause
#
#  @@-COPYRIGHT-END-@@
# =============================================================================
# pylint: disable=too-many-lines, wrong-import-order, redefined-builtin
""" Quantized modules"""

from packaging import version
import contextlib
import itertools
from inspect import signature
from abc import abstractmethod, ABCMeta
from collections import OrderedDict
from typing import Type, Any, Optional, Callable, Dict
from weakref import WeakKeyDictionary
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.overrides import BaseTorchFunctionMode, get_overridable_functions
from torch._VF import ( # pylint: disable=no-name-in-module
    gru as _gru,
    gru_cell as _gru_cell,
    lstm as _lstm,
    lstm_cell as _lstm_cell,
    rnn_relu as _rnn_relu,
    rnn_tanh as _rnn_tanh,
    rnn_relu_cell as _rnn_relu_cell,
    rnn_tanh_cell as _rnn_tanh_cell,
)

from aimet_torch.v2.quantization.base import QuantizerBase
from aimet_torch.v2.quantization.tensor import QuantizedTensorBase
from aimet_torch.v2.utils import patch_attr, _ContextManager, allow_recompute
from .base import BaseQuantizationMixin # pylint: disable=import-error


def _quantize_if_applicable(data: Any, quantizer: Optional[QuantizerBase]):
    """
    Quantize data if it is a quantizable type and quantize is not None
    """
    if quantizer and isinstance(data, Tensor) and data.is_floating_point():
        if isinstance(data, QuantizedTensorBase):
            data = data.dequantize()
        return quantizer(data)

    if isinstance(data, QuantizedTensorBase):
        return data.quantize()

    return data


def _dequantize_if_applicable(data: torch.Tensor):
    return data.dequantize() if isinstance(data, QuantizedTensorBase) else data


def _quantize_dequantize_if_applicable(data, quantizer):
    if quantizer and isinstance(data, Tensor) and data.is_floating_point():
        if isinstance(data, QuantizedTensorBase):
            data = data.dequantize()
        data = quantizer(data)

    if isinstance(data, QuantizedTensorBase):
        return data.dequantize()

    return data


_QUANTIZED_MODULES_UNDER_COMPUTE_ENCODINGS = WeakKeyDictionary()


def _is_computing_encodings(qmodule):
    return _QUANTIZED_MODULES_UNDER_COMPUTE_ENCODINGS.get(qmodule, 0) > 0


def _enter_computing_encodings(qmodule):
    if qmodule not in _QUANTIZED_MODULES_UNDER_COMPUTE_ENCODINGS:
        _QUANTIZED_MODULES_UNDER_COMPUTE_ENCODINGS[qmodule] = 0
    _QUANTIZED_MODULES_UNDER_COMPUTE_ENCODINGS[qmodule] += 1


def _exit_compute_encodings(qmodule):
    assert _QUANTIZED_MODULES_UNDER_COMPUTE_ENCODINGS[qmodule] > 0
    _QUANTIZED_MODULES_UNDER_COMPUTE_ENCODINGS[qmodule] -= 1


class QuantizationMixinMeta(ABCMeta):
    """Sets :meth:`forward` to :meth:`quantized_forward` if only :meth:`quantized_forward` is defined
    """

    def __new__(mcs, name, bases, namespace, **kwargs):
        if "quantized_forward" in namespace and "forward" not in namespace:
            warnings.warn("Support for defining `quantized_forward` in place of `forward` method will be deprecated, "
                          "please use `forward` instead.",
                          DeprecationWarning, stacklevel=2)
            namespace["forward"] = namespace["quantized_forward"]
        return super().__new__(mcs, name, bases, namespace, **kwargs)


[docs] class QuantizationMixin(BaseQuantizationMixin, metaclass=QuantizationMixinMeta): # pylint: disable=abstract-method """Quantization mixin class for torch.nn.Module. Specifically, a quantized module will quantize input, output, and parameter tensors with its held :class:`QuantizerBase` objects during the :meth:`forward` method and use the inherited :class:`torch.nn.Module` forward method to compute the layer operation. If all input, output, and parameter quantizers are ``None``, a quantized module will behave exactly the same as its parent :class:`torch.nn.Module`. Attributes: input_quantizers: :class:`torch.nn.ModuleList` containing :class:`QuantizerBase` objects to be applied to the layer's input tensors output_quantizers: :class:`torch.nn.ModuleList` containing :class:`QuantizerBase` objects to be applied to the layer's output tensors param_quantizers: :class:`torch.nn.ModuleDict` mapping parameter names to associated :class:`QuantizerBase` objects Examples: >>> qlinear = QuantizedLinear(in_features=10, out_features=10) >>> print(qlinear) QuantizedLinear( in_features=10, out_features=10, bias=True (param_quantizers): ModuleDict( (weight): None (bias): None ) (input_quantizers): ModuleList( (0): None ) (output_quantizers): ModuleList( (0): None ) ) """ cls_to_qcls = OrderedDict() # quantized class -> original class qcls_to_cls = OrderedDict() # original class -> quantized class _default_kernel: Optional[Callable] = None _kernels = WeakKeyDictionary() # instance -> instance_kernel
[docs] @abstractmethod def forward(self, *args, **kwargs): """Computes a quantized version of the parent module's forward method. The :meth:`forward` method should perform the following logic in order: 1) Apply existing input quantizers to input tensors 2) Apply existing param quantizers to the layer's parameters 3) Call the inherited :class:`torch.nn.Module` forward method with quantized inputs and parameters 4) Apply existing output quantizers to the outputs of the forward method If all input, output, and parameter quantizers are ``None``, this method will behave exactly the same as its parent module's forward pass. """ return super().forward(*args, **kwargs)
[docs] @classmethod def set_default_kernel(cls, kernel: Callable): """Set default kernel for the class. In general, this signature will follow the signature of the equivalent :mod:`torch.nn.functional` function, but should return a :class:`QuantizedTensor` object and take in the additional keyword argument ``output_encodings``. Once set, all instances of cls will call into kernel in the forward pass unless: 1) The instance is within the :meth:`compute_encodings` context, or 2) The kernel has been overridden by a :meth:`set_kernel` call Args: kernel: Callable object to be used as the default kernel by all the instances of this class. Example: >>> from aimet_torch.v2 import quantization as Q >>> def int_multiply(a, b, output_encodings=None): ... encodings = [a.encoding, b.encoding, output_encodings] ... if not all(enc.mapping == "affine" for enc in encodings): ... raise NotImplementedError ... q_output = (a.quantized_repr() + a.encoding.offset) * (b.quantized_repr() + b.encoding.offset) ... dq_output = q_output * (a.encoding.scale * b.encoding.scale) ... return Q.QuantizedTensor(output_encodings.quantize(dq_output), encoding=output_encodings) ... >>> QuantizedMultiply.set_default_kernel(int_multiply) >>> qmult = QuantizedMultiply() >>> qmult.get_kernel() <function int_multiply at ...> """ cls._default_kernel = kernel
[docs] @classmethod def get_default_kernel(cls) -> Optional[Callable]: """Return the default kernel of the class Returns: Default kernel of the class. None if the default kernel is not set. """ return cls._default_kernel
[docs] def set_kernel(self, kernel: Callable): """Set kernel for this instance of quantized module. In general, this signature will follow the signature of the equivalent :mod:`torch.nn.functional` function, but should return a :class:`QuantizedTensor` object and take in the additional keyword argument ``output_encodings``. Once set, the layer will call into ``kernel`` in the forward pass unless within the :meth:`compute_encodings` context. Args: kernel: Callable object to be used as the underlying kernel. Example: >>> from aimet_torch.v2 import quantization as Q >>> def int_multiply(a, b, output_encodings=None): ... encodings = [a.encoding, b.encoding, output_encodings] ... if not all(enc.mapping == "affine" for enc in encodings): ... raise NotImplementedError ... q_output = (a.quantized_repr() + a.encoding.offset) * (b.quantized_repr() + b.encoding.offset) ... dq_output = q_output * (a.encoding.scale * b.encoding.scale) ... return Q.QuantizedTensor(output_encodings.quantize(dq_output), encoding=output_encodings) ... >>> qmult = QuantizedMultiply() >>> qmult.set_kernel(int_multiply) """ QuantizationMixin._kernels[self] = kernel
[docs] def get_kernel(self) -> Optional[Callable]: """Return the kernel to be used by this instance of quantized module. If the current instance does not have any kernel set, it will retrieve the default kernel of the class. Returns: The kernel to be used by this instance. """ if self in QuantizationMixin._kernels: return QuantizationMixin._kernels[self] return self.get_default_kernel()
[docs] @contextlib.contextmanager def compute_encodings(self): # pylint: disable=missing-function-docstring ctx = _ContextManager(action=lambda: _enter_computing_encodings(self), cleanup=lambda: _exit_compute_encodings(self)) with super().compute_encodings(), ctx: yield
def _patch_dequantized_parameters(self): stack = contextlib.ExitStack() for param_name, _ in self.param_quantizers.items(): qparam = getattr(self, param_name) ctx = patch_attr(self, param_name, _dequantize_if_applicable(qparam)) stack.enter_context(ctx) return stack
[docs] @classmethod def wrap(cls, module_cls: Type[nn.Module]) -> Type[nn.Module]: """ Wrap a regular module class into a quantized module class """ if not issubclass(module_cls, nn.Module): raise ValueError("Expected module_cls to be a subclass of torch.nn.Module. " f"Got {module_cls}.") if module_cls in cls.cls_to_qcls: return cls.cls_to_qcls[module_cls] quantized_cls_name = f"Quantized{module_cls.__name__}" base_classes = (cls, module_cls) quantized_cls = type(quantized_cls_name, base_classes, {'__module__': __name__}) return cls.implements(module_cls)(quantized_cls)
[docs] @classmethod def from_module(cls, module: nn.Module): r"""Create an instance of quantized module from a regular module instance. The resulting quantized module contains the same attributes and parameters as the original module, but may be assigned input, output and parameter quantizers. :param module: Floating point module to quantize :return: Quantized version of the original module Example: >>> linear = torch.nn.Linear(10, 10) >>> quantized_linear = QuantizationMixin.from_module(linear) >>> print(quantized_linear.param_quantizers) QuantizedLinear( in_features=10, out_features=10, bias=True (param_quantizers): ModuleDict( (weight): None (bias): None ) (input_quantizers): ModuleList( (0): None ) (output_quantizers): ModuleList( (0): None ) ) >>> print(quantized_linear.weight is linear.weight) True """ return super().from_module(module)
[docs] @classmethod def implements(cls, module_cls): r""" Decorator for registering quantized definition of the given base class. Even though AIMET supports quantization of all built-in modules in torch.nn subpackage such as ``torch.nn.Conv2d`` or ``torch.nn.Linear`` that AIMET is already aware of, :class:`QuantizationSimModel` will throw a runtime error when it encounters custom modules defined by the users, asking the users to provide the quantized definition of the custom modules that AIMET doesn't know of. To declare the quantized definition of a module, :class:`QuantizationSimModel` requires you to define a subclass of your module decorated with :meth:`implements`, in which you will implement ``__quant_init__`` and ``forward`` methods. As an example, given a custom module as below:: class MaskedAdd(torch.nn.Module): def forward(self, input: torch.Tensor, mask: torch.Tensor, value: torch.Tensor): return input + mask * value its quantized definition should be declared before creating :class:`QuantizationSimModel`, typically as below:: @QuantizationMixin.implements(MaskedAdd) class QuantizedMaskedAdd(QuantizationMixin, MaskedAdd): # The quantized definition of MaskedAdd should be a subclass of # QuantizationMixin and MaskedAdd (Order matters!) def __quant_init__(self): super().__quant_init__() # Declare the number of input/output quantizers self.input_quantizers = torch.nn.ModuleList([None, None, None]) self.output_quantizers = torch.nn.ModuleList([None]) def forward(self, input: torch.Tensor, mask: torch.Tensor, value: torch.Tensor): input_qtzr = self.input_quantizers[0] _ = self.input_quantizers[1] # I don't want to quantize the boolean masks! value_qtzr = self.input_quantizers[2] output_qtzr = self.output_quantizers[0] if input_qtzr is not None: input = input_qtzr(input) if value_qtzr is not None: value = value_qtzr(value) output = super().forward(input, mask, value) if output_qtzr is not None: output = output_qtzr(output) return output """ return super().implements(module_cls)
# pylint: disable=too-many-ancestors _dispatch_table: Dict[Callable, Optional[Callable]] _dispatch_table = { torch_fn: None for torch_fn in itertools.chain(*get_overridable_functions().values()) } # NOTE: ``torch.overrides.get_overridable_functions()`` doesn't include # F.hardswish, F.hardsigmoid, or Tensor.unflatten, even though # they are implemented in a perfectly dispatchable manner. _dispatch_table[F.hardswish] = None _dispatch_table[F.hardsigmoid] = None _dispatch_table[Tensor.unflatten] = None class _Dispatcher(BaseTorchFunctionMode): def __torch_function__(self, func, types, args=(), kwargs=None): impl = _dispatch_table.get(func, None) if impl is None: impl = func return super().__torch_function__(impl, types, args, kwargs) @contextlib.contextmanager def _dispatch(torch_func: Callable, custom_impl: Callable): try: orig = _dispatch_table[torch_func] except KeyError as e: raise RuntimeError(f"PyTorch doesn't support overriding {torch_func}") from e try: _dispatch_table[torch_func] = custom_impl with _Dispatcher(): yield finally: _dispatch_table[torch_func] = orig class _DispatchMeta(QuantizationMixinMeta): def __new__(mcs, name, bases, namespace, **kwargs): """ Sanity check for class definitions of dispatch-based quantized modules """ if '_builtin_torch_fn' in namespace: torch_fn = namespace['_builtin_torch_fn'] if torch_fn and torch_fn not in _dispatch_table: raise RuntimeError(f"PyTorch doesn't support overriding {torch_fn}") return super().__new__(mcs, name, bases, namespace, **kwargs) class _DispatchMixin(metaclass=_DispatchMeta): _builtin_torch_fn: Optional[Callable] = None def _get_builtin_torch_fn(self): return type(self)._builtin_torch_fn def forward(self, *args, **kwargs): # pylint: disable=missing-function-docstring kernel = self.get_kernel() builtin_torch_fn = self._get_builtin_torch_fn() if not kernel or _is_computing_encodings(self): kernel = self._builtin_torch_fn_helper(builtin_torch_fn) else: kernel = self._custom_kernel_helper(kernel) with self._patch_quantized_parameters(): with _dispatch(builtin_torch_fn, kernel): output = super().forward(*args, **kwargs) return _dequantize_if_applicable(output) def _builtin_torch_fn_helper(self, fn: Callable[..., Tensor]): def wrapper(*args, **kwargs): qtzd_args = ( _quantize_dequantize_if_applicable(x, qtzr) for x, qtzr in zip(args, self.input_quantizers) ) others = ( _dequantize_if_applicable(x) for x in args[len(self.input_quantizers):] ) kwargs = { key: _dequantize_if_applicable(value) for key, value in kwargs.items() } output = fn(*qtzd_args, *others, **kwargs) return _quantize_dequantize_if_applicable(output, self.output_quantizers[0]) return wrapper def _custom_kernel_helper(self, fn: Callable[..., QuantizedTensorBase]): def wrapper(*args, **kwargs): qtzd_args = ( _quantize_if_applicable(x, qtzr) for x, qtzr in zip(args, self.input_quantizers) ) others = args[len(self.input_quantizers):] output_encodings = self.output_quantizers[0].get_encodings() if self.output_quantizers[0] else None kwargs.update(output_encodings=output_encodings) return fn(*qtzd_args, *others, **kwargs) return wrapper def __nullary__(self): super(type(self), self).__quant_init__() self.input_quantizers = nn.ModuleList([]) def __unary__(self): super(type(self), self).__quant_init__() def __binary__(self): super(type(self), self).__quant_init__() self.input_quantizers = nn.ModuleList([None, None]) def __ternary__(self): super(type(self), self).__quant_init__() self.input_quantizers = nn.ModuleList([None, None, None]) def _generate_docstring(parent_cls): return \ f""" Quantized subclass of torch.nn.{parent_cls.__name__} .. method:: forward{str(signature(parent_cls.forward))} :noindex: Quantized forward of torch.nn.{parent_cls.__name__}. The input(s), parameter(s) (if any), and output(s) will be quantized with ``self.input_quantizers``, ``self.param_quantizers``, and ``self.output_quantizers`` respectively. For more information, see :class:`QuantizationMixin`. """
[docs] @QuantizationMixin.implements(nn.AdaptiveAvgPool1d) class QuantizedAdaptiveAvgPool1d(_DispatchMixin, QuantizationMixin, nn.AdaptiveAvgPool1d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(nn.AdaptiveAvgPool1d) _builtin_torch_fn = F.adaptive_avg_pool1d __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.AdaptiveAvgPool2d) class QuantizedAdaptiveAvgPool2d(_DispatchMixin, QuantizationMixin, nn.AdaptiveAvgPool2d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.AdaptiveAvgPool2d) _builtin_torch_fn = F.adaptive_avg_pool2d __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.AdaptiveAvgPool3d) class QuantizedAdaptiveAvgPool3d(_DispatchMixin, QuantizationMixin, nn.AdaptiveAvgPool3d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.AdaptiveAvgPool3d) _builtin_torch_fn = F.adaptive_avg_pool3d __quant_init__ = __unary__
# @QuantizationMixin.implements(nn.AdaptiveLogSoftmaxWithLoss) # class QuantizedAdaptiveLogSoftmaxWithLoss(_DispatchMixin, QuantizationMixin, nn.AdaptiveLogSoftmaxWithLoss): # _builtin_torch_fn = ...
[docs] @QuantizationMixin.implements(nn.AdaptiveMaxPool1d) class QuantizedAdaptiveMaxPool1d(_DispatchMixin, QuantizationMixin, nn.AdaptiveMaxPool1d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.AdaptiveMaxPool1d) _builtin_torch_fn = F.adaptive_max_pool1d __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.AdaptiveMaxPool2d) class QuantizedAdaptiveMaxPool2d(_DispatchMixin, QuantizationMixin, nn.AdaptiveMaxPool2d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.AdaptiveMaxPool2d) _builtin_torch_fn = F.adaptive_max_pool2d __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.AdaptiveMaxPool3d) class QuantizedAdaptiveMaxPool3d(_DispatchMixin, QuantizationMixin, nn.AdaptiveMaxPool3d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.AdaptiveMaxPool3d) _builtin_torch_fn = F.adaptive_max_pool3d __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.AlphaDropout) class QuantizedAlphaDropout(_DispatchMixin, QuantizationMixin, nn.AlphaDropout): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.AlphaDropout) _builtin_torch_fn = F.alpha_dropout __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.AvgPool1d) class QuantizedAvgPool1d(_DispatchMixin, QuantizationMixin, nn.AvgPool1d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.AvgPool1d) _builtin_torch_fn = F.avg_pool1d __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.AvgPool2d) class QuantizedAvgPool2d(_DispatchMixin, QuantizationMixin, nn.AvgPool2d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.AvgPool2d) _builtin_torch_fn = F.avg_pool2d __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.AvgPool3d) class QuantizedAvgPool3d(_DispatchMixin, QuantizationMixin, nn.AvgPool3d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.AvgPool3d) _builtin_torch_fn = F.avg_pool3d __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.BCELoss) class QuantizedBCELoss(_DispatchMixin, QuantizationMixin, nn.BCELoss): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.BCELoss) _builtin_torch_fn = F.binary_cross_entropy __quant_init__ = __binary__
[docs] @QuantizationMixin.implements(nn.BCEWithLogitsLoss) class QuantizedBCEWithLogitsLoss(_DispatchMixin, QuantizationMixin, nn.BCEWithLogitsLoss): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.BCEWithLogitsLoss) _builtin_torch_fn = F.binary_cross_entropy_with_logits __quant_init__ = __binary__
[docs] @QuantizationMixin.implements(nn.BatchNorm1d) class QuantizedBatchNorm1d(_DispatchMixin, QuantizationMixin, nn.BatchNorm1d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.BatchNorm1d) _builtin_torch_fn = F.batch_norm __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.BatchNorm2d) class QuantizedBatchNorm2d(_DispatchMixin, QuantizationMixin, nn.BatchNorm2d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.BatchNorm2d) _builtin_torch_fn = F.batch_norm __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.BatchNorm3d) class QuantizedBatchNorm3d(_DispatchMixin, QuantizationMixin, nn.BatchNorm3d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.BatchNorm3d) _builtin_torch_fn = F.batch_norm __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.Bilinear) class QuantizedBilinear(_DispatchMixin, QuantizationMixin, nn.Bilinear): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.Bilinear) _builtin_torch_fn = F.bilinear __quant_init__ = __binary__
[docs] @QuantizationMixin.implements(nn.CELU) class QuantizedCELU(_DispatchMixin, QuantizationMixin, nn.CELU): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.CELU) _builtin_torch_fn = F.celu __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.CTCLoss) class QuantizedCTCLoss(_DispatchMixin, QuantizationMixin, nn.CTCLoss): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.CTCLoss) _builtin_torch_fn = F.ctc_loss __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.ChannelShuffle) class QuantizedChannelShuffle(_DispatchMixin, QuantizationMixin, nn.ChannelShuffle): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.ChannelShuffle) _builtin_torch_fn = F.channel_shuffle __quant_init__ = __unary__
if version.parse(torch.__version__) >= version.parse("2.1.0"):
[docs] @QuantizationMixin.implements(nn.CircularPad1d) class QuantizedCircularPad1d(_DispatchMixin, QuantizationMixin, nn.CircularPad1d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.CircularPad1d) _builtin_torch_fn = F.pad __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.CircularPad2d) class QuantizedCircularPad2d(_DispatchMixin, QuantizationMixin, nn.CircularPad2d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.CircularPad2d) _builtin_torch_fn = F.pad __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.CircularPad3d) class QuantizedCircularPad3d(_DispatchMixin, QuantizationMixin, nn.CircularPad3d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.CircularPad3d) _builtin_torch_fn = F.pad __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.ConstantPad1d) class QuantizedConstantPad1d(_DispatchMixin, QuantizationMixin, nn.ConstantPad1d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.ConstantPad2d) _builtin_torch_fn = F.pad __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.ConstantPad2d) class QuantizedConstantPad2d(_DispatchMixin, QuantizationMixin, nn.ConstantPad2d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.ConstantPad2d) _builtin_torch_fn = F.pad __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.ConstantPad3d) class QuantizedConstantPad3d(_DispatchMixin, QuantizationMixin, nn.ConstantPad3d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.ConstantPad3d) _builtin_torch_fn = F.pad __quant_init__ = __unary__
# @QuantizationMixin.implements(nn.Container) # class QuantizedContainer(_DispatchMixin, QuantizationMixin, nn.Container): # _builtin_torch_fn = ...
[docs] @QuantizationMixin.implements(nn.Conv1d) class QuantizedConv1d(_DispatchMixin, QuantizationMixin, nn.Conv1d): # pylint: disable=too-many-ancestors # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.Conv1d) _builtin_torch_fn = F.conv1d __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.Conv2d) class QuantizedConv2d(_DispatchMixin, QuantizationMixin, nn.Conv2d): # pylint: disable=too-many-ancestors # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.Conv2d) _builtin_torch_fn = F.conv2d __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.Conv3d) class QuantizedConv3d(_DispatchMixin, QuantizationMixin, nn.Conv3d): # pylint: disable=too-many-ancestors # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.Conv3d) _builtin_torch_fn = F.conv3d __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.ConvTranspose1d) class QuantizedConvTranspose1d(_DispatchMixin, QuantizationMixin, nn.ConvTranspose1d): # pylint: disable=too-many-ancestors # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.ConvTranspose1d) _builtin_torch_fn = F.conv_transpose1d __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.ConvTranspose2d) class QuantizedConvTranspose2d(_DispatchMixin, QuantizationMixin, nn.ConvTranspose2d): # pylint: disable=too-many-ancestors # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.ConvTranspose2d) _builtin_torch_fn = F.conv_transpose2d __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.ConvTranspose3d) class QuantizedConvTranspose3d(_DispatchMixin, QuantizationMixin, nn.ConvTranspose3d): # pylint: disable=too-many-ancestors # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.ConvTranspose3d) _builtin_torch_fn = F.conv_transpose3d __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.CosineEmbeddingLoss) class QuantizedCosineEmbeddingLoss(_DispatchMixin, QuantizationMixin, nn.CosineEmbeddingLoss): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.CosineEmbeddingLoss) _builtin_torch_fn = F.cosine_embedding_loss __quant_init__ = __binary__
[docs] @QuantizationMixin.implements(nn.CosineSimilarity) class QuantizedCosineSimilarity(_DispatchMixin, QuantizationMixin, nn.CosineSimilarity): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.CosineSimilarity) _builtin_torch_fn = F.cosine_similarity __quant_init__ = __binary__
[docs] @QuantizationMixin.implements(nn.CrossEntropyLoss) class QuantizedCrossEntropyLoss(_DispatchMixin, QuantizationMixin, nn.CrossEntropyLoss): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.CrossEntropyLoss) _builtin_torch_fn = F.cross_entropy __quant_init__ = __binary__
# @QuantizationMixin.implements(nn.CrossMapLRN2d) # class QuantizedCrossMapLRN2d(_DispatchMixin, QuantizationMixin, nn.CrossMapLRN2d): # _builtin_torch_fn = ...
[docs] @QuantizationMixin.implements(nn.Dropout) class QuantizedDropout(_DispatchMixin, QuantizationMixin, nn.Dropout): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.Dropout) _builtin_torch_fn = F.dropout __quant_init__ = __unary__
if version.parse(torch.__version__) >= version.parse("1.12.0"):
[docs] @QuantizationMixin.implements(nn.Dropout1d) class QuantizedDropout1d(_DispatchMixin, QuantizationMixin, nn.Dropout1d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.Dropout1d) _builtin_torch_fn = F.dropout1d __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.Dropout2d) class QuantizedDropout2d(_DispatchMixin, QuantizationMixin, nn.Dropout2d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.Dropout2d) _builtin_torch_fn = F.dropout2d __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.Dropout3d) class QuantizedDropout3d(_DispatchMixin, QuantizationMixin, nn.Dropout3d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.Dropout3d) _builtin_torch_fn = F.dropout3d __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.ELU) class QuantizedELU(_DispatchMixin, QuantizationMixin, nn.ELU): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.ELU) _builtin_torch_fn = F.elu __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.Embedding) class QuantizedEmbedding(_DispatchMixin, QuantizationMixin, nn.Embedding): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.Embedding) _builtin_torch_fn = F.embedding __quant_init__ = __nullary__
[docs] @QuantizationMixin.implements(nn.EmbeddingBag) class QuantizedEmbeddingBag(_DispatchMixin, QuantizationMixin, nn.EmbeddingBag): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.EmbeddingBag) _builtin_torch_fn = F.embedding_bag def _builtin_torch_fn_helper(self, fn: Callable[..., Tensor]): def embedding_bag(input: Tensor, # pylint: disable=redefined-builtin, too-many-arguments weight: Tensor, offsets: Optional[Tensor] = None, max_norm: Optional[float] = None, norm_type: float = 2, scale_grad_by_freq: bool = False, mode: str = "mean", sparse: bool = False, per_sample_weights: Optional[Tensor] = None, include_last_offset: bool = False, padding_idx: Optional[int] = None): if per_sample_weights is not None: qtzr = self.input_quantizers[0] per_sample_weights = _quantize_dequantize_if_applicable(per_sample_weights, qtzr) output = fn(input, weight, offsets=offsets, max_norm=max_norm, norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq, mode=mode, sparse=sparse, per_sample_weights=per_sample_weights, include_last_offset=include_last_offset, padding_idx=padding_idx) return _quantize_dequantize_if_applicable(output, self.output_quantizers[0]) return embedding_bag def _custom_kernel_helper(self, fn: Callable[..., QuantizedTensorBase]): def embedding_bag(input: Tensor, # pylint: disable=redefined-builtin, too-many-arguments weight: Tensor, offsets: Optional[Tensor] = None, max_norm: Optional[float] = None, norm_type: float = 2, scale_grad_by_freq: bool = False, mode: str = "mean", sparse: bool = False, per_sample_weights: Optional[Tensor] = None, include_last_offset: bool = False, padding_idx: Optional[int] = None): if per_sample_weights is not None: qtzr = self.input_quantizers[0] per_sample_weights = _quantize_if_applicable(per_sample_weights, qtzr) output_encodings = self.output_quantizers[0].get_encodings() if self.output_quantizers[0] else None return fn(input, weight, offsets=offsets, max_norm=max_norm, norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq, mode=mode, sparse=sparse, per_sample_weights=per_sample_weights, include_last_offset=include_last_offset, padding_idx=padding_idx, output_encodings=output_encodings) return embedding_bag
[docs] @QuantizationMixin.implements(nn.FeatureAlphaDropout) class QuantizedFeatureAlphaDropout(_DispatchMixin, QuantizationMixin, nn.FeatureAlphaDropout): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.FeatureAlphaDropout) _builtin_torch_fn = F.feature_alpha_dropout __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.Flatten) class QuantizedFlatten(_DispatchMixin, QuantizationMixin, nn.Flatten): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.Flatten) def _get_builtin_torch_fn(self): return Tensor.flatten __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.Fold) class QuantizedFold(_DispatchMixin, QuantizationMixin, nn.Fold): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.Fold) _builtin_torch_fn = F.fold __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.FractionalMaxPool2d) class QuantizedFractionalMaxPool2d(_DispatchMixin, QuantizationMixin, nn.FractionalMaxPool2d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.FractionalMaxPool2d) _builtin_torch_fn = F.fractional_max_pool2d __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.FractionalMaxPool3d) class QuantizedFractionalMaxPool3d(_DispatchMixin, QuantizationMixin, nn.FractionalMaxPool3d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.FractionalMaxPool3d) _builtin_torch_fn = F.fractional_max_pool3d __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.GELU) class QuantizedGELU(_DispatchMixin, QuantizationMixin, nn.GELU): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.GELU) _builtin_torch_fn = F.gelu __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.GLU) class QuantizedGLU(_DispatchMixin, QuantizationMixin, nn.GLU): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.GLU) _builtin_torch_fn = F.glu __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.GRU) class QuantizedGRU(_DispatchMixin, QuantizationMixin, nn.GRU): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.GRU) _builtin_torch_fn = _gru def __quant_init__(self): super().__quant_init__() # pylint: disable=attribute-defined-outside-init self.input_quantizers = nn.ModuleList([None, None]) self.output_quantizers = nn.ModuleList([None, None]) def _quantize_inputs(self, args, apply): if args[1].is_floating_point(): input, hx, *others = args batch_sizes = None else: input, batch_sizes, hx, *others = args input = apply(input, self.input_quantizers[0]) hx = apply(hx, self.input_quantizers[1]) if batch_sizes is None: return input, hx, *others return input, batch_sizes, hx, *others def _builtin_torch_fn_helper(self, fn: Callable[..., Tensor]): assert fn == _gru apply = _quantize_dequantize_if_applicable def gru(*args): args = self._quantize_inputs(args, apply) output, h_n = fn(*args) return ( apply(output, self.output_quantizers[0]), apply(h_n, self.output_quantizers[1]), ) return gru def _custom_kernel_helper(self, fn: Callable[..., QuantizedTensorBase]): apply = _quantize_if_applicable def gru(*args): args = self._quantize_inputs(args, apply) output_encodings = tuple(qtzr and qtzr.get_encodings() for qtzr in self.output_quantizers) return fn(*args, output_encodings=output_encodings) return gru
[docs] @QuantizationMixin.implements(nn.GRUCell) class QuantizedGRUCell(_DispatchMixin, QuantizationMixin, nn.GRUCell): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.GRUCell) _builtin_torch_fn = _gru_cell def __quant_init__(self): super().__quant_init__() # pylint: disable=attribute-defined-outside-init self.input_quantizers = nn.ModuleList([None, None]) self.output_quantizers = nn.ModuleList([None]) def _builtin_torch_fn_helper(self, fn: Callable[..., Tensor]): assert fn == _gru_cell apply = _quantize_dequantize_if_applicable def gru_cell(input, hx, *args, **kwargs): input = apply(input, self.input_quantizers[0]) hx = apply(hx, self.input_quantizers[1]) output = fn(input, hx, *args, **kwargs) return apply(output, self.output_quantizers[0]) return gru_cell def _custom_kernel_helper(self, fn: Callable[..., QuantizedTensorBase]): apply = _quantize_if_applicable def gru_cell(input, hx, *args, **kwargs): input = apply(input, self.input_quantizers[0]) hx = apply(hx, self.input_quantizers[1]) output_encodings = self.output_quantizers[0] and self.output_quantizers[0].get_encodings() return fn(input, hx, *args, **kwargs, output_encodings=output_encodings) return gru_cell
[docs] @QuantizationMixin.implements(nn.GaussianNLLLoss) class QuantizedGaussianNLLLoss(_DispatchMixin, QuantizationMixin, nn.GaussianNLLLoss): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.GaussianNLLLoss) _builtin_torch_fn = F.gaussian_nll_loss __quant_init__ = __ternary__
[docs] @QuantizationMixin.implements(nn.GroupNorm) class QuantizedGroupNorm(_DispatchMixin, QuantizationMixin, nn.GroupNorm): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.GroupNorm) _builtin_torch_fn = F.group_norm __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.Hardshrink) class QuantizedHardshrink(_DispatchMixin, QuantizationMixin, nn.Hardshrink): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.Hardshrink) _builtin_torch_fn = F.hardshrink __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.Hardsigmoid) class QuantizedHardsigmoid(_DispatchMixin, QuantizationMixin, nn.Hardsigmoid): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.Hardsigmoid) _builtin_torch_fn = F.hardsigmoid __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.Hardswish) class QuantizedHardswish(_DispatchMixin, QuantizationMixin, nn.Hardswish): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.Hardswish) _builtin_torch_fn = F.hardswish __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.Hardtanh) class QuantizedHardtanh(_DispatchMixin, QuantizationMixin, nn.Hardtanh): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.Hardtanh) _builtin_torch_fn = F.hardtanh __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.HingeEmbeddingLoss) class QuantizedHingeEmbeddingLoss(_DispatchMixin, QuantizationMixin, nn.HingeEmbeddingLoss): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.HingeEmbeddingLoss) _builtin_torch_fn = F.hinge_embedding_loss __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.HuberLoss) class QuantizedHuberLoss(_DispatchMixin, QuantizationMixin, nn.HuberLoss): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.HuberLoss) _builtin_torch_fn = F.huber_loss __quant_init__ = __binary__
# @QuantizationMixin.implements(nn.Identity) # class QuantizedIdentity(_DispatchMixin, QuantizationMixin, nn.Identity): # _builtin_torch_fn = ...
[docs] @QuantizationMixin.implements(nn.InstanceNorm1d) class QuantizedInstanceNorm1d(_DispatchMixin, QuantizationMixin, nn.InstanceNorm1d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.InstanceNorm1d) _builtin_torch_fn = F.instance_norm __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.InstanceNorm2d) class QuantizedInstanceNorm2d(_DispatchMixin, QuantizationMixin, nn.InstanceNorm2d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.InstanceNorm2d) _builtin_torch_fn = F.instance_norm __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.InstanceNorm3d) class QuantizedInstanceNorm3d(_DispatchMixin, QuantizationMixin, nn.InstanceNorm3d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.InstanceNorm3d) _builtin_torch_fn = F.instance_norm __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.KLDivLoss) class QuantizedKLDivLoss(_DispatchMixin, QuantizationMixin, nn.KLDivLoss): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.KLDivLoss) _builtin_torch_fn = F.kl_div __quant_init__ = __binary__
[docs] @QuantizationMixin.implements(nn.L1Loss) class QuantizedL1Loss(_DispatchMixin, QuantizationMixin, nn.L1Loss): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.L1Loss) _builtin_torch_fn = F.l1_loss __quant_init__ = __binary__
[docs] @QuantizationMixin.implements(nn.LPPool1d) class QuantizedLPPool1d(_DispatchMixin, QuantizationMixin, nn.LPPool1d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.LPPool1d) _builtin_torch_fn = F.lp_pool1d __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.LPPool2d) class QuantizedLPPool2d(_DispatchMixin, QuantizationMixin, nn.LPPool2d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.LPPool2d) _builtin_torch_fn = F.lp_pool2d __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.LSTM) class QuantizedLSTM(_DispatchMixin, QuantizationMixin, nn.LSTM): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.LSTM) _builtin_torch_fn = _lstm def __quant_init__(self): super().__quant_init__() # pylint: disable=attribute-defined-outside-init self.input_quantizers = nn.ModuleList([None, None, None]) self.output_quantizers = nn.ModuleList([None, None, None]) def _quantize_inputs(self, args, apply): if isinstance(args[1], Tensor): input, batch_sizes, hx, *others = args else: input, hx, *others = args batch_sizes = None input = apply(input, self.input_quantizers[0]) h, c = hx h_qtzr, c_qtzr = self.input_quantizers[1:] hx = (apply(h, h_qtzr), apply(c, c_qtzr)) if batch_sizes is None: return input, hx, *others return input, batch_sizes, hx, *others def _builtin_torch_fn_helper(self, fn: Callable[..., Tensor]): assert fn == _lstm apply = _quantize_dequantize_if_applicable def lstm(*args): args = self._quantize_inputs(args, apply) output, h_n, c_n = fn(*args) return ( apply(output, self.output_quantizers[0]), apply(h_n, self.output_quantizers[1]), apply(c_n, self.output_quantizers[2]), ) return lstm def _custom_kernel_helper(self, fn: Callable[..., QuantizedTensorBase]): apply = _quantize_if_applicable def lstm(*args): args = self._quantize_inputs(args, apply) output_encodings = tuple(qtzr and qtzr.get_encodings() for qtzr in self.output_quantizers) return fn(*args, output_encodings=output_encodings) return lstm
[docs] @QuantizationMixin.implements(nn.LSTMCell) class QuantizedLSTMCell(_DispatchMixin, QuantizationMixin, nn.LSTMCell): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.LSTMCell) _builtin_torch_fn = _lstm_cell def __quant_init__(self): super().__quant_init__() # pylint: disable=attribute-defined-outside-init self.input_quantizers = nn.ModuleList([None, None, None]) self.output_quantizers = nn.ModuleList([None, None]) def _builtin_torch_fn_helper(self, fn: Callable[..., Tensor]): assert fn == _lstm_cell apply = _quantize_dequantize_if_applicable def lstm_cell(input, hx, *args, **kwargs): input = apply(input, self.input_quantizers[0]) h, c = hx h_qtzr, c_qtzr = self.input_quantizers[1:] hx = (apply(h, h_qtzr), apply(c, c_qtzr)) hx, cx = fn(input, hx, *args, **kwargs) return ( apply(hx, self.output_quantizers[0]), apply(cx, self.output_quantizers[1]), ) return lstm_cell def _custom_kernel_helper(self, fn: Callable[..., QuantizedTensorBase]): apply = _quantize_if_applicable def lstm_cell(input, hx, *args, **kwargs): input = apply(input, self.input_quantizers[0]) h, c = hx h_qtzr, c_qtzr = self.input_quantizers[1:] hx = (apply(h, h_qtzr), apply(c, c_qtzr)) output_encodings = tuple(qtzr and qtzr.get_encodings() for qtzr in self.output_quantizers) return fn(input, hx, *args, **kwargs, output_encodings=output_encodings) return lstm_cell
[docs] @QuantizationMixin.implements(nn.LayerNorm) class QuantizedLayerNorm(_DispatchMixin, QuantizationMixin, nn.LayerNorm): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.LayerNorm) _builtin_torch_fn = F.layer_norm __quant_init__ = __unary__
# @QuantizationMixin.implements(nn.LazyBatchNorm1d) # class QuantizedLazyBatchNorm1d(_DispatchMixin, QuantizationMixin, nn.LazyBatchNorm1d): # _builtin_torch_fn = ... # @QuantizationMixin.implements(nn.LazyBatchNorm2d) # class QuantizedLazyBatchNorm2d(_DispatchMixin, QuantizationMixin, nn.LazyBatchNorm2d): # _builtin_torch_fn = ... # @QuantizationMixin.implements(nn.LazyBatchNorm3d) # class QuantizedLazyBatchNorm3d(_DispatchMixin, QuantizationMixin, nn.LazyBatchNorm3d): # _builtin_torch_fn = ... # @QuantizationMixin.implements(nn.LazyConv1d) # class QuantizedLazyConv1d(_DispatchMixin, QuantizationMixin, nn.LazyConv1d): # _builtin_torch_fn = ... # @QuantizationMixin.implements(nn.LazyConv2d) # class QuantizedLazyConv2d(_DispatchMixin, QuantizationMixin, nn.LazyConv2d): # _builtin_torch_fn = ... # @QuantizationMixin.implements(nn.LazyConv3d) # class QuantizedLazyConv3d(_DispatchMixin, QuantizationMixin, nn.LazyConv3d): # _builtin_torch_fn = ... # @QuantizationMixin.implements(nn.LazyConvTranspose1d) # class QuantizedLazyConvTranspose1d(_DispatchMixin, QuantizationMixin, nn.LazyConvTranspose1d): # _builtin_torch_fn = ... # @QuantizationMixin.implements(nn.LazyConvTranspose2d) # class QuantizedLazyConvTranspose2d(_DispatchMixin, QuantizationMixin, nn.LazyConvTranspose2d): # _builtin_torch_fn = ... # @QuantizationMixin.implements(nn.LazyConvTranspose3d) # class QuantizedLazyConvTranspose3d(_DispatchMixin, QuantizationMixin, nn.LazyConvTranspose3d): # _builtin_torch_fn = ... # @QuantizationMixin.implements(nn.LazyInstanceNorm1d) # class QuantizedLazyInstanceNorm1d(_DispatchMixin, QuantizationMixin, nn.LazyInstanceNorm1d): # _builtin_torch_fn = ... # @QuantizationMixin.implements(nn.LazyInstanceNorm2d) # class QuantizedLazyInstanceNorm2d(_DispatchMixin, QuantizationMixin, nn.LazyInstanceNorm2d): # _builtin_torch_fn = ... # @QuantizationMixin.implements(nn.LazyInstanceNorm3d) # class QuantizedLazyInstanceNorm3d(_DispatchMixin, QuantizationMixin, nn.LazyInstanceNorm3d): # _builtin_torch_fn = ... # @QuantizationMixin.implements(nn.LazyLinear) # class QuantizedLazyLinear(_DispatchMixin, QuantizationMixin, nn.LazyLinear): # _builtin_torch_fn = ...
[docs] @QuantizationMixin.implements(nn.LeakyReLU) class QuantizedLeakyReLU(_DispatchMixin, QuantizationMixin, nn.LeakyReLU): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.LeakyReLU) _builtin_torch_fn = F.leaky_relu __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.Linear) class QuantizedLinear(_DispatchMixin, QuantizationMixin, nn.Linear): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.Linear) _builtin_torch_fn = F.linear __quant_init__ = __unary__ # Only allow activation recompute (a.k.a activation checkpointing) for QuantizedLinear. # This is mainly to reduce memory footprint of QAT of large language models.
[docs] @allow_recompute def forward(self, *args, **kwargs): # Workaround for deepspeed. # Deepspeed zero3 sometimes forcefully mokey-patches F.linear to torch.addmm, # which collides with the core assumption of our dispatch mechanism # that nn.Linear invokes F.linear. # To circumvent this issue, we temporarily restore the original F.linear # before running forward. with patch_attr(F, 'linear', type(self)._builtin_torch_fn): return super().forward(*args, **kwargs)
[docs] @QuantizationMixin.implements(nn.LocalResponseNorm) class QuantizedLocalResponseNorm(_DispatchMixin, QuantizationMixin, nn.LocalResponseNorm): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.LocalResponseNorm) _builtin_torch_fn = F.local_response_norm __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.LogSigmoid) class QuantizedLogSigmoid(_DispatchMixin, QuantizationMixin, nn.LogSigmoid): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.LogSigmoid) _builtin_torch_fn = F.logsigmoid __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.LogSoftmax) class QuantizedLogSoftmax(_DispatchMixin, QuantizationMixin, nn.LogSoftmax): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.LogSoftmax) _builtin_torch_fn = F.log_softmax __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.MSELoss) class QuantizedMSELoss(_DispatchMixin, QuantizationMixin, nn.MSELoss): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.MSELoss) _builtin_torch_fn = F.mse_loss __quant_init__ = __binary__
[docs] @QuantizationMixin.implements(nn.MarginRankingLoss) class QuantizedMarginRankingLoss(_DispatchMixin, QuantizationMixin, nn.MarginRankingLoss): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.MarginRankingLoss) _builtin_torch_fn = F.margin_ranking_loss __quant_init__ = __binary__
[docs] @QuantizationMixin.implements(nn.MaxPool1d) class QuantizedMaxPool1d(_DispatchMixin, QuantizationMixin, nn.MaxPool1d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.MaxPool1d) _builtin_torch_fn = F.max_pool1d __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.MaxPool2d) class QuantizedMaxPool2d(_DispatchMixin, QuantizationMixin, nn.MaxPool2d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.MaxPool2d) _builtin_torch_fn = F.max_pool2d __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.MaxPool3d) class QuantizedMaxPool3d(_DispatchMixin, QuantizationMixin, nn.MaxPool3d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.MaxPool3d) _builtin_torch_fn = F.max_pool3d __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.MaxUnpool1d) class QuantizedMaxUnpool1d(_DispatchMixin, QuantizationMixin, nn.MaxUnpool1d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.MaxUnpool1d) _builtin_torch_fn = F.max_unpool1d __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.MaxUnpool2d) class QuantizedMaxUnpool2d(_DispatchMixin, QuantizationMixin, nn.MaxUnpool2d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.MaxUnpool2d) _builtin_torch_fn = F.max_unpool2d __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.MaxUnpool3d) class QuantizedMaxUnpool3d(_DispatchMixin, QuantizationMixin, nn.MaxUnpool3d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.MaxUnpool3d) _builtin_torch_fn = F.max_unpool3d __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.Mish) class QuantizedMish(_DispatchMixin, QuantizationMixin, nn.Mish): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.Mish) _builtin_torch_fn = F.mish __quant_init__ = __unary__
# @QuantizationMixin.implements(nn.Module) # class QuantizedModule(_DispatchMixin, QuantizationMixin, nn.Module): # _builtin_torch_fn = ... # @QuantizationMixin.implements(nn.ModuleDict) # class QuantizedModuleDict(_DispatchMixin, QuantizationMixin, nn.ModuleDict): # _builtin_torch_fn = ... # @QuantizationMixin.implements(nn.ModuleList) # class QuantizedModuleList(_DispatchMixin, QuantizationMixin, nn.ModuleList): # _builtin_torch_fn = ...
[docs] @QuantizationMixin.implements(nn.MultiLabelMarginLoss) class QuantizedMultiLabelMarginLoss(_DispatchMixin, QuantizationMixin, nn.MultiLabelMarginLoss): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.MultiLabelMarginLoss) _builtin_torch_fn = F.multilabel_margin_loss __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.MultiLabelSoftMarginLoss) class QuantizedMultiLabelSoftMarginLoss(_DispatchMixin, QuantizationMixin, nn.MultiLabelSoftMarginLoss): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.MultiLabelSoftMarginLoss) _builtin_torch_fn = F.multilabel_soft_margin_loss __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.MultiMarginLoss) class QuantizedMultiMarginLoss(_DispatchMixin, QuantizationMixin, nn.MultiMarginLoss): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.MultiMarginLoss) _builtin_torch_fn = F.multi_margin_loss __quant_init__ = __unary__
# @QuantizationMixin.implements(nn.MultiheadAttention) # class QuantizedMultiheadAttention(_DispatchMixin, QuantizationMixin, nn.MultiheadAttention): # _builtin_torch_fn = ...
[docs] @QuantizationMixin.implements(nn.NLLLoss) class QuantizedNLLLoss(_DispatchMixin, QuantizationMixin, nn.NLLLoss): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.NLLLoss) _builtin_torch_fn = F.nll_loss __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.NLLLoss2d) class QuantizedNLLLoss2d(_DispatchMixin, QuantizationMixin, nn.NLLLoss2d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.NLLLoss2d) _builtin_torch_fn = F.nll_loss __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.PReLU) class QuantizedPReLU(_DispatchMixin, QuantizationMixin, nn.PReLU): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.PReLU) _builtin_torch_fn = F.prelu __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.PairwiseDistance) class QuantizedPairwiseDistance(_DispatchMixin, QuantizationMixin, nn.PairwiseDistance): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.PairwiseDistance) _builtin_torch_fn = F.pairwise_distance __quant_init__ = __binary__
# @QuantizationMixin.implements(nn.ParameterDict) # class QuantizedParameterDict(_DispatchMixin, QuantizationMixin, nn.ParameterDict): # _builtin_torch_fn = ... # @QuantizationMixin.implements(nn.ParameterList) # class QuantizedParameterList(_DispatchMixin, QuantizationMixin, nn.ParameterList): # _builtin_torch_fn = ...
[docs] @QuantizationMixin.implements(nn.PixelShuffle) class QuantizedPixelShuffle(_DispatchMixin, QuantizationMixin, nn.PixelShuffle): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.PixelShuffle) _builtin_torch_fn = F.pixel_shuffle __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.PixelUnshuffle) class QuantizedPixelUnshuffle(_DispatchMixin, QuantizationMixin, nn.PixelUnshuffle): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.PixelUnshuffle) _builtin_torch_fn = F.pixel_unshuffle __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.PoissonNLLLoss) class QuantizedPoissonNLLLoss(_DispatchMixin, QuantizationMixin, nn.PoissonNLLLoss): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.PoissonNLLLoss) _builtin_torch_fn = F.poisson_nll_loss __quant_init__ = __binary__
[docs] @QuantizationMixin.implements(nn.RNN) class QuantizedRNN(_DispatchMixin, QuantizationMixin, nn.RNN): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.RNN) def _get_builtin_torch_fn(self): assert self.mode in ('RNN_TANH', 'RNN_RELU') if self.mode == 'RNN_TANH': return _rnn_tanh return _rnn_relu def __quant_init__(self): super().__quant_init__() # pylint: disable=attribute-defined-outside-init self.input_quantizers = nn.ModuleList([None, None]) self.output_quantizers = nn.ModuleList([None, None]) def _quantize_inputs(self, args, apply): if args[1].is_floating_point(): input, hx, *others = args batch_sizes = None else: input, batch_sizes, hx, *others = args input = apply(input, self.input_quantizers[0]) hx = apply(hx, self.input_quantizers[1]) if batch_sizes is None: return input, hx, *others return input, batch_sizes, hx, *others def _builtin_torch_fn_helper(self, fn: Callable[..., Tensor]): assert fn in (_rnn_tanh, _rnn_relu) apply = _quantize_dequantize_if_applicable def rnn(*args): args = self._quantize_inputs(args, apply) output, h_n = fn(*args) return ( apply(output, self.output_quantizers[0]), apply(h_n, self.output_quantizers[1]), ) return rnn def _custom_kernel_helper(self, fn: Callable[..., QuantizedTensorBase]): apply = _quantize_if_applicable def rnn(*args): args = self._quantize_inputs(args, apply) output_encodings = tuple(qtzr and qtzr.get_encodings() for qtzr in self.output_quantizers) return fn(*args, output_encodings=output_encodings) return rnn
# @QuantizationMixin.implements(nn.RNNBase) # class QuantizedRNNBase(_DispatchMixin, QuantizationMixin, nn.RNNBase): # _builtin_torch_fn = ...
[docs] @QuantizationMixin.implements(nn.RNNCell) class QuantizedRNNCell(_DispatchMixin, QuantizationMixin, nn.RNNCell): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.RNNCell) def _get_builtin_torch_fn(self): assert self.nonlinearity in ("tanh", "relu") if self.nonlinearity == "tanh": return _rnn_tanh_cell return _rnn_relu_cell def __quant_init__(self): super().__quant_init__() # pylint: disable=attribute-defined-outside-init self.input_quantizers = nn.ModuleList([None, None]) self.output_quantizers = nn.ModuleList([None]) def _builtin_torch_fn_helper(self, fn: Callable[..., Tensor]): assert fn in (_rnn_tanh_cell, _rnn_relu_cell) apply = _quantize_dequantize_if_applicable def rnn_cell(input, hx, *args, **kwargs): input = apply(input, self.input_quantizers[0]) hx = apply(hx, self.input_quantizers[1]) output = fn(input, hx, *args, **kwargs) return apply(output, self.output_quantizers[0]) return rnn_cell def _custom_kernel_helper(self, fn: Callable[..., QuantizedTensorBase]): apply = _quantize_if_applicable def rnn_cell(input, hx, *args, **kwargs): input = apply(input, self.input_quantizers[0]) hx = apply(hx, self.input_quantizers[1]) output_encodings = self.output_quantizers[0] and self.output_quantizers[0].get_encodings() return fn(input, hx, *args, **kwargs, output_encodings=output_encodings) return rnn_cell
# @QuantizationMixin.implements(nn.RNNCellBase) # class QuantizedRNNCellBase(_DispatchMixin, QuantizationMixin, nn.RNNCellBase): # _builtin_torch_fn = ...
[docs] @QuantizationMixin.implements(nn.RReLU) class QuantizedRReLU(_DispatchMixin, QuantizationMixin, nn.RReLU): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.RReLU) _builtin_torch_fn = F.rrelu __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.ReLU) class QuantizedReLU(_DispatchMixin, QuantizationMixin, nn.ReLU): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.ReLU) _builtin_torch_fn = F.relu __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.ReLU6) class QuantizedReLU6(_DispatchMixin, QuantizationMixin, nn.ReLU6): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.ReLU6) _builtin_torch_fn = F.hardtanh __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.ReflectionPad1d) class QuantizedReflectionPad1d(_DispatchMixin, QuantizationMixin, nn.ReflectionPad1d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.ReflectionPad1d) _builtin_torch_fn = F.pad __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.ReflectionPad2d) class QuantizedReflectionPad2d(_DispatchMixin, QuantizationMixin, nn.ReflectionPad2d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.ReflectionPad2d) _builtin_torch_fn = F.pad __quant_init__ = __unary__
if version.parse(torch.__version__) >= version.parse("1.10.0"):
[docs] @QuantizationMixin.implements(nn.ReflectionPad3d) class QuantizedReflectionPad3d(_DispatchMixin, QuantizationMixin, nn.ReflectionPad3d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.ReflectionPad3d) _builtin_torch_fn = F.pad __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.ReplicationPad1d) class QuantizedReplicationPad1d(_DispatchMixin, QuantizationMixin, nn.ReplicationPad1d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.ReplicationPad1d) _builtin_torch_fn = F.pad __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.ReplicationPad2d) class QuantizedReplicationPad2d(_DispatchMixin, QuantizationMixin, nn.ReplicationPad2d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.ReplicationPad2d) _builtin_torch_fn = F.pad __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.ReplicationPad3d) class QuantizedReplicationPad3d(_DispatchMixin, QuantizationMixin, nn.ReplicationPad3d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.ReplicationPad3d) _builtin_torch_fn = F.pad __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.SELU) class QuantizedSELU(_DispatchMixin, QuantizationMixin, nn.SELU): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.SELU) _builtin_torch_fn = F.selu __quant_init__ = __unary__
# @QuantizationMixin.implements(nn.Sequential) # class QuantizedSequential(_DispatchMixin, QuantizationMixin, nn.Sequential): # _builtin_torch_fn = ...
[docs] @QuantizationMixin.implements(nn.SiLU) class QuantizedSiLU(_DispatchMixin, QuantizationMixin, nn.SiLU): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.SiLU) _builtin_torch_fn = F.silu __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.Sigmoid) class QuantizedSigmoid(_DispatchMixin, QuantizationMixin, nn.Sigmoid): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.Sigmoid) _builtin_torch_fn = torch.sigmoid __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.SmoothL1Loss) class QuantizedSmoothL1Loss(_DispatchMixin, QuantizationMixin, nn.SmoothL1Loss): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.SmoothL1Loss) _builtin_torch_fn = F.smooth_l1_loss __quant_init__ = __binary__
[docs] @QuantizationMixin.implements(nn.SoftMarginLoss) class QuantizedSoftMarginLoss(_DispatchMixin, QuantizationMixin, nn.SoftMarginLoss): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.SoftMarginLoss) _builtin_torch_fn = F.soft_margin_loss __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.Softmax) class QuantizedSoftmax(_DispatchMixin, QuantizationMixin, nn.Softmax): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.Softmax) _builtin_torch_fn = F.softmax __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.Softmax2d) class QuantizedSoftmax2d(_DispatchMixin, QuantizationMixin, nn.Softmax2d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.Softmax2d) _builtin_torch_fn = F.softmax __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.Softmin) class QuantizedSoftmin(_DispatchMixin, QuantizationMixin, nn.Softmin): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.Softmin) _builtin_torch_fn = F.softmin __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.Softplus) class QuantizedSoftplus(_DispatchMixin, QuantizationMixin, nn.Softplus): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.Softplus) _builtin_torch_fn = F.softplus __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.Softshrink) class QuantizedSoftshrink(_DispatchMixin, QuantizationMixin, nn.Softshrink): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.Softshrink) _builtin_torch_fn = F.softshrink __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.Softsign) class QuantizedSoftsign(_DispatchMixin, QuantizationMixin, nn.Softsign): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.Softsign) _builtin_torch_fn = F.softsign __quant_init__ = __unary__
# @QuantizationMixin.implements(nn.SyncBatchNorm) # class QuantizedSyncBatchNorm(_DispatchMixin, QuantizationMixin, nn.SyncBatchNorm): # _builtin_torch_fn = ...
[docs] @QuantizationMixin.implements(nn.Tanh) class QuantizedTanh(_DispatchMixin, QuantizationMixin, nn.Tanh): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.Tanh) _builtin_torch_fn = torch.tanh __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.Tanhshrink) class QuantizedTanhshrink(_DispatchMixin, QuantizationMixin, nn.Tanhshrink): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.Tanhshrink) _builtin_torch_fn = F.tanhshrink __quant_init__ = __unary__
# @QuantizationMixin.implements(nn.Threshold)
[docs] @QuantizationMixin.implements(nn.Threshold) class QuantizedThreshold(_DispatchMixin, QuantizationMixin, nn.Threshold): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.Threshold) _builtin_torch_fn = F.threshold __quant_init__ = __unary__
# @QuantizationMixin.implements(nn.Transformer) # class QuantizedTransformer(_DispatchMixin, QuantizationMixin, nn.Transformer): # _builtin_torch_fn = ... # @QuantizationMixin.implements(nn.TransformerDecoder) # class QuantizedTransformerDecoder(_DispatchMixin, QuantizationMixin, nn.TransformerDecoder): # _builtin_torch_fn = ... # @QuantizationMixin.implements(nn.TransformerDecoderLayer) # class QuantizedTransformerDecoderLayer(_DispatchMixin, QuantizationMixin, nn.TransformerDecoderLayer): # _builtin_torch_fn = ... # @QuantizationMixin.implements(nn.TransformerEncoder) # class QuantizedTransformerEncoder(_DispatchMixin, QuantizationMixin, nn.TransformerEncoder): # _builtin_torch_fn = ... # @QuantizationMixin.implements(nn.TransformerEncoderLayer) # class QuantizedTransformerEncoderLayer(_DispatchMixin, QuantizationMixin, nn.TransformerEncoderLayer): # _builtin_torch_fn = ...
[docs] @QuantizationMixin.implements(nn.TripletMarginLoss) class QuantizedTripletMarginLoss(_DispatchMixin, QuantizationMixin, nn.TripletMarginLoss): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.TripletMarginLoss) _builtin_torch_fn = F.triplet_margin_loss __quant_init__ = __ternary__
[docs] @QuantizationMixin.implements(nn.TripletMarginWithDistanceLoss) class QuantizedTripletMarginWithDistanceLoss(_DispatchMixin, QuantizationMixin, nn.TripletMarginWithDistanceLoss): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.TripletMarginWithDistanceLoss) _builtin_torch_fn = F.triplet_margin_with_distance_loss __quant_init__ = __ternary__
[docs] @QuantizationMixin.implements(nn.Unflatten) class QuantizedUnflatten(_DispatchMixin, QuantizationMixin, nn.Unflatten): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.Unflatten) def _get_builtin_torch_fn(self): return Tensor.unflatten
[docs] @QuantizationMixin.implements(nn.Unfold) class QuantizedUnfold(_DispatchMixin, QuantizationMixin, nn.Unfold): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.Unfold) _builtin_torch_fn = F.unfold __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.Upsample) class QuantizedUpsample(_DispatchMixin, QuantizationMixin, nn.Upsample): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.Upsample) _builtin_torch_fn = F.interpolate __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.UpsamplingBilinear2d) class QuantizedUpsamplingBilinear2d(_DispatchMixin, QuantizationMixin, nn.UpsamplingBilinear2d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.UpsamplingBilinear2d) _builtin_torch_fn = F.interpolate __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.UpsamplingNearest2d) class QuantizedUpsamplingNearest2d(_DispatchMixin, QuantizationMixin, nn.UpsamplingNearest2d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.UpsamplingNearest2d) _builtin_torch_fn = F.interpolate __quant_init__ = __unary__
if version.parse(torch.__version__) >= version.parse("2.1.0"):
[docs] @QuantizationMixin.implements(nn.ZeroPad1d) class QuantizedZeroPad1d(_DispatchMixin, QuantizationMixin, nn.ZeroPad1d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.ZeroPad1d) _builtin_torch_fn = F.pad __quant_init__ = __unary__
[docs] @QuantizationMixin.implements(nn.ZeroPad2d) class QuantizedZeroPad2d(_DispatchMixin, QuantizationMixin, nn.ZeroPad2d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.ZeroPad2d) _builtin_torch_fn = F.pad __quant_init__ = __unary__
if version.parse(torch.__version__) >= version.parse("2.1.0"):
[docs] @QuantizationMixin.implements(nn.ZeroPad3d) class QuantizedZeroPad3d(_DispatchMixin, QuantizationMixin, nn.ZeroPad3d): # pylint: disable=missing-class-docstring __doc__ = _generate_docstring(parent_cls=nn.ZeroPad3d) _builtin_torch_fn = F.pad __quant_init__ = __unary__
del __nullary__ del __unary__ del __binary__ del __ternary__