A Practical Guide to Scanning and Transmission Electron Microscopy Simulations

Interactive figure for 2D Fourier transforms

TODO: add colorwheel legend?

%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
from IPython.display import display
from ipywidgets import HBox, VBox, Dropdown, Layout, Label, Output, widgets
from matplotlib import cm
from matplotlib.colors import hsv_to_rgb
dpi = 72

shape_fft = (32,32)
shape_real = (128,128)

# fourier coordinates
kx = np.fft.fftfreq(shape_fft[0],d=1.0)[:,None]
ky = np.fft.fftfreq(shape_fft[1],d=1.0)[None,:]
k2 = kx**2 + ky**2
k1 = np.sqrt(k2)


def make_im_real(
    im_fft,
    shape_real,
):
    s = im_fft.shape
    im_real = np.zeros(shape_real,dtype='complex')
    im_real[:s[0]//2,   :s[1]//2] =    im_fft[:s[0]//2,   :s[1]//2]
    im_real[shape_real[0]-s[0]//2:,:s[1]//2] =    im_fft[ s[0]//2:,:s[1]//2]
    im_real[:s[0]//2,    shape_real[1]-s[1]//2:] = im_fft[:s[0]//2,    s[1]//2:]
    im_real[shape_real[0]-s[0]//2:,shape_real[1]-s[1]//2:] = im_fft[ s[0]//2:, s[1]//2:]
    im_real = np.fft.ifft2(im_real)
    return im_real

def complex_im(
    im,
    amp_range = None,
    amp_power = 0.75,
):
    if amp_range is None:
        amp_range = np.array((0, np.max(np.abs(im))))

    a = np.clip((np.abs(im)-amp_range[0]) / (amp_range[1]-amp_range[0]),0,1)**amp_power
    p = np.angle(im)
    im_rgb = np.ones((im.shape[0],im.shape[1],3))
    im_rgb[:,:,0] = np.mod(p/(2*np.pi),1.0)
    im_rgb[:,:,2] = a
    im_rgb = hsv_to_rgb(im_rgb)   

    # rescale color ranges?
    # im_rgb[:,:,0] = im_rgb[:,:,0]*1.0 + im_rgb[:,:,2]*0.3
    im_rgb[:,:,1] = im_rgb[:,:,1]*0.7 + im_rgb[:,:,2]*0.7
    # im_rgb[:,:,2] = im_rgb[:,:,2]*1.4 
    im_rgb = np.clip(im_rgb,0,1)
    
    return im_rgb
   
# initial images
k_max = 0.39
dk = kx[1]-kx[0]
im_fft = np.clip(
    (k_max - k1)/dk + 0.5,
    0,
    1,
)
im_real = make_im_real(im_fft,shape_real) 

# fig = plt.figure(figsize=(680/dpi, 340/dpi), dpi=dpi)
with plt.ioff():
    fig = plt.figure(figsize=(500/dpi, 300/dpi), dpi=dpi)
ax_fft = fig.add_axes([0.01,  0.02,  0.41, 0.96])
ax_real  = fig.add_axes([0.51, 0.02, 0.41, 0.96])

# global variables
fig.dragging = False
fig.xy0 = (0,0)
fig.im_fft = im_fft
fig.im_fft_0 = im_fft.copy()

# plot initial images
h_im_fft = ax_fft.imshow(
    np.fft.fftshift(complex_im(fig.im_fft)),
)
h_im_real = ax_real.imshow(
    np.fft.fftshift(complex_im(im_real)),
)
# origins
ax_fft.scatter(
    shape_fft[1]//2,
    shape_fft[0]//2,
    marker = '+',
    color = 'r',
    s = 100,
)
ax_real.scatter(
    shape_real[1]//2,
    shape_real[0]//2,
    marker = '+',
    color = 'r',
    s = 100,
)

# Appearance
fig.canvas.resizable=False
fig.canvas.toolbar_visible = True
fig.canvas.header_visible = False
fig.canvas.footer_visible = False
fig.canvas.layout.width = '500px'
fig.canvas.layout.height = '300px'
fig.canvas.toolbar_position = 'bottom'

# Axes appearance
ax_fft.axis('off');
ax_real.axis('off');
ax_fft.set_title('Click to switch pixels on and off');
ax_real.set_title('Drag to move with Fourier shift theorem');

def update_images():
    h_im_fft.set_data(np.fft.fftshift(complex_im(fig.im_fft)))
    im_real = make_im_real(fig.im_fft,shape_real) 
    h_im_real.set_data(np.fft.fftshift(complex_im(im_real)))

def button_press_callback(event):
    # if left click and within axes, proceed with mouse movement event
    if event.inaxes is None:
        return
    if event.button != 1:
        return

    # Check if mouse is within fft axes
    t = ax_fft.transData.inverted()
    xy_fft = t.transform([event.x,event.y])[::-1]
    if xy_fft[0] > 0 and \
        xy_fft[1] > 0 and \
        xy_fft[0] < shape_fft[0] and \
        xy_fft[1] < shape_fft[1]:

        x = np.mod(np.round(xy_fft[0]+shape_fft[0]//2).astype('int'),shape_fft[0])
        y = np.mod(np.round(xy_fft[1]+shape_fft[1]//2).astype('int'),shape_fft[1])
        
        if np.abs(fig.im_fft[x,y]) < 0.1:
            fig.im_fft[x,y] = 1.0 * np.exp(1j*np.angle(fig.im_fft[x,y]))
        else:
            fig.im_fft[x,y] = 1e-8 * np.exp(1j*np.angle(fig.im_fft[x,y]))
        update_images()
        
    else:
        # Check if mouse is within real axes
        t = ax_real.transData.inverted()
        xy_real = t.transform([event.x,event.y])[::-1]
        fig.dragging = True
        fig.xy0 = xy_real
        fig.im_fft_0 = fig.im_fft.copy()

def motion_notify_callback(event):
    if fig.dragging is False:
        return
    # if event.inaxes is None:
    #     return
    if event.button != 1:
        return

    t = ax_real.transData.inverted()
    xy_real = t.transform([event.x,event.y])[::-1]
    dx = (xy_real[0] - fig.xy0[0]) / 4
    dy = (xy_real[1] - fig.xy0[1]) / 4

    fig.im_fft = fig.im_fft_0 * np.exp((-1j*2.0*np.pi)*(kx*dx + ky*dy))
    update_images()               

def button_release_callback(event):
    fig.dragging = False
    update_images()                   

# button callbacks
fig.canvas.mpl_connect('button_press_event', button_press_callback)
fig.canvas.mpl_connect('button_release_event', button_release_callback)
fig.canvas.mpl_connect('motion_notify_event', motion_notify_callback)    

complex_functions = (
    'circular aperture - large', 
    'circular aperture - medium', 
    'circular aperture - small', 
    'circular aperture - prism 2', 
    'circular aperture - prism 4', 
    'lattice - hexagonal',
    'lattice - square',
    'lattice - complex square',
    'two beam',
    'systematic row',
)
# update the plots with a pre-selected function
def select_preset_eventhandler(change):
   
    if change.new == complex_functions[0]:
        # circular aperture - large
        k_max = 0.29
        fig.im_fft = np.clip((k_max - k1)/dk + 0.5,0,1)
    elif change.new == complex_functions[1]:
        # circular aperture - medium
        k_max = 0.19
        fig.im_fft = np.clip((k_max - k1)/dk + 0.5,0,1)
    elif change.new == complex_functions[2]:
        # circular aperture - small
        k_max = 0.09
        fig.im_fft = np.clip((k_max - k1)/dk + 0.5,0,1)
    
    elif change.new == complex_functions[3]:
        # circular aperture - prism 2
        k_max = 0.29
        fig.im_fft = np.clip((k_max - k1)/dk + 0.5,0,1)
        fig.im_fft[1::2,:] = 0.0
        fig.im_fft[:,1::2] = 0.0
    elif change.new == complex_functions[4]:
        # circular aperture - prism 4
        k_max = 0.29
        fig.im_fft = np.clip((k_max - k1)/dk + 0.5,0,1)
        fig.im_fft[1::4,:] = 0.0
        fig.im_fft[2::4,:] = 0.0
        fig.im_fft[3::4,:] = 0.0
        fig.im_fft[:,1::4] = 0.0
        fig.im_fft[:,2::4] = 0.0
        fig.im_fft[:,3::4] = 0.0

    elif change.new == complex_functions[5]:
        # lattice - hexagonal
        fig.im_fft = np.zeros_like(fig.im_fft)
        fig.im_fft[0,0] = 2
        fig.im_fft[ 8, 0] = 1
        fig.im_fft[-8, 0] = 1
        fig.im_fft[-4, 7] = 1
        fig.im_fft[ 4, 7] = 1
        fig.im_fft[-4,-7] = 1
        fig.im_fft[ 4,-7] = 1
    elif change.new == complex_functions[6]:
        # lattice - square
        fig.im_fft = np.zeros_like(fig.im_fft)
        fig.im_fft[0,0] = 4
        fig.im_fft[ 4, 4] = 1
        fig.im_fft[-4, 4] = 1
        fig.im_fft[ 4,-4] = 1
        fig.im_fft[-4,-4] = 1
        fig.im_fft[ 0, 8] = 1
        fig.im_fft[ 0,-8] = 1
        fig.im_fft[ 8, 0] = 1
        fig.im_fft[-8, 0] = 1
        fig.im_fft[ 8, 8] = 1
        fig.im_fft[-8, 8] = 1
        fig.im_fft[ 8,-8] = 1
        fig.im_fft[-8,-8] = 1
        fig.im_fft[ 4, 12] = 1
        fig.im_fft[-4, 12] = 1
        fig.im_fft[ 4,-12] = 1
        fig.im_fft[-4,-12] = 1
        fig.im_fft[ 12, 4] = 1
        fig.im_fft[-12, 4] = 1
        fig.im_fft[ 12,-4] = 1
        fig.im_fft[-12,-4] = 1
    elif change.new == complex_functions[7]:
        # lattice - square complex
        fig.im_fft = np.zeros_like(fig.im_fft)
        fig.im_fft[0,0] = 1

        fig.im_fft[ 8, 0] = 1
        fig.im_fft[-8, 0] = 1
        fig.im_fft[ 0, 8] = 1
        fig.im_fft[ 0,-8] = 1
        fig.im_fft[ 8, 8] = 1
        fig.im_fft[-8, 8] = 1
        fig.im_fft[ 8,-8] = 1
        fig.im_fft[-8,-8] = 1

        fig.im_fft[ 4, 4] = 0.5
        fig.im_fft[-4, 4] = 0.5
        fig.im_fft[ 4,-4] = 0.5
        fig.im_fft[-4,-4] = 0.5
        fig.im_fft[ 12, 4] = 0.5
        fig.im_fft[-12, 4] = 0.5
        fig.im_fft[ 12,-4] = 0.5
        fig.im_fft[-12,-4] = 0.5
        fig.im_fft[ 4, 12] = 0.5
        fig.im_fft[-4, 12] = 0.5
        fig.im_fft[ 4,-12] = 0.5
        fig.im_fft[-4,-12] = 0.5
        fig.im_fft[ 12, 12] = 0.5
        fig.im_fft[-12, 12] = 0.5
        fig.im_fft[ 12,-12] = 0.5
        fig.im_fft[-12,-12] = 0.5

        fig.im_fft[ 4, 0] = 0.25
        fig.im_fft[-4, 0] = 0.25
        fig.im_fft[ 0, 4] = 0.25
        fig.im_fft[ 0,-4] = 0.25
        fig.im_fft[ 8, 4] = 0.25
        fig.im_fft[-8, 4] = 0.25
        fig.im_fft[ 8,-4] = 0.25
        fig.im_fft[-8,-4] = 0.25
        fig.im_fft[ 4, 8] = 0.25
        fig.im_fft[-4, 8] = 0.25
        fig.im_fft[ 4,-8] = 0.25
        fig.im_fft[-4,-8] = 0.25
        fig.im_fft[ 12, 0] = 0.25
        fig.im_fft[-12, 0] = 0.25
        fig.im_fft[0, 12] = 0.25
        fig.im_fft[0,-12] = 0.25
        fig.im_fft[ 8, 12] = 0.25
        fig.im_fft[-8, 12] = 0.25
        fig.im_fft[ 8,-12] = 0.25
        fig.im_fft[-8,-12] = 0.25
        fig.im_fft[ 12, 8] = 0.25
        fig.im_fft[-12, 8] = 0.25
        fig.im_fft[ 12,-8] = 0.25
        fig.im_fft[-12,-8] = 0.25
        
    elif change.new == complex_functions[8]:
        # two beam
        fig.im_fft = np.zeros_like(fig.im_fft)
        fig.im_fft[0,0] = 1
        fig.im_fft[5,8] = 1
    elif change.new == complex_functions[9]:
        # two beam
        fig.im_fft = np.zeros_like(fig.im_fft)
        fig.im_fft[0,0] = 1
        fig.im_fft[-12,8] = 1
        fig.im_fft[-9,6] = 1
        fig.im_fft[-6,4] = 1
        fig.im_fft[-3,2] = 1
        fig.im_fft[12,-8] = 1
        fig.im_fft[9,-6] = 1
        fig.im_fft[6,-4] = 1
        fig.im_fft[3,-2] = 1
        
    update_images()
        
# Widgets
dropdown = Dropdown(
    options = complex_functions,
    layout = Layout(width='180px',height='30px'),
)
dropdown.observe(select_preset_eventhandler, names='value')

# widget layouts
widget = HBox([
    fig.canvas,
    VBox([
        Label('Example Functions',layout=Layout(width='180px',height='30px')), 
        dropdown,
    ]),
])
display(widget);