Extending Olshausens classical SparseNet

This is an old blog post, see the newer version in this post

  • In a previous notebook, we tried to reproduce the learning strategy specified in the framework of the SparseNet algorithm from Bruno Olshausen. It allows to efficiently code natural image patches by constraining the code to be sparse. In particular, we saw that in order to optimize competition, it is important to control cooperation and we implemented a heuristic to just do this.

  • In this notebook, we provide an extension to the SparseNet algorithm. We will study how homeostasis (cooperation) may be an essential ingredient to this algorithm working on a winner-take-all basis (competition). This extension has been published as Perrinet, Neural Computation (2010) (see https://laurentperrinet.github.io/publication/perrinet-10-shl ):

@article{Perrinet10shl,
    Title = {Role of homeostasis in learning sparse representations},
    Author = {Perrinet, Laurent U.},
    Journal = {Neural Computation},
    Year = {2010},
    Doi = {10.1162/neco.2010.05-08-795},
    Keywords = {Neural population coding, Unsupervised learning, Statistics of natural images, Simple cell receptive fields, Sparse Hebbian Learning, Adaptive Matching Pursuit, Cooperative Homeostasis, Competition-Optimized Matching Pursuit},
    Month = {July},
    Number = {7},
    Url = {https://laurentperrinet.github.io/publication/perrinet-10-shl},
    Volume = {22},
}

dictionary learning on natural images

Here, we reproduce the dictionary learning obtained with the SparseNet alogrithm while using the Orthogonal Matching Pursuit algorithm for the sparse coding algorithm.

In [7]:
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format='svg'
import numpy as np
np.set_printoptions(precision=2, suppress=True)
import pandas as pd
import seaborn as sns
%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
In [8]:
from shl_scripts import SHL
database = '/Users/lolo/pool/science/BICV/SHL_scripts/database/'
DEBUG_DOWNSCALE, verbose = 1, 100
DEBUG_DOWNSCALE, verbose = 1, 0
DEBUG_DOWNSCALE, verbose = 10, 100
shl = SHL(database=database, DEBUG_DOWNSCALE=DEBUG_DOWNSCALE, verbose=verbose)
dico = shl.learn_dico()
shl.show_dico(dico)
Extracting data...Mda_archi56.jpg, Mda_arnat105.jpg, Mda_art509.jpg, Fda_art464.jpg, Hda_obj84.jpg, Fdn_open12.jpg, Fdn_open19.jpg, Bda_art1164.jpg, Hdn_objn18.jpg, Mdn_text38.jpg, Fda_A223024.jpg, Hdn_objn98.jpg, Mdn_land159.jpg, Hda_obj172.jpg, Fda_art1017.jpg, Bda_art1271.jpg, Hdn_objn41.jpg, Fdn_nat297.jpg, Mdn_natu961.jpg, Mdn_sharp49.jpg, Fdn_sclos1.jpg, Mdn_N344070.jpg, Hda_obj96.jpg, Mda_gre329.jpg, Mdn_land243.jpg, Hda_obj12.jpg, Fda_art36.jpg, Fda_art1015.jpg, Bda_room43.jpg, Bda_int110.jpg, Mdn_N295008.jpg, Mdn_land313.jpg, Bda_art1168.jpg, Fda_art1249.jpg, Fda_A683045.jpg, Hdn_objn36.jpg, Fdn_bea1.jpg, Bda_art1117.jpg, Hdn_N124002.jpg, Bdn_land71.jpg, Hda_obj348.jpg, Mdn_N344038.jpg, Mdn_N347003.jpg, Mda_art337.jpg, Bdn_N44094.jpg, Mdn_nat1253.jpg, Hdn_objn113.jpg, Hdn_objn190.jpg, Bda_art1171.jpg, Hdn_objn122.jpg, Fdn_open6.jpg, Bdn_text64.jpg, Fdn_natu815.jpg, Fdn_land362.jpg, Bda_art1067.jpg, Hdn_for104.jpg, Fdn_nat462.jpg, Mda_archi100.jpg, Hda_obj414.jpg, Bdn_objn146.jpg, Fda_gre536.jpg, Mdn_N228004.jpg, Mda_enc30.jpg, Hda_obj390.jpg, Mdn_N344035.jpg, Bda_int30.jpg, Bdn_land861.jpg, Mdn_N328009.jpg, Fdn_land317.jpg, Hda_obj65.jpg, Mda_urb556.jpg, Bda_art685.jpg, Fdn_for7.jpg, Fdn_open18.jpg, Fdn_land83.jpg, Hdn_natu369.jpg, Mda_archi303.jpg, Fdn_nat498.jpg, Bdn_nat267.jpg, Fdn_bea33.jpg, Mdn_nat1247.jpg, Bda_int770.jpg, Mdn_N295019.jpg, Mda_urb382.jpg, Bda_art673.jpg, Fda_A698003.jpg, Mdn_natu35.jpg, Hdn_objn24.jpg, Hda_room419.jpg, Mda_art1665.jpg, Mdn_N347022.jpg, Fda_art1486.jpg, Mda_art438.jpg, Fda_A673086.jpg, Hda_int888.jpg, Fda_enc51.jpg, Bdn_natu676.jpg, Hda_obj343.jpg, Bdn_natu977.jpg, Hda_obj265.jpg, Fda_art772.jpg, Fda_street63.jpg, Bda_gre298.jpg, Mdn_N291096.jpg, Hdn_objn177.jpg, Bdn_text15.jpg, Hda_obj95.jpg, Fda_A462040.jpg, Bda_int741.jpg, Fdn_open17.jpg, Mda_urb276.jpg, Mdn_nat492.jpg, Hda_obj454.jpg, Fdn_nat895.jpg, Hdn_objn26.jpg, Mda_archi233.jpg, Mda_art1136.jpg, Bda_art1147.jpg, Hdn_N124064.jpg, Bda_room185.jpg, Hda_obj7.jpg, Bdn_land810.jpg, Mda_urb902.jpg, Fdn_bea5.jpg, Bda_urb291.jpg, Bda_art921.jpg, Mdn_natu302.jpg, Mdn_nat954.jpg, Fda_A683031.jpg, Mdn_for35.jpg, Fda_A277094.jpg, Fda_art1469.jpg, Bdn_text6.jpg, Fdn_open16.jpg, Bda_art923.jpg, Bda_art1175.jpg, Bdn_natu160.jpg, Mda_archi327.jpg, Fda_art779.jpg, Bda_art1185.jpg, Hda_obj94.jpg, Bdn_text19.jpg, Hda_obj89.jpg, Hda_obj2.jpg, Hda_int170.jpg, Bdn_for82.jpg, Bda_room452.jpg, Hda_obj98.jpg, Mdn_for83.jpg, Mda_arnat70.jpg, Fda_A463033.jpg, Fda_urb128.jpg, Mda_art1517.jpg, Bdn_natu419.jpg, Bdn_nat1157.jpg, Fda_A244041.jpg, Mdn_nat211.jpg, Hdn_objn23.jpg, Hdn_objn42.jpg, Fda_art560.jpg, Mda_art361.jpg, Mdn_N344067.jpg, Hdn_objn37.jpg, Hda_obj5.jpg, Mda_gre163.jpg, Bda_art142.jpg, Fda_art1420.jpg, Mdn_land503.jpg, Hda_obj47.jpg, Mda_art1680.jpg, Fdn_nat168.jpg, Fda_art1464.jpg, Fda_obj230.jpg, Fda_urb342.jpg, Mda_art1030.jpg, Fdn_open20.jpg, Hdn_objn200.jpg, Fdn_open5.jpg, Fda_A673062.jpg, Mda_gre199.jpg, Mdn_N228075.jpg, Hda_obj88.jpg, Mdn_N344026.jpg, Mda_art22.jpg, Hda_int168.jpg, Mda_par140.jpg, Hdn_objn129.jpg, Mda_art1312.jpg, Bdn_text21.jpg, Bdn_land759.jpg, Bdn_text111.jpg, Bda_room188.jpg, Hdn_objn22.jpg, Hdn_objn25.jpg, Hda_obj59.jpg, Fda_city14.jpg, Fda_art1529.jpg, Mda_hous31.jpg, Mdn_N266002.jpg, Mda_art600.jpg, 
Data is of shape : (20000, 144)done in 7.79s.Learning the dictionary... Training on 20000 patches... Iteration   0 /   500 (elapsed time:   0s,  0.0mn)
Iteration   5 /   500 (elapsed time:   0s,  0.0mn)
Iteration  10 /   500 (elapsed time:   1s,  0.0mn)
Iteration  15 /   500 (elapsed time:   2s,  0.0mn)
Iteration  20 /   500 (elapsed time:   3s,  0.1mn)
Iteration  25 /   500 (elapsed time:   4s,  0.1mn)
Iteration  30 /   500 (elapsed time:   4s,  0.1mn)
Iteration  35 /   500 (elapsed time:   5s,  0.1mn)
Iteration  40 /   500 (elapsed time:   6s,  0.1mn)
Iteration  45 /   500 (elapsed time:   7s,  0.1mn)
Iteration  50 /   500 (elapsed time:   8s,  0.1mn)
Iteration  55 /   500 (elapsed time:   8s,  0.1mn)
Iteration  60 /   500 (elapsed time:   9s,  0.2mn)
Iteration  65 /   500 (elapsed time:  10s,  0.2mn)
Iteration  70 /   500 (elapsed time:  11s,  0.2mn)
Iteration  75 /   500 (elapsed time:  12s,  0.2mn)
Iteration  80 /   500 (elapsed time:  12s,  0.2mn)
Iteration  85 /   500 (elapsed time:  13s,  0.2mn)
Iteration  90 /   500 (elapsed time:  14s,  0.2mn)
Iteration  95 /   500 (elapsed time:  15s,  0.3mn)
Iteration  100 /   500 (elapsed time:  16s,  0.3mn)
Iteration  105 /   500 (elapsed time:  16s,  0.3mn)
Iteration  110 /   500 (elapsed time:  17s,  0.3mn)
Iteration  115 /   500 (elapsed time:  18s,  0.3mn)
Iteration  120 /   500 (elapsed time:  19s,  0.3mn)
Iteration  125 /   500 (elapsed time:  20s,  0.3mn)
Iteration  130 /   500 (elapsed time:  21s,  0.4mn)
Iteration  135 /   500 (elapsed time:  21s,  0.4mn)
Iteration  140 /   500 (elapsed time:  22s,  0.4mn)
Iteration  145 /   500 (elapsed time:  23s,  0.4mn)
Iteration  150 /   500 (elapsed time:  24s,  0.4mn)
Iteration  155 /   500 (elapsed time:  25s,  0.4mn)
Iteration  160 /   500 (elapsed time:  25s,  0.4mn)
Iteration  165 /   500 (elapsed time:  26s,  0.4mn)
Iteration  170 /   500 (elapsed time:  27s,  0.5mn)
Iteration  175 /   500 (elapsed time:  28s,  0.5mn)
Iteration  180 /   500 (elapsed time:  28s,  0.5mn)
Iteration  185 /   500 (elapsed time:  29s,  0.5mn)
Iteration  190 /   500 (elapsed time:  30s,  0.5mn)
Iteration  195 /   500 (elapsed time:  31s,  0.5mn)
Iteration  200 /   500 (elapsed time:  32s,  0.5mn)
Iteration  205 /   500 (elapsed time:  32s,  0.5mn)
Iteration  210 /   500 (elapsed time:  33s,  0.6mn)
Iteration  215 /   500 (elapsed time:  34s,  0.6mn)
Iteration  220 /   500 (elapsed time:  35s,  0.6mn)
Iteration  225 /   500 (elapsed time:  36s,  0.6mn)
Iteration  230 /   500 (elapsed time:  36s,  0.6mn)
Iteration  235 /   500 (elapsed time:  37s,  0.6mn)
Iteration  240 /   500 (elapsed time:  38s,  0.6mn)
Iteration  245 /   500 (elapsed time:  39s,  0.7mn)
Iteration  250 /   500 (elapsed time:  39s,  0.7mn)
Iteration  255 /   500 (elapsed time:  40s,  0.7mn)
Iteration  260 /   500 (elapsed time:  41s,  0.7mn)
Iteration  265 /   500 (elapsed time:  42s,  0.7mn)
Iteration  270 /   500 (elapsed time:  43s,  0.7mn)
Iteration  275 /   500 (elapsed time:  44s,  0.7mn)
Iteration  280 /   500 (elapsed time:  44s,  0.7mn)
Iteration  285 /   500 (elapsed time:  45s,  0.8mn)
Iteration  290 /   500 (elapsed time:  46s,  0.8mn)
Iteration  295 /   500 (elapsed time:  47s,  0.8mn)
Iteration  300 /   500 (elapsed time:  48s,  0.8mn)
Iteration  305 /   500 (elapsed time:  48s,  0.8mn)
Iteration  310 /   500 (elapsed time:  49s,  0.8mn)
Iteration  315 /   500 (elapsed time:  50s,  0.8mn)
Iteration  320 /   500 (elapsed time:  51s,  0.9mn)
Iteration  325 /   500 (elapsed time:  51s,  0.9mn)
Iteration  330 /   500 (elapsed time:  52s,  0.9mn)
Iteration  335 /   500 (elapsed time:  53s,  0.9mn)
Iteration  340 /   500 (elapsed time:  54s,  0.9mn)
Iteration  345 /   500 (elapsed time:  55s,  0.9mn)
Iteration  350 /   500 (elapsed time:  55s,  0.9mn)
Iteration  355 /   500 (elapsed time:  56s,  0.9mn)
Iteration  360 /   500 (elapsed time:  57s,  1.0mn)
Iteration  365 /   500 (elapsed time:  58s,  1.0mn)
Iteration  370 /   500 (elapsed time:  58s,  1.0mn)
Iteration  375 /   500 (elapsed time:  59s,  1.0mn)
Iteration  380 /   500 (elapsed time:  60s,  1.0mn)
Iteration  385 /   500 (elapsed time:  61s,  1.0mn)
Iteration  390 /   500 (elapsed time:  62s,  1.0mn)
Iteration  395 /   500 (elapsed time:  62s,  1.0mn)
Iteration  400 /   500 (elapsed time:  63s,  1.1mn)
Iteration  405 /   500 (elapsed time:  64s,  1.1mn)
Iteration  410 /   500 (elapsed time:  65s,  1.1mn)
Iteration  415 /   500 (elapsed time:  66s,  1.1mn)
Iteration  420 /   500 (elapsed time:  66s,  1.1mn)
Iteration  425 /   500 (elapsed time:  67s,  1.1mn)
Iteration  430 /   500 (elapsed time:  68s,  1.1mn)
Iteration  435 /   500 (elapsed time:  69s,  1.2mn)
Iteration  440 /   500 (elapsed time:  69s,  1.2mn)
Iteration  445 /   500 (elapsed time:  70s,  1.2mn)
Iteration  450 /   500 (elapsed time:  71s,  1.2mn)
Iteration  455 /   500 (elapsed time:  72s,  1.2mn)
Iteration  460 /   500 (elapsed time:  73s,  1.2mn)
Iteration  465 /   500 (elapsed time:  73s,  1.2mn)
Iteration  470 /   500 (elapsed time:  74s,  1.2mn)
Iteration  475 /   500 (elapsed time:  75s,  1.3mn)
Iteration  480 /   500 (elapsed time:  76s,  1.3mn)
Iteration  485 /   500 (elapsed time:  76s,  1.3mn)
Iteration  490 /   500 (elapsed time:  77s,  1.3mn)
Iteration  495 /   500 (elapsed time:  78s,  1.3mn)
done in 79.32s.
Out[8]:
(<matplotlib.figure.Figure at 0x1119e5128>,
 <matplotlib.axes._subplots.AxesSubplot at 0x11beb0320>)
In [9]:
for eta in np.logspace(-4, 0, int(15/(DEBUG_DOWNSCALE)**.3), base=10, endpoint=False):
    shl = SHL(DEBUG_DOWNSCALE=DEBUG_DOWNSCALE, eta=eta, verbose=verbose)
    dico = shl.learn_dico()
    shl.show_dico(dico, title='eta={}'.format(eta))
Extracting data...
--- Logging error ---
Traceback (most recent call last):
  File "/Users/lolo/pool/science/BICV/SLIP/src/SLIP.py", line 201, in list_database
    filelist = os.listdir(self.full_url(name_database))
FileNotFoundError: [Errno 2] No such file or directory: 'database/serre07_distractors'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/Cellar/python3/3.5.1/Frameworks/Python.framework/Versions/3.5/lib/python3.5/logging/__init__.py", line 980, in emit
    msg = self.format(record)
  File "/usr/local/Cellar/python3/3.5.1/Frameworks/Python.framework/Versions/3.5/lib/python3.5/logging/__init__.py", line 830, in format
    return fmt.format(record)
  File "/usr/local/Cellar/python3/3.5.1/Frameworks/Python.framework/Versions/3.5/lib/python3.5/logging/__init__.py", line 567, in format
    record.message = record.getMessage()
  File "/usr/local/Cellar/python3/3.5.1/Frameworks/Python.framework/Versions/3.5/lib/python3.5/logging/__init__.py", line 330, in getMessage
    msg = msg % self.args
TypeError: not all arguments converted during string formatting
Call stack:
  File "/usr/local/Cellar/python3/3.5.1/Frameworks/Python.framework/Versions/3.5/lib/python3.5/runpy.py", line 170, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/local/Cellar/python3/3.5.1/Frameworks/Python.framework/Versions/3.5/lib/python3.5/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.5/site-packages/ipykernel/__main__.py", line 3, in <module>
    app.launch_new_instance()
  File "/usr/local/lib/python3.5/site-packages/traitlets/config/application.py", line 592, in launch_instance
    app.start()
  File "/usr/local/lib/python3.5/site-packages/ipykernel/kernelapp.py", line 403, in start
    ioloop.IOLoop.instance().start()
  File "/usr/local/lib/python3.5/site-packages/zmq/eventloop/ioloop.py", line 151, in start
    super(ZMQIOLoop, self).start()
  File "/usr/local/lib/python3.5/site-packages/tornado/ioloop.py", line 866, in start
    handler_func(fd_obj, events)
  File "/usr/local/lib/python3.5/site-packages/tornado/stack_context.py", line 275, in null_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.5/site-packages/zmq/eventloop/zmqstream.py", line 433, in _handle_events
    self._handle_recv()
  File "/usr/local/lib/python3.5/site-packages/zmq/eventloop/zmqstream.py", line 465, in _handle_recv
    self._run_callback(callback, msg)
  File "/usr/local/lib/python3.5/site-packages/zmq/eventloop/zmqstream.py", line 407, in _run_callback
    callback(*args, **kwargs)
  File "/usr/local/lib/python3.5/site-packages/tornado/stack_context.py", line 275, in null_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.5/site-packages/ipykernel/kernelbase.py", line 260, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/usr/local/lib/python3.5/site-packages/ipykernel/kernelbase.py", line 212, in dispatch_shell
    handler(stream, idents, msg)
  File "/usr/local/lib/python3.5/site-packages/ipykernel/kernelbase.py", line 370, in execute_request
    user_expressions, allow_stdin)
  File "/usr/local/lib/python3.5/site-packages/ipykernel/ipkernel.py", line 175, in do_execute
    shell.run_cell(code, store_history=store_history, silent=silent)
  File "/usr/local/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 2902, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/usr/local/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 3006, in run_ast_nodes
    if self.run_code(code, result):
  File "/usr/local/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 3066, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-9-99cafb0443b4>", line 3, in <module>
    dico = shl.learn_dico()
  File "/Users/lolo/pool/science/BICV/SHL_scripts/src/shl_scripts.py", line 147, in learn_dico
    data = self.get_data(name_database)
  File "/Users/lolo/pool/science/BICV/SHL_scripts/src/shl_scripts.py", line 117, in get_data
    imagelist = self.slip.make_imagelist(name_database=name_database)#, seed=seed)
  File "/Users/lolo/pool/science/BICV/SLIP/src/SLIP.py", line 249, in make_imagelist
    filelist = self.list_database(name_database)
  File "/Users/lolo/pool/science/BICV/SLIP/src/SLIP.py", line 206, in list_database
    self.log.error('XX failed opening database ',  self.full_url(name_database))
Message: 'XX failed opening database '
Arguments: ('database/serre07_distractors',)
--- Logging error ---
Traceback (most recent call last):
  File "/Users/lolo/pool/science/BICV/SLIP/src/SLIP.py", line 201, in list_database
    filelist = os.listdir(self.full_url(name_database))
FileNotFoundError: [Errno 2] No such file or directory: 'database/serre07_distractors'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/usr/local/Cellar/python3/3.5.1/Frameworks/Python.framework/Versions/3.5/lib/python3.5/logging/__init__.py", line 980, in emit
    msg = self.format(record)
  File "/usr/local/Cellar/python3/3.5.1/Frameworks/Python.framework/Versions/3.5/lib/python3.5/logging/__init__.py", line 830, in format
    return fmt.format(record)
  File "/usr/local/Cellar/python3/3.5.1/Frameworks/Python.framework/Versions/3.5/lib/python3.5/logging/__init__.py", line 567, in format
    record.message = record.getMessage()
  File "/usr/local/Cellar/python3/3.5.1/Frameworks/Python.framework/Versions/3.5/lib/python3.5/logging/__init__.py", line 330, in getMessage
    msg = msg % self.args
TypeError: not all arguments converted during string formatting
Call stack:
  File "/usr/local/Cellar/python3/3.5.1/Frameworks/Python.framework/Versions/3.5/lib/python3.5/runpy.py", line 170, in _run_module_as_main
    "__main__", mod_spec)
  File "/usr/local/Cellar/python3/3.5.1/Frameworks/Python.framework/Versions/3.5/lib/python3.5/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.5/site-packages/ipykernel/__main__.py", line 3, in <module>
    app.launch_new_instance()
  File "/usr/local/lib/python3.5/site-packages/traitlets/config/application.py", line 592, in launch_instance
    app.start()
  File "/usr/local/lib/python3.5/site-packages/ipykernel/kernelapp.py", line 403, in start
    ioloop.IOLoop.instance().start()
  File "/usr/local/lib/python3.5/site-packages/zmq/eventloop/ioloop.py", line 151, in start
    super(ZMQIOLoop, self).start()
  File "/usr/local/lib/python3.5/site-packages/tornado/ioloop.py", line 866, in start
    handler_func(fd_obj, events)
  File "/usr/local/lib/python3.5/site-packages/tornado/stack_context.py", line 275, in null_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.5/site-packages/zmq/eventloop/zmqstream.py", line 433, in _handle_events
    self._handle_recv()
  File "/usr/local/lib/python3.5/site-packages/zmq/eventloop/zmqstream.py", line 465, in _handle_recv
    self._run_callback(callback, msg)
  File "/usr/local/lib/python3.5/site-packages/zmq/eventloop/zmqstream.py", line 407, in _run_callback
    callback(*args, **kwargs)
  File "/usr/local/lib/python3.5/site-packages/tornado/stack_context.py", line 275, in null_wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.5/site-packages/ipykernel/kernelbase.py", line 260, in dispatcher
    return self.dispatch_shell(stream, msg)
  File "/usr/local/lib/python3.5/site-packages/ipykernel/kernelbase.py", line 212, in dispatch_shell
    handler(stream, idents, msg)
  File "/usr/local/lib/python3.5/site-packages/ipykernel/kernelbase.py", line 370, in execute_request
    user_expressions, allow_stdin)
  File "/usr/local/lib/python3.5/site-packages/ipykernel/ipkernel.py", line 175, in do_execute
    shell.run_cell(code, store_history=store_history, silent=silent)
  File "/usr/local/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 2902, in run_cell
    interactivity=interactivity, compiler=compiler, result=result)
  File "/usr/local/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 3006, in run_ast_nodes
    if self.run_code(code, result):
  File "/usr/local/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 3066, in run_code
    exec(code_obj, self.user_global_ns, self.user_ns)
  File "<ipython-input-9-99cafb0443b4>", line 3, in <module>
    dico = shl.learn_dico()
  File "/Users/lolo/pool/science/BICV/SHL_scripts/src/shl_scripts.py", line 147, in learn_dico
    data = self.get_data(name_database)
  File "/Users/lolo/pool/science/BICV/SHL_scripts/src/shl_scripts.py", line 117, in get_data
    imagelist = self.slip.make_imagelist(name_database=name_database)#, seed=seed)
  File "/Users/lolo/pool/science/BICV/SLIP/src/SLIP.py", line 261, in make_imagelist
    image_, filename, croparea = self.patch(name_database, i_image=shuffling[i_image % N_image_db], verbose=verbose)
  File "/Users/lolo/pool/science/BICV/SLIP/src/SLIP.py", line 307, in patch
    image, filename = self.load_in_database(name_database, i_image=i_image, filename=filename, verbose=verbose)
  File "/Users/lolo/pool/science/BICV/SLIP/src/SLIP.py", line 224, in load_in_database
    filelist = self.list_database(name_database=name_database)
  File "/Users/lolo/pool/science/BICV/SLIP/src/SLIP.py", line 206, in list_database
    self.log.error('XX failed opening database ',  self.full_url(name_database))
