A Practical Guide to Scanning and Transmission Electron Microscopy Simulations

Interactive plot for contrast in TEM

cif file

%matplotlib widget

import numpy as np
import matplotlib.pyplot as plt

from IPython.display import display
from ipywidgets import HBox, VBox, interact, Dropdown, Label, AppLayout, FloatSlider, FloatLogSlider, IntSlider, Layout, widgets
import abtem
from ase.io import read
from scipy.ndimage import gaussian_filter

# Copied from py4DSTEM directly
def Complex2RGB(complex_data, vmin=None, vmax = None, hue_start = 0, invert=False):
    """
    complex_data (array): complex array to plot
    vmin (float)        : minimum absolute value 
    vmax (float)        : maximum absolute value 
    hue_start (float)   : rotational offset for colormap (degrees)
    inverse (bool)      : if True, uses light color scheme
    """
    amp = np.abs(complex_data)
    if np.isclose(np.max(amp),np.min(amp)):
        if vmin is None:
            vmin = 0
        if vmax is None:
            vmax = np.max(amp)
    else:
        if vmin is None:
            vmin = 0.02
        if vmax is None:
            vmax = 0.98
        vals = np.sort(amp[~np.isnan(amp)])
        ind_vmin = np.round((vals.shape[0] - 1) * vmin).astype("int")
        ind_vmax = np.round((vals.shape[0] - 1) * vmax).astype("int")
        ind_vmin = np.max([0, ind_vmin])
        ind_vmax = np.min([len(vals) - 1, ind_vmax])
        vmin = vals[ind_vmin]
        vmax = vals[ind_vmax]

    amp = np.where(amp < vmin, vmin, amp)
    amp = np.where(amp > vmax, vmax, amp)

    phase = np.angle(complex_data) + np.deg2rad(hue_start)
    amp /= np.max(amp)
    rgb = np.zeros(phase.shape +(3,))
    rgb[...,0] = 0.5*(np.sin(phase)+1)*amp
    rgb[...,1] = 0.5*(np.sin(phase+np.pi/2)+1)*amp
    rgb[...,2] = 0.5*(-np.sin(phase)+1)*amp
    
    return 1-rgb if invert else rgb

class Zernike(abtem.transfer.BaseAperture):
    """
    Zernike aperture.

    Parameters
    ----------
    hole_cutoff : float
        Cutoff semiangle of aperture hole [mrad].
    phase_shift: float
        Phase shift of Zernike film [rad]
    semiangle_cutoff : float
        The cutoff semiangle of the aperture [mrad].
    energy : float, optional
        Electron energy [eV]. If not provided, inferred from the wave functions.
    extent : float or two float, optional
        Lateral extent of wave functions [Å] in `x` and `y` directions. If a single float is given, both are set equal.
    gpts : two ints, optional
        Number of grid points describing the wave functions.
    sampling : two float, optional
        Lateral sampling of wave functions [1 / Å]. If 'gpts' is also given, will be ignored.
    """

    def __init__(
        self,
        hole_cutoff: float,
        phase_shift: float,
        semiangle_cutoff: float,
        energy: float = None,
        extent: float | tuple[float, float] = None,
        gpts: int | tuple[int, int] = None,
        sampling: float | tuple[float, float] = None,
    ):
        self._hole_cutoff = hole_cutoff
        self._phase_shift = phase_shift
        super().__init__(
            energy=energy,
            semiangle_cutoff=semiangle_cutoff,
            extent=extent,
            gpts=gpts,
            sampling=sampling,
        )

    @property
    def hole_cutoff(self) -> float:
        """Cutoff semiangle of aperture hole."""
        return self._hole_cutoff
        
    @property
    def phase_shift(self) -> float:
        """Phase shift of Zernike film."""
        return self._phase_shift

    def _evaluate_from_angular_grid(
        self, alpha: np.ndarray, phi: np.ndarray
    ) -> np.ndarray:
        xp = abtem.core.backend.get_array_module(alpha)
        alpha = xp.array(alpha)

        semiangle_cutoff = self.semiangle_cutoff / 1e3
        hole_cutoff = self.hole_cutoff/ 1e3
        phase_shift = self.phase_shift
        
        amplitude = xp.asarray(alpha < semiangle_cutoff, dtype = "float")
        hole_array = xp.asarray(alpha > hole_cutoff, dtype = "float")
        phase = xp.exp(1j * phase_shift * hole_array)
        array = amplitude * phase
        self._amplitude = amplitude
        self._phase = phase
        self._save_array = xp.fft.fftshift(array)
        return array

atoms = read('data/3jcl.xyz')
atoms.positions[:,0] -= atoms.positions[:,0].min()
atoms.positions[:,1] -= atoms.positions[:,1].min()
atoms.positions[:,2] -= atoms.positions[:,2].min()

atoms.cell[0][0] = atoms.positions[:,0].max()
atoms.cell[1][1] = atoms.positions[:,1].max()
atoms.cell[2][2] = atoms.positions[:,2].max()

atoms.center(vacuum = 10, axis = (0,1))
atoms.cell[0][0] = atoms.cell[1][1]

sampling = 1
slice_thickness = 1 
potential = abtem.Potential(
    atoms, 
    sampling = sampling,
    slice_thickness = slice_thickness, 
    projection = 'infinite', 
    parametrization = 'kirkland',
)

potential = potential.build()


potential_blurred = abtem.PotentialArray(
    gaussian_filter(potential.array,0.5), 
    potential.slice_thickness, 
    potential.extent,
    potential.sampling, 
)

wave = abtem.waves.PlaneWave(energy=300e3)
exit_waves = wave.multislice(potential_blurred).compute()

