Evaluating the Transfer of Information in Phase Retrieval STEM Techniques
Contents
Segmented Iterative Ptychography
# enable interactive matplotlib
%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import ctf # import custom plotting / utils
import cmasher as cmr
from tqdm.notebook import tqdm
import ipywidgets
from IPython.display import display
4D STEM Simulation¶
# parameters
n = 96
q_max = 2 # inverse Angstroms
q_probe = 1 # inverse Angstroms
wavelength = 0.019687 # 300kV
sampling = 1 / q_max / 2 # Angstroms
reciprocal_sampling = 2 * q_max / n # inverse Angstroms
scan_step_size = 1 # pixels
sx = sy = n//scan_step_size
phi0 = 1.0
cmap = cmr.eclipse
sample_cmap = 'gray'
segmented_icom_line_color = 'cornflowerblue'
segmented_ptycho_line_color = 'orchid'
pixelated_ptycho_line_color = 'darkmagenta'
White Noise Potential¶
def white_noise_object_2D(n, phi0):
""" creates a 2D real-valued array, whose FFT has random phase and constant amplitude """
evenQ = n%2 == 0
# indices
pos_ind = np.arange(1,(n if evenQ else n+1)//2)
neg_ind = np.flip(np.arange(n//2+1,n))
# random phase
arr = np.random.randn(n,n)
# top-left // bottom-right
arr[pos_ind[:,None],pos_ind[None,:]] = -arr[neg_ind[:,None],neg_ind[None,:]]
# bottom-left // top-right
arr[pos_ind[:,None],neg_ind[None,:]] = -arr[neg_ind[:,None],pos_ind[None,:]]
# kx=0
arr[0,pos_ind] = -arr[0,neg_ind]
# ky=0
arr[pos_ind,0] = -arr[neg_ind,0]
# zero-out components which don't have k-> -k mapping
if evenQ:
arr[n//2,:] = 0 # zero highest spatial freq
arr[:,n//2] = 0 # zero highest spatial freq
arr[0,0] = 0 # DC component
# fourier-array
arr = np.exp(2j*np.pi*arr)*phi0
# inverse FFT and remove floating point errors
arr = np.fft.ifft2(arr).real
return arr
# potential
potential = white_noise_object_2D(n,phi0)
complex_obj = np.exp(1j*potential)
Import sample potentials¶
sto_potential = np.load("data/STO_projected-potential_192x192_4qprobe.npy")
sto_potential -= sto_potential.mean()
mof_potential = np.load("data/MOF_projected-potential_192x192_4qprobe.npy")
mof_potential -= mof_potential.mean()
apo_potential = np.load("data/apoF_projected-potential_192x192_4qprobe.npy")
apo_potential -= apo_potential.mean()
Probe¶
# we build probe in Fourier space, using a soft aperture
qx = qy = np.fft.fftfreq(n,sampling)
q2 = qx[:,None]**2 + qy[None,:]**2
q = np.sqrt(q2)
aperture_fourier = np.sqrt(
np.clip(
(q_probe - q)/reciprocal_sampling + 0.5,
0,
1,
),
)
# # normalized s.t. np.sum(np.abs(probe_array_fourier)**2) = 1.0
# probe_array_fourier /= np.sqrt(np.sum(np.abs(probe_array_fourier)**2))
# # we then take the inverse FFT, and normalize s.t. np.sum(np.abs(probe_array)**2) = 1.0
# probe_array = np.fft.ifft2(probe_array_fourier) * n
def simulate_intensities(defocus, batch_size=n**2, pbar=None):
m = n**2
n_batch = int(m // batch_size)
order = np.arange(m).reshape((n_batch,batch_size))
amplitudes = np.zeros((m,n,n))
if pbar is not None:
pbar.reset(n_batch)
pbar.colour = None
pbar.refresh()
probe_array_fourier = aperture_fourier * np.exp(-1j*np.pi*wavelength*defocus*q**2)
probe_array_fourier /= np.sqrt(np.sum(np.abs(probe_array_fourier)**2))
probe_array = np.fft.ifft2(probe_array_fourier) * n
for batch_index in range(n_batch):
batch_order = order[batch_index]
amplitudes[batch_order] = ctf.simulate_data(
complex_obj,
probe_array,
row[batch_order],
col[batch_order],
)
if pbar is not None:
pbar.update(1)
if pbar is not None:
pbar.colour = 'green'
return [amplitudes, probe_array, probe_array_fourier]
x = y = np.arange(0.,n,scan_step_size)
xx, yy = np.meshgrid(x,y,indexing='ij')
positions = np.stack((xx.ravel(),yy.ravel()),axis=-1)
row, col = ctf.return_patch_indices(positions,(n,n),(n,n))
amplitudes_probe = simulate_intensities(defocus=0, batch_size=1024, pbar=None)
intensities = [amplitudes_probe[0].reshape((sx,sy,n,n))**2 / n**2,None]
intensities[1] = intensities[0].sum((-1,-2))
Virtual Detectors and CoM calculation¶
def annular_segmented_detectors(
gpts,
sampling,
n_angular_bins,
rotation_offset = 0,
inner_radius = 0,
outer_radius = np.inf,
):
""" """
nx,ny = gpts
sx,sy = sampling
k_x = np.fft.fftfreq(nx,sx)
k_y = np.fft.fftfreq(ny,sy)
k = np.sqrt(k_x[:,None]**2 + k_y[None,:]**2)
radial_mask = ((inner_radius <= k) & (k < outer_radius))
theta = (np.arctan2(k_y[None,:], k_x[:,None]) + rotation_offset) % (2 * np.pi)
angular_bins = np.floor(n_angular_bins * (theta / (2 * np.pi))) + 1
angular_bins *= radial_mask.astype("int")
angular_bins = [np.fft.fftshift((angular_bins == i).astype("int")) for i in range(1,n_angular_bins+1)]
return angular_bins
def compute_com_using_virtual_detectors(
corner_centered_intensities,
center_centered_masks,
corner_centered_intensities_sum,
sx,sy,
kxa,kya,
):
""" """
masks = np.fft.ifftshift(np.asarray(center_centered_masks),axes=(-1,-2))
com_x = np.zeros((sx,sy))
com_y = np.zeros((sx,sy))
kx = ky = np.fft.fftfreq(n,sampling).astype(np.float32)
kxa, kya = np.meshgrid(kx, ky, indexing='ij')
for mask in masks:
kxa_i,kya_i=np.where(mask)
patches= corner_centered_intensities[:,:,kxa_i,kya_i].sum(-1) / corner_centered_intensities_sum
com_x += patches * np.mean(kxa[kxa_i,kya_i])
com_y += patches * np.mean(kya[kxa_i,kya_i])
return com_x, com_y
def integrate_com(
com_x,
com_y,
kx_op,
ky_op,
):
""" """
icom_fft = np.fft.fft2(com_x)*kx_op + np.fft.fft2(com_y)*ky_op
return np.real(np.fft.ifft2(icom_fft))
def bin_amplitudes_using_virtual_detectors(
corner_centered_amplitudes,
center_centered_masks,
):
""" """
masks = np.fft.ifftshift(np.asarray(center_centered_masks).astype(np.bool_),axes=(-1,-2))
inverse_mask = (1-masks.sum(0)).astype(np.bool_)
values = np.zeros((masks.shape[0],corner_centered_amplitudes.shape[0]))
for index, mask in enumerate(masks):
values[index] = np.sqrt(np.sum(corner_centered_amplitudes**2 * mask,axis=(-1,-2)))
return values
def virtual_detector_ptycho_reconstruction(
binned_amplitude_values,
row,
col,
positions,
center_centered_masks,
recon,
probe_array,
pbars,
batch_size = n**2,
iterations=64,
step_size=1.0,
):
""" """
m = binned_amplitude_values.shape[1]
nx, ny = probe_array.shape
n = int(m // batch_size)
order = np.arange(m)
np.random.shuffle(order)
masks = np.fft.ifftshift(np.asarray(center_centered_masks,dtype=np.bool_),axes=(-1,-2))
inverse_mask = (1-masks.sum(0)).astype(np.bool_)
# normalization
probe_normalization = np.mean(np.sum(binned_amplitude_values**2,0)) / nx / ny
shifted_probes = probe_array * np.sqrt(probe_normalization)
outer_pbar,inner_pbar = pbars
outer_pbar.reset(iterations)
outer_pbar.colour= None
outer_pbar.refresh()
for iter_index in range(iterations):
inner_pbar.reset(n)
inner_pbar.colour= None
inner_pbar.refresh()
for batch_index in range(n):
batch_order = order.reshape((n,batch_size))[batch_index]
batch_amplitudes = binned_amplitude_values[:,batch_order]
batch_pos = positions[batch_order]
batch_row = row[batch_order]
batch_col = col[batch_order]
# recon
obj_patches = recon[batch_row,batch_col]
overlap = shifted_probes * obj_patches
fourier_overlap = np.fft.fft2(overlap)
fourier_intensities = np.abs(fourier_overlap)**2
# preprocess fourier overlap
old_fourier_overlap_sum = np.sum(np.abs(fourier_overlap)**2)
fourier_overlap[...,inverse_mask] = 0.0
modified_fourier_overlap = fourier_overlap.copy()
new_fourier_overlap_sum = 0.0
for mask, amp_val in zip(masks,batch_amplitudes):
squared_val = np.sum(fourier_intensities * mask,axis=(-1,-2))
new_fourier_overlap_sum += np.sum(squared_val)
modified_fourier_overlap[...,mask] *= (amp_val/np.sqrt(squared_val))[:,None]
modified_fourier_overlap /= np.sqrt(old_fourier_overlap_sum/new_fourier_overlap_sum)
grad = np.fft.ifft2(modified_fourier_overlap-fourier_overlap)
update = ctf.sum_patches(
grad*np.conj(shifted_probes),
batch_pos,
(nx,ny),
(nx,ny),
) / probe_normalization
recon += (step_size*update)
amp = np.abs(recon).clip(0.0,1.0)
recon = amp * np.exp(1j*np.angle(recon))
inner_pbar.update(1)
np.random.shuffle(order)
update_ptycho_panel(recon)
outer_pbar.update(1)
inner_pbar.colour='green'
outer_pbar.colour='green'
return recon
# Spatial frequencies
kx = ky = np.fft.fftfreq(n,sampling).astype(np.float32)
kxa, kya = np.meshgrid(kx, ky, indexing='ij')
k2 = kxa**2 + kya**2
k = np.sqrt(k2)
k2[0, 0] = np.inf
# iCoM operators
kx_op = -1.0j * kxa / k2
ky_op = -1.0j * kya / k2
# Compute the inverse error
inverse_error = (k*np.pi/np.sqrt(2))
# Initial masks and recon
virtual_masks_annular = [ annular_segmented_detectors(
gpts=(n,n),
sampling=(sampling,sampling),
n_angular_bins=4,
inner_radius=q_probe/2,
outer_radius=q_probe*1.05,
rotation_offset=0,
)]
# com
com_x, com_y = compute_com_using_virtual_detectors(
intensities[0],
virtual_masks_annular[0],
intensities[1],
sx,sy,
kxa,kya,
)
icom_annular = integrate_com(com_x,com_y,kx_op,ky_op)
ctf_annular = ctf.compute_ctf(icom_annular)
q_bins_annular, I_bins_annular = ctf.radially_average_ctf(
ctf_annular,
(sampling,sampling)
)
# # Analytical CTF (probe autocorrelation)
# ctf_analytic = np.abs(
# np.real(
# np.fft.ifft2(
# np.abs(
# np.fft.fft2(
# amplitudes_probe[2]
# )
# )**2
# )
# )
# )
# # Radially-averaged CTF and SNR
# q_bins_analytic, I_bins_analytic = ctf.radially_average_ctf(ctf_analytic,(sampling,sampling))
amplitude_values = [None]
ptycho_recon = [np.ones((n,n),dtype=np.complex128)]
with plt.ioff():
dpi=72
fig, axs = plt.subplots(2,4,figsize=(640/dpi,400/dpi),dpi=dpi)
# detector
ax_detector = axs[0,0]
im_detector = ax_detector.imshow(ctf.combined_images_rgb(virtual_masks_annular[0]))
# annular CTF
ax_ctf_annular_dpc = axs[0,1]
im_ctf_dpc = ax_ctf_annular_dpc.imshow(ctf.histogram_scaling(np.fft.fftshift(ctf_annular),normalize=True),cmap=cmap)
# ptycho CTF
ax_ctf_annular_ptycho = axs[0,2]
im_ctf_ptycho = ax_ctf_annular_ptycho.imshow(np.zeros((n,n)),cmap=cmap,vmin=0,vmax=1)
# analytic CTF radially-averaged
ax_ctf_rad = axs[0,3]
# plot_ctf_dpc_analytical = ax_ctf_rad.plot(q_bins_analytic,I_bins_analytic,color='k',label='pixelated iCOM')[0]
plot_ctf_dpc = ax_ctf_rad.plot(q_bins_annular, I_bins_annular, color=segmented_icom_line_color,label='segmented iCOM')[0]
plot_ctf_ptycho = ax_ctf_rad.plot(q_bins_annular, np.zeros_like(I_bins_annular), color=segmented_ptycho_line_color,label='segmented ptycho')[0]
ax_ctf_rad.legend()
# remove ticks, add titles to 2D-plots
for ax, title in zip(
axs.flatten(),
[
"detector geometry",
"segmented iCOM CTF",
"segmented ptycho CTF",
"radially averaged CTF",
"white noise object",
"strontium titanate",
"metal-organic framework",
"apoferritin protein",
]
):
ax.set_xticks([])
ax.set_yticks([])
ax.set_title(title)
for ax in axs[0,:3]:
ctf.add_scalebar(ax,length=n//4,sampling=reciprocal_sampling,units=r'$q_{\mathrm{probe}}$')
ax_ctf_rad.set_ylim([0,1])
ax_ctf_rad.set_xlim([0,q_max])
ax_ctf_rad.vlines([q_probe/2,q_probe*1.05],0,2,colors='k',linestyles='--',linewidth=1,)
ax_ctf_rad.set_xticks([0,q_probe,q_max])
ax_ctf_rad.set_xticklabels([0,1,2])
ax_ctf_rad.set_aspect(2)
ax_ctf_rad.set_xlabel(r"spatial frequency, $q/q_{\mathrm{probe}}$")
ax_white_noise_obj = axs[1,0]
im_white_noise_obj = ax_white_noise_obj.imshow(
np.zeros((n,n)),vmin=0,vmax=1,
cmap=sample_cmap
)
ctf.add_scalebar(ax_white_noise_obj,length=n//5,sampling=sampling,units=r'Å')
ax_sto_obj = axs[1,1]
im_sto_obj = ax_sto_obj.imshow(
np.zeros((n,n)),vmin=0,vmax=1,
cmap=sample_cmap
)
sto_sampling = 23.67 / n # Å
ctf.add_scalebar(ax_sto_obj,length=n//5,sampling=sto_sampling,units=r'Å')
ax_mof_obj = axs[1,2]
im_mof_obj = ax_mof_obj.imshow(
np.zeros((n,n)),vmin=0,vmax=1,
cmap=sample_cmap
)
mof_sampling = 4.48 / n # nm
ctf.add_scalebar(ax_mof_obj,length=n//5,sampling=mof_sampling,units=r'nm')
ax_apo_obj = axs[1,3]
im_apo_obj = ax_apo_obj.imshow(
np.zeros((n,n)),vmin=0,vmax=1,
cmap=sample_cmap
)
apo_sampling = 19.2 / n # nm
ctf.add_scalebar(ax_apo_obj,length=n//5,sampling=apo_sampling,units=r'nm')
im_ctf_ptycho.set_alpha(0.25)
im_white_noise_obj.set_alpha(0.25)
im_sto_obj.set_alpha(0.25)
im_mof_obj.set_alpha(0.25)
im_apo_obj.set_alpha(0.25)
# fix ipympl canvas from resizing
fig.canvas.resizable = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.toolbar_visible = False
# fig.canvas.toolbar_visible = True
# fig.canvas.toolbar_position = 'bottom'
fig.canvas.layout.width = '640px'
fig.canvas.layout.height = '420px'
fig.tight_layout()
# fig
def update_ptycho_panel(ptycho_recon):
""" """
ctf_ptycho = ctf.compute_ctf(ptycho_recon)
im_ctf_ptycho.set_data(
ctf.histogram_scaling(
np.fft.fftshift(ctf_ptycho),
normalize=True
)
)
# real space samples
im_white_noise_obj.set_data(
ctf.histogram_scaling(
np.fft.ifft2(
np.fft.fft2(potential) * ctf_ptycho
).real
,normalize=True
)
)
zero_pad_ctf_to_4qprobe = np.fft.ifftshift(np.pad(np.fft.fftshift(ctf_ptycho),48))
resample_2qprobe_ctf_to_192 = np.fft.fft2(np.fft.ifftshift(np.pad(np.fft.fftshift(np.fft.ifft2(ctf_ptycho).real),48)))
im_sto_obj.set_data(ctf.histogram_scaling(np.fft.ifft2(np.fft.fft2(sto_potential) * zero_pad_ctf_to_4qprobe).real,normalize=True))
im_mof_obj.set_data(ctf.histogram_scaling(np.fft.ifft2(np.fft.fft2(mof_potential) * zero_pad_ctf_to_4qprobe).real,normalize=True))
im_apo_obj.set_data(ctf.histogram_scaling(np.fft.ifft2(np.fft.fft2(apo_potential) * zero_pad_ctf_to_4qprobe).real,normalize=True))
# radially average
_, I_bins_ptycho = ctf.radially_average_ctf(
ctf_ptycho,
(sampling,sampling)
)
plot_ctf_ptycho.set_ydata(I_bins_ptycho)
im_ctf_ptycho.set_alpha(1)
im_white_noise_obj.set_alpha(1)
im_sto_obj.set_alpha(1)
im_mof_obj.set_alpha(1)
im_apo_obj.set_alpha(1)
plot_ctf_ptycho.set_alpha(1)
fig.canvas.draw()
return None
def compute_ptycho_updates(
batch_size,
iterations,
pbars,
):
""" """
if amplitude_values[0] is None:
amplitude_values[0] = bin_amplitudes_using_virtual_detectors(
amplitudes_probe[0],
virtual_masks_annular[0],
)
ptycho_recon[0] = virtual_detector_ptycho_reconstruction(
amplitude_values[0],
row,
col,
positions,
virtual_masks_annular[0],
ptycho_recon[0],
amplitudes_probe[1],
pbars,
batch_size=batch_size,
iterations = iterations,
)
return None
style = {'description_width': 'initial'}
layout = ipywidgets.Layout(width="320px",height="30px")
smaller_layout = ipywidgets.Layout(width="160px",height="30px")
kwargs = {'style':style,'layout':layout,'continuous_update':False}
inner_collection_angle_slider = ipywidgets.FloatSlider(
value = q_probe/2,
min = 0,
max = q_probe,
step = q_probe/20,
description = r"inner collection angle [$q_{\mathrm{probe}}$]",
**kwargs
)
outer_collection_angle_slider = ipywidgets.FloatSlider(
value = q_probe*1.05,
min = q_probe/20,
max = q_max,
step = q_probe/20,
description = r"outer collection angle [$q_{\mathrm{probe}}$]",
**kwargs
)
number_of_segments_slider = ipywidgets.IntSlider(
value = 4,
min = 3,
max = 16,
step = 1,
description = "number of annular segments",
**kwargs
)
rotation_offset_slider = ipywidgets.IntSlider(
value = 0, min = 0, max = 180/4, step = 1,
description = "rotation offset [°]",
**kwargs
)
number_of_rings_slider = ipywidgets.IntSlider(
value = 1,
min = 1,
max = 8,
step = 1,
description = "number of radial rings",
**kwargs
)
rotate_half_the_rings = ipywidgets.ToggleButton(
value = False,
description = 'offset radial rings',
disabled = False,
layout=ipywidgets.Layout(width="155px",height="30px")
)
area_toggle = ipywidgets.ToggleButton(
value = False,
description = 'distribute by area',
layout=ipywidgets.Layout(width="155px",height="30px")
)
def update_outer_collection_angle(change):
value = change['new']
outer_collection_angle_slider.min = value*1.05
def update_inner_collection_angle(change):
value = change['new']
inner_collection_angle_slider.max = value
inner_collection_angle_slider.observe(update_outer_collection_angle, names='value')
outer_collection_angle_slider.observe(update_inner_collection_angle, names='value')
# rotation offset is modulo 180/n
def update_rotation_offset_range(change):
value = change['new']
rotation_offset_slider.max = 180/value
number_of_segments_slider.observe(update_rotation_offset_range, names='value')
batch_size_slider = ipywidgets.IntSlider(
value = n**2, min = 1, max = n**2, step = 1,
description = "batch size",
**kwargs
)
m = n**2
batch_sizes = m/(np.arange(m)+1)
batch_sizes = np.where(np.mod(batch_sizes, 1, out=batch_sizes)==0)[0]+1
batch_size_slider = ipywidgets.SelectionSlider(
options=batch_sizes,
value=batch_sizes[-7],
description= "batch size",
**kwargs
)
iterations_slider = ipywidgets.IntSlider(
value = 4, min = 1, max = 32, step = 1,
description = "(outer loop) iterations",
**kwargs
)
iterate_button = ipywidgets.Button(
description="reconstruct (expensive)",
layout=smaller_layout,
)
reset_button = ipywidgets.Button(
description="reset object",
layout=smaller_layout,
)
defocus_slider = ipywidgets.IntSlider(
value = 0, min = -n, max = n, step = 1,
description = "negative defocus, $C_{1,0}$ [Å]",
**kwargs
)
simulate_button = ipywidgets.Button(
description='simulate (expensive)',
layout=ipywidgets.Layout(width="160px",height="30px")
)
simulation_pbar = tqdm(total=9,display=False)
simulation_pbar_wrapper = ipywidgets.HBox(
simulation_pbar.container.children[:2],
layout=ipywidgets.Layout(width="160px")
)
def defocus_wrapper(*args):
""" """
simulate_button.button_style = 'warning'
reset_wrapper()
im_ctf_dpc.set_alpha(0.25)
im_ctf_ptycho.set_alpha(0.25)
plot_ctf_dpc.set_alpha(0.25)
plot_ctf_ptycho.set_alpha(0.25)
# plot_ctf_dpc_analytical.set_alpha(0.25)
im_white_noise_obj.set_alpha(0.25)
im_sto_obj.set_alpha(0.25)
im_mof_obj.set_alpha(0.25)
im_apo_obj.set_alpha(0.25)
simulation_pbar.reset()
defocus_slider.observe(defocus_wrapper,names='value')
def simulate_wrapper(*args):
disable_all(True)
amplitudes_probe[:] = simulate_intensities(
defocus=defocus_slider.value,
batch_size=1024,
pbar=simulation_pbar
)
intensities[0] = amplitudes_probe[0].reshape((sx,sy,n,n))**2 / n**2
intensities[1] = intensities[0].sum((-1,-2))
# ctf_analytic = np.real(
# np.fft.ifft2(
# np.abs(
# np.fft.fft2(
# amplitudes_probe[2]
# )
# )**2
# )
# )
# # Radially-averaged CTF and SNR
# q_bins_analytic, I_bins_analytic = ctf.radially_average_ctf(ctf_analytic,(sampling,sampling))
# plot_ctf_dpc_analytical.set_ydata(I_bins_analytic)
# plot_ctf_dpc_analytical.set_alpha(1)
amplitude_values[0] = bin_amplitudes_using_virtual_detectors(
amplitudes_probe[0],
virtual_masks_annular[0],
)
update_figure("dummy")
disable_all(False)
iterate_button.button_style = 'warning'
outer_reconstruct_pbar.reset()
inner_reconstruct_pbar.reset()
simulate_button.button_style = ''
simulate_button.on_click(simulate_wrapper)
def reset_wrapper(*args):
""" """
ptycho_recon[0] = np.ones((n,n),dtype=np.complex128)
update_ptycho_panel(ptycho_recon[0])
if simulate_button.button_style != 'warning':
iterate_button.button_style = 'warning'
outer_reconstruct_pbar.reset()
inner_reconstruct_pbar.reset()
im_ctf_ptycho.set_alpha(0.25)
im_white_noise_obj.set_alpha(0.25)
im_sto_obj.set_alpha(0.25)
im_mof_obj.set_alpha(0.25)
im_apo_obj.set_alpha(0.25)
reset_button.on_click(reset_wrapper)
def disable_all(boolean):
""" """
inner_collection_angle_slider.disabled = boolean
outer_collection_angle_slider.disabled = boolean
number_of_segments_slider.disabled = boolean
rotation_offset_slider.disabled = boolean
number_of_rings_slider.disabled = boolean
rotate_half_the_rings.disabled = boolean
area_toggle.disabled = boolean
batch_size_slider.disabled = boolean
iterations_slider.disabled = boolean
reset_button.disabled = boolean
iterate_button.disabled = boolean
defocus_slider.disabled = boolean
simulate_button.disabled = boolean
simulation_pbar_wrapper.disabled = boolean
def click_wrapper(*args):
""" """
disable_all(True)
compute_ptycho_updates(
batch_size=batch_size_slider.value,
iterations=iterations_slider.value,
pbars=(outer_reconstruct_pbar,inner_reconstruct_pbar),
)
disable_all(False)
iterate_button.button_style = ''
iterate_button.on_click(click_wrapper)
outer_reconstruct_pbar = tqdm(total=4,display=False)
outer_reconstruct_pbar_wrapper = ipywidgets.HBox(
outer_reconstruct_pbar.container.children[:2],
layout=ipywidgets.Layout(width="160px")
)
inner_reconstruct_pbar = tqdm(total=9,display=False)
inner_reconstruct_pbar_wrapper = ipywidgets.HBox(
inner_reconstruct_pbar.container.children[:2],
layout=ipywidgets.Layout(width="160px")
)
def update_figure(
*args,
):
""" """
# compute new datasets
_virtual_masks_annular = []
if area_toggle.value:
ring_collection_angles = np.linspace(
inner_collection_angle_slider.value**2,
outer_collection_angle_slider.value**2,
num=number_of_rings_slider.value + 1
)**(1/2)
else:
ring_collection_angles = np.linspace(
inner_collection_angle_slider.value,
outer_collection_angle_slider.value,
num=number_of_rings_slider.value + 1
)
if rotate_half_the_rings.value:
ring_rotation = np.deg2rad((180/number_of_segments_slider.value))
else:
ring_rotation = 0
for i in range(1,number_of_rings_slider.value+1):
j = i-1
_virtual_masks_annular.append(
annular_segmented_detectors(
gpts=(n,n),
sampling=(sampling,sampling),
n_angular_bins=number_of_segments_slider.value,
inner_radius=ring_collection_angles[j],
outer_radius=ring_collection_angles[i],
rotation_offset=np.deg2rad(rotation_offset_slider.value) + ring_rotation*(j%2),
)
)
virtual_masks_annular[0] = np.vstack(_virtual_masks_annular)
com_x, com_y = compute_com_using_virtual_detectors(
intensities[0],
virtual_masks_annular[0],
intensities[1],
sx,sy,
kxa,kya,
)
icom_annular = integrate_com(com_x,com_y,kx_op,ky_op)
ctf_annular = ctf.compute_ctf(icom_annular)
q_bins_annular, I_bins_annular = ctf.radially_average_ctf(
ctf_annular,
(sampling,sampling)
)
q_bins_annular_snr, I_bins_annular_snr = ctf.radially_average_ctf(
ctf_annular*inverse_error,
(sampling,sampling)
)
# update data
# 2D arrays
im_detector.set_data(ctf.combined_images_rgb(virtual_masks_annular[0]))
im_ctf_dpc.set_data(ctf.histogram_scaling(np.fft.fftshift(ctf_annular),normalize=True))
# 1D lines
plot_ctf_dpc.set_ydata(I_bins_annular)
# vlines
ax_ctf_rad.collections[0].remove()
ax_ctf_rad.vlines(
[
inner_collection_angle_slider.value,
outer_collection_angle_slider.value
],0,2,
colors='k',linestyles='--',linewidth=1,
)
im_ctf_dpc.set_alpha(1)
plot_ctf_dpc.set_alpha(1)
iterate_button.button_style = 'warning'
if amplitude_values[0] is not None:
amplitude_values[0] = None
reset_wrapper()
im_ctf_ptycho.set_alpha(0.25)
im_white_noise_obj.set_alpha(0.25)
im_sto_obj.set_alpha(0.25)
im_mof_obj.set_alpha(0.25)
im_apo_obj.set_alpha(0.25)
# re-draw figure
fig.canvas.draw_idle()
return None
inner_collection_angle_slider.observe(update_figure,names='value')
outer_collection_angle_slider.observe(update_figure,names='value')
number_of_segments_slider.observe(update_figure,names='value')
rotation_offset_slider.observe(update_figure,names='value')
number_of_rings_slider.observe(update_figure,names='value')
rotate_half_the_rings.observe(update_figure,names='value')
area_toggle.observe(update_figure,names='value')
iterate_button.button_style = 'warning'
def simulate(
defocus,
):
""" """
intensities[0], probe_array_fourier[0] = simulate_intensities(
defocus=defocus,
)
intensities[1] = intensities[0].sum((-1,-2))
update_figure("dummy")
return None
display(
ipywidgets.VBox(
[
ipywidgets.VBox(
[
ipywidgets.HBox([defocus_slider,simulate_button,simulation_pbar_wrapper]),
ipywidgets.HTML("<hr>",layout=ipywidgets.Layout(width="640px")),
ipywidgets.HBox([inner_collection_angle_slider,outer_collection_angle_slider]),
ipywidgets.HBox([number_of_segments_slider,rotation_offset_slider]),
ipywidgets.HBox([number_of_rings_slider,rotate_half_the_rings,area_toggle]),
ipywidgets.HTML("<hr>",layout=ipywidgets.Layout(width="640px")),
ipywidgets.HBox([batch_size_slider,iterations_slider]),
ipywidgets.HBox([reset_button,iterate_button,outer_reconstruct_pbar_wrapper,inner_reconstruct_pbar_wrapper]),
]
),
fig.canvas
]
)
)
VBox(children=(VBox(children=(HBox(children=(IntSlider(value=0, continuous_update=False, description='negative…