Quantization recipes for LLMs

This document presents the quantization and evaluation workflow for large language models (LLMs) models using aimet-torch and aimet-onnx. The objective is to communicate performance expectations through two reference recipes applied to the following LLMs:

The exported artifacts from the recipes are not directly compatible with Qualcomm® AI Engine Direct (QAIRT) and require additional adaptation steps for deployment on the target hardware. Refer to the model adaptation guide for more details when deploying on the target hardware.

System Requirements

The quantization process requires a machine with:

  • Operating System: Linux

  • Hardware: CUDA-enabled GPU

GPU Memory Requirements:

  • Minimum: 40GB VRAM

Recipes

We present two recipes for INT4 weights, INT16 activations quantization using combinations of Post-Training Quantization (PTQ) techniques available in aimet-torch and aimet-onnx:

  1. PCQ + SpinQuant + AdaScale
    • Per-Channel Quantization (PCQ) - Uses per-output channel scales for weights on linear layers.

    • SpinQuant - A PTQ technique that improves the accuracy by inserting rotations at specific points in the model to mitigate activation outliers.

    • AdaScale - A PTQ technique that enhances accuracy by introducing learnable parameters in the weight quantizers and performing Block-wise Knowledge Distillation (BKD) against FP outputs.

  2. LPBQ + SequentialMSE

To maintain accuracy, activations are primarily kept at INT16, with a mixed-precision profile using INT8 activations selectively where feasible—such as for the KV cache.

Workflow Overview

  1. Load the HuggingFace model
    • Start by loading the pretrained model using HuggingFace transformers library.

  2. Apply the selected Quantization recipe
    • Use aimet-torch for PyTorch based workflows or

    • Use aimet-onnx for ONNX based workflows.

    • Note

      The SpinQuant technique is currently available only in aimet-torch. You can apply SpinQuant on the FP32 model before exporting it to ONNX, and then continue the workflow using aimet-onnx.

  3. Compute Activations encodings
    • Both aimet-torch and aimet-onnx compute activation encodings using representative data. In this tutorial, we use WikiText (English) for calibration.

    • For aimet-torch only:
      • Due to PyTorch limitations, certain functional operations (torch.nn.functional) cannot have quantizers inserted. This makes implementing a mixed-precision profile (e.g., KV Cache in INT8) challenging.

      • To address this, include aimet-onnx evaluation step within the aimet-torch workflow. aimet-onnx provides a static graph, ensuring correct quantizer insertion for all activations and delivering a more accurate quantization simulation.

  4. Export for deployment
    • Export the ONNX model along with the encodings file for the on-target inference.

Quick Start

This section provides a quick example of applying a quantization recipe using either aimet-torch or aimet-onnx.

In this tutorial, we apply the quantization recipe to the Llama 3.2 1B model. The steps work for all fine-tuned variants that share the same tokenizer and network architecture.

The example scripts are designed to be flattened, so all AIMET API calls and HuggingFace API calls are visible at the top level.

To understand how this works under the hood for PyTorch and ONNX models using the same driver code, refer to the Generator class in GenAITests.

Quantize

Example: Apply Recipe 1 (pcq_spinquant_adascale)

Using aimet-torch:

python -m Examples.torch.quantize \
 --model-id "meta-llama/Llama-3.2-1B-Instruct" \
 --recipe "pcq_spinquant_adascale" \
 --export-path "./torch_pcq" \
 --adascale-num-batches 128 --adascale-num-iterations 2048

Using aimet-onnx:

python -m Examples.onnx.quantize \
 --model-id "meta-llama/Llama-3.2-1B-Instruct" \
 --recipe "pcq_spinquant_adascale" \
 --export-path "./onnx_pcq" \
 --adascale-num-batches 128 --adascale-num-iterations 2048

Example: Apply Recipe 2 (lpbq_seqmse)

Using aimet-torch:

python -m Examples.torch.quantize \
 --model-id "meta-llama/Llama-3.2-1B-Instruct" \
 --recipe "lpbq_seqmse" \
 --export-path "./torch_lpbq" \
 --seqmse-num-batches 20

Using aimet-onnx:

python -m Examples.onnx.quantize \
 --model-id "meta-llama/Llama-3.2-1B-Instruct" \
 --recipe "lpbq_seqmse" \
 --export-path "./onnx_lpbq" \
 --seqmse-num-batches 20

Evaluate

Use the checkpoint generated in the previous step to evaluate the quantized model.

  • ONNX evaluation works for models quantized with either aimet-torch or aimet-onnx.

  • PyTorch evaluation works only for models quantized with aimet-torch.

Using aimet-onnx:

python -m Examples.onnx.evaluate \
 --model-id "meta-llama/Llama-3.2-1B-Instruct" \
 --checkpoint "./torch_lpbq" \
 --eval-ppl

