Extending Olshausens classical SparseNet
This is an old blog post, see the newer version in this post
-
In a previous notebook, we tried to reproduce the learning strategy specified in the framework of the SparseNet algorithm from Bruno Olshausen. It allows to efficiently code natural image patches by constraining the code to be sparse. In particular, we saw that in order to optimize competition, it is important to control cooperation and we implemented a heuristic to just do this.
-
In this notebook, we provide an extension to the SparseNet algorithm. We will study how homeostasis (cooperation) may be an essential ingredient to this algorithm working on a winner-take-all basis (competition). This extension has been published as Perrinet, Neural Computation (2010) (see https://laurentperrinet.github.io/publication/perrinet-10-shl ):
@article{Perrinet10shl,
Title = {Role of homeostasis in learning sparse representations},
Author = {Perrinet, Laurent U.},
Journal = {Neural Computation},
Year = {2010},
Doi = {10.1162/neco.2010.05-08-795},
Keywords = {Neural population coding, Unsupervised learning, Statistics of natural images, Simple cell receptive fields, Sparse Hebbian Learning, Adaptive Matching Pursuit, Cooperative Homeostasis, Competition-Optimized Matching Pursuit},
Month = {July},
Number = {7},
Url = {https://laurentperrinet.github.io/publication/perrinet-10-shl},
Volume = {22},
}
- find out updates on https://laurentperrinet.github.io/publication/perrinet-19-hulk and https://github.com/bicv/SHL_scripts
dictionary learning on natural images¶
Here, we reproduce the dictionary learning obtained with the SparseNet alogrithm while using the Orthogonal Matching Pursuit algorithm for the sparse coding algorithm.
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format='svg'
import numpy as np
np.set_printoptions(precision=2, suppress=True)
import pandas as pd
import seaborn as sns
%load_ext autoreload
%autoreload 2
The autoreload extension is already loaded. To reload it, use: %reload_ext autoreload
from shl_scripts import SHL
database = '/Users/lolo/pool/science/BICV/SHL_scripts/database/'
DEBUG_DOWNSCALE, verbose = 1, 100
DEBUG_DOWNSCALE, verbose = 1, 0
DEBUG_DOWNSCALE, verbose = 10, 100
shl = SHL(database=database, DEBUG_DOWNSCALE=DEBUG_DOWNSCALE, verbose=verbose)
dico = shl.learn_dico()
shl.show_dico(dico)
Extracting data...Mda_archi56.jpg, Mda_arnat105.jpg, Mda_art509.jpg, Fda_art464.jpg, Hda_obj84.jpg, Fdn_open12.jpg, Fdn_open19.jpg, Bda_art1164.jpg, Hdn_objn18.jpg, Mdn_text38.jpg, Fda_A223024.jpg, Hdn_objn98.jpg, Mdn_land159.jpg, Hda_obj172.jpg, Fda_art1017.jpg, Bda_art1271.jpg, Hdn_objn41.jpg, Fdn_nat297.jpg, Mdn_natu961.jpg, Mdn_sharp49.jpg, Fdn_sclos1.jpg, Mdn_N344070.jpg, Hda_obj96.jpg, Mda_gre329.jpg, Mdn_land243.jpg, Hda_obj12.jpg, Fda_art36.jpg, Fda_art1015.jpg, Bda_room43.jpg, Bda_int110.jpg, Mdn_N295008.jpg, Mdn_land313.jpg, Bda_art1168.jpg, Fda_art1249.jpg, Fda_A683045.jpg, Hdn_objn36.jpg, Fdn_bea1.jpg, Bda_art1117.jpg, Hdn_N124002.jpg, Bdn_land71.jpg, Hda_obj348.jpg, Mdn_N344038.jpg, Mdn_N347003.jpg, Mda_art337.jpg, Bdn_N44094.jpg, Mdn_nat1253.jpg, Hdn_objn113.jpg, Hdn_objn190.jpg, Bda_art1171.jpg, Hdn_objn122.jpg, Fdn_open6.jpg, Bdn_text64.jpg, Fdn_natu815.jpg, Fdn_land362.jpg, Bda_art1067.jpg, Hdn_for104.jpg, Fdn_nat462.jpg, Mda_archi100.jpg, Hda_obj414.jpg, Bdn_objn146.jpg, Fda_gre536.jpg, Mdn_N228004.jpg, Mda_enc30.jpg, Hda_obj390.jpg, Mdn_N344035.jpg, Bda_int30.jpg, Bdn_land861.jpg, Mdn_N328009.jpg, Fdn_land317.jpg, Hda_obj65.jpg, Mda_urb556.jpg, Bda_art685.jpg, Fdn_for7.jpg, Fdn_open18.jpg, Fdn_land83.jpg, Hdn_natu369.jpg, Mda_archi303.jpg, Fdn_nat498.jpg, Bdn_nat267.jpg, Fdn_bea33.jpg, Mdn_nat1247.jpg, Bda_int770.jpg, Mdn_N295019.jpg, Mda_urb382.jpg, Bda_art673.jpg, Fda_A698003.jpg, Mdn_natu35.jpg, Hdn_objn24.jpg, Hda_room419.jpg, Mda_art1665.jpg, Mdn_N347022.jpg, Fda_art1486.jpg, Mda_art438.jpg, Fda_A673086.jpg, Hda_int888.jpg, Fda_enc51.jpg, Bdn_natu676.jpg, Hda_obj343.jpg, Bdn_natu977.jpg, Hda_obj265.jpg, Fda_art772.jpg, Fda_street63.jpg, Bda_gre298.jpg, Mdn_N291096.jpg, Hdn_objn177.jpg, Bdn_text15.jpg, Hda_obj95.jpg, Fda_A462040.jpg, Bda_int741.jpg, Fdn_open17.jpg, Mda_urb276.jpg, Mdn_nat492.jpg, Hda_obj454.jpg, Fdn_nat895.jpg, Hdn_objn26.jpg, Mda_archi233.jpg, Mda_art1136.jpg, Bda_art1147.jpg, Hdn_N124064.jpg, Bda_room185.jpg, Hda_obj7.jpg, Bdn_land810.jpg, Mda_urb902.jpg, Fdn_bea5.jpg, Bda_urb291.jpg, Bda_art921.jpg, Mdn_natu302.jpg, Mdn_nat954.jpg, Fda_A683031.jpg, Mdn_for35.jpg, Fda_A277094.jpg, Fda_art1469.jpg, Bdn_text6.jpg, Fdn_open16.jpg, Bda_art923.jpg, Bda_art1175.jpg, Bdn_natu160.jpg, Mda_archi327.jpg, Fda_art779.jpg, Bda_art1185.jpg, Hda_obj94.jpg, Bdn_text19.jpg, Hda_obj89.jpg, Hda_obj2.jpg, Hda_int170.jpg, Bdn_for82.jpg, Bda_room452.jpg, Hda_obj98.jpg, Mdn_for83.jpg, Mda_arnat70.jpg, Fda_A463033.jpg, Fda_urb128.jpg, Mda_art1517.jpg, Bdn_natu419.jpg, Bdn_nat1157.jpg, Fda_A244041.jpg, Mdn_nat211.jpg, Hdn_objn23.jpg, Hdn_objn42.jpg, Fda_art560.jpg, Mda_art361.jpg, Mdn_N344067.jpg, Hdn_objn37.jpg, Hda_obj5.jpg, Mda_gre163.jpg, Bda_art142.jpg, Fda_art1420.jpg, Mdn_land503.jpg, Hda_obj47.jpg, Mda_art1680.jpg, Fdn_nat168.jpg, Fda_art1464.jpg, Fda_obj230.jpg, Fda_urb342.jpg, Mda_art1030.jpg, Fdn_open20.jpg, Hdn_objn200.jpg, Fdn_open5.jpg, Fda_A673062.jpg, Mda_gre199.jpg, Mdn_N228075.jpg, Hda_obj88.jpg, Mdn_N344026.jpg, Mda_art22.jpg, Hda_int168.jpg, Mda_par140.jpg, Hdn_objn129.jpg, Mda_art1312.jpg, Bdn_text21.jpg, Bdn_land759.jpg, Bdn_text111.jpg, Bda_room188.jpg, Hdn_objn22.jpg, Hdn_objn25.jpg, Hda_obj59.jpg, Fda_city14.jpg, Fda_art1529.jpg, Mda_hous31.jpg, Mdn_N266002.jpg, Mda_art600.jpg, Data is of shape : (20000, 144)done in 7.79s.Learning the dictionary... Training on 20000 patches... Iteration 0 / 500 (elapsed time: 0s, 0.0mn) Iteration 5 / 500 (elapsed time: 0s, 0.0mn) Iteration 10 / 500 (elapsed time: 1s, 0.0mn) Iteration 15 / 500 (elapsed time: 2s, 0.0mn) Iteration 20 / 500 (elapsed time: 3s, 0.1mn) Iteration 25 / 500 (elapsed time: 4s, 0.1mn) Iteration 30 / 500 (elapsed time: 4s, 0.1mn) Iteration 35 / 500 (elapsed time: 5s, 0.1mn) Iteration 40 / 500 (elapsed time: 6s, 0.1mn) Iteration 45 / 500 (elapsed time: 7s, 0.1mn) Iteration 50 / 500 (elapsed time: 8s, 0.1mn) Iteration 55 / 500 (elapsed time: 8s, 0.1mn) Iteration 60 / 500 (elapsed time: 9s, 0.2mn) Iteration 65 / 500 (elapsed time: 10s, 0.2mn) Iteration 70 / 500 (elapsed time: 11s, 0.2mn) Iteration 75 / 500 (elapsed time: 12s, 0.2mn) Iteration 80 / 500 (elapsed time: 12s, 0.2mn) Iteration 85 / 500 (elapsed time: 13s, 0.2mn) Iteration 90 / 500 (elapsed time: 14s, 0.2mn) Iteration 95 / 500 (elapsed time: 15s, 0.3mn) Iteration 100 / 500 (elapsed time: 16s, 0.3mn) Iteration 105 / 500 (elapsed time: 16s, 0.3mn) Iteration 110 / 500 (elapsed time: 17s, 0.3mn) Iteration 115 / 500 (elapsed time: 18s, 0.3mn) Iteration 120 / 500 (elapsed time: 19s, 0.3mn) Iteration 125 / 500 (elapsed time: 20s, 0.3mn) Iteration 130 / 500 (elapsed time: 21s, 0.4mn) Iteration 135 / 500 (elapsed time: 21s, 0.4mn) Iteration 140 / 500 (elapsed time: 22s, 0.4mn) Iteration 145 / 500 (elapsed time: 23s, 0.4mn) Iteration 150 / 500 (elapsed time: 24s, 0.4mn) Iteration 155 / 500 (elapsed time: 25s, 0.4mn) Iteration 160 / 500 (elapsed time: 25s, 0.4mn) Iteration 165 / 500 (elapsed time: 26s, 0.4mn) Iteration 170 / 500 (elapsed time: 27s, 0.5mn) Iteration 175 / 500 (elapsed time: 28s, 0.5mn) Iteration 180 / 500 (elapsed time: 28s, 0.5mn) Iteration 185 / 500 (elapsed time: 29s, 0.5mn) Iteration 190 / 500 (elapsed time: 30s, 0.5mn) Iteration 195 / 500 (elapsed time: 31s, 0.5mn) Iteration 200 / 500 (elapsed time: 32s, 0.5mn) Iteration 205 / 500 (elapsed time: 32s, 0.5mn) Iteration 210 / 500 (elapsed time: 33s, 0.6mn) Iteration 215 / 500 (elapsed time: 34s, 0.6mn) Iteration 220 / 500 (elapsed time: 35s, 0.6mn) Iteration 225 / 500 (elapsed time: 36s, 0.6mn) Iteration 230 / 500 (elapsed time: 36s, 0.6mn) Iteration 235 / 500 (elapsed time: 37s, 0.6mn) Iteration 240 / 500 (elapsed time: 38s, 0.6mn) Iteration 245 / 500 (elapsed time: 39s, 0.7mn) Iteration 250 / 500 (elapsed time: 39s, 0.7mn) Iteration 255 / 500 (elapsed time: 40s, 0.7mn) Iteration 260 / 500 (elapsed time: 41s, 0.7mn) Iteration 265 / 500 (elapsed time: 42s, 0.7mn) Iteration 270 / 500 (elapsed time: 43s, 0.7mn) Iteration 275 / 500 (elapsed time: 44s, 0.7mn) Iteration 280 / 500 (elapsed time: 44s, 0.7mn) Iteration 285 / 500 (elapsed time: 45s, 0.8mn) Iteration 290 / 500 (elapsed time: 46s, 0.8mn) Iteration 295 / 500 (elapsed time: 47s, 0.8mn) Iteration 300 / 500 (elapsed time: 48s, 0.8mn) Iteration 305 / 500 (elapsed time: 48s, 0.8mn) Iteration 310 / 500 (elapsed time: 49s, 0.8mn) Iteration 315 / 500 (elapsed time: 50s, 0.8mn) Iteration 320 / 500 (elapsed time: 51s, 0.9mn) Iteration 325 / 500 (elapsed time: 51s, 0.9mn) Iteration 330 / 500 (elapsed time: 52s, 0.9mn) Iteration 335 / 500 (elapsed time: 53s, 0.9mn) Iteration 340 / 500 (elapsed time: 54s, 0.9mn) Iteration 345 / 500 (elapsed time: 55s, 0.9mn) Iteration 350 / 500 (elapsed time: 55s, 0.9mn) Iteration 355 / 500 (elapsed time: 56s, 0.9mn) Iteration 360 / 500 (elapsed time: 57s, 1.0mn) Iteration 365 / 500 (elapsed time: 58s, 1.0mn) Iteration 370 / 500 (elapsed time: 58s, 1.0mn) Iteration 375 / 500 (elapsed time: 59s, 1.0mn) Iteration 380 / 500 (elapsed time: 60s, 1.0mn) Iteration 385 / 500 (elapsed time: 61s, 1.0mn) Iteration 390 / 500 (elapsed time: 62s, 1.0mn) Iteration 395 / 500 (elapsed time: 62s, 1.0mn) Iteration 400 / 500 (elapsed time: 63s, 1.1mn) Iteration 405 / 500 (elapsed time: 64s, 1.1mn) Iteration 410 / 500 (elapsed time: 65s, 1.1mn) Iteration 415 / 500 (elapsed time: 66s, 1.1mn) Iteration 420 / 500 (elapsed time: 66s, 1.1mn) Iteration 425 / 500 (elapsed time: 67s, 1.1mn) Iteration 430 / 500 (elapsed time: 68s, 1.1mn) Iteration 435 / 500 (elapsed time: 69s, 1.2mn) Iteration 440 / 500 (elapsed time: 69s, 1.2mn) Iteration 445 / 500 (elapsed time: 70s, 1.2mn) Iteration 450 / 500 (elapsed time: 71s, 1.2mn) Iteration 455 / 500 (elapsed time: 72s, 1.2mn) Iteration 460 / 500 (elapsed time: 73s, 1.2mn) Iteration 465 / 500 (elapsed time: 73s, 1.2mn) Iteration 470 / 500 (elapsed time: 74s, 1.2mn) Iteration 475 / 500 (elapsed time: 75s, 1.3mn) Iteration 480 / 500 (elapsed time: 76s, 1.3mn) Iteration 485 / 500 (elapsed time: 76s, 1.3mn) Iteration 490 / 500 (elapsed time: 77s, 1.3mn) Iteration 495 / 500 (elapsed time: 78s, 1.3mn) done in 79.32s.
(<matplotlib.figure.Figure at 0x1119e5128>, <matplotlib.axes._subplots.AxesSubplot at 0x11beb0320>)
for eta in np.logspace(-4, 0, int(15/(DEBUG_DOWNSCALE)**.3), base=10, endpoint=False):
shl = SHL(DEBUG_DOWNSCALE=DEBUG_DOWNSCALE, eta=eta, verbose=verbose)
dico = shl.learn_dico()
shl.show_dico(dico, title='eta={}'.format(eta))
Extracting data...
--- Logging error --- Traceback (most recent call last): File "/Users/lolo/pool/science/BICV/SLIP/src/SLIP.py", line 201, in list_database filelist = os.listdir(self.full_url(name_database)) FileNotFoundError: [Errno 2] No such file or directory: 'database/serre07_distractors' During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/usr/local/Cellar/python3/3.5.1/Frameworks/Python.framework/Versions/3.5/lib/python3.5/logging/__init__.py", line 980, in emit msg = self.format(record) File "/usr/local/Cellar/python3/3.5.1/Frameworks/Python.framework/Versions/3.5/lib/python3.5/logging/__init__.py", line 830, in format return fmt.format(record) File "/usr/local/Cellar/python3/3.5.1/Frameworks/Python.framework/Versions/3.5/lib/python3.5/logging/__init__.py", line 567, in format record.message = record.getMessage() File "/usr/local/Cellar/python3/3.5.1/Frameworks/Python.framework/Versions/3.5/lib/python3.5/logging/__init__.py", line 330, in getMessage msg = msg % self.args TypeError: not all arguments converted during string formatting Call stack: File "/usr/local/Cellar/python3/3.5.1/Frameworks/Python.framework/Versions/3.5/lib/python3.5/runpy.py", line 170, in _run_module_as_main "__main__", mod_spec) File "/usr/local/Cellar/python3/3.5.1/Frameworks/Python.framework/Versions/3.5/lib/python3.5/runpy.py", line 85, in _run_code exec(code, run_globals) File "/usr/local/lib/python3.5/site-packages/ipykernel/__main__.py", line 3, in <module> app.launch_new_instance() File "/usr/local/lib/python3.5/site-packages/traitlets/config/application.py", line 592, in launch_instance app.start() File "/usr/local/lib/python3.5/site-packages/ipykernel/kernelapp.py", line 403, in start ioloop.IOLoop.instance().start() File "/usr/local/lib/python3.5/site-packages/zmq/eventloop/ioloop.py", line 151, in start super(ZMQIOLoop, self).start() File "/usr/local/lib/python3.5/site-packages/tornado/ioloop.py", line 866, in start handler_func(fd_obj, events) File "/usr/local/lib/python3.5/site-packages/tornado/stack_context.py", line 275, in null_wrapper return fn(*args, **kwargs) File "/usr/local/lib/python3.5/site-packages/zmq/eventloop/zmqstream.py", line 433, in _handle_events self._handle_recv() File "/usr/local/lib/python3.5/site-packages/zmq/eventloop/zmqstream.py", line 465, in _handle_recv self._run_callback(callback, msg) File "/usr/local/lib/python3.5/site-packages/zmq/eventloop/zmqstream.py", line 407, in _run_callback callback(*args, **kwargs) File "/usr/local/lib/python3.5/site-packages/tornado/stack_context.py", line 275, in null_wrapper return fn(*args, **kwargs) File "/usr/local/lib/python3.5/site-packages/ipykernel/kernelbase.py", line 260, in dispatcher return self.dispatch_shell(stream, msg) File "/usr/local/lib/python3.5/site-packages/ipykernel/kernelbase.py", line 212, in dispatch_shell handler(stream, idents, msg) File "/usr/local/lib/python3.5/site-packages/ipykernel/kernelbase.py", line 370, in execute_request user_expressions, allow_stdin) File "/usr/local/lib/python3.5/site-packages/ipykernel/ipkernel.py", line 175, in do_execute shell.run_cell(code, store_history=store_history, silent=silent) File "/usr/local/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 2902, in run_cell interactivity=interactivity, compiler=compiler, result=result) File "/usr/local/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 3006, in run_ast_nodes if self.run_code(code, result): File "/usr/local/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 3066, in run_code exec(code_obj, self.user_global_ns, self.user_ns) File "<ipython-input-9-99cafb0443b4>", line 3, in <module> dico = shl.learn_dico() File "/Users/lolo/pool/science/BICV/SHL_scripts/src/shl_scripts.py", line 147, in learn_dico data = self.get_data(name_database) File "/Users/lolo/pool/science/BICV/SHL_scripts/src/shl_scripts.py", line 117, in get_data imagelist = self.slip.make_imagelist(name_database=name_database)#, seed=seed) File "/Users/lolo/pool/science/BICV/SLIP/src/SLIP.py", line 249, in make_imagelist filelist = self.list_database(name_database) File "/Users/lolo/pool/science/BICV/SLIP/src/SLIP.py", line 206, in list_database self.log.error('XX failed opening database ', self.full_url(name_database)) Message: 'XX failed opening database ' Arguments: ('database/serre07_distractors',) --- Logging error --- Traceback (most recent call last): File "/Users/lolo/pool/science/BICV/SLIP/src/SLIP.py", line 201, in list_database filelist = os.listdir(self.full_url(name_database)) FileNotFoundError: [Errno 2] No such file or directory: 'database/serre07_distractors' During handling of the above exception, another exception occurred: Traceback (most recent call last): File "/usr/local/Cellar/python3/3.5.1/Frameworks/Python.framework/Versions/3.5/lib/python3.5/logging/__init__.py", line 980, in emit msg = self.format(record) File "/usr/local/Cellar/python3/3.5.1/Frameworks/Python.framework/Versions/3.5/lib/python3.5/logging/__init__.py", line 830, in format return fmt.format(record) File "/usr/local/Cellar/python3/3.5.1/Frameworks/Python.framework/Versions/3.5/lib/python3.5/logging/__init__.py", line 567, in format record.message = record.getMessage() File "/usr/local/Cellar/python3/3.5.1/Frameworks/Python.framework/Versions/3.5/lib/python3.5/logging/__init__.py", line 330, in getMessage msg = msg % self.args TypeError: not all arguments converted during string formatting Call stack: File "/usr/local/Cellar/python3/3.5.1/Frameworks/Python.framework/Versions/3.5/lib/python3.5/runpy.py", line 170, in _run_module_as_main "__main__", mod_spec) File "/usr/local/Cellar/python3/3.5.1/Frameworks/Python.framework/Versions/3.5/lib/python3.5/runpy.py", line 85, in _run_code exec(code, run_globals) File "/usr/local/lib/python3.5/site-packages/ipykernel/__main__.py", line 3, in <module> app.launch_new_instance() File "/usr/local/lib/python3.5/site-packages/traitlets/config/application.py", line 592, in launch_instance app.start() File "/usr/local/lib/python3.5/site-packages/ipykernel/kernelapp.py", line 403, in start ioloop.IOLoop.instance().start() File "/usr/local/lib/python3.5/site-packages/zmq/eventloop/ioloop.py", line 151, in start super(ZMQIOLoop, self).start() File "/usr/local/lib/python3.5/site-packages/tornado/ioloop.py", line 866, in start handler_func(fd_obj, events) File "/usr/local/lib/python3.5/site-packages/tornado/stack_context.py", line 275, in null_wrapper return fn(*args, **kwargs) File "/usr/local/lib/python3.5/site-packages/zmq/eventloop/zmqstream.py", line 433, in _handle_events self._handle_recv() File "/usr/local/lib/python3.5/site-packages/zmq/eventloop/zmqstream.py", line 465, in _handle_recv self._run_callback(callback, msg) File "/usr/local/lib/python3.5/site-packages/zmq/eventloop/zmqstream.py", line 407, in _run_callback callback(*args, **kwargs) File "/usr/local/lib/python3.5/site-packages/tornado/stack_context.py", line 275, in null_wrapper return fn(*args, **kwargs) File "/usr/local/lib/python3.5/site-packages/ipykernel/kernelbase.py", line 260, in dispatcher return self.dispatch_shell(stream, msg) File "/usr/local/lib/python3.5/site-packages/ipykernel/kernelbase.py", line 212, in dispatch_shell handler(stream, idents, msg) File "/usr/local/lib/python3.5/site-packages/ipykernel/kernelbase.py", line 370, in execute_request user_expressions, allow_stdin) File "/usr/local/lib/python3.5/site-packages/ipykernel/ipkernel.py", line 175, in do_execute shell.run_cell(code, store_history=store_history, silent=silent) File "/usr/local/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 2902, in run_cell interactivity=interactivity, compiler=compiler, result=result) File "/usr/local/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 3006, in run_ast_nodes if self.run_code(code, result): File "/usr/local/lib/python3.5/site-packages/IPython/core/interactiveshell.py", line 3066, in run_code exec(code_obj, self.user_global_ns, self.user_ns) File "<ipython-input-9-99cafb0443b4>", line 3, in <module> dico = shl.learn_dico() File "/Users/lolo/pool/science/BICV/SHL_scripts/src/shl_scripts.py", line 147, in learn_dico data = self.get_data(name_database) File "/Users/lolo/pool/science/BICV/SHL_scripts/src/shl_scripts.py", line 117, in get_data imagelist = self.slip.make_imagelist(name_database=name_database)#, seed=seed) File "/Users/lolo/pool/science/BICV/SLIP/src/SLIP.py", line 261, in make_imagelist image_, filename, croparea = self.patch(name_database, i_image=shuffling[i_image % N_image_db], verbose=verbose) File "/Users/lolo/pool/science/BICV/SLIP/src/SLIP.py", line 307, in patch image, filename = self.load_in_database(name_database, i_image=i_image, filename=filename, verbose=verbose) File "/Users/lolo/pool/science/BICV/SLIP/src/SLIP.py", line 224, in load_in_database filelist = self.list_database(name_database=name_database) File "/Users/lolo/pool/science/BICV/SLIP/src/SLIP.py", line 206, in list_database self.log.error('XX failed opening database ', self.full_url(name_database)) Message: 'XX failed opening database ' Arguments: ('database/serre07_distractors',) ERROR:SLIP:failed opening database/serre07_distractors/d
--------------------------------------------------------------------------- AttributeError Traceback (most recent call last) <ipython-input-9-99cafb0443b4> in <module>() 1 for eta in np.logspace(-4, 0, int(15/(DEBUG_DOWNSCALE)**.3), base=10, endpoint=False): 2 shl = SHL(DEBUG_DOWNSCALE=DEBUG_DOWNSCALE, eta=eta, verbose=verbose) ----> 3 dico = shl.learn_dico() 4 shl.show_dico(dico, title='eta={}'.format(eta)) /Users/lolo/pool/science/BICV/SHL_scripts/src/shl_scripts.py in learn_dico(self, name_database, **kwargs) 145 146 def learn_dico(self, name_database='serre07_distractors', **kwargs): --> 147 data = self.get_data(name_database) 148 # Learn the dictionary from reference patches 149 if self.verbose: print('Learning the dictionary...', end=' ') /Users/lolo/pool/science/BICV/SHL_scripts/src/shl_scripts.py in get_data(self, name_database, seed, patch_norm) 115 sys.stdout.write("\b" * (toolbar_width+1)) # return to start of line, after '[' 116 t0 = time.time() --> 117 imagelist = self.slip.make_imagelist(name_database=name_database)#, seed=seed) 118 for filename, croparea in imagelist: 119 # whitening /Users/lolo/pool/science/BICV/SLIP/src/SLIP.py in make_imagelist(self, name_database, verbose) 259 imagelist = [] 260 for i_image in range(N_image): --> 261 image_, filename, croparea = self.patch(name_database, i_image=shuffling[i_image % N_image_db], verbose=verbose) 262 imagelist.append([filename, croparea]) 263 /Users/lolo/pool/science/BICV/SLIP/src/SLIP.py in patch(self, name_database, i_image, filename, croparea, threshold, verbose, preprocess, center, use_max) 309 310 if (croparea is None): --> 311 image_size_h, image_size_v = image.shape 312 if self.N_X > image_size_h or self.N_Y > image_size_v: 313 print('N_X patch_h patch_v ', self.N_X, image_size_h, image_size_v) AttributeError: 'str' object has no attribute 'shape'
a quick diagnostic of what is wrong¶
An assumption in the previous code is the heuristics used to control how elements are chosen. Basically, we have set the norm of every dictionary element to the inverse of an estimate of the mean variance, such that a high variance means a lower norm and a lower corresponding coefficient: it will thus get less likely to be selected again.
data = shl.get_data()
dico.set_params(transform_algorithm='omp', transform_n_nonzero_coefs=shl.n_components)
code = dico.transform(data)
Z = np.mean(code**2)
fig = plt.figure(figsize=(12, 4))
ax = fig.add_subplot(111)
ax.bar(np.arange(shl.n_components), np.mean(code**2/Z, axis=0))#, yerr=np.std(code**2/Z, axis=0))
ax.set_title('Variance of coefficients')
ax.set_ylabel('Variance')
ax.set_xlabel('#')
ax.axis('tight')
fig.show()
fig = plt.figure(figsize=(6, 4))
ax = fig.add_subplot(111)
data = pd.DataFrame(np.mean(code**2/Z, axis=0), columns=['Variance'])
with sns.axes_style("white"):
ax = sns.distplot(data['Variance'], kde_kws={'clip':(0., 5.)})
ax.set_title('distribution of the mean variance of coefficients')
ax.set_ylabel('pdf')
fig.show()
In particular, those with high-variance are more likely features that were more learned, while those with lower variance correspond to textures closer to the initail state of the dictionary. That is shown by ordering filters in the dictionary (from the first line on the top from left to right and then to the bottom):
sorted_idx = np.argsort(np.mean(code**2/Z, axis=0))
dico.components_ = dico.components_[sorted_idx, :]
fig = shl.show_dico(dico)
fig.show()
print(dico.components_.shape, data.shape)
data = shl.get_data()
#dico.set_params(transform_algorithm='omp', transform_n_nonzero_coefs=shl.n_components)
#code = dico.transform(data)
coef = np.dot(dico.components_, data.T).T / np.sum(dico.components_**2, axis=1)[np.newaxis, :]
#Z = np.mean(code**2)
fig = plt.figure(figsize=(12, 4))
ax = fig.add_subplot(111)
ax.bar(np.arange(shl.n_components), np.mean(coef**2, axis=0))#, yerr=np.std(code**2/Z, axis=0))
ax.set_title('Variance of linear coefficients')
ax.set_ylabel('Variance')
ax.set_xlabel('#')
ax.axis('tight')
fig.show()
By the scaling law of a linear filter, one may normalize the norm of the filters relative to the mean deviation (root square of variance) measured on the input. By doing this operation, the observed variance of the linear coefficients is normalized to the same value. That's what we have with the heuristics implemented in the previous notebook.
However, the learning rate (gain_rate
) of this heuristics may be hard to tune. Most importantly, this heuristics assumes that all components are scaled (for instance that the first to learn will have higher variance and that its whole distribution will be scaled), and thus assumes that they all have the same distribution. This assumption misses an important aspect of this unsupervised learning scheme.
Indeed, when a feature appears during the learning, its response is scaled but its distribution becomes more kurtotic. This may be understood using the central limit theorem. When mixing signals such as when computing the average of random variables, the distribution converges to a Gaussian.
Let's show the distribution of the two dictionary elements with respectively the highest (left) and lowest (right) variance:
sorted_idx = np.argsort(np.mean(coef**2, axis=0))
# plot distribution of most kurtotic vs least
fig = plt.figure(figsize=(12, 4))
for i, idx in enumerate([0, -1]):
#data = pd.DataFrame(code[:, sorted_idx[idx]], columns=['coefficient'])
with sns.axes_style("white"):
ax = fig.add_subplot(1, 2, i)
n_bins = 30
n, bins = np.histogram((np.abs(coef[:, sorted_idx[idx]])), n_bins)# =
ax.bar(bins[:-1], np.log2(n), width=(np.abs(coef[:, sorted_idx[idx]].max()))/n_bins)
#ax = sns.distplot(data['coefficient'])#, kde_kws={'clip':(0., 5.)})
ax.set_title('Variance = {}'.format(np.mean(coef[:, sorted_idx[idx]]**2)))
ax.set_ylabel('log2-density')
ax.set_xlabel('absolute coefficient')
fig.suptitle('distribution of the value of coefficients', fontsize=16)
fig.show()
To circumvent this problem, instead of changing the learning scheme, it is possible to change the coding procedure such that we are sure that every coefficient is selected a priori with the same probability. To perform that in a non-parametric fashion, one may use histogram equalization.
n_samples, n_components = coef.shape
np.sort(np.abs(coef[:(n_samples/2), sorted_idx[idx]])).shape
#print(-np.sort(-np.abs(coef[:(n_samples/2), sorted_idx[idx]])))
#print(-np.sort(-np.abs(coef[:, sorted_idx[idx]]]))
# plot distribution of most kurtotic vs least
fig = plt.figure(figsize=(12, 4))
for i, idx in enumerate([0, -1]):
with sns.axes_style("white"):
ax = fig.add_subplot(1, 2, i)
ax.semilogx(-np.sort(-np.abs(coef[:(n_samples/2), sorted_idx[idx]])))
ax.set_title('Variance = {}'.format(np.mean(coef[:, sorted_idx[idx]]**2)))
ax.set_ylabel('value')
ax.set_xlabel('rank')
ax.axis('tight')
fig.suptitle('distribution of the value of linear coefficients', fontsize=16)
fig.show()
# what we need is some sort of histogram normalization
def histeq(abs_code, mod):
# use linear interpolation of cdf to find z-values
# use the fact that the sorted coeffs give the inverse cdf
# moreover we use a hack to ensure that np.interp uses an increasing sequence of x-coordinates
z = np.interp(-abs_code, -mod, np.linspace(0, 1., mod.size, endpoint=True))
return z
fig = plt.figure(figsize=(12, 4))
for i, idx in enumerate([0, -1]):
mod = -np.sort(-np.abs(coef[:(n_samples/2), sorted_idx[idx]]))
# learn distribution of z-values on the first half of the data
z = histeq(np.abs(coef[(n_samples/2):, sorted_idx[idx]]), mod)
#z = histeq(np.abs(coef[:(n_samples/2), sorted_idx[idx]]), mod) - to check it is exact
# plot distribution of z-values on the second half of the data
with sns.axes_style("white"):
ax = fig.add_subplot(1, 2, i)
n_bins = 30
n, bins = np.histogram(z, n_bins)# =
ax.bar(bins[:-1], n, width=.5/n_bins)
#ax.scatter(np.linspace(0, 1., z.size, endpoint=True), z)
#ax.set_title('Variance = {}'.format(np.mean(code[:, sorted_idx[idx]]**2/Z, axis=0)))
ax.set_ylabel('density')
ax.set_xlabel('z-score')
ax.axis('tight')
fig.suptitle('distribution of the value of z-scores', fontsize=16)
fig.show()
To learn this modulation function, we may just estimate it online beginning with a flat one:
mod = np.dot(np.linspace(1., 0, n_samples, endpoint=True)[:, np.newaxis], np.ones((1, n_components)))
print(mod.shape)
plt.plot(mod[:, 0])
print(coef.shape)
print(np.sort(np.abs(coef), axis=0).shape)
mod_ = -np.sort(-np.abs(coef), axis=0)
gain_rate = .1
mod = (1 - gain_rate)*mod + gain_rate * mod_
plt.semilogx(mod[:, 0])
shl = SHL(DEBUG_DOWNSCALE=DEBUG_DOWNSCALE)
verbose = 0
dico = shl.learn_dico(learning_algorithm='comp', transform_n_nonzero_coefs=shl.n_components, gain_rate=0.001, verbose=verbose)
fig = shl.show_dico(dico)
fig.show()
data = shl.get_data()
#dico.set_params(transform_algorithm='omp', transform_n_nonzero_coefs=shl.n_components)
#code = dico.transform(data)
coef = np.dot(dico.components_, data.T).T / np.sum(dico.components_**2, axis=1)[np.newaxis, :]
#Z = np.mean(code**2)
fig = plt.figure(figsize=(12, 4))
ax = fig.add_subplot(111)
ax.bar(np.arange(shl.n_components), np.mean(coef**2, axis=0))#, yerr=np.std(code**2/Z, axis=0))
ax.set_title('Variance of linear coefficients')
ax.set_ylabel('Variance')
ax.set_xlabel('#')
ax.axis('tight')
fig.show()
# plot distribution of most kurtotic vs least
sorted_idx = np.argsort(np.mean(coef**2, axis=0))
fig = plt.figure(figsize=(12, 4))
for i, idx in enumerate([0, -1]):
#data = pd.DataFrame(code[:, sorted_idx[idx]], columns=['coefficient'])
with sns.axes_style("white"):
ax = fig.add_subplot(1, 2, i)
n_bins = 30
n, bins = np.histogram((np.abs(coef[:, sorted_idx[idx]])), n_bins)# =
ax.bar(bins[:-1], np.log2(n), width=(np.abs(coef[:, sorted_idx[idx]].max()))/n_bins)
#ax = sns.distplot(data['coefficient'])#, kde_kws={'clip':(0., 5.)})
ax.set_title('Variance = {}'.format(np.mean(coef[:, sorted_idx[idx]]**2)))
ax.set_ylabel('log2-density')
ax.set_xlabel('absolute coefficient')
fig.suptitle('distribution of the value of coefficients', fontsize=16)
fig.show()
# quick estipmation of the z-scores
#z = np.interp(-np.abs(coef[(n_samples/2):, sorted_idx[idx]]), -mod, np.linspace(0, 1., mod.size, endpoint=True))
from shl_scripts import SHL
DEBUG_DOWNSCALE=1
for gain_rate in np.logspace(-4, 0, 15, base=10):
#
shl = SHL(DEBUG_DOWNSCALE=DEBUG_DOWNSCALE)
dico = shl.learn_dico(learning_algorithm='comp', transform_n_nonzero_coefs=10, gain_rate=gain_rate)
fig = shl.show_dico(dico)
fig.show()
data = shl.get_data()
dico.set_params(transform_algorithm='omp', transform_n_nonzero_coefs=shl.n_components)
code = dico.transform(data)
n_components, n_samples = code.shape
Z = np.mean(code**2)
fig = plt.figure(figsize=(12, 4))
ax = fig.add_subplot(111)
ax.bar(np.arange(shl.n_components), np.mean(code**2, axis=0)/Z)
ax.set_title('Variance of coefficients - gain {}'.format(gain_rate))
ax.set_ylabel('Variance')
ax.set_xlabel('#')
ax.axis('tight')
fig.show()
fig = plt.figure(figsize=(6, 4))
ax = fig.add_subplot(111)
data = pd.DataFrame(np.mean(code**2, axis=0)/np.mean(code**2), columns=['Variance'])
with sns.axes_style("white"):
ax = sns.distplot(data['Variance'], kde_kws={'clip':(0., 5.)})
ax.set_title('distribution of the mean variance of coefficients')
ax.set_ylabel('pdf')
fig.show()
great, let's try that new version of the classical algorithm¶
contributing code to sklearn¶
As previously, let's contribute this bit of code to sklearn
:
-
set-up variables
cd ~/pool/libs/ github_user='bicv' lib='scikit-learn' project='comp' git clone https://github.com/$github_user/$lib cd $lib
-
creating a new branch for his project
git branches git remote add $github_user https://github.com/$github_user/$lib git checkout -b $project
-
installing the library in dev mode:
pip uninstall $lib pip install -e .
-
making the one line change to moviepy's code and test it with the minimum working example above:
mvim -p sklearn/decomposition/dict_learning.py sklearn/linear_model/omp.py
More details on MiniBatchDictionaryLearning
: