Evaluating the Transfer of Information in Phase Retrieval STEM Techniques
Contents
Segmented SSB
# enable interactive matplotlib
%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import matplotlib.patches as mpatches
import ctf # import custom plotting / utils
import cmasher as cmr
from tqdm.notebook import tqdm
import ipywidgets
from IPython.display import display
4D STEM Simulation¶
# parameters
n = 96
q_max = 2 # inverse Angstroms
q_probe = 1 # inverse Angstroms
wavelength = 0.019687 # 300kV
sampling = 1 / q_max / 2 # Angstroms
reciprocal_sampling = 2 * q_max / n # inverse Angstroms
scan_step_size = 1 # pixels
sx = sy = n//scan_step_size
phi0 = 1.0
cmap = cmr.eclipse
sample_cmap = 'gray'
icom_line_color = 'cornflowerblue'
iter_ptycho_line_color = 'mediumvioletred'
pixelated_ssb_line_color = 'darkgreen'
segmented_ssb_line_color = 'yellowgreen'
White Noise Potential¶
def white_noise_object_2D(n, phi0):
""" creates a 2D real-valued array, whose FFT has random phase and constant amplitude """
evenQ = n%2 == 0
# indices
pos_ind = np.arange(1,(n if evenQ else n+1)//2)
neg_ind = np.flip(np.arange(n//2+1,n))
# random phase
arr = np.random.randn(n,n)
# top-left // bottom-right
arr[pos_ind[:,None],pos_ind[None,:]] = -arr[neg_ind[:,None],neg_ind[None,:]]
# bottom-left // top-right
arr[pos_ind[:,None],neg_ind[None,:]] = -arr[neg_ind[:,None],pos_ind[None,:]]
# kx=0
arr[0,pos_ind] = -arr[0,neg_ind]
# ky=0
arr[pos_ind,0] = -arr[neg_ind,0]
# zero-out components which don't have k-> -k mapping
if evenQ:
arr[n//2,:] = 0 # zero highest spatial freq
arr[:,n//2] = 0 # zero highest spatial freq
arr[0,0] = 0 # DC component
# fourier-array
arr = np.exp(2j*np.pi*arr)*phi0
# inverse FFT and remove floating point errors
arr = np.fft.ifft2(arr).real
return arr
# potential
potential = white_noise_object_2D(n,phi0)
complex_obj = np.exp(1j*potential)
sx, sy = potential.shape
Import sample potentials¶
sto_potential = np.load("data/STO_projected-potential_192x192_4qprobe.npy")
sto_potential -= sto_potential.mean()
mof_potential = np.load("data/MOF_projected-potential_192x192_4qprobe.npy")
mof_potential -= mof_potential.mean()
apo_potential = np.load("data/apoF_projected-potential_192x192_4qprobe.npy")
apo_potential -= apo_potential.mean()
Probe¶
def soft_aperture(q,q_probe,reciprocal_sampling):
return np.sqrt(
np.clip(
(q_probe - q)/reciprocal_sampling + 0.5,
0,
1,
),
)
def hard_aperture(q,q_probe,reciprocal_sampling):
return ((q_probe - q)>0).astype(np.float64)
qx = qy = np.fft.fftfreq(n,sampling)
q = np.sqrt(qx[:,None]**2 + qy[None,:]**2)
Kx = qx
Ky = qy
K = np.sqrt(Kx[:,None]**2 + Ky[None,:]**2)
Qx = qx
Qy = qy
x = y = np.arange(0.,n,scan_step_size)
xx, yy = np.meshgrid(x,y,indexing='ij')
positions = np.stack((xx.ravel(),yy.ravel()),axis=-1)
row, col = ctf.return_patch_indices(positions,(n,n),(n,n))
def simulate_intensities(defocus, use_soft_aperture, batch_size=n**2, pbar=None):
m = n**2
n_batch = int(m // batch_size)
order = np.arange(m).reshape((n_batch,batch_size))
intensities = np.zeros((m,n,n))
if pbar is not None:
pbar.reset(n_batch)
pbar.colour = None
pbar.refresh()
aperture = soft_aperture if use_soft_aperture else hard_aperture
probe_array_fourier = aperture(q,q_probe,reciprocal_sampling) * np.exp(-1j*np.pi*wavelength*defocus*q**2)
probe_array_fourier /= np.sqrt(np.sum(np.abs(probe_array_fourier)**2))
probe_array = np.fft.ifft2(probe_array_fourier) * n
for batch_index in range(n_batch):
batch_order = order[batch_index]
intensities[batch_order] = ctf.simulate_data(
complex_obj,
probe_array,
row[batch_order],
col[batch_order],
)
if pbar is not None:
pbar.update(1)
if pbar is not None:
pbar.colour = 'green'
return intensities.reshape((sx,sy,n,n))**2 / n**2
intensities = [
simulate_intensities(
defocus=0,
use_soft_aperture=False,
batch_size=1024,
pbar=None,
)
]
intensities_FFT = [np.fft.fft2(intensities[0],axes=(0,1))]
def annular_segmented_detectors(
gpts,
sampling,
n_angular_bins,
rotation_offset = 0,
inner_radius = 0,
outer_radius = np.inf,
):
""" """
nx,ny = gpts
sx,sy = sampling
k_x = np.fft.fftfreq(nx,sx)
k_y = np.fft.fftfreq(ny,sy)
k = np.sqrt(k_x[:,None]**2 + k_y[None,:]**2)
radial_mask = ((inner_radius <= k) & (k < outer_radius))
theta = (np.arctan2(k_y[None,:], k_x[:,None]) + rotation_offset) % (2 * np.pi)
angular_bins = np.floor(n_angular_bins * (theta / (2 * np.pi))) + 1
angular_bins *= radial_mask.astype("int")
angular_bins = [np.fft.fftshift((angular_bins == i).astype("int")) for i in range(1,n_angular_bins+1)]
return angular_bins
def mask_intensities_using_virtual_detectors(
corner_centered_intensities,
corner_centered_masked_intensities,
center_centered_masks,
):
""" """
masks = np.fft.ifftshift(np.asarray(center_centered_masks).astype(np.bool_),axes=(-1,-2))
inverse_mask = (1-masks.sum(0)).astype(np.bool_)
for mask in masks:
val = np.sum(corner_centered_intensities * mask,axis=(-1,-2)) / np.sum(mask)
corner_centered_masked_intensities[...,mask] = val[...,None]
corner_centered_masked_intensities[...,inverse_mask] = 0.0
return None
def mask_gamma_using_virtual_detectors(
corner_centered_gamma,
center_centered_masks,
):
""" """
masks = np.fft.ifftshift(np.asarray(center_centered_masks).astype(np.bool_),axes=(-1,-2))
inverse_mask = (1-masks.sum(0)).astype(np.bool_)
for mask in masks:
val = np.sum(corner_centered_gamma * mask) / np.sum(mask)
corner_centered_gamma[mask] = val
corner_centered_gamma[inverse_mask] = 0.0
return None
virtual_masks_annular = [np.zeros((n,n))]
virtual_masks_annular[0][0,0] = 1
masked_intensities_FFT = [np.zeros_like(intensities_FFT[0])]
masked_intensities_FFT[0][0,0] = 1
def ptychography_reconstruction(
masked_intensities_FFT,
virtual_masks_annular,
defocus,
use_soft_aperture,
use_OBF_weighting,
intensities_FFT=None,
pbar=None
):
aperture = soft_aperture if use_soft_aperture else hard_aperture
threshold = 1e-3 if use_soft_aperture else 0.0
psi = np.empty_like(complex_obj)
if intensities_FFT is not None:
psi_0 = np.empty_like(complex_obj)
A_q = aperture(K,q_probe,reciprocal_sampling) * np.exp(-1j*np.pi*wavelength*defocus*K**2)
A_q_conj = A_q.conj()
if use_OBF_weighting:
probe_normalization = np.abs(A_q)**2
probe_normalization /= probe_normalization.sum()
mask_gamma_using_virtual_detectors(
probe_normalization,
virtual_masks_annular,
)
if pbar is not None:
pbar.reset(sx*sy)
pbar.colour = None
pbar.refresh()
for ind_x in range(sx):
for ind_y in range(sy):
G = masked_intensities_FFT[ind_x,ind_y]
if intensities_FFT is not None:
G_0 = intensities_FFT[ind_x,ind_y]
if ind_x == 0 and ind_y == 0 :
psi[ind_x,ind_y] = np.abs(G).sum()
if intensities_FFT is not None:
psi_0[ind_x,ind_y] = np.abs(G_0).sum()
else:
q_plus_Q = np.sqrt((Kx[:,None]+Qx[ind_x])**2 + (Ky[None,:]+Qy[ind_y])**2)
A_q_plus_Q = aperture(q_plus_Q,q_probe,reciprocal_sampling) * np.exp(-1j*np.pi*wavelength*defocus*q_plus_Q**2)
q_minus_Q = np.sqrt((Kx[:,None]-Qx[ind_x])**2 + (Ky[None,:]-Qy[ind_y])**2)
A_q_minus_Q = aperture(q_minus_Q,q_probe,reciprocal_sampling) * np.exp(-1j*np.pi*wavelength*defocus*q_minus_Q**2)
gamma = A_q_conj * A_q_minus_Q - A_q * A_q_plus_Q.conj()
if intensities_FFT is not None:
gamma_abs = np.abs(gamma)
gamma_ind = gamma_abs > threshold
psi_0[ind_x,ind_y] = (G_0[gamma_ind] * np.conj(gamma[gamma_ind])/gamma_abs[gamma_ind]).sum()
mask_gamma_using_virtual_detectors(
gamma,
virtual_masks_annular,
)
gamma_abs = np.abs(gamma)
gamma_ind = gamma_abs > threshold
normalization = gamma_abs[gamma_ind]
if use_OBF_weighting:
d = probe_normalization[gamma_ind]
normalization = d * np.sqrt(np.sum(normalization**2 / d))
psi[ind_x,ind_y] = (G[gamma_ind] * np.conj(gamma[gamma_ind])/normalization).sum()
if pbar is not None:
pbar.update(1)
if pbar is not None:
pbar.colour = 'green'
return_val = (np.fft.ifft2(psi),None) if intensities_FFT is None else (np.fft.ifft2(psi),np.fft.ifft2(psi_0))
return return_val
def ptychography_reconstruction_pixelated(
intensities_FFT,
defocus,
use_soft_aperture,
use_OBF_weighting,
pbar=None,
):
aperture = soft_aperture if use_soft_aperture else hard_aperture
threshold = 1e-3 if use_soft_aperture else 0.0
psi = np.empty_like(complex_obj)
A_q = aperture(K,q_probe,reciprocal_sampling) * np.exp(-1j*np.pi*wavelength*defocus*K**2)
A_q_conj = A_q.conj()
if use_OBF_weighting:
probe_normalization = np.abs(A_q)**2
probe_normalization /= probe_normalization.sum()
if pbar is not None:
pbar.reset(sx*sy)
pbar.colour = None
pbar.refresh()
for ind_x in range(sx):
for ind_y in range(sy):
G = intensities_FFT[ind_x,ind_y]
if ind_x == 0 and ind_y == 0 :
psi[ind_x,ind_y] = np.abs(G).sum()
else:
q_plus_Q = np.sqrt((Kx[:,None]+Qx[ind_x])**2 + (Ky[None,:]+Qy[ind_y])**2)
A_q_plus_Q = aperture(q_plus_Q,q_probe,reciprocal_sampling) * np.exp(-1j*np.pi*wavelength*defocus*q_plus_Q**2)
q_minus_Q = np.sqrt((Kx[:,None]-Qx[ind_x])**2 + (Ky[None,:]-Qy[ind_y])**2)
A_q_minus_Q = aperture(q_minus_Q,q_probe,reciprocal_sampling) * np.exp(-1j*np.pi*wavelength*defocus*q_minus_Q**2)
gamma = A_q_conj * A_q_minus_Q - A_q * A_q_plus_Q.conj()
gamma_abs = np.abs(gamma)
gamma_ind = gamma_abs > threshold
normalization = gamma_abs[gamma_ind]
if use_OBF_weighting:
d = probe_normalization[gamma_ind]
normalization = d * np.sqrt(np.sum(normalization**2 /d))
psi[ind_x,ind_y] = (G[gamma_ind] * np.conj(gamma[gamma_ind])/normalization).sum()
if pbar is not None:
pbar.update(1)
if pbar is not None:
pbar.colour = 'green'
return_val = np.fft.ifft2(psi)
return return_val
recon_0 = ptychography_reconstruction_pixelated(
intensities_FFT=intensities_FFT[0],
defocus=0,
use_soft_aperture=False,
use_OBF_weighting=False,
)
numeric_ctf_0 = np.abs(np.fft.fft2(np.angle(recon_0))) / 2
numeric_ctf_0[0,0] = 0.0
q_bins_pixelated, I_bins_pixelated = ctf.radially_average_ctf(numeric_ctf_0,(sampling,sampling))
def mask_opacities(virtual_masks):
n = len(virtual_masks)
if n % 2 == 0:
vals = np.tile([0.25,0.375],n)[:n]
else:
vals = np.tile([0.25,0.375],n)[:n] + [0.125]
opacities = 1-np.tensordot(
np.array(virtual_masks),
vals,
axes=(0,0)
)
return opacities
with plt.ioff():
dpi=72
fig, axs = plt.subplots(2,4,figsize=(640/dpi,400/dpi),dpi=dpi)
empty = np.zeros((n,n))
empty[0,0] = 1
ax_trotter_pixelated = axs[0,0]
im_trotter_pixelated = ax_trotter_pixelated.imshow(virtual_masks_annular[0])
ax_trotter_annular = axs[0,1]
im_trotter_annular = ax_trotter_annular.imshow(virtual_masks_annular[0])
ax_ctf = axs[0,2]
im_ctf = ax_ctf.imshow(empty,cmap=cmap)
ax_ctf_rad = axs[0,3]
plot_ctf_pixelated = ax_ctf_rad.plot(q_bins_pixelated,I_bins_pixelated,color=pixelated_ssb_line_color,label='pixelated SSB')[0]
plot_ctf = ax_ctf_rad.plot(np.linspace(0,q_max,n//2 + 1),np.zeros(n//2 + 1),color=segmented_ssb_line_color,label='segmented SSB')[0]
for ax, title in zip(
axs.flatten(),
[
"pixelated trotter",
"segmented trotter",
"segmented CTF",
"radially-averaged CTF",
"white noise object",
"strontium titanate",
"metal-organic framework",
"apoferritin protein",
]
):
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(title)
for ax in axs[0,:3]:
ctf.add_scalebar(ax,length=n//4,sampling=reciprocal_sampling,units=r'$q_{\mathrm{probe}}$')
# remove y-ticks, add x-label, add vlines to radial avg. plot
ax_ctf_rad.set_ylim([0,1])
ax_ctf_rad.set_xlim([0,q_max])
ax_ctf_rad.vlines([q_probe/2,q_probe],0,2,colors='k',linestyles='--',linewidth=1,)
ax_ctf_rad.set_xticks([0,q_probe,q_max])
ax_ctf_rad.set_xticklabels([0,1,2])
ax_ctf_rad.set_yticks([])
ax_ctf_rad.set_aspect(2)
ax_ctf_rad.set_xlabel(r"spatial frequency, $q/q_{\mathrm{probe}}$")
ax_ctf_rad.legend()
ax_white_noise_obj = axs[1,0]
im_white_noise_obj = ax_white_noise_obj.imshow(
np.zeros((n,n)),vmin=0,vmax=1,
cmap=sample_cmap
)
ctf.add_scalebar(ax_white_noise_obj,length=n//5,sampling=sampling,units=r'Å')
ax_sto_obj = axs[1,1]
im_sto_obj = ax_sto_obj.imshow(
np.zeros((n,n)),vmin=0,vmax=1,
cmap=sample_cmap
)
sto_sampling = 23.67 / n # Å
ctf.add_scalebar(ax_sto_obj,length=n//5,sampling=sto_sampling,units=r'Å')
ax_mof_obj = axs[1,2]
im_mof_obj = ax_mof_obj.imshow(
np.zeros((n,n)),vmin=0,vmax=1,
cmap=sample_cmap
)
mof_sampling = 4.48 / n # nm
ctf.add_scalebar(ax_mof_obj,length=n//5,sampling=mof_sampling,units=r'nm')
ax_apo_obj = axs[1,3]
im_apo_obj = ax_apo_obj.imshow(
np.zeros((n,n)),vmin=0,vmax=1,
cmap=sample_cmap
)
apo_sampling = 19.2 / n # nm
ctf.add_scalebar(ax_apo_obj,length=n//5,sampling=apo_sampling,units=r'nm')
fig.tight_layout()
fig.canvas.resizable = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.toolbar_visible = False
# fig.canvas.toolbar_visible = True
# fig.canvas.toolbar_position = 'bottom'
fig.canvas.layout.width = '640px'
fig.canvas.layout.height = '420px'
fig.tight_layout()
None
style = {'description_width': 'initial'}
layout = ipywidgets.Layout(width="320px",height="30px")
layout_half = ipywidgets.Layout(width="160px",height="30px")
kwargs = {'style':style,'layout':layout,'continuous_update':False}
kwargs_half = {'style':style,'layout':layout_half,'continuous_update':False}
inner_collection_angle_slider = ipywidgets.FloatSlider(
value = q_probe/2,
min = 0,
max = q_probe,
step = q_probe/20,
description = r"inner collection angle [$q_{\mathrm{probe}}$]",
**kwargs
)
outer_collection_angle_slider = ipywidgets.FloatSlider(
value = q_probe,
min = q_probe/20,
max = q_max,
step = q_probe/20,
description = r"outer collection angle [$q_{\mathrm{probe}}$]",
**kwargs
)
number_of_segments_slider = ipywidgets.IntSlider(
value = 4,
min = 3,
max = 16,
step = 1,
description = "number of annular segments",
**kwargs
)
rotation_offset_slider = ipywidgets.IntSlider(
value = 0, min = 0, max = 180/4, step = 1,
description = "rotation offset [°]",
**kwargs
)
number_of_rings_slider = ipywidgets.IntSlider(
value = 1,
min = 1,
max = 8,
step = 1,
description = "number of radial rings",
**kwargs
)
rotate_half_the_rings = ipywidgets.ToggleButton(
value = False,
description = 'offset radial rings',
disabled = False,
layout=ipywidgets.Layout(width="155px",height="30px")
)
area_toggle = ipywidgets.ToggleButton(
value = False,
description = 'distribute by area',
layout=ipywidgets.Layout(width="155px",height="30px")
)
def update_outer_collection_angle(change):
value = change['new']
outer_collection_angle_slider.min = value*1.05
def update_inner_collection_angle(change):
value = change['new']
inner_collection_angle_slider.max = value
inner_collection_angle_slider.observe(update_outer_collection_angle, names='value')
outer_collection_angle_slider.observe(update_inner_collection_angle, names='value')
# rotation offset is modulo 180/n
def update_rotation_offset_range(change):
value = change['new']
rotation_offset_slider.max = 180/value
number_of_segments_slider.observe(update_rotation_offset_range, names='value')
defocus_slider = ipywidgets.IntSlider(
value = 0, min = -n, max = n, step = 1,
description = "negative defocus, $C_{1,0}$ [Å]",
**kwargs
)
frequencies_toggle = ipywidgets.ToggleButton(
description='show low freq trotter',
value=True,
**kwargs_half
)
OBF_toggle = ipywidgets.ToggleButton(
value=False,
description='use OBF weights',
**kwargs_half
)
simulate_button = ipywidgets.Button(
description='simulate (expensive)',
layout=ipywidgets.Layout(width="160px",height="30px")
)
simulation_pbar = tqdm(total=9,display=False)
simulation_pbar_wrapper = ipywidgets.HBox(simulation_pbar.container.children[:2],layout=ipywidgets.Layout(width="160px"))
reconstruct_button = ipywidgets.Button(
description='reconstruct (expensive)',
layout=ipywidgets.Layout(width="160px",height="30px")
)
reconstruction_pbar = tqdm(total=9,display=False)
reconstruction_pbar_wrapper = ipywidgets.HBox(reconstruction_pbar.container.children[:2],layout=ipywidgets.Layout(width="160px"))
def disable_all(boolean):
inner_collection_angle_slider.disabled = boolean
outer_collection_angle_slider.disabled = boolean
number_of_segments_slider.disabled = boolean
rotation_offset_slider.disabled = boolean
number_of_rings_slider.disabled = boolean
rotate_half_the_rings.disabled = boolean
area_toggle.disabled = boolean
defocus_slider.disabled = boolean
# soft_aperture_toggle.disabled = boolean
frequencies_toggle.disabled = boolean
OBF_toggle.disabled = boolean
simulate_button.disabled = boolean
reconstruct_button.disabled = boolean
return None
def defocus_aperture_wrapper(*args):
im_trotter_annular.set_alpha(0.25)
im_trotter_pixelated.set_alpha(0.25)
im_ctf.set_alpha(0.25)
im_white_noise_obj.set_alpha(0.25)
im_sto_obj.set_alpha(0.25)
im_mof_obj.set_alpha(0.25)
im_apo_obj.set_alpha(0.25)
plot_ctf.set_alpha(0.25)
plot_ctf_pixelated.set_alpha(0.25)
simulate_button.button_style = 'warning'
simulation_pbar.reset()
defocus_slider.observe(defocus_aperture_wrapper,names='value')
def simulate_wrapper(*args):
disable_all(True)
simulate_and_update_trotters(
defocus_slider.value,
False,
OBF_toggle.value,
pbar=simulation_pbar,
)
disable_all(False)
reconstruct_button.button_style = 'warning'
reconstruction_pbar.reset()
simulate_button.button_style = ''
disable_all(False)
simulate_button.on_click(simulate_wrapper)
def reconstruct_wrapper(*args):
disable_all(True)
update_ctfs(
defocus_slider.value,
False,
OBF_toggle.value,
pbar=reconstruction_pbar,
)
disable_all(False)
reconstruct_button.on_click(reconstruct_wrapper)
def simulate_and_update_trotters(
defocus,
use_soft_aperture,
use_OBF_weighting,
pbar=None,
):
""" """
intensities[0] = simulate_intensities(
defocus=defocus,
use_soft_aperture=use_soft_aperture,
pbar=pbar,
batch_size=1024,
)
intensities_FFT[0] = np.fft.fft2(intensities[0],axes=(0,1))
ind = n//12 if frequencies_toggle.value else n//8
update_virtual_and_pixelated_trotters()
update_pixelated_ctf(
defocus=defocus,
use_soft_aperture=use_soft_aperture,
use_OBF_weighting=use_OBF_weighting,
intensities_FFT = intensities_FFT[0],
)
return None
def update_pixelated_ctf(
defocus,
use_soft_aperture,
use_OBF_weighting,
intensities_FFT,
):
""" """
recon_0 = ptychography_reconstruction_pixelated(
intensities_FFT=intensities_FFT,
defocus=defocus,
use_soft_aperture=use_soft_aperture,
use_OBF_weighting= use_OBF_weighting,
)
numeric_ctf_0 = ctf.compute_ctf(np.angle(recon_0)) / 2
numeric_ctf_0[0,0] = 0.0
q_bins_pixelated, I_bins_pixelated = ctf.radially_average_ctf(numeric_ctf_0,(sampling,sampling))
plot_ctf_pixelated.set_ydata(I_bins_pixelated)
plot_ctf_pixelated.set_alpha(1)
fig.canvas.draw()
return None
def update_ctfs(
defocus,
use_soft_aperture,
use_OBF_weighting,
intensities_FFT=None,
pbar=None,
):
""" """
recon,recon_0 = ptychography_reconstruction(
masked_intensities_FFT[0],
virtual_masks_annular[0],
defocus=defocus,
use_soft_aperture=use_soft_aperture,
use_OBF_weighting=use_OBF_weighting,
intensities_FFT=intensities_FFT,
pbar=pbar,
)
numeric_ctf = ctf.compute_ctf(np.angle(recon)) / 2
numeric_ctf[0,0] = 0.0
im_ctf.set_data(
ctf.histogram_scaling(
np.fft.fftshift(numeric_ctf),
normalize=True
)
)
# real space samples
im_white_noise_obj.set_data(
ctf.histogram_scaling(
np.fft.ifft2(
np.fft.fft2(potential) * numeric_ctf
).real
,normalize=True
)
)
zero_pad_ctf_to_4qprobe = np.fft.ifftshift(np.pad(np.fft.fftshift(numeric_ctf),n//2))
resample_2qprobe_ctf_to_192 = np.fft.fft2(
np.fft.ifftshift(
np.pad(
np.fft.fftshift(
np.fft.ifft2(numeric_ctf).real),
n//2)
)
)
im_sto_obj.set_data(
ctf.histogram_scaling(
np.fft.ifft2(
np.fft.fft2(sto_potential) * zero_pad_ctf_to_4qprobe).real,
normalize=True)
)
im_mof_obj.set_data(
ctf.histogram_scaling(
np.fft.ifft2(
np.fft.fft2(mof_potential) * zero_pad_ctf_to_4qprobe).real,
normalize=True)
)
im_apo_obj.set_data(
ctf.histogram_scaling(
np.fft.ifft2(
np.fft.fft2(apo_potential) * zero_pad_ctf_to_4qprobe).real,
normalize=True)
)
q_bins, I_bins = ctf.radially_average_ctf(numeric_ctf,(sampling,sampling))
plot_ctf.set_ydata(I_bins)
if recon_0 is not None:
numeric_ctf_0 = ctf.compute_ctf(np.angle(recon_0)) / 2
numeric_ctf_0[0,0] = 0.0
q_bins_pixelated, I_bins_pixelated = ctf.radially_average_ctf(numeric_ctf_0,(sampling,sampling))
plot_ctf_pixelated.set_ydata(I_bins_pixelated)
ax_ctf_rad.collections[0].remove()
ax_ctf_rad.vlines(
[
inner_collection_angle_slider.value,
outer_collection_angle_slider.value
],0,2,
colors='k',linestyles='--',linewidth=1,
)
im_ctf.set_alpha(1)
im_white_noise_obj.set_alpha(1)
im_sto_obj.set_alpha(1)
im_mof_obj.set_alpha(1)
im_apo_obj.set_alpha(1)
plot_ctf.set_alpha(1)
plot_ctf_pixelated.set_alpha(1)
reconstruct_button.button_style = ''
fig.canvas.draw()
return None
def update_virtual_and_pixelated_trotters(
*args,
):
""" """
disable_all(True)
# compute new datasets
_virtual_masks_annular = []
if area_toggle.value:
ring_collection_angles = np.linspace(
inner_collection_angle_slider.value**2,
outer_collection_angle_slider.value**2,
num=number_of_rings_slider.value + 1
)**(1/2)
else:
ring_collection_angles = np.linspace(
inner_collection_angle_slider.value,
outer_collection_angle_slider.value,
num=number_of_rings_slider.value + 1
)
if rotate_half_the_rings.value:
ring_rotation = np.deg2rad((180/number_of_segments_slider.value))
else:
ring_rotation = 0
for i in range(1,number_of_rings_slider.value+1):
j = i-1
_virtual_masks_annular.append(
annular_segmented_detectors(
gpts=(n,n),
sampling=(sampling,sampling),
n_angular_bins=number_of_segments_slider.value,
inner_radius=ring_collection_angles[j],
outer_radius=ring_collection_angles[i],
rotation_offset=np.deg2rad(rotation_offset_slider.value) + ring_rotation*(j%2),
)
)
virtual_masks_annular[0] = np.vstack(_virtual_masks_annular)
# Previously: mask intensities and then take FFT
# mask_intensities_using_virtual_detectors(
# intensities[0],
# masked_intensities_FFT[0],
# virtual_masks_annular[0]
# )
# masked_intensities_FFT[0] = np.fft.fft2(masked_intensities_FFT[0],axes=(0,1))
# Now: equivalently, mask intensities FFT directly -- thanks linearity
mask_intensities_using_virtual_detectors(
intensities_FFT[0],
masked_intensities_FFT[0],
virtual_masks_annular[0]
)
ind = n//12 if frequencies_toggle.value else n//8
im_trotter_pixelated.set_data(
np.dstack(
(
ctf.complex_to_rgb(
np.fft.fftshift(intensities_FFT[0][ind,2*ind])
),
mask_opacities(virtual_masks_annular[0])
)
)
)
im_trotter_annular.set_data(
ctf.complex_to_rgb(
np.fft.fftshift(masked_intensities_FFT[0][ind,2*ind])
)
)
im_trotter_annular.set_alpha(1)
im_trotter_pixelated.set_alpha(1)
im_ctf.set_alpha(0.25)
im_white_noise_obj.set_alpha(0.25)
im_sto_obj.set_alpha(0.25)
im_mof_obj.set_alpha(0.25)
im_apo_obj.set_alpha(0.25)
plot_ctf.set_alpha(0.25)
reconstruct_button.button_style = 'warning'
reconstruction_pbar.reset()
fig.canvas.draw()
disable_all(False)
return None
inner_collection_angle_slider.observe(update_virtual_and_pixelated_trotters,names='value')
outer_collection_angle_slider.observe(update_virtual_and_pixelated_trotters,names='value')
number_of_segments_slider.observe(update_virtual_and_pixelated_trotters,names='value')
rotation_offset_slider.observe(update_virtual_and_pixelated_trotters,names='value')
number_of_rings_slider.observe(update_virtual_and_pixelated_trotters,names='value')
rotate_half_the_rings.observe(update_virtual_and_pixelated_trotters,names='value')
area_toggle.observe(update_virtual_and_pixelated_trotters,names='value')
def update_trotters_frequency(*args):
ind = n//12 if frequencies_toggle.value else n//8
im_trotter_pixelated.set_data(
np.dstack(
(
ctf.complex_to_rgb(
np.fft.fftshift(intensities_FFT[0][ind,2*ind])
),
mask_opacities(virtual_masks_annular[0])
)
)
)
im_trotter_annular.set_data(
ctf.complex_to_rgb(
np.fft.fftshift(masked_intensities_FFT[0][ind,2*ind])
)
)
fig.canvas.draw()
return None
frequencies_toggle.observe(update_trotters_frequency,'value')
def update_weights(change):
disable_all(True)
recon_0 = ptychography_reconstruction_pixelated(
intensities_FFT=intensities_FFT[0],
defocus=defocus_slider.value,
use_soft_aperture=False,
use_OBF_weighting= change.new,
)
numeric_ctf_0 = ctf.compute_ctf(np.angle(recon_0)) / 2
numeric_ctf_0[0,0] = 0.0
q_bins_pixelated, I_bins_pixelated = ctf.radially_average_ctf(numeric_ctf_0,(sampling,sampling))
plot_ctf_pixelated.set_ydata(I_bins_pixelated)
plot_ctf_pixelated.set_alpha(1)
im_ctf.set_alpha(0.25)
im_white_noise_obj.set_alpha(0.25)
im_sto_obj.set_alpha(0.25)
im_mof_obj.set_alpha(0.25)
im_apo_obj.set_alpha(0.25)
plot_ctf.set_alpha(0.25)
fig.canvas.draw()
reconstruct_button.button_style = 'warning'
reconstruction_pbar.reset()
disable_all(False)
return None
OBF_toggle.observe(update_weights,'value')
update_virtual_and_pixelated_trotters()
display(
ipywidgets.VBox(
[
ipywidgets.VBox(
[
ipywidgets.HBox([defocus_slider,simulate_button,simulation_pbar_wrapper]),
ipywidgets.HBox([frequencies_toggle,OBF_toggle,reconstruct_button,reconstruction_pbar_wrapper]),
ipywidgets.HTML("<hr>",layout=ipywidgets.Layout(width="640px")),
ipywidgets.HBox([inner_collection_angle_slider,outer_collection_angle_slider]),
ipywidgets.HBox([number_of_segments_slider,rotation_offset_slider]),
ipywidgets.HBox([number_of_rings_slider,rotate_half_the_rings,area_toggle]),
]
),
fig.canvas
]
)
)
VBox(children=(VBox(children=(HBox(children=(IntSlider(value=0, continuous_update=False, description='negative…