Message: 'XX failed opening database '
Arguments: ('database/serre07_distractors',)
ERROR:SLIP:failed opening database/serre07_distractors/d

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-9-99cafb0443b4> in <module>()
      1 for eta in np.logspace(-4, 0, int(15/(DEBUG_DOWNSCALE)**.3), base=10, endpoint=False):
      2     shl = SHL(DEBUG_DOWNSCALE=DEBUG_DOWNSCALE, eta=eta, verbose=verbose)
----> 3     dico = shl.learn_dico()
      4     shl.show_dico(dico, title='eta={}'.format(eta))

/Users/lolo/pool/science/BICV/SHL_scripts/src/shl_scripts.py in learn_dico(self, name_database, **kwargs)
    145 
    146     def learn_dico(self, name_database='serre07_distractors', **kwargs):
--> 147         data = self.get_data(name_database)
    148         # Learn the dictionary from reference patches
    149         if self.verbose: print('Learning the dictionary...', end=' ')

/Users/lolo/pool/science/BICV/SHL_scripts/src/shl_scripts.py in get_data(self, name_database, seed, patch_norm)
    115             sys.stdout.write("\b" * (toolbar_width+1)) # return to start of line, after '['
    116             t0 = time.time()
--> 117         imagelist = self.slip.make_imagelist(name_database=name_database)#, seed=seed)
    118         for filename, croparea in imagelist:
    119             # whitening

