Deep Learning Applications in Microscopy: Segmentation and Tracking
Contents
Interactive IoU and Dice coefficients
%matplotlib widget
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display
from matplotlib.ticker import MaxNLocator
# Disable matplotlib auto display
plt.ioff()
# Read data
file_path_esam = 'train_test_log/log_esam_tiny.csv'
data_esam = pd.read_csv(file_path_esam)
data_esam.columns = [col.strip() for col in data_esam.columns]
file_path_swin = 'train_test_log/log_swin_unet.csv'
data_swin = pd.read_csv(file_path_swin)
data_swin.columns = [col.strip() for col in data_swin.columns]
file_path_vmamba = 'train_test_log/log_mamba.csv'
data_vmamba = pd.read_csv(file_path_vmamba)
data_vmamba.columns = [col.strip() for col in data_vmamba.columns]
# Integrate data into a dictionary
models = {
'EfficientSAM': {
'train_iou': data_esam['Train IOU'][:1000],
'test_iou': data_esam['Test IOU'][:1000],
'train_dice': data_esam['Train Dice'][:1000],
'test_dice': data_esam['Test Dice'][:1000]
},
'Swin-UNet': {
'train_iou': data_swin['Train IOU'][:1000],
'test_iou': data_swin['Test IOU'][:1000],
'train_dice': data_swin['Train Dice'][:1000],
'test_dice': data_swin['Test Dice'][:1000]
},
'VMamba': {
'train_iou': data_vmamba['Train IOU'][:1000],
'test_iou': data_vmamba['Test IOU'][:1000],
'train_dice': data_vmamba['Train Dice'][:1000],
'test_dice': data_vmamba['Test Dice'][:1000]
}
}
# Limit epochs to the first 1000 data points
epochs = np.arange(1, 1001)
# Initialize plotting
fig, ax = plt.subplots(figsize=(4,)*2)
fig.canvas.resizable = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.toolbar_visible = True
fig.canvas.layout.width = '670px'
fig.canvas.layout.height = "420px"
fig.canvas.toolbar_position = 'bottom'
ax.set_xlabel("Epoch")
ax.set_ylabel("Metric")
# Create interactive widgets
switch = widgets.Dropdown(
options=['IoU', 'Dice'],
value='IoU',
description='Metric:',
layout=widgets.Layout(width='200px')
)
x_axis_scale = widgets.Dropdown(
options=['Linear', 'Log'],
value='Linear',
description='X-Axis Scale:',
layout=widgets.Layout(width='200px')
)
smoothing_slider = widgets.FloatSlider(
value=0,
min=0,
max=10,
step=0.5,
description='Smoothing:',
layout=widgets.Layout(width='200px')
)
checkboxes = {
model: widgets.Checkbox(value=True, description=model)
for model in models.keys()
}
# update plot function
def update_plot(*args):
ax.clear()
ax.set_xlabel("Epoch")
ax.set_ylabel(switch.value)
# Update X-axis scale
if x_axis_scale.value == 'Log':
ax.set_xscale('log')
else:
ax.set_xscale('linear')
ax.xaxis.set_major_locator(MaxNLocator(integer=True))
smoothing_window = int(smoothing_slider.value)
# Display different data based on the selected metric
metric_map = {
'IoU': ('train_iou', 'test_iou'),
'Dice': ('train_dice', 'test_dice')
}
for model_name, checkbox in checkboxes.items():
if checkbox.value:
train_metric, test_metric = metric_map[switch.value]
train_data = models[model_name][train_metric]
test_data = models[model_name][test_metric]
if smoothing_window > 0:
# Use pandas' rolling method for smoothing
train_data = pd.Series(train_data).rolling(smoothing_window, min_periods=1, center=True).mean().to_numpy()
test_data = pd.Series(test_data).rolling(smoothing_window, min_periods=1, center=True).mean().to_numpy()
ax.plot(epochs, train_data, label=f"{model_name} Train {switch.value}")
ax.plot(epochs, test_data, label=f"{model_name} Test {switch.value}")
ax.legend()
fig.canvas.draw_idle()
# Bind widgets to the update function
switch.observe(update_plot, 'value')
x_axis_scale.observe(update_plot, 'value')
smoothing_slider.observe(update_plot, 'value')
for checkbox in checkboxes.values():
checkbox.observe(update_plot, 'value')
# Layout controls on the right
controls = widgets.VBox([
switch,
x_axis_scale,
smoothing_slider,
widgets.VBox(list(checkboxes.values()))
], layout=widgets.Layout(padding='0px', align_items='flex-start', width='300px'))
# Combine the plot and controls
output = widgets.HBox([controls, fig.canvas], layout=widgets.Layout(align_items='flex-start'))
# # Display the layout
display(output)
# Initial plot rendering
update_plot()
HBox(children=(VBox(children=(Dropdown(description='Metric:', layout=Layout(width='200px'), options=('IoU', 'D…