# Fitting COVID data

I propose here a simple method to fit experimental data common to epidemiological spreads, such as the present COVID-19 pandemic, using the inverse gaussian distribution. This follows the general incompregension of my answer to the question Is the COVID-19 pandemic curve a Gaussian curve? on StackOverflow. My initial point is to say that a Gaussian is not adapted as it handles a distribution on real numbers, while such a curve (the variable being number of days) handles numbers on the half line. Inspired by the excellent A Theory of Reaction Time Distributions by Dr Fermin Moscoso del Prado Martin a constructive approach is to propose another distribution, such as the inverse Gaussian distribution.

This notebook develops this same idea on real data and proves numerically how bad the Gaussian fit is compared to the latter. Thinking about the importance of doing a proper inference in such a case, I conclude

In this notebook, I define a fitting method using pytorch which fits data to a the inverse gaussian distribution :

$$f ( x ; \mu , \lambda ) = \sqrt {\frac {\lambda }{2\pi x^3}} \exp (-\frac {\lambda (x-\mu )^{2}}{2\mu ^{2}x} )$$

In [33]:
from __future__ import division, print_function
import numpy as np
np.set_printoptions(precision=6, suppress=True)
import os
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import matplotlib.pyplot as plt
phi = (np.sqrt(5)+1)/2
fig_width = 10
figsize = (fig_width, fig_width/phi)

The history saving thread hit an unexpected error (OperationalError('attempt to write a readonly database')).History will not be written to the database.


## getting real COVID data¶

### retrieve data¶

We will here use data provided openly by the french government: https://github.com/opencovid19-fr/data and use the following URL:

In [34]:
URL, data_cache = 'https://raw.githubusercontent.com/opencovid19-fr/data/master/dist/chiffres-cles.json', '/tmp/covid_fr.json'

In [35]:
import pandas as pd


In [36]:
try:
except:
df.to_json(data_cache)

loading cache

In [37]:
df