/Users/lolo/pool/science/BICV/SLIP/src/SLIP.py in make_imagelist(self, name_database, verbose)
    259         imagelist = []
    260         for i_image in range(N_image):
--> 261             image_, filename, croparea = self.patch(name_database, i_image=shuffling[i_image % N_image_db], verbose=verbose)
    262             imagelist.append([filename, croparea])
    263 

/Users/lolo/pool/science/BICV/SLIP/src/SLIP.py in patch(self, name_database, i_image, filename, croparea, threshold, verbose, preprocess, center, use_max)
    309 
    310         if (croparea is None):
--> 311             image_size_h, image_size_v = image.shape
    312             if self.N_X > image_size_h or self.N_Y > image_size_v:
    313                 print('N_X patch_h patch_v  ', self.N_X, image_size_h, image_size_v)

AttributeError: 'str' object has no attribute 'shape'

a quick diagnostic of what is wrong

An assumption in the previous code is the heuristics used to control how elements are chosen. Basically, we have set the norm of every dictionary element to the inverse of an estimate of the mean variance, such that a high variance means a lower norm and a lower corresponding coefficient: it will thus get less likely to be selected again.

In [ ]:
data = shl.get_data()
In [ ]:
dico.set_params(transform_algorithm='omp', transform_n_nonzero_coefs=shl.n_components)
code = dico.transform(data)
In [ ]:
Z = np.mean(code**2)
fig = plt.figure(figsize=(12, 4))
ax = fig.add_subplot(111)
ax.bar(np.arange(shl.n_components), np.mean(code**2/Z, axis=0))#, yerr=np.std(code**2/Z, axis=0))
ax.set_title('Variance of coefficients')
ax.set_ylabel('Variance')
ax.set_xlabel('#')
ax.axis('tight')
fig.show()
In [ ]:
fig = plt.figure(figsize=(6, 4))
ax = fig.add_subplot(111)
data = pd.DataFrame(np.mean(code**2/Z, axis=0), columns=['Variance'])
with sns.axes_style("white"):
    ax = sns.distplot(data['Variance'],  kde_kws={'clip':(0., 5.)})