with plt.ioff():
    dpi = 72
    fig = plt.figure(figsize=(675/dpi, 225/dpi), dpi=dpi)

ax0 = fig.add_axes([0.04,  0.05,  0.28, 0.75])
ax1 = fig.add_axes([0.37,  0.05,  0.28, 0.75])
ax2 = fig.add_axes([0.70,  0.05,  0.28, 0.75])
ax3 = fig.add_axes([0.7,  0.6,  0.1, 0.2])

dose = 100
energy = 300e3
semiangle_cutoff = 8
defocus = 1000

ctf_focus = abtem.transfer.CTF(
    semiangle_cutoff = semiangle_cutoff,
    energy=energy
)


ctf_defocus = abtem.transfer.CTF(
    defocus = defocus,
    semiangle_cutoff = semiangle_cutoff,
    energy=energy
)

ctf_zernike = Zernike(
    energy = energy, 
    hole_cutoff=1, 
    semiangle_cutoff = semiangle_cutoff,
    phase_shift = np.pi/2
)


noisy_focus = exit_waves.apply_transform(ctf_focus).intensity().poisson_noise(dose).compute()

noisy_defocus = exit_waves.apply_transform(ctf_defocus).intensity().poisson_noise(dose).compute()

noisy_zernike = exit_waves.apply_transform(ctf_zernike).intensity().poisson_noise(dose)

cmap = 'gray'

vmax = np.max([noisy_focus.array, noisy_defocus.array, noisy_zernike.array])
vmin = np.min([noisy_focus.array, noisy_defocus.array, noisy_zernike.array])
im0 = ax0.imshow(noisy_focus.array, vmax = vmax, vmin = vmin, cmap = cmap)
im1 = ax1.imshow(noisy_defocus.array, vmax = vmax, vmin = vmin, cmap = cmap)
im2 = ax2.imshow(noisy_zernike.array, vmax = vmax, vmin = vmin, cmap = cmap)
im3 = ax3.imshow(Complex2RGB(ctf_zernike._save_array, 0, 1))

ax0.set_xticks([])  
ax0.set_yticks([]) 
ax1.set_xticks([])  
ax1.set_yticks([]) 
ax2.set_xticks([])  
ax2.set_yticks([]) 
ax3.set_xticks([])
ax3.set_yticks([])
ax3.set_xlabel('')
ax3.set_ylabel('')
ax0.set_title('In focus intensity');
ax1.set_title('Defocused intensity');
ax2.set_title('Intensity with Zernike\nphase plate (In focus)');

def update_ims(dose, defocus, phase_shift, zernike_radius):
    phase_shift = np.deg2rad(phase_shift)
    ctf_focus = abtem.transfer.CTF(
        defocus = 0,
        semiangle_cutoff = semiangle_cutoff,
        energy=energy
    )

    ctf_defocus = abtem.transfer.CTF(
        defocus = defocus,
        semiangle_cutoff = semiangle_cutoff,
        energy=energy
    )

    ctf_zernike = Zernike(  
        energy = energy, 
        hole_cutoff=zernike_radius, 
        semiangle_cutoff = semiangle_cutoff,
        phase_shift = phase_shift
    )



    noisy_focus = exit_waves.apply_ctf(ctf_focus).intensity().poisson_noise(dose).compute()
    noisy_defocus = exit_waves.apply_ctf(ctf_defocus).intensity().poisson_noise(dose).compute()
    noisy_zernike = exit_waves.apply_ctf(ctf_zernike).intensity().poisson_noise(dose)
    
    vmax = np.max([noisy_focus.array, noisy_defocus.array, noisy_zernike.array])
    vmin = np.min([noisy_focus.array, noisy_defocus.array, noisy_zernike.array])
    
    im0.set_data(noisy_focus.array)
    im0.set_clim(vmax = vmax, vmin = vmin)
    im1.set_data(noisy_defocus.array)
    im1.set_clim(vmax = vmax, vmin = vmin)
    im2.set_data(noisy_zernike.array)
    im2.set_clim(vmax = vmax, vmin = vmin)
    im3.set_data(Complex2RGB(ctf_zernike._save_array, 0, 1))

    fig.canvas.draw_idle()

style = {
    'description_width': 'initial',
}

layout = Layout(width="325px",height="30px")

defocus = IntSlider(
    value = 1000, min = -10000, max = 10000, 
    step = 100,
    description = "defocus (A)",
    style = style,
    layout=layout,
)


dose = FloatLogSlider(
    value=100,
    base=10,
    min=0, # min exponent of base
    max=5, # max exponent of base
    step=0.05, # exponent step
    description = r"dose (e$^-$/A$^2$)",
    style = style,
    layout=layout,
)

phase_shift = IntSlider(
    value = 90, min = 0, max = 180, 
    step = np.pi/8,
    description = r"phase shift ($^\circ$)",
    style = style,
    layout=layout,
)

zernike_radius = IntSlider(
    value = 1.0, min = 1.0, max = 8, 
    step = 1.0,
    description = r'radius of shift (mrad)',
    style = style,
    layout=layout,
    # readout_format='.1f',
)

widgets.interactive_output(
    update_ims, 
    {
        'dose':dose,
        'defocus':defocus,
        'phase_shift':phase_shift,
        'zernike_radius':zernike_radius,
    },
)

fig.canvas.resizable = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.toolbar_visible = True
fig.canvas.layout.width = '675px'
fig.canvas.layout.height = '255px'
fig.canvas.toolbar_position = 'bottom'

widget = widgets.VBox(
    [
        fig.canvas,
        HBox([
            VBox([dose,defocus]), 
            VBox([phase_shift,zernike_radius])
        ]),
    ],
)
[########################################] | 100% Completed | 506.21 ms
display(widget);