SpinQuant¶
Context¶
SpinQuant is a PTQ technique which improves the accuracy of the quantized model by inserting rotations at specific points in the model to help with outliers in activation quantization: https://arxiv.org/pdf/2405.16406.
In the paper, 4 rotation types are described: R1, R2, R3, and R4 rotations. The current AIMET implementation of SpinQuant enables R1 rotations without optimization only. As these rotations can be merged with adjacent layer weights, the final model architecture will not be changed.
Applying rotations does not require a quantized model, so either an FP model or a quantized model can be used as input. Since rotations need to be inserted at well known points in the model, the feature determines proper insertion points through use of a mapping table to define pre-determined insertion points for known model types.
Currently supported model types include
LlamaForCausalLM
Qwen2ForCausalLM
MistralForCausalLM
We expose the mapping dictionary as a module level object in case users need to register their own insertion points for other model types.
Note
This feature is currently marked as experimental. The API may change in the future.
Note
This feature is currently only supported for PyTorch framework.
Note
Only R1 rotations without optmization are currently supported.
Workflow¶
Prerequisites¶
To use SpinQuant, you must:
Load a pre-trained model
Procedure¶
Setup¶
# Load the model
# General setup that can be changed as needed
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id = "meta-llama/Llama-3.2-1B-Instruct"
model_config = AutoConfig.from_pretrained(model_id)
model_config.return_dict=False
model_config.use_cache = False
model_config.tie_word_embeddings = False
model = modeling_llama.LlamaForCausalLM.from_pretrained(model_id, config=model_config).to(device)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True, trust_remote_code=True)
Step 1¶
Register RMSNorm fusion locations and R1 insertion points if needed. The below code shows an example for a model of type MyModel. RMSNorm fusion is a transformation for folding RMSNorm weights into adjacent linear layers. The resulting RMSNorm op will have weights of all 1’s and bias 0. The folded model is mathematically equivalent to the original model in floating point computation, and is necessary for R1 transforms to be added later.
To prepare SpinQuant to take effect for a model other than what is already supported, users need to register the model type with two functions:
A function which, when given a model object, returns a list of tuples, where each tuple consists of an rmsnorm layer and a list of linear layers it should fuse with
A function which, when given a model object, returns a list of tuples, where each tuple consists of a linear layer and a boolean.
A boolean of True denotes R1 fusion occurring before the linear, while False denotes R1 fusion occurring after the linear.
For typical HuggingFace models which share similar architecture, _default_rmsnorm_linear_pairs_func() and _default_r1_fusion_func() can be used. For example, Llama, Qwen, and Mistral all share the same functions.
# Example model type to register
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.embed_tokens = ...
self.rmsnorm = ...
self.q_proj = ...
self.k_proj = ...
self.v_proj = ...
self.o_proj = ...
self.gate_proj = ...
self.up_proj = ...
self.down_proj = ...
self.lm_head = ...
...
def forward(self, input):
...
from aimet_torch.experimental.spinquant import spinquant_optimizer
# Define a function to identify rmsnorm fusion pairs
def rmsnorm_fusion_pairs(model):
return [(model.rmsnorm, [model.q_proj, model.k_proj, model.v_proj])]
# Define a function to identify R1 placement
def r1_placement(model):
return [(model.embed_tokens, False),
(model.q_proj, True),
(model.k_proj, True),
(model.v_proj, True),
(model.o_proj, False),
(model.gate_proj, True),
(model.up_proj, True),
(model.down_proj, False),
(model.lm_head, True)
]
# Register MyModel type with rmsnorm fusion and R1 placement functions
spinquant_optimizer.SUPPORTED_MODULE_DICT[MyModel] = {spinquant_optimizer.RMSNORM_LINEAR_PAIRS: rmsnorm_fusion_pairs,
spinquant_optimizer.R1_FUSION_PAIRS: r1_placement}
Step 2¶
Apply SpinQuant to the model.
# The model is updated in place
apply_spinquant(model=model)
Step 3¶
The subsequent steps are not strictly to do with SpinQuant, but serve as an example for how to quantize the model and evaluate. Use AIMET’s quantization simulation to create a QuantSimModel object.
from aimet_common.defs import QuantScheme
from aimet_torch.quantsim import QuantizationSimModel
seq_length = 2048
input_ids = torch.randint(0, model_config.vocab_size, (1, seq_length), device=device)
attention_mask = torch.ones((1, seq_length), dtype=torch.long, device=device)
dummy_input = (input_ids, attention_mask)
sim = QuantizationSimModel(model,
dummy_input=dummy_input,
quant_scheme=QuantScheme.training_range_learning_with_tf_init,
default_param_bw=4,
default_output_bw=16,
in_place=True)
Step 4¶
Instantiate a dataloader and compute encodings for remaining parameters of the model.
def tokenize(examples):
seq_length = 2048
examples = tokenizer(examples["text"])
concatenated_examples = {k: list(chain(*examples[k])) for k in examples.keys()}
total_length = len(concatenated_examples[list(examples.keys())[0]])
if total_length >= seq_length:
total_length = (total_length // seq_length) * seq_length
result = {
k: [t[i : i + seq_length] for i in range(0, total_length, seq_length)]
for k, t in concatenated_examples.items()
}
result["labels"] = result["input_ids"].copy()
return result
train_dataset = load_dataset(path='wikitext', name='wikitext-2-raw-v1', split='train').map(tokenize, batched=True, remove_columns=['text'])
train_dataloader = DataLoader(train_dataset, shuffle=False, batch_size=1, collate_fn=default_data_collator)
# Custom class to use limited samples from dataloader
class LimitedBatchDataLoader:
"""Internal helper class to reduce number of accessible batches in Dataloader"""
def __init__(self, dataloader, num_batches):
self.dataloader = dataloader
self.num_batches = num_batches
self.current_batch = 0
def __iter__(self):
# pylint: disable=attribute-defined-outside-init
self.iterator = iter(self.dataloader)
self.current_batch = 0
return self
def __next__(self):
if self.current_batch < self.num_batches:
self.current_batch += 1
return next(self.iterator)
raise StopIteration
def __len__(self):
return min(len(self.dataloader), self.num_batches)
def calibration_wrapper(model, dataloader):
for batch in dataloader:
batch = tuple((d.to(device) for d in batch.values()))
model(*batch)
# Compute the Quantization Encodings
# compute encodings for all activations and parameters of uninitialized layer(s)/operations(s)
sim.compute_encodings(calibration_wrapper, LimitedBatchDataLoader(train_dataloader, num_batches=40))
Step 5¶
At this point, the quantized model is ready to be evaluated.
API¶
Top level APIs
- aimet_torch.experimental.spinquant.apply_spinquant(model)[source]¶
Apply SpinQuant to the model, modifying weights in place. https://arxiv.org/pdf/2405.16406 Currently only R1 rotations without optimization are supported. The model is updated in place.
- Parameters:
model (
Module
) – The model to apply SpinQuant to.