Sparse coding of large images

In this post we would try to show how one could infer the sparse representation of an image knowing an appropriate generative model for its synthesis. We will start by a linear inversion (pseudo-inverse deconvolution), and then move to a gradient descent algorithm. Finally, we will implement a convolutional version of the iterative shrinkage thresholding algorithm (ISTA) and its fast version, the FISTA.

For computational efficiency, all convolutions will be implemented by a Fast Fourier Tansform, so that a standard convolution will be mathematically exactly similar. We will benchmark this on a realistic image size of $512 \times 512$ giving some timing results on a standard laptop.

Let's first initialize the notebook:

In [1]:
from __future__ import division, print_function
import numpy as np
np.set_printoptions(precision=6, suppress=True)
import os
%matplotlib inline
#%config InlineBackend.figure_format='retina'
%config InlineBackend.figure_format = 'svg'
import matplotlib.pyplot as plt
phi = (np.sqrt(5)+1)/2
fig_width = 15
figsize = (fig_width, fig_width/phi)
%load_ext autoreload
%autoreload 2

problem statement: sparse events in a static image

In the first case that we will study, images are produced by the combination of a single kernel with sparse events. Think of it like some brushstrokes on a canvas, somes drops of rain that fell on the surface of a pond...

For simplicity, the combination of these events will be considered linear. This is mostly true for instance in watercolor or in combining transparent surfaces and by ignoring saturation of the sensor for instance. To handle linearity, we will use linear algebra and matrices: We will denote as $x$ the (raveled) image and these "events" as $y$. The kernels that define the brushstrokes is denoted as $A$ and will be explicited below. Generally, in matrix form, the genrativve model of the image is written:

$$ x = A y $$

First, coefficients are drawn from a Laplace distribution and sparse events are generated by thresholding these coefficients. These will be convolved with a kernel defined using the MotionCloudslibrary. Let's first illustrate that for a small image size:

In [2]:
N_X, N_Y = 2**9, 2**9 # BENCHMARK
# N_X, N_Y = 2**8, 2**8 # DEBUG

rho = 1.e-3
sf_0 = .15
B_sf = .15
y_lambda = 1.61803
In [3]:
seed = 2020
rng = np.random.RandomState(seed)
events = y_lambda * rng.laplace(size=(N_X, N_Y))
In [4]:
events_max = np.max(np.abs(events))
fig, ax = plt.subplots(figsize=figsize)
ax.hist(events.ravel(), bins=np.linspace(-events_max, events_max, 100, endpoint=True))
ax.set_xlabel('coefficients')
ax.set_yscale('log')
ax.set_ylabel('#');
2022-03-25T10:01:22.430741image/svg+xmlMatplotlib v3.5.1, https://matplotlib.org/
In [5]:
print('mean, std=', events.mean(), events.std())
mean, std= -0.008859802575012121 2.2900579569626447

From this continuous distribution of coefficients, we will zero out to achieve a desired sparsity (in the $\ell_0$ pseudoi-norm sense):

In [6]:
threshold = np.quantile(np.absolute(events).ravel(), 1-rho)
print('threshold=', threshold)
threshold= 11.432255842859089
In [7]:
events_thr = events  * ((events < -threshold) + (events > threshold))
In [8]:
fig, ax = plt.subplots(figsize=figsize)
ax.hist(events_thr.ravel(), bins=np.linspace(-events_max, events_max, 100, endpoint=True))
ax.set_xlabel('coefficients')
ax.set_yscale('log')
ax.set_ylabel('#');
2022-03-25T10:01:22.961977image/svg+xmlMatplotlib v3.5.1, https://matplotlib.org/
In [9]:
print('mean, std=', events_thr.mean(), events_thr.std())
mean, std= 0.00031693423294735503 0.4103409801780267
In [10]:
fig, ax = plt.subplots(figsize=(fig_width, fig_width))
ax.imshow(events_thr, cmap=plt.viridis());
2022-03-25T10:01:23.265580image/svg+xmlMatplotlib v3.5.1, https://matplotlib.org/

Let's now generate the image by convolving it with an isotropic kernel, as defined in the https://github.com/NeuralEnsemble/MotionClouds library:

