A Practical Guide to Scanning and Transmission Electron Microscopy Simulations

TEM Imaging

Imports

%matplotlib widget
import ase
import abtem

import matplotlib.pyplot as plt
import numpy as np
from scipy.spatial import Voronoi, voronoi_plot_2d, cKDTree

from matplotlib.patches import Circle
from ipywidgets import HBox, VBox, widgets, interact, Dropdown, Label, Layout

Atomic Model

STO_unit_cell   = ase.io.read('data/SrTiO3.cif')
STO_atoms       = ase.build.surface(STO_unit_cell, (1,1,0), 4, periodic=True)*(8*4,12*4,4)
np.random.seed(111111)
points = np.random.rand(8,2)
points = np.vstack((points,np.array([[0,0],[0,1],[1,0],[1,1]])))
points[:,0] *= STO_atoms.cell[0,0]/2
points[:,1] *= STO_atoms.cell[1,1]/2
points[:,0] += STO_atoms.cell[0,0]/4
points[:,1] += STO_atoms.cell[1,1]/4

voronoi = Voronoi(points)
voronoi_kdtree = cKDTree(points)

grains = ase.Atoms(cell=STO_atoms.cell,pbc=True)

np.random.seed(111111)
random_angles = np.random.rand(8)*180-90

for grain in range(12):
    STO_atoms_rotated = STO_atoms.copy()
    angle = np.random.choice(random_angles)
    #print(f"{grain=}, {angle=:.3}")
    STO_atoms_rotated.rotate(
        angle,
        (0,0,1),
        center=STO_atoms.cell.lengths()/2
    )
    _, reg = voronoi_kdtree.query(STO_atoms_rotated.positions[:,:2])
    del STO_atoms_rotated[np.invert(reg==grain)]

    grains+=STO_atoms_rotated

del grains[grains.positions[:,0]<STO_atoms.cell[0,0]/4]
del grains[grains.positions[:,0]>3*STO_atoms.cell[0,0]/4]
del grains[grains.positions[:,1]<STO_atoms.cell[1,1]/4]
del grains[grains.positions[:,1]>3*STO_atoms.cell[1,1]/4]

grains.cell[0,0] = grains.cell[1,1] = 80
grains.center(axis=(0,1), vacuum = 20)

# abtem.show_atoms(grains);

abTEM potential and wave

potential = abtem.Potential(
    grains,
    gpts=(512,512),
    device='cpu',
    projection='infinite',
    parametrization='kirkland'
).build()

wave = abtem.PlaneWave(energy=300e3)

TEM Simulation

