Deep Learning Applications in Microscopy: Segmentation and Tracking

Interactive IoU and Dice coefficients

%matplotlib widget

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display
from matplotlib.ticker import MaxNLocator

# Disable matplotlib auto display
plt.ioff()

# Read data
file_path_esam = 'train_test_log/log_esam_tiny.csv'
data_esam = pd.read_csv(file_path_esam)
data_esam.columns = [col.strip() for col in data_esam.columns]

file_path_swin = 'train_test_log/log_swin_unet.csv'
data_swin = pd.read_csv(file_path_swin)
data_swin.columns = [col.strip() for col in data_swin.columns]

file_path_vmamba = 'train_test_log/log_mamba.csv'
data_vmamba = pd.read_csv(file_path_vmamba)
data_vmamba.columns = [col.strip() for col in data_vmamba.columns]

# Integrate data into a dictionary
models = {
    'EfficientSAM': {
        'train_iou': data_esam['Train IOU'][:1000],
        'test_iou': data_esam['Test IOU'][:1000],
        'train_dice': data_esam['Train Dice'][:1000],
        'test_dice': data_esam['Test Dice'][:1000]
    },
    'Swin-UNet': {
        'train_iou': data_swin['Train IOU'][:1000],
        'test_iou': data_swin['Test IOU'][:1000],
        'train_dice': data_swin['Train Dice'][:1000],
        'test_dice': data_swin['Test Dice'][:1000]
    },
    'VMamba': {
        'train_iou': data_vmamba['Train IOU'][:1000],
        'test_iou': data_vmamba['Test IOU'][:1000],
        'train_dice': data_vmamba['Train Dice'][:1000],
        'test_dice': data_vmamba['Test Dice'][:1000]
    }
}

# Limit epochs to the first 1000 data points
epochs = np.arange(1, 1001)

# Initialize plotting
fig, ax = plt.subplots(figsize=(4,)*2)
fig.canvas.resizable = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.toolbar_visible = True
fig.canvas.layout.width = '670px'
fig.canvas.layout.height = "420px"
fig.canvas.toolbar_position = 'bottom'
ax.set_xlabel("Epoch")
ax.set_ylabel("Metric")

# Create interactive widgets
switch = widgets.Dropdown(
    options=['IoU', 'Dice'],
    value='IoU',
    description='Metric:',
    layout=widgets.Layout(width='200px')
)

x_axis_scale = widgets.Dropdown(
    options=['Linear', 'Log'],
    value='Linear',
    description='X-Axis Scale:',
    layout=widgets.Layout(width='200px')
)

smoothing_slider = widgets.FloatSlider(
    value=0,
    min=0,
    max=10,
    step=0.5,
    description='Smoothing:',
    layout=widgets.Layout(width='200px')
)

checkboxes = {
    model: widgets.Checkbox(value=True, description=model)
    for model in models.keys()
}

# update plot function
def update_plot(*args):
    ax.clear()
    ax.set_xlabel("Epoch")
    ax.set_ylabel(switch.value)
    
    # Update X-axis scale
    if x_axis_scale.value == 'Log':
        ax.set_xscale('log')
    else:
        ax.set_xscale('linear')
    ax.xaxis.set_major_locator(MaxNLocator(integer=True))
    
    smoothing_window = int(smoothing_slider.value)
    
    # Display different data based on the selected metric
    metric_map = {
        'IoU': ('train_iou', 'test_iou'),
        'Dice': ('train_dice', 'test_dice')
    }
    
    for model_name, checkbox in checkboxes.items():
        if checkbox.value:
            train_metric, test_metric = metric_map[switch.value]
            train_data = models[model_name][train_metric]
            test_data = models[model_name][test_metric]
            
            if smoothing_window > 0:
                # Use pandas' rolling method for smoothing
                train_data = pd.Series(train_data).rolling(smoothing_window, min_periods=1, center=True).mean().to_numpy()
                test_data = pd.Series(test_data).rolling(smoothing_window, min_periods=1, center=True).mean().to_numpy()
            
            ax.plot(epochs, train_data, label=f"{model_name} Train {switch.value}")
            ax.plot(epochs, test_data, label=f"{model_name} Test {switch.value}")
    
    ax.legend()
    fig.canvas.draw_idle()

# Bind widgets to the update function
switch.observe(update_plot, 'value')
x_axis_scale.observe(update_plot, 'value')
smoothing_slider.observe(update_plot, 'value')
for checkbox in checkboxes.values():
    checkbox.observe(update_plot, 'value')

# Layout controls on the right
controls = widgets.VBox([
    switch,
    x_axis_scale,
    smoothing_slider,
    widgets.VBox(list(checkboxes.values()))
], layout=widgets.Layout(padding='0px', align_items='flex-start', width='300px'))

# Combine the plot and controls
output = widgets.HBox([controls, fig.canvas], layout=widgets.Layout(align_items='flex-start'))

# # Display the layout
display(output)

# Initial plot rendering
update_plot()