A Practical Guide to Scanning and Transmission Electron Microscopy Simulations

Interactive CTF/PSF

%matplotlib widget
import numpy as np
import abtem

from matplotlib import cm, colors as mcolors, pyplot as plt
import numpy as np
from mpl_toolkits.axes_grid1 import make_axes_locatable
from colorspacious import cspace_convert

from IPython.display import display
import ipywidgets
def Complex2RGB(complex_data, vmin=None, vmax=None, power=None, chroma_boost=1):
    """
    complex_data (array): complex array to plot
    vmin (float)        : minimum absolute value
    vmax (float)        : maximum absolute value
    power (float)       : power to raise amplitude to
    chroma_boost (float): boosts chroma for higher-contrast (~1-2.5)
    """
    amp = np.abs(complex_data)
    phase = np.angle(complex_data)

    if power is not None:
        amp = amp**power

    if np.isclose(np.max(amp), np.min(amp)):
        if vmin is None:
            vmin = 0
        if vmax is None:
            vmax = np.max(amp)
    else:
        if vmin is None:
            vmin = 0.02
        if vmax is None:
            vmax = 0.98
        vals = np.sort(amp[~np.isnan(amp)])
        ind_vmin = np.round((vals.shape[0] - 1) * vmin).astype("int")
        ind_vmax = np.round((vals.shape[0] - 1) * vmax).astype("int")
        ind_vmin = np.max([0, ind_vmin])
        ind_vmax = np.min([len(vals) - 1, ind_vmax])
        vmin = vals[ind_vmin]
        vmax = vals[ind_vmax]

    amp = np.where(amp < vmin, vmin, amp)
    amp = np.where(amp > vmax, vmax, amp)
    amp = ((amp - vmin) / vmax).clip(1e-16, 1)

    J = amp * 61.5  # Note we restrict luminance to the monotonic chroma cutoff
    C = np.minimum(chroma_boost * 98 * J / 123, 110)
    h = np.rad2deg(phase) + 180

    JCh = np.stack((J, C, h), axis=-1)
    rgb = cspace_convert(JCh, "JCh", "sRGB1").clip(0, 1)

    return rgb


def add_colorbar_arg(ax, chroma_boost=1, c=49, j=61.5):
    """
    cax                 : axis to add cbar to
    chroma_boost (float): boosts chroma for higher-contrast (~1-2.25)
    c (float)           : constant chroma value
    j (float)           : constant luminance value
    """

    divider = make_axes_locatable(ax)
    cax = divider.append_axes("right", size="5%", pad="2.5%")
    
    h = np.linspace(0, 360, 256, endpoint=False)
    J = np.full_like(h, j)
    C = np.full_like(h, np.minimum(c * chroma_boost, 110))
    JCh = np.stack((J, C, h), axis=-1)
    rgb_vals = cspace_convert(JCh, "JCh", "sRGB1").clip(0, 1)
    newcmp = mcolors.ListedColormap(rgb_vals)
    norm = mcolors.Normalize(vmin=-np.pi, vmax=np.pi)

    cb = plt.colorbar(cm.ScalarMappable(norm=norm, cmap=newcmp), cax=cax)

    cb.set_label("arg", rotation=0, ha="center", va="bottom")
    cb.ax.yaxis.set_label_coords(0.5, 1.01)
    cb.set_ticks(np.array([-np.pi, -np.pi / 2, 0, np.pi / 2, np.pi]))
    cb.set_ticklabels(
        [r"$-\pi$", r"$-\dfrac{\pi}{2}$", "$0$", r"$\dfrac{\pi}{2}$", r"$\pi$"]
    )
    return None
probe_init = abtem.Probe(
    energy=300*1e3,
    semiangle_cutoff=25,
    gpts=(256,256),
    sampling=0.4,
).build()

ctf = abtem.CTF(
    energy=300*1e3,
    semiangle_cutoff=25,
    defocus=100,
)

probe = probe_init.apply_ctf(ctf)
array = probe.array.compute()
array_fourier = np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(array)))
rgb_psf = Complex2RGB(array,vmin=0,vmax=1)
rgb_ctf = Complex2RGB(array_fourier,vmin=0,vmax=1)

# widget figure generation
with plt.ioff():
    dpi = 72
    fig, (ax_ctf,ax_psf) = plt.subplots(1,2,figsize=(675/dpi, 300/dpi), dpi=dpi)

im_ctf = ax_ctf.imshow(rgb_ctf)
im_psf = ax_psf.imshow(rgb_psf)

