Phase Contrast Imaging Notebook

%matplotlib widget

import tcbf
import numpy as np
import matplotlib.pyplot as plt

from IPython.display import display
import ipywidgets
file_name = "apoF-ice-embedded-potential-binned.npy"
binned_volume_zxy = np.load("data/"+file_name)
projected_potential = np.sum(binned_volume_zxy,axis=0)
style = {
    'description_width': 'initial',
}

layout = ipywidgets.Layout(width="250px",height="30px")

defocus_slider = ipywidgets.FloatSlider(
    value = 0, min = -2, max = 2, 
    step = 0.05,
    description = r"defocus [$\mu$m]",
    style = style,
    layout = layout,
)


electrons_per_area_slider = ipywidgets.FloatLogSlider(
    value=10,
    base=10,
    min=1, # min exponent of base
    max=3, # max exponent of base
    step=0.05, # exponent step
    description = r"dose [e/A$^2$]",
    style = style,
    layout = layout,
)

# show_zernike_switch = ipywidgets.Checkbox(
#     value=False,
#     description="Zernike phase-plate",
#     style=style,
#     layout = layout,
# )

# def toggle_zernike(change):
#     show_zernike = change['new']
#     axs[1].axis("on" if show_zernike else "off")
#     for artist in zernike_artists: 
#         artist.set_visible(show_zernike)
#     fig.canvas.draw_idle()
#     return None

# show_zernike_switch.observe(toggle_zernike,names='value')
# constants
semiangle = 4  # mrad
wavelength = 0.0197  # A (300kV)
sigma = 0.00065  # 1/V (300kV)
rolloff = 0.125  # mrad
# PotentialArray

pixel_size = 2 / 3
bin_factor_xy = 2
bin_factor_z = 6

potential = tcbf.PotentialArray(
    binned_volume_zxy,
    slice_thickness=pixel_size * bin_factor_z,
    sampling=(pixel_size * bin_factor_xy, pixel_size * bin_factor_xy),
)

potential.slice_thickness = pixel_size * bin_factor_z + 1e4 * defocus_slider.value / binned_volume_zxy.shape[0]
# Tilted Plane Wave
tilted_plane_wave = tcbf.Waves(
    array=np.ones(potential.gpts, dtype=np.complex64),
    sampling=potential.sampling,
    wavelength=wavelength,
    sigma=sigma,
    tilt=(0, 0),
)

# CTF
ctf = tcbf.CTF(
    semiangle_cutoff=semiangle,
    rolloff=rolloff,
)

# Angles
alpha, phi = tilted_plane_wave.get_scattering_angles()
bright_field_disk = np.fft.fftshift(ctf.evaluate_aperture(alpha, phi))
# Exit Waves

exit_wave = tilted_plane_wave.multislice(potential)
exit_wave = np.random.poisson(
    (
        np.abs(exit_wave) ** 2
        * np.prod(potential.sampling)
        * electrons_per_area_slider.value
    ).clip(0)
)
# Static Figure

with plt.ioff():
    dpi = 72
    fig, axs = plt.subplots(1,3, figsize=(675/dpi, (275+12)/dpi), dpi=dpi)

# projected potential
tcbf.show(
    projected_potential,
    ticks=False,
    figax=(fig, axs[0]),
    cbar=False,
    cmap='magma',
)
axs[0].set_title(
    "Projected Potential of Sample",
    fontsize=12,
)
axs[0].axis("off")
tcbf.add_scalebar(
    axs[0],
    color="white",
    sampling=pixel_size * bin_factor_xy / 10,
    length=30,
    units="nm",
)

# HRTEM exit wave
tcbf.show(
    exit_wave,
    ticks=False,
    figax=(fig, axs[1]),
    cbar=False,
)
axs[1].set_title(
    "CTEM Image Intensity",
    fontsize=12,
)
tcbf.add_scalebar(
    axs[1],
    color="black",
    sampling=pixel_size * bin_factor_xy / 10,
    length=30,
    units="nm",
)

# Zernike phase plate
exit_wave_zernike = np.fft.fft2(exit_wave)
zernike_kernel = np.zeros_like(np.abs(exit_wave_zernike))
zernike_kernel[0, 0] = np.pi / 2
zernike_kernel = np.exp(1j * zernike_kernel)
exit_wave_zernike = np.fft.ifft2(exit_wave_zernike * zernike_kernel)
exit_wave_zernike = np.random.poisson(
    (
        np.abs(exit_wave_zernike) ** 2
        * np.prod(potential.sampling)
        * electrons_per_area_slider.value
    ).clip(0)
)

tcbf.show(
    exit_wave_zernike,
    ticks=False,
    figax=(fig, axs[2]),
    cbar=False,
)

_, bar = tcbf.add_scalebar(
    axs[2],
    color="white",
    sampling=pixel_size * bin_factor_xy / 10,
    length=30,
    units="nm",
)

text = axs[2].set_title(
    "Zernike Phase Plate Intensity",
    fontsize=12,
)

im = axs[2].get_images()[0]

zernike_artists = [im, bar, text]
# for artist in zernike_artists: 
#     artist.set_visible(False)
    
# axs[2].patch.set_visible(False)
# axs[2].axis("off")
fig.tight_layout()
def update_figure(
    defocus,
    electrons_per_area,
    # show_zernike,
):
    """ """
    potential.slice_thickness = pixel_size * bin_factor_z + 1e4 * defocus / binned_volume_zxy.shape[0]
    exit_wave = tilted_plane_wave.multislice(potential)
    _exit_wave = np.random.poisson(
        (
            np.abs(exit_wave) ** 2
            * np.prod(potential.sampling)
            * electrons_per_area
        ).clip(0)
    )
    im = axs[1].get_images()[0]
    
    _exit_wave, _vmin, _vmax = tcbf.visualize.return_scaled_histogram(_exit_wave)
    im.set_data(_exit_wave)
    im.set_clim(vmin=_vmin, vmax=_vmax)

    # if show_zernike:
    exit_wave_zernike = np.fft.fft2(exit_wave)
    zernike_kernel = np.zeros_like(np.abs(exit_wave_zernike))
    zernike_kernel[0, 0] = np.pi / 2
    zernike_kernel = np.exp(1j * zernike_kernel)
    exit_wave_zernike = np.fft.ifft2(exit_wave_zernike * zernike_kernel)
    exit_wave_zernike = np.random.poisson(
        (
            np.abs(exit_wave_zernike) ** 2
            * np.prod(potential.sampling)
            * electrons_per_area
        ).clip(0)
    )

    _exit_wave, _vmin, _vmax = tcbf.visualize.return_scaled_histogram(exit_wave_zernike)
    zernike_artists[0].set_data(_exit_wave)
    zernike_artists[0].set_clim(vmin=_vmin, vmax=_vmax)
    
    fig.canvas.draw_idle()
    return None

ipywidgets.widgets.interactive_output(
    update_figure, 
    {
        'defocus':defocus_slider,
        'electrons_per_area':electrons_per_area_slider,
        # 'show_zernike':show_zernike_switch
    },
)

fig.canvas.resizable = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.toolbar_visible = True
fig.canvas.layout.width = '680px'
fig.canvas.layout.height = "292px"
fig.canvas.toolbar_position = 'bottom'
display(
    ipywidgets.VBox([
        fig.canvas,
        ipywidgets.HBox(
            [
                defocus_slider,
                electrons_per_area_slider,
                # show_zernike_switch,
            ],
            layout=ipywidgets.Layout(justify_content="center",width="680px")
        )
    ])
)