Out[37]:
0 2020-01-01 {'nom': 'Ministère des Solidarités et de la Sa... ministere-sante 44985.0 19780.0 24296.0 194901.0 France FRA NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
1 2020-01-07 {'nom': 'Ministère des Solidarités et de la Sa... ministere-sante 46539.0 20302.0 24521.0 200079.0 France FRA 2727321.0 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2 2020-01-08 {'nom': 'Ministère des Solidarités et de la Sa... ministere-sante 46815.0 NaN NaN NaN France FRA NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
3 2020-01-16 {'nom': 'Ministère des Solidarités et de la Sa... ministere-sante 48783.0 21359.0 25019.0 209056.0 France FRA 2894347.0 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
4 2020-01-17 {'nom': 'Ministère des Solidarités et de la Sa... ministere-sante 48924.0 21359.0 25269.0 209343.0 France FRA 2910989.0 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
65049 2021-08-12 {'nom': 'OpenCOVID19-fr'} opencovid19-fr 3932.0 NaN 616.0 17960.0 Nouvelle-Aquitaine REG-75 NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
65050 2021-08-12 {'nom': 'OpenCOVID19-fr'} opencovid19-fr 4720.0 NaN 1116.0 23109.0 Occitanie REG-76 NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
65051 2021-08-12 {'nom': 'OpenCOVID19-fr'} opencovid19-fr 11763.0 NaN 967.0 52545.0 Auvergne-Rhône-Alpes REG-84 NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
65052 2021-08-12 {'nom': 'OpenCOVID19-fr'} opencovid19-fr 8195.0 NaN 1422.0 42408.0 Provence-Alpes-Côte d'Azur REG-93 NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
65053 2021-08-12 {'nom': 'OpenCOVID19-fr'} opencovid19-fr 217.0 NaN 85.0 1051.0 Corse REG-94 NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN

65054 rows × 31 columns

Let's convert the date column into a proper date:

In [38]:
df['date'] = pd.to_datetime(df['date'])
df

Out[38]:
0 2020-01-01 {'nom': 'Ministère des Solidarités et de la Sa... ministere-sante 44985.0 19780.0 24296.0 194901.0 France FRA NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
1 2020-01-07 {'nom': 'Ministère des Solidarités et de la Sa... ministere-sante 46539.0 20302.0 24521.0 200079.0 France FRA 2727321.0 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2 2020-01-08 {'nom': 'Ministère des Solidarités et de la Sa... ministere-sante 46815.0 NaN NaN NaN France FRA NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
3 2020-01-16 {'nom': 'Ministère des Solidarités et de la Sa... ministere-sante 48783.0 21359.0 25019.0 209056.0 France FRA 2894347.0 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
4 2020-01-17 {'nom': 'Ministère des Solidarités et de la Sa... ministere-sante 48924.0 21359.0 25269.0 209343.0 France FRA 2910989.0 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
65049 2021-08-12 {'nom': 'OpenCOVID19-fr'} opencovid19-fr 3932.0 NaN 616.0 17960.0 Nouvelle-Aquitaine REG-75 NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
65050 2021-08-12 {'nom': 'OpenCOVID19-fr'} opencovid19-fr 4720.0 NaN 1116.0 23109.0 Occitanie REG-76 NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
65051 2021-08-12 {'nom': 'OpenCOVID19-fr'} opencovid19-fr 11763.0 NaN 967.0 52545.0 Auvergne-Rhône-Alpes REG-84 NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
65052 2021-08-12 {'nom': 'OpenCOVID19-fr'} opencovid19-fr 8195.0 NaN 1422.0 42408.0 Provence-Alpes-Côte d'Azur REG-93 NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
65053 2021-08-12 {'nom': 'OpenCOVID19-fr'} opencovid19-fr 217.0 NaN 85.0 1051.0 Corse REG-94 NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN

65054 rows × 31 columns

... and extract the different data codes:

In [39]:
df['code'].unique()

Out[39]:
array(['FRA', 'DEP-16', 'DEP-17', 'DEP-19', 'DEP-23', 'DEP-24', 'DEP-33',
'DEP-40', 'DEP-47', 'DEP-64', 'DEP-79', 'DEP-86', 'DEP-87',
'REG-11', 'REG-75', 'WORLD', 'DEP-34', 'DEP-74', 'REG-84',
'REG-27', 'DEP-02', 'DEP-25', 'DEP-59', 'DEP-60', 'DEP-62',
'DEP-80', 'DEP-90', 'REG-32', 'REG-44', 'DEP-21', 'DEP-29',
'DEP-44', 'DEP-67', 'DEP-06', 'DEP-49', 'DEP-53', 'DEP-76',
'REG-01', 'REG-02', 'REG-03', 'REG-04', 'REG-06', 'REG-24',
'REG-28', 'REG-52', 'REG-53', 'REG-76', 'REG-93', 'REG-94',
'DEP-35', 'COM-977', 'COM-978', 'DEP-56', 'DEP-72', 'DEP-01',
'DEP-08', 'DEP-10', 'DEP-27', 'DEP-51', 'DEP-52', 'DEP-54',
'DEP-55', 'DEP-57', 'DEP-68', 'DEP-69', 'DEP-88', 'DEP-03',
'DEP-07', 'DEP-15', 'DEP-26', 'DEP-30', 'DEP-38', 'DEP-42',
'DEP-43', 'DEP-63', 'DEP-71', 'DEP-73', 'DEP-12', 'DEP-13',
'DEP-22', 'DEP-28', 'DEP-37', 'DEP-70', 'DEP-84', 'DEP-973',
'DEP-05', 'DEP-14', 'DEP-18', 'DEP-2A', 'DEP-2B', 'DEP-31',
'DEP-36', 'DEP-41', 'DEP-45', 'DEP-50', 'DEP-75', 'DEP-77',
'DEP-78', 'DEP-83', 'DEP-91', 'DEP-92', 'DEP-93', 'DEP-94',
'DEP-95', 'DEP-972', 'DEP-39', 'DEP-46', 'DEP-81', 'DEP-82',
'DEP-85', 'DEP-89', 'DEP-11', 'DEP-58', 'DEP-61', 'DEP-04',
'DEP-32', 'COM-987', 'COM-974', 'DEP-65', 'DEP-971', 'DEP-974',
'DEP-66', 'DEP-48', 'DEP-976', 'DEP-09', 'COM-988', 'COM-986'],
dtype=object)

### filtering data¶

Selecting one region:

In [40]:
code = 'FRA'
code = 'REG-93'
df_loc = df[df['code']==code]

sourceType = "ministere-sante" sourceType = 'agences-regionales-sante' df_loc = df_loc[df_loc['sourceType']==sourceType]

Selecting the two columns of interest:

In [41]:
df_loc = df_loc[['date', 'deces']]
df_loc

Out[41]:
date deces
219 2020-02-28 NaN
297 2020-03-02 NaN
352 2020-03-03 NaN
439 2020-03-04 0.0
440 2020-03-04 NaN
... ... ...
64572 2021-08-08 8145.0
64692 2021-08-09 8158.0
64812 2021-08-10 8168.0
64932 2021-08-11 8181.0
65052 2021-08-12 8195.0

556 rows × 2 columns

Removing that containing NaNs:

In [42]:
df_loc = df_loc[df_loc['deces'].notna()]


Selecting a date range:

In [43]:
start_date = '2020-02-23'
df_loc['date'] > start_date

Out[43]:
439      True
537      True
663      True
770      True
874      True
...
64572    True
64692    True
64812    True
64932    True
65052    True
Name: date, Length: 533, dtype: bool
In [44]:
df_loc = df_loc[start_date < df_loc['date']]

In [45]:
stop_date = '2020-06-23'
df_loc = df_loc[df_loc['date']<stop_date]

In [46]:
fig, ax = plt.subplots(figsize=figsize)
df_loc.plot(x='date', y='deces', ax=ax)
ax.set_title('Cumulative number of deaths');


#### returning a numpy array¶

In [47]:
death = np.array(df_loc['deces'])
df_loc['date'], death

Out[47]:
(439     2020-03-04
537     2020-03-05
663     2020-03-06
770     2020-03-07
874     2020-03-08
...
14405   2020-06-18
14526   2020-06-19
14647   2020-06-20
14768   2020-06-21
14889   2020-06-22
Name: date, Length: 117, dtype: datetime64[ns],
array([  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   4.,
7.,   7.,   9.,   9.,  11.,  11.,  13.,  13.,  15.,  14.,  20.,
20.,  26.,  26.,  33.,  33.,  44.,  44.,  48.,  55.,  65.,  80.,
103., 124., 141., 161., 178., 195., 231., 253., 269., 286., 299.,
317., 328., 348., 374., 393., 421., 447., 463., 470., 499., 526.,
561., 575., 610., 625., 631., 656., 665., 682., 696., 704., 708.,
713., 732., 746., 756., 764., 770., 773., 777., 797., 805., 816.,
824., 831., 836., 840., 849., 855., 862., 869., 874., 876., 878.,
884., 889., 890., 893., 898., 898., 898., 898., 908., 911., 912.,
913., 913., 913., 917., 919., 920., 922., 923., 923., 923., 925.,
927., 927., 929., 933., 935., 935., 936.]))
In [48]:
death.shape, np.diff(death).shape

Out[48]:
((117,), (116,))
In [49]:
death[1:] = np.diff(death)
death[0] = 0

In [50]:
death

Out[50]:
array([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  4.,  3.,  0.,
2.,  0.,  2.,  0.,  2.,  0.,  2., -1.,  6.,  0.,  6.,  0.,  7.,
0., 11.,  0.,  4.,  7., 10., 15., 23., 21., 17., 20., 17., 17.,
36., 22., 16., 17., 13., 18., 11., 20., 26., 19., 28., 26., 16.,
7., 29., 27., 35., 14., 35., 15.,  6., 25.,  9., 17., 14.,  8.,
4.,  5., 19., 14., 10.,  8.,  6.,  3.,  4., 20.,  8., 11.,  8.,
7.,  5.,  4.,  9.,  6.,  7.,  7.,  5.,  2.,  2.,  6.,  5.,  1.,
3.,  5.,  0.,  0.,  0., 10.,  3.,  1.,  1.,  0.,  0.,  4.,  2.,
1.,  2.,  1.,  0.,  0.,  2.,  2.,  0.,  2.,  4.,  2.,  0.,  1.])
In [51]:
df_loc['death'] = death
df_loc

Out[51]:
date deces death
439 2020-03-04 0.0 0.0
537 2020-03-05 0.0 0.0
663 2020-03-06 0.0 0.0
770 2020-03-07 0.0 0.0
874 2020-03-08 0.0 0.0
... ... ... ...
14405 2020-06-18 929.0 2.0
14526 2020-06-19 933.0 4.0
14647 2020-06-20 935.0 2.0
14768 2020-06-21 935.0 0.0
14889 2020-06-22 936.0 1.0

117 rows × 3 columns

In [52]:
fig, ax = plt.subplots(figsize=figsize)
df_loc.plot(x='date', y='death', ax=ax)
ax.set_title('Daily rate of deaths');


### performing the fit¶

#### converting data to fit¶

In [53]:
X, y = df_loc['date'], df_loc['death']

In [54]:
X.shape, y.shape

Out[54]:
((117,), (117,))
In [55]:
X = X.astype(int)
X = np.array(X)
X

Out[55]:
array([1583280000000000000, 1583366400000000000, 1583452800000000000,
1583539200000000000, 1583625600000000000, 1583712000000000000,
1583798400000000000, 1583884800000000000, 1583971200000000000,
1584057600000000000, 1584403200000000000, 1584489600000000000,
1584489600000000000, 1584576000000000000, 1584576000000000000,
1584662400000000000, 1584662400000000000, 1584748800000000000,
1584748800000000000, 1584835200000000000, 1584835200000000000,
1584921600000000000, 1584921600000000000, 1585008000000000000,
1585008000000000000, 1585094400000000000, 1585094400000000000,
1585180800000000000, 1585180800000000000, 1585267200000000000,
1585353600000000000, 1585440000000000000, 1585526400000000000,
1585612800000000000, 1585699200000000000, 1585785600000000000,
1585872000000000000, 1585958400000000000, 1586044800000000000,
1586131200000000000, 1586217600000000000, 1586304000000000000,
1586390400000000000, 1586476800000000000, 1586563200000000000,
1586649600000000000, 1586736000000000000, 1586822400000000000,
1586908800000000000, 1586995200000000000, 1587081600000000000,
1587168000000000000, 1587254400000000000, 1587340800000000000,
1587427200000000000, 1587513600000000000, 1587600000000000000,
1587686400000000000, 1587772800000000000, 1587859200000000000,
1587945600000000000, 1588032000000000000, 1588118400000000000,
1588204800000000000, 1588291200000000000, 1588377600000000000,
1588464000000000000, 1588550400000000000, 1588636800000000000,
1588723200000000000, 1588809600000000000, 1588896000000000000,
1588982400000000000, 1589068800000000000, 1589155200000000000,
1589241600000000000, 1589328000000000000, 1589414400000000000,
1589500800000000000, 1589587200000000000, 1589673600000000000,
1589760000000000000, 1589846400000000000, 1589932800000000000,
1590019200000000000, 1590105600000000000, 1590192000000000000,
1590278400000000000, 1590364800000000000, 1590451200000000000,
1590537600000000000, 1590624000000000000, 1590710400000000000,
1590796800000000000, 1590883200000000000, 1590969600000000000,
1591056000000000000, 1591142400000000000, 1591228800000000000,
1591315200000000000, 1591401600000000000, 1591488000000000000,
1591574400000000000, 1591660800000000000, 1591747200000000000,
1591833600000000000, 1591920000000000000, 1592006400000000000,
1592092800000000000, 1592179200000000000, 1592265600000000000,
1592352000000000000, 1592438400000000000, 1592524800000000000,
1592611200000000000, 1592697600000000000, 1592784000000000000])
In [56]:
INTS_PER_DAY = 86400000000000

In [57]:
X = (X-X[0])//INTS_PER_DAY

In [58]:
np.array(X)
X

Out[58]:
array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  13,  14,  14,
15,  15,  16,  16,  17,  17,  18,  18,  19,  19,  20,  20,  21,
21,  22,  22,  23,  24,  25,  26,  27,  28,  29,  30,  31,  32,
33,  34,  35,  36,  37,  38,  39,  40,  41,  42,  43,  44,  45,
46,  47,  48,  49,  50,  51,  52,  53,  54,  55,  56,  57,  58,
59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,  70,  71,
72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,  84,
85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110])
In [59]:
y = np.array(y)
y

Out[59]:
array([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  4.,  3.,  0.,
2.,  0.,  2.,  0.,  2.,  0.,  2., -1.,  6.,  0.,  6.,  0.,  7.,
0., 11.,  0.,  4.,  7., 10., 15., 23., 21., 17., 20., 17., 17.,
36., 22., 16., 17., 13., 18., 11., 20., 26., 19., 28., 26., 16.,
7., 29., 27., 35., 14., 35., 15.,  6., 25.,  9., 17., 14.,  8.,
4.,  5., 19., 14., 10.,  8.,  6.,  3.,  4., 20.,  8., 11.,  8.,
7.,  5.,  4.,  9.,  6.,  7.,  7.,  5.,  2.,  2.,  6.,  5.,  1.,
3.,  5.,  0.,  0.,  0., 10.,  3.,  1.,  1.,  0.,  0.,  4.,  2.,
1.,  2.,  1.,  0.,  0.,  2.,  2.,  0.,  2.,  4.,  2.,  0.,  1.])

#### using torch¶

or "when you have a new hammer, everything looks like a nail, seehttps://en.wikipedia.org/wiki/Law_of_the_instrument

Let's use the definition of the pdf:

$$f ( x ; \mu , \lambda ) = \sqrt {\frac {\lambda }{2\pi x^3}} \exp (-\frac {\lambda (x-\mu )^{2}}{2\mu ^{2}x} )$$

and a previous post where I used pyTorch to fit psychophysical data:

In [60]:
import torch

torch.set_default_tensor_type("torch.DoubleTensor")
criterion = torch.nn.MSELoss(reduction="sum")
# https://pytorch.org/docs/stable/generated/torch.nn.L1Loss.html#torch.nn.L1Loss
#criterion = torch.nn.L1Loss(reduction="sum")

class CovidRegressionModel(torch.nn.Module):
def __init__(self, mu=80, tau=24,
log_amp=torch.log(500*torch.ones(1)),
log_lambda=torch.log(100*torch.ones(1)),
):
super(CovidRegressionModel, self).__init__()
self.tau = torch.nn.Parameter(tau * torch.ones(1))
self.mu = torch.nn.Parameter(mu * torch.ones(1))
# when modeling a stricly positive number, a good habit is to use their log as the fitted parameter:
self.log_amp = torch.nn.Parameter(log_amp * torch.ones(1))
self.log_lambda = torch.nn.Parameter(log_lambda * torch.ones(1))

def forward(self, x):
out = (x > self.tau) * torch.exp(self.log_amp)
date = (x-self.tau) #
date[x<=self.tau] = 1 # to avoid NaNs in the output value (and in the gradient)
out *= torch.sqrt(torch.exp(self.log_lambda) / (2 * np.pi * date ** 3))
out *= torch.exp(-torch.exp(self.log_lambda) * (date - self.mu) ** 2 / (2 * self.mu ** 2 * date))
out[x<=self.tau] = 0 # overwriting values before the onset
return out

learning_rate = 0.005
beta1, beta2 = 0.9, 0.999
betas = (beta1, beta2)
num_epochs = 2 ** 9 + 1
batch_size = 16
amsgrad = False # gives similar results
amsgrad = True  # gives similar results

def fit_data(
X,
y,
learning_rate=learning_rate,
batch_size=batch_size,  # gamma=gamma,
num_epochs=num_epochs,
betas=betas,
verbose=False, **kwargs
):

variables, labels = torch.Tensor(X[:, None]), torch.Tensor(y[:, None])
TensorDataset(variables, labels), batch_size=batch_size, shuffle=True
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

covid_model = CovidRegressionModel()
covid_model = covid_model.to(device)
covid_model.train()
)
for epoch in range(int(num_epochs)):
covid_model.train()
losses = []
variables_, labels_ = variables_.to(device), labels_.to(device)
outputs = covid_model(variables_)
loss = criterion(outputs, labels_)

loss.backward()
optimizer.step()
losses.append(loss.item())

if verbose and (epoch % (num_epochs // 32) == 0):
print(f"Iteration: {epoch} - Loss: {np.sum(losses)/len(variables):.5f}")

covid_model.eval()
variables, labels = torch.Tensor(X[:, None]), torch.Tensor(y[:, None])
outputs = covid_model(variables)
loss = criterion(outputs, labels).item() / len(variables)
return covid_model, loss

In [61]:
covid_model, loss = fit_data(X, y, verbose=True)
print("Final loss =", loss)

Iteration: 0 - Loss: 95.69920
Iteration: 16 - Loss: 54.13756
Iteration: 32 - Loss: 33.30596
Iteration: 48 - Loss: 31.36121
Iteration: 64 - Loss: 30.85001
Iteration: 80 - Loss: 30.38980
Iteration: 96 - Loss: 30.06311
Iteration: 112 - Loss: 29.77887
Iteration: 128 - Loss: 29.59526
Iteration: 144 - Loss: 29.40957
Iteration: 160 - Loss: 29.16254
Iteration: 176 - Loss: 29.03642
Iteration: 192 - Loss: 28.90962
Iteration: 208 - Loss: 28.83253
Iteration: 224 - Loss: 28.75095
Iteration: 240 - Loss: 28.68247
Iteration: 256 - Loss: 28.60991
Iteration: 272 - Loss: 28.55558
Iteration: 288 - Loss: 28.55271
Iteration: 304 - Loss: 28.45341
Iteration: 320 - Loss: 28.41483
Iteration: 336 - Loss: 28.38179
Iteration: 352 - Loss: 28.37117
Iteration: 368 - Loss: 28.33166
Iteration: 384 - Loss: 28.32052
Iteration: 400 - Loss: 28.29289
Iteration: 416 - Loss: 28.27676
Iteration: 432 - Loss: 28.27867
Iteration: 448 - Loss: 28.24910
Iteration: 464 - Loss: 28.24694
Iteration: 480 - Loss: 28.20544
Iteration: 496 - Loss: 28.19613
Iteration: 512 - Loss: 28.18755
Final loss = 28.163296599808184

In [62]:
covid_model.eval()
outputs = covid_model(torch.Tensor(X[:, None]))
y_pred = outputs.detach().numpy()

In [63]:
fig, ax = plt.subplots(figsize=figsize)
ax.plot(X, y, '*', label='data')
ax.plot(X, y_pred, '--', label='fit')
ax.vlines(covid_model.mu.item(), y.min(), y.max(), colors='g', linestyles='--', label=f'$\mu$={covid_model.mu.item():.1f} (mean date)')
ax.vlines(covid_model.tau.item(), y.min(), y.max(), colors='r', linestyles='--', label=fr'$\tau$={covid_model.tau.item():.1f} (day of onset)')
ax.plot([], [], label=f'A={torch.exp(covid_model.log_amp).item():.1f} (amplitude)')
ax.plot([], [], label=f'$\lambda$={torch.exp(covid_model.log_lambda).item():.1f} (spread in days)')
ax.legend()
ax.set_xlabel(f'# number of days since {start_date}');
ax.set_ylabel(f'# number of dead per day');
ax.set_title('Fitting Daily rate of deaths (PACA-FR)');


### some book keeping for the notebook¶

In [64]:
%load_ext watermark
%watermark -i -h -m -v -p numpy,pandas,matplotlib,torch  -r -g -b

The watermark extension is already loaded. To reload it, use:
Python implementation: CPython
Python version       : 3.11.6
IPython version      : 8.17.2

numpy     : 1.26.2
pandas    : 2.1.2
matplotlib: 3.8.1
torch     : 2.2.0.dev20231107

Compiler    : Clang 15.0.0 (clang-1500.0.40.1)
OS          : Darwin
Release     : 23.2.0
Machine     : arm64
Processor   : arm
CPU cores   : 10
Architecture: 64bit

Hostname: obiwan.local

Git hash: 688f040143fa7b3a5258594843cf6e13498821db

Git repo: https://github.com/laurentperrinet/sciblog

Git branch: master