# -*- mode: python -*-# =============================================================================# @@-COPYRIGHT-START-@@## Copyright (c) 2024, Qualcomm Innovation Center, Inc. All rights reserved.## Redistribution and use in source and binary forms, with or without# modification, are permitted provided that the following conditions are met:## 1. Redistributions of source code must retain the above copyright notice,# this list of conditions and the following disclaimer.## 2. Redistributions in binary form must reproduce the above copyright notice,# this list of conditions and the following disclaimer in the documentation# and/or other materials provided with the distribution.## 3. Neither the name of the copyright holder nor the names of its contributors# may be used to endorse or promote products derived from this software# without specific prior written permission.## THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE# POSSIBILITY OF SUCH DAMAGE.## SPDX-License-Identifier: BSD-3-Clause## @@-COPYRIGHT-END-@@# =============================================================================""" Quant Analyzer for AIMET v2"""importosimportcontextlibfromcollectionsimportnamedtuplefromtypingimportTuple,List,Type,Optional,Generatorimporttorchfromaimet_common.quant_analyzerimportexport_stats_histogram_plotfromaimet_torch._base.quant_analyzerimportQuantAnalyzerBasefromaimet_torch.v2.quantsimimportQuantizationSimModelfromaimet_torch.v2.nn.baseimportBaseQuantizationMixinfromaimet_torch.v2.quantization.baseimportQuantizerBasefromaimet_torch.v2.quantization.encoding_analyzerimport_HistogramObserver,_Histogramfromaimet_torch.v2.batch_norm_foldimportfold_all_batch_normsV1Encoding=namedtuple('V1Encoding',['min','max'])
[docs]classQuantAnalyzer(QuantAnalyzerBase):""" QuantAnalyzer tool provides 1) model sensitivity to weight and activation quantization 2) per layer sensitivity analysis 3) per layer encoding (min - max range) 4) per PDF analysis and 5) per layer MSE analysis """@staticmethoddef_get_quantsim_cls()->Type[QuantizationSimModel]:returnQuantizationSimModel@staticmethoddef_get_quant_wrapper_type()->Tuple[Type]:return(BaseQuantizationMixin,)# pylint: disable=no-self-usedef_create_and_export_stats_histogram_plot(self,quantizer:QuantizerBase,results_dir:str,title:str,):""" For given quantizer, create and export histogram (PDF) of statistics in html format. :param quantizer: Quantizer. :param results_dir: Directory to save the results. :param title: Title of the plot. """os.makedirs(results_dir,exist_ok=True)assertisinstance(quantizer.encoding_analyzer.observer,_HistogramObserver)v2_histograms=quantizer.encoding_analyzer.observer.get_stats()histograms=self._convert_to_v1_histograms(v2_histograms)encodings=self._get_quantizer_encodings(quantizer)forindex,(histogram,encoding)inenumerate(zip(histograms,encodings)):export_stats_histogram_plot(histogram,encoding,results_dir,title=f"{title}_{index}")@staticmethoddef_enable_disable_quantizers(quantizers:List[QuantizerBase],enabled:bool):""" For given list of quantizers, set (enable/disable) quantizer's enabled. :param quantizers: List of quantizers. :param enabled: Enabled flag. """raiseRuntimeError("Changing enabled attribute is not allowed in quantsim v2")@classmethoddef_disable_param_quantizers(cls,sim:QuantizationSimModel):# pylint: disable=protected-accessctx=contextlib.ExitStack()forquant_wrapperincls._get_quantized_modules(sim):ctx.enter_context(quant_wrapper._remove_param_quantizers())returnctx@classmethoddef_disable_activation_quantizers(cls,sim:QuantizationSimModel):# pylint: disable=protected-accessctx=contextlib.ExitStack()forquant_wrapperincls._get_quantized_modules(sim):ctx.enter_context(quant_wrapper._remove_activation_quantizers())returnctx@staticmethoddef_disable_quant_wrapper(module:BaseQuantizationMixin):# pylint: disable=protected-accessreturnmodule._remove_all_quantizers()@staticmethoddef_convert_to_v1_histograms(histograms:List[_Histogram])->List:v1_histograms=[]forhistinhistograms:asserthistisnotNone,"Cannot find histogram data in quantizer"hist_sum=torch.sum(hist.histogram).item()v1_hist=[]forbin_edge,hist_valueinzip(hist.bin_edges,hist.histogram):v1_hist.append((bin_edge.item(),hist_value.item()/hist_sum))v1_histograms.append(v1_hist)returnv1_histograms@staticmethoddef_is_quantizer_enabled(quantizer:Optional[QuantizerBase]):returnquantizerisnotNone@classmethoddef_get_quantizer_encodings(cls,quantizer:QuantizerBase)->Optional[List]:v1_encodings=[]encoding=quantizer.get_encodings()ifnotencoding:returnNoneflatten_min=encoding.min.flatten()flatten_max=encoding.max.flatten()forencoding_min,encoding_maxinzip(flatten_min,flatten_max):v1_encodings.append(V1Encoding(min=encoding_min.item(),max=encoding_max.item()))returnv1_encodings@staticmethoddef_get_quantized_modules(sim:QuantizationSimModel)->Generator[BaseQuantizationMixin,None,None]:formoduleinsim.model.modules():ifisinstance(module,BaseQuantizationMixin):yieldmodule@staticmethoddef_fold_all_batch_norms(*args,**kwargs):returnfold_all_batch_norms(*args,**kwargs)