Source code for aimet_tensorflow.auto_quant

# -*- mode: python -*-
# =============================================================================
#  @@-COPYRIGHT-START-@@
#
#  Copyright (c) 2022, Qualcomm Innovation Center, Inc. All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#
#  1. Redistributions of source code must retain the above copyright notice,
#     this list of conditions and the following disclaimer.
#
#  2. Redistributions in binary form must reproduce the above copyright notice,
#     this list of conditions and the following disclaimer in the documentation
#     and/or other materials provided with the distribution.
#
#  3. Neither the name of the copyright holder nor the names of its contributors
#     may be used to endorse or promote products derived from this software
#     without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
#  POSSIBILITY OF SUCH DAMAGE.
#
#  SPDX-License-Identifier: BSD-3-Clause
#
#  @@-COPYRIGHT-END-@@
# =============================================================================

"""Automatic Post-Training Quantization"""
import contextlib
from dataclasses import dataclass
import os
from typing import Any, Callable, Dict, List, Optional, Tuple
import tensorflow as tf
from tqdm import tqdm

import jinja2
from bokeh.resources import CDN

from aimet_tensorflow.adaround.adaround_weight import Adaround, AdaroundParameters
from aimet_tensorflow.cross_layer_equalization import equalize_model
from aimet_tensorflow.batch_norm_fold import fold_all_batch_norms
from aimet_tensorflow.quantsim import QuantizationSimModel
from aimet_tensorflow.utils.graph_saver import load_model_from_meta
from aimet_tensorflow.utils.common import (
    create_input_feed_dict,
    deepcopy_tf_session,
    iterate_tf_dataset,
)
from aimet_tensorflow.cache import TfSessionSerializationProtocol

from aimet_common.auto_quant import Diagnostics
from aimet_common.cache import Cache
from aimet_common.defs import QuantScheme
from aimet_common.utils import AimetLogger, Spinner
from aimet_common.quantsim import validate_quantsim_inputs


tf.compat.v1.disable_eager_execution()

_logger = AimetLogger.get_area_logger(AimetLogger.LogAreas.AutoQuant)


cache = Cache()


# The number of samples to be used for performance evaluation.
# NOTE: None means "all".
NUM_SAMPLES_FOR_PERFORMANCE_EVALUATION = None


