Evaluating the Transfer of Information in Phase Retrieval STEM Techniques

Pixelated SSB

# enable interactive matplotlib
%matplotlib widget 

import numpy as np
import matplotlib.pyplot as plt
import ctf # import custom plotting / utils
import cmasher as cmr 
import ipywidgets
# parameters
n = 96
q_max = 2 # inverse Angstroms
q_probe = 1 # inverse Angstroms
wavelength = 0.019687 # 300kV
sampling = 1 / q_max / 2 # Angstroms
reciprocal_sampling = 2 * q_max / n # inverse Angstroms
bin_value = n // 96

C10 = 0
C30 = 0

# cmap = cmr.eclipse
cmap = cmr.viola
sample_cmap = 'gray'
pixelated_ssb_line_color = 'darkgreen'
sto_potential = np.load("data/STO_projected-potential_192x192_4qprobe.npy")
sto_potential -= sto_potential.mean()
mof_potential = np.load("data/MOF_projected-potential_192x192_4qprobe.npy")
mof_potential -= mof_potential.mean()
apo_potential = np.load("data/apoF_projected-potential_192x192_4qprobe.npy")
apo_potential -= apo_potential.mean()

potentials = [sto_potential,mof_potential,apo_potential]

sto_sampling = 23.67 / sto_potential.shape[0]  # Å
mof_sampling = 4.48 / mof_potential.shape[0]  # nm
apo_sampling = 19.2 / apo_potential.shape[0]  # nm
def return_chi(
    q,
    wavelength,
    C10,
    C30,
):
    """ """
    prefactor = 2*np.pi / wavelength
    alpha = q*wavelength
    order_2 = alpha**2 / 2 * C10 
    order_4 = alpha**4 / 4 * C30
    
    return (order_2 + order_4) * prefactor

def return_complex_probe(
    q,
    wavelength,
    C10,
    C30
):
    """ """
    chi = return_chi(
        q,
        wavelength,
        C10,
        C30
    )
    return probe_array_fourier_0 * np.exp(-1j*chi)

def return_ssb_ctf(
    complex_probe,
    q
):
    """ """
    complex_probe_conj = complex_probe.conj()
    ssb_ctf = np.zeros((n,n))
    for sx in range(n):
        for sy in range(n):
            if q[sx,sy] < q_max:
                shifted_probe_plus = np.roll(complex_probe,(-sx,-sy),axis=(0,1))
                shifted_probe_minus = np.roll(complex_probe,(sx,sy),axis=(0,1))
            
                gamma = complex_probe_conj * shifted_probe_minus - complex_probe * shifted_probe_plus.conj()
                ssb_ctf[sx,sy] = np.abs(gamma).sum() / 2
            else:    
                ssb_ctf[sx,sy] = 0.0
    return ssb_ctf
qx = qy = np.fft.fftfreq(n,sampling)
q2 = qx[:,None]**2 + qy[None,:]**2
q  = np.sqrt(q2)

probe_array_fourier_0 = np.sqrt(
    np.clip(
        (q_probe - q)/reciprocal_sampling + 0.5,
        0,
        1,
    ),
)
probe_array_fourier_0 /= np.sqrt(np.sum(np.abs(probe_array_fourier_0)**2))

complex_probe = return_complex_probe(
    q,
    wavelength,
    C10,
    C30
)

ssb_ctf_2D = return_ssb_ctf(complex_probe,q)
q_bins, I_bins = ctf.radially_average_ctf(
    ssb_ctf_2D,
    (sampling,sampling)
)

binned_ctf_to_96 = ssb_ctf_2D.reshape(
    (
        n//bin_value,
        bin_value,
        n//bin_value,
        bin_value
    )
).mean((1,3))

zero_pad_ctf_to_4qprobe = np.fft.ifftshift(
    np.pad(np.fft.fftshift(binned_ctf_to_96),48)
)

convolved_object_sto = np.fft.ifft2(
    np.fft.fft2(sto_potential) * zero_pad_ctf_to_4qprobe
).real

convolved_object_mof = np.fft.ifft2(
    np.fft.fft2(mof_potential) * zero_pad_ctf_to_4qprobe
).real

convolved_object_apo = np.fft.ifft2(
    np.fft.fft2(apo_potential) * zero_pad_ctf_to_4qprobe
).real

sto_limits = [convolved_object_sto.min(),convolved_object_sto.max()]
mof_limits = [convolved_object_mof.min(),convolved_object_mof.max()]
apo_limits = [convolved_object_apo.min(),convolved_object_apo.max()]

limits = [sto_limits,mof_limits,apo_limits]
with plt.ioff():
    dpi=72
    fig, axs = plt.subplots(1,3,figsize=(640/dpi,270/dpi),dpi=dpi)

im_ctf = axs[0].imshow(
    np.fft.fftshift(
        ssb_ctf_2D
    ),
    vmin=-1,
    vmax=1,
    cmap=cmap
)
ctf.add_scalebar(
    axs[0],
    length=n//4,
    sampling=reciprocal_sampling,
    units=r'$q_{\mathrm{probe}}$',
    color="black"
)
axs[0].set(xticks=[],yticks=[],title="contrast transfer function (CTF)")

