QWA-LoRa¶
Context¶
The QWA-LoRa workflow involves determining the appropriate weight and activation encodings for the base model before performing some epochs of LoRa training. Finally, the weight and activations for the updated LoRa layers are calibrated. This is expressed in the block diagram below.

Workflow¶
Setup¶
In this section, we instantiate the base model, LoRa adapters, and dataset using Huggingface APIs.
from itertools import chain
from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, default_data_collator
from datasets import load_dataset
model_id = "facebook/opt-350m"
peft_model_id = "ybelkada/opt-350m-lora"
# This ensures that use_cache and return_dict are always set to false
# These settings are selected so that the model is JIT-traceable
config = AutoConfig.from_pretrained(model_id)
config.use_cache = False
config.return_dict = False
# Load model and LoRa adapter from Huggingface
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, config=config)
model.load_adapter(peft_model_id)
# Load train and test splits of dataset
def tokenize(examples):
seq_length = 512
examples = tokenizer(examples["text"])
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
if total_length >= seq_length:
total_length = (total_length // seq_length) * seq_length
result = {
k: [t[i : i + seq_length] for i in range(0, total_length, seq_length)]
for k, t in concatenated_examples.items()
}
result["labels"] = result["input_ids"].copy()
return result
train_dataset = load_dataset(path='wikitext', name='wikitext-2-raw-v1', split='train').map(tokenize, batched=True, remove_columns=['text'])
test_dataset = load_dataset(path='wikitext', name='wikitext-2-raw-v1', split='test').map(tokenize, batched=True, remove_columns=['text'])
train_dataloader = DataLoader(train_dataset, batch_size=1, collate_fn=default_data_collator)
test_dataloader = DataLoader(test_dataset, batch_size=1, collate_fn=default_data_collator)
Create QuantizationSimModel¶
import torch
from transformers.models import opt
from aimet_torch.quantsim import QuantizationSimModel
from aimet_torch.v2.nn.true_quant import QuantizationMixin
from aimet_torch.peft import replace_lora_layers_with_quantizable_layers
# Generate dummy data used to instantiate QuantizationSimModel
tokenized_dummy_text = tokenizer("Here is some sample text used to create dummy input ids")
dummy_input_ids = torch.Tensor(tokenized_dummy_text['input_ids']).to(dtype=torch.int32).unsqueeze(0)
dummy_attention_mask = torch.Tensor(tokenized_dummy_text['attention_mask']).to(dtype=torch.int32).unsqueeze(0)
# Modify LoRa layers so they are quantizable
replace_lora_layers_with_quantizable_layers(model)
# Register quantized version of OPTLearnedPositionalEmbedding
@QuantizationMixin.implements(opt.modeling_opt.OPTLearnedPositionalEmbedding)
class QuantizedOPTLearnedPositionalEmbedding(QuantizationMixin, opt.modeling_opt.OPTLearnedPositionalEmbedding):
""" Dummy placeholder - we don't want to quantize OPTLearnedPositionalEmbedding """
forward = opt.modeling_opt.OPTLearnedPositionalEmbedding.forward
# Create QuantizationSimModel
quantsim = QuantizationSimModel(model=model,
dummy_input=(dummy_input_ids, dummy_attention_mask),
default_output_bw=16,
default_param_bw=4,
in_place=True)
Calibration Callback¶
from tqdm import tqdm
# Callback function to pass calibration data through the model
def generate_calibration_callback(dataloader, max_iterations: int, device: torch.device):
def forward_pass(model: torch.nn.Module):
with torch.no_grad():
for batch_id, batch in enumerate(tqdm(dataloader, total=max_iterations)):
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
model(input_ids=input_ids, attention_mask=attention_mask)
if batch_id >= max_iterations:
break
return forward_pass
Training Callback¶
# Function to perform one epoch of training
def train_one_epoch(model, dataloader, device=torch.device("cuda")):
optimizer = torch.optim.AdamW(model.parameters())
loss_fn = torch.nn.CrossEntropyLoss()
for batch_id, batch in enumerate(tqdm(dataloader)):
optimizer.zero_grad()
input_ids = batch['input_ids'].to(device)
attention_mask = batch['attention_mask'].to(device)
labels = batch['labels'].to(device)
logits, = model(input_ids=input_ids, attention_mask=attention_mask)
# Compute the loss and its gradients
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss = loss_fn(shift_logits.view(-1, model.config.vocab_size), shift_labels.view(-1))
loss.backward()
# # Adjust learning weights
optimizer.step()
Run QWA-LoRa¶
from aimet_torch.utils import place_model
from aimet_torch.peft import LoraLayer
import aimet_torch.v2.quantization as Q
from aimet_torch.v2.utils import remove_all_quantizers
lora_a_layers = [module.lora_A for module in quantsim.model.modules() if isinstance(module, LoraLayer)]
lora_b_layers = [module.lora_B for module in quantsim.model.modules() if isinstance(module, LoraLayer)]
lora_add_layers = [module.add_lora_to_res for module in quantsim.model.modules() if isinstance(module, LoraLayer)]
lora_mul_layers = [module.mul_scale for module in quantsim.model.modules() if isinstance(module, LoraLayer)]
with place_model(model, torch.device("cuda")):
# Temporarily remove all LoRa layer quantizers, leaving only base model quantizers
with remove_all_quantizers(lora_a_layers + lora_b_layers + lora_add_layers + lora_mul_layers):
# Only compute encodings for base model weights, activations
calibration_callback = generate_calibration_callback(dataloader=train_dataloader, max_iterations=20, device=torch.device("cuda"))
quantsim.compute_encodings(calibration_callback)
# prevent quantization encoding getting overwritten by sim.compute_encodings()
for module_name, module in model.named_modules():
if isinstance(module, Q.base.QuantizerBase):
module.allow_overwrite(False)
# Configure model so that only LoRa layers are trainable
model.requires_grad_(False)
for module_name, module in model.named_modules():
if isinstance(module, LoraLayer):
module.lora_A.requires_grad_(True)
module.lora_B.requires_grad_(True)
# Perform LoRa QAT with base model weight, activation encodings frozen
train_one_epoch(quantsim.model, train_dataloader, torch.device("cuda"))
# Compute all other encodings
calibration_callback = generate_calibration_callback(dataloader=train_dataloader, max_iterations=20, device=torch.device("cuda"))
quantsim.compute_encodings(calibration_callback)