# -*- mode: python -*-
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
# SPDX-License-Identifier: BSD-3-Clause
#
# @@-COPYRIGHT-END-@@
# =============================================================================
""" Quantized tensor class implementation """
import abc
import copy
import torch
from torch.utils._pytree import tree_map, tree_flatten
from aimet_torch.v2.quantization.base import EncodingBase
__all__ = ['QuantizedTensorBase', 'QuantizedTensor', 'DequantizedTensor', 'EncodingError']
HANDLED_FUNCTIONS = {}
def implements(torch_function):
"""
Register an override for QuantizedTensorBase
"""
def decorator(func):
HANDLED_FUNCTIONS[torch_function] = func
return func
return decorator
class QuantizedTensorBase(torch.Tensor):
"""
Abstract base class to define quantized tensor behavior.
Represents a quantized or dequantized tensor as a subclass of :class:`torch.Tensor` which also holds the quantization encodings.
This object can be safely quantized or dequantized through the :meth:`quantize` and :meth:`dequantize` methods without
changing the represented data values.
Example:
>>> from aimet_torch.v2 import quantization as Q
>>> quantizer = Q.affine.Quantize(shape=(2, 1), bitwidth=8, symmetric=True)
>>> x = torch.tensor([[-1.20, 4.1, -0.21, 2.3],
... [0.2, 5.6, -1.0, -.1]])
>>> with quantizer.compute_encodings():
... x_q = quantizer(x)
>>> torch.equal(x_q.encoding.scale, quantizer.get_scale())
True
>>> x_q
QuantizedTensor([[-37., 127., -7., 71.],
[ 5., 127., -23., -2.]])
>>> x_q.quantized_repr()
tensor([[-37, 127, -7, 71],
[ 5, 127, -23, -2]], dtype=torch.int8)
>>> x_q.dequantize()
DequantizedTensor([[-1.1945, 4.1000, -0.2260, 2.2921],
[ 0.2205, 5.6000, -1.0142, -0.0882]])
"""
encoding: EncodingBase
_cast_ops = [
torch.Tensor.half,
torch.Tensor.float,
torch.Tensor.double,
torch.Tensor.char,
torch.Tensor.short,
torch.Tensor.int,
torch.Tensor.long,
torch.Tensor.cuda,
torch.Tensor.cpu,
torch.Tensor.to,
]
# Operations that an encoding can always pass through
_passthrough_ops = {
torch.Tensor.contiguous,
}
# Operations that a per-tensor encoding can pass through
_pertensor_passthrough_ops = {
torch.Tensor.broadcast_to,
torch.Tensor.expand,
torch.Tensor.expand_as,
torch.Tensor.flatten,
torch.Tensor.masked_select,
torch.Tensor.permute,
torch.Tensor.repeat,
torch.Tensor.reshape,
torch.Tensor.reshape_as,
torch.Tensor.resize,
torch.Tensor.resize_as,
torch.Tensor.select,
torch.Tensor.squeeze,
torch.Tensor.swapaxes,
torch.Tensor.swapdims,
torch.Tensor.t,
torch.Tensor.transpose,
torch.Tensor.unflatten,
torch.Tensor.unsqueeze,
torch.Tensor.view,
torch.Tensor.view_as,
torch.as_strided,
#torch.as_strided_copy, TODO: Uncomment when pytorch 1.9 support is fully deprecated
#torch.expand_copy,
torch.flatten,
torch.permute,
#torch.permute_copy,
torch.reshape,
torch.squeeze,
#torch.squeeze_copy,
torch.swapdims,
torch.t,
#torch.t_copy,
#torch.unflatten,
torch.unsqueeze,
#torch.unsqueeze_copy,
#torch.view_copy
}
@abc.abstractmethod
def quantize(self) -> "QuantizedTensor":
"""
Quantizes ``self`` with the associated encoding
.. note::
This method must be an IDEMPOTENT function.
The result of calling this method multiple times should be equal to calling it only once.
In other words, calling this method multiple times should not result in duplicate quantization.
"""
raise NotImplementedError
@abc.abstractmethod
def dequantize(self) -> "DequantizedTensor":
"""
Dequantizes ``self`` with the associated encoding
.. note::
This method must be an IDEMPOTENT function.
The result of calling this method multiple times should be equal to calling it only once.
In other words, calling this method multiple times should not result in duplicate dequantization.
"""
raise NotImplementedError
@abc.abstractmethod
def quantized_repr(self) -> torch.Tensor:
"""
Return the quantized representation of ``self`` as a :class:`torch.Tensor` with data type :attr:`self.encoding.dtype`
.. note::
The result of this function may not be able to carry a gradient depending on the quantized data type.
Thus, it may be necessary to call this only within an autograd function to allow for backpropagation.
Example:
>>> from aimet_torch.v2 import quantization as Q
>>> quantizer = Q.affine.Quantize(shape=(2, 1), bitwidth=8, symmetric=True)
>>> x = torch.randn((2, 4), requires_grad=True)
>>> with quantizer.compute_encodings():
... x_q = quantizer(x)
>>> x_q
QuantizedTensor([[ 11., -57., -128., 38.],
[ 28., -0., -128., -40.]], grad_fn=<AliasBackward0>)
>>> x_q.quantized_repr()
tensor([[ 11, -57, -128, 38],
[ 28, 0, -128, -40]], dtype=torch.int8)
"""
raise NotImplementedError
@classmethod
def __new__(cls, *args, **kwargs):
encoding = kwargs.pop('encoding', None)
ret = super().__new__(*args, **kwargs)
if not ret.is_floating_point():
raise RuntimeError(f"Non-floating point dtype `{ret.dtype}` is not allowed for quantized tensors.")
ret.encoding = encoding
return ret
@implements(torch.clone)
def clone(self, *, memory_format=torch.preserve_format):
"""
Returns a copy of self
:param memory_format: Desired memory format of the returned tensor (default=torch.preserve_format)
"""
# Note: use encoding.clone() here instead of deepcopy to propagate gradient through operation
encoding_clone = self.encoding._clone() # pylint:disable = protected-access
self_clone = super().clone(memory_format=memory_format).as_subclass(self.__class__)
self_clone.encoding = encoding_clone
return self_clone
@implements(torch.detach)
def detach(self) -> "QuantizedTensorBase":
"""
Returns a new QuantizedTensorBase with data and encoding detached from the current graph
"""
self_detached = super().detach().as_subclass(self.__class__)
self_detached.encoding = self.encoding._detach() # pylint:disable = protected-access
return self_detached
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if func in HANDLED_FUNCTIONS:
kwargs = kwargs if kwargs is not None else {}
return HANDLED_FUNCTIONS[func](*args, **kwargs)
ret = super().__torch_function__(func, types, args, kwargs)
flattened_args, _ = tree_flatten((args, kwargs))
if any(ret is arg for arg in flattened_args):
# Return value is the same object as one of the arguments.
# This implies that func is likely (but not necessarily) an in-place operator.
return ret
if func in cls._cast_ops:
if not ret.dtype.is_floating_point:
raise RuntimeError(
f"Type casting to non-floating point dtype `{ret.dtype}` is not allowed for quantized tensors. "
"To cast quantized tensors to integer, use `qtensor.quantzed_repr()`."
)
# Outputs of cast ops can inherit the same encoding as its parents
self, *_ = args
ret.encoding = copy.copy(self.encoding) # shallow copy
if func in cls._passthrough_ops:
self, *_ = args
ret.encoding = copy.copy(self.encoding)
if func in cls._pertensor_passthrough_ops:
self, *_ = args
if self.encoding.granularity == "pertensor":
# Return a cls object with the same encoding which can later be quantized or dequantized
ret.encoding = copy.copy(self.encoding)
else:
# Return a cls object with no encoding
# If the user later tries to quantize or dequantize this, an error will be thrown
ret.encoding = None
return ret
def set_encoding(qtensor):
if not hasattr(qtensor, 'encoding'):
qtensor.encoding = None
if qtensor.encoding is None:
# If encoding does not exist, return a plain torch.Tensor
return qtensor.as_subclass(torch.Tensor)
# Change device of encoding
# NOTE: We don't change the dtypes of encoding because scale/offset
# are sensitive to dtype
qtensor.encoding = qtensor.encoding.to(device=qtensor.device)
return qtensor
return tree_map(lambda t: set_encoding(t) if isinstance(t, cls) else t, ret)
[docs]class QuantizedTensor(QuantizedTensorBase):
"""
Represents a quantized tensor object. The object holds quantized values stored in a floating-point tensor along with
an :class:`EncodingBase` object which holds the information necessary to map the quantized values back to the
real/represented values.
"""
[docs] def quantize(self) -> "QuantizedTensor":
"""
Returns ``self``
"""
if self.encoding is None:
raise EncodingError("Encoding does not exist")
return self
[docs] def dequantize(self) -> "DequantizedTensor":
"""
Dequantizes ``self`` using :attr:`self.encoding` to produce a :class:`DequantizedTensor` with the same encoding
information.
Example:
>>> from aimet_torch.v2.quantization as Q
>>> x = torch.tensor([[2.57, -2.312],
... [0.153, 0.205]])
>>> quantizer = Q.affine.Quantize(shape=(1, ), bitwidth=8, symmetric=True)
>>> quantizer.set_range(-128 * 0.1, 127 * 0.1)
>>> x_q = quantizer(x)
>>> x_q
QuantizedTensor([[ 26., -23.],
[ 2., 2.]], grad_fn=<AliasBackward0>)
>>> x_dq = x_q.dequantize()
>>> x_dq
DequantizedTensor([[ 2.6000, -2.3000],
[ 0.2000, 0.2000]], grad_fn=<AliasBackward0>)
>>> torch.equal(x_dq.encoding.scale, x_q.encoding.scale)
True
"""
if self.encoding is None:
raise EncodingError("Encoding does not exist")
qtensor = self.encoding.dequantize(self.as_subclass(torch.Tensor))
qtensor = qtensor.as_subclass(DequantizedTensor)
qtensor.encoding = copy.copy(self.encoding)
return qtensor
[docs] def quantized_repr(self) -> torch.Tensor:
# FIXME(kyunggeu): This only works for affine encodings.
# Needs to be generalized for any kind of encodings
return self.quantize().as_subclass(torch.Tensor).to(self.encoding.dtype)
[docs]class DequantizedTensor(QuantizedTensorBase):
"""
Represents a tensor which has been quantized and subsequently dequantized. This object contains real floating point
data as well as an :class:`EncodingBase` object which holds information about the quantization parameters with which
the data was quantized. With this, a :class:`DequantizedTensor` can be converted back to its quantized representation
without further loss in information.
"""
def __getstate__(self):
state = self.__dict__
state["data"] = self.data
state["encoding"] = self.encoding
return state
def __setstate__(self, state):
self.data = state["data"]
self.encoding = state["encoding"]
def __deepcopy__(self, memo):
new_instance = type(self).__new__(type(self))
state = self.__getstate__()
new_instance.__setstate__(state)
new_instance.encoding = copy.deepcopy(state["encoding"])
new_instance.data = copy.deepcopy(state["data"])
return new_instance
[docs] def quantize(self) -> QuantizedTensor:
"""
Quantizes ``self`` using :attr:`self.encoding` to produce a :class:`QuantizedTensor` with the same encoding
information.
Example:
>>> import aimet_torch.v2.quantization as Q
>>> x = torch.tensor([[0.39, 51.0], [3.521, 9.41]])
>>> quant_dequant = Q.affine.QuantizeDequantize((1, ), 8, symmetric=False)
>>> quant_dequant.set_range(-10, 41)
>>> x_qdq = quant_dequant(x)
>>> x_qdq
DequantizedTensor([[ 0.4000, 41.0000],
[ 3.6000, 9.4000]], grad_fn=<AliasBackward0>)
>>> x_qdq.quantize()
QuantizedTensor([[ 52., 255.],
[ 68., 97.]], grad_fn=<AliasBackward0>)
"""
if self.encoding is None:
raise EncodingError("Encoding does not exist")
qtensor = self.encoding.quantize(self.as_subclass(torch.Tensor))
qtensor = qtensor.as_subclass(QuantizedTensor)
qtensor.encoding = copy.copy(self.encoding)
return qtensor
[docs] def dequantize(self) -> "DequantizedTensor":
"""
Returns ``self``
"""
if self.encoding is None:
raise EncodingError("Encoding does not exist")
return self
[docs] def quantized_repr(self) -> torch.Tensor:
"""
Return the quantized representation of ``self`` as a :class:`torch.Tensor` with data type :attr:`self.encoding.dtype`.
.. note::
The result of this function may not be able to carry a gradient depending on the quantized data type.
Thus, it may be necessary to call this only within an autograd function to allow for backpropagation.
Example:
>>> import aimet_torch.v2.quantization as Q
>>> x = torch.tensor([[0.39, 51.0], [3.521, 9.41]])
>>> quant_dequant = Q.affine.QuantizeDequantize((1, ), 8, symmetric=False)
>>> quant_dequant.set_range(-10, 41)
>>> x_qdq = quant_dequant(x)
>>> x_qdq
DequantizedTensor([[ 0.4000, 41.0000],
[ 3.6000, 9.4000]], grad_fn=<AliasBackward0>)
>>> x_qdq.quantized_repr()
tensor([[ 52, 255],
[ 68, 97]], dtype=torch.uint8)
"""
# FIXME(kyunggeu): This only works for affine encodings.
# Needs to be generalized for any kind of encodings
return self.quantize().as_subclass(torch.Tensor).to(self.encoding.dtype)
class EncodingError(RuntimeError):
"""Error that indicates an encoding is missing or invalid"""