aimet_torch.experimental.omniquant¶
Top level APIs
- aimet_torch.experimental.omniquant.apply_omniquant(quant_sim, dataloader, forward_fn, num_iterations=800, output_path='./aimet_omniquant_artifact/')¶
Returns model with with omniquant weight, and save metadata in safetensor format to output path. Metadata safetensor can be used in update_lora_weights to update lora adaptor weights for peft lora model.
- Parameters:
quant_sim (
QuantizationSimModel
) – QuantizationSimModel object to optimize with Omniquant.dataloader – Dataloader used to train model.
forward_fn (
Callable
) – Model forward function used to cache intermediate data. Expect to have model and inputs as function argument. e.g. lambda model, inputs: model(*inputs)num_iterations (
int
) – Number of iterations to train each block with omniquant.output_path (
str
) – Path to save {layer_name: scale} metadata safetensor.
- Returns:
Model with Omniquant weights.