QWA-LoRa

Context

The QWA-LoRa workflow involves determining the appropriate weight and activation encodings for the base model before performing some epochs of LoRa training. Finally, the weight and activations for the updated LoRa layers are calibrated. This is expressed in the block diagram below.

../../_images/qwa_lora_block_diagram.png

Workflow

Setup

In this section, we instantiate the base model, LoRa adapters, and dataset using Huggingface APIs.

from itertools import chain

from torch.utils.data import DataLoader
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig, default_data_collator
from datasets import load_dataset

model_id = "facebook/opt-350m"
peft_model_id = "ybelkada/opt-350m-lora"

# This ensures that use_cache and return_dict are always set to false
# These settings are selected so that the model is JIT-traceable
config = AutoConfig.from_pretrained(model_id)
config.use_cache = False
config.return_dict = False

# Load model and LoRa adapter from Huggingface
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, config=config)
model.load_adapter(peft_model_id)

# Load train and test splits of dataset
def tokenize(examples):
    seq_length = 512
    examples = tokenizer(examples["text"])
    concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    if total_length >= seq_length:
        total_length = (total_length // seq_length) * seq_length
    result = {
        k: [t[i : i + seq_length] for i in range(0, total_length, seq_length)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

train_dataset = load_dataset(path='wikitext', name='wikitext-2-raw-v1', split='train').map(tokenize, batched=True, remove_columns=['text'])
test_dataset = load_dataset(path='wikitext', name='wikitext-2-raw-v1', split='test').map(tokenize, batched=True, remove_columns=['text'])

train_dataloader = DataLoader(train_dataset, batch_size=1, collate_fn=default_data_collator)
test_dataloader = DataLoader(test_dataset, batch_size=1, collate_fn=default_data_collator)

Create QuantizationSimModel

import torch
from transformers.models import opt

from aimet_torch.quantsim import QuantizationSimModel
from aimet_torch.v2.nn.true_quant import QuantizationMixin
from aimet_torch.peft import replace_lora_layers_with_quantizable_layers

# Generate dummy data used to instantiate QuantizationSimModel
tokenized_dummy_text = tokenizer("Here is some sample text used to create dummy input ids")
dummy_input_ids = torch.Tensor(tokenized_dummy_text['input_ids']).to(dtype=torch.int32).unsqueeze(0)
dummy_attention_mask = torch.Tensor(tokenized_dummy_text['attention_mask']).to(dtype=torch.int32).unsqueeze(0)

# Modify LoRa layers so they are quantizable
replace_lora_layers_with_quantizable_layers(model)

# Register quantized version of OPTLearnedPositionalEmbedding
@QuantizationMixin.implements(opt.modeling_opt.OPTLearnedPositionalEmbedding)
class QuantizedOPTLearnedPositionalEmbedding(QuantizationMixin, opt.modeling_opt.OPTLearnedPositionalEmbedding):
    """ Dummy placeholder - we don't want to quantize OPTLearnedPositionalEmbedding """
    forward = opt.modeling_opt.OPTLearnedPositionalEmbedding.forward

# Create QuantizationSimModel
quantsim = QuantizationSimModel(model=model,
                                dummy_input=(dummy_input_ids, dummy_attention_mask),
                                default_output_bw=16,
                                default_param_bw=4,
                                in_place=True)

Calibration Callback

from tqdm import tqdm

# Callback function to pass calibration data through the model
def generate_calibration_callback(dataloader, max_iterations: int, device: torch.device):
    def forward_pass(model: torch.nn.Module):
        with torch.no_grad():
            for batch_id, batch in enumerate(tqdm(dataloader, total=max_iterations)):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                model(input_ids=input_ids, attention_mask=attention_mask)
                if batch_id >= max_iterations:
                    break

    return forward_pass

Training Callback

# Function to perform one epoch of training
def train_one_epoch(model, dataloader, device=torch.device("cuda")):
    optimizer = torch.optim.AdamW(model.parameters())
    loss_fn = torch.nn.CrossEntropyLoss()

    for batch_id, batch in enumerate(tqdm(dataloader)):
        optimizer.zero_grad()

        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)

        logits, = model(input_ids=input_ids, attention_mask=attention_mask)

        # Compute the loss and its gradients
        shift_logits = logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()

        # Flatten the tokens
        loss = loss_fn(shift_logits.view(-1, model.config.vocab_size), shift_labels.view(-1))
        loss.backward()

        # # Adjust learning weights
        optimizer.step()

Run QWA-LoRa

from aimet_torch.utils import place_model
from aimet_torch.peft import LoraLayer
import aimet_torch.v2.quantization as Q
from aimet_torch.v2.utils import remove_all_quantizers

lora_a_layers = [module.lora_A for module in quantsim.model.modules() if isinstance(module, LoraLayer)]
lora_b_layers = [module.lora_B for module in quantsim.model.modules() if isinstance(module, LoraLayer)]
lora_add_layers = [module.add_lora_to_res for module in quantsim.model.modules() if isinstance(module, LoraLayer)]
lora_mul_layers = [module.mul_scale for module in quantsim.model.modules() if isinstance(module, LoraLayer)]

with place_model(model, torch.device("cuda")):
    # Temporarily remove all LoRa layer quantizers, leaving only base model quantizers
    with remove_all_quantizers(lora_a_layers + lora_b_layers + lora_add_layers + lora_mul_layers):
        # Only compute encodings for base model weights, activations
        calibration_callback = generate_calibration_callback(dataloader=train_dataloader, max_iterations=20, device=torch.device("cuda"))
        quantsim.compute_encodings(calibration_callback)

        # prevent quantization encoding getting overwritten by sim.compute_encodings()
        for module_name, module in model.named_modules():
            if isinstance(module, Q.base.QuantizerBase):
                module.allow_overwrite(False)

        # Configure model so that only LoRa layers are trainable
        model.requires_grad_(False)
        for module_name, module in model.named_modules():
            if isinstance(module, LoraLayer):
                module.lora_A.requires_grad_(True)
                module.lora_B.requires_grad_(True)

        # Perform LoRa QAT with base model weight, activation encodings frozen
        train_one_epoch(quantsim.model, train_dataloader, torch.device("cuda"))

    # Compute all other encodings
    calibration_callback = generate_calibration_callback(dataloader=train_dataloader, max_iterations=20, device=torch.device("cuda"))
    quantsim.compute_encodings(calibration_callback)