AdaScale

Note

This feature is currently experimental. The API may change in the future.

Context

AdaScale is a post-training quantization (PTQ) technique that recovers accuracy lost during INT4 weight quantization without any fine-tuning. It works by learning optimal per-weight scaling parameters through Blockwise Knowledge Distillation (BKD): the quantized output of each transformer block is optimized to match its FP32 equivalent until the two converge.

AdaScale is based on FlexRound and integrates learnable weight clipping from OmniQuant.

A block is a transformer decoder layer that accepts a single activation tensor as input and produces a single activation tensor as output. For all supported model families, decoder layers are contiguous by default so no special configuration is required.

Workflow

Prerequisites

To use AdaScale, you need:

  • A pre-trained model loaded from HuggingFace. Supported model families: Llama, Qwen2, Mistral, Phi3, Qwen3. PyTorch additionally supports vision-language models: Qwen2.5-VL, Qwen3-VL.

  • ONNX only: the model must be exported to ONNX with the input naming convention required by AdaScale — Step 2 handles this.

Note

For a complete working example including all steps below, see Examples/torch/quantize.py or Examples/onnx/quantize.py (run with --recipe pcq_spinquant_adascale).

Procedure

Step 1: Load model

Load the HuggingFace model and wrap it with ONNXExportableModuleWithCache to enable JIT tracing with a static graph — required for both the PyTorch and ONNX workflows.

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from GenAITests.shared.models.utils.model_utils import ONNXExportableModuleWithCache

SEQUENCE_LENGTH = 2048
CONTEXT_LENGTH = 4096

model_id = "meta-llama/Llama-3.2-1B-Instruct"
hf_model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float32)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, trust_remote_code=True)

# Wrap model to satisfy static graph constraints for JIT trace
traceable_model = ONNXExportableModuleWithCache(hf_model)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from GenAITests.shared.models.utils.model_utils import ONNXExportableModuleWithCache

SEQUENCE_LENGTH = 2048
CONTEXT_LENGTH = 4096

model_id = "meta-llama/Llama-3.2-1B-Instruct"
hf_model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float32)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, trust_remote_code=True)

# Wrap model to satisfy static graph constraints for JIT trace
traceable_model = ONNXExportableModuleWithCache(hf_model)

Step 2: Create QuantizationSimModel

Create a QuantizationSimModel with the desired quantization configuration. For ONNX, this step also exports the model to ONNX first — the ONNX tab includes the torch.onnx.export call that produces the correctly named inputs required by AdaScale.

from aimet_torch.common.defs import QuantScheme
from aimet_torch import QuantizationSimModel
from GenAITests.shared.models.base import LLM
from GenAITests.shared.models.generator import Generator

assembled_dummy_inputs = Generator.prepare_inputs(
    model=traceable_model,
    input_ids=torch.zeros((1, SEQUENCE_LENGTH), dtype=torch.int),
    attention_mask=torch.ones((1, SEQUENCE_LENGTH), dtype=torch.int),
    past_key_values=[],
    context_length=CONTEXT_LENGTH,
    sequence_length=SEQUENCE_LENGTH,
)

quantsim = QuantizationSimModel(
    model=traceable_model,
    quant_scheme=QuantScheme.post_training_tf,
    dummy_input=assembled_dummy_inputs,
    default_output_bw=16,
    default_param_bw=4,
    in_place=True,
    config_file="htp_v73",
)

generator = Generator(quantsim.model, tokenizer, SEQUENCE_LENGTH, CONTEXT_LENGTH)
import os
import tempfile
import onnx
from aimet_onnx.quantsim import QuantizationSimModel
from GenAITests.shared.models.base import LLM
from GenAITests.shared.models.generator import Generator
from GenAITests.onnx.models.utils.torch_onnx_interface import TorchONNXInterface
from GenAITests.onnx.models.utils.quantsim_utils import (
    _set_tensors_to_output_n_bit_symmmetric,
    _tie_quantizers_for_kv_cache,
    _set_lm_head_to_8b,
)

assembled_dummy_inputs = Generator.prepare_inputs(
    model=traceable_model,
    input_ids=torch.zeros((1, SEQUENCE_LENGTH), dtype=torch.int),
    attention_mask=torch.ones((1, SEQUENCE_LENGTH), dtype=torch.int),
    past_key_values=[],
    context_length=CONTEXT_LENGTH,
    sequence_length=SEQUENCE_LENGTH,
)