plot_ctf = axs[1].plot(
    q_bins,
    I_bins,
    color=pixelated_ssb_line_color
)[0]

axs[1].axhline(0,color='black',lw=1,linestyle='--')

axs[1].set(
    xlim=[0,2],
    ylim=[0,1],
    aspect= 2,
    xticks=[0,1,2],
    yticks=[],
    xlabel=r"spatial frequency, $q/q_{\mathrm{probe}}$",
    title="radially averaged CTF"
)

im_obj = axs[2].imshow(
    ctf.histogram_scaling(
        convolved_object_sto,
        normalize=True,
    ),
    cmap=sample_cmap,
    vmin=0,
    vmax=1
)

ctf.add_scalebar(
    axs[2],
    length=40,
    sampling=sto_sampling,
    units=r'Å',
    size_vertical=2,
)

ctf.add_scalebar(
    axs[2],
    length=40,
    sampling=mof_sampling,
    units=r'nm',
    size_vertical=2
)

ctf.add_scalebar(
    axs[2],
    length=40,
    sampling=apo_sampling,
    units=r'nm',
    size_vertical=2
)

sto_scalebar, mof_scalebar, apo_scalebar = axs[2].artists
mof_scalebar.set_visible(False)
apo_scalebar.set_visible(False)

axs[2].set(xticks=[],yticks=[],title="CTF-convolved weak phase object")
fig.tight_layout()

fig.canvas.resizable = False
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.toolbar_visible = False
fig.canvas.layout.height = "280px"
fig.canvas.layout.width = '640px'
None
style = {'description_width': 'initial'}
layout_half = ipywidgets.Layout(width="320px",height="30px")
layout_quarter = ipywidgets.Layout(width="160px",height="30px")
kwargs = {'style':style,'layout':layout_half,"continuous_update":False,}
kwargs_quarter = {'style':style,'layout':layout_quarter,"continuous_update":False,}

C10_slider = ipywidgets.FloatSlider(
    value = 0,
    min = -500,
    max = 500, 
    step = 1,
    description = r"negative defocus, $C_{1,0}$ [Å]",
    **kwargs
)

C30_slider = ipywidgets.FloatSlider(
    value = 0,
    min = -100,
    max = 100, 
    step = 0.1,
    description = r"spherical aberration, $C_{3,0}$ [µm]",
    **kwargs
)

scherzer_button = ipywidgets.Button(
    description="use Scherzer defocus",
    **kwargs_quarter
)

clim_button = ipywidgets.ToggleButton(
    value=True,
    description="use relative scaling",
    **kwargs_quarter
)

object_dropdown = ipywidgets.Dropdown(
    options=[("strontium titanate",0),("metal-organic framework",1),("apoferritin protein",2)],
    **kwargs
)
def update_ctf(*args):
    """ """
    C10 = C10_slider.value
    C30 = C30_slider.value * 1e4
    object_index = object_dropdown.value

    complex_probe = return_complex_probe(
        q,
        wavelength,
        C10,
        C30
    )
    
    ssb_ctf_2D = return_ssb_ctf(complex_probe,q)
    q_bins, I_bins = ctf.radially_average_ctf(
        ssb_ctf_2D,
        (sampling,sampling)
    )
    
    binned_ctf_to_96 = ssb_ctf_2D.reshape(
        (
            n//bin_value,
            bin_value,
            n//bin_value,
            bin_value
        )
    ).mean((1,3))
    
    zero_pad_ctf_to_4qprobe = np.fft.ifftshift(
        np.pad(np.fft.fftshift(binned_ctf_to_96),48)
    )

    chosen_potential = potentials[object_index]
    convolved_object = np.fft.ifft2(
        np.fft.fft2(chosen_potential) * zero_pad_ctf_to_4qprobe
    ).real

    if clim_button.value:
        convolved_object = ctf.histogram_scaling(convolved_object,normalize=True)
        im_obj.set_data(convolved_object)
        im_obj.set_clim(vmin=0,vmax=1)
    else:
        im_obj.set_data(convolved_object)
        im_obj.set_clim(
        vmin=limits[object_index][0],
        vmax=limits[object_index][1]
    )

    im_ctf.set_data(np.fft.fftshift(ssb_ctf_2D))
    plot_ctf.set_ydata(I_bins)
    sto_scalebar.set_visible(object_index==0)
    mof_scalebar.set_visible(object_index==1)
    apo_scalebar.set_visible(object_index==2)
    fig.canvas.draw_idle()
    return None

C10_slider.observe(update_ctf,"value")
C30_slider.observe(update_ctf,"value")
object_dropdown.observe(update_ctf,"value")

def apply_scherzer(*args):
    """ """
    Cs = C30_slider.value*1e4
    C10_slider.value = -np.sign(Cs) * np.sqrt(3/2*np.abs(Cs)*wavelength)
    return None

scherzer_button.on_click(apply_scherzer)
clim_button.observe(update_ctf,"value")
ipywidgets.VBox(
    [
        ipywidgets.HBox([C10_slider,C30_slider]),
        ipywidgets.HBox([scherzer_button,clim_button,object_dropdown]),
        fig.canvas
    ]
)