Source code for aimet_torch.v2.visualization_tools.quant_stats_visualization

# -*- mode: python -*-
# =============================================================================
#  @@-COPYRIGHT-START-@@
#
#  Copyright (c) 2023-2023, Qualcomm Innovation Center, Inc. All rights reserved.
#
#  Redistribution and use in source and binary forms, with or without
#  modification, are permitted provided that the following conditions are met:
#
#  1. Redistributions of source code must retain the above copyright notice,
#     this list of conditions and the following disclaimer.
#
#  2. Redistributions in binary form must reproduce the above copyright notice,
#     this list of conditions and the following disclaimer in the documentation
#     and/or other materials provided with the distribution.
#
#  3. Neither the name of the copyright holder nor the names of its contributors
#     may be used to endorse or promote products derived from this software
#     without specific prior written permission.
#
#  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
#  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
#  IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
#  ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
#  LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
#  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
#  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
#  INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
#  CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
#  ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
#  POSSIBILITY OF SUCH DAMAGE.
#
#  SPDX-License-Identifier: BSD-3-Clause
#
#  @@-COPYRIGHT-END-@@
# =============================================================================

""" Tool to visualize min and max activations/weights of quantized modules in a given model"""


import os
from pathlib import Path
import torch
from bokeh.events import DocumentReady, Reset
from bokeh.layouts import row, column
from bokeh.models import ColumnDataSource, TextInput, CustomJS, Range1d, HoverTool, CustomJSHover, Div, \
    BooleanFilter, CDSView, Spacer, DataTable, StringFormatter, ScientificFormatter, TableColumn, Tooltip, Select, \
    Whisker, FactorRange
from bokeh.models.tools import ResetTool
from bokeh.models.dom import HTML
from bokeh.plotting import figure, save, curdoc
from aimet_torch.v2.quantsim import QuantizationSimModel
from aimet_torch.utils import get_ordered_list_of_modules
from aimet_torch.v2.quantization.base import QuantizerBase
from aimet_torch.v2.quantization.encoding_analyzer import _MinMaxObserver, _HistogramObserver


def _visualize(sim: QuantizationSimModel, dummy_input, mode: str, save_path: str = "./quant_stats_visualization.html", percentile_list: list = None) -> None:
    """
    Helper function for the visualization APIs.

    :param sim: Calibrated QuantSim Object.
    :param dummy_input: Dummy Input.
    :param mode: Whether to plot basic or advanced stats.
    :param save_path: Path for saving the visualization. Format is 'path_to_dir/file_name.html'. Default is './quant_stats_visualization.html'.
    :param percentile_list: Percentiles to be extracted and used.
    """

    # Ensure that sim is an instance of aimet_torch.v1.quantsim.QuantizationSimModel
    if not isinstance(sim, QuantizationSimModel):
        raise TypeError(f"Expected type 'aimet_torch.v2.quantsim.QuantizationSimModel', got '{type(sim)}'.")

    if percentile_list is None:
        raise ValueError("percentile_list cannot be None. Consider providing an empty percentile_list if needed.")

    # Ensure that the save path is valid
    _check_path(save_path)

    # Topologically sort the quantized modules into an ordered list for easier indexing in the plots
    ordered_list = get_ordered_list_of_modules(sim.model, dummy_input)
    stats_list = []

    # Collect stats from observers
    for i in ordered_list:
        module_stats = _get_observer_stats(i, percentile_list=percentile_list)
        if module_stats is not None:
            stats_list.append(module_stats)

    # Raise an error if no stats were found
    if len(stats_list) == 0:
        raise RuntimeError(
            "No stats found to plot. Either there were no quantized modules, or calibration was not performed before calling this function, or no observers of type _MinMaxObserver or _HistogramObserver were present.")


    stats_dict = dict()
    keys_list = ["name", 0, 100] + percentile_list
    stats_dict["idx"] = list(range(len(stats_list)))
    for i in keys_list:
        stats_dict[i] = [None] * len(stats_list)
    for idx, stats in enumerate(stats_list):
        for i in keys_list:
            stats_dict[i][idx] = stats[i]

    if mode == "advanced":
        _get_additional_boxplot_stats(stats_dict)

    visualizer = QuantStatsVisualizer(stats_dict, percentile_list)

    # Save an interactive bokeh plot as a standalone html
    visualizer.export_plot_as_html(save_path, mode)


