# -*- mode: python -*-
# =============================================================================
# @@-COPYRIGHT-START-@@
#
# Copyright (c) 2021-2023, Qualcomm Innovation Center, Inc. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice,
# this list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors
# may be used to endorse or promote products derived from this software
# without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
# POSSIBILITY OF SUCH DAMAGE.
#
# SPDX-License-Identifier: BSD-3-Clause
#
# @@-COPYRIGHT-END-@@
# =============================================================================
""" Top level API for Adaptive Rounding - Post-Training Quantization (PTQ) """
import os
import contextlib
import itertools
import json
import shutil
from typing import Tuple, Union, Dict, List, Callable, Any, Optional
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
# Import AIMET specific modules
from aimet_common.utils import AimetLogger, convert_configs_values_to_bool
from aimet_common.defs import QuantScheme, QuantizationDataType
from aimet_torch import utils
from aimet_torch.save_utils import SaveUtils
from aimet_torch.meta import connectedgraph_utils
from aimet_torch.quantsim import QuantizationSimModel, QcQuantizeWrapper, ExportableQuantModule
from aimet_torch.qc_quantize_op import StaticGridQuantWrapper, QcQuantizeOpMode
from aimet_torch.tensor_quantizer import TensorQuantizer
from aimet_torch.adaround.adaround_wrapper import AdaroundWrapper
from aimet_torch.adaround.adaround_optimizer import AdaroundOptimizer
from aimet_torch.adaround.adaround_loss import AdaroundHyperParameters
from aimet_torch.adaround.activation_sampler import create_modulelist_for_group_modules, get_block_inputs, \
get_block_outputs, create_cached_block_schedule_list
from aimet_torch.utils import get_named_module
logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant)
# The following modules with weights are supported by Adaround
AdaroundSupportedModules = (torch.nn.Conv2d, torch.nn.ConvTranspose2d, torch.nn.Linear)
WORKING_DIR = '/tmp/adaround/'
[docs]class AdaroundParameters:
"""
Configuration parameters for Adaround
"""
def __init__(self, data_loader: DataLoader, num_batches: int,
default_num_iterations: int = None, default_reg_param: float = 0.01,
default_beta_range: Tuple = (20, 2), default_warm_start: float = 0.2,
forward_fn: Callable[[torch.nn.Module, Any], Any] = None):
"""
:param data_loader: Data loader
:param num_batches: Number of batches to be used for Adaround.
A commonly recommended value for this parameter is the smaller value among (1) len(data_loader) and (2) ceil(2000/batch_size)
:param default_num_iterations: Number of iterations to adaround each layer.
The default value is 10K for models with 8- or higher bit weights, and 15K for models with lower than 8 bit weights.
:param default_reg_param: Regularization parameter, trading off between rounding loss vs reconstruction loss.
Default 0.01
:param default_beta_range: Start and stop beta parameter for annealing of rounding loss (start_beta, end_beta).
Default (20, 2)
:param default_warm_start: warm up period, during which rounding loss has zero effect. Default 20% (0.2)
:param forward_fn: Optional adapter function that performs forward pass given a model and inputs
yielded from the data loader. The function expects model as first argument and inputs to model
as second argument.
"""
if len(data_loader) < num_batches:
raise ValueError(f'Can not fetch {num_batches} batches from '
f'a data loader of length {len(data_loader)}.')
self.data_loader = data_loader
self.num_batches = num_batches
self.num_iterations = default_num_iterations
self.reg_param = default_reg_param
self.beta_range = default_beta_range
self.warm_start = default_warm_start
self.forward_fn = forward_fn
class Adaround:
"""
Weight-rounding mechanism for Post Training Quantization (PTQ)
"""
@classmethod
def apply_adaround(cls, model: torch.nn.Module, dummy_input: Union[torch.Tensor, Tuple], params: AdaroundParameters,
path: str, filename_prefix: str, default_param_bw: int = 4,
param_bw_override_list: List[Tuple[torch.nn.Module, int]] = None,
ignore_quant_ops_list: List[torch.nn.Module] = None,
default_quant_scheme: QuantScheme = QuantScheme.post_training_tf_enhanced,
default_config_file: str = None) -> torch.nn.Module:
"""
Returns model with optimized weight rounding of every module (Conv and Linear) and also saves the
corresponding quantization encodings to a separate JSON-formatted file that can then be imported by
QuantSim for inference or QAT
:param model: Model to Adaround
:param dummy_input: Dummy input to the model. Used to parse model graph. If the model has more than one input,
pass a tuple. User is expected to place the tensors on the appropriate device.
:param params: Parameters for Adaround
:param path: path where to store parameter encodings
:param filename_prefix: Prefix to use for filename of the encodings file
:param default_param_bw: Default bitwidth (4-31) to use for quantizing layer parameters
:param param_bw_override_list: List of Tuples. Each Tuple is a module and the corresponding parameter bitwidth
to be used for that module.
:param ignore_quant_ops_list: Ops listed here are skipped during quantization needed for AdaRounding. Do not
specify Conv and Linear modules in this list. Doing so, will affect accuracy.
:param default_quant_scheme: Quantization scheme. Supported options are using Quant Scheme Enum
QuantScheme.post_training_tf or QuantScheme.post_training_tf_enhanced
:param default_config_file: Default configuration file for model quantizers
:return: Model with Adarounded weights and saves corresponding parameter encodings JSON file at provided path
"""
# pylint: disable=too-many-arguments
# Create Quant sim with given parameters
quant_sim = cls._get_quantsim(model, dummy_input=dummy_input, quant_scheme=default_quant_scheme,
default_param_bw=default_param_bw,
config_file=default_config_file)
# For the modules in the param_bw_override_list, override the default parameter bitwidths in the QuantSim
if param_bw_override_list:
cls._override_param_bitwidth(model, quant_sim, param_bw_override_list)
if ignore_quant_ops_list:
cls._exclude_modules(model, quant_sim, ignore_quant_ops_list)
# Compute only param encodings
cls._compute_param_encodings(quant_sim)
return cls._apply_adaround(quant_sim, model, dummy_input, params, path, filename_prefix)
@classmethod
def _apply_adaround(cls, quant_sim: QuantizationSimModel, model: torch.nn.Module,
dummy_input: Union[torch.Tensor, Tuple], params: AdaroundParameters,
path: str, filename_prefix: str, checkpoints_config: str = None) -> torch.nn.Module:
"""
Returns model with optimized weight rounding of every module (Conv and Linear) and also saves the
corresponding quantization encodings to a separate JSON-formatted file that can then be imported by
QuantSim for inference or QAT
:param quant_sim: QuantizationSimModel object to optimize weight rounding.
The activation quantizers are expected to have been disabled.
:param model: Original fp32 model from which quant_sim was created.
:param dummy_input: Dummy input to the model. Used to parse model graph. If the model has more than one input,
pass a tuple. User is expected to place the tensors on the appropriate device.
:param params: Parameters for Adaround
:param path: path where to store parameter encodings
:param filename_prefix: Prefix to use for filename of the encodings file
:param checkpoints_config: Config files to split fp32/quant model by checkpoints
:return: Model with Adarounded weights and saves corresponding parameter encodings JSON file at provided path
"""
# Sanity check: All the input/output quantizers should be disabled
cls._check_input_output_quantizers_for_adaround(quant_sim.model)
# Get the module - activation function pair using ConnectedGraph
module_act_func_pair = connectedgraph_utils.get_module_act_func_pair(model, dummy_input)
cls._adaround_model(model, quant_sim, module_act_func_pair, params, dummy_input, checkpoints_config)
# Export quantization encodings to JSON-formatted file
cls._export_encodings_to_json(path, filename_prefix, quant_sim)
cls._remove_quantization_wrappers(quant_sim.model)
logger.info('Completed Adarounding Model')
return quant_sim.model
@classmethod
def _adaround_model(cls, model: torch.nn.Module, quant_sim: QuantizationSimModel, module_act_func_pair: Dict,
params: AdaroundParameters, dummy_input: Union[torch.Tensor, Tuple],
checkpoints_config: str = None):
"""
Optimize weight rounding of every module (AdaroundSupportedModules) of model in sequential manner
based on occurrence
NOTE: When checkpoints_config file is provided, assumption is that the outputs from previous group modules (block)
should feed directly into next group modules (block)
:param model: Original fp32 model from which quant_sim was created.
:param quant_sim: QuantizationSimModel object to optimize weight rounding.
The activation quantizers are expected to have been disabled.
:param module_act_func_pair: Dictionary of module to immediate following activation function
:param params: Adaround parameters
:param dummy_input: Dummy input to the model
:param checkpoints_config: Config files to split fp32/quant model by checkpoints to speedup activations sampling
"""
# pylint: disable=too-many-locals, protected-access, too-many-branches, too-many-statements
num_iterations = params.num_iterations
if num_iterations is None:
lowest_weight_bw = cls._get_lowest_weight_bw(quant_sim.model)
# If the lowest wegith bitwidth is < 8, then set num_iterations to 15K by default
if lowest_weight_bw < 8:
num_iterations = 15000
else:
num_iterations = 10000
try:
# Cache model input data to WORKING_DIR
cached_dataset = utils.CachedDataset(params.data_loader, params.num_batches, WORKING_DIR)
# Optimization Hyper parameters
opt_params = AdaroundHyperParameters(num_iterations, params.reg_param, params.beta_range,
params.warm_start)
# AdaRound must be applied to modules in the order of occurrence
if checkpoints_config:
# Load the predefined json file for checkpoints info
checkpoint_config = json.load(open(checkpoints_config))
convert_configs_values_to_bool(checkpoint_config)
assert 'cache_on_cpu' in checkpoint_config.keys(), \
"Please define cache_on_cpu to determine whether to cache intermediate tensors on CPU"
cache_on_cpu = checkpoint_config['cache_on_cpu']
checkpoint_type = checkpoint_config.get('checkpoint_type', 'sequential')
if checkpoint_type == 'sequential':
assert 'grouped_modules' in checkpoint_config.keys(), \
"Please provide a dictionary of grouped_modules in the file to define checkpoints"
assert 'include_static_inputs' in checkpoint_config.keys(), \
"Please provide a dictionary of include_static_inputs in the file to define checkpoints"
grouped_modules = checkpoint_config['grouped_modules']
breakpoint_module_name = checkpoint_config['grouped_modules'][list(grouped_modules.keys())[0]][0]
include_static_inputs = checkpoint_config['include_static_inputs']
cached_fp_dataset, cached_quant_dataset = get_block_inputs(model, quant_sim,
breakpoint_module_name,
cached_dataset, cache_on_cpu,
params.forward_fn, params.num_batches,
WORKING_DIR)
# Get the device of model to latter be used to place input tensor on the same device
device = utils.get_device(model)
model.cpu()
quant_sim.model.cpu()
# Forward function for the ModuleList object
def fwd_mod_ls(mod_ls, x):
for mod in mod_ls:
x = params.forward_fn(mod, x)
return x
sub_fp_models, sub_sim_models = create_modulelist_for_group_modules(model, quant_sim, grouped_modules)
for i, (fp_block, quant_sim_block, static_input) in enumerate(zip(sub_fp_models,
sub_sim_models,
include_static_inputs)):
modules = utils.get_ordered_list_of_modules(fp_block, cached_fp_dataset[0], fwd_mod_ls)
cls._run_adaround_model(modules, fp_block, quant_sim_block,
module_act_func_pair, opt_params,
fwd_mod_ls,
cached_fp_dataset, cached_quant_dataset)
# Get the outputs from the current block and assign to be the inputs for next block
# except for the last block
if i < len(sub_fp_models) - 1:
get_block_outputs(fp_block, quant_sim_block, static_input,
cached_fp_dataset, cached_quant_dataset, cache_on_cpu,
fwd_mod_ls, device, WORKING_DIR)
# After finishing Adaround, placing the quant model back to its original device
quant_sim.model.to(device)
else:
assert 'cached_blocks' in checkpoint_config.keys(), \
"Please provide a list of modules that can be cached"
block_list = create_cached_block_schedule_list(
model, dummy_input, checkpoint_config['cached_blocks'], AdaroundSupportedModules)
for block_cfg, modules in tqdm(block_list, desc='block'):
if block_cfg is None: # doesn't belong to a cached block
cls._run_adaround_model(modules, model, quant_sim.model, module_act_func_pair, opt_params,
params.forward_fn, cached_dataset)
else:
block_name, fp_block = block_cfg
quant_sim_block: torch.nn.Module = get_named_module(quant_sim.model, block_name)
cached_fp_dataset, cached_quant_dataset = get_block_inputs(model, quant_sim,
block_name,
cached_dataset, cache_on_cpu,
params.forward_fn,
params.num_batches,
WORKING_DIR,
incl_kwargs=True)
def block_fwd(_model, x):
return _model(*x)
cls._run_adaround_model(modules, fp_block, quant_sim_block, module_act_func_pair,
opt_params,
block_fwd, cached_fp_dataset, cached_quant_dataset)
del cached_fp_dataset
del cached_quant_dataset
else:
modules = utils.get_ordered_list_of_modules(model, dummy_input)
cls._run_adaround_model(modules, model, quant_sim.model, module_act_func_pair, opt_params,
params.forward_fn, cached_dataset)
finally:
try:
logger.info('Deleting model inputs from location: %s', WORKING_DIR)
shutil.rmtree(WORKING_DIR)
except FileNotFoundError:
pass
@classmethod
def _run_adaround_model(cls, modules: List, model: torch.nn.Module, quant_sim_model: torch.nn.Module,
module_act_func_pair: Dict, opt_params: AdaroundHyperParameters, forward_fn: Callable,
cached_dataset: utils.CachedDataset,
cached_quant_dataset: Optional[utils.CachedDataset] = None):
"""
Iterate through all modules to find out Adaround supported modules and
apply Adaround optimization to those modules
:param modules: Candidate modules
:param model: Original fp32 model
:param quant_sim_model: QuantSim model
:param module_act_func_pair: Activation function pairs
:param opt_params: Optimization parameters
:param forward_fn: Adapter function that performs forward pass given a model and inputs
yielded from the data loader
:param cached_dataset: Cached dataset for the fp32 model
:param cached_quant_dataset: Cached dataset for the quant model
"""
# pylint: disable=too-many-arguments, too-many-locals, protected-access
for name, module in tqdm(modules):
if isinstance(module, AdaroundSupportedModules):
# Using name, get corresponding quantized wrapper module from Quant sim model
quant_wrapper = cls._get_quant_wrapper(quant_sim_model, name)
if not quant_wrapper:
continue
# Wraps the quant module with adaround wrapper
# and temporarily replace quant module with wrapped module
with cls._replace_quantization_layer(quant_sim_model, name) as adaround_wrapper:
# Get module's next following activation function
act_func = module_act_func_pair[module]
logger.info("Started Optimizing weight rounding of module: %s", name)
AdaroundOptimizer.adaround_module(module, adaround_wrapper, model, quant_sim_model, act_func,
cached_dataset, forward_fn, opt_params, cached_quant_dataset)
weight = adaround_wrapper.weight
# Fold trained alpha to weight
with torch.no_grad():
# Use soft rounding to compute Adarounded weight
adaround_wrapper.use_soft_rounding = True
adarounded_weight = adaround_wrapper.apply_adaround(weight)
weight.copy_(adarounded_weight)
del adarounded_weight
@staticmethod
def _compute_param_encodings(quant_sim: QuantizationSimModel):
"""
Compute encodings for parameters, needed for initializing Adaround quantizers
:param quant_sim: Quant sim
"""
for quant_module in quant_sim.model.modules():
if isinstance(quant_module, StaticGridQuantWrapper):
# Adaround requires input and output quantizers to be disabled
for quatizer in quant_module.input_quantizers:
quatizer.enabled = False
for quatizer in quant_module.output_quantizers:
quatizer.enabled = False
# pylint: disable=protected-access
for name, param in quant_module._module_to_wrap.named_parameters():
param_quantizer = quant_module.param_quantizers[name]
param_quantizer.reset_encoding_stats()
param_quantizer.update_encoding_stats(param.data)
param_quantizer.compute_encoding()
# Wrapper mode must be set to ACTIVE because the wrapper's quantize_dequantize_params() will only call
# into the param tensor quantizer's quantize_dequantize() if the mode is not PASSTHROUGH.
quant_module.set_mode(QcQuantizeOpMode.ACTIVE)
@staticmethod
def _get_quantsim(model: torch.nn.Module, dummy_input: torch.Tensor,
quant_scheme: QuantScheme, default_param_bw: int, config_file: str):
return QuantizationSimModel(model, dummy_input=dummy_input, quant_scheme=quant_scheme,
default_param_bw=default_param_bw,
config_file=config_file)
@staticmethod
def _get_adaround_wrapper(quant_module: QcQuantizeWrapper):
return AdaroundWrapper(quant_module)
@staticmethod
def _remove_quantization_wrappers(module: torch.nn.Module):
SaveUtils.remove_quantization_wrappers(module)
@staticmethod
@contextlib.contextmanager
def _patch_module_layer(model, layer_name, new_layer):
"""
Temporarily replace model layer
"""
original_layer = getattr(model, layer_name)
setattr(model, layer_name, new_layer)
yield
setattr(model, layer_name, original_layer)
@staticmethod
def _validate_quant_module_for_adaround(quant_module: StaticGridQuantWrapper):
assert quant_module.param_quantizers['weight'], '%s does not have weight parameter.' % quant_module
assert quant_module.param_quantizers['weight'].encoding, '%s encoding needs to be set.' % quant_module
@staticmethod
def _check_input_output_quantizers_for_adaround(quant_model: torch.nn.Module):
_, input_quantizers, output_quantizers = utils.get_all_quantizers(quant_model)
for quantizer in itertools.chain(input_quantizers, output_quantizers):
assert not quantizer.enabled
@staticmethod
def _get_lowest_weight_bw(quant_model: torch.nn.Module):
param_quantizers, _, _ = utils.get_all_quantizers(quant_model)
return min(
quantizer.bitwidth for quantizer in param_quantizers
if quantizer.enabled and quantizer.data_type == QuantizationDataType.int
)
@classmethod
@contextlib.contextmanager
def _replace_quantization_layer(cls, quant_sim_model: torch.nn.Module, module_name: str):
"""
Replace the quantized module's weight tensor quantizer with the Adaround tensor quantizer
:param quant_module: quant module
"""
quant_module = utils.get_named_module(quant_sim_model, module_name)
cls._validate_quant_module_for_adaround(quant_module)
adaround_layer = cls._get_adaround_wrapper(quant_module)
# We need to look for the container to patch for modules inside submodule
upper_module = quant_sim_model
upper_module_name, _, target_module_name = module_name.rpartition('.')
if upper_module_name:
upper_module = utils.get_named_module(quant_sim_model, upper_module_name)
# Temporarily replace quant module with wrapped module
with cls._patch_module_layer(upper_module, target_module_name, adaround_layer):
yield adaround_layer
@staticmethod
def _get_quant_wrapper(quant_sim_model: torch.nn.Module, module_name: str) -> Union[StaticGridQuantWrapper, None]:
"""
For given module name, get the quantized wrapper module from the QuantSim model
:param quant_sim_model: Model with simulation ops
:param module_name: Module name
:return: Quantized wrapper module or None
"""
quant_module = None
for name, module in quant_sim_model.named_modules():
if name == module_name and isinstance(module, StaticGridQuantWrapper):
quant_module = module
break
return quant_module
@classmethod
def _export_encodings_to_json(cls, path: str, filename_prefix: str, quant_sim: QuantizationSimModel):
"""
Save Adadrounded module's parameter encodings to JSON file
:param path: path where to store param encodings
:param filename_prefix: filename to store exported weight encodings in JSON format
:param quant_sim: QunatSim that contains the model and Adaround tensor quantizers
"""
# pylint: disable=protected-access
# Create a dictionary to export to JSON file
param_encodings = {}
for name, quant_module in quant_sim.model.named_modules():
if isinstance(quant_module, ExportableQuantModule) and \
isinstance(quant_module.get_original_module(), AdaroundSupportedModules):
if 'weight' in quant_module.param_quantizers:
cls._update_param_encodings_dict(quant_module, name, param_encodings)
# Unify the encoding format to be same as that of full encoding export file
encoding = {'param_encodings': param_encodings}
# export encodings to JSON file
os.makedirs(os.path.abspath(path), exist_ok=True)
encoding_file_path = os.path.join(path, filename_prefix + '.encodings')
with open(encoding_file_path, 'w') as encoding_fp:
json.dump(encoding, encoding_fp, sort_keys=True, indent=4)
@classmethod
def _update_param_encodings_dict(cls, quant_module: ExportableQuantModule, name: str, param_encodings: Dict):
"""
Add module's weight parameter encodings to dictionary to be used for exporting encodings
:param quant_module: quant module
:param name: name of module
:param param_encodings: Dictionary of param encodings
"""
for orig_param_name, encodings in quant_module.export_param_encodings().items():
if orig_param_name == 'weight' and encodings:
param_name = name + '.' + orig_param_name
param_encodings[param_name] = encodings
@staticmethod
def _create_encodings_dict_for_quantizer(quantizer: TensorQuantizer) -> List[Dict]:
"""
Return encodings for given qunatizer
:param quantizer: Tensor quantizer associated with module's param
:return: Dictionary containing encodings
"""
quant_encodings = quantizer.encoding
if not isinstance(quantizer.encoding, list):
quant_encodings = [quant_encodings]
encodings_dict = []
for enc in quant_encodings:
encodings_dict.append({'min': enc.min,
'max': enc.max,
'scale': enc.delta,
'offset': int(enc.offset),
'bitwidth': enc.bw,
'is_symmetric': str(quantizer.use_symmetric_encodings),
'dtype': 'int' if quantizer.data_type == QuantizationDataType.int else 'float'})
return encodings_dict
@staticmethod
def _override_param_bitwidth(model: torch.nn.Module, quant_sim: QuantizationSimModel,
param_bw_override_list: List[Tuple[torch.nn.Module, int]]):
"""
For the QuantSim, for the list of modules in the param_bw_override_list,
overrides the default parameter bitwidths with the provided bitwidth.
:param model: The original model
:param quant_sim: The QuantSim that was created using a deepcopy of the original model.
:param param_bw_override_list: List of Tuples. Each Tuple is a module and the corresponding parameter bitwidth
to be used for that module.
"""
# Create a mapping of original model's AdaRoundable module and their name
module_to_name = {}
for name, module in model.named_modules():
if isinstance(module, AdaroundSupportedModules):
module_to_name[module] = name
# Create a mapping of QuantSim model's AdaRoundable module name and their module
name_to_module = {}
for q_name, q_module in quant_sim.model.named_modules():
if isinstance(q_module, ExportableQuantModule):
if isinstance(q_module.get_original_module(), AdaroundSupportedModules): # pylint: disable=protected-access
name_to_module[q_name] = q_module
# For the modules specified in the param_bw_override_list, set the weight quantizer bitwidth
for (module, bw) in param_bw_override_list:
module_name = module_to_name[module]
quant_wrapper = name_to_module[module_name]
quant_wrapper.param_quantizers['weight'].bitwidth = bw
@classmethod
def _exclude_modules(cls, model: torch.nn.Module, quant_sim: QuantizationSimModel,
ignore_quant_ops_list: List[torch.nn.Module]):
"""
For the modules mentioned in the ignore_quant_ops_list, remove the corresponding quant wrappers from the
quantSim and excludes modules from adaround optimization.
:param model: The original model
:param quant_sim: The QuantSim that was created using a deepcopy of the original model.
:param ignore_quant_ops_list: The list of modules for which the Quantization wrappers are removed from the
QuantSim object.
"""
quant_wrappers_to_exclude = []
for module in ignore_quant_ops_list:
for m in module.modules():
name = utils.get_layer_name(model, m)
quant_wrapper = cls._get_quant_wrapper(quant_sim.model, name)
if quant_wrapper:
quant_wrappers_to_exclude.append(quant_wrapper)
quant_sim.exclude_layers_from_quantization(quant_wrappers_to_exclude)
@classmethod
def apply_adaround_with_cache(cls, model: torch.nn.Module, dummy_input: Union[torch.Tensor, Tuple],
params: AdaroundParameters,
path: str, filename_prefix: str, default_param_bw: int = 4,
param_bw_override_list: List[Tuple[torch.nn.Module, int]] = None,
ignore_quant_ops_list: List[torch.nn.Module] = None,
default_quant_scheme: QuantScheme = QuantScheme.post_training_tf_enhanced,
default_config_file: str = None,
checkpoints_config: str = None) -> torch.nn.Module:
"""
Returns model with optimized weight rounding of every module (Conv and Linear) and also saves the
corresponding quantization encodings to a separate JSON-formatted file that can then be imported by
QuantSim for inference or QAT
:param model: Model to Adaround
:param dummy_input: Dummy input to the model. Used to parse model graph. If the model has more than one input,
pass a tuple. User is expected to place the tensors on the appropriate device.
:param params: Parameters for Adaround
:param path: path where to store parameter encodings
:param filename_prefix: Prefix to use for filename of the encodings file
:param default_param_bw: Default bitwidth (4-31) to use for quantizing layer parameters
:param param_bw_override_list: List of Tuples. Each Tuple is a module and the corresponding parameter bitwidth
to be used for that module.
:param ignore_quant_ops_list: Ops listed here are skipped during quantization needed for AdaRounding. Do not
specify Conv and Linear modules in this list. Doing so, will affect accuracy.
:param default_quant_scheme: Quantization scheme. Supported options are using Quant Scheme Enum
QuantScheme.post_training_tf or QuantScheme.post_training_tf_enhanced
:param default_config_file: Default configuration file for model quantizers
:param checkpoints_file: JSON file to define checkpoints for caching intermediate tensors of fp32/quant model
:return: Model with Adarounded weights and saves corresponding parameter encodings JSON file at provided path
"""
# pylint: disable=too-many-arguments
assert checkpoints_config is not None, "To run Adaround with cached tensors, please provide a JSON file with checkpoints defined"
# Create Quant sim with given parameters
quant_sim = cls._get_quantsim(model, dummy_input=dummy_input, quant_scheme=default_quant_scheme,
default_param_bw=default_param_bw,
config_file=default_config_file)
# For the modules in the param_bw_override_list, override the default parameter bitwidths in the QuantSim
if param_bw_override_list:
cls._override_param_bitwidth(model, quant_sim, param_bw_override_list)
if ignore_quant_ops_list:
cls._exclude_modules(model, quant_sim, ignore_quant_ops_list)
# Compute only param encodings
cls._compute_param_encodings(quant_sim)
return cls._apply_adaround(quant_sim, model, dummy_input, params, path, filename_prefix, checkpoints_config)