# Export to ONNX using LLM.get_backbone_input_names to produce the input naming
# convention required by AdaScale (input_ids, attention_mask, position_ids,
# past_key_0_in, past_value_0_in, ...)
with tempfile.TemporaryDirectory() as tmpdir:
    torch.onnx.export(
        traceable_model,
        assembled_dummy_inputs,
        os.path.join(tmpdir, "model.onnx"),
        input_names=LLM.get_backbone_input_names(hf_model.config.num_hidden_layers),
        output_names=LLM.get_backbone_output_names(hf_model.config.num_hidden_layers),
        opset_version=17,
        dynamo=False,
    )
    onnx_model = onnx.load(os.path.join(tmpdir, "model.onnx"))

quantsim = QuantizationSimModel(
    model=onnx_model,
    quant_scheme="min_max",
    default_activation_bw=16,
    default_param_bw=4,
    config_file="htp_v73",
    providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
)
# Setting kv_cache and some other layers to 8-bit
_set_tensors_to_output_n_bit_symmmetric(quantsim, kv_bits=8)
_set_lm_head_to_8b(quantsim)
_tie_quantizers_for_kv_cache(quantsim)

quantsim_with_torch_interface = TorchONNXInterface(quantsim, hf_model.config)
generator = Generator(quantsim_with_torch_interface, tokenizer, SEQUENCE_LENGTH, CONTEXT_LENGTH)

Step 3: Apply AdaScale

Apply AdaScale to find optimal weight quantization encodings for each supported block.

_prefill_inputs collects the full model inputs (including KV cache tensors) from the calibration dataset; AdaScale derives per-block activations internally and uses them for BKD. This is the most time-consuming step; expect 2–6 hours depending on model size and iteration count (see the timing column in Quantization recipes for LLMs).

ADASCALE_NUM_BATCHES and ADASCALE_NUM_ITERATIONS trade accuracy against runtime. The values below are validated per model size; if your model is not listed, start from the closest row. See Quantization recipes for LLMs for full results.

Model

num_batches

num_iterations

Qwen/Qwen2.5-0.5B-Instruct

128

2048

meta-llama/Llama-3.2-1B-Instruct

128

2048

Qwen/Qwen2.5-1.5B-Instruct

128

1024

meta-llama/Llama-3.2-3B-Instruct

128

1024

Qwen/Qwen3-4B

128

512

microsoft/Phi-3.5-mini-instruct

32

256

from aimet_torch.experimental.adascale.adascale_optimizer import apply_adascale
from GenAITests.shared.helpers.datasets import Wikitext
from GenAITests.torch.helpers.quant_recipes import _prefill_inputs

ADASCALE_NUM_BATCHES = 128   # reduce for larger models to control runtime
ADASCALE_NUM_ITERATIONS = 2048  # reduce for larger models; see quantization recipes

train_dataset = Wikitext.load_encoded_dataset(tokenizer, CONTEXT_LENGTH, "train")
prefilled_inputs = _prefill_inputs(
    generator, train_dataset, num_batches=ADASCALE_NUM_BATCHES, device=torch.device("cpu")
)

apply_adascale(
    qsim=quantsim,
    data_loader=prefilled_inputs,
    num_iterations=ADASCALE_NUM_ITERATIONS,
)
from aimet_onnx.experimental.adascale.adascale_optimizer import (
    AdaScale,
    adascale_model_config_dict,
)
from GenAITests.shared.helpers.datasets import Wikitext
from GenAITests.onnx.helpers.quant_recipes import _prefill_inputs

ADASCALE_NUM_BATCHES = 128   # reduce for larger models to control runtime
ADASCALE_NUM_ITERATIONS = 2048  # reduce for larger models; see quantization recipes

train_dataset = Wikitext.load_encoded_dataset(tokenizer, CONTEXT_LENGTH, "train")
prefilled_inputs = _prefill_inputs(
    quantsim, generator, train_dataset, num_batches=ADASCALE_NUM_BATCHES
)