[docs]class AutoQuant: """ Integrate and apply post-training quantization techniques. AutoQuant includes 1) batchnorm folding, 2) cross-layer equalization, and 3) Adaround. These techniques will be applied in a best-effort manner until the model meets the evaluation goal given as allowed_accuracy_drop. """ def __init__( # pylint: disable=too-many-arguments self, allowed_accuracy_drop: float, unlabeled_dataset: tf.compat.v1.data.Dataset, eval_callback: Callable[[tf.compat.v1.Session, Optional[int]], float], default_param_bw: int = 8, default_output_bw: int = 8, default_quant_scheme: QuantScheme = QuantScheme.post_training_tf_enhanced, default_rounding_mode: str = 'nearest', default_config_file: str = None, ) -> None: """ :param allowed_accuracy_drop: Maximum allowed accuracy drop. :param unlabeled_dataset: An unlabeled dataset for encoding computation. By default, this dataset will be also used for Adaround unless otherwise specified by `self.set_adaround_params`. :param eval_callback: A function that maps a tf session and the number of samples to the evaluation score. This callback is expected to return a scalar value representing the model performance evaluated against exactly `N` samples, where `N` is the number of samples passed as the second argument of this callback. NOTE: If `N` is None, the model is expected to be evaluated against the whole evaluation dataset. :param default_param_bw: Default bitwidth (4-31) to use for quantizing layer parameters. :param default_output_bw: Default bitwidth (4-31) to use for quantizing layer inputs andoutputs. :param default_quant_scheme: Quantization scheme. Supported values are QuantScheme.post_training_tf or QuantScheme.post_training_tf_enhanced. :param default_rounding_mode: Rounding mode. Supported options are 'nearest' or 'stochastic' :param default_config_file: Path to configuration file for model quantizers """ if allowed_accuracy_drop < 0: raise ValueError( "`allowed_accuracy_drop` must be a positive value. Got {:.2f}" .format(allowed_accuracy_drop) ) validate_quantsim_inputs(default_quant_scheme, default_rounding_mode, default_output_bw, default_param_bw) self.allowed_accuracy_drop = allowed_accuracy_drop self.eval_callback = eval_callback self.default_param_bw = default_param_bw self.default_output_bw = default_output_bw self.default_quant_scheme = default_quant_scheme self.default_rounding_mode = default_rounding_mode self.default_config_file = default_config_file self._unlabeled_dataset = unlabeled_dataset self._unlabled_dataset_length = None self._adaround_params = None @property def adaround_params(self): """Returns the adaround parameter.""" # If adaround_params is manually set, return it. if self._adaround_params is not None: return self._adaround_params # Otherwise, return the default adaround params if the length of the # dataset if known. if self._unlabled_dataset_length is not None: return AdaroundParameters(self._unlabeled_dataset, self._unlabled_dataset_length) return None def _evaluate_model_performance(self, sess: tf.compat.v1.Session) -> float: """ Evaluate the model performance. :param sess: tf.Session associated with the model to evaluate. :return: Evaluation score. """ return self.eval_callback(sess, NUM_SAMPLES_FOR_PERFORMANCE_EVALUATION)
[docs] def set_adaround_params(self, adaround_params: AdaroundParameters) -> None: """ Set Adaround parameters. If this method is not called explicitly by the user, AutoQuant will use `unlabeled_dataset` (passed to `__init__`) for Adaround. :param adaround_params: Adaround parameters. """ self._adaround_params = adaround_params
def _create_quantsim_and_encodings( # pylint: disable=too-many-arguments self, sess: tf.compat.v1.Session, starting_op_names: List[str], output_op_names: List[str], quant_scheme: QuantScheme = None, rounding_mode: str = None, default_output_bw: int = None, default_param_bw: int = None, config_file: str = None, encoding_path: str = None, ) -> QuantizationSimModel: """ Create a QuantizationSimModel and compute encoding. If `encoding_path` is not None, it is prioritized over other arguments (`default_output_bw`, `defalt_param_bw`, ...). NOTE: Input session is not mutated. :param sess: The input model as session to add quantize ops to. :param starting_op_names: List of starting op names of the model. :param output_op_names: List of output op names of the model. :param quant_scheme: Quantization scheme. Defaults to self.default_quant_scheme. :param rounding_mode: Rounding mode. Defaults to self.default_rounding_mode. :param default_output_bw: Default bitwidth (4-31) to use for quantizing layer inputs andoutputs. Defaults to self.default_output_bw. :param default_param_bw: Default bitwidth (4-31) to use for quantizing layer parameters. Defaults to self.default_param_bw. :param config_file: Path to configuration file for model quantizers. Defaults to self.default_config_file. :param encoding_path: Path to parameter encodings file. :return: Quantsim model. """ kwargs = dict( quant_scheme=(quant_scheme or self.default_quant_scheme), rounding_mode=(rounding_mode or self.default_rounding_mode), default_output_bw=(default_output_bw or self.default_output_bw), default_param_bw=(default_param_bw or self.default_param_bw), config_file=(config_file or self.default_config_file), ) with deepcopy_tf_session(sess) as sess: # pylint: disable=redefined-argument-from-local sim = QuantizationSimModel(sess, starting_op_names, output_op_names, **kwargs) if encoding_path: sim.set_and_freeze_param_encodings(encoding_path) def forward_pass_callback(sess: tf.compat.v1.Session, _: Any = None): output_ops = [ sess.graph.get_operation_by_name(op_name) for op_name in output_op_names ] count = 0 iterator = iterate_tf_dataset(self._unlabeled_dataset) for inputs in tqdm(iterator, total=self._unlabled_dataset_length): feed_dict = create_input_feed_dict(sess.graph, starting_op_names, inputs) sess.run(output_ops, feed_dict=feed_dict) count += 1 self._unlabled_dataset_length = count sim.compute_encodings(forward_pass_callback, None) return sim def _apply_batchnorm_folding( # pylint: disable=no-self-use self, sess: tf.compat.v1.Session, starting_op_names: List[str], output_op_names: List[str], ) -> Tuple[tf.compat.v1.Session, List[Tuple[tf.Operation, tf.Operation]]]: """ Apply batchnorm folding. NOTE: Input session is not mutated. :param sess: tf.Session associated with the model to apply cle. :param starting_op_names: List of starting op names of the model. :param output_op_names: List of output op names of the model. :return: Output session and folded pairs. """ # NOTE: We don't apply caching to batchnorm folding because caching is # likely going to have an adverse effect on the performance. # Since a tf.Operation contains a reference to the graph it belongs # to, serializing a subset of operations of a tf.Graph requires # serializing the whole graph, making the serialization cost very # likely to exceed the evaluation cost. with deepcopy_tf_session(sess) as sess: # pylint: disable=redefined-argument-from-local return fold_all_batch_norms(sess, starting_op_names, output_op_names) @cache.mark("cle", TfSessionSerializationProtocol()) def _apply_cross_layer_equalization( # pylint: disable=no-self-use self, sess: tf.compat.v1.Session, starting_op_names: List[str], output_op_names: List[str], ) -> tf.compat.v1.Session: """ Apply cross-layer equalization. NOTE: Input session is not mutated. :param sess: tf.Session associated with the model to apply batchnorm folding. :param starting_op_names: List of starting op names of the model. :param output_op_names: List of output op names of the model. :return: Output session. """ with deepcopy_tf_session(sess) as sess: # pylint: disable=redefined-argument-from-local return equalize_model(sess, starting_op_names, output_op_names) def _apply_adaround( self, sess: tf.compat.v1.Session, starting_op_names: List[str], output_op_names: List[str], results_dir: str, ) -> Tuple[tf.compat.v1.Session, str]: """ Apply adaround. :param sess: tf.Session associated with the model to apply adaround. :param starting_op_names: List of starting op names of the model. :param output_op_names: List of output op names of the model. :param results_dir: Directory to save the results of AdaRound. :return: Output session and the path to the parameter encoding file. """ # NOTE: We dont need to make a deepcopy of model here, since Adaround.apply_adaround # internally creates and returns a deepcopy of model. if self.adaround_params is None: raise RuntimeError filename_prefix = "adaround" adaround_encoding_path = os.path.join(results_dir, "{}.encodings".format(filename_prefix)) _apply_adaround_cached =\ cache.mark("adaround", TfSessionSerializationProtocol())\ (Adaround.apply_adaround) sess = _apply_adaround_cached(sess, starting_op_names, output_op_names, self.adaround_params, path=results_dir, filename_prefix=filename_prefix, default_param_bw=self.default_param_bw, default_quant_scheme=self.default_quant_scheme, default_config_file=self.default_config_file) return sess, adaround_encoding_path
[docs] def apply( self, fp32_sess: tf.compat.v1.Session, starting_op_names: List[str], output_op_names: List[str], results_dir: str = "/tmp", cache_id: str = None, ) -> Tuple[tf.compat.v1.Session, float, str]: """ Apply post-training quantization techniques. :param fp32_sess: tf.Session associated with the model to apply PTQ techniques. :param starting_op_names: List of starting op names of the model. :param output_op_names: List of output op names of the model. :param results_dir: Directory to save the results. :return: Tuple of (best session, eval score, encoding path). """ result = self._apply_helper(self._auto_quant_main, fp32_sess, starting_op_names, output_op_names, results_dir, cache_id) return result["model"],\ result["accuracy"],\ result["encoding_path"]
def _apply_helper( self, auto_quant_main_fn: Callable, fp32_sess: tf.compat.v1.Session, starting_op_names: List[str], output_op_names: List[str], results_dir: str = "/tmp", cache_id: str = None, ) -> Dict[str, Any]: """ Helper for self.apply(). :param fp32_sess: tf.Session associated with the model to apply PTQ techniques. :param starting_op_names: List of starting op names of the model. :param output_op_names: List of output op names of the model. :param results_dir: Directory to save the results. :return: The best ptq result as a dictionary. """ results_dir = os.path.abspath(results_dir) os.makedirs(results_dir, exist_ok=True) if cache_id is None: cache_dir = None else: cache_dir = os.path.join(results_dir, ".auto_quant_cache", cache_id) with cache.enable(cache_dir): _logger.info("Starting AutoQuant") fp32_acc = self._evaluate_model_performance(fp32_sess) target_acc = fp32_acc - self.allowed_accuracy_drop _logger.info("Target eval score: %f", target_acc) _logger.info("FP32 eval score (W32A32): %f", fp32_acc) eval_manager = _EvalManager( quantsim_factory=self._create_quantsim_and_encodings, eval_func=self._evaluate_model_performance, starting_op_names=starting_op_names, output_op_names=output_op_names, results_dir=results_dir, ) ret = auto_quant_main_fn(fp32_sess, target_acc, starting_op_names, output_op_names, eval_manager, results_dir) acc = ret["accuracy"] _logger.info("Best eval score: %f", acc) if acc < target_acc: _logger.info( "AutoQuant is unable to match the target accuracy. " "Consider Quantization Aware Training." ) eval_manager.export_diagnostics() return ret def _auto_quant_main( self, fp32_sess: tf.compat.v1.Session, target_acc: float, starting_op_names: List[str], output_op_names: List[str], eval_manager: "_EvalManager", results_dir: str = "/tmp", ) -> Dict[str, Any]: """ Helper function of apply(). :param fp32_sess: Model to apply PTQ techniques. :param target_acc: Target eval score. :param starting_op_names: List of starting op names of the model. :param output_op_names: List of output op names of the model. :param eval_manager: _Evalmanager object. :param results_dir: Directory to save the results. :return: The best ptq result as a dictionary. """ with eval_manager.analysis_session("Weight Quantization Sensitivity") as s: acc = s.eval(fp32_sess, default_output_bw=32) s.diagnostics.add( f"Weight-quantized eval score (W{self.default_param_bw}A32): {acc:f}" ) with eval_manager.analysis_session("Activation Quantization Sensitivity") as s: acc = s.eval(fp32_sess, default_param_bw=32) s.diagnostics.add( f"Activation-quantized eval score (W32A{self.default_output_bw}): {acc:f}" ) # Batchnorm Folding with eval_manager.ptq_session("Batchnorm Folding") as s: sess, folded_pairs = self._apply_batchnorm_folding(fp32_sess, starting_op_names, output_op_names) for conv, bn in folded_pairs: s.diagnostics.add(f"{conv} was merged with {bn}.") s.set_ptq_result(sess=sess, applied_techniques=["batchnorm_folding"]) best_result = eval_manager.get_best_ptq_result() if best_result.accuracy >= target_acc: return best_result.as_dict() # Cross-Layer Equalization with eval_manager.ptq_session("Cross-Layer Equalization") as s: sess = self._apply_cross_layer_equalization(fp32_sess, starting_op_names, output_op_names) s.set_ptq_result(sess=sess, applied_techniques=["cross_layer_equalization"]) best_result = eval_manager.get_best_ptq_result() if best_result.accuracy >= target_acc: return best_result.as_dict() # AdaRound with eval_manager.ptq_session("AdaRound") as s: sess, encoding_path = self._apply_adaround(best_result.load_model(), starting_op_names, output_op_names, results_dir) s.set_ptq_result(sess=sess, encoding_path=encoding_path, applied_techniques=[*best_result.applied_techniques, "adaround"]) return eval_manager.get_best_ptq_result().as_dict()
@dataclass class PtqResult: """ Evaluation results. :param tag: Identifier string of the evaluation result. :param model_path: Path to the serialized model. :param encoding_path: Path to the encoding file. :param accuracy: Accuracy of the model. """ meta_path: str checkpoint_path: str encoding_path: str accuracy: float applied_techniques: List[str] def load_model(self) -> tf.compat.v1.Session: """ Load model. :return: Loaded model. """ return load_model_from_meta(self.meta_path, self.checkpoint_path) def as_dict(self): """Convert to dictionary""" return dict(model=self.load_model(), accuracy=self.accuracy, encoding_path=self.encoding_path, applied_techniques=self.applied_techniques) class _EvalManager: """ Evaluation manager for AutoQuant. """ def __init__(self, quantsim_factory: Callable, eval_func: Callable[[tf.compat.v1.Session], float], starting_op_names: List[str], output_op_names: List[str], results_dir: str): """ :param quantsim_factory: A factory function that returns QuantizationSimModel. :param eval_func: Evaluation function. :param dummy_input: Dummy input to the model. Assumed to be located on the same device as the model. :param dummy_input_on_cpu: Dummy input to the model in CPU memory. :param results_dir: Base directory to save the temporary serialized model. """ self._quantsim_factory = quantsim_factory self._eval_func = eval_func self._starting_op_names = starting_op_names self._output_op_names = output_op_names self._results_dir = results_dir os.makedirs(self._results_dir, exist_ok=True) self._all_sessions: List[_EvalSession] = [] self._ptq_sessions: List[_PtqSession] = [] def get_best_ptq_result(self) -> PtqResult: """ Get the results with the highest evaluation score among the ptq results evaluated so far. :return: The best evaluation result so far. """ if not self._ptq_sessions: raise RuntimeError ptq_results = [sess.ptq_result for sess in self._ptq_sessions] return max(ptq_results, key=lambda ptq_result: ptq_result.accuracy) def analysis_session(self, title: str) -> "_EvalSession": """ Return a session for analysis only. :param title: Title of the session. :return: Analysis session. """ return self._get_session(title, _EvalSession) def ptq_session(self, title: str) -> "_PtqSession": """ Return a session for analysis only. :param title: Title of the session. :return: PTQ session. """ sess = self._get_session(title, _PtqSession) self._ptq_sessions.append(sess) return sess def _get_session(self, title: str, session_cls: type): """ Session factory. :param title: Title of the session. :session_cls: Class of the session. :return: Session object. """ session = session_cls(title, self._quantsim_factory, self._eval_func, self._starting_op_names, self._output_op_names, results_dir=os.path.join(self._results_dir, ".trace")) self._all_sessions.append(session) return session def export_diagnostics(self) -> str: """ Export diagnostics in html format. :return: Diagnostics string in html format. """ loader = jinja2.FileSystemLoader(os.path.dirname(os.path.abspath(__file__))) env = jinja2.Environment(loader=loader) template = env.get_template("auto_quant_diagnostics_template.html") if any(sess.diagnostics.contains_bokeh() for sess in self._all_sessions): head = CDN.render() else: head = "" body = { sess.title: sess.diagnostics for sess in self._all_sessions if not sess.diagnostics.is_empty() } html = template.render(head=head, body=body) filename = os.path.join(self._results_dir, "diagnostics.html") with open(filename, "w") as f: f.write(html) return html class _EvalSession: """ Evaluation session for AutoQuant. Each session object contains a title and diagnostics produced during the session. The collected diagnostics will be exported into a html file by _EvalManager. """ def __init__( self, title: str, quantsim_factory: Callable, eval_func: Callable[[tf.compat.v1.Session], float], starting_op_names: List[str], output_op_names: List[str], results_dir: str ): """ :param title: Title of the session. :param quantsim_factory: A factory function that returns QuantizationSimModel. :param eval_func: Evaluation function. :param dummy_input: Dummy input to the model. Assumed to be located on the same device as the model. :param dummy_input_on_cpu: Dummy input to the model in CPU memory. :param results_dir: Base directory to save the temporary serialized model. """ self._title = title self._quantsim_factory = quantsim_factory self._eval_func = eval_func self._starting_op_names = starting_op_names self._output_op_names = output_op_names self._results_dir = results_dir self._spinner = None os.makedirs(self._results_dir, exist_ok=True) self._diagnostics = Diagnostics() # Map session title to file name. # e.g. title: "Cross-Layer Equalization" -> filename: "cross_layer_equalization" self._filename = self._title.lower().replace("-", " ") self._filename = "_".join(self._filename.split()) @property def title(self): """Getter of self._title.""" return self._title @property def diagnostics(self): """Getter of self._diagnostics.""" return self._diagnostics def eval(self, sess: tf.compat.v1.Session, **kwargs): """ Evaluate the model. :param sess: tf.Session associated with the model to evaluate. :param **kwargs: Additional arguments to the quantsim factory. :return: Eval score. """ sim = self._quantsim_factory(sess, self._starting_op_names, self._output_op_names, **kwargs) acc = self._eval_func(sim.session) return acc def __enter__(self): self._spinner = Spinner(self._title) self._spinner.__enter__() return self def __exit__(self, exc_type, exc_val, exc_tb): try: if self._spinner is not None: self._spinner.__exit__(exc_type, exc_val, exc_tb) finally: if exc_val is not None: raise exc_val class _PtqSession(_EvalSession): """ PTQ session. Each PTQ session object should call `set_ptq_result` exactly once inside a with-as block. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._ptq_result = None @property def ptq_result(self) -> PtqResult: """Getter of self._ptq_result.""" if self._ptq_result is None: raise RuntimeError return self._ptq_result def set_ptq_result( self, applied_techniques: List[str], sess: tf.compat.v1.Session = None, sim: QuantizationSimModel = None, acc: float = None, **kwargs ) -> None: """ Set the result of PTQ. Should be called exactly once inside a with-as block. Exactly one among model and (sim, acc) pair should be specified. 1) If sim and acc is specified, save them as the result of this session. 2) If model is specified, evaluate the quantized accuracy of the model and save the result. :param sess: Result of PTQ. :param sim: Result of PTQ. The quamtization encoding (compute_encodings()) is assumed to have been computed in advance. :param acc: Eval score. :param **kwargs: Additional arguments to the quantsim factory. :return: None """ if sim is None: assert acc is None assert sess is not None sim = self._quantsim_factory(sess, self._starting_op_names, self._output_op_names, **kwargs) acc = self._eval_func(sim.session) else: assert acc is not None assert sess is None self._set_ptq_result(sim, acc, applied_techniques) def _set_ptq_result( self, sim: QuantizationSimModel, acc: float, applied_techniques: List[str], ) -> PtqResult: """ Set the result of PTQ. Should be called exactly once inside a with-as block. :param sim: Result of PTQ. The quamtization encoding (compute_encodings()) is assumed to have been computed in advance. :param acc: Eval score. :return: PtqResult object. """ if self._ptq_result is not None: raise RuntimeError( "sess.eval() can be called only once per each _EvalSession instance." ) meta_path, checkpoint_path, encoding_path = self._export(sim) self._ptq_result = PtqResult( meta_path=meta_path, checkpoint_path=checkpoint_path, encoding_path=encoding_path, accuracy=acc, applied_techniques=applied_techniques, ) return self._ptq_result def _export(self, sim: QuantizationSimModel) -> Tuple[str, str, str]: """ Export quantsim. :param sim: QuantizationSimModel object to export. :return: The paths where model and encoding are saved """ sim.export(path=self._results_dir, filename_prefix=self._filename) checkpoint_path = os.path.join(self._results_dir, self._filename) meta_path = f"{checkpoint_path}.meta" encoding_path = f"{checkpoint_path}.encodings" _logger.info("The results of %s is saved in %s, %s, and %s.", self._title, checkpoint_path, meta_path, encoding_path) return meta_path, checkpoint_path, encoding_path def __exit__(self, exc_type, exc_val, exc_tb): """Raises error if set_ptq_result is not called.""" super().__exit__(exc_type, exc_val, exc_tb) if self._ptq_result is None: raise RuntimeError _logger.info("Session finished: %s. (eval score: %f)", self._title, self._ptq_result.accuracy) @contextlib.contextmanager def spy_auto_quant(auto_quant: AutoQuant): """ Install a spy that collects the handles to the ptq result of each stage of AutoQuant. Typical usage:: >>> auto_quant = AutoQuant(...) ... with auto_quant_spy(auto_quant) as spy: ... _ = auto_quant.apply(...) ... ... for result in spy.get_all_ptq_results(): ... print(result.applied_techniques) ... print(result.accuracy) ... print(result.encoding_path) ... model = result.load_model() ... ... """ # pylint: disable=protected-access class Spy: """ Spy that collects the handles to the ptq result of each stage of AutoQuant. """ def __init__(self): self._eval_manager = None def get_all_ptq_results(self) -> List[PtqResult]: """Return handles to the results of AutoQuant""" if self._eval_manager is None: return [] return [sess.ptq_result for sess in self._eval_manager._ptq_sessions] spy = Spy() _auto_quant_main = auto_quant._auto_quant_main def _auto_quant_main_wrapper(fp32_sess, target_acc, starting_op_names, output_op_names, eval_manager, results_dir="/tmp"): spy._eval_manager = eval_manager return _auto_quant_main(fp32_sess, target_acc, starting_op_names, output_op_names, eval_manager, results_dir) try: setattr(auto_quant, "_auto_quant_main", _auto_quant_main_wrapper) yield spy finally: setattr(auto_quant, "_auto_quant_main", _auto_quant_main)