ax.set_title('distribution of the mean variance of coefficients')
ax.set_ylabel('pdf')
fig.show()

In particular, those with high-variance are more likely features that were more learned, while those with lower variance correspond to textures closer to the initail state of the dictionary. That is shown by ordering filters in the dictionary (from the first line on the top from left to right and then to the bottom):

In [ ]:
sorted_idx = np.argsort(np.mean(code**2/Z, axis=0))
dico.components_ = dico.components_[sorted_idx, :]
fig = shl.show_dico(dico)
fig.show()
In [ ]:
print(dico.components_.shape, data.shape)
In [ ]:
data = shl.get_data()
#dico.set_params(transform_algorithm='omp', transform_n_nonzero_coefs=shl.n_components)
#code = dico.transform(data)
coef = np.dot(dico.components_, data.T).T / np.sum(dico.components_**2, axis=1)[np.newaxis, :]
#Z = np.mean(code**2)
fig = plt.figure(figsize=(12, 4))
ax = fig.add_subplot(111)
ax.bar(np.arange(shl.n_components), np.mean(coef**2, axis=0))#, yerr=np.std(code**2/Z, axis=0))
ax.set_title('Variance of linear coefficients')
ax.set_ylabel('Variance')
ax.set_xlabel('#')
ax.axis('tight')
fig.show()

