aimet_torch.v1.seq_mse

Note

This module is also available in the default aimet_torch namespace with the same top-level API.

Top level APIs

aimet_torch.v1.seq_mse.apply_seq_mse(model, sim, data_loader, params, modules_to_exclude=None, checkpoints_config=None)

Sequentially minimizing activation MSE loss in layer-wise way to decide optimal param quantization encodings.

1 Disable all input/output quantizers, param quantizers of non-supported modules 2 Find and feeze optimal parameter encodings candidate for remaining supported modules 3 Re-enable disabled quantizers from step 1

Example userflow: model = Model().eval() sim = QuantizationSimModel(…) apply_seq_mse(…) sim.compute_encodings(…) [compute encodings for all activations and parameters of non-supported modules] sim.export(…)

NOTE: 1) module reference passed to modules_to_exclude should be from FP32 model. 2) module from modules_to_exclude won’t be quantized and skipped when applying sequential MSE. 3) Except finding param encodings for supported modules, config JSON file will be respected and final state of sim will be unchanged.

Parameters:
  • model (Module) – Original fp32 model

  • sim (QuantizationSimModel) – Corresponding QuantizationSimModel object

  • data_loader (DataLoader) – Data loader

  • params (SeqMseParams) – Sequential MSE parameters

  • modules_to_exclude (Optional[List[Module]]) – List of supported type module(s) to exclude when applying Sequential MSE

  • checkpoints_config (Optional[str]) – Config files to split fp32/quant model by checkpoints to speedup activations sampling

Sequential MSE parameters

class aimet_torch.v1.seq_mse.SeqMseParams(num_batches, num_candidates=20, inp_symmetry='symqt', loss_fn='mse', forward_fn=<function default_forward_fn>)[source]

Sequential MSE parameters

Parameters:
  • num_batches (int) – Number of batches.

  • num_candidates (int) – Number of candidates to perform grid search. Default 20.

  • inp_symmetry (str) – Input symmetry. Available options are ‘asym’, ‘symfp’ and ‘symqt’. Default ‘symqt’.

  • loss_fn (str) – Loss function. Available options are ‘mse’, ‘l1’ and ‘sqnr’. Default ‘mse’.

  • forward_fn (Callable) – Optional adapter function that performs forward pass given a model and inputs yielded from the data loader. The function expects model as first argument and inputs to model as second argument.

forward_fn(inputs)

Default forward function. :type model: :param model: pytorch model :type inputs: :param inputs: model inputs