Now, we will go through the performance numbers for the selected LLMs.

Performance Summary

Once the model is quantized, it is essential to evaluate its accuracy to ensure it meets acceptable thresholds. The same evaluation can also be performed on the original (unquantized) model to establish a strong baseline.

We demonstrate quantitative evaluation using two key metrics:

Additionally, we report:

  • End-to-end runtime for each quantization recipe

  • Peak CUDA memory usage during quantization

The consolidated performance tables summarize results for selected LLM models. You will find numbers for both recipes using aimet-torch and aimet-onnx.

Note

For models quantized using aimet-torch, we include results from evaluation on aimet-onnx. This ensures accurate activation quantizer placement and mixed-precision simulation (e.g., INT8 KV Cache).

To avoid confusion, we explicitly report two fields for each result:

  • Quantized With – the AIMET package used to create the quantized model

  • Evaluated On – the AIMET package used to measure accuracy and performance

During quantization and evaluation, we use a sequence length of 2048 tokens (referred to as AR-2048) and the context length of 4096 tokens.

1. meta-llama/Llama-3.2-1B-Instruct

Precision settings:

  • Weights: INT4, except for:
    • LM Head: INT8

  • Activations: INT16, except for:
    • KV Cache: INT8

Hyperparameters:

  • AdaScale: num_batches=128, num_iterations=2048

  • SequentialMSE: num_batches=20

  • Calibration: num_batches=20

Technique

Quantized With

Evaluated On

PPL

MMLU

Time (hh:mm:ss)

CUDA (GB)

FP32

N/A

Both

12.14

46.06

00:00:14

6.34

PCQ + SpinQuant + AdaScale

aimet-torch

aimet-onnx

13.67

42.25

02:31:06

20.89

PCQ + SpinQuant + AdaScale

aimet-onnx

aimet-onnx

13.68

41.82

01:53:17

46.38

LPBQ + SequentialMSE

aimet-torch

aimet-onnx

14.07

43.09

00:44:38

28.52

LPBQ + SequentialMSE

aimet-onnx

aimet-onnx

13.84

43.53

00:20:44

34.79

2. meta-llama/Llama-3.2-3B-Instruct

Precision settings:

  • Weights: INT4, except for:
    • LM Head: INT8

  • Activations: INT16, except for:
    • KV Cache: INT8

Hyperparameters:

  • AdaScale: num_batches=128, num_iterations=1024

  • SequentialMSE: num_batches=20

  • Calibration: num_batches=20

Technique

Quantized With

Evaluated On

PPL

MMLU

Time (hh:mm:ss)

CUDA (GB)

FP32

N/A

Both

10.13

60.74

00:00:10

13.90

PCQ + SpinQuant + AdaScale

aimet-torch

aimet-onnx

11.01

58.09

06:35:22

41.24

PCQ + AdaScale

aimet-onnx

aimet-onnx

11.14

56.79

04:49:36

47.35

LPBQ + SequentialMSE

aimet-torch

aimet-onnx

10.69

59.08

02:41:44

51.11

LPBQ + SequentialMSE

aimet-onnx

aimet-onnx

10.55

59.29

01:13:12

59.41

3. Qwen/Qwen2.5-0.5B-Instruct

Precision settings:

  • Weights: INT4, except for:
    • LM Head: INT8

  • Activations: INT16

Hyperparameters:

  • AdaScale: num_batches=128, num_iterations=2048

  • SequentialMSE: num_batches=20

  • Calibration: num_batches=20

Technique

Quantized With

Evaluated On

PPL

MMLU

Time (hh:mm:ss)

CUDA (GB)

FP32

N/A

Both

13.14

46.30

00:00:13

3.68

PCQ + SpinQuant + AdaScale

aimet-torch

aimet-onnx

13.89

44.19

03:19:37

13.37

PCQ + SpinQuant + AdaScale

aimet-onnx

aimet-onnx

13.82

42.65

01:16:54

34.01

LPBQ + SequentialMSE

aimet-torch

aimet-onnx

15.32

42.33

00:22:39

14.25

LPBQ + SequentialMSE

aimet-onnx

aimet-onnx

15.30

43.26

00:11:33

20.43

4. Qwen/Qwen2.5-1.5B-Instruct

Precision settings:

  • Weights: INT4, except for:
    • LM Head: INT8

  • Activations: INT16

Hyperparameters:

  • AdaScale: num_batches=128, num_iterations=1024

  • SequentialMSE: num_batches=20

  • Calibration: num_batches=20

Technique

Quantized With

Evaluated On

PPL

MMLU

Time (hh:mm:ss)

CUDA (GB)

FP32

N/A

Both

12.41

54.65

00:00:10

7.78

PCQ + SpinQuant + AdaScale

aimet-torch

aimet-onnx

13.57