By the scaling law of a linear filter, one may normalize the norm of the filters relative to the mean deviation (root square of variance) measured on the input. By doing this operation, the observed variance of the linear coefficients is normalized to the same value. That's what we have with the heuristics implemented in the previous notebook.

However, the learning rate (gain_rate) of this heuristics may be hard to tune. Most importantly, this heuristics assumes that all components are scaled (for instance that the first to learn will have higher variance and that its whole distribution will be scaled), and thus assumes that they all have the same distribution. This assumption misses an important aspect of this unsupervised learning scheme.

Indeed, when a feature appears during the learning, its response is scaled but its distribution becomes more kurtotic. This may be understood using the central limit theorem. When mixing signals such as when computing the average of random variables, the distribution converges to a Gaussian.

Let's show the distribution of the two dictionary elements with respectively the highest (left) and lowest (right) variance:

In [ ]:
sorted_idx = np.argsort(np.mean(coef**2, axis=0))
# plot distribution of most kurtotic vs least
fig = plt.figure(figsize=(12, 4))
for i, idx in enumerate([0, -1]):
    #data = pd.DataFrame(code[:, sorted_idx[idx]], columns=['coefficient'])
    with sns.axes_style("white"):
        ax = fig.add_subplot(1, 2, i)
        n_bins = 30
        n, bins = np.histogram((np.abs(coef[:, sorted_idx[idx]])), n_bins)# = 
        ax.bar(bins[:-1], np.log2(n), width=(np.abs(coef[:, sorted_idx[idx]].max()))/n_bins)
        #ax = sns.distplot(data['coefficient'])#,  kde_kws={'clip':(0., 5.)})
        ax.set_title('Variance = {}'.format(np.mean(coef[:, sorted_idx[idx]]**2)))
        ax.set_ylabel('log2-density')
        ax.set_xlabel('absolute coefficient')