In [11]:
import MotionClouds as mc
fx, fy, ft = mc.get_grids(N_X, N_Y, 1)
opts = dict(V_X=0., V_Y=0., B_V=0, B_theta=np.inf, sf_0=sf_0, B_sf=B_sf)
envelope = mc.envelope_gabor(fx, fy, ft, **opts).squeeze()
#env_smooth = mc.retina(fx, fy, ft, df=0.07)[:, :, 0]
In [12]:
fig, ax = plt.subplots(figsize=(fig_width, fig_width))
#ax.imshow(envelope[(N_X//2-N_X//4):(N_X//2+N_X//4), (N_Y//2-N_Y//4):(N_Y//2+N_Y//4)], cmap=plt.plasma()); 
ax.imshow(envelope, cmap=plt.plasma()); 
2022-03-25T10:01:23.677760image/svg+xmlMatplotlib v3.5.1, https://matplotlib.org/

(This shape reminds me something... DOOH!)

Then we use the Fourier transform to actually perform the convolution (as implemented in the Motion Clouds library):

In [13]:
x = mc.random_cloud(envelope[:, :, None], events=events_thr[:, :, None])
x = x.reshape((N_X, N_Y))
print('x.shape=', x.shape)
x.shape= (512, 512)

This has all the nice properties of the 2D Discrete Fourier Transform, see for example this textbook (in French).

In [14]:
fig, ax = plt.subplots(figsize=(fig_width, fig_width))
vmax = np.absolute(x).max()
ax.imshow(x, cmap=plt.gray(), vmin=-vmax, vmax=vmax);
2022-03-25T10:01:24.127615image/svg+xmlMatplotlib v3.5.1, https://matplotlib.org/
In [15]:
x_max = np.max(np.abs(x))
fig, ax = plt.subplots(figsize=figsize)
ax.hist(x.ravel(), bins=np.linspace(-x_max, x_max, 100, endpoint=True))
ax.set_xlabel('luminance')
ax.set_yscale('log')
ax.set_ylabel('#');
2022-03-25T10:01:24.699847image/svg+xmlMatplotlib v3.5.1, https://matplotlib.org/

Note: one could have also used https://en.wikipedia.org/wiki/Shot_noise

All in one function:

In [16]:
def MC_env(N_X, N_Y, opts=dict(V_X=0., V_Y=0., B_V=0, B_theta=np.inf, sf_0=sf_0, B_sf=B_sf, alpha=0), do_norm=True, verbose=False):
    fx, fy, ft = mc.get_grids(N_X, N_Y, 1)
    envelope = mc.envelope_gabor(fx, fy, ft, **opts).squeeze()
    envelope /= envelope.max()
    if do_norm: envelope /= np.sqrt((envelope**2).sum())
    if verbose: print('(envelope**2).sum()=', (envelope**2).sum())
    return envelope
envelope = MC_env(N_X, N_Y, verbose=True)
(envelope**2).sum()= 1.0000000000000002
In [17]:
%%timeit
envelope = MC_env(N_X, N_Y)
13.5 ms ± 398 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [18]:
def random_cloud(envelope, events):
    (N_X, N_Y) = envelope.shape
    #fx, fy, ft = mc.get_grids(N_X, N_Y, N_frame)    
    F_events = np.fft.fftn(events)
    F_events = np.fft.fftshift(F_events)
    
    Fz = F_events * envelope
    # de-centering the spectrum
    Fz = np.fft.ifftshift(Fz)
    #Fz[0, 0, 0] = 0. # removing the DC component
    z = np.fft.ifftn(Fz).real
    return z
In [19]:
%%timeit
random_cloud(envelope, events_thr)
15 ms ± 307 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [20]:
def model(envelope, events, verbose=False):
    if verbose: print('envelope.shape = ', envelope.shape)
    if verbose: print('events.shape = ', events.shape)
    N_X, N_Y = envelope.shape
    x = random_cloud(envelope, events=events)
    #x = x.reshape((N_X, N_Y))
    if verbose: print('x.shape=', x.shape)
    return x