[docs]def visualize_stats(sim: QuantizationSimModel, dummy_input, save_path: str = "./quant_stats_visualization.html") -> None: """Produces an interactive html to view the stats collected by each quantizer during calibration .. note:: The QuantizationSimModel input is expected to have been calibrated before using this function. Stats will only be plotted for activations/parameters with quantizers containing calibration statistics. Creates an interactive visualization of min and max activations/weights of all quantized modules in the input QuantSim object. The features include: - Adjustable threshold values to flag layers whose min or max activations/weights exceed the set thresholds - Tables containing names and ranges for layers exceeding threshold values Saves the visualization as a .html at the given path. Example: >>> sim = aimet_torch.v2.quantsim.QuantizationSimModel(model, dummy_input, quant_scheme=QuantScheme.post_training_tf) >>> with aimet_torch.v2.nn.compute_encodings(sim.model): ... for data, _ in data_loader: ... sim.model(data) ... >>> visualize_stats(sim, dummy_input, save_path="./quant_stats_visualization.html") :param sim: Calibrated QuantizationSimModel :param dummy_input: Sample input used to trace the model :param save_path: Path for saving the visualization. Default is "./quant_stats_visualization.html" """ percentile_list = [] _visualize(sim, dummy_input, mode="basic", save_path=save_path, percentile_list=percentile_list)
def visualize_advanced_stats(sim: QuantizationSimModel, dummy_input, save_path: str = "./quant_advanced_stats_visualization.html", additional_percentiles: tuple = (1, 99)) -> None: """Produces an interactive html to view the advanced stats collected by each quantizer during calibration .. note:: The QuantizationSimModel input is expected to have been calibrated before using this function. Stats will only be plotted for activations/parameters with quantizers containing calibration statistics. .. note:: For plotting advanced stats, the quantizer encoding analyzers should have observers of type :class:`_HistogramObserver`. If observers are of the type :class:`_MinMaxObserver`, then advanced stats cannot be extracted and only the min and max values are shown in the boxplots. .. note:: In case of Per-channel or Blockwise quantizers, percentiles are not extracted due to the presence of multiple histograms. For these quantizers, only min and max values are shown in the boxplots. Creates an interactive visualization of min and max activations/weights of all quantized modules in the input QuantSim object. The features include: - Adjustable threshold values to flag layers whose min or max activations/weights exceed the set thresholds - Table containing names and ranges for layers exceeding threshold values - Select different views of the table to group layers exceeding threshold values - Filter layers listed in the table by name - Select one or more layers from the table for viewing their boxplots and highlighting them in the main plot Saves the visualization as a .html at the given path. Example: >>> sim = aimet_torch.v2.quantsim.QuantizationSimModel(model, dummy_input, quant_scheme=QuantScheme.post_training_tf_enhanced) >>> with aimet_torch.v2.nn.compute_encodings(sim.model): ... for data, _ in data_loader: ... sim.model(data) ... >>> visualize_advanced_stats(sim, dummy_input, save_path="./quant_advanced_stats_visualization.html", additional_percentiles=(1, 99)) :param sim: Calibrated QuantizationSimModel :param dummy_input: Sample input used to trace the model :param save_path: Path for saving the visualization. Default is "./quant_advanced_stats_visualization.html" :param additional_percentiles: Percentiles other than those related to the boxplot (25, 50, 75) to be shown. """ percentile_list = _add_key_percentiles(additional_percentiles) percentile_list = sorted(percentile_list) _visualize(sim, dummy_input, mode="advanced", save_path=save_path, percentile_list=percentile_list) def _check_path(path: str): """ Function for sanity check on the given path """ path_to_directory = os.path.dirname(path) if path_to_directory != '' and not os.path.exists(path_to_directory): raise NotADirectoryError(f"'{path_to_directory}' is not a directory.") if not path.endswith('.html'): raise ValueError("'save_path' must end with '.html'.") def _get_observer_stats(module, percentile_list): """ Function to extract stats from an observer. Handles observers of types _MinMaxObserver, _HistogramObserver. """ module_name, module_quantizer = module[0], module[1] if isinstance(module_quantizer, QuantizerBase): if isinstance(module_quantizer.encoding_analyzer.observer, _MinMaxObserver): rng = module_quantizer.encoding_analyzer.observer.get_stats() if rng.min is not None: stats = dict() stats["name"] = module_name stats[0] = torch.min(rng.min).item() stats[100] = torch.max(rng.max).item() for p in percentile_list: stats[p] = None return stats elif isinstance(module_quantizer.encoding_analyzer.observer, _HistogramObserver): histogram_list = module_quantizer.encoding_analyzer.observer.get_stats() if len(histogram_list) == 1: histogram = histogram_list[0] if histogram.min is not None: stats = dict() stats["name"] = module_name stats[0] = histogram.min.item() stats[100] = histogram.max.item() _get_advanced_stats_from_histogram(histogram, stats, percentile_list) return stats elif len(histogram_list) > 1: stats = dict() stats["name"] = module_name curmin = float("inf") curmax = float("-inf") for histogram in histogram_list: if histogram.min is not None: curmin = min(curmin, histogram.min.item()) curmax = max(curmax, histogram.max.item()) if curmin < float("inf"): stats[0] = curmin stats[100] = curmax for p in percentile_list: stats[p] = None return stats return None def _add_key_percentiles(percentiles: tuple): """ Add percentiles required for boxplot if not already present """ percentile_list = list(percentiles) for p in [25, 50, 75]: if p not in percentile_list: percentile_list.append(p) return percentile_list def _get_advanced_stats_from_histogram(histogram, stats, percentile_list): """ High level function to extract advanced stats from a histogram object """ if len(percentile_list) > 0: percentile_stats = _get_percentile_stats_from_histogram(histogram, percentile_list) for i, percentile in enumerate(percentile_list): stats[percentile] = percentile_stats[i] def _get_percentile_stats_from_histogram(histogram, percentile_list): """ Function to extract percentile stats from a histogram object """ if len(percentile_list) == 0: raise RuntimeError("'percentile_list' cannot be empty.'") if not _is_sorted(percentile_list): raise RuntimeError("'percentile_list' must be sorted before calling this function.") n = torch.sum(histogram.histogram).item() cum_f = 0 idx = 0 percentile_stats = [] for i in range(len(histogram.histogram)): f = histogram.histogram[i].item() if f > 0: bin_low = histogram.bin_edges[i].item() bin_high = histogram.bin_edges[i + 1].item() while True: if (cum_f + f) / n >= percentile_list[idx] / 100: percentile_stats.append(bin_low + ((n * percentile_list[idx] / 100 - cum_f) / f) * (bin_high - bin_low)) idx += 1 if idx == len(percentile_list): return percentile_stats else: break cum_f += f return None def _get_additional_boxplot_stats(stats_dict: dict): """ Get additional values required to plot a boxplot""" stats_dict["stridx"] = [] stats_dict["boxplot_upper"] = [] stats_dict["boxplot_lower"] = [] for i in range(len(stats_dict["idx"])): stats_dict["stridx"].append(str(stats_dict["idx"][i])) if (stats_dict[25][i] is not None) and (stats_dict[75][i] is not None): inter_quantile_range = stats_dict[75][i] - stats_dict[25][i] stats_dict["boxplot_upper"].append(stats_dict[75][i] + 1.5 * inter_quantile_range) stats_dict["boxplot_lower"].append(stats_dict[25][i] - 1.5 * inter_quantile_range) else: stats_dict["boxplot_upper"].append(None) stats_dict["boxplot_lower"].append(None) def _is_sorted(arr: list): for i in range(len(arr) - 1): if arr[i] > arr[i + 1]: return False return True class DataSources: """ Class to hold the Bokeh ColumnDataSource objects needed in the visualization. """ def __init__(self, stats_dict: dict, plot: figure, default_values: dict, percentiles: list ): self.data_source = ColumnDataSource( data=dict(idx=stats_dict["idx"], namelist=stats_dict["name"], minlist=stats_dict[0], min_namelist=["Min"] * len(stats_dict["idx"]), maxlist=stats_dict[100], max_namelist=["Max"] * len(stats_dict["idx"]), marker_yminlist=[default_values['default_ymin']] * len(stats_dict["idx"]), marker_ymaxlist=[default_values['default_ymax']] * len(stats_dict["idx"]), selected=[False] * len(stats_dict["idx"]))) if "stridx" in stats_dict.keys(): self.data_source.add(data=stats_dict["stridx"], name="stridx") if "boxplot_upper" in stats_dict.keys(): self.data_source.add(data=stats_dict["boxplot_upper"], name="boxplot_upper_list") if "boxplot_lower" in stats_dict.keys(): self.data_source.add(data=stats_dict["boxplot_lower"], name="boxplot_lower_list") for key in [25, 50, 75]: if key in stats_dict: self.data_source.add(data=stats_dict[key], name=str(key) + "%ilelist") for key in percentiles: self.data_source.add(data=stats_dict[key], name=str(key) + "%ilelist") self.data_source.add(data=[str(key)+" %ile" for _ in range(len(stats_dict["idx"]))], name=str(key) + "%ile_namelist") self.default_values_source = ColumnDataSource( data=dict(default_ymax=[default_values['default_ymax']], default_ymin=[default_values['default_ymin']], default_maxclip=[default_values['default_maxclip']], default_minclip=[default_values['default_minclip']], default_xmax=[default_values['default_xmax']], default_xmin=[default_values['default_xmin']])) self.limits_source = ColumnDataSource( data=dict(ymax=[default_values['default_ymax']], ymin=[default_values['default_ymin']], xmin=[plot.x_range.start], xmax=[plot.x_range.end], minclip=[default_values['default_minclip']], maxclip=[default_values['default_maxclip']])) self.table_data_source = ColumnDataSource( data=dict(idx=[], namelist=[], minlist=[], maxlist=[])) self.selected_data_source = ColumnDataSource( data=dict(idx=[], namelist=[], floor=[], ceil=[], minlist=[], min_namelist=[], maxlist=[], max_namelist=[]) ) if "stridx" in stats_dict.keys(): self.selected_data_source.add(data=[], name="stridx") if "boxplot_upper" in stats_dict.keys(): self.selected_data_source.add(data=[], name="boxplot_upper_list") if "boxplot_lower" in stats_dict.keys(): self.selected_data_source.add(data=[], name="boxplot_lower_list") for key in [25, 50, 75]: if key in stats_dict: self.selected_data_source.add(data=[], name=str(key) + "%ilelist") for key in percentiles: self.selected_data_source.add(data=[], name=str(key) + "%ilelist") self.selected_data_source.add(data=[], name=str(key) + "%ile_namelist") class TableFilters: """ Class for holding data filters. """ def __init__(self, data_sources: DataSources): self.name_filter = BooleanFilter() self.name_filter.booleans = [True for _ in range(len(data_sources.data_source.data['idx']))] self.min_thresh_filter = BooleanFilter() self.min_thresh_filter.booleans = [True for _ in range(len(data_sources.data_source.data['idx']))] self.max_thresh_filter = BooleanFilter() self.max_thresh_filter.booleans = [True for _ in range(len(data_sources.data_source.data['idx']))] class TableViews: """ Class for holding views of the data sources. """ def __init__(self, tablefilters: TableFilters): self.min_thresh_view = CDSView(filter=tablefilters.min_thresh_filter) self.max_thresh_view = CDSView(filter=tablefilters.max_thresh_filter) class TableObjects: """ Class for holding various objects related to the table elements in the visualization. """ def __init__(self, datasources: DataSources): self.filters = TableFilters(datasources) self.views = TableViews(self.filters) columns = [ TableColumn(field="idx", title="Layer Index", width=QuantStatsVisualizer.table_column_widths["Layer Index"]), TableColumn(field="namelist", title="Layer Name", formatter=StringFormatter(font_style="bold"), width=QuantStatsVisualizer.table_column_widths["Layer Name"]), TableColumn(field="minlist", title="Min Activation", formatter=ScientificFormatter(precision=3), width=QuantStatsVisualizer.table_column_widths["Min Activation"]), TableColumn(field="maxlist", title="Max Activation", formatter=ScientificFormatter(precision=3), width=QuantStatsVisualizer.table_column_widths["Max Activation"]), ] self.data_table = DataTable(source=datasources.table_data_source, columns=columns, sortable=True, width=QuantStatsVisualizer.plot_dims["table_width"], selectable="checkbox", index_position=None, ) class InputWidgets: """ Class to hold various input widgets. """ def __init__(self, default_values: dict): self.ymin_input = TextInput(value=str(default_values['default_ymin']), title="Enter lower display limit of the plot") self.ymax_input = TextInput(value=str(default_values['default_ymax']), title="Enter upper display limit of the plot") self.minclip_input = TextInput(value=str(default_values['default_minclip']), title="Enter lower threshold value for activations/weights") self.maxclip_input = TextInput(value=str(default_values['default_maxclip']), title="Enter upper threshold value for activations/weights") self.name_input = TextInput(value="", title="Enter Name Filter") tooltip_table_mode = Tooltip(content=HTML(""" <h3> Select Table View </h3> <p> Following table views are available </p> <ol> <li> <b> All: </b> All quantized layers </li> <li> <b> Min: </b> Quantized layers with min activation below lower threshold value </li> <li> <b> Max: </b> Quantized layers with max activation above upper threshold value </li> <li> <b> Min | Max: </b> Union of Min and Max </li> <li> <b> Min & Max: </b> Intersection of Min and Max </li> </ol> """), position="right") self.table_view_select = Select(title="Select Table View", value="Min | Max", options=["All", "Min", "Max", "Min | Max", "Min & Max"], width=200, description=tooltip_table_mode ) class CustomCallbacks: """ Class to hold Custom JavaScript Callbacks for interactivity in the visualization. """ def __init__(self): self.limit_change_callback = None self.reset_callback = None self.name_filter_callback = None self.select_table_view_callback = None self.table_selection_callback = None class QuantStatsVisualizer: """ Class for constructing the visualization with functionality to export the plot as :param stats_dict: Dictionary containing the module names, indices, and other extracted statistics """ # Class level constants plot_dims = {"plot_width": 700, "plot_height": 400, "table_width": 800, "boxplot_unit_width": 150, "boxplot_height": 400, "whisker_head": 20, "boxplot_vbar_width": 0.7, "boxplot_text_font_size": "10px"} initial_vals = {"default_ymin": -1e5, "default_ymax": 1e5} spacer_dims = {"sp1_width": 50, "sp1_height": 40, "sp2_width": 100} table_column_widths = {"Layer Index": 100, "Layer Name": 400, "Min Activation": 100, "Max Activation": 100} def __init__(self, stats_dict: dict, percentile_list: list): self.stats_dict = stats_dict self.plot = figure( title="Min Max Activations/Weights of quantized modules for given model", x_axis_label="Layer index", y_axis_label="Activation/Weight", tools="pan,wheel_zoom,box_zoom") self.boxplot = figure( x_range=FactorRange(), title="Boxplots of selected layers", x_axis_label="Layer index", y_axis_label="Activation/Weight", tools="pan,wheel_zoom,box_zoom") self.default_values = dict() self.percentiles = [] for percentile in percentile_list: if percentile not in [25, 50, 75]: self.percentiles.append(percentile) def _add_plot_lines(self, datasources: DataSources): self.plot.segment(x0='xmin', x1='xmax', y0='ymin', y1='ymin', line_width=4, line_color='black', source=datasources.limits_source) self.plot.segment(x0='xmin', x1='xmax', y0='ymax', y1='ymax', line_width=4, line_color='black', source=datasources.limits_source) self.plot.segment(x0='xmin', x1='xmax', y0='minclip', y1='minclip', line_width=2, line_color='black', line_dash='dashed', source=datasources.limits_source) self.plot.segment(x0='xmin', x1='xmax', y0='maxclip', y1='maxclip', line_width=2, line_color='black', line_dash='dashed', source=datasources.limits_source) self.plot.line('idx', 'maxlist', source=datasources.data_source, legend_label="Max Activation", line_width=2, line_color="red") self.plot.line('idx', 'minlist', source=datasources.data_source, legend_label="Min Activation", line_width=2, line_color="blue") selections = self.plot.segment(x0='idx', x1='idx', y0='floor', y1='ceil', line_width=2, line_color='goldenrod', line_alpha=0.5, source=datasources.selected_data_source) return selections def _add_min_max_markers(self, datasources: DataSources, tableobjects: TableObjects): min_markers = self.plot.circle_x('idx', 'marker_yminlist', source=datasources.data_source, size=10, color='orange', line_color="navy") min_markers.view = tableobjects.views.min_thresh_view max_markers = self.plot.circle_x('idx', 'marker_ymaxlist', source=datasources.data_source, size=10, color='orange', line_color="navy") max_markers.view = tableobjects.views.max_thresh_view return min_markers, max_markers def _add_boxplots(self, datasources: DataSources): whisker = Whisker(base="stridx", upper="boxplot_upper_list", lower="boxplot_lower_list", source=datasources.selected_data_source) whisker.upper_head.size = whisker.lower_head.size = QuantStatsVisualizer.plot_dims["whisker_head"] self.boxplot.add_layout(whisker) self.boxplot.vbar("stridx", QuantStatsVisualizer.plot_dims["boxplot_vbar_width"], "50%ilelist", "75%ilelist", source=datasources.selected_data_source, line_color="black") self.boxplot.vbar("stridx", QuantStatsVisualizer.plot_dims["boxplot_vbar_width"], "25%ilelist", "50%ilelist", source=datasources.selected_data_source, line_color="black") self.boxplot.circle(x="stridx", y="minlist", source=datasources.selected_data_source, name="min_points", color="orange") self.boxplot.text(x="stridx", y="minlist", text="min_namelist", source=datasources.selected_data_source, x_offset=5, y_offset=5, text_font_size=QuantStatsVisualizer.plot_dims["boxplot_text_font_size"], name="min_labels") self.boxplot.circle(x="stridx", y="maxlist", source=datasources.selected_data_source, name="max_points", color="orange") self.boxplot.text(x="stridx", y="maxlist", text="max_namelist", source=datasources.selected_data_source, x_offset=5, y_offset=5, text_font_size=QuantStatsVisualizer.plot_dims["boxplot_text_font_size"], name="max_labels") for percentile in self.percentiles: self.boxplot.circle(x="stridx", y=str(percentile)+"%ilelist", source=datasources.selected_data_source, name=str(percentile) + "_" + "points", color="orange") self.boxplot.text(x="stridx", y=str(percentile)+"%ilelist", text=str(percentile)+"%ile_namelist", source=datasources.selected_data_source, x_offset=5, y_offset=5, text_font_size=QuantStatsVisualizer.plot_dims["boxplot_text_font_size"], name=str(percentile) + "_" + "labels") @staticmethod def _get_marker_hovertool(min_markers, max_markers): format_code = """ if (Math.abs(value) < 1e-3 || Math.abs(value) > 1e5) { return value.toExponential(3); } else { return value.toFixed(3); } """ format_hover = CustomJSHover(code=format_code) marker_hover = HoverTool(renderers=[min_markers, max_markers], tooltips=[ ("Layer Index", "@idx"), ("Name", "@namelist"), ("Max Activation", "@maxlist{custom}"), ("Min Activation", "@minlist{custom}"), ], formatters={ "@minlist": format_hover, "@maxlist": format_hover, }) return marker_hover @staticmethod def _get_selection_hovertool(selections): format_code = """ if (Math.abs(value) < 1e-3 || Math.abs(value) > 1e5) { return value.toExponential(3); } else { return value.toFixed(3); } """ format_hover = CustomJSHover(code=format_code) selection_hover = HoverTool(renderers=[selections], tooltips=[ ("Layer Index", "@idx"), ("Name", "@namelist"), ("Max Activation", "@maxlist{custom}"), ("Min Activation", "@minlist{custom}"), ], formatters={ "@minlist": format_hover, "@maxlist": format_hover, }) return selection_hover def _define_callbacks(self, datasources, tableobjects, inputwidgets, mode): customcallbacks = CustomCallbacks() table_columns = ["idx", "namelist", "minlist", "maxlist"] selection_columns = [] if mode == "basic": selection_columns += ["idx", "namelist", "minlist", "min_namelist", "maxlist", "max_namelist"] elif mode == "advanced": selection_columns += ["idx", "namelist", "minlist", "min_namelist", "maxlist", "max_namelist", "stridx", "boxplot_upper_list", "boxplot_lower_list", "25%ilelist", "50%ilelist", "75%ilelist"] for percentile in self.percentiles: selection_columns.append(str(percentile) + "%ilelist") selection_columns.append(str(percentile) + "%ile_namelist") customcallbacks.limit_change_callback = CustomJS(args=dict( limits_source=datasources.limits_source, data_source=datasources.data_source, table_data_source=datasources.table_data_source, selected_data_source=datasources.selected_data_source, min_marker_source=datasources.data_source, max_marker_source=datasources.data_source, ymax_input=inputwidgets.ymax_input, ymin_input=inputwidgets.ymin_input, maxclip_input=inputwidgets.maxclip_input, minclip_input=inputwidgets.minclip_input, plot=self.plot, min_thresh_filter=tableobjects.filters.min_thresh_filter, max_thresh_filter=tableobjects.filters.max_thresh_filter, name_filter=tableobjects.filters.name_filter, select=inputwidgets.table_view_select, table_columns=table_columns, ), code=(Path(__file__).parent / "quant_stats_visualization_JS_code/utils.js").read_text("utf8") + (Path(__file__).parent / "quant_stats_visualization_JS_code/limit_change_callback.js").read_text("utf8")) customcallbacks.reset_callback = CustomJS(args=dict( limits_source=datasources.limits_source, data_source=datasources.data_source, table_data_source=datasources.table_data_source, selected_data_source=datasources.selected_data_source, default_values_source=datasources.default_values_source, min_marker_source=datasources.data_source, max_marker_source=datasources.data_source, ymax_input=inputwidgets.ymax_input, ymin_input=inputwidgets.ymin_input, maxclip_input=inputwidgets.maxclip_input, minclip_input=inputwidgets.minclip_input, select=inputwidgets.table_view_select, name_input=inputwidgets.name_input, plot=self.plot, boxplot=self.boxplot, min_thresh_filter=tableobjects.filters.min_thresh_filter, max_thresh_filter=tableobjects.filters.max_thresh_filter, name_filter=tableobjects.filters.name_filter, selection_columns=selection_columns, table_columns=table_columns, mode=mode, boxplot_unit_width=QuantStatsVisualizer.plot_dims["boxplot_unit_width"], ), code=(Path(__file__).parent / "quant_stats_visualization_JS_code/utils.js").read_text("utf8") + (Path(__file__).parent / "quant_stats_visualization_JS_code/reset_callback.js").read_text("utf8")) customcallbacks.name_filter_callback = CustomJS(args=dict( data_source=datasources.data_source, table_data_source=datasources.table_data_source, limits_source=datasources.limits_source, min_thresh_filter=tableobjects.filters.min_thresh_filter, max_thresh_filter=tableobjects.filters.max_thresh_filter, name_filter=tableobjects.filters.name_filter, select=inputwidgets.table_view_select, table_columns=table_columns, ), code=(Path(__file__).parent / "quant_stats_visualization_JS_code/utils.js").read_text("utf8") + (Path(__file__).parent / "quant_stats_visualization_JS_code/name_filter_callback.js").read_text("utf8")) customcallbacks.select_table_view_callback = CustomJS(args=dict( data_source=datasources.data_source, table_data_source=datasources.table_data_source, select=inputwidgets.table_view_select, min_thresh_filter=tableobjects.filters.min_thresh_filter, max_thresh_filter=tableobjects.filters.max_thresh_filter, name_filter=tableobjects.filters.name_filter, table=tableobjects.data_table, table_columns=table_columns, ), code=(Path(__file__).parent / "quant_stats_visualization_JS_code/utils.js").read_text("utf8") + (Path(__file__).parent / "quant_stats_visualization_JS_code/select_table_view_callback.js").read_text("utf8")) customcallbacks.table_selection_callback = CustomJS(args=dict( data_source=datasources.data_source, table_data_source=datasources.table_data_source, selected_data_source=datasources.selected_data_source, limits_source=datasources.limits_source, boxplot=self.boxplot, selection_columns=selection_columns, mode=mode, boxplot_unit_width=QuantStatsVisualizer.plot_dims["boxplot_unit_width"], ), code=(Path(__file__).parent / "quant_stats_visualization_JS_code/utils.js").read_text("utf8") + (Path(__file__).parent / "quant_stats_visualization_JS_code/table_selection_callback.js").read_text("utf8")) return customcallbacks def _attach_callbacks(self, datasources, inputwidgets, customcallbacks): self.plot.js_on_event(Reset, customcallbacks.reset_callback) inputwidgets.ymax_input.js_on_change('value', customcallbacks.limit_change_callback) inputwidgets.ymin_input.js_on_change('value', customcallbacks.limit_change_callback) inputwidgets.maxclip_input.js_on_change('value', customcallbacks.limit_change_callback) inputwidgets.minclip_input.js_on_change('value', customcallbacks.limit_change_callback) inputwidgets.name_input.js_on_change("value", customcallbacks.name_filter_callback) inputwidgets.table_view_select.js_on_change('value', customcallbacks.select_table_view_callback) datasources.table_data_source.selected.js_on_change('indices', customcallbacks.table_selection_callback) def _create_layout(self, inputwidgets, tableobjects, mode): heading_1 = Div(text="<h2>Quant Stats Visualizer</h2>") heading_2 = Div(text="<h2>Quant Stats Data Table</h2>") sp1 = Spacer(width=QuantStatsVisualizer.spacer_dims["sp1_width"], height=QuantStatsVisualizer.spacer_dims["sp1_height"]) row1 = row(inputwidgets.ymin_input, inputwidgets.ymax_input) row2 = row(inputwidgets.minclip_input, inputwidgets.maxclip_input) inputs1 = column(row1, row2) if mode == "basic": layout = column(heading_1, inputs1, sp1, self.plot, column(heading_2, row(inputwidgets.table_view_select, inputwidgets.name_input), tableobjects.data_table)) elif mode == "advanced": sp2 = Spacer(width=QuantStatsVisualizer.spacer_dims["sp2_width"], height=QuantStatsVisualizer.plot_dims["boxplot_height"]) layout = column(heading_1, inputs1, sp1, row(self.plot, sp2, self.boxplot), column(heading_2, row(inputwidgets.table_view_select, inputwidgets.name_input), tableobjects.data_table)) else: raise ValueError(f"Expected mode to be 'basic' or 'advanced', got '{mode}'.") return layout def export_plot_as_html(self, save_path: str, mode: str) -> None: """ Method for constructing the visualization and saving it to the given path. :param save_path: Path for saving the visualization. :param mode: Whether to plot basic or advanced stats. """ curdoc().theme = 'light_minimal' self.plot.width = QuantStatsVisualizer.plot_dims["plot_width"] self.plot.height = QuantStatsVisualizer.plot_dims["plot_height"] # Defining the default values of plotting parameters self.default_values['default_ymax'] = QuantStatsVisualizer.initial_vals["default_ymax"] self.default_values['default_ymin'] = QuantStatsVisualizer.initial_vals["default_ymin"] self.default_values['default_xmax'] = len(self.stats_dict["idx"]) - 1 self.default_values['default_xmin'] = 0 self.default_values['default_maxclip'] = self.default_values['default_ymax'] / 2 self.default_values['default_minclip'] = self.default_values['default_ymin'] / 2 self.plot.x_range = Range1d(0, len(self.stats_dict["idx"])) self.plot.y_range = Range1d(self.default_values['default_ymax'] * 1.05, self.default_values['default_ymin'] * 1.05) # Creating and adding a reset tool rt = ResetTool() self.plot.add_tools(rt) # Defining Bokeh ColumnDataSources datasources = DataSources(stats_dict=self.stats_dict, plot=self.plot, default_values=self.default_values, percentiles=self.percentiles ) # Creating plot objects selections = self._add_plot_lines(datasources) if mode == "advanced": self.boxplot.width = 5 * QuantStatsVisualizer.plot_dims["boxplot_unit_width"] self.boxplot.height = QuantStatsVisualizer.plot_dims["boxplot_height"] self.boxplot.y_range = Range1d() self._add_boxplots(datasources) # Defining the table objects and name filter views tableobjects = TableObjects(datasources) # Marker points to see which layers cross the thresholds min_markers, max_markers = self._add_min_max_markers(datasources, tableobjects) # Defining a hover functionality to see layer details on hovering on the marker points and selections marker_hover = self._get_marker_hovertool(min_markers, max_markers) selection_hover = self._get_selection_hovertool(selections) self.plot.add_tools(marker_hover, selection_hover) # Creating the input widgets inputwidgets = InputWidgets(self.default_values) # Defining Custom JavaScript callbacks customcallbacks = self._define_callbacks(datasources, tableobjects, inputwidgets, mode) # Attach events to corresponding callbacks curdoc().js_on_event(DocumentReady, customcallbacks.reset_callback) self._attach_callbacks(datasources, inputwidgets, customcallbacks) # Define the formatting layout = self._create_layout(inputwidgets, tableobjects, mode) # Save as standalone html save(layout, save_path)