# -*- mode: python -*-
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
# SPDX-License-Identifier: BSD-3-Clause
#
# @@-COPYRIGHT-END-@@
# =============================================================================
""" Quantized tensor class implementation """
import abc
import copy
import torch
from torch.utils._pytree import tree_map
from aimet_torch.v2.quantization.base import EncodingBase
__all__ = ['QuantizedTensorBase', 'QuantizedTensor', 'DequantizedTensor', 'EncodingError']
HANDLED_FUNCTIONS = {}
def implements(torch_function):
"""
Register an override for QuantizedTensorBase
"""
def decorator(func):
HANDLED_FUNCTIONS[torch_function] = func
return func
return decorator
class QuantizedTensorBase(torch.Tensor):
"""
Abstract base class to define quantized tensor behavior.
Represents a quantized or dequantized tensor as a subclass of :class:`torch.Tensor` which also holds the quantization encodings.
This object can be safely quantized or dequantized through the :meth:`quantize` and :meth:`dequantize` methods without
changing the represented data values.
Example:
>>> from aimet_torch.v2 import quantization as Q
>>> quantizer = Q.affine.Quantize(shape=(2, 1), bitwidth=8, symmetric=True)
>>> x = torch.tensor([[-1.20, 4.1, -0.21, 2.3],
... [0.2, 5.6, -1.0, -.1]])
>>> with quantizer.compute_encodings():
... x_q = quantizer(x)
>>> torch.equal(x_q.encoding.scale, quantizer.get_scale())
True
>>> x_q
QuantizedTensor([[-37., 127., -7., 71.],
[ 5., 127., -23., -2.]])
>>> x_q.quantized_repr()
tensor([[-37, 127, -7, 71],
[ 5, 127, -23, -2]], dtype=torch.int8)
>>> x_q.dequantize()
DequantizedTensor([[-1.1945, 4.1000, -0.2260, 2.2921],
[ 0.2205, 5.6000, -1.0142, -0.0882]])
"""
encoding: EncodingBase
_cast_ops = {
torch.Tensor.half,
torch.Tensor.float,
torch.Tensor.double,
torch.Tensor.char,
torch.Tensor.short,
torch.Tensor.int,
torch.Tensor.long,
torch.Tensor.cuda,
torch.Tensor.cpu,
torch.Tensor.to,
torch.Tensor.type,
torch.Tensor.type_as,
}
# Operations that an encoding can always pass through
_passthrough_ops = {
torch.Tensor.contiguous,
torch.Tensor.requires_grad_,
torch.Tensor.share_memory_,
}
# Operations that a per-tensor encoding can pass through
_pertensor_passthrough_ops = {
torch.Tensor.__getitem__,
torch.Tensor.as_strided,
torch.Tensor.broadcast_to,
torch.Tensor.chunk,
torch.Tensor.dsplit,
torch.Tensor.expand,
torch.Tensor.expand_as,
torch.Tensor.flatten,
torch.Tensor.flip,
torch.Tensor.fliplr,
torch.Tensor.flipud,
torch.Tensor.gather,
torch.Tensor.hsplit,
torch.Tensor.index_select,
torch.Tensor.kthvalue,
torch.Tensor.masked_select,
torch.Tensor.movedim,
torch.Tensor.moveaxis,
torch.Tensor.msort,
torch.Tensor.narrow,
torch.Tensor.permute,
torch.Tensor.repeat,
torch.Tensor.reshape,
torch.Tensor.reshape_as,
torch.Tensor.resize,
torch.Tensor.resize_,
torch.Tensor.resize_as,
torch.Tensor.resize_as_,
torch.Tensor.select,
torch.Tensor.split,
torch.Tensor.squeeze,
torch.Tensor.squeeze_,
torch.Tensor.swapaxes,
torch.Tensor.swapdims,
torch.Tensor.t,
torch.Tensor.t_,
torch.Tensor.take,
torch.Tensor.take_along_dim,
torch.Tensor.tensor_split,
torch.Tensor.tile,
torch.Tensor.transpose,
torch.Tensor.unflatten,
torch.Tensor.unsqueeze,
torch.Tensor.unsqueeze_,
torch.Tensor.view,
torch.Tensor.view_as,
torch.as_strided,
torch.as_strided_copy,
torch.chunk,
torch.dsplit,
torch.expand_copy,
torch.flatten,
torch.flip,
torch.fliplr,
torch.flipud,
torch.gather,
torch.hsplit,
torch.index_select,
torch.masked_select,
torch.moveaxis,
torch.movedim,
torch.narrow,
torch.narrow_copy,
torch.permute,
torch.permute_copy,
torch.reshape,
torch.select,
torch.split,
torch.squeeze,
torch.squeeze_copy,
torch.swapaxes,
torch.swapdims,
torch.t,
torch.take,
torch.take_along_dim,
torch.tensor_split,
torch.tile,
torch.t_copy,
torch.unbind,
torch.unflatten,
torch.unsqueeze,
torch.unsqueeze_copy,
torch.vsplit,
torch.view_copy,
}
@abc.abstractmethod
def quantize(self) -> "QuantizedTensor":
"""
Quantizes ``self`` with the associated encoding
.. note::
This method must be an IDEMPOTENT function.
The result of calling this method multiple times should be equal to calling it only once.
In other words, calling this method multiple times should not result in duplicate quantization.
"""
raise NotImplementedError
@abc.abstractmethod
def dequantize(self) -> "DequantizedTensor":
"""
Dequantizes ``self`` with the associated encoding
.. note::
This method must be an IDEMPOTENT function.
The result of calling this method multiple times should be equal to calling it only once.
In other words, calling this method multiple times should not result in duplicate dequantization.
"""
raise NotImplementedError
@abc.abstractmethod
def quantized_repr(self) -> torch.Tensor:
"""
Return the quantized representation of ``self`` as a :class:`torch.Tensor` with data type :attr:`self.encoding.dtype`
.. note::
The result of this function may not be able to carry a gradient depending on the quantized data type.
Thus, it may be necessary to call this only within an autograd function to allow for backpropagation.
Example:
>>> from aimet_torch.v2 import quantization as Q
>>> quantizer = Q.affine.Quantize(shape=(2, 1), bitwidth=8, symmetric=True)
>>> x = torch.randn((2, 4), requires_grad=True)
>>> with quantizer.compute_encodings():
... x_q = quantizer(x)
>>> x_q
QuantizedTensor([[ 11., -57., -128., 38.],
[ 28., -0., -128., -40.]], grad_fn=<AliasBackward0>)
>>> x_q.quantized_repr()
tensor([[ 11, -57, -128, 38],
[ 28, 0, -128, -40]], dtype=torch.int8)
"""
raise NotImplementedError
@classmethod
def __new__(cls, *args, **kwargs):
encoding = kwargs.pop('encoding', None)
ret = super().__new__(*args, **kwargs)
if not ret.is_floating_point():
raise RuntimeError(f"Non-floating point dtype `{ret.dtype}` is not allowed for quantized tensors.")
ret.encoding = encoding
return ret
def new_empty(self, size, *, dtype=None, device=None, requires_grad=False,
layout=torch.strided, pin_memory=False, **kwargs) -> "QuantizedTensorBase":
# PyTorch requires subclasses of torch.Tensor to override this method such that
# it returns an instance of the subclass, not a plain torch.Tensor,
# for the subclass to be deep-copyable
encoding = kwargs.pop('encoding', None)
t = super().new_empty(size, dtype=dtype, device=device, requires_grad=requires_grad,
layout=layout, pin_memory=pin_memory, **kwargs).as_subclass(type(self))
t.encoding = encoding
return t
@implements(torch.clone)
def clone(self, *, memory_format=torch.preserve_format):
"""
Returns a copy of self
:param memory_format: Desired memory format of the returned tensor (default=torch.preserve_format)
"""
# Note: use encoding.clone() here instead of deepcopy to propagate gradient through operation
encoding_clone = self.encoding and self.encoding._clone() # pylint:disable = protected-access
self_clone = super().clone(memory_format=memory_format).as_subclass(self.__class__)
self_clone.encoding = encoding_clone
return self_clone
@implements(torch.detach)
def detach(self) -> "QuantizedTensorBase":
"""
Returns a new QuantizedTensorBase with data and encoding detached from the current graph
"""
self_detached = super().detach().as_subclass(self.__class__)
self_detached.encoding = self.encoding and self.encoding._detach() # pylint:disable = protected-access
return self_detached
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None): # pylint: disable=too-many-return-statements
if func in HANDLED_FUNCTIONS:
kwargs = kwargs if kwargs is not None else {}
return HANDLED_FUNCTIONS[func](*args, **kwargs)
ret = super().__torch_function__(func, types, args, kwargs)
self, *_ = args
if func in cls._cast_ops:
if not ret.dtype.is_floating_point:
raise RuntimeError(
f"Type casting to non-floating point dtype `{ret.dtype}` is not allowed for quantized tensors. "
"To cast quantized tensors to integer, use `qtensor.quantzed_repr()`."
)
# Outputs of cast ops can inherit the same encoding as its parents
ret.encoding = self.encoding and self.encoding.to(device=ret.device)
return ret
def propagate_encoding(qtensor, encoding):
if isinstance(qtensor, QuantizedTensorBase):
qtensor.encoding = copy.copy(encoding)
if func in cls._passthrough_ops:
if self is ret:
return ret
tree_map(lambda t: propagate_encoding(t, self.encoding), ret)
return ret
if func in cls._pertensor_passthrough_ops:
if self is ret:
return ret
if self.encoding and self.encoding.granularity == "pertensor":
# Return a cls object with the same encoding which can later be quantized or dequantized
tree_map(lambda t: propagate_encoding(t, self.encoding), ret)
else:
# Return a cls object with no encoding
# If the user later tries to quantize or dequantize this, an error will be thrown
tree_map(lambda t: propagate_encoding(t, None), ret)
return ret
if self is ret:
# Non-passthrough in-place functions invalidate the encoding of the input tensor.
# Discard the stale encoding.
ret.encoding = None
return ret
def set_encoding(qtensor):
if not hasattr(qtensor, 'encoding'):
qtensor.encoding = None
if qtensor.encoding is None:
# If encoding does not exist, return a plain torch.Tensor
return qtensor.as_subclass(torch.Tensor)
return qtensor
return tree_map(lambda t: set_encoding(t) if isinstance(t, cls) else t, ret)
[docs]class QuantizedTensor(QuantizedTensorBase):
"""
Represents a quantized tensor object. The object holds quantized values stored in a floating-point tensor along with
an :class:`EncodingBase` object which holds the information necessary to map the quantized values back to the
real/represented values.
"""
[docs] def quantize(self) -> "QuantizedTensor":
"""
Returns ``self``
"""
return self
[docs] def dequantize(self) -> "DequantizedTensor":
"""
Dequantizes ``self`` using :attr:`self.encoding` to produce a :class:`DequantizedTensor` with the same encoding
information.
Example:
>>> from aimet_torch.v2.quantization as Q
>>> x = torch.tensor([[2.57, -2.312],
... [0.153, 0.205]])
>>> quantizer = Q.affine.Quantize(shape=(), bitwidth=8, symmetric=True)
>>> quantizer.set_range(-128 * 0.1, 127 * 0.1)
>>> x_q = quantizer(x)
>>> x_q
QuantizedTensor([[ 26., -23.],
[ 2., 2.]], grad_fn=<AliasBackward0>)
>>> x_dq = x_q.dequantize()
>>> x_dq
DequantizedTensor([[ 2.6000, -2.3000],
[ 0.2000, 0.2000]], grad_fn=<AliasBackward0>)
>>> torch.equal(x_dq.encoding.scale, x_q.encoding.scale)
True
"""
if self.encoding is None:
raise EncodingError("Encoding does not exist")
qtensor = self.encoding.dequantize(self.as_subclass(torch.Tensor))
qtensor = qtensor.as_subclass(DequantizedTensor)
qtensor.encoding = copy.copy(self.encoding)
return qtensor
[docs] def quantized_repr(self) -> torch.Tensor:
# FIXME(kyunggeu): This only works for affine encodings.
# Needs to be generalized for any kind of encodings
if self.encoding is None:
raise EncodingError("Encoding does not exist")
return self.quantize().as_subclass(torch.Tensor).to(self.encoding.dtype)
[docs]class DequantizedTensor(QuantizedTensorBase):
"""
Represents a tensor which has been quantized and subsequently dequantized. This object contains real floating point
data as well as an :class:`EncodingBase` object which holds information about the quantization parameters with which
the data was quantized. With this, a :class:`DequantizedTensor` can be converted back to its quantized representation
without further loss in information.
"""
[docs] def quantize(self) -> QuantizedTensor:
"""
Quantizes ``self`` using :attr:`self.encoding` to produce a :class:`QuantizedTensor` with the same encoding
information.
Example:
>>> import aimet_torch.v2.quantization as Q
>>> x = torch.tensor([[0.39, 51.0], [3.521, 9.41]])
>>> quant_dequant = Q.affine.QuantizeDequantize(shape=(), bitwidth=8, symmetric=False)
>>> quant_dequant.set_range(-10, 41)
>>> x_qdq = quant_dequant(x)
>>> x_qdq
DequantizedTensor([[ 0.4000, 41.0000],
[ 3.6000, 9.4000]], grad_fn=<AliasBackward0>)
>>> x_qdq.quantize()
QuantizedTensor([[ 52., 255.],
[ 68., 97.]], grad_fn=<AliasBackward0>)
"""
if self.encoding is None:
raise EncodingError("Encoding does not exist")
qtensor = self.encoding.quantize(self.as_subclass(torch.Tensor))
qtensor = qtensor.as_subclass(QuantizedTensor)
qtensor.encoding = copy.copy(self.encoding)
return qtensor
[docs] def dequantize(self) -> "DequantizedTensor":
"""
Returns ``self``
"""
return self
[docs] def quantized_repr(self) -> torch.Tensor:
"""
Return the quantized representation of ``self`` as a :class:`torch.Tensor` with data type :attr:`self.encoding.dtype`.
.. note::
The result of this function may not be able to carry a gradient depending on the quantized data type.
Thus, it may be necessary to call this only within an autograd function to allow for backpropagation.
Example:
>>> import aimet_torch.v2.quantization as Q
>>> x = torch.tensor([[0.39, 51.0], [3.521, 9.41]])
>>> quant_dequant = Q.affine.QuantizeDequantize(shape=(), bitwidth=8, symmetric=False)
>>> quant_dequant.set_range(-10, 41)
>>> x_qdq = quant_dequant(x)
>>> x_qdq
DequantizedTensor([[ 0.4000, 41.0000],
[ 3.6000, 9.4000]], grad_fn=<AliasBackward0>)
>>> x_qdq.quantized_repr()
tensor([[ 52, 255],
[ 68, 97]], dtype=torch.uint8)
"""
# FIXME(kyunggeu): This only works for affine encodings.
# Needs to be generalized for any kind of encodings
if self.encoding is None:
raise EncodingError("Encoding does not exist")
return self.quantize().as_subclass(torch.Tensor).to(self.encoding.dtype)
class EncodingError(RuntimeError):
"""Error that indicates an encoding is missing or invalid"""