# -*- mode: python -*-# =============================================================================# @@-COPYRIGHT-START-@@## Copyright (c) 2019-2024, Qualcomm Innovation Center, Inc. All rights reserved.## Redistribution and use in source and binary forms, with or without# modification, are permitted provided that the following conditions are met:## 1. Redistributions of source code must retain the above copyright notice,# this list of conditions and the following disclaimer.## 2. Redistributions in binary form must reproduce the above copyright notice,# this list of conditions and the following disclaimer in the documentation# and/or other materials provided with the distribution.## 3. Neither the name of the copyright holder nor the names of its contributors# may be used to endorse or promote products derived from this software# without specific prior written permission.## THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE# POSSIBILITY OF SUCH DAMAGE.## SPDX-License-Identifier: BSD-3-Clause## @@-COPYRIGHT-END-@@# =============================================================================""" Implementation for simulating models running on Quantized hardware """importcontextlibimportosimportioimportcopyfromtypingimportTuple,List,Union,Dict,Callable,Optional,Anyimporttorchfromaimet_common.utilsimportAimetLoggerfromaimet_common.defsimportQuantScheme,QuantizationDataTypefromaimet_torch.v1.qc_quantize_opimportQcQuantizeStandAloneBase,QcQuantizeWrapper,QcQuantizeOpMode, \
StaticGridQuantWrapper,LearnedGridQuantWrapper,NativeTorchQuantWrapperfromaimet_torch.v1.tensor_quantizerimportinitialize_learned_grid_quantizer_attributesfromaimet_torch.v1.qc_quantize_opimportget_encoding_by_quantizeras_get_encoding_by_quantizerfromaimet_torchimportutilsfromaimet_torch.v1.utilsimportcreate_encoding_dict,get_v1_quant_scheme_for_initializationfromaimet_torch.onnx_utilsimportOnnxSaver,OnnxExportApiArgsfromaimet_torch.v1.qc_quantize_recurrentimportQcQuantizeRecurrentfromaimet_torch.quantsim_config.builderimportLazyQuantizeWrapperfromaimet_torch.v1._builderimport_V1LazyQuantizeWrapperfromaimet_torch._base.quantsimimport(_QuantizationSimModelBase,_QuantizedModuleProtocol,unquantizable_modules,QuantParams,ExportableQuantModule,save_checkpoint,load_checkpoint,check_accumulator_overflow,)__all__=['QuantizationSimModel','QuantParams','ExportableQuantModule','save_checkpoint','load_checkpoint','check_accumulator_overflow','load_encodings_to_sim','compute_encodings_for_sims',]logger=AimetLogger.get_area_logger(AimetLogger.LogAreas.Quant)# If a torch module type is in this dictionary, call the corresponding quantized module constructor instead of wrapping# it with QcQuantizeWrapper.qc_quantize_modules_dict={torch.nn.RNN:QcQuantizeRecurrent,torch.nn.LSTM:QcQuantizeRecurrent,torch.nn.GRU:QcQuantizeRecurrent}# Types of modules which cannot be quantizedquantized_modules=(QcQuantizeWrapper,QcQuantizeStandAloneBase,QcQuantizeRecurrent,_QuantizedModuleProtocol,LazyQuantizeWrapper,)
[docs]classQuantizationSimModel(_QuantizationSimModelBase):# pylint: disable=missing-class-docstring__doc__=_QuantizationSimModelBase.__doc__# pylint: disable=too-many-arguments, too-many-locals, too-many-public-methods_quantized_modules=quantized_modulesdef_realize_quant_wrappers_in_model(self,model:torch.nn.Module):""" Prepare QuantSim for compute encodings. Resets encodings for each quantizable layer and sets mode to Analysis. Realize quant wrappers using collected information in LazyQuantWrapper. :param model: model containing modules wrapped with LazyQuantWrapper """formodule_name,module_refinmodel.named_children():ifisinstance(module_ref,LazyQuantizeWrapper):quantized_module=module_ref.realize()setattr(model,module_name,quantized_module)elifnotutils.is_leaf_module(module_ref):self._realize_quant_wrappers_in_model(module_ref)def__str__(self):""" Pretty-printed output indicating where in the model, quantizers have been activated :return: """defprint_quantizer_state(stream,quantizer,prefix_string):ifquantizer.enabled:stream.write(f' {prefix_string}: bw={quantizer.bitwidth}, 'f'encoding-present={bool(quantizer.encoding)}\n')ifquantizer.encoding:stream.write(f' {quantizer}')else:stream.write(f' {prefix_string}: Not quantized\n')stream.write(' -------\n')stream=io.StringIO(newline='\n')stream.write("-------------------------\n")stream.write("Quantized Model Report\n")stream.write("-------------------------\n")forlayer_name,layerinself._get_qc_quantized_layers(self.model):stream.write('----------------------------------------------------------\n')stream.write('Layer: {}\n'.format(layer_name))# Inputsifisinstance(layer.input_quantizers,dict):forname,quantizerinlayer.input_quantizers.items():print_quantizer_state(stream,quantizer,prefix_string=f"Input[{name}]")else:forindex,quantizerinenumerate(layer.input_quantizers):print_quantizer_state(stream,quantizer,prefix_string=f"Input[{index}]")# Paramsforparam_name,quantizerinlayer.param_quantizers.items():print_quantizer_state(stream,quantizer,prefix_string=f"Param[{param_name}]")# Outputsifisinstance(layer.output_quantizers,dict):forname,quantizerinlayer.output_quantizers.items():print_quantizer_state(stream,quantizer,prefix_string=f"Output[{name}]")else:forindex,quantizerinenumerate(layer.output_quantizers):print_quantizer_state(stream,quantizer,prefix_string=f"Output[{index}]")returnstream.getvalue()@staticmethoddefprepare_sim_for_compute_encodings(sim:'QuantizationSimModel'):""" Prepare QuantSim for compute encodings. Resets encodings for each quantizable layer and sets mode to Analysis. :param sim: QuantSim to prepare """# pylint: disable=protected-accessquantized_layers=sim._get_qc_quantized_layers(sim.model)for_,layerinquantized_layers:# Clear stats and encodings if they are presentlayer.reset_encodings()# And set the mode to analysislayer.set_mode(QcQuantizeOpMode.ANALYSIS)for_,layerinquantized_layers:# call only when quant scheme is percentileifsim._quant_scheme==QuantScheme.post_training_percentile:layer.set_percentile_value(sim._percentile_value)@staticmethoddefcompute_layer_encodings_for_sim(sim:'QuantizationSimModel'):""" Compute encodings for each quantizable layer in sim after forward pass has been called. :param sim: QuantSim to compute encodings for """# pylint: disable=protected-accessquantized_layers=sim._get_qc_quantized_layers(sim.model)# Get the computed per-layer encodings and log themforname,layerinquantized_layers:layer.compute_encoding()# Before we return we set the mode to active - meaning ready for quantize/de-quantize# for layers with valid_encoding, otherwise we set to pass throughifisinstance(layer,QcQuantizeRecurrent):sim.set_mode_for_recurrent_module(layer,name)else:# By default we want to set the Quantization wrappers to ACTIVE modelayer.set_mode(QcQuantizeOpMode.ACTIVE)sim.replace_wrappers_for_quantize_dequantize()
[docs]defcompute_encodings(self,forward_pass_callback,forward_pass_callback_args):# pylint: disable=arguments-differ""" Computes encodings for all quantization sim nodes in the model. It is also used to find initial encodings for Range Learning :param forward_pass_callback: A callback function that simply runs forward passes on the model. This callback function should use representative data for the forward pass, so the calculated encodings work for all data samples. This callback internally chooses the number of data samples it wants to use for calculating encodings. :param forward_pass_callback_args: These argument(s) are passed to the forward_pass_callback as-is. Up to the user to determine the type of this parameter. E.g. could be simply an integer representing the number of data samples to use. Or could be a tuple of parameters or an object representing something more complex. If set to None, forward_pass_callback will be invoked with no parameters. :return: None """QuantizationSimModel.prepare_sim_for_compute_encodings(self)# Run forward iterations so we can collect statistics to compute the appropriate encodingswithutils.in_eval_mode(self.model),torch.no_grad():_=forward_pass_callback(self.model,forward_pass_callback_args)QuantizationSimModel.compute_layer_encodings_for_sim(self)
@classmethoddefset_mode_for_recurrent_module(cls,layer:QcQuantizeRecurrent,name:str):""" Sets Recurrent module to active or pass through mode based on quantizer state :param layer: Qc Quantizer layer for recurrent module :param name: layer name :return: True if the encoding is invalid """forquantizer_name,output_quantizerinlayer.output_quantizers.items():ifoutput_quantizer.enabled:ifoutput_quantizer.encoding:encoding=output_quantizer.encodinglogger.debug("Encoding for %s-%s: min=%f, max=%f, offset=%f. delta=%f, bw=%f",name,quantizer_name,encoding.min,encoding.max,encoding.delta,encoding.offset,encoding.bw)forquantizer_name,input_quantizerinlayer.input_quantizers.items():ifinput_quantizer.enabled:ifinput_quantizer.encoding:encoding=input_quantizer.encodinglogger.debug("Encoding for %s-%s: min=%f, max=%f, offset=%f. delta=%f, bw=%f",name,quantizer_name,encoding.min,encoding.max,encoding.delta,encoding.offset,encoding.bw)layer.set_mode(QcQuantizeOpMode.ACTIVE)defset_percentile_value(self,percentile_value:float):""" Set the percentile value to be used while computing encodings """ifpercentile_value<90orpercentile_value>100:raiseValueError("Percentile value must be in range [90, 100]")self._percentile_value=percentile_valuedef_replace_quantization_wrapper(self,model,device):""" Recursively remove quantization wrappers from all appropriate modules starting with a given module :param model: model for which PostTrainingWrapper gets replaced with Trainable wrapped module :param device: device on which model is present :return: None """formodule_name,module_refinmodel.named_children():ifisinstance(module_ref,StaticGridQuantWrapper):# Create a Trainable wrapper and copy properties of PostTrainingWrapper to the Trainable wrapperquantized_module=self._construct_and_initialize_trainable_wrapper(module_ref,device)setattr(model,module_name,quantized_module)elifisinstance(module_ref,QcQuantizeRecurrent):# Set Recurrent layer for training modemodule_ref.construct_and_initialize_trainable_quantizers(self._quant_scheme)# Recursively call children modules if presentifnotutils.is_leaf_module(module_ref):self._replace_quantization_wrapper(module_ref,device)def_construct_and_initialize_trainable_wrapper(self,post_training_module:StaticGridQuantWrapper,device:torch.device)->LearnedGridQuantWrapper:""" Copies following tensor quantizer attributes from StaticGridQuantWrapper to LearnedGridQuantWrapper to avoid any mismatch. - enabled - bitwidth - encoding - use_symmetric_encodings - use_strict_symmetric - use_unsigned_symmetric :param post_training_module: StaticGridQuantWrapper wrapped module :param device: device on which model is present :return: trainable_module: QcTrainable wrapper module """# pylint: disable=protected-accessmodule=post_training_module._module_to_wrapnum_inputs=len(post_training_module.input_quantizers)num_outputs=len(post_training_module.output_quantizers)# Creating a LearnedGridQuantWrapper moduletrainable_module=LearnedGridQuantWrapper(module,self._default_param_bw,self._default_output_bw,self._rounding_mode,self._quant_scheme,device=device,num_inputs=num_inputs,num_outputs=num_outputs,data_type=QuantizationDataType.int)# Copy user settable attributes for outputsforindex,quantizerinenumerate(post_training_module.output_quantizers):initialize_learned_grid_quantizer_attributes(trainable_module.output_quantizers[index],quantizer)iftrainable_module.output_quantizers[index].encoding_min_max_fixed_valsisnotNone:trainable_module.output_quantizers[index].freeze_encoding()# Copy user settable attributes for inputsforindex,quantizerinenumerate(post_training_module.input_quantizers):initialize_learned_grid_quantizer_attributes(trainable_module.input_quantizers[index],quantizer)iftrainable_module.input_quantizers[index].encoding_min_max_fixed_valsisnotNone:trainable_module.input_quantizers[index].freeze_encoding()# Copy user settable attributes for paramsforname,quantizerinpost_training_module.param_quantizers.items():learned_grid_quantizer=trainable_module.param_quantizers[name]initialize_learned_grid_quantizer_attributes(learned_grid_quantizer,quantizer)iflearned_grid_quantizer.encoding_min_max_fixed_valsisnotNone:learned_grid_quantizer.freeze_encoding()returntrainable_moduledefreplace_wrappers_for_quantize_dequantize(self):""" Replaces StaticGridWrapper with LearnedGridWrapper """ifself._quant_schemein(QuantScheme.training_range_learning_with_tf_init,QuantScheme.training_range_learning_with_tf_enhanced_init):try:device=utils.get_device(self.model)exceptStopIteration:# Model doesn't have any parameter.# Set device to cpu by default.device=torch.device('cpu')self._replace_quantization_wrapper(self.model,device)def_create_quantizer_module(self,module_to_quantize:torch.nn.Module,num_inout_tensors:Dict,data_type:QuantizationDataType)->torch.nn.Module:"""Instantiates wrapper based on quant scheme """assertself._quant_schemein[QuantScheme.post_training_tf,QuantScheme.post_training_tf_enhanced,QuantScheme.training_range_learning_with_tf_enhanced_init,QuantScheme.training_range_learning_with_tf_init,QuantScheme.post_training_percentile]# We lookup the number of input and output tensors already determined# Special case, we are adding a wrapper for a module not in the forward pass: Use default of 1, 1num_in_tensors,num_out_tensors=num_inout_tensors.get(module_to_quantize,(1,1))# Set quantizer to be a module replacer if it is in qc_quantize_modules_dict, otherwise set as# StaticGridQuantWrapper.quantizer_wrapper_type=qc_quantize_modules_dict.get(type(module_to_quantize),_V1LazyQuantizeWrapper)ifissubclass(quantizer_wrapper_type,LazyQuantizeWrapper):quant_scheme_for_initialization=self._quant_schemeelse:quant_scheme_for_initialization=get_v1_quant_scheme_for_initialization(self._quant_scheme)# TODO add quant_scheme_for_initialization for FP8 casequantized_module=quantizer_wrapper_type(module_to_quantize,self._default_param_bw,self._default_output_bw,self._rounding_mode,quant_scheme_for_initialization,num_inputs=num_in_tensors,num_outputs=num_out_tensors,data_type=data_type)returnquantized_module@classmethoddef_is_quantizable_module(cls,module:torch.nn.Module):# pylint: disable=unidiomatic-typecheckreturntype(module)!=torch.nn.Moduleand\
notisinstance(module,unquantizable_modules)and\
notcls._is_quantized_module(module)@classmethoddef_is_quantized_module(cls,module:torch.nn.Module):returnisinstance(module,quantized_modules)def_add_quantization_wrappers(self,module,num_inout_tensors,default_data_type:QuantizationDataType):"""Recursively add quantization wrappers to all appropriate modules starting with module """ifself._is_quantized_module(module):returnformodule_name,module_refinmodule.named_children():logger.debug("nn.Module found : %s",module_ref)ifself._is_quantizable_module(module_ref)andutils.is_leaf_module(module_ref):# Create a new QcQuantize wrapper modulequantized_module=self._create_quantizer_module(module_ref,num_inout_tensors,default_data_type)setattr(module,module_name,quantized_module)else:self._add_quantization_wrappers(module_ref,num_inout_tensors,default_data_type)# pylint: disable=too-many-arguments@classmethoddef_update_encoding_dicts_for_layer(cls,layer:_QuantizedModuleProtocol,layer_name:str,activation_encodings_onnx:Dict,activation_encodings_torch:Dict,param_encodings:Dict,op_to_io_tensor_map:Dict,valid_param_set:set,propagate_encodings:bool,tensor_to_consumer_map:Dict[str,str],layers_to_onnx_op_names:Dict[str,str],tensor_to_quantizer_map:Dict):""" Add given layer param and activation encodings to respective dictionaries to be used for exporting encodings :param layer: layer as torch.nn.Module :param layer_name: Name of the layer :param activation_encodings_onnx: dictionary of activation encodings which maps onnx attribute to encodings :param activation_encodings_torch: dictionary of activation encodings which maps pytorch names to encodings :param param_encodings: dictionary of param encodings :param op_to_io_tensor_map: ONNX or Torch Script map of layer name to it's input/output tensors :param valid_param_set: a set of valid param input names in model :param propagate_encodings: If True, encoding entries for intermediate ops (when one PyTorch ops results in multiple ONNX nodes) are filled with the same BW and data_type as the output tensor for that series of ops. :param tensor_to_consumer_map: Dictionary mapping tensor names to op names which consume the tensor :param layers_to_onnx_op_names: Dictionary mapping PyTorch layer names to names of corresponding ONNX ops """ifisinstance(layer,QcQuantizeRecurrent):# Update encodings for Recurrent layersQuantizationSimModel._update_encoding_dict_for_recurrent_layers(layer,layer_name,op_to_io_tensor_map,activation_encodings_onnx,param_encodings,propagate_encodings,tensor_to_quantizer_map)else:super()._update_encoding_dicts_for_layer(layer,layer_name,activation_encodings_onnx,activation_encodings_torch,param_encodings,op_to_io_tensor_map,valid_param_set,propagate_encodings,tensor_to_consumer_map,layers_to_onnx_op_names,tensor_to_quantizer_map)@staticmethoddef_update_encoding_dict_for_recurrent_layers(layer:torch.nn.Module,layer_name:str,op_to_io_tensor_map:Dict,activation_encodings_onnx:Dict,param_encodings:Dict,propagate_encodings:bool,tensor_to_quantizer_map:Dict):""" :param layer: :param layer_name: :param op_to_io_tensor_map: :param activation_encodings_onnx: :param param_encodings: :param propagate_encodings: :return: """# pylint: disable=too-many-nested-blocks# pylint: disable=too-many-localsonnx_activations_to_quantizers,onnx_params_to_quantizers= \
layer.get_activation_param_quantizers_for_onnx_tensors(op_to_io_tensor_map[layer_name+'#root_node'])# ------------------# Activations# ------------------quantizer=Nonefortensor,quantizerinonnx_activations_to_quantizers.items():quantizer_encoding=_get_encoding_by_quantizer(quantizer)encoding=create_encoding_dict(quantizer_encoding,quantizer,propagate_encodings=False)activation_encodings_onnx[tensor]=[encoding]tensor_to_quantizer_map[tensor]=quantizerifpropagate_encodingsandquantizer:_,op_names=QuantizationSimModel.find_op_names_for_layer(layer_name,op_to_io_tensor_map,None,None)forop_nameinop_names:io_tensor_list=op_to_io_tensor_map[op_name]ifnotisinstance(io_tensor_list,list):io_tensor_list=[io_tensor_list]forio_tensorsinio_tensor_list:ifio_tensors.outputs:foroutput_tensorinio_tensors.outputs:ifoutput_tensorinonnx_activations_to_quantizers:continuequantizer_encoding=_get_encoding_by_quantizer(quantizer)encoding=create_encoding_dict(quantizer_encoding,quantizer,True)activation_encodings_onnx[output_tensor]=[encoding]tensor_to_quantizer_map[output_tensor]=quantizer# ------------------# Params# ------------------fortensor,quantizerinonnx_params_to_quantizers.items():quantizer_encoding=_get_encoding_by_quantizer(quantizer)encoding=create_encoding_dict(quantizer_encoding,quantizer,propagate_encodings=False)param_encodings[tensor]=[encoding]tensor_to_quantizer_map[tensor]=quantizer@staticmethoddef_get_qc_quantized_layers(model)->List[Tuple[str,QcQuantizeWrapper]]:quantized_layers=[]forname,moduleinmodel.named_modules():ifisinstance(module,(QcQuantizeRecurrent,LazyQuantizeWrapper,_QuantizedModuleProtocol)):quantized_layers.append((name,module))returnquantized_layers@classmethoddef_remove_quantization_wrappers(cls,starting_module,list_of_modules_to_exclude):""" Recursively remove quantization wrappers from all appropriate modules starting with a given module :param starting_module: Module to recursive search downstream from :param list_of_modules_to_exclude: List of torch modules to remove quantization wrappers from (if present) :return: None """formodule_name,module_refinstarting_module.named_children():# If modules is in the exclude list, remove the wrapperifmodule_refinlist_of_modules_to_exclude:ifisinstance(module_ref,(_QuantizedModuleProtocol,QcQuantizeRecurrent)):orig_module=module_ref.get_original_module()elifisinstance(module_ref,QcQuantizeStandAloneBase):orig_module=torch.nn.Identity()else:orig_module=Noneiforig_module:setattr(starting_module,module_name,orig_module)module_ref=orig_module# Recursively call children modules if presentifnotutils.is_leaf_module(module_ref):cls._remove_quantization_wrappers(module_ref,list_of_modules_to_exclude)@classmethod@torch.no_grad()def_apply_qdq_to_model_parameters(cls,model:torch.nn.Module):""" Applies quant-dequant to the parameters of a PyTorch model to avoid rounding error during weight quantization. :param model: The PyTorch model whose parameters will be quant-dequantized. """# pylint: disable=protected-accessformoduleinmodel.modules():ifisinstance(module,(QcQuantizeRecurrent,StaticGridQuantWrapper)):withutils.in_eval_mode(module):module._quantize_dequantize_params()elifisinstance(module,(LearnedGridQuantWrapper)):withutils.in_eval_mode(module):module._quantize_params()cls._update_parameters_by_attr(module._module_to_wrap)defnamed_qmodules(self):"""Generator that yields all quantized modules in the model and their names """forname,moduleinself.model.named_modules():ifisinstance(module,(QcQuantizeWrapper,QcQuantizeRecurrent,LazyQuantizeWrapper)):yieldname,modulequant_wrappers=named_qmodules@staticmethoddef_replace_quantization_wrapper_with_native_torch_quantization_nodes(quant_sim_model,device:torch.device):""" Recursively remove quantization wrappers from all appropriate modules starting with a given module :param quant_sim_model: model for which QcQuantizeWrapper gets replaced with wrapped module using native torch quantization nodes :param device: device on which model is present :return: """# Recursively replace quantization wrappers to native torch quantization nodesformodule_name,module_refinquant_sim_model.named_children():# Create a native torch quantization nodeifisinstance(module_ref,QcQuantizeWrapper):embedded_module=NativeTorchQuantWrapper(module_ref,'_module_to_wrap',device)setattr(quant_sim_model,module_name,embedded_module)elifisinstance(module_ref,QcQuantizeRecurrent):logger.error('Do not support save model embedded native torch quantization nodes using QcQuantizeRecurrent.')raiseAssertionError# Recursively call children modules if presentifnotutils.is_leaf_module(module_ref):QuantizationSimModel._replace_quantization_wrapper_with_native_torch_quantization_nodes(module_ref,device)@classmethoddefsave_model_with_embedded_quantization_nodes(cls,sim_model,path:str,filename_prefix:str,dummy_input:Union[torch.Tensor,Tuple],onnx_export_args:Optional[Union[OnnxExportApiArgs,Dict]]=None,export_to_torchscript:bool=False,is_conditional:bool=False):""" Export model embedded with native torch quantization nodes. These nodes will be exported as default onnx or torch script quantized nodes. :param sim_model: model with the quantsim wrappers :param path: path where to store model pth and encodings :param filename_prefix: Prefix to use for filenames of the model pth and encodings files :param dummy_input: Dummy input to the model. Used to parse model graph :param onnx_export_args: optional export argument with onnx specific overrides if not provide export via torchscript graph. Int16 can only be exported by torchscript :param export_to_torchscript: If True, export to torchscript. Export to onnx otherwise. Defaults to False. :param is_conditional: True if model is conditional, False otherwise :return: """def_validate_torchquantizer(quant_sim_model):# To avoid non 8 bit TorchQuantizer are exported to ONNXfor_,moduleinquant_sim_model.named_modules():ifisinstance(module,NativeTorchQuantWrapper):quantizers=module.input_quantizers+module.output_quantizersif'weight'inmodule.param_quantizers:quantizers+=[module.param_quantizers['weight']]if'bias'inmodule.param_quantizers:quantizers+=[module.param_quantizers['bias']]forquantizerinquantizers:ifquantizer.enabledandquantizer.data_type==QuantizationDataType.intandquantizer.bitwidth!=8:raiseValueError('Only 8 bit quantizers are supported by exporting to ONNX model.''Please enable export_to_torchscript if you want to export non 8 bit quantizers.')model_filename=filename_prefix+'_embedded'+'.onnx'model_path=os.path.join(path,model_filename)quant_sim_model=copy.deepcopy(sim_model)device=utils.get_device(quant_sim_model)ifisinstance(dummy_input,torch.Tensor):dummy_input=dummy_input.to(device)else:dummy_input=tuple(input.to(device)forinputindummy_input)QuantizationSimModel._replace_quantization_wrapper_with_native_torch_quantization_nodes(quant_sim_model,device)ifexport_to_torchscript:withutils.in_eval_mode(quant_sim_model),torch.no_grad():trace=torch.jit.trace(quant_sim_model,dummy_input)ts_path=os.path.join(path,filename_prefix+'_embedded'+'.torchscript.pth')trace.save(ts_path)else:_validate_torchquantizer(quant_sim_model)OnnxSaver._export_model_to_onnx(quant_sim_model,dummy_input,model_path,is_conditional,onnx_export_args)# pylint: disable=protected-access
defload_encodings_to_sim(quant_sim_model:_QuantizationSimModelBase,pytorch_encoding_path:str):""" Loads the saved encodings to quant sim model. The encoding filename to load should end in _torch.encodings, generated as part of quantsim export. :param quant_sim_model: Quantized model to load encodings for. Note: The model configuration should be the same as when encodings were exported. :param pytorch_encoding_path: Path of the encodings file to load. """formoduleinquant_sim_model.model.modules():ifisinstance(module,QcQuantizeWrapper):module.set_mode(QcQuantizeOpMode.ACTIVE)quant_sim_model.load_encodings(pytorch_encoding_path,strict=True,partial=False,requires_grad=None,allow_overwrite=None)ifisinstance(quant_sim_model,QuantizationSimModel):# Only for V1 quantsimquant_sim_model.replace_wrappers_for_quantize_dequantize()defcompute_encodings_for_sims(sim_list:List[QuantizationSimModel],forward_pass_callback:Callable,forward_pass_callback_args:Any):""" Compute encodings for a list of QuantSims. :param sim_list: List of QuantSims to compute encodings for. :param forward_pass_callback: A callback function that simply runs forward passes on the models. This callback function should use representative data for the forward pass, so the calculated encodings work for all data samples. This callback internally chooses the number of data samples it wants to use for calculating encodings. The callback expects exactly two inputs: - List of models which are involved in the forward pass. The models are taken directly from calling sim.model for each sim in sim_list, passed in the same order in which the sims appear in sim_list. - Forward pass callback args :param forward_pass_callback_args: These argument(s) are passed to the forward_pass_callback as-is. Up to the user to determine the type of this parameter. E.g. could be simply an integer representing the number of data samples to use. Or could be a tuple of parameters or an object representing something more complex. If set to None, forward_pass_callback will be invoked with no parameters. """ctx_managers=[torch.no_grad()]forsiminsim_list:ctx_managers.append(utils.in_eval_mode(sim.model))QuantizationSimModel.prepare_sim_for_compute_encodings(sim)withcontextlib.ExitStack()asstack:formgrinctx_managers:stack.enter_context(mgr)_=forward_pass_callback([sim.modelforsiminsim_list],forward_pass_callback_args)forsiminsim_list:QuantizationSimModel.compute_layer_encodings_for_sim(sim)