for ax, title in zip((ax_ctf,ax_psf),['contrast transfer function (CTF)','point spread function (PSF)']):
    ax.set_xticks([])
    ax.set_yticks([])
    add_colorbar_arg(ax)
    ax.set_title(title)

fig.tight_layout()

fig.canvas.resizable = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.toolbar_visible = True
fig.canvas.layout.width = '675px'
fig.canvas.layout.height = '330px'
fig.canvas.toolbar_position = 'bottom'
/tmp/ipykernel_33501/13827917.py:9: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`). Consider using `matplotlib.pyplot.close()`.
  fig, (ax_ctf,ax_psf) = plt.subplots(1,2,figsize=(675/dpi, 300/dpi), dpi=dpi)
def update_aberrations(energy, defocus, C30, semiangle):

    ctf = abtem.CTF(
        energy=energy*1e3, 
        semiangle_cutoff = semiangle,
        defocus =  defocus,
        C30 =  C30*1e7,
    )

    probe_init = abtem.Probe(
        energy=energy*1e3,
        semiangle_cutoff=semiangle,
        gpts=(256,256),
        sampling=0.4,
    ).build()
    
    probe = probe_init.apply_ctf(ctf)
    array = probe.array.compute()
    array_fourier = np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(array)))
    rgb_psf = Complex2RGB(array,vmin=0,vmax=1)
    rgb_ctf = Complex2RGB(array_fourier,vmin=0,vmax=1)
    
    im_ctf.set_data(rgb_ctf)
    im_psf.set_data(rgb_psf)
    fig.canvas.draw_idle()
    return None
# List of options
option_list = (
    'select an option',
    'uncorrected STEM',
    'uncorrected STEM at Scherzer',
    'aberration corrected STEM',
    'SEM',
    'boundary artifacts', 
)

# update the plots with a pre-selected function
def select_preset_eventhandler(change):
    if change.new == option_list[1]: #uncorrected STEM
        energy.value = 300
        defocus.value = 0
        C30.value = 1.3
        semiangle.value = 10
    
    if change.new == option_list[2]: #uncorrected STEM @ Scherzer
        #calculate scherzer 
        lambda_300kv = abtem.core.energy.energy2wavelength(300e3)
        C1_Scherzer = 0.87 * (1.3e7*lambda_300kv) ** 0.5 
        
        energy.value = 300
        defocus.value = C1_Scherzer
        C30.value = 1.3
        semiangle.value = 10
        
    if change.new == option_list[3]: #corrected STEM
        energy.value = 300
        defocus.value = 0
        C30.value = 0.001
        semiangle.value = 10
    
    if change.new == option_list[4]: #SEM
        energy.value = 20
        defocus.vaue = 0
        C30.value = 5
        semiangle.value = 2.5
    
    if change.new == option_list[5]: #artifacts
        energy.value = 10
        defocus.value = 2000
        C30.value = 0
        semiangle.value = 20     
            
# Widgets
dropdown = ipywidgets.Dropdown(
    options = option_list,
    layout = ipywidgets.Layout(width='200px',height='30px'),
)
dropdown.observe(select_preset_eventhandler, names='value')

style = {
    'description_width': 'initial',
}

energy = ipywidgets.IntSlider(
    value=60, min=10, max=300, 
    step = 2,
    description = "energy (kV)",
    style = style,
)

defocus = ipywidgets.IntSlider(
    value = 0, min = -2000, max = 2000, 
    step = 20,
    description = "defocus / -1*C1 (A)",
    style = style
)

C30 = ipywidgets.FloatSlider(
    value = 0, min = 0, max =5, 
    step = 0.1,
    description = "C3 (mm)",
    style = style
)

semiangle = ipywidgets.FloatLogSlider(
    value=20,
    base=10,
    min=0, # min exponent of base
    max=1.6021, # max exponent of base
    step=0.05, # exponent step
    description = "semiangle (mrad)",
    style = style,
)

ipywidgets.interactive_output(
    update_aberrations, 
    {
        'energy':energy,
        'defocus':defocus,
        'semiangle':semiangle,
        'C30':C30,
    },
)
None
widget = ipywidgets.VBox(
    [
        fig.canvas,
        ipywidgets.HBox([
            ipywidgets.VBox([
                energy,
                defocus,
                semiangle,
                C30,  
            ]),
            ipywidgets.VBox([
                ipywidgets.Label('Preset probes',layout=ipywidgets.Layout(width='100px',height='30px')), 
                dropdown,
            ])
        ]),
    ],
)

display(widget);