# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
# SPDX-License-Identifier: BSD-3-Clause
# pylint: disable=too-many-lines, redefined-builtin
"""Utilities that are used for different AIMET PyTorch features"""
import itertools
from typing import (
List,
Tuple,
Union,
Dict,
Callable,
Any,
Iterable,
Optional,
TextIO,
Mapping,
)
import contextlib
from contextlib import contextmanager, ExitStack
import os
import pickle
import logging
import functools
from packaging import version
import torch.nn
import torch
from torch.utils.data import DataLoader, Dataset
from torch.utils._pytree import tree_map
from torch.nn.modules.module import (
_global_backward_hooks,
_global_forward_pre_hooks,
_global_forward_hooks,
)
try:
from torch.nn.modules.module import _global_backward_pre_hooks
except ImportError:
_global_backward_pre_hooks = None
from torchvision import datasets, transforms
from aimet_torch.common.utils import AimetLogger, Handle
from aimet_torch.common.utils import profile as _profile
logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Utils)
dtypes_to_ignore_for_quantization = (int, bool, str, tuple, type(None))
torch_dtypes_to_ignore_for_quantization = [
torch.int,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.bool,
torch.uint8,
]
allowed_output_types = (torch.Tensor, float, *dtypes_to_ignore_for_quantization)
DROPOUT_TYPES = (torch.nn.Dropout, torch.nn.Dropout2d, torch.nn.Dropout3d)
# list of modules which need to be treated as a leaf module
modules_to_treat_as_leaf = []
# list of modules not to treat as leaf
modules_not_to_treat_as_leaf = [torch.nn.ModuleList, torch.nn.ModuleDict]
class StopForwardException(Exception):
"""
Dummy exception to early-terminate forward-pass
"""
class ModuleData:
"""
Collect input and output data to and from module
"""
def __init__(
self,
model: torch.nn.Module,
module: torch.nn.Module,
forward_fn: Callable[[torch.nn.Module, Any], Any] = None,
):
"""
:param model: Pytorch model
:param module: Module reference
:param forward_fn: Adapter function that performs forward pass given a model and inputs
yielded from the data loader.
"""
self._model = model
self._module = module
self._forward_fn = forward_fn or self.default_forward_fn
def collect_inp_out_data(
self, args, kwargs: Mapping[str, Any], collect_input: bool, collect_output: bool
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Collect input and output data depending on the collect_input and collect_output flag
:param model_input: Input to model, Can be a single tensor or a list/tuple of tensors
:param collect_input: Boolean to collect input or not
:param collect_output: Boolean to collect output or not
:return: Module's input and output data
"""
def adjust_input_dtype(module, inp):
if hasattr(module, "weight") and module.weight is not None:
dtype = module.weight.dtype
# Cast input to dtype only if it is a floating point tensor (float, half, bfloat16, etc.).
# If input is a non-float tensor (e.g. long, bool), leave the input uncasted.
return tree_map(
lambda x: x.to(dtype)
if isinstance(x, torch.Tensor) and x.is_floating_point()
else x,
inp,
)
return inp
handles = [
mod.register_forward_pre_hook(adjust_input_dtype)
for mod in self._model.modules()
]
def _hook_to_collect_inp_out_data(_, inp, out):
"""
hook to collect input and output data
"""
if collect_input:
inp_data_list.append(inp[0])
if collect_output:
out_data_list.append(out)
raise StopForwardException
inp_data_list = []
out_data_list = []
handles.append(
self._module.register_forward_hook(_hook_to_collect_inp_out_data)
)
# get the model's device placement information
device = get_device(self._model)
# place the input to appropriate device
args = change_tensor_device_placement(args, device)
kwargs = change_tensor_device_placement(kwargs, device)
# Custom injected exception is raised when the activations data from desired module is collected.
try:
with in_eval_mode(self._model), torch.no_grad():
_ = self._forward_fn(self._model, *args, **kwargs)
except StopForwardException:
pass
finally:
# remove hook handle
for handle in handles:
handle.remove()
inp_data, out_data = None, None
if inp_data_list and isinstance(inp_data_list[0], torch.Tensor):
inp_data = inp_data_list[0].detach()
if out_data_list and isinstance(out_data_list[0], torch.Tensor):
out_data = out_data_list[0].detach()
return inp_data, out_data
@staticmethod
def default_forward_fn(
model: torch.nn.Module,
inputs: Union[torch.tensor, List[torch.Tensor], Tuple[torch.Tensor]],
):
"""
Default forward function that performs forward pass given a model and inputs yielded from
the data loader. Data loader which yields torch.Tensor object that can be directly
passed into the model, or a data loader which yields a tuple of length two where its
first element can be directly passed into the model.
:param model: PyTorch model.
:param inputs: Inputs passed to model.
"""
# When provided dataloader is labeled (model_inputs, labels), then ignore the second element (labels).
if isinstance(inputs, (list, tuple)):
inputs, _ = inputs
if isinstance(inputs, torch.Tensor):
inputs = [inputs]
model(*inputs)
class CachedDataset(Dataset):
"""
Cache number of batches from the data loader at given path location and
provide interface to fetch single batch of model inputs.
"""
# pylint: disable=super-init-not-called
def __init__(self, data_loader: DataLoader, num_batches: Optional[int], path: str):
"""
:param data_loader: Data loader
:param num_batches: Number of batches to fetch from data loader
:param path: Path to save model inputs
"""
if data_loader:
if num_batches is not None and len(data_loader) < num_batches:
raise ValueError(
f"Can not fetch {num_batches} batches from "
f"a data loader of length {len(data_loader)}."
)
self._num_batches = None
self._path = path
if num_batches is None:
self._cache_model_inputs(data_loader)
else:
self._cache_model_inputs(itertools.islice(data_loader, num_batches))
assert self._num_batches is not None
else:
assert len(os.listdir(path)) == num_batches
self._num_batches = num_batches
self._path = path
logger.info(
"Found %d batches of data at path location: %s",
self._num_batches,
self._path,
)
def __len__(self):
return self._num_batches
def __getitem__(self, index: int):
path = os.path.join(self._path, "model_inputs_" + str(index))
with open(path, "rb") as file:
batch = pickle.load(file)
return batch
def __iter__(self):
for i in range(self.__len__()):
yield self.__getitem__(i)
def _cache_model_inputs(self, data_loader):
"""
Function to cache number of batches individually in separate file at provided path location
"""
if not os.path.exists(self._path):
os.makedirs(self._path)
for i, batch in enumerate(data_loader):
path = os.path.join(self._path, f"model_inputs_{i}")
args = (batch,)
kwargs = {}
with open(path, "wb") as file:
pickle.dump((args, kwargs), file)
self._num_batches = i + 1
logger.info(
"Caching %d batches from data loader at path location: %s",
self._num_batches,
self._path,
)
def run_hook_for_layers(
model: torch.nn.Module,
input_shapes: Union[Tuple, List[Tuple]],
hook,
module_type_for_attaching_hook=None,
leaf_node_only=True,
):
"""
Register the given hook function for all layers in the model
:param model: Model
:param input_shapes: Shape of inputs to pass to the model
:param hook: Hook function to register
:param module_type_for_attaching_hook: Tuple of torch.nn module types for which hook has to be attached
:param leaf_node_only: Set to False if all modules are required
:return: None
"""
# ------------------------
# Register hook function
# ------------------------
hooks = []
# All leaf modules
modules = [
module
for module in model.modules()
if not leaf_node_only or is_leaf_module(module)
]
if module_type_for_attaching_hook:
# if needed, filter by module types specified by caller
modules = [
module
for module in modules
if isinstance(module, module_type_for_attaching_hook)
]
for module in modules:
hooks.append(module.register_forward_hook(hook))
# ------------------------------------------------
# Run forward pass to execute the hook functions
# ------------------------------------------------
device = get_device(model)
dummy_tensors = create_rand_tensors_given_shapes(input_shapes, device)
with in_eval_mode(model), torch.no_grad():
_ = model(*dummy_tensors)
# --------------------------
# Remove all hooks we added
# --------------------------
for h in hooks:
h.remove()
def run_hook_for_layers_with_given_input(
model: torch.nn.Module,
input_tensor: Union[torch.Tensor, Tuple],
hook,
module_type_for_attaching_hook=None,
leaf_node_only=True,
fwd_func=None,
):
"""
Register the given hook function for all layers in the model
:param model: Model
:param input_tensor: Input tensor to the model. If more than one model inputs, use a tuple
:param hook: Hook function to register
:param module_type_for_attaching_hook: Tuple of torch.nn module types for which hook has to be attached
:param leaf_node_only: Set to False if all modules are required
:param fwd_func: forward function for model inference
:return: None
"""
# pylint: disable=too-many-branches
# ------------------------
# Register hook function
# ------------------------
hooks = []
# All leaf modules
modules = []
# Based on the modules in modules_to_treat_as_leaf, we do not want to further continue searching for next level
# of modules present in modules_to_treat_as_leaf. To achieve this, save them in modules_to_skip
modules_to_skip = set()
for module in model.modules():
if module not in modules_to_skip:
# pylint: disable=protected-access
if isinstance(module, tuple(modules_to_treat_as_leaf)):
modules.append(module)
# check for modules inside the 'module' and add them to modules_to_skip
for sub_module in module._modules.values():
modules_to_skip.add(sub_module)
else:
if leaf_node_only:
if is_leaf_module(module):
modules.append(module)
else:
modules.append(module)
if module_type_for_attaching_hook:
# if needed, filter by module types specified by caller
modules = [
module
for module in modules
if isinstance(module, module_type_for_attaching_hook)
]
try:
for module in modules:
hooks.append(module.register_forward_hook(hook))
# ------------------------------------------------
# Run forward pass to execute the hook functions
# ------------------------------------------------
with in_eval_mode(model), torch.no_grad():
if fwd_func:
_ = fwd_func(model, input_tensor)
else:
if isinstance(input_tensor, (list, tuple)):
_ = model(*input_tensor)
elif isinstance(input_tensor, dict):
try:
_ = model(**input_tensor)
except TypeError:
# Some models require inputs as dict.
# https://github.com/pytorch/vision/blob/ef2920cc80bac61282b3b19a775b3c33de4e7551/torchvision/ops/feature_pyramid_network.py#L172
_ = model(input_tensor)
else:
_ = model(input_tensor)
finally:
# --------------------------
# Remove all hooks we added
# --------------------------
for h in hooks:
h.remove()
def create_fake_data_loader(dataset_size: int, batch_size: int, image_size=(1, 28, 28)):
"""
Helper function to create fake data loader which is default image size (1, 28, 28)
:param dataset_size : total images in data set
:param batch_size : batch size
:param image_size : size of input
:return:
"""
transform = transforms.Compose([transforms.ToTensor()])
data_loader = torch.utils.data.DataLoader(
datasets.FakeData(
size=dataset_size,
image_size=image_size,
num_classes=10,
transform=transform,
target_transform=None,
),
batch_size=batch_size,
shuffle=False,
)
return data_loader
def get_module_to_name_dict(
model: torch.nn.Module, prefix: str = ""
) -> Dict[torch.nn.Module, str]:
"""
Get a dictionary mapping model modules to names
:param model: Model to get mapping for
:param prefix: Prefix string to prepend to names
:return: Dictionary mapping model modules to names
"""
module_to_name_dict = {}
for name, module in model.named_modules(prefix=prefix):
module_to_name_dict[module] = name
return module_to_name_dict
def get_layer_name(model, layer):
"""
Helper function to get layer name given model and layer reference
:param model: model (nn.Module)
:param layer: layer reference
:return:
"""
for name, module in model.named_modules():
if module is layer:
return name
raise KeyError(f"Couldn't find layer {layer} from model {model}")
def is_model_on_gpu(model):
"""
Function to check whether given model is created on GPU or CPU
Assumption : model is on single device
:return:
True if the model is on GPU, False if on CPU
"""
return next(model.parameters()).is_cuda
def get_device(model):
"""
Function to find which device is model on
Assumption : model is on single device
:param model:
:return: Device on which model is present
"""
return next(model.parameters()).device
def is_leaf_module(module):
"""Utility function to determine if the given module is a leaf module - that is, does not have children modules
:return:
True if the module is a leaf, False otherwise
"""
# pylint: disable=import-outside-toplevel
from aimet_torch._base.nn.modules._spconv import CustomSparseConv3DLayer
try:
_ = next(module.children())
except StopIteration:
has_child = False
else:
has_child = True
return (
not has_child
or type(module) in modules_to_treat_as_leaf
or (
CustomSparseConv3DLayer is not None
and isinstance(module, CustomSparseConv3DLayer)
)
) and not isinstance(module, tuple(modules_not_to_treat_as_leaf))
def has_hooks(module: torch.nn.Module):
"""Returns True if the module uses hooks."""
# pylint: disable=protected-access
return (
module._backward_hooks
or module._backward_pre_hooks
or module._forward_hooks
or module._forward_pre_hooks
or _global_backward_pre_hooks
or _global_backward_hooks
or _global_forward_hooks
or _global_forward_pre_hooks
)
def get_ordered_list_of_modules(
model: torch.nn.Module,
dummy_input: Union[torch.Tensor, List[torch.Tensor], Tuple],
fwd_func=None,
ignore_duplicates=False,
) -> List:
"""
Finds ordered modules in given model.
:param model: PyTorch model.
:param dummy_input: Dummy input to the model. Used to parse model graph.
:param fwd_func: forward function for model inference
:param ignore_duplicates: If True, don't add a module to ordered_list again if it was already seen before.
:return: List of module name, module in order.
"""
seen_modules = set()
def _hook_to_collect_name_of_module(module, _, __):
"""
hook to find name of module
"""
module_name = module_to_name_dict[module]
if module_name in seen_modules and ignore_duplicates:
return
list_modules.append([module_name, module])
seen_modules.add(module_name)
module_to_name_dict = {}
for name, module in model.named_modules():
module_to_name_dict[module] = name
list_modules = []
run_hook_for_layers_with_given_input(
model, dummy_input, hook=_hook_to_collect_name_of_module, fwd_func=fwd_func
)
return list_modules
def replace_modules(
model: torch.nn.Module,
condition: Callable[[torch.nn.Module], bool],
factory: Callable[[torch.nn.Module], torch.nn.Module],
):
"""
Replace all modules that satisfy the given condition
"""
def fn(parent):
for name, child in parent.named_children():
if condition(child):
setattr(parent, name, factory(child))
model.apply(fn)
def create_rand_tensors_given_shapes(input_shape, device: torch.device):
"""
Given shapes of some tensors, create one or more random tensors and return them as a list of tensors
:param input_shape: Shapes of tensors to create (possibly nested) tuple of integers
:param device: Device to create tensors on
:return: Created list of tensors
"""
try:
input_shapes = [torch.Size(input_shape)]
except TypeError:
input_shapes = input_shape
rand_tensors = []
for shape in input_shapes:
try:
t = torch.rand(torch.Size(shape), device=device)
except TypeError:
t = create_rand_tensors_given_shapes(shape, device)
rand_tensors.append(t)
return rand_tensors
def get_ordered_lists_of_conv_fc(
model: torch.nn.Module, dummy_input: Union[torch.Tensor, Tuple, List]
) -> List:
"""
Finds order of nodes in graph
:param model: model
:param dummy_input: A dummy input to the model. Can be a Tensor or a Tuple of Tensors
:return: List of names in graph in order
"""
module_list = get_ordered_list_of_modules(model, dummy_input)
module_list = [
[name, module]
for name, module in module_list
if isinstance(
module,
(
torch.nn.Conv1d,
torch.nn.Conv2d,
torch.nn.Linear,
torch.nn.ConvTranspose2d,
torch.nn.Conv3d,
),
)
]
return module_list
def change_tensor_device_placement(input_data, device: torch.device):
"""
Change the tensor_data's device placement
:param input_data: torch.tensor , list of torch.tensors, or tuple of torch.tensors
:param device: device
:return: tensor_data with modified device placement
"""
return tree_map(
lambda x: x.to(device) if isinstance(x, torch.Tensor) else x, input_data
)
def nested_map(data, fn: Callable[[torch.Tensor], torch.Tensor]):
"""
Apply a function to a nested tuple, list, or dict of tensors.
:param data: Tensor, or a nested tuple, list, or dict of tensors.
:param fn: Function to apply to the tensors
:return: Nested structure of tensors with function applied
"""
if isinstance(data, torch.Tensor):
return fn(data)
if isinstance(data, (tuple, list)):
cls = tuple if isinstance(data, tuple) else list
return cls(nested_map(x, fn) for x in data)
if isinstance(data, dict):
return {key: nested_map(value, fn) for key, value in data.items()}
logger.debug(
"unexpected input type=%s, expecting torch.Tensor, tuple, list, or dict. skipping..",
type(data),
)
return data
def find_num_inout_tensors_per_module(model: torch.nn.Module, input_tensor) -> Dict:
"""
Returns a map of module -> number of output tensors, for all the children modules of the
provided module
:param model: Torch module to find children modules for
:param input_tensor: Input tensor to use to run forward pass for the model. If model needs more than one input
tensor, pass a tuple
:return: map of module -> number of output tensors
"""
num_inout_map = {}
def record_num_outputs(module, inputs, outputs):
num_inputs = len(inputs) if isinstance(inputs, (List, Tuple)) else 1
num_outputs = len(outputs) if isinstance(outputs, (List, Tuple)) else 1
num_inout_map[module] = (num_inputs, num_outputs)
run_hook_for_layers_with_given_input(model, input_tensor, record_num_outputs)
return num_inout_map
def get_reused_modules(
model: torch.nn.Module, model_input: Union[torch.Tensor, Tuple]
) -> List[Tuple[str, torch.nn.Module]]:
"""
Identify modules which are used more than once in the model
:param model: Model to check for modules used more than once
:param model_input: Input to the model
:return: List of tuples of name and module for modules in the model which are used more than once
"""
module_set = set()
reused_modules_set = set()
def forward_hook(curr_module, _, _1):
"""
Custom forward hook function to add modules to module_set and reused_module_set.
:param curr_module: Current module being traversed during forward pass.
:param _1: Unused param
"""
if curr_module in module_set:
reused_modules_set.add(curr_module)
else:
module_set.add(curr_module)
run_hook_for_layers_with_given_input(model, model_input, forward_hook)
reused_modules_list = []
for name, module in model.named_modules():
if is_leaf_module(module) and module in reused_modules_set:
reused_modules_list.append((name, module))
return reused_modules_list
@contextlib.contextmanager
def in_eval_mode(module: Union[torch.nn.Module, Iterable[torch.nn.Module]]):
"""
Utility to temporarily put model in eval mode using context manager.
:param module: PyTorch module or a list of modules
:return: None
"""
with _in_mode(module, train=False):
yield
@contextlib.contextmanager
def in_train_mode(module: Union[torch.nn.Module, Iterable[torch.nn.Module]]):
"""
Utility to temporarily put model in train mode using context manager.
:param module: PyTorch module or a list of modules
:return: None
"""
with _in_mode(module, train=True):
yield
@contextlib.contextmanager
def _in_mode(modules: Union[torch.nn.Module, Iterable[torch.nn.Module]], train: bool):
if isinstance(modules, torch.nn.Module):
modules = (modules,)
modules = set(itertools.chain(*(m.modules() for m in modules)))
original_modes = {module: module.training for module in modules}
try:
for module in modules:
module.training = train
yield
finally:
for module, original_mode in original_modes.items():
module.training = original_mode
def is_torch_nn_module(module: torch.nn.Module) -> bool:
"""
Utility function to determine if the given module is from torch.nn class or not.
For modules like torch.nn.Conv2d, the utility will return True.
:param module: PyTorch module.
:return: True if the module from torch.nn class, False otherwise
"""
return (
isinstance(module, torch.nn.Module)
and type(module) in torch.nn.__dict__.values()
)
def is_torch_nn_leaf_module(module: torch.nn.Module) -> bool:
"""
Utility function to determine if the given module is leaf and from torch.nn class or not.
:param module: PyTorch module.
:return: True if the module is leaf and from torch.nn class, False otherwise
"""
torch_nn_leaf_module = False
if is_leaf_module(module) and is_torch_nn_module(module):
torch_nn_leaf_module = True
return torch_nn_leaf_module
def get_torch_tensortype_shape(
torch_graph_output: torch._C.TensorType,
) -> Union[None, List[int]]:
"""
Given an output tensor from a torch graph, return its shape, or return None if the output tensor is not a
tensortype.
"""
# pylint: disable=protected-access
shape = None
if isinstance(torch_graph_output.type(), torch._C.TensorType):
shape = torch_graph_output.type().sizes()
return shape
def get_all_quantizers(model: torch.nn.Module):
"""
Get all the quantizers in the model
:param model: Root module
:returns: List of parameter, input, and output quantizers
"""
param_quantizers = []
input_quantizers = []
output_quantizers = []
for module in model.modules():
_param_qtzrs = getattr(module, "param_quantizers", {}).values()
_input_qtzrs = getattr(module, "input_quantizers", [])
_output_qtzrs = getattr(module, "output_quantizers", [])
if _param_qtzrs:
param_quantizers.extend(_param_qtzrs)
if _input_qtzrs:
input_quantizers.extend(
_input_qtzrs.values()
if isinstance(_input_qtzrs, dict)
else _input_qtzrs
)
if _output_qtzrs:
output_quantizers.extend(
_output_qtzrs.values()
if isinstance(_output_qtzrs, dict)
else _output_qtzrs
)
return param_quantizers, input_quantizers, output_quantizers
def disable_all_quantizers(model: torch.nn.Module):
"""
Temporarily disable all quantizers in the model within with-as block, or permanently disable
without employing context manager.
:param model: Root module
:returns: Handle that enable all quantizers in the model upon handle.remove().
"""
# pylint: disable=import-outside-toplevel, cyclic-import
from aimet_torch.nn.base import BaseQuantizationMixin
if any(isinstance(m, BaseQuantizationMixin) for m in model.modules()):
return remove_all_quantizers(model)
param_quantizers, input_quantizers, output_quantizers = get_all_quantizers(model)
all_quantizers = param_quantizers + input_quantizers + output_quantizers
active_quantizers = set(
quantizer for quantizer in all_quantizers if quantizer.enabled
)
def cleanup():
for quantizer in active_quantizers:
quantizer.enabled = True
try:
for quantizer in active_quantizers:
quantizer.enabled = False
return Handle(cleanup)
except:
cleanup()
raise
def save_to_cache(tensor, dir_path, idx):
"""
Save tensor data into provided path with index
:param tensor: Tensor
:param dir_path: Provided path to save data
:param idx: Index of the file
"""
if not os.path.exists(dir_path):
os.makedirs(dir_path)
path = os.path.join(dir_path, f"model_inputs_{idx}")
with open(path, "wb") as cache:
pickle.dump(tensor, cache)
def cache_intermediate_datasets(
cached_dataset,
cache_on_cpu,
model,
module_name,
forward_fn,
path=None,
incl_kwargs: bool = False,
):
"""
Cache the input tensor of the target module and save to CPU or disk for latter usage
:param cached_dataset: Cached dataset
:param cache_on_cpu: True if caching data on CPU, False if caching to disk
:param model: Model that contains the target module
:param module_name: Name of the target module
:param forward_fn: Forward function that performs forward pass given a model and inputs
:param path: Location to save cached data if caching to dick
:param incl_kwargs: if True, capture kwargs, normalize and attach to inputs.
:return: Cached data on CPU
"""
# pylint: disable=cell-var-from-loop, too-many-locals, missing-class-docstring, missing-function-docstring
cached_data = []
*parent_name, child_name = module_name.split(".")
parent = model.get_submodule(".".join(parent_name))
orig_child = getattr(parent, child_name)
class CachingHelper(torch.nn.Module):
def forward(self, *args, **kwargs):
if not incl_kwargs:
kwargs = {}
if cache_on_cpu:
cached_data.append(
change_tensor_device_placement((args, kwargs), torch.device("cpu"))
)
else:
save_to_cache((args, kwargs), path, idx)
raise StopForwardException
caching_helper = CachingHelper()
try:
setattr(parent, child_name, caching_helper)
iterator = iter(cached_dataset)
for idx in range(len(cached_dataset)):
args, kwargs = next(iterator)
try:
with in_eval_mode(model), torch.no_grad():
_ = forward_fn(model, *args, **kwargs)
except StopForwardException:
pass
return cached_data
finally:
setattr(parent, child_name, orig_child)
def profile(
label: str,
file: Union[str, os.PathLike, TextIO] = None,
new_file: bool = False,
logger: Optional[logging.Logger] = None, # pylint: disable=redefined-outer-name
):
"""
Profile a block of code and save profiling information into a file.
:param label: String label associated with the block of code to profile (shows up in the profiling print)
:param file: File path and name or a file-like object to send output text to (Default: stdout)
:param new_file: True if a new file is to be created to hold profiling info, False if an existing file should be
appended to. This flag is only valid when ``file`` is a path, not a file-like object.
:param logger: If logger is provided, profiling string will also be printed with INFO logging level
"""
if not torch.cuda.is_available():
return profile_async(label, file, new_file, logger)
ctx = _profile(label, file, new_file, logger, cleanup=torch.cuda.synchronize)
return _ContextManager(
action=ctx.__enter__, cleanup=lambda: ctx.__exit__(None, None, None)
) # pylint: disable=no-member
def profile_async(
label: str,
file: Union[str, os.PathLike, TextIO] = None,
new_file: bool = False,
logger: Optional[logging.Logger] = None, # pylint: disable=redefined-outer-name
):
"""
Profile a block of code and save profiling information into a file.
:param label: String label associated with the block of code to profile (shows up in the profiling print)
:param file: File path and name or a file-like object to send output text to (Default: stdout)
:param new_file: True if a new file is to be created to hold profiling info, False if an existing file should be
appended to. This flag is only valid when ``file`` is a path, not a file-like object.
:param logger: If logger is provided, profiling string will also be printed with INFO logging level
"""
ctx = _profile(label, file, new_file, logger, cleanup=None)
return _ContextManager(
action=ctx.__enter__, cleanup=lambda: ctx.__exit__(None, None, None)
) # pylint: disable=no-member
def is_vector_encoding(encoding: Optional[List[Dict]]) -> bool:
"""
Check if encoding is from vector quantization
:param encoding: List of encoding dictionaries
:return: True if all required vector quantization properties are included in encoding
"""
if encoding is None:
return False
required_properties = (
"rows_per_block",
"cols_per_block",
"vector_dim",
"vector_stride",
"index_bw",
)
return all((property_ in encoding[0] for property_ in required_properties))
def get_all_named_parameters(model: torch.nn.Module):
"""
Yields all (name, parameter) pairs in model including redundant parameters.
:param model: torch.nn.Module from which to retrieve parameters
"""
for name, module in model.named_modules(remove_duplicate=False):
for param_name, parameter in module.named_parameters(recurse=False):
if name:
yield name + "." + param_name, parameter
else:
# Don't prepend . if module name is "" (Parameter owned by base model)
yield param_name, parameter
@contextlib.contextmanager
def place_model(model: torch.nn.Module, device: torch.device):
"""
Temporarily place model on given device
"""
original_device = get_device(model)
try:
model.to(device=device)
yield
finally:
model.to(device=original_device)
def get_param_channel_axis(module: torch.nn.Module, param_name: str):
"""
Given a module and its param name, this method returns the channel axis of the given parameter.
:param module: torch.nn.Module
:param param_name: str representing the name of the parameter
"""
channel_axis = 0
if isinstance(
module,
(torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d),
):
channel_axis = 1 if param_name == "weight" else 0
return channel_axis
def _is_expandable(src_shape: Tuple[int, ...], target_shape: Tuple[int, ...]) -> bool:
"""
Returns true if source shape can be expanded as target shape
"""
if len(src_shape) > len(target_shape):
return False
for src_dim, dst_dim in zip(src_shape[::-1], target_shape[::-1]):
if src_dim not in (1, dst_dim):
return False
return True
def _is_reducible(src_shape: Tuple[int, ...], target_shape: Tuple[int, ...]) -> bool:
"""
Returns true if source shape can be reduced as target shape
"""
return _is_expandable(target_shape, src_shape) # pylint: disable=arguments-out-of-order
def reduce(
input: torch.Tensor,
shape: tuple[int, ...],
reduce_op: Callable,
block_size: tuple[int, ...] | None = None,
):
"""
Reduce input into given shape.
:param input: Input to reduce
:param shape: Shape of the reduced output
:param reduce_op: Reduce operation
:param block_size: Block size for block-wise reduction
"""
from .quantization._utils import interleave, concretize_block_size
output_shape = shape
if block_size is not None:
block_size = concretize_block_size(input.shape, shape, block_size)
input = input.reshape(-1, *interleave(shape, block_size))
shape = interleave(shape, 1)
if not _is_reducible(input.shape, shape):
raise RuntimeError(
f"Input of shape {list(input.shape)} can't be reduced to shape {list(shape)}"
)
padded_shape = (*itertools.repeat(1, len(input.shape) - len(shape)), *shape)
reduce_dims = tuple(axis for axis, dim in enumerate(padded_shape) if dim == 1)
other_dims = tuple(axis for axis, dim in enumerate(padded_shape) if dim > 1)
permute_dims = reduce_dims + other_dims
return reduce_op(
input.permute(permute_dims).reshape(-1, *output_shape), dim=0, keepdim=False
)
class _ContextManager:
def __init__(self, action: Callable[[], Any], cleanup: Callable[[], Any]):
self._action = action
self._cleanup = cleanup
def __enter__(self):
self._action()
return self
def __exit__(self, *_):
self._cleanup()
def __call__(self, fn: Callable):
@functools.wraps(fn)
def wrapper(*args, **kwargs):
with self:
return fn(*args, **kwargs)
return wrapper
class _NullAttribute:
pass
def patch_attr(obj, attr_name, new_attr) -> _ContextManager:
"""
Temporarily overwrite object attribute
"""
if isinstance(obj, torch.nn.Module):
if attr_name in obj._parameters or attr_name in obj._buffers: # pylint: disable=protected-access
return _patch_param_or_buffer(obj, attr_name, new_attr)
if hasattr(obj, attr_name):
old_attr = getattr(obj, attr_name)
else:
old_attr = _NullAttribute()
action = lambda: setattr(obj, attr_name, new_attr)
def cleanup():
try:
delattr(obj, attr_name)
except AttributeError:
pass
if not hasattr(obj, attr_name) and not isinstance(old_attr, _NullAttribute):
setattr(obj, attr_name, old_attr)
return _ContextManager(action, cleanup)
def _patch_param_or_buffer(
module: torch.nn.Module,
param_or_buffer_name: str,
new_param_or_buffer: torch.Tensor,
):
"""
Temporarily substitute the reference to the a parameter with the quantized parameter.
Under the scope of this function, ``getattr(module, param_or_buffer_name)`` will return
``new_param_or_buffer`` instead of the original parameter.
:param module: Module that owns the parameter
:param param_or_buffer_name: Name of the parameter
:param new_param_or_buffer: New parameter to replace the original parameter
"""
# pylint: disable=protected-access
orig_param_or_buffer = getattr(module, param_or_buffer_name)
if orig_param_or_buffer is not None:
assert new_param_or_buffer.shape == orig_param_or_buffer.shape
if param_or_buffer_name in module._parameters:
container = module._parameters
elif param_or_buffer_name in module._buffers:
container = module._buffers
elif param_or_buffer_name in module.__dict__:
# Some non-standard modules (e.g. replicas of torch.nn.DataParallel) store their parameters
container = module.__dict__
else:
raise RuntimeError(
f"'{param_or_buffer_name}' is not a valid name of parameter of buffer of {type(module)}."
)
action = lambda: container.update({param_or_buffer_name: new_param_or_buffer})
cleanup = lambda: container.update({param_or_buffer_name: orig_param_or_buffer})
return _ContextManager(action, cleanup)
class _StraightThroughEstimator(torch.autograd.Function): # pylint: disable=abstract-method
@staticmethod
def forward(ctx, op, *args, **kwargs): # pylint:disable=arguments-differ, unused-argument
return op(*args, **kwargs)
@staticmethod
def backward(ctx, *grad):
return (None, *grad)
def ste_round(*args, **kwargs):
"""
Applies straight-through rounding
"""
return _StraightThroughEstimator.apply(torch.round, *args, **kwargs)
class StatisticsNotFoundError(RuntimeError):
"""
Error raised when compute_encodings() is invoked without statistics
"""
_ENABLE_RECOMPUTE = False
def _set_enable_recompute(mode: bool):
original_mode = _ENABLE_RECOMPUTE
def action():
global _ENABLE_RECOMPUTE # pylint: disable=global-statement
_ENABLE_RECOMPUTE = mode
def cleanup():
global _ENABLE_RECOMPUTE # pylint: disable=global-statement
_ENABLE_RECOMPUTE = original_mode
return _ContextManager(action, cleanup)
def is_recompute_enabled():
"""
Returns True if recomputation for memory saving is enabled; False otherwise.
"""
return _ENABLE_RECOMPUTE
def enable_recompute():
"""
Enable recomputation for memory saving.
"""
return _set_enable_recompute(True)
def no_recompute():
"""
Disable recomputation for memory saving.
"""
return _set_enable_recompute(False)
def allow_recompute(fn):
"""
Allow recomputation of activation of the given function during training
if recompute is enabled.
"""
@functools.wraps(fn)
def wrapper(*args, **kwargs):
if is_recompute_enabled():
# Enable activation recompute (a.k.a. activataion checkpointing)
# to reduce memory footprint of training
return torch.utils.checkpoint.checkpoint(
fn, *args, use_reentrant=False, **kwargs
)
return fn(*args, **kwargs)
return wrapper
def flatten_nn_module_list(module):
"""
Flatten nested list of nn.Modules into a flat list
"""
def flat_iter(mod):
if isinstance(mod, (list, tuple, torch.nn.ModuleList)):
for x in mod:
yield from flat_iter(x)
else:
yield mod
return list(flat_iter(module))
def _map_qmodule(modules, func):
# pylint: disable=import-outside-toplevel
# pylint: disable=protected-access, cyclic-import
from aimet_torch.nn import BaseQuantizationMixin
contexts = []
ctx = _ContextManager(
action=lambda: None,
cleanup=lambda: [context._cleanup() for context in contexts],
)
if isinstance(modules, torch.nn.Module):
modules = [modules]
try:
for module_elem in modules:
for module in module_elem.modules():
if isinstance(module, BaseQuantizationMixin):
context = func(module)
contexts.append(context)
except Exception:
ctx._cleanup()
raise
return ctx
[docs]
def remove_output_quantizers(modules):
"""
Temporarily remove all output quantizers
Example:
>>> print(sim.model)
Sequential(
(0): QuantizedConv2d(
3, 3, kernel_size=(3, 3), stride=(1, 1)
(param_quantizers): ModuleDict(
(weight): QuantizeDequantize(shape=(3, 1, 1, 1), qmin=-128, qmax=127, symmetric=True)
(bias): None
)
(input_quantizers): ModuleList(
(0): QuantizeDequantize(shape=(), qmin=0, qmax=255, symmetric=False)
)
(output_quantizers): ModuleList(
(0): QuantizeDequantize(shape=(), qmin=0, qmax=255, symmetric=False)
)
)
)
>>> with remove_output_quantizers(sim.model):
... print(sim.model)
...
Sequential(
(0): QuantizedConv2d(
3, 3, kernel_size=(3, 3), stride=(1, 1)
(param_quantizers): ModuleDict(
(weight): QuantizeDequantize(shape=(3, 1, 1, 1), qmin=-128, qmax=127, symmetric=True)
(bias): None
)
(input_quantizers): ModuleList(
(0): QuantizeDequantize(shape=(), qmin=0, qmax=255, symmetric=False)
)
(output_quantizers): ModuleList(
(0): None
)
)
)
"""
# pylint: disable=protected-access
return _map_qmodule(modules, lambda qmodule: qmodule._remove_output_quantizers())
[docs]
def remove_param_quantizers(modules):
"""
Temporarily remove all parameter quantizers
Example:
>>> print(sim.model)
Sequential(
(0): QuantizedConv2d(
3, 3, kernel_size=(3, 3), stride=(1, 1)
(param_quantizers): ModuleDict(
(weight): QuantizeDequantize(shape=(3, 1, 1, 1), qmin=-128, qmax=127, symmetric=True)
(bias): None
)
(input_quantizers): ModuleList(
(0): QuantizeDequantize(shape=(), qmin=0, qmax=255, symmetric=False)
)
(output_quantizers): ModuleList(
(0): QuantizeDequantize(shape=(), qmin=0, qmax=255, symmetric=False)
)
)
)
>>> with remove_param_quantizers(sim.model):
... print(sim.model)
...
Sequential(
(0): QuantizedConv2d(
3, 3, kernel_size=(3, 3), stride=(1, 1)
(param_quantizers): ModuleDict(
(weight): None
(bias): None
)
(input_quantizers): ModuleList(
(0): QuantizeDequantize(shape=(), qmin=0, qmax=255, symmetric=False)
)
(output_quantizers): ModuleList(
(0): QuantizeDequantize(shape=(), qmin=0, qmax=255, symmetric=False)
)
)
)
"""
# pylint: disable=protected-access
return _map_qmodule(modules, lambda qmodule: qmodule._remove_param_quantizers())
[docs]
def remove_activation_quantizers(modules):
"""
Temporarily remove all input and output quantizers
Example:
>>> print(sim.model)
Sequential(
(0): QuantizedConv2d(
3, 3, kernel_size=(3, 3), stride=(1, 1)
(param_quantizers): ModuleDict(
(weight): QuantizeDequantize(shape=(3, 1, 1, 1), qmin=-128, qmax=127, symmetric=True)
(bias): None
)
(input_quantizers): ModuleList(
(0): QuantizeDequantize(shape=(), qmin=0, qmax=255, symmetric=False)
)
(output_quantizers): ModuleList(
(0): QuantizeDequantize(shape=(), qmin=0, qmax=255, symmetric=False)
)
)
)
>>> with remove_activation_quantizers(sim.model):
... print(sim.model)
...
Sequential(
(0): QuantizedConv2d(
3, 3, kernel_size=(3, 3), stride=(1, 1)
(param_quantizers): ModuleDict(
(weight): QuantizeDequantize(shape=(3, 1, 1, 1), qmin=-128, qmax=127, symmetric=True)
(bias): None
)
(input_quantizers): ModuleList(
(0): None
)
(output_quantizers): ModuleList(
(0): None
)
)
)
"""
if not isinstance(modules, torch.nn.Module):
# Shallow copy in case modules is an iterator
modules = list(modules)
context_1 = remove_input_quantizers(modules)
context_2 = remove_output_quantizers(modules)
# pylint: disable=protected-access
return _ContextManager(
action=lambda: None,
cleanup=lambda: (context_1._cleanup(), context_2._cleanup()),
)
[docs]
def remove_all_quantizers(modules):
"""
Temporarily remove all quantizers
Example:
>>> print(sim.model)
Sequential(
(0): QuantizedConv2d(
3, 3, kernel_size=(3, 3), stride=(1, 1)
(param_quantizers): ModuleDict(
(weight): QuantizeDequantize(shape=(3, 1, 1, 1), qmin=-128, qmax=127, symmetric=True)
(bias): None
)
(input_quantizers): ModuleList(
(0): QuantizeDequantize(shape=(), qmin=0, qmax=255, symmetric=False)
)
(output_quantizers): ModuleList(
(0): QuantizeDequantize(shape=(), qmin=0, qmax=255, symmetric=False)
)
)
)
>>> with remove_all_quantizers(sim.model):
... print(sim.model)
...
Sequential(
(0): QuantizedConv2d(
3, 3, kernel_size=(3, 3), stride=(1, 1)
(param_quantizers): ModuleDict(
(weight): None
(bias): None
)
(input_quantizers): ModuleList(
(0): None
)
(output_quantizers): ModuleList(
(0): None
)
)
)
"""
if not isinstance(modules, torch.nn.Module):
# Shallow copy in case modules is an iterator
modules = list(modules)
context_1 = remove_activation_quantizers(modules)
context_2 = remove_param_quantizers(modules)
# pylint: disable=protected-access
return _ContextManager(
action=lambda: None,
cleanup=lambda: (context_1._cleanup(), context_2._cleanup()),
)
def has_no_quantizers(module, ignore_params: bool = False) -> bool:
"""
Helper function to check if a module has any quantizers enabled
"""
return (
all(inp_qtzr is None for inp_qtzr in module.input_quantizers)
and all(out_qtzr is None for out_qtzr in module.output_quantizers)
and (
ignore_params
or all(
param_qtzr is None for param_qtzr in module.param_quantizers.values()
)
)
)
def rgetattr(obj, attr):
"""Drop in replacement for __getattr__ that can handle dotted attribute strings"""
return functools.reduce(getattr, [obj] + attr.split("."))
def rsetattr(obj, attr, val):
"""Drop in replacement for __setattr__ that can handle dotted attribute strings"""
pre, _, post = attr.rpartition(".")
pre_obj = rgetattr(obj, pre) if pre else obj
return setattr(pre_obj, post, val)
def apply_fn_recursively_to_all_elems(fn, container):
"""Apply fn to all elements in recursively composed container"""
if container is None:
return None
if isinstance(container, (list, tuple)):
return [apply_fn_recursively_to_all_elems(fn, elem) for elem in container]
if isinstance(container, dict):
return {
key: apply_fn_recursively_to_all_elems(fn, elem)
for key, elem in container.items()
}
return fn(container)
def flatten_list(container):
"""Helper function to flatten nested list/tuple into 1D"""
if not container:
return container
if not isinstance(container, (list, tuple)):
return [container]
if isinstance(container[0], (list, tuple)):
return flatten_list(container[0]) + flatten_list(container[1:])
if len(container) == 1:
return container
return container[:1] + flatten_list(container[1:])
def default_forward_fn(model, inputs):
"""
Default forward function.
:param model: pytorch model
:param inputs: model inputs
"""
if isinstance(inputs, torch.Tensor):
inputs = [inputs]
return model(*inputs)
_torch_compiler_is_compiling: Callable[[], bool]
_torch_compiler_is_dynamo_compiling: Callable[[], bool]
_torch_compiler_is_exporting: Callable[[], bool]
if version.parse(torch.__version__) >= version.parse("2.7"):
_torch_compiler_is_compiling = torch.compiler.is_compiling
_torch_compiler_is_dynamo_compiling = torch.compiler.is_dynamo_compiling
_torch_compiler_is_exporting = torch.compiler.is_exporting
else:
# torch < 2.7.0 doesn't have torch.compiler.is_compiling/exporting API
def _torch_compiler_is_compiling() -> bool:
return False
def _torch_compiler_is_dynamo_compiling() -> bool:
return False
def _torch_compiler_is_exporting() -> bool:
return False
_qtensor_enabled = True
@contextmanager
def _enable_qtensor_casting(enable: bool):
global _qtensor_enabled # pylint: disable=global-statement
original_value = _qtensor_enabled
try:
_qtensor_enabled = enable
yield
finally:
_qtensor_enabled = original_value
def _is_qtensor_casting_enabled():
return _qtensor_enabled
@contextmanager
def _inference_mode(model: torch.nn.Module, prequantize_parameters: bool):
# pylint: disable=protected-access
from .quantsim import QuantizationSimModel
from .quantization.tensor import DequantizedTensor
from .quantization.base import QuantizerBase
with ExitStack() as stack:
for q in model.modules():
if isinstance(q, QuantizerBase):
dtype = next(
p.dtype for p in itertools.chain(q.parameters(), q.buffers())
)
stack.enter_context(q._precompute_encodings(dtype=dtype))
stack.enter_context(_enable_qtensor_casting(False))
if prequantize_parameters:
stack.enter_context(
QuantizationSimModel._apply_qdq_to_model_parameters(model)
)
stack.enter_context(remove_param_quantizers(model))
def cast_param_to_plain_tensor(module):
for name, param in module.named_parameters(recurse=False):
if isinstance(param, DequantizedTensor):
stack.enter_context(
patch_attr(module, name, param.as_subclass(torch.Tensor))
)
model.apply(cast_param_to_plain_tensor)
yield