Contents
Interactive plot for contrast in TEM
%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display
from ipywidgets import HBox, VBox, interact, Dropdown, Label, AppLayout, FloatSlider, FloatLogSlider, IntSlider, Layout, widgets
import abtem
from ase.io import read
from scipy.ndimage import gaussian_filter
# Copied from py4DSTEM directly
def Complex2RGB(complex_data, vmin=None, vmax = None, hue_start = 0, invert=False):
"""
complex_data (array): complex array to plot
vmin (float) : minimum absolute value
vmax (float) : maximum absolute value
hue_start (float) : rotational offset for colormap (degrees)
inverse (bool) : if True, uses light color scheme
"""
amp = np.abs(complex_data)
if np.isclose(np.max(amp),np.min(amp)):
if vmin is None:
vmin = 0
if vmax is None:
vmax = np.max(amp)
else:
if vmin is None:
vmin = 0.02
if vmax is None:
vmax = 0.98
vals = np.sort(amp[~np.isnan(amp)])
ind_vmin = np.round((vals.shape[0] - 1) * vmin).astype("int")
ind_vmax = np.round((vals.shape[0] - 1) * vmax).astype("int")
ind_vmin = np.max([0, ind_vmin])
ind_vmax = np.min([len(vals) - 1, ind_vmax])
vmin = vals[ind_vmin]
vmax = vals[ind_vmax]
amp = np.where(amp < vmin, vmin, amp)
amp = np.where(amp > vmax, vmax, amp)
phase = np.angle(complex_data) + np.deg2rad(hue_start)
amp /= np.max(amp)
rgb = np.zeros(phase.shape +(3,))
rgb[...,0] = 0.5*(np.sin(phase)+1)*amp
rgb[...,1] = 0.5*(np.sin(phase+np.pi/2)+1)*amp
rgb[...,2] = 0.5*(-np.sin(phase)+1)*amp
return 1-rgb if invert else rgb
class Zernike(abtem.transfer.BaseAperture):
"""
Zernike aperture.
Parameters
----------
hole_cutoff : float
Cutoff semiangle of aperture hole [mrad].
phase_shift: float
Phase shift of Zernike film [rad]
semiangle_cutoff : float
The cutoff semiangle of the aperture [mrad].
energy : float, optional
Electron energy [eV]. If not provided, inferred from the wave functions.
extent : float or two float, optional
Lateral extent of wave functions [Å] in `x` and `y` directions. If a single float is given, both are set equal.
gpts : two ints, optional
Number of grid points describing the wave functions.
sampling : two float, optional
Lateral sampling of wave functions [1 / Å]. If 'gpts' is also given, will be ignored.
"""
def __init__(
self,
hole_cutoff: float,
phase_shift: float,
semiangle_cutoff: float,
energy: float = None,
extent: float | tuple[float, float] = None,
gpts: int | tuple[int, int] = None,
sampling: float | tuple[float, float] = None,
):
self._hole_cutoff = hole_cutoff
self._phase_shift = phase_shift
super().__init__(
energy=energy,
semiangle_cutoff=semiangle_cutoff,
extent=extent,
gpts=gpts,
sampling=sampling,
)
@property
def hole_cutoff(self) -> float:
"""Cutoff semiangle of aperture hole."""
return self._hole_cutoff
@property
def phase_shift(self) -> float:
"""Phase shift of Zernike film."""
return self._phase_shift
def _evaluate_from_angular_grid(
self, alpha: np.ndarray, phi: np.ndarray
) -> np.ndarray:
xp = abtem.core.backend.get_array_module(alpha)
alpha = xp.array(alpha)
semiangle_cutoff = self.semiangle_cutoff / 1e3
hole_cutoff = self.hole_cutoff/ 1e3
phase_shift = self.phase_shift
amplitude = xp.asarray(alpha < semiangle_cutoff, dtype = "float")
hole_array = xp.asarray(alpha > hole_cutoff, dtype = "float")
phase = xp.exp(1j * phase_shift * hole_array)
array = amplitude * phase
self._amplitude = amplitude
self._phase = phase
self._save_array = xp.fft.fftshift(array)
return array
atoms = read('data/3jcl.xyz')
atoms.positions[:,0] -= atoms.positions[:,0].min()
atoms.positions[:,1] -= atoms.positions[:,1].min()
atoms.positions[:,2] -= atoms.positions[:,2].min()
atoms.cell[0][0] = atoms.positions[:,0].max()
atoms.cell[1][1] = atoms.positions[:,1].max()
atoms.cell[2][2] = atoms.positions[:,2].max()
atoms.center(vacuum = 10, axis = (0,1))
atoms.cell[0][0] = atoms.cell[1][1]
sampling = 1
slice_thickness = 1
potential = abtem.Potential(
atoms,
sampling = sampling,
slice_thickness = slice_thickness,
projection = 'infinite',
parametrization = 'kirkland',
)
potential = potential.build()
potential_blurred = abtem.PotentialArray(
gaussian_filter(potential.array,0.5),
potential.slice_thickness,
potential.extent,
potential.sampling,
)
wave = abtem.waves.PlaneWave(energy=300e3)
exit_waves = wave.multislice(potential_blurred).compute()
with plt.ioff():
dpi = 72
fig = plt.figure(figsize=(675/dpi, 225/dpi), dpi=dpi)
ax0 = fig.add_axes([0.04, 0.05, 0.28, 0.75])
ax1 = fig.add_axes([0.37, 0.05, 0.28, 0.75])
ax2 = fig.add_axes([0.70, 0.05, 0.28, 0.75])
ax3 = fig.add_axes([0.7, 0.6, 0.1, 0.2])
dose = 100
energy = 300e3
semiangle_cutoff = 8
defocus = 1000
ctf_focus = abtem.transfer.CTF(
semiangle_cutoff = semiangle_cutoff,
energy=energy
)
ctf_defocus = abtem.transfer.CTF(
defocus = defocus,
semiangle_cutoff = semiangle_cutoff,
energy=energy
)
ctf_zernike = Zernike(
energy = energy,
hole_cutoff=1,
semiangle_cutoff = semiangle_cutoff,
phase_shift = np.pi/2
)
noisy_focus = exit_waves.apply_transform(ctf_focus).intensity().poisson_noise(dose).compute()
noisy_defocus = exit_waves.apply_transform(ctf_defocus).intensity().poisson_noise(dose).compute()
noisy_zernike = exit_waves.apply_transform(ctf_zernike).intensity().poisson_noise(dose)
cmap = 'gray'
vmax = np.max([noisy_focus.array, noisy_defocus.array, noisy_zernike.array])
vmin = np.min([noisy_focus.array, noisy_defocus.array, noisy_zernike.array])
im0 = ax0.imshow(noisy_focus.array, vmax = vmax, vmin = vmin, cmap = cmap)
im1 = ax1.imshow(noisy_defocus.array, vmax = vmax, vmin = vmin, cmap = cmap)
im2 = ax2.imshow(noisy_zernike.array, vmax = vmax, vmin = vmin, cmap = cmap)
im3 = ax3.imshow(Complex2RGB(ctf_zernike._save_array, 0, 1))
ax0.set_xticks([])
ax0.set_yticks([])
ax1.set_xticks([])
ax1.set_yticks([])
ax2.set_xticks([])
ax2.set_yticks([])
ax3.set_xticks([])
ax3.set_yticks([])
ax3.set_xlabel('')
ax3.set_ylabel('')
ax0.set_title('In focus intensity');
ax1.set_title('Defocused intensity');
ax2.set_title('Intensity with Zernike\nphase plate (In focus)');
def update_ims(dose, defocus, phase_shift, zernike_radius):
phase_shift = np.deg2rad(phase_shift)
ctf_focus = abtem.transfer.CTF(
defocus = 0,
semiangle_cutoff = semiangle_cutoff,
energy=energy
)
ctf_defocus = abtem.transfer.CTF(
defocus = defocus,
semiangle_cutoff = semiangle_cutoff,
energy=energy
)
ctf_zernike = Zernike(
energy = energy,
hole_cutoff=zernike_radius,
semiangle_cutoff = semiangle_cutoff,
phase_shift = phase_shift
)
noisy_focus = exit_waves.apply_ctf(ctf_focus).intensity().poisson_noise(dose).compute()
noisy_defocus = exit_waves.apply_ctf(ctf_defocus).intensity().poisson_noise(dose).compute()
noisy_zernike = exit_waves.apply_ctf(ctf_zernike).intensity().poisson_noise(dose)
vmax = np.max([noisy_focus.array, noisy_defocus.array, noisy_zernike.array])
vmin = np.min([noisy_focus.array, noisy_defocus.array, noisy_zernike.array])
im0.set_data(noisy_focus.array)
im0.set_clim(vmax = vmax, vmin = vmin)
im1.set_data(noisy_defocus.array)
im1.set_clim(vmax = vmax, vmin = vmin)
im2.set_data(noisy_zernike.array)
im2.set_clim(vmax = vmax, vmin = vmin)
im3.set_data(Complex2RGB(ctf_zernike._save_array, 0, 1))
fig.canvas.draw_idle()
style = {
'description_width': 'initial',
}
layout = Layout(width="325px",height="30px")
defocus = IntSlider(
value = 1000, min = -10000, max = 10000,
step = 100,
description = "defocus (A)",
style = style,
layout=layout,
)
dose = FloatLogSlider(
value=100,
base=10,
min=0, # min exponent of base
max=5, # max exponent of base
step=0.05, # exponent step
description = r"dose (e$^-$/A$^2$)",
style = style,
layout=layout,
)
phase_shift = IntSlider(
value = 90, min = 0, max = 180,
step = np.pi/8,
description = r"phase shift ($^\circ$)",
style = style,
layout=layout,
)
zernike_radius = IntSlider(
value = 1.0, min = 1.0, max = 8,
step = 1.0,
description = r'radius of shift (mrad)',
style = style,
layout=layout,
# readout_format='.1f',
)
widgets.interactive_output(
update_ims,
{
'dose':dose,
'defocus':defocus,
'phase_shift':phase_shift,
'zernike_radius':zernike_radius,
},
)
fig.canvas.resizable = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.toolbar_visible = True
fig.canvas.layout.width = '675px'
fig.canvas.layout.height = '255px'
fig.canvas.toolbar_position = 'bottom'
widget = widgets.VBox(
[
fig.canvas,
HBox([
VBox([dose,defocus]),
VBox([phase_shift,zernike_radius])
]),
],
)
[########################################] | 100% Completed | 506.21 ms
display(widget);
VBox(children=(Canvas(footer_visible=False, header_visible=False, layout=Layout(height='255px', width='675px')…