fig.suptitle('distribution of the value of coefficients', fontsize=16)
        
fig.show()

To circumvent this problem, instead of changing the learning scheme, it is possible to change the coding procedure such that we are sure that every coefficient is selected a priori with the same probability. To perform that in a non-parametric fashion, one may use histogram equalization.

In [ ]:
n_samples, n_components = coef.shape
np.sort(np.abs(coef[:(n_samples/2), sorted_idx[idx]])).shape
#print(-np.sort(-np.abs(coef[:(n_samples/2), sorted_idx[idx]])))
#print(-np.sort(-np.abs(coef[:, sorted_idx[idx]]]))
In [ ]:
# plot distribution of most kurtotic vs least
fig = plt.figure(figsize=(12, 4))
for i, idx in enumerate([0, -1]):
    with sns.axes_style("white"):
        ax = fig.add_subplot(1, 2, i)
        ax.semilogx(-np.sort(-np.abs(coef[:(n_samples/2), sorted_idx[idx]])))
        ax.set_title('Variance = {}'.format(np.mean(coef[:, sorted_idx[idx]]**2)))
        ax.set_ylabel('value')
        ax.set_xlabel('rank')
        ax.axis('tight')
fig.suptitle('distribution of the value of linear coefficients', fontsize=16)
fig.show()
In [ ]:
# what we need is some sort of histogram normalization
def histeq(abs_code, mod):
    # use linear interpolation of cdf to find z-values
    # use the fact that the sorted coeffs give the inverse cdf
    # moreover we use a hack to ensure that np.interp uses an increasing sequence of x-coordinates
    z = np.interp(-abs_code, -mod, np.linspace(0, 1., mod.size, endpoint=True))
    return z