AdaScale.apply_adascale(
    quantsim,
    prefilled_inputs,
    adascale_model_config=adascale_model_config_dict[generator.config.model_type],
    num_iterations=ADASCALE_NUM_ITERATIONS,
)

Step 4: Compute activation encodings

AdaScale optimizes weight encodings only. This step calibrates the remaining activation quantizers by running the model through a representative dataset.

import itertools
from tqdm import tqdm

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")


def calibration_callback(model):
    for batch in tqdm(
        itertools.islice(train_dataset, 20), total=20, desc="Calibrating"
    ):
        generator(input_ids=batch["input_ids"].to(device=device))


quantsim.compute_encodings(calibration_callback)
from tqdm import tqdm

calib_inputs = _prefill_inputs(quantsim, generator, train_dataset, num_batches=20)


def _forward(session, _):
    for batch in tqdm(calib_inputs, total=len(calib_inputs), desc="Calibrating"):
        session.run(None, batch)


quantsim.compute_encodings(_forward, tuple())

After completing these steps, export the quantized model:

  • PyTorch: quantsim.onnx.export(...)

  • ONNX: quantsim.export(...)

See Examples/torch/quantize.py or Examples/onnx/quantize.py for the export invocation.

API

Top level APIs

aimet_torch.experimental.adascale.apply_adascale(qsim, data_loader, forward_fn=None, num_iterations=1500)
Parameters:
  • qsim (QuantizationSimModel) – Quantization Sim model

  • data_loader (DataLoader) – DataLoader object to load the input data

  • forward_fn (Optional[Callable[[Module, Any], Any]]) – forward function to run the forward pass of the model

  • num_iterations (int) – Number of iterations to optimize for during AdaScale BKD

Note that the forward_fn should take exactly two arguments - 1) the model 2) The object returned from the dataloader irrespective of whether it’s a tensor/tuple of tensors/dict/etc

The forward_fn should prepare the “input sample” as needed and call the forward pass in the very end. The forward_fn should not be running any sort of eval, creating full dataloader inside the method, etc.

Example usage:
>>> model = DummyModel()
>>> dummy_input = ...
>>> data_set = DataSet(dummy_input)
>>> data_loader = DataLoader(data_set, ...)
>>> sim = QuantizationSimModel(model, dummy_input)
>>> apply_adascale(sim, data_loader, forward_fn=forward_fn, num_iterations=1500)
>>> sim.compute_encodings(...)
>>> sim.export(...)
  1. apply_adascale modifies the weights in-place in the model

  2. compute encodings should not be called before the apply_adascale call

  3. Activation quantizers will remain uninitialized throughout the feature, and so compute encodings needs to be called by the user afterwards. This is so activation encodings will be computed with updated weights taken into account.

Warning: This feature is currently considered experimental pending API changes

Top level APIs

aimet_onnx.experimental.adascale.adascale_optimizer.apply_adascale(sim, inputs, adascale_model_config, num_iterations=1500)
Parameters:
  • sim (QuantizationSimModel) – Quantization Sim model

  • inputs (Collection[Dict[str, ndarray]]) – (Collection[Dict[str, np.ndarray]]): The set of input samples to use during optimization.

  • adascale_model_config (AdaScaleModelConfig) – Adascale model config. There are pre-defined configs for Llama, Qwen2, Mistral, Qwen3, Phi3. For other models use AdaScaleModelConfig

  • num_iterations (int) – Number of iterations to optimize for during AdaScale

Example usage:
>>> model = DummyModel()
>>> inputs = ...
>>> adascale_model_config = adascale_model_config['llama']
>>> sim = QuantizationSimModel(model)
>>> apply_adascale(sim, inputs, adascale_model_config, num_iterations=num_iterations)
>>> sim.compute_encodings(...)
>>> sim.export(...)
  1. apply_adascale modifies the weights in-place in the model

  2. compute encodings should not be called before the apply_adascale call

  3. Activation quantizers will remain uninitialized throughout the feature, and so compute encodings needs to be called by the user afterwards. This is so activation encodings will be computed with updated weights taken into account.

Warning: This feature is currently considered experimental pending API changes

class aimet_onnx.experimental.adascale.adascale_optimizer.AdaScaleModelConfig(model_type, beta_gamma_lr=0.001, scales_lr=0.0005)[source]