AdaScale¶
Note
This feature is currently experimental. The API may change in the future.
Context¶
AdaScale is a post-training quantization (PTQ) technique that recovers accuracy lost during INT4 weight quantization without any fine-tuning. It works by learning optimal per-weight scaling parameters through Blockwise Knowledge Distillation (BKD): the quantized output of each transformer block is optimized to match its FP32 equivalent until the two converge.
AdaScale is based on FlexRound and integrates learnable weight clipping from OmniQuant.
A block is a transformer decoder layer that accepts a single activation tensor as input and produces a single activation tensor as output. For all supported model families, decoder layers are contiguous by default so no special configuration is required.
Workflow¶
Prerequisites¶
To use AdaScale, you need:
A pre-trained model loaded from HuggingFace. Supported model families:
Llama,Qwen2,Mistral,Phi3,Qwen3. PyTorch additionally supports vision-language models:Qwen2.5-VL,Qwen3-VL.ONNX only: the model must be exported to ONNX with the input naming convention required by AdaScale — Step 2 handles this.
Note
For a complete working example including all steps below, see
Examples/torch/quantize.py
or
Examples/onnx/quantize.py
(run with --recipe pcq_spinquant_adascale).
Procedure¶
Step 1: Load model¶
Load the HuggingFace model and wrap it with ONNXExportableModuleWithCache to enable JIT tracing
with a static graph — required for both the PyTorch and ONNX workflows.
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from GenAITests.shared.models.utils.model_utils import ONNXExportableModuleWithCache
SEQUENCE_LENGTH = 2048
CONTEXT_LENGTH = 4096
model_id = "meta-llama/Llama-3.2-1B-Instruct"
hf_model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float32)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, trust_remote_code=True)
# Wrap model to satisfy static graph constraints for JIT trace
traceable_model = ONNXExportableModuleWithCache(hf_model)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from GenAITests.shared.models.utils.model_utils import ONNXExportableModuleWithCache
SEQUENCE_LENGTH = 2048
CONTEXT_LENGTH = 4096
model_id = "meta-llama/Llama-3.2-1B-Instruct"
hf_model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float32)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, trust_remote_code=True)
# Wrap model to satisfy static graph constraints for JIT trace
traceable_model = ONNXExportableModuleWithCache(hf_model)
Step 2: Create QuantizationSimModel¶
Create a QuantizationSimModel with the desired quantization configuration.
For ONNX, this step also exports the model to ONNX first — the ONNX tab includes the
torch.onnx.export call that produces the correctly named inputs required by AdaScale.
from aimet_torch.common.defs import QuantScheme
from aimet_torch import QuantizationSimModel
from GenAITests.shared.models.base import LLM
from GenAITests.shared.models.generator import Generator
assembled_dummy_inputs = Generator.prepare_inputs(
model=traceable_model,
input_ids=torch.zeros((1, SEQUENCE_LENGTH), dtype=torch.int),
attention_mask=torch.ones((1, SEQUENCE_LENGTH), dtype=torch.int),
past_key_values=[],
context_length=CONTEXT_LENGTH,
sequence_length=SEQUENCE_LENGTH,
)
quantsim = QuantizationSimModel(
model=traceable_model,
quant_scheme=QuantScheme.post_training_tf,
dummy_input=assembled_dummy_inputs,
default_output_bw=16,
default_param_bw=4,
in_place=True,
config_file="htp_v73",
)
generator = Generator(quantsim.model, tokenizer, SEQUENCE_LENGTH, CONTEXT_LENGTH)
import os
import tempfile
import onnx
from aimet_onnx.quantsim import QuantizationSimModel
from GenAITests.shared.models.base import LLM
from GenAITests.shared.models.generator import Generator
from GenAITests.onnx.models.utils.torch_onnx_interface import TorchONNXInterface
from GenAITests.onnx.models.utils.quantsim_utils import (
_set_tensors_to_output_n_bit_symmmetric,
_tie_quantizers_for_kv_cache,
_set_lm_head_to_8b,
)
assembled_dummy_inputs = Generator.prepare_inputs(
model=traceable_model,
input_ids=torch.zeros((1, SEQUENCE_LENGTH), dtype=torch.int),
attention_mask=torch.ones((1, SEQUENCE_LENGTH), dtype=torch.int),
past_key_values=[],
context_length=CONTEXT_LENGTH,
sequence_length=SEQUENCE_LENGTH,
)
# Export to ONNX using LLM.get_backbone_input_names to produce the input naming
# convention required by AdaScale (input_ids, attention_mask, position_ids,
# past_key_0_in, past_value_0_in, ...)
with tempfile.TemporaryDirectory() as tmpdir:
torch.onnx.export(
traceable_model,
assembled_dummy_inputs,
os.path.join(tmpdir, "model.onnx"),
input_names=LLM.get_backbone_input_names(hf_model.config.num_hidden_layers),
output_names=LLM.get_backbone_output_names(hf_model.config.num_hidden_layers),
opset_version=17,
dynamo=False,
)
onnx_model = onnx.load(os.path.join(tmpdir, "model.onnx"))
quantsim = QuantizationSimModel(
model=onnx_model,
quant_scheme="min_max",
default_activation_bw=16,
default_param_bw=4,
config_file="htp_v73",
providers=["CUDAExecutionProvider", "CPUExecutionProvider"],
)
# Setting kv_cache and some other layers to 8-bit
_set_tensors_to_output_n_bit_symmmetric(quantsim, kv_bits=8)
_set_lm_head_to_8b(quantsim)
_tie_quantizers_for_kv_cache(quantsim)
quantsim_with_torch_interface = TorchONNXInterface(quantsim, hf_model.config)
generator = Generator(quantsim_with_torch_interface, tokenizer, SEQUENCE_LENGTH, CONTEXT_LENGTH)
Step 3: Apply AdaScale¶
Apply AdaScale to find optimal weight quantization encodings for each supported block.
_prefill_inputs collects the full model inputs (including KV cache tensors) from the calibration
dataset; AdaScale derives per-block activations internally and uses them for BKD.
This is the most time-consuming step; expect 2–6 hours depending on model size and iteration count
(see the timing column in Quantization recipes for LLMs).
ADASCALE_NUM_BATCHES and ADASCALE_NUM_ITERATIONS trade accuracy against runtime. The values
below are validated per model size; if your model is not listed, start from the closest row.
See Quantization recipes for LLMs for full results.
Model |
|
|
|---|---|---|
Qwen/Qwen2.5-0.5B-Instruct |
128 |
2048 |
meta-llama/Llama-3.2-1B-Instruct |
128 |
2048 |
Qwen/Qwen2.5-1.5B-Instruct |
128 |
1024 |
meta-llama/Llama-3.2-3B-Instruct |
128 |
1024 |
Qwen/Qwen3-4B |
128 |
512 |
microsoft/Phi-3.5-mini-instruct |
32 |
256 |
from aimet_torch.experimental.adascale.adascale_optimizer import apply_adascale
from GenAITests.shared.helpers.datasets import Wikitext
from GenAITests.torch.helpers.quant_recipes import _prefill_inputs
ADASCALE_NUM_BATCHES = 128 # reduce for larger models to control runtime
ADASCALE_NUM_ITERATIONS = 2048 # reduce for larger models; see quantization recipes
train_dataset = Wikitext.load_encoded_dataset(tokenizer, CONTEXT_LENGTH, "train")
prefilled_inputs = _prefill_inputs(
generator, train_dataset, num_batches=ADASCALE_NUM_BATCHES, device=torch.device("cpu")
)
apply_adascale(
qsim=quantsim,
data_loader=prefilled_inputs,
num_iterations=ADASCALE_NUM_ITERATIONS,
)
from aimet_onnx.experimental.adascale.adascale_optimizer import (
AdaScale,
adascale_model_config_dict,
)
from GenAITests.shared.helpers.datasets import Wikitext
from GenAITests.onnx.helpers.quant_recipes import _prefill_inputs
ADASCALE_NUM_BATCHES = 128 # reduce for larger models to control runtime
ADASCALE_NUM_ITERATIONS = 2048 # reduce for larger models; see quantization recipes
train_dataset = Wikitext.load_encoded_dataset(tokenizer, CONTEXT_LENGTH, "train")
prefilled_inputs = _prefill_inputs(
quantsim, generator, train_dataset, num_batches=ADASCALE_NUM_BATCHES
)
AdaScale.apply_adascale(
quantsim,
prefilled_inputs,
adascale_model_config=adascale_model_config_dict[generator.config.model_type],
num_iterations=ADASCALE_NUM_ITERATIONS,
)
Step 4: Compute activation encodings¶
AdaScale optimizes weight encodings only. This step calibrates the remaining activation quantizers by running the model through a representative dataset.
import itertools
from tqdm import tqdm
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
def calibration_callback(model):
for batch in tqdm(
itertools.islice(train_dataset, 20), total=20, desc="Calibrating"
):
generator(input_ids=batch["input_ids"].to(device=device))
quantsim.compute_encodings(calibration_callback)
from tqdm import tqdm
calib_inputs = _prefill_inputs(quantsim, generator, train_dataset, num_batches=20)
def _forward(session, _):
for batch in tqdm(calib_inputs, total=len(calib_inputs), desc="Calibrating"):
session.run(None, batch)
quantsim.compute_encodings(_forward, tuple())
After completing these steps, export the quantized model:
PyTorch:
quantsim.onnx.export(...)ONNX:
quantsim.export(...)
See Examples/torch/quantize.py or Examples/onnx/quantize.py for the export invocation.
API¶
Top level APIs
- aimet_torch.experimental.adascale.apply_adascale(qsim, data_loader, forward_fn=None, num_iterations=1500)¶
- Parameters:
qsim (
QuantizationSimModel) – Quantization Sim modeldata_loader (
DataLoader) – DataLoader object to load the input dataforward_fn (
Optional[Callable[[Module,Any],Any]]) – forward function to run the forward pass of the modelnum_iterations (
int) – Number of iterations to optimize for during AdaScale BKD
Note that the forward_fn should take exactly two arguments - 1) the model 2) The object returned from the dataloader irrespective of whether it’s a tensor/tuple of tensors/dict/etc
The forward_fn should prepare the “input sample” as needed and call the forward pass in the very end. The forward_fn should not be running any sort of eval, creating full dataloader inside the method, etc.
- Example usage:
>>> model = DummyModel() >>> dummy_input = ... >>> data_set = DataSet(dummy_input) >>> data_loader = DataLoader(data_set, ...) >>> sim = QuantizationSimModel(model, dummy_input) >>> apply_adascale(sim, data_loader, forward_fn=forward_fn, num_iterations=1500) >>> sim.compute_encodings(...) >>> sim.export(...)
apply_adascale modifies the weights in-place in the model
compute encodings should not be called before the apply_adascale call
Activation quantizers will remain uninitialized throughout the feature, and so compute encodings needs to be called by the user afterwards. This is so activation encodings will be computed with updated weights taken into account.
Warning: This feature is currently considered experimental pending API changes
Top level APIs
- aimet_onnx.experimental.adascale.adascale_optimizer.apply_adascale(sim, inputs, adascale_model_config, num_iterations=1500)¶
- Parameters:
sim (
QuantizationSimModel) – Quantization Sim modelinputs (
Collection[Dict[str,ndarray]]) – (Collection[Dict[str, np.ndarray]]): The set of input samples to use during optimization.adascale_model_config (
AdaScaleModelConfig) – Adascale model config. There are pre-defined configs for Llama, Qwen2, Mistral, Qwen3, Phi3. For other models use AdaScaleModelConfignum_iterations (
int) – Number of iterations to optimize for during AdaScale
- Example usage:
>>> model = DummyModel() >>> inputs = ... >>> adascale_model_config = adascale_model_config['llama'] >>> sim = QuantizationSimModel(model) >>> apply_adascale(sim, inputs, adascale_model_config, num_iterations=num_iterations) >>> sim.compute_encodings(...) >>> sim.export(...)
apply_adascale modifies the weights in-place in the model
compute encodings should not be called before the apply_adascale call
Activation quantizers will remain uninitialized throughout the feature, and so compute encodings needs to be called by the user afterwards. This is so activation encodings will be computed with updated weights taken into account.
Warning: This feature is currently considered experimental pending API changes