OmniQuant

Context

OmniQuant is a PTQ technique which improves the accuracy of the quantized model by computing optimal quantization parameters for weights. OmniQuant is based on : https://arxiv.org/abs/2308.13137. Omniquant comprises of 2 components Learnable Weight Clipping (LWC) and Learnable Equivalent Transformation (LET).

OmniQuant introduces trainable parameter scale in the weight quantizers of every supported module and performs BKD (Blockwise Knowledge Distillation) by comparing quantized output of every supported block with its FP32 equivalent. The trainable parameter scale is learnt pairwise in Omniquant. From OmniQuant perspective, a block is defined as a non-leaf module which takes in one activation input tensor and outputs one activation tensor. Omniquant also requires blocks to be contiguous to perform optimization. Warning: This feature is currently experimental. This feature is currently supported for llama3.2, Qwen2.5, Deepseek Distill for Qwen 2.5

Workflow

Prerequisites

To use OmniQuant, you must:

  • Use PyTorch. OmniQuant does not support other frameworks yet

  • Load a pre-trained model

  • Create a dataloader for the model

  • Choose a model which has contiguous blocks, and each block taking in one activation input and outputting one activation tensor. Example block: LlamaDecoderLayer in LlamaModel

Procedure

Setup

# Load the model
# General setup that can be changed as needed
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "meta-llama/Llama-3.2-1B-Instruct"
model_config = AutoConfig.from_pretrained(model_id)
model_config.return_dict=False
model_config.use_cache = False

model = modeling_llama.LlamaForCausalLM.from_pretrained(model_id, config=model_config)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, trust_remote_code=True)

def tokenize(examples):
    seq_length = 2048
    examples = tokenizer(examples["text"])
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    if total_length >= seq_length:
        total_length = (total_length // seq_length) * seq_length
    result = {
        k: [t[i : i + seq_length] for i in range(0, total_length, seq_length)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

train_dataset = load_dataset(path='wikitext', name='wikitext-2-raw-v1', split='train').map(tokenize, batched=True, remove_columns=['text'])
test_dataset = load_dataset(path='wikitext', name='wikitext-2-raw-v1', split='test').map(tokenize, batched=True, remove_columns=['text'])
train_dataloader = DataLoader(train_dataset, shuffle=False, batch_size=1, collate_fn=default_data_collator)
test_dataloader = DataLoader(test_dataset, shuffle=False, batch_size=1, collate_fn=default_data_collator)

# Custom class to use limited samples from dataloader
dataloader_wrapper_len = 40
class LimitedBatchDataLoader(DataLoader):
    def __init__(self, data_loader):
        self.data_loader = data_loader
 
    def __len__(self):
        return dataloader_wrapper_len
 
    def __iter__(self):
        return iter(self.data_loader)

Not supported.

Not supported.

Step 1

Use AIMET’s quantization simulation to create a QuantSimModel object.

from aimet_common.defs import QuantScheme
from aimet_torch.quantsim import QuantizationSimModel

seq_length = 2048
input_ids = torch.randint(0, model_config.vocab_size, (1, seq_length), device=device)
attention_mask = torch.ones((1, seq_length), dtype=torch.long, device=device)
dummy_input = (input_ids, attention_mask)
sim = QuantizationSimModel(model,
                           dummy_input=dummy_input,
                           quant_scheme=QuantScheme.training_range_learning_with_tf_init,
                           default_param_bw=4,
                           default_output_bw=16,
                           in_place=True)

Not supported.

Not supported.

Step 2

Apply apply_omniquant to decide optimal quantization encodings for parameters of supported layers. It is recommended to use a minimum of 800 iterations when applying apply_omniquant regardless of the dataloader batch size. The learnt scales are dumped as safetensors when we do apply_omniquant. These scales can be used for quantizing lora adapters. The usage of the dumped scale is not supported in the current release.

# Find and freeze optimal encodings candidate for weight parameters of supported layers

apply_omniquant(quant_sim=sim,
               dataloader=train_dataloader,
               forward_fn=lambda model, input: model.forward(**input),
               num_iterations=800)

Not supported.

Not supported.

Step 3

Compute encodings for remaining parameters of the model.

def calibration_wrapper(model, dataloader, max_iterations: int):
    for batch_id, batch in enumerate(dataloader):
        if batch_id < max_iterations:
            batch = tuple((d.to(device) for d in batch.values()))
            model.to(device)(*batch)
        else:
            break

# Compute the Quantization Encodings
# compute encodings for all activations and parameters of uninitialized layer(s)/operations(s)
sim.compute_encodings(calibration_wrapper, dataloader = train_dataloader, max_iterations=40)

Not supported.

Not supported.

Step 4

Evaluate the quantized model.

# Determine simulated quantized accuracy
...

Not supported.

Not supported.

Step 5

If the resulting quantized accuracy is satisfactory, export the model.

# Export the model for on-target inference
path = './'
filename = 'dummy_model'
sim.export(path=path, filename_prefix="quantized_" + filename, dummy_input=dummy_input.cpu())

Not supported.

Not supported.

API

Top level APIs

aimet_torch.experimental.omniquant.apply_omniquant(quant_sim, dataloader, forward_fn, num_iterations=800, output_path='./aimet_omniquant_artifact/')

Returns model with with omniquant weight, and save metadata in safetensor format to output path. Metadata safetensor can be used in update_lora_weights to update lora adaptor weights for peft lora model.

Parameters:
  • quant_sim (QuantizationSimModel) – QuantizationSimModel object to optimize with Omniquant.

  • dataloader – Dataloader used to train model.

  • forward_fn (Callable) – Model forward function used to cache intermediate data. Expect to have model and inputs as function argument. e.g. lambda model, inputs: model(*inputs)

  • num_iterations (int) – Number of iterations to train each block with omniquant.

  • output_path (str) – Path to save {layer_name: scale} metadata safetensor.

Returns:

Model with Omniquant weights.

Not supported.

Not supported.