x = model(envelope, events_thr, verbose=True)
envelope.shape =  (512, 512)
events.shape =  (512, 512)
x.shape= (512, 512)
In [21]:
%%timeit
model(envelope, events_thr)
14.7 ms ± 449 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Merging both functions:

In [22]:
def model(envelope, events, verbose=False):
    if verbose: print('envelope.shape = ', envelope.shape)
    if verbose: print('events.shape = ', events.shape)
    N_X, N_Y = envelope.shape
    F_events = np.fft.fftn(events)
    F_events = np.fft.fftshift(F_events)
    Fz = F_events * envelope
    Fz = np.fft.ifftshift(Fz)
    x = np.fft.ifftn(Fz).real
    if verbose: print('x.shape=', x.shape)
    return x
x = model(envelope, events_thr, verbose=True)
envelope.shape =  (512, 512)
events.shape =  (512, 512)
x.shape= (512, 512)
In [23]:
%%timeit
model(envelope, events_thr)
14.9 ms ± 433 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
In [24]:
def generative_model(envelope, rho=rho, y_lambda=y_lambda, seed=seed, verbose=False):
    N_X, N_Y = envelope.shape
    if verbose: print('N_X, N_Y = ', envelope.shape)
    rng = np.random.RandomState(seed)
    events = y_lambda * rng.laplace(size=(N_X, N_Y))

    threshold = np.quantile(np.absolute(events).ravel(), 1.-rho)
    events = events  * ((events < -threshold) + (events > threshold))

    x = model(envelope, events, verbose=verbose)
    return events, x
events, x = generative_model(envelope, verbose=True)
N_X, N_Y =  (512, 512)
envelope.shape =  (512, 512)
events.shape =  (512, 512)
x.shape= (512, 512)
In [25]:
fig, ax = plt.subplots(figsize=(fig_width, fig_width))
vmax = np.absolute(x).max()
ax.imshow(x, cmap=plt.gray(), vmin=-vmax, vmax=vmax);
2022-03-25T10:02:12.309746image/svg+xmlMatplotlib v3.5.1, https://matplotlib.org/
In [26]:
envelope = MC_env(N_X, N_Y)
In [27]:
%%timeit
events, x = generative_model(envelope)
25.9 ms ± 591 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

Now, how do I retrieve the events knowing x? This is a form of deconvolution:

Deconvolution in the noiseless case: invertibility of the image to retrieve events

It's relatively easy to use the shape of the enveloppe to retrieve the events:

In [28]:
eps = 2.e-6
thr = 5.e-2
def deconv(envelope, eps=eps, thr=thr, do_norm=False):
    # mask = 1 / ( 1 + np.exp( -(envelope/envelope.max() - thr) / eps ) )# coefficients near zero
    fr = mc.frequency_radius(fx, fy, ft).squeeze()
    mask =  1 / ( 1 + np.exp( (fr - .45) / .025 ) ) # high frequency
    mask *=  1 / ( 1 + np.exp( -(fr - thr) / .01 ) )# low frequency
    
    F_deconv = envelope / (envelope**2 + eps*(1-mask)) # avoid division by zero
    F_deconv *= mask
    if do_norm: F_deconv /= np.sqrt((F_deconv**2).sum())    
    return F_deconv / ((envelope*F_deconv)).mean()

F_deconv = deconv(envelope, eps=eps, thr=thr)

fig, ax = plt.subplots(figsize=(fig_width, fig_width))
#cmap = ax.imshow(F_deconv[(N_X//2-N_X//4):(N_X//2+N_X//4), (N_Y//2-N_Y//4):(N_Y//2+N_Y//4)], cmap=plt.plasma())
cmap = ax.imshow(F_deconv, cmap=plt.plasma())
plt.colorbar(cmap, shrink=.8);
2022-03-25T10:02:14.835700image/svg+xmlMatplotlib v3.5.1, https://matplotlib.org/
In [29]:
fig, ax = plt.subplots(figsize=(fig_width, fig_width))
cmap = ax.imshow(envelope*F_deconv, cmap=plt.plasma())
plt.colorbar(cmap, shrink=.8); 
2022-03-25T10:02:15.302642image/svg+xmlMatplotlib v3.5.1, https://matplotlib.org/