Deep Learning Applications in Microscopy: Segmentation and Tracking

Interactive train and test loss curves

%matplotlib widget

import pandas as pd
import matplotlib.pyplot as plt
import ipywidgets as widgets
from IPython.display import display

# Disable matplotlib auto display
plt.ioff()

# Load the CSV files for different models
file_paths = {
    'YOLOv8n-seg': 'train_test_log/log_yolo_v8.csv',
    'EfficientSAM': 'train_test_log/log_esam_tiny.csv',
    'vmamba': 'train_test_log/log_mamba.csv',
    'Swin-UNet': 'train_test_log/log_swin_unet.csv'
}

# Read and store data in a dictionary
data_dict = {}
for model_name, file_path in file_paths.items():
    data = pd.read_csv(file_path)
    data.columns = data.columns.str.strip()
    data_dict[model_name] = data

# Create a figure with 2 rows and 2 columns
fig, axes = plt.subplots(2, 2, figsize=(6.7,)*2)
plt.subplots_adjust(wspace=0.3, hspace=0.3)  # Adjust spacing between subplots
fig.canvas.resizable = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.toolbar_visible = True
fig.canvas.layout.width = '670px'
fig.canvas.layout.height = "710px"
fig.canvas.toolbar_position = 'bottom'

# Flatten the axes array for easy iteration
axes = axes.flatten()

# Create toggle buttons for log scaling (Linear vs Log)
log_scale_toggle = widgets.ToggleButtons(
    options=['Linear', 'Log'],
    value='Linear',
    description='Y-Axis Scale:',
    tooltip='Toggle log scaling of y-axis'
)

# Define the update function
def update_plots(*args):
    for ax, (model_name, data) in zip(axes, data_dict.items()):
        ax.clear()  # Clear the current axes
        
        # Extract epochs and loss values
        if model_name == 'YOLOv8n-seg':
            epochs = data['epoch']
            train_loss = data['train/seg_loss']
            val_loss = data['val/seg_loss']
        else:
            epochs = data['E']
            train_loss = data['Train Loss']
            val_loss = data['Test Loss']
        
        # Plot the data
        ax.plot(epochs, train_loss, label='Train Loss')
        ax.plot(epochs, val_loss, linestyle='--', label='Test Loss')

        # Set title and labels
        ax.set_title(model_name)
        ax.set_xlabel('Epochs')
        ax.set_ylabel('Loss')
        
        # Set y-axis scale based on the toggle button
        if log_scale_toggle.value == 'Log':
            ax.set_yscale('log')
        else:
            ax.set_yscale('linear')
        
        # Add legend
        ax.legend()
    
    fig.canvas.draw_idle()

# Connect the widgets to the update function
log_scale_toggle.observe(update_plots, names='value')

# Call the update function once to initialize the plots
update_plots()

# Create a VBox to combine the widgets and the figure
combined_box = widgets.VBox([log_scale_toggle, fig.canvas])  # Combine widgets and figure in a vertical box

# Display the combined widgets and figure
display(combined_box)