fig = plt.figure(figsize=(12, 4))
for i, idx in enumerate([0, -1]):
    mod = -np.sort(-np.abs(coef[:(n_samples/2), sorted_idx[idx]]))
    # learn distribution of z-values on the first half of the data
    z = histeq(np.abs(coef[(n_samples/2):, sorted_idx[idx]]), mod)

    #z = histeq(np.abs(coef[:(n_samples/2), sorted_idx[idx]]), mod) - to check it is exact
    # plot distribution of z-values on the second half of the data
    with sns.axes_style("white"):
        ax = fig.add_subplot(1, 2, i)
        n_bins = 30
        n, bins = np.histogram(z, n_bins)# = 
        ax.bar(bins[:-1], n, width=.5/n_bins)
        #ax.scatter(np.linspace(0, 1., z.size, endpoint=True), z)
        #ax.set_title('Variance = {}'.format(np.mean(code[:, sorted_idx[idx]]**2/Z, axis=0)))
        ax.set_ylabel('density')
        ax.set_xlabel('z-score')
        ax.axis('tight')
        
fig.suptitle('distribution of the value of z-scores', fontsize=16)
        
fig.show()

To learn this modulation function, we may just estimate it online beginning with a flat one:

In [ ]:
mod = np.dot(np.linspace(1., 0, n_samples, endpoint=True)[:, np.newaxis], np.ones((1, n_components)))
print(mod.shape)
plt.plot(mod[:, 0])
In [ ]:
print(coef.shape)
print(np.sort(np.abs(coef), axis=0).shape)
mod_ = -np.sort(-np.abs(coef), axis=0)
gain_rate = .1
mod = (1 - gain_rate)*mod + gain_rate * mod_
plt.semilogx(mod[:, 0])
In [ ]:
shl = SHL(DEBUG_DOWNSCALE=DEBUG_DOWNSCALE)
verbose = 0
dico = shl.learn_dico(learning_algorithm='comp', transform_n_nonzero_coefs=shl.n_components, gain_rate=0.001, verbose=verbose)
fig = shl.show_dico(dico)
fig.show()
In [ ]:
data = shl.get_data()
#dico.set_params(transform_algorithm='omp', transform_n_nonzero_coefs=shl.n_components)
#code = dico.transform(data)
coef = np.dot(dico.components_, data.T).T / np.sum(dico.components_**2, axis=1)[np.newaxis, :]
#Z = np.mean(code**2)
fig = plt.figure(figsize=(12, 4))
ax = fig.add_subplot(111)
ax.bar(np.arange(shl.n_components), np.mean(coef**2, axis=0))#, yerr=np.std(code**2/Z, axis=0))
ax.set_title('Variance of linear coefficients')
ax.set_ylabel('Variance')
ax.set_xlabel('#')
ax.axis('tight')
fig.show()
In [ ]:
# plot distribution of most kurtotic vs least
sorted_idx = np.argsort(np.mean(coef**2, axis=0))
fig = plt.figure(figsize=(12, 4))
for i, idx in enumerate([0, -1]):
    #data = pd.DataFrame(code[:, sorted_idx[idx]], columns=['coefficient'])
    with sns.axes_style("white"):
        ax = fig.add_subplot(1, 2, i)
        n_bins = 30
        n, bins = np.histogram((np.abs(coef[:, sorted_idx[idx]])), n_bins)# = 
        ax.bar(bins[:-1], np.log2(n), width=(np.abs(coef[:, sorted_idx[idx]].max()))/n_bins)
        #ax = sns.distplot(data['coefficient'])#,  kde_kws={'clip':(0., 5.)})
        ax.set_title('Variance = {}'.format(np.mean(coef[:, sorted_idx[idx]]**2)))
        ax.set_ylabel('log2-density')
        ax.set_xlabel('absolute coefficient')