49.81

03:03:17

22.62

PCQ + SpinQuant + AdaScale

aimet-onnx

aimet-onnx

13.35

50.27

02:13:33

42.97

LPBQ + SequentialMSE

aimet-torch

aimet-onnx

14.86

49.25

01:07:43

26.01

LPBQ + SequentialMSE

aimet-onnx

aimet-onnx

14.33

49.97

00:37:52

34.40

5. Qwen/Qwen3-4B

Precision settings:

  • Weights: INT4, except for:
    • LM Head: INT8

  • Activations: INT16, except for:
    • KV Cache: INT8

Hyperparameters:

  • AdaScale: num_batches=128, num_iterations=512

  • SequentialMSE: num_batches=20

  • Calibration: num_batches=20

Technique

Quantized With

Evaluated On

PPL

MMLU

Time (hh:mm:ss)

CUDA (GB)

FP32

N/A

Both

12.41

70.06

00:00:10

17.02

PCQ + SpinQuant + AdaScale

aimet-torch

aimet-onnx

13.85

65.07

06:41:32

47.71

PCQ + AdaScale

aimet-onnx

aimet-onnx

13.79

62.33

04:34:22

71.3

LPBQ + SequentialMSE

aimet-torch

aimet-onnx

13.10

65.66

02:41:48

39.42

LPBQ + SequentialMSE

aimet-onnx

aimet-onnx

12.77

65.36

01:35:29

63.61

6. microsoft/Phi-3.5-mini-instruct

Precision settings:

  • Weights: INT4, except for:
    • LM Head: INT8

  • Activations: INT16, except for:
    • KV Cache: INT8

Hyperparameters:

  • AdaScale: num_batches=128, num_iterations=256

  • SequentialMSE: num_batches=20

  • Calibration: num_batches=20

Technique

Quantized With

Evaluated On

PPL

MMLU

Time (hh:mm:ss)

CUDA (GB)

FP32

N/A

Both

5.77

68.89

00:00:08

16.17

PCQ + SpinQuant + AdaScale

aimet-torch

aimet-onnx

6.58

62.62

04:16:53

48.03

PCQ + SpinQuant + AdaScale

aimet-onnx

aimet-onnx

6.50

62.51

01:51:43

61.85

LPBQ + SequentialMSE

aimet-torch

aimet-onnx

6.45

64.63

02:03:41

37.64

LPBQ + SequentialMSE

aimet-onnx

aimet-onnx

6.41

63.90

01:32:36

75.62

FAQs

  1. When should I choose aimet-torch vs aimet-onnx?
    • Choose aimet-torch when:
      • You want to apply quantization directly on a PyTorch model and keep the workflow within the PyTorch ecosystem.

      • You plan to apply Quantization-Aware Training (QAT) or run calibration using PyTorch datasets and dataloaders.

      • You need flexibility for dynamic graph operations.

    • Choose aimet-onnx when:
      • You need a static graph representation for deployment.

      • You want full quantization coverage, including functional operations that aimet-torch cannot instrument easily.

      • You are preparing the model for hardware adaptation (e.g. QAIRT) or other runtimes which consume ONNX graphs.

  2. When should I choose Recipe 1 vs Recipe 2?
    • Choose Recipe 1: PCQ + SpinQuant + AdaScale
      • Uses Per-channel Quantization (PCQ), which provides good granularity for weights.

      • Performance KPIs (token rate, time-to-first-tokens etc.) are better on the target device.

      • Recommended when you can afford longer calibration time and prioritize throughput over accuracy.

    • Choose Recipe 2: LPBQ + SequentialMSE
      • Uses Blockwise quantization, which provides finer granularity than PCQ.

      • Recommended when the accuracy is the top priority.

      • Trade off: Slight impact on performance KPIs due to INT4 -> INT8 decoding.

  3. Can I run the artifacts generated from the recipes as-is on target hardware?
    • No. The generated artifacts from the recipes are not directly compatible with QAIRT and require non-trivial adaptation steps for deployment on target hardware. Refer to the model adaptation guide for details.

  4. Why does computing MMLU takes a long time?
    • MMLU evaluation can be slow even on high-end GPUs because it involves thousands of questions across 57 subjects. You can trade off accuracy for speed by reducing the number of samples.

  5. Why INT8 KV Cache is not “Good enough” for Qwen 2.5?
    • Qwen 2.5 (0.5B and 1.5B) suffers with INT8 path (with only 256 discrete levels) for KV Cache activations due to wider dynamic range and INT16 offers 65,536 discrete levels which drastically reduces quantization error. So for Qwen 2.5, INT16, which doubles memory compared to INT8, maintains performance much closer to FP32, making it the better choice when quality matters and memory allows.

Contact Us

Please reach out to us if you encounter any issue with this tutorial or applying recipes to similar models.