AdaScale¶
Context¶
AdaScale is a PTQ technique which improves the accuracy of the quantized model by computing optimal quantization parameters for weights. AdaScale is based on FlexRound: https://arxiv.org/abs/2306.00317 and integrates Learnable Weight Clipping from OmniQuant: https://arxiv.org/abs/2308.13137.
AdaScale introduces trainable parameters (gamma, beta, s2, s3) in the weight quantizers of every supported module and performs BKD (Blockwise Knowledge Distillation) by comparing quantized output of every supported block with its FP32 equivalent.
From AdaScale perspective, a block is defined as a non-leaf module which takes in one activation input tensor and outputs one activation tensor.
Warning: This feature is currently experimental.
Workflow¶
Prerequisites¶
To use AdaScale, you must:
Use PyTorch. AdaScale does not support other frameworks yet.
Load a pre-trained model
Create a dataloader for the model
Procedure¶
Setup¶
# Load the model
# General setup that can be changed as needed
device = "cuda:0" if torch.cuda.is_available() else "cpu"
model = ModelWithConsecutiveLinearBlocks().eval().to(device)
num_batches = 32
num_samples = 96
dummy_input = torch.rand(num_samples, 3, 32, 64).to(device)
data_set = CustomDataset(dummy_input)
data_loader = DataLoader(data_set, batch_size=int(num_samples / num_batches), shuffle=True)
Not supported.
Not supported.
Step 1¶
Use AIMET’s quantization simulation to create a QuantSimModel object
from aimet_common.defs import QuantScheme
from aimet_torch.quantsim import QuantizationSimModel
sim = QuantizationSimModel(model,
dummy_input=dummy_input,
quant_scheme=QuantScheme.training_range_learning_with_tf_init,
default_param_bw=4,
default_output_bw=16)
Not supported.
Not supported.
Step 2¶
Apply AdaScale to decide optimal quantization encodings for parameters of supported layers
# Find and freeze optimal encodings candidate for weight parameters of supported layers
from aimet_torch.experimental.adascale import apply_adascale
from aimet_torch.v2.utils import default_forward_fn
apply_adascale(qsim=sim,
data_loader=data_loader,
forward_fn=default_forward_fn,
num_epochs=10)
Not supported.
Not supported.
Step 3¶
Compute encodings for remaining parameters of the model
def forward_pass(model: torch.nn.Module, _):
with torch.no_grad():
for data in data_loader:
model(data)
# Compute the Quantization Encodings
# compute encodings for all activations and parameters of uninitialized layer(s)/operations(s)
sim.compute_encodings(forward_pass, None)
Not supported.
Not supported.
Step 4¶
Evaluate the quantized model
# Determine simulated quantized accuracy
...
Not supported.
Not supported.
Step 5¶
If the resulting quantized accuracy is satisfactory, export the model.
# Export the model for on-target inference
path = './'
filename = 'dummy_model'
sim.export(path=path, filename_prefix="quantized_" + filename, dummy_input=dummy_input.cpu())
Not supported.
Not supported.
API¶
Top level APIs
- aimet_torch.experimental.adascale.apply_adascale(qsim, data_loader, forward_fn=None, num_epochs=1)¶
- Parameters:
qsim (
QuantizationSimModel
) – Quantization Sim modeldata_loader (
DataLoader
) – DataLoader object to load the input dataforward_fn (
Optional
[Callable
[[Module
,Any
],Any
]]) – forward function to run the forward pass of the modelnum_epochs (
int
) – Number of epochs to perform the AdaScale BKD
Note that the forward_fn should take exactly two arguments - 1) the model 2) The object returned from the dataloader irrespective of whether it’s a tensor/tuple of tensors/dict/etc
The forward_fn should prepare the “input sample” as needed and call the forward pass in the very end. The forward_fn should not be running any sort of eval, creating full dataloader inside the method, etc.
- Example usage:
>>> model = DummyModel() >>> dummy_input = ... >>> data_set = DataSet(dummy_input) >>> data_loader = DataLoader(data_set, ...) >>> sim = QuantizationSimModel(model, dummy_input) >>> apply_adascale(sim, data_loader, forward_fn=forward_fn, num_epochs=10) >>> sim.compute_encodings(...) >>> sim.export(...)
apply_adascale modifies the weights in-place in the model
compute encodings should not be called before the apply_adascale call
Activation quantizers will remain uninitialized throughout the feature, and so compute encodings needs to be called by the user afterwards. This is so activation encodings will be computed with updated weights taken into account.
Warning: This feature is currently considered experimental pending API changes
Not supported.
Not supported.