aimet_torch.experimental.omniquant

Top level APIs

aimet_torch.experimental.omniquant.apply_omniquant(quant_sim, dataloader, forward_fn, num_iterations=800, output_path='./aimet_omniquant_artifact/')

Returns model with with omniquant weight, and save metadata in safetensor format to output path. Metadata safetensor can be used in update_lora_weights to update lora adaptor weights for peft lora model.

Parameters:
  • quant_sim (QuantizationSimModel) – QuantizationSimModel object to optimize with Omniquant.

  • dataloader – Dataloader used to train model.

  • forward_fn (Callable) – Model forward function used to cache intermediate data. Expect to have model and inputs as function argument. e.g. lambda model, inputs: model(*inputs)

  • num_iterations (int) – Number of iterations to train each block with omniquant.

  • output_path (str) – Path to save {layer_name: scale} metadata safetensor.

Returns:

Model with Omniquant weights.