aimet_torch.seq_mse

Top level APIs

aimet_torch.seq_mse.apply_seq_mse(*args, **kwargs)[source]

Sequentially minimizing activation MSE loss in layer-wise way to decide optimal param quantization encodings.

1 Disable all input/output quantizers, param quantizers of non-supported modules 2 Find and feeze optimal parameter encodings candidate for remaining supported modules 3 Re-enable disabled quantizers from step 1

Example userflow: model = Model().eval() sim = QuantizationSimModel(…) apply_seq_mse(…) sim.compute_encodings(…) [compute encodings for all activations and parameters of non-supported modules] sim.export(…)

NOTE: modules in modules_to_exclude won’t be quantized and skipped when applying sequential MSE.

Parameters:
  • sim – QuantizationSimModel object

  • data_loader – Data loader

  • num_candidates – Number of candidate encodings to evaluate for each layer

  • forward_fn – callback function to perform forward pass given accepts model, inputs

  • modules_to_exclude – List of supported type module(s) to exclude when applying Sequential MSE

  • checkpoints_config – Config files to split fp32/quant model by checkpoints to speedup activations sampling

Sequential MSE parameters

class aimet_torch.seq_mse.SeqMseParams(num_batches, num_candidates=20, inp_symmetry='symqt', loss_fn='mse', forward_fn=<function default_forward_fn>)[source]

Sequential MSE parameters

Parameters:
  • num_batches (Optional[int]) – Number of batches.

  • num_candidates (int) – Number of candidates to perform grid search. Default 20.

  • inp_symmetry (str) – Input symmetry. Available options are ‘asym’, ‘symfp’ and ‘symqt’. Default ‘symqt’.

  • loss_fn (str) – Loss function. Available options are ‘mse’, ‘l1’ and ‘sqnr’. Default ‘mse’.

  • forward_fn (Callable) – Optional adapter function that performs forward pass given a model and inputs yielded from the data loader. The function expects model as first argument and inputs to model as second argument.

forward_fn(inputs)

Default forward function. :type model: :param model: pytorch model :type inputs: :param inputs: model inputs

get_loss_fn()[source]

Returns loss function

Return type:

Callable