fig.suptitle('distribution of the value of coefficients', fontsize=16)
        
fig.show()
In [ ]:
# quick estipmation of the z-scores
#z = np.interp(-np.abs(coef[(n_samples/2):, sorted_idx[idx]]), -mod, np.linspace(0, 1., mod.size, endpoint=True))
from shl_scripts import SHL
DEBUG_DOWNSCALE=1
for gain_rate in np.logspace(-4, 0, 15, base=10):
    #
    shl = SHL(DEBUG_DOWNSCALE=DEBUG_DOWNSCALE)
    dico = shl.learn_dico(learning_algorithm='comp', transform_n_nonzero_coefs=10, gain_rate=gain_rate)
    fig = shl.show_dico(dico)
    fig.show()
    data = shl.get_data()           
    dico.set_params(transform_algorithm='omp', transform_n_nonzero_coefs=shl.n_components)
    code = dico.transform(data)
    n_components, n_samples = code.shape
    Z = np.mean(code**2)
    fig = plt.figure(figsize=(12, 4))
    ax = fig.add_subplot(111)
    ax.bar(np.arange(shl.n_components), np.mean(code**2, axis=0)/Z)
    ax.set_title('Variance of coefficients - gain {}'.format(gain_rate))
    ax.set_ylabel('Variance')
    ax.set_xlabel('#')      
    ax.axis('tight')
    fig.show()
    fig = plt.figure(figsize=(6, 4))
    ax = fig.add_subplot(111)
    data = pd.DataFrame(np.mean(code**2, axis=0)/np.mean(code**2), columns=['Variance'])
    with sns.axes_style("white"):
        ax = sns.distplot(data['Variance'],  kde_kws={'clip':(0., 5.)})
    ax.set_title('distribution of the mean variance of coefficients')
    ax.set_ylabel('pdf')
    fig.show()

great, let's try that new version of the classical algorithm

contributing code to sklearn

As previously, let's contribute this bit of code to sklearn:

  • set-up variables

     cd ~/pool/libs/
     github_user='bicv'
     lib='scikit-learn'
     project='comp'
     git clone https://github.com/$github_user/$lib
     cd $lib
    
    
  • creating a new branch for his project

    git branches
    git remote add $github_user https://github.com/$github_user/$lib
    git checkout -b $project
    
    
  • installing the library in dev mode:

     pip uninstall $lib
     pip install -e .
    
    
  • making the one line change to moviepy's code and test it with the minimum working example above:

     mvim -p sklearn/decomposition/dict_learning.py sklearn/linear_model/omp.py
    
    
    

More details on MiniBatchDictionaryLearning:

In [ ]: