Evaluating the Transfer of Information in Phase Retrieval STEM Techniques

Pixelated Parallax

# enable interactive matplotlib
%matplotlib widget 

import numpy as np
import matplotlib.pyplot as plt
import ctf # import custom plotting / utils
import cmasher as cmr 
import ipywidgets
# parameters
n = 384
q_max = 2 # inverse Angstroms
q_probe = 1 # inverse Angstroms
wavelength = 0.019687 # 300kV
sampling = 1 / q_max / 2 # Angstroms
reciprocal_sampling = 2 * q_max / n # inverse Angstroms
bin_value = n // 96

C10 = -128
C30 = 0

cmap = cmr.viola
sample_cmap = 'gray'
pixelated_parallax_line_color = 'darkred'
sto_potential = np.load("data/STO_projected-potential_192x192_4qprobe.npy")
sto_potential -= sto_potential.mean()
mof_potential = np.load("data/MOF_projected-potential_192x192_4qprobe.npy")
mof_potential -= mof_potential.mean()
apo_potential = np.load("data/apoF_projected-potential_192x192_4qprobe.npy")
apo_potential -= apo_potential.mean()

potentials = [sto_potential,mof_potential,apo_potential]

sto_sampling = 23.67 / sto_potential.shape[0]  # Å
mof_sampling = 4.48 / mof_potential.shape[0]  # nm
apo_sampling = 19.2 / apo_potential.shape[0]  # nm
def autocorrelation(array):
    """ """
    return np.real(
        np.fft.ifft2(
            np.abs(
                np.fft.fft2(
                    array
                )
            )**2
        )
    )

def return_chi(
    q,
    wavelength,
    C10,
    C30,
):
    """ """
    prefactor = 2*np.pi / wavelength
    alpha = q*wavelength
    order_2 = alpha**2 / 2 * C10 
    order_4 = alpha**4 / 4 * C30
    
    return (order_2 + order_4) * prefactor
qx = qy = np.fft.fftfreq(n,sampling)
q2 = qx[:,None]**2 + qy[None,:]**2
q  = np.sqrt(q2)

probe_array_fourier_0 = np.sqrt(
    np.clip(
        (q_probe - q)/reciprocal_sampling + 0.5,
        0,
        1,
    ),
)
probe_array_fourier_0 /= np.sqrt(np.sum(np.abs(probe_array_fourier_0)**2))

chi = return_chi(
    q,
    wavelength,
    C10,
    C30
)
sin_chi = -np.sin(chi)

parallax_ctf_2D = autocorrelation(probe_array_fourier_0) * sin_chi
q_bins, I_bins = ctf.radially_average_ctf(
    parallax_ctf_2D,
    (sampling,sampling)
)

binned_ctf_to_96 = np.abs(parallax_ctf_2D).reshape(
    (
        n//bin_value,
        bin_value,
        n//bin_value,
        bin_value
    )
).mean((1,3))

zero_pad_ctf_to_4qprobe = np.fft.ifftshift(
    np.pad(np.fft.fftshift(binned_ctf_to_96),48)
)

convolved_object_sto = np.fft.ifft2(
    np.fft.fft2(sto_potential) * zero_pad_ctf_to_4qprobe
).real

convolved_object_mof = np.fft.ifft2(
    np.fft.fft2(mof_potential) * zero_pad_ctf_to_4qprobe
).real

convolved_object_apo = np.fft.ifft2(
    np.fft.fft2(apo_potential) * zero_pad_ctf_to_4qprobe
).real

sto_limits = [convolved_object_sto.min(),convolved_object_sto.max()]
mof_limits = [convolved_object_mof.min(),convolved_object_mof.max()]
apo_limits = [convolved_object_apo.min(),convolved_object_apo.max()]

limits = [sto_limits,mof_limits,apo_limits]
with plt.ioff():
    dpi=72
    fig, axs = plt.subplots(1,3,figsize=(640/dpi,270/dpi),dpi=dpi)

im_ctf = axs[0].imshow(
    np.fft.fftshift(
        parallax_ctf_2D
    ),
    vmin=-1,
    vmax=1,
    cmap=cmap
)
ctf.add_scalebar(
    axs[0],
    length=n//4,
    sampling=reciprocal_sampling,
    units=r'$q_{\mathrm{probe}}$',
    color='black'
)
axs[0].set(xticks=[],yticks=[],title="contrast transfer function (CTF)")

plot_ctf = axs[1].plot(
    q_bins,
    I_bins,
    color=pixelated_parallax_line_color
)[0]

axs[1].axhline(0,color='black',lw=1,linestyle='--')

axs[1].set(
    xlim=[0,2],
    ylim=[-1,1],
    aspect= 1,
    xticks=[0,1,2],
    yticks=[],
    xlabel=r"spatial frequency, $q/q_{\mathrm{probe}}$",
    title="radially averaged CTF"
)

im_obj = axs[2].imshow(
    convolved_object_sto,
    cmap=sample_cmap,
    vmin=sto_limits[0],
    vmax=sto_limits[1]
)

ctf.add_scalebar(
    axs[2],
    length=40,
    sampling=sto_sampling,
    units=r'Å',
    size_vertical=2
)

ctf.add_scalebar(
    axs[2],
    length=40,
    sampling=mof_sampling,
    units=r'nm',
    size_vertical=2
)

ctf.add_scalebar(
    axs[2],
    length=40,
    sampling=apo_sampling,
    units=r'nm',
    size_vertical=2
)

sto_scalebar, mof_scalebar, apo_scalebar = axs[2].artists
mof_scalebar.set_visible(False)
apo_scalebar.set_visible(False)

axs[2].set(xticks=[],yticks=[],title="CTF-convolved weak phase object")
fig.tight_layout()

fig.canvas.resizable = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.toolbar_visible = False
fig.canvas.layout.height = "280px"
fig.canvas.layout.width = '640px'
None
style = {'description_width': 'initial'}
layout_half = ipywidgets.Layout(width="320px",height="30px")
layout_quarter = ipywidgets.Layout(width="160px",height="30px")
kwargs = {'style':style,'layout':layout_half}
kwargs_quarter = {'style':style,'layout':layout_quarter}

C10_slider = ipywidgets.FloatSlider(
    value = -128,
    min = -500,
    max = 500, 
    step = 1,
    description = r"negative defocus, $C_{1,0}$ [Å]",
    **kwargs
)

C30_slider = ipywidgets.FloatSlider(
    value = 0,
    min = -100,
    max = 100, 
    step = 0.1,
    description = r"spherical aberration, $C_{3,0}$ [µm]",
    **kwargs
)

scherzer_button = ipywidgets.Button(
    description="use Scherzer defocus",
    **kwargs_quarter
)

clim_button = ipywidgets.ToggleButton(
    value=False,
    description="use relative scaling",
    **kwargs_quarter
)

phase_flip_button = ipywidgets.ToggleButton(
    value=True,
    description="correct phase flipping",
    **kwargs_quarter
)

object_dropdown = ipywidgets.Dropdown(
    options=[("strontium titanate",0),("metal-organic framework",1),("apoferritin protein",2)],
    **kwargs_quarter
)
def update_ctf(*args):
    """ """
    C10 = C10_slider.value
    C30 = C30_slider.value * 1e4
    object_index = object_dropdown.value

    chi = return_chi(
        q,
        wavelength,
        C10,
        C30
    )
    sin_chi = -np.sin(chi)
    parallax_ctf_2D = autocorrelation(probe_array_fourier_0) * sin_chi
    q_bins, I_bins = ctf.radially_average_ctf(
        parallax_ctf_2D,
        (sampling,sampling)
    )

    if phase_flip_button.value:
        _parallax_ctf_2D = np.abs(parallax_ctf_2D)
    else:
        _parallax_ctf_2D = parallax_ctf_2D
        
    binned_ctf_to_96 = _parallax_ctf_2D.reshape(
        (
            n//bin_value,
            bin_value,
            n//bin_value,
            bin_value
        )
    ).mean((1,3))
    
    zero_pad_ctf_to_4qprobe = np.fft.ifftshift(
        np.pad(np.fft.fftshift(binned_ctf_to_96),48)
    )

    chosen_potential = potentials[object_index]
    convolved_object = np.fft.ifft2(
        np.fft.fft2(chosen_potential) * zero_pad_ctf_to_4qprobe
    ).real

    im_ctf.set_data(np.fft.fftshift(parallax_ctf_2D))
    plot_ctf.set_ydata(I_bins)
    
    if clim_button.value:
        convolved_object = ctf.histogram_scaling(convolved_object,normalize=True)
        im_obj.set_data(convolved_object)
        im_obj.set_clim(vmin=0,vmax=1)
    else:
        im_obj.set_data(convolved_object)
        im_obj.set_clim(
        vmin=limits[object_index][0],
        vmax=limits[object_index][1]
    )
    sto_scalebar.set_visible(object_index==0)
    mof_scalebar.set_visible(object_index==1)
    apo_scalebar.set_visible(object_index==2)
    fig.canvas.draw_idle()
    return None

C10_slider.observe(update_ctf,"value")
C30_slider.observe(update_ctf,"value")
object_dropdown.observe(update_ctf,"value")

def apply_scherzer(*args):
    """ """
    Cs = C30_slider.value*1e4
    C10_slider.value = -np.sign(Cs) * np.sqrt(3/2*np.abs(Cs)*wavelength)
    return None

scherzer_button.on_click(apply_scherzer)
clim_button.observe(update_ctf,"value")
phase_flip_button.observe(update_ctf,"value")
ipywidgets.VBox(
    [
        ipywidgets.HBox([C10_slider,C30_slider]),
        ipywidgets.HBox([scherzer_button,clim_button,phase_flip_button,object_dropdown]),
        fig.canvas
    ]
)