exit_wave = wave.multislice(potential)
exit_wave.compute()
[########################################] | 100% Completed | 1.68 sms
<abtem.waves.Waves at 0x154167e90>
def return_scaled_histogram_ordering(array, vmin=None, vmax=None, normalize=False):
    """
    Utility function for calculating min and max values for plotting array
    based on distribution of pixel values

    Parameters
    ----------
    array: np.array
        array to be plotted
    vmin: float
        lower fraction cut off of pixel values
    vmax: float
        upper fraction cut off of pixel values
    normalize: bool
        if True, rescales from 0 to 1

    Returns
    ----------
    scaled_array: np.array
        array clipped outside vmin and vmax
    vmin: float
        lower value to be plotted
    vmax: float
        upper value to be plotted
    """

    if vmin is None:
        vmin = 0.02
    if vmax is None:
        vmax = 0.98

    vals = np.sort(array.ravel())
    ind_vmin = np.round((vals.shape[0] - 1) * vmin).astype("int")
    ind_vmax = np.round((vals.shape[0] - 1) * vmax).astype("int")
    ind_vmin = np.max([0, ind_vmin])
    ind_vmax = np.min([len(vals) - 1, ind_vmax])
    vmin = vals[ind_vmin]
    vmax = vals[ind_vmax]

    if vmax == vmin:
        vmin = vals[0]
        vmax = vals[-1]

    scaled_array = array.copy()
    scaled_array = np.where(scaled_array < vmin, vmin, scaled_array)
    scaled_array = np.where(scaled_array > vmax, vmax, scaled_array)

    if normalize:
        scaled_array -= scaled_array.min()
        scaled_array /= scaled_array.max()
        vmin = 0
        vmax = 1

    return scaled_array, vmin, vmax
exit_wave_array = exit_wave.array[80:-80,80:-80]
s = exit_wave_array.shape
x = np.fft.fftfreq(s[0], 1/s[0])
y = np.fft.fftfreq(s[1], 1/s[1])

yy, xx = np.meshgrid(x, y)
with plt.ioff():
    dpi = 72
    fig, axs = plt.subplots(1,2, figsize=(675/dpi, 400/dpi), dpi=dpi)

#plot FFT
w0 = np.hanning(s[1]) * np.hanning(s[0])[:, None]
axs[0].imshow(
    return_scaled_histogram_ordering(
        np.fft.fftshift(
            np.abs(
                np.fft.fft2(
                    exit_wave_array * w0
                )
            )
        )
    )[0],
    cmap = "gray"
)


#make mask
radius = 35
x0, y0 = 0,0

mask = np.zeros(s)
mask[(xx-x0)**2 + (yy-y0) **2 < radius**2] = 1
    
# add circles 
circle = Circle((
    y0+s[0]/2, x0 + s[0]/2
), radius, fill = False, color = "red", linewidth = 5)

axs[0].add_patch(circle)


#plot image
axs[1].imshow(
    return_scaled_histogram_ordering(np.abs(np.fft.ifft2((np.fft.fft2(exit_wave_array) * mask))))[0],
    cmap = "gray",
    
)

axs[0].set_xticks([])
axs[0].set_yticks([])

axs[1].set_xticks([])
axs[1].set_yticks([])


axs[0].set_title('diffraction plane')
axs[1].set_title('imaging plane')


# interact
def update_ims(x0, y0, radius):
    mask = np.zeros(s)
    mask[(xx-x0)**2 + (yy-y0) **2 < radius**2] = 1
    mask[s[0]//2-1:s[0]//2+1, s[1]//2-1:s[1]//2+1] = 1
    
    axs[0].clear()
    axs[0].imshow(
        return_scaled_histogram_ordering(
            np.fft.fftshift(
                np.abs(
                    np.fft.fft2(
                        exit_wave_array
                    )
                )
            )
        )[0],
        cmap = "gray"
    )

    circle = Circle((
       x0 + s[0]/2, y0 + s[0]/2
    ), radius, fill = False, color = "red", linewidth = 5)
    
    axs[0].add_patch(circle)
    
    axs[0].set_xticks([])
    axs[0].set_yticks([])

    axs[0].set_title('diffraction plane')

    axs[1].imshow(
        return_scaled_histogram_ordering(np.abs(np.fft.ifft2((np.fft.fft2(exit_wave_array) * mask))))[0],
        cmap = "gray",
        
    )
    plt.tight_layout()    
    fig.canvas.draw_idle()

style = {
    'description_width': 'initial',
}

layout = Layout(width='250px',height='30px')

x0 = widgets.IntSlider(
    value = 0, min = -s[0]/2, max = s[0]/2, 
    step = 2,
    description = "x position",
    style = style,
    layout = layout,
)


y0 = widgets.IntSlider(
    value = 0, min = -s[0]/2, max = s[0]/2, 
    step = 2,
    description = "y position",
    style = style,
    layout = layout,
)


radius = widgets.IntSlider(
    value = 30, min = 0, max = 100,
    step = 5,
    description = "aperature radius",
    style = style,
    layout = layout,
)

widgets.interactive_output(
    update_ims, 
    {
        'x0':x0,
        'y0':y0,
        'radius':radius,
    },
)


fig.canvas.resizable = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.toolbar_visible = True
fig.canvas.layout.width = '675px'
fig.canvas.layout.height = '400px'
fig.canvas.toolbar_position = 'bottom'

widget = widgets.VBox(
    [
        fig.canvas,
        # VBox(
            HBox(
                [x0,y0,radius]
            )
        # ),
    ],
)
display(widget);