Evaluating the Transfer of Information in Phase Retrieval STEM Techniques
Contents
Segmented Parallax
# enable interactive matplotlib
%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import matplotlib.patches as mpatches
import ctf # import custom plotting / utils
import cmasher as cmr
import ipywidgets
4D STEM Simulation¶
# parameters
n = 96
q_max = 2 # inverse Angstroms
q_probe = 1 # inverse Angstroms
wavelength = 0.019687 # 300kV
sampling = 1 / q_max / 2 # Angstroms
reciprocal_sampling = 2 * q_max / n # inverse Angstroms
scan_step_size = 1 # pixels
sx = sy = n//scan_step_size
phi0 = 1.0
C10 = 46
cmap = cmr.eclipse
sample_cmap = 'gray'
# segmented_icom_line_color = 'cornflowerblue'
segmented_parallax_line_color = 'darksalmon'
pixelated_parallax_line_color = 'darkred'
White Noise Potential¶
def white_noise_object_2D(n, phi0):
""" creates a 2D real-valued array, whose FFT has random phase and constant amplitude """
evenQ = n%2 == 0
# indices
pos_ind = np.arange(1,(n if evenQ else n+1)//2)
neg_ind = np.flip(np.arange(n//2+1,n))
# random phase
arr = np.random.randn(n,n)
# top-left // bottom-right
arr[pos_ind[:,None],pos_ind[None,:]] = -arr[neg_ind[:,None],neg_ind[None,:]]
# bottom-left // top-right
arr[pos_ind[:,None],neg_ind[None,:]] = -arr[neg_ind[:,None],pos_ind[None,:]]
# kx=0
arr[0,pos_ind] = -arr[0,neg_ind]
# ky=0
arr[pos_ind,0] = -arr[neg_ind,0]
# zero-out components which don't have k-> -k mapping
if evenQ:
arr[n//2,:] = 0 # zero highest spatial freq
arr[:,n//2] = 0 # zero highest spatial freq
arr[0,0] = 0 # DC component
# fourier-array
arr = np.exp(2j*np.pi*arr)*phi0
# inverse FFT and remove floating point errors
arr = np.fft.ifft2(arr).real
return arr
# potential
potential = white_noise_object_2D(n,phi0)
complex_obj = np.exp(1j*potential)
Import sample potentials¶
sto_potential = np.load("data/STO_projected-potential_192x192_4qprobe.npy")
sto_potential -= sto_potential.mean()
mof_potential = np.load("data/MOF_projected-potential_192x192_4qprobe.npy")
mof_potential -= mof_potential.mean()
apo_potential = np.load("data/apoF_projected-potential_192x192_4qprobe.npy")
apo_potential -= apo_potential.mean()
Probe¶
# we build probe in Fourier space, using a soft aperture
qx = qy = np.fft.fftfreq(n,sampling)
q2 = qx[:,None]**2 + qy[None,:]**2
q = np.sqrt(q2)
x = y = np.arange(0.,n,scan_step_size)
xx, yy = np.meshgrid(x,y,indexing='ij')
positions = np.stack((xx.ravel(),yy.ravel()),axis=-1)
row, col = ctf.return_patch_indices(positions,(n,n),(n,n))
probe_array_fourier_0 = np.sqrt(
np.clip(
(q_probe - q)/reciprocal_sampling + 0.5,
0,
1,
),
)
def simulate_intensities(C10):
probe_array_fourier = probe_array_fourier_0 * np.exp(-1j * np.pi * wavelength * q**2 * C10)
# normalized s.t. np.sum(np.abs(probe_array_fourier)**2) = 1.0
probe_array_fourier /= np.sqrt(np.sum(np.abs(probe_array_fourier)**2))
# we then take the inverse FFT, and normalize s.t. np.sum(np.abs(probe_array)**2) = 1.0
probe_array = np.fft.ifft2(probe_array_fourier) * n
intensities = ctf.simulate_data(
complex_obj,
probe_array,
row,
col,
).reshape((sx,sy,n,n))**2 / n**2
return intensities, probe_array_fourier
ints, probe = simulate_intensities(C10=C10)
intensities = [ints]
probe_array_fourier = [probe]
Virtual Detectors and CoM calculation¶
def annular_segmented_detectors(
gpts,
sampling,
n_angular_bins,
rotation_offset = 0,
inner_radius = 0,
outer_radius = np.inf,
):
""" """
nx,ny = gpts
sx,sy = sampling
k_x = np.fft.fftfreq(nx,sx)
k_y = np.fft.fftfreq(ny,sy)
k = np.sqrt(k_x[:,None]**2 + k_y[None,:]**2)
radial_mask = ((inner_radius <= k) & (k < outer_radius))
theta = (np.arctan2(k_y[None,:], k_x[:,None]) + rotation_offset) % (2 * np.pi)
angular_bins = np.floor(n_angular_bins * (theta / (2 * np.pi))) + 1
angular_bins *= radial_mask.astype("int")
angular_bins = [np.fft.fftshift((angular_bins == i).astype("int")) for i in range(1,n_angular_bins+1)]
return angular_bins
Compute CTFs and initial values¶
# Spatial frequencies
kx = ky = np.fft.fftfreq(n,sampling).astype(np.float32)
kxa, kya = np.meshgrid(kx, ky, indexing='ij')
k2 = kxa**2 + kya**2
k = np.sqrt(k2)
k2[0, 0] = np.inf
# Initial masks and CoM
virtual_masks_annular = annular_segmented_detectors(
gpts=(n,n),
sampling=(sampling,sampling),
n_angular_bins=4,
inner_radius=q_probe/2,
outer_radius=q_probe*1.05,
rotation_offset=0,
)
qx_shift = -2j * np.pi * qx[:,None]
qy_shift = -2j * np.pi * qy[None,:]
def return_segmented_parallax_ctf(
intensities,
virtual_masks_annular,
C10,
):
""" """
corner_shifted_masks = np.fft.ifftshift(virtual_masks_annular)
corner_shifted_bf_masks = []
for mask in corner_shifted_masks:
kxa_i,kya_i=np.where(mask)
centroid_norm = k[kxa_i,kya_i].mean()
if centroid_norm <= q_probe*1.05 or number_of_rings_slider.value == 1:
corner_shifted_bf_masks.append(mask)
corner_shifted_bf_masks = np.array(corner_shifted_bf_masks)
num_masks = corner_shifted_bf_masks.shape[0]
vbfs = np.empty((num_masks,n,n))
shifts_ang = np.empty((num_masks,2))
grad_x, grad_y = np.meshgrid(
kx * wavelength * C10,
ky * wavelength * C10,
indexing='ij'
)
for i, mask in enumerate(corner_shifted_bf_masks):
kxa_i,kya_i=np.where(mask)
vbf = intensities[...,kxa_i,kya_i].sum(-1)
vbfs[i] = vbf / vbf.mean() - 1
shifts_ang[i,0] = grad_x[kxa_i,kya_i].mean()
shifts_ang[i,1] = grad_y[kxa_i,kya_i].mean()
Gs = np.fft.fft2(vbfs)
shift_op = np.exp(
qx_shift[None] * shifts_ang[:,0,None, None]
+ qy_shift[None] * shifts_ang[:,1,None, None]
)
shifted_stack = np.fft.ifft2(Gs*shift_op).real.mean(0)
return shifted_stack
def return_analytical_parallax_ctf(
C10,
):
""" """
sin_chi = -np.sin(np.pi * wavelength * q**2 * C10)
aperture = probe_array_fourier_0 / np.sqrt(np.sum(np.abs(probe_array_fourier_0)**2))
aperture_autocorrelation = np.real(
np.fft.ifft2(
np.abs(
np.fft.fft2(
aperture
)
)**2
)
)
return np.abs(sin_chi) * aperture_autocorrelation
analytical_parallax_ctf = return_analytical_parallax_ctf(
C10
)
parallax_recon = return_segmented_parallax_ctf(
intensities[0],
virtual_masks_annular,
C10
)
ctf_parallax = ctf.compute_ctf(parallax_recon)
q_bins_parallax, I_bins_parallax = ctf.radially_average_ctf(
ctf_parallax,
(sampling,sampling)
)
q_bins_parallax_analytic, I_bins_parallax_analytic = ctf.radially_average_ctf(
analytical_parallax_ctf,
(sampling,sampling)
)
Visualization¶
Base Plot¶
We make the interactive plot using the initial values, and name the artists (imshow, plot) we want to modify later.
Note: I use 2-98% histogram scaling, and I normalize the values to lie within 0-1 (to avoid having to modify the clims)
with plt.ioff():
dpi=72
fig, axs = plt.subplots(2,4,figsize=(640/dpi,400/dpi),dpi=dpi)
# detector
ax_detector = axs[0,0]
im_detector = ax_detector.imshow(ctf.combined_images_rgb(virtual_masks_annular))
ctf.add_scalebar(ax_detector,length=n//4,sampling=reciprocal_sampling,units=r'$q_{\mathrm{probe}}$')
# analytical CTF
ax_ctf_annular = axs[0,1]
im_ctf = ax_ctf_annular.imshow(ctf.histogram_scaling(np.fft.fftshift(analytical_parallax_ctf),normalize=True),cmap=cmap)
# annular parallax
ax_ctf_parallax = axs[0,2]
im_ctf_parallax = ax_ctf_parallax.imshow(ctf.histogram_scaling(np.fft.fftshift(ctf_parallax),normalize=True),cmap=cmap)
# CTF radially-averaged
ax_ctf_rad = axs[0,3]
plot_ctf_parallax_analytic = ax_ctf_rad.plot(q_bins_parallax_analytic,I_bins_parallax_analytic,color=pixelated_parallax_line_color,label='pixelated tcBF')[0]
plot_ctf_parallax = ax_ctf_rad.plot(q_bins_parallax,I_bins_parallax,color=segmented_parallax_line_color,label='segmented tcBF')[0]
ax_ctf_rad.legend()
# remove ticks, add titles to 2D-plots
for ax, title in zip(
axs.flatten(),
[
"detector geometry",
"pixelated CTF",
"segmented CTF",
"radially averaged CTF",
"white noise object",
"strontium titanate",
"metal-organic framework",
"apoferritin protein",
]
):
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(title)
for ax in axs[0,:3]:
ctf.add_scalebar(ax,length=n//4,sampling=reciprocal_sampling,units=r'$q_{\mathrm{probe}}$')
# remove y-ticks, add x-label, add vlines to radial avg. plot
ax_ctf_rad.set_ylim([0,1])
ax_ctf_rad.set_xlim([0,q_max])
ax_ctf_rad.vlines([q_probe/2,q_probe*1.05],0,2,colors='k',linestyles='--',linewidth=1,)
ax_ctf_rad.set_xticks([0,q_probe,q_max])
ax_ctf_rad.set_xticklabels([0,1,2])
ax_ctf_rad.set_aspect(2)
ax_ctf_rad.set_xlabel(r"spatial frequency, $q/q_{\mathrm{probe}}$")
ax_white_noise_obj = axs[1,0]
im_white_noise_obj = ax_white_noise_obj.imshow(
ctf.histogram_scaling(
np.fft.ifft2(np.fft.fft2(potential) * ctf_parallax).real,
normalize=True
),
cmap=sample_cmap
)
ctf.add_scalebar(ax_white_noise_obj,length=n//5,sampling=sampling,units=r'Å')
zero_pad_ctf_to_4qprobe = np.fft.ifftshift(
np.pad(np.fft.fftshift(ctf_parallax),48)
)
resample_2qprobe_ctf_to_192 = np.fft.fft2(
np.fft.ifftshift(
np.pad(np.fft.fftshift(np.fft.ifft2(ctf_parallax).real),48)
)
)
ax_sto_obj = axs[1,1]
im_sto_obj = ax_sto_obj.imshow(
ctf.histogram_scaling(
np.fft.ifft2(np.fft.fft2(sto_potential) * zero_pad_ctf_to_4qprobe).real,
normalize=True
),
cmap=sample_cmap
)
sto_sampling = 23.67 / sto_potential.shape[0] # Å
ctf.add_scalebar(ax_sto_obj,length=40,sampling=sto_sampling,units=r'Å',size_vertical=2)
ax_mof_obj = axs[1,2]
im_mof_obj = ax_mof_obj.imshow(
ctf.histogram_scaling(
np.fft.ifft2(np.fft.fft2(mof_potential) * zero_pad_ctf_to_4qprobe).real,
normalize=True
),
cmap=sample_cmap
)
mof_sampling = 4.48 / mof_potential.shape[0] # nm
ctf.add_scalebar(ax_mof_obj,length=40,sampling=mof_sampling,units=r'nm',size_vertical=2)
ax_apo_obj = axs[1,3]
im_apo_obj = ax_apo_obj.imshow(
ctf.histogram_scaling(
np.fft.ifft2(np.fft.fft2(apo_potential) * zero_pad_ctf_to_4qprobe).real,
normalize=True
),
cmap=sample_cmap
)
apo_sampling = 19.2 / apo_potential.shape[0] # nm
ctf.add_scalebar(ax_apo_obj,length=40,sampling=apo_sampling,units=r'nm',size_vertical=2)
# fix ipympl canvas from resizing
fig.canvas.resizable = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.toolbar_visible = False
# fig.canvas.toolbar_visible = True
# fig.canvas.toolbar_position = 'bottom'
fig.canvas.layout.width = '640px'
fig.canvas.layout.height = '420px'
fig.tight_layout()
None
Interactive Updating¶
We need to update 6 parts of the plot each time we update:
- The virtual masks (
im_detector
) - The annular CTF (
im_ctf
) - The radially-averaged annular CTF (
plot_ctf
) - The vlines on the radially-averaged plots
- This one doesn’t have a single artist we can update. Instead we remove all LineCollections and replot
Widget¶
We define our 4 sliders, as-well as two callback functions to update the outer collection angle minimum and the meaningful rotation offset range
style = {'description_width': 'initial'}
layout = ipywidgets.Layout(width="320px",height="30px")
kwargs = {'style':style,'layout':layout,'continuous_update':False}
inner_collection_angle_slider = ipywidgets.FloatSlider(
value = q_probe/4,
min = 0,
max = q_probe/2,
step = q_probe/20,
description = r"inner collection angle [$q_{\mathrm{probe}}$]",
**kwargs
)
outer_collection_angle_slider = ipywidgets.FloatSlider(
value = q_probe + q_probe/20,
min = q_probe/2 + q_probe/10,
max = q_max,
step = q_probe/20,
description = r"outer collection angle [$q_{\mathrm{probe}}$]",
**kwargs
)
number_of_segments_slider = ipywidgets.IntSlider(
value = 4,
min = 3,
max = 16,
step = 1,
description = "number of annular segments",
**kwargs
)
rotation_offset_slider = ipywidgets.IntSlider(
value = 0, min = 0, max = 180/4, step = 1,
description = "rotation offset [°]",
**kwargs
)
number_of_rings_slider = ipywidgets.IntSlider(
value = 1,
min = 1,
max = 8,
step = 1,
description = "number of radial rings",
**kwargs
)
rotate_half_the_rings = ipywidgets.ToggleButton(
value = False,
description = 'offset radial rings',
disabled = False,
layout=ipywidgets.Layout(width="155px",height="30px")
)
area_toggle = ipywidgets.ToggleButton(
value = False,
description = 'distribute by area',
layout=ipywidgets.Layout(width="155px",height="30px")
)
# def update_outer_collection_angle(change):
# value = change['new']
# outer_collection_angle_slider.min = value*1.05
# def update_inner_collection_angle(change):
# value = change['new']
# inner_collection_angle_slider.max = value
# inner_collection_angle_slider.observe(update_outer_collection_angle, names='value')
# outer_collection_angle_slider.observe(update_inner_collection_angle, names='value')
# rotation offset is modulo 180/n
def update_rotation_offset_range(change):
value = change['new']
rotation_offset_slider.max = 180/value
number_of_segments_slider.observe(update_rotation_offset_range, names='value')
def disable_all(boolean):
inner_collection_angle_slider.disabled = boolean
outer_collection_angle_slider.disabled = boolean
number_of_segments_slider.disabled = boolean
rotation_offset_slider.disabled = boolean
number_of_rings_slider.disabled = boolean
rotate_half_the_rings.disabled = boolean
area_toggle.disabled = boolean
defocus_slider.disabled = boolean
simulate_button.disabled = boolean
return None
defocus_slider = ipywidgets.IntSlider(
value = 46,
min = -n,
max = n,
step = 1,
description = r'negative defocus, $C_{1,0}$ [Å]',
**kwargs
)
def defocus_wrapper(*args):
im_ctf.set_alpha(0.25)
im_ctf_parallax.set_alpha(0.25)
plot_ctf_parallax.set_alpha(0.25)
plot_ctf_parallax_analytic.set_alpha(0.25)
im_white_noise_obj.set_alpha(0.25)
im_sto_obj.set_alpha(0.25)
im_mof_obj.set_alpha(0.25)
im_apo_obj.set_alpha(0.25)
simulate_button.button_style = 'warning'
defocus_slider.observe(defocus_wrapper,names='value')
simulate_button = ipywidgets.Button(
description='simulate (expensive)',
layout=ipywidgets.Layout(width="315px",height="30px")
)
def simulate_wrapper(*args):
disable_all(True)
simulate(
defocus_slider.value,
)
im_ctf.set_alpha(1)
im_ctf_parallax.set_alpha(1)
plot_ctf_parallax.set_alpha(1)
plot_ctf_parallax_analytic.set_alpha(1)
im_white_noise_obj.set_alpha(1)
im_sto_obj.set_alpha(1)
im_mof_obj.set_alpha(1)
im_apo_obj.set_alpha(1)
simulate_button.button_style = ''
disable_all(False)
simulate_button.on_click(simulate_wrapper)
def update_figure(
*args,
):
""" """
# compute new datasets
virtual_masks_annular = []
if area_toggle.value:
ring_collection_angles = np.linspace(
inner_collection_angle_slider.value**2,
outer_collection_angle_slider.value**2,
num=number_of_rings_slider.value + 1
)**(1/2)
else:
ring_collection_angles = np.linspace(
inner_collection_angle_slider.value,
outer_collection_angle_slider.value,
num=number_of_rings_slider.value + 1
)
if rotate_half_the_rings.value:
ring_rotation = np.deg2rad((180/number_of_segments_slider.value))
else:
ring_rotation = 0
for i in range(1,number_of_rings_slider.value+1):
j = i-1
virtual_masks_annular.append(
annular_segmented_detectors(
gpts=(n,n),
sampling=(sampling,sampling),
n_angular_bins=number_of_segments_slider.value,
inner_radius=ring_collection_angles[j],
outer_radius=ring_collection_angles[i],
rotation_offset=np.deg2rad(rotation_offset_slider.value) + ring_rotation*(j%2),
)
)
virtual_masks_annular = np.vstack(virtual_masks_annular)
# only keep nonzeero masks
virtual_masks_annular = virtual_masks_annular[virtual_masks_annular.sum((-1,-2))>0]
parallax_annular = return_segmented_parallax_ctf(
intensities[0],
virtual_masks_annular,
defocus_slider.value
)
ctf_parallax = ctf.compute_ctf(parallax_annular)
q_bins_parallax, I_bins_parallax = ctf.radially_average_ctf(
ctf_parallax,
(sampling,sampling)
)
# update data
# 2D arrays
im_detector.set_data(ctf.combined_images_rgb(virtual_masks_annular))
im_ctf_parallax.set_data(ctf.histogram_scaling(np.fft.fftshift(ctf_parallax),normalize=True))
# set radial average
plot_ctf_parallax.set_ydata(I_bins_parallax)
# collections (vlines)
axs[0,3].collections[0].remove()
axs[0,3].vlines(
[
inner_collection_angle_slider.value,
outer_collection_angle_slider.value
],0,2,
colors='k',linestyles='--',linewidth=1,
)
# real space samples
im_white_noise_obj.set_data(
ctf.histogram_scaling(
np.fft.ifft2(
np.fft.fft2(potential) * ctf_parallax
).real,
normalize=True
)
)
zero_pad_ctf_to_4qprobe = np.fft.ifftshift(
np.pad(np.fft.fftshift(ctf_parallax),48)
)
resample_2qprobe_ctf_to_192 = np.fft.fft2(
np.fft.ifftshift(
np.pad(np.fft.fftshift(np.fft.ifft2(ctf_parallax).real),48)
)
)
im_sto_obj.set_data(
ctf.histogram_scaling(
np.fft.ifft2(
np.fft.fft2(sto_potential) * zero_pad_ctf_to_4qprobe
).real,
normalize=True)
)
im_mof_obj.set_data(
ctf.histogram_scaling(
np.fft.ifft2(
np.fft.fft2(mof_potential) * zero_pad_ctf_to_4qprobe
).real,
normalize=True)
)
im_apo_obj.set_data(
ctf.histogram_scaling(
np.fft.ifft2(
np.fft.fft2(apo_potential) * zero_pad_ctf_to_4qprobe
).real,
normalize=True)
)
# re-draw figure
fig.canvas.draw_idle()
return None
inner_collection_angle_slider.observe(update_figure,names='value')
outer_collection_angle_slider.observe(update_figure,names='value')
number_of_segments_slider.observe(update_figure,names='value')
rotation_offset_slider.observe(update_figure,names='value')
number_of_rings_slider.observe(update_figure,names='value')
rotate_half_the_rings.observe(update_figure,names='value')
area_toggle.observe(update_figure,names='value')
def simulate(
C10,
):
""" """
intensities[0], probe_array_fourier[0] = simulate_intensities(
C10=C10,
)
update_analytical()
update_figure("dummy")
return None
def update_analytical():
""" """
analytical_parallax_ctf = return_analytical_parallax_ctf(
defocus_slider.value
)
im_ctf.set_data(ctf.histogram_scaling(np.fft.fftshift(analytical_parallax_ctf),normalize=True))
# Radially-averaged CTF
q_bins_parallax_analytic, I_bins_parallax_analytic = ctf.radially_average_ctf(analytical_parallax_ctf,(sampling,sampling))
# set radial average
plot_ctf_parallax_analytic.set_ydata(I_bins_parallax_analytic)
fig.canvas.draw_idle()
return None
ipywidgets.VBox(
[
ipywidgets.VBox(
[
ipywidgets.HBox([defocus_slider, simulate_button]),
ipywidgets.HTML("<hr>",layout=ipywidgets.Layout(width="640px")),
ipywidgets.HBox([inner_collection_angle_slider,outer_collection_angle_slider]),
ipywidgets.HBox([number_of_segments_slider,rotation_offset_slider]),
ipywidgets.HBox([number_of_rings_slider,rotate_half_the_rings,area_toggle]),
]
),
fig.canvas
]
)
VBox(children=(VBox(children=(HBox(children=(IntSlider(value=46, continuous_update=False, description='negativ…