Fitting COVID data

I propose here a simple method to fit experimental data common to epidemiological spreads, such as the present COVID-19 pandemic, using the inverse gaussian distribution. This follows the general incompregension of my answer to the question Is the COVID-19 pandemic curve a Gaussian curve? on StackOverflow. My initial point is to say that a Gaussian is not adapted as it handles a distribution on real numbers, while such a curve (the variable being number of days) handles numbers on the half line. Inspired by the excellent A Theory of Reaction Time Distributions by Dr Fermin Moscoso del Prado Martin a constructive approach is to propose another distribution, such as the inverse Gaussian distribution.

This notebook develops this same idea on real data and proves numerically how bad the Gaussian fit is compared to the latter. Thinking about the importance of doing a proper inference in such a case, I conclude

In this notebook, I define a fitting method using pytorch which fits data to a the inverse gaussian distribution :

$$ f ( x ; \mu , \lambda ) = \sqrt {\frac {\lambda }{2\pi x^3}} \exp (-\frac {\lambda (x-\mu )^{2}}{2\mu ^{2}x} ) $$

To learn more about some properties / implementation of this pdf, see scipy.stats documentation.

In [33]:
from __future__ import division, print_function
import numpy as np
np.set_printoptions(precision=6, suppress=True)
import os
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import matplotlib.pyplot as plt
phi = (np.sqrt(5)+1)/2
fig_width = 10
figsize = (fig_width, fig_width/phi)
The history saving thread hit an unexpected error (OperationalError('attempt to write a readonly database')).History will not be written to the database.

getting real COVID data

retrieve data

We will here use data provided openly by the french government: and use the following URL:

In [34]:
URL, data_cache = '', '/tmp/covid_fr.json'
In [35]:
import pandas as pd

Let's cache the data locally to avoid downloading it again and again:

In [36]:
    df = pd.read_json(data_cache)
    print('loading cache')
    df = pd.read_json(URL)
    print('loading from internet')
loading cache
In [37]:
date source sourceType deces decesEhpad hospitalises gueris nom code casConfirmes ... capaciteLitsSoinsIntensifs hospitalisation hospitalise hospitalisesReadaptation hospitalisesAuxUrgences hospitalisesConventionnelle reanimations capaciteReanimation casEhpad casPossiblesEhpad
0 2020-01-01 {'nom': 'Ministère des Solidarités et de la Sa... ministere-sante 44985.0 19780.0 24296.0 194901.0 France FRA NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
1 2020-01-07 {'nom': 'Ministère des Solidarités et de la Sa... ministere-sante 46539.0 20302.0 24521.0 200079.0 France FRA 2727321.0 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2 2020-01-08 {'nom': 'Ministère des Solidarités et de la Sa... ministere-sante 46815.0 NaN NaN NaN France FRA NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
3 2020-01-16 {'nom': 'Ministère des Solidarités et de la Sa... ministere-sante 48783.0 21359.0 25019.0 209056.0 France FRA 2894347.0 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
4 2020-01-17 {'nom': 'Ministère des Solidarités et de la Sa... ministere-sante 48924.0 21359.0 25269.0 209343.0 France FRA 2910989.0 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
65049 2021-08-12 {'nom': 'OpenCOVID19-fr'} opencovid19-fr 3932.0 NaN 616.0 17960.0 Nouvelle-Aquitaine REG-75 NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
65050 2021-08-12 {'nom': 'OpenCOVID19-fr'} opencovid19-fr 4720.0 NaN 1116.0 23109.0 Occitanie REG-76 NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
65051 2021-08-12 {'nom': 'OpenCOVID19-fr'} opencovid19-fr 11763.0 NaN 967.0 52545.0 Auvergne-Rhône-Alpes REG-84 NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
65052 2021-08-12 {'nom': 'OpenCOVID19-fr'} opencovid19-fr 8195.0 NaN 1422.0 42408.0 Provence-Alpes-Côte d'Azur REG-93 NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
65053 2021-08-12 {'nom': 'OpenCOVID19-fr'} opencovid19-fr 217.0 NaN 85.0 1051.0 Corse REG-94 NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN

65054 rows × 31 columns

Let's convert the date column into a proper date:

In [38]:
df['date'] = pd.to_datetime(df['date'])
date source sourceType deces decesEhpad hospitalises gueris nom code casConfirmes ... capaciteLitsSoinsIntensifs hospitalisation hospitalise hospitalisesReadaptation hospitalisesAuxUrgences hospitalisesConventionnelle reanimations capaciteReanimation casEhpad casPossiblesEhpad
0 2020-01-01 {'nom': 'Ministère des Solidarités et de la Sa... ministere-sante 44985.0 19780.0 24296.0 194901.0 France FRA NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
1 2020-01-07 {'nom': 'Ministère des Solidarités et de la Sa... ministere-sante 46539.0 20302.0 24521.0 200079.0 France FRA 2727321.0 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
2 2020-01-08 {'nom': 'Ministère des Solidarités et de la Sa... ministere-sante 46815.0 NaN NaN NaN France FRA NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
3 2020-01-16 {'nom': 'Ministère des Solidarités et de la Sa... ministere-sante 48783.0 21359.0 25019.0 209056.0 France FRA 2894347.0 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
4 2020-01-17 {'nom': 'Ministère des Solidarités et de la Sa... ministere-sante 48924.0 21359.0 25269.0 209343.0 France FRA 2910989.0 ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
65049 2021-08-12 {'nom': 'OpenCOVID19-fr'} opencovid19-fr 3932.0 NaN 616.0 17960.0 Nouvelle-Aquitaine REG-75 NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
65050 2021-08-12 {'nom': 'OpenCOVID19-fr'} opencovid19-fr 4720.0 NaN 1116.0 23109.0 Occitanie REG-76 NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
65051 2021-08-12 {'nom': 'OpenCOVID19-fr'} opencovid19-fr 11763.0 NaN 967.0 52545.0 Auvergne-Rhône-Alpes REG-84 NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
65052 2021-08-12 {'nom': 'OpenCOVID19-fr'} opencovid19-fr 8195.0 NaN 1422.0 42408.0 Provence-Alpes-Côte d'Azur REG-93 NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN
65053 2021-08-12 {'nom': 'OpenCOVID19-fr'} opencovid19-fr 217.0 NaN 85.0 1051.0 Corse REG-94 NaN ... NaN NaN NaN NaN NaN NaN NaN NaN NaN NaN

65054 rows × 31 columns

... and extract the different data codes:

In [39]:
array(['FRA', 'DEP-16', 'DEP-17', 'DEP-19', 'DEP-23', 'DEP-24', 'DEP-33',
       'DEP-40', 'DEP-47', 'DEP-64', 'DEP-79', 'DEP-86', 'DEP-87',
       'REG-11', 'REG-75', 'WORLD', 'DEP-34', 'DEP-74', 'REG-84',
       'REG-27', 'DEP-02', 'DEP-25', 'DEP-59', 'DEP-60', 'DEP-62',
       'DEP-80', 'DEP-90', 'REG-32', 'REG-44', 'DEP-21', 'DEP-29',
       'DEP-44', 'DEP-67', 'DEP-06', 'DEP-49', 'DEP-53', 'DEP-76',
       'REG-01', 'REG-02', 'REG-03', 'REG-04', 'REG-06', 'REG-24',
       'REG-28', 'REG-52', 'REG-53', 'REG-76', 'REG-93', 'REG-94',
       'DEP-35', 'COM-977', 'COM-978', 'DEP-56', 'DEP-72', 'DEP-01',
       'DEP-08', 'DEP-10', 'DEP-27', 'DEP-51', 'DEP-52', 'DEP-54',
       'DEP-55', 'DEP-57', 'DEP-68', 'DEP-69', 'DEP-88', 'DEP-03',
       'DEP-07', 'DEP-15', 'DEP-26', 'DEP-30', 'DEP-38', 'DEP-42',
       'DEP-43', 'DEP-63', 'DEP-71', 'DEP-73', 'DEP-12', 'DEP-13',
       'DEP-22', 'DEP-28', 'DEP-37', 'DEP-70', 'DEP-84', 'DEP-973',
       'DEP-05', 'DEP-14', 'DEP-18', 'DEP-2A', 'DEP-2B', 'DEP-31',
       'DEP-36', 'DEP-41', 'DEP-45', 'DEP-50', 'DEP-75', 'DEP-77',
       'DEP-78', 'DEP-83', 'DEP-91', 'DEP-92', 'DEP-93', 'DEP-94',
       'DEP-95', 'DEP-972', 'DEP-39', 'DEP-46', 'DEP-81', 'DEP-82',
       'DEP-85', 'DEP-89', 'DEP-11', 'DEP-58', 'DEP-61', 'DEP-04',
       'DEP-32', 'COM-987', 'COM-974', 'DEP-65', 'DEP-971', 'DEP-974',
       'DEP-66', 'DEP-48', 'DEP-976', 'DEP-09', 'COM-988', 'COM-986'],

filtering data

Selecting one region:

In [40]:
code = 'FRA'
code = 'REG-93'
df_loc = df[df['code']==code]
sourceType = "ministere-sante" sourceType = 'agences-regionales-sante' df_loc = df_loc[df_loc['sourceType']==sourceType]

Selecting the two columns of interest:

In [41]:
df_loc = df_loc[['date', 'deces']]
date deces
219 2020-02-28 NaN
297 2020-03-02 NaN
352 2020-03-03 NaN
439 2020-03-04 0.0
440 2020-03-04 NaN
... ... ...
64572 2021-08-08 8145.0
64692 2021-08-09 8158.0
64812 2021-08-10 8168.0
64932 2021-08-11 8181.0
65052 2021-08-12 8195.0

556 rows × 2 columns

Removing that containing NaNs:

In [42]:
df_loc = df_loc[df_loc['deces'].notna()]

Selecting a date range:

In [43]:
start_date = '2020-02-23'
df_loc['date'] > start_date
439      True
537      True
663      True
770      True
874      True
64572    True
64692    True
64812    True
64932    True
65052    True
Name: date, Length: 533, dtype: bool
In [44]:
df_loc = df_loc[start_date < df_loc['date']]
In [45]:
stop_date = '2020-06-23'
df_loc = df_loc[df_loc['date']<stop_date]
In [46]:
fig, ax = plt.subplots(figsize=figsize)
df_loc.plot(x='date', y='deces', ax=ax)
ax.set_title('Cumulative number of deaths');
No description has been provided for this image

returning a numpy array

In [47]:
death = np.array(df_loc['deces'])
df_loc['date'], death
(439     2020-03-04
 537     2020-03-05
 663     2020-03-06
 770     2020-03-07
 874     2020-03-08
 14405   2020-06-18
 14526   2020-06-19
 14647   2020-06-20
 14768   2020-06-21
 14889   2020-06-22
 Name: date, Length: 117, dtype: datetime64[ns],
 array([  0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   0.,   4.,
          7.,   7.,   9.,   9.,  11.,  11.,  13.,  13.,  15.,  14.,  20.,
         20.,  26.,  26.,  33.,  33.,  44.,  44.,  48.,  55.,  65.,  80.,
        103., 124., 141., 161., 178., 195., 231., 253., 269., 286., 299.,
        317., 328., 348., 374., 393., 421., 447., 463., 470., 499., 526.,
        561., 575., 610., 625., 631., 656., 665., 682., 696., 704., 708.,
        713., 732., 746., 756., 764., 770., 773., 777., 797., 805., 816.,
        824., 831., 836., 840., 849., 855., 862., 869., 874., 876., 878.,
        884., 889., 890., 893., 898., 898., 898., 898., 908., 911., 912.,
        913., 913., 913., 917., 919., 920., 922., 923., 923., 923., 925.,
        927., 927., 929., 933., 935., 935., 936.]))
In [48]:
death.shape, np.diff(death).shape
((117,), (116,))
In [49]:
death[1:] = np.diff(death)
death[0] = 0
In [50]:
array([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  4.,  3.,  0.,
        2.,  0.,  2.,  0.,  2.,  0.,  2., -1.,  6.,  0.,  6.,  0.,  7.,
        0., 11.,  0.,  4.,  7., 10., 15., 23., 21., 17., 20., 17., 17.,
       36., 22., 16., 17., 13., 18., 11., 20., 26., 19., 28., 26., 16.,
        7., 29., 27., 35., 14., 35., 15.,  6., 25.,  9., 17., 14.,  8.,
        4.,  5., 19., 14., 10.,  8.,  6.,  3.,  4., 20.,  8., 11.,  8.,
        7.,  5.,  4.,  9.,  6.,  7.,  7.,  5.,  2.,  2.,  6.,  5.,  1.,
        3.,  5.,  0.,  0.,  0., 10.,  3.,  1.,  1.,  0.,  0.,  4.,  2.,
        1.,  2.,  1.,  0.,  0.,  2.,  2.,  0.,  2.,  4.,  2.,  0.,  1.])
In [51]:
df_loc['death'] = death
date deces death
439 2020-03-04 0.0 0.0
537 2020-03-05 0.0 0.0
663 2020-03-06 0.0 0.0
770 2020-03-07 0.0 0.0
874 2020-03-08 0.0 0.0
... ... ... ...
14405 2020-06-18 929.0 2.0
14526 2020-06-19 933.0 4.0
14647 2020-06-20 935.0 2.0
14768 2020-06-21 935.0 0.0
14889 2020-06-22 936.0 1.0

117 rows × 3 columns

df_loc['dead per day'] = df_loc['deces']
In [52]:
fig, ax = plt.subplots(figsize=figsize)
df_loc.plot(x='date', y='death', ax=ax)
ax.set_title('Daily rate of deaths');
No description has been provided for this image

performing the fit

converting data to fit

In [53]:
X, y = df_loc['date'], df_loc['death']
In [54]:
X.shape, y.shape
((117,), (117,))
In [55]:
X = X.astype(int)
X = np.array(X)
array([1583280000000000000, 1583366400000000000, 1583452800000000000,
       1583539200000000000, 1583625600000000000, 1583712000000000000,
       1583798400000000000, 1583884800000000000, 1583971200000000000,
       1584057600000000000, 1584403200000000000, 1584489600000000000,
       1584489600000000000, 1584576000000000000, 1584576000000000000,
       1584662400000000000, 1584662400000000000, 1584748800000000000,
       1584748800000000000, 1584835200000000000, 1584835200000000000,
       1584921600000000000, 1584921600000000000, 1585008000000000000,
       1585008000000000000, 1585094400000000000, 1585094400000000000,
       1585180800000000000, 1585180800000000000, 1585267200000000000,
       1585353600000000000, 1585440000000000000, 1585526400000000000,
       1585612800000000000, 1585699200000000000, 1585785600000000000,
       1585872000000000000, 1585958400000000000, 1586044800000000000,
       1586131200000000000, 1586217600000000000, 1586304000000000000,
       1586390400000000000, 1586476800000000000, 1586563200000000000,
       1586649600000000000, 1586736000000000000, 1586822400000000000,
       1586908800000000000, 1586995200000000000, 1587081600000000000,
       1587168000000000000, 1587254400000000000, 1587340800000000000,
       1587427200000000000, 1587513600000000000, 1587600000000000000,
       1587686400000000000, 1587772800000000000, 1587859200000000000,
       1587945600000000000, 1588032000000000000, 1588118400000000000,
       1588204800000000000, 1588291200000000000, 1588377600000000000,
       1588464000000000000, 1588550400000000000, 1588636800000000000,
       1588723200000000000, 1588809600000000000, 1588896000000000000,
       1588982400000000000, 1589068800000000000, 1589155200000000000,
       1589241600000000000, 1589328000000000000, 1589414400000000000,
       1589500800000000000, 1589587200000000000, 1589673600000000000,
       1589760000000000000, 1589846400000000000, 1589932800000000000,
       1590019200000000000, 1590105600000000000, 1590192000000000000,
       1590278400000000000, 1590364800000000000, 1590451200000000000,
       1590537600000000000, 1590624000000000000, 1590710400000000000,
       1590796800000000000, 1590883200000000000, 1590969600000000000,
       1591056000000000000, 1591142400000000000, 1591228800000000000,
       1591315200000000000, 1591401600000000000, 1591488000000000000,
       1591574400000000000, 1591660800000000000, 1591747200000000000,
       1591833600000000000, 1591920000000000000, 1592006400000000000,
       1592092800000000000, 1592179200000000000, 1592265600000000000,
       1592352000000000000, 1592438400000000000, 1592524800000000000,
       1592611200000000000, 1592697600000000000, 1592784000000000000])
In [56]:
INTS_PER_DAY = 86400000000000 
In [57]:
X = (X-X[0])//INTS_PER_DAY
In [58]:
array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  13,  14,  14,
        15,  15,  16,  16,  17,  17,  18,  18,  19,  19,  20,  20,  21,
        21,  22,  22,  23,  24,  25,  26,  27,  28,  29,  30,  31,  32,
        33,  34,  35,  36,  37,  38,  39,  40,  41,  42,  43,  44,  45,
        46,  47,  48,  49,  50,  51,  52,  53,  54,  55,  56,  57,  58,
        59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,  70,  71,
        72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,  84,
        85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
        98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110])
In [59]:
y = np.array(y)
array([ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  4.,  3.,  0.,
        2.,  0.,  2.,  0.,  2.,  0.,  2., -1.,  6.,  0.,  6.,  0.,  7.,
        0., 11.,  0.,  4.,  7., 10., 15., 23., 21., 17., 20., 17., 17.,
       36., 22., 16., 17., 13., 18., 11., 20., 26., 19., 28., 26., 16.,
        7., 29., 27., 35., 14., 35., 15.,  6., 25.,  9., 17., 14.,  8.,
        4.,  5., 19., 14., 10.,  8.,  6.,  3.,  4., 20.,  8., 11.,  8.,
        7.,  5.,  4.,  9.,  6.,  7.,  7.,  5.,  2.,  2.,  6.,  5.,  1.,
        3.,  5.,  0.,  0.,  0., 10.,  3.,  1.,  1.,  0.,  0.,  4.,  2.,
        1.,  2.,  1.,  0.,  0.,  2.,  2.,  0.,  2.,  4.,  2.,  0.,  1.])

using torch

or "when you have a new hammer, everything looks like a nail, see

Let's use the definition of the pdf:

$$ f ( x ; \mu , \lambda ) = \sqrt {\frac {\lambda }{2\pi x^3}} \exp (-\frac {\lambda (x-\mu )^{2}}{2\mu ^{2}x} ) $$

and a previous post where I used pyTorch to fit psychophysical data:

In [60]:
import torch
from import TensorDataset, DataLoader

criterion = torch.nn.MSELoss(reduction="sum")
#criterion = torch.nn.L1Loss(reduction="sum")

class CovidRegressionModel(torch.nn.Module):
    def __init__(self, mu=80, tau=24, 
        super(CovidRegressionModel, self).__init__()
        self.tau = torch.nn.Parameter(tau * torch.ones(1)) = torch.nn.Parameter(mu * torch.ones(1))
        # when modeling a stricly positive number, a good habit is to use their log as the fitted parameter:
        self.log_amp = torch.nn.Parameter(log_amp * torch.ones(1))
        self.log_lambda = torch.nn.Parameter(log_lambda * torch.ones(1))

    def forward(self, x):
        out = (x > self.tau) * torch.exp(self.log_amp)
        date = (x-self.tau) # 
        date[x<=self.tau] = 1 # to avoid NaNs in the output value (and in the gradient)
        out *= torch.sqrt(torch.exp(self.log_lambda) / (2 * np.pi * date ** 3)) 
        out *= torch.exp(-torch.exp(self.log_lambda) * (date - ** 2 / (2 * ** 2 * date))
        out[x<=self.tau] = 0 # overwriting values before the onset
        return out

learning_rate = 0.005
beta1, beta2 = 0.9, 0.999
betas = (beta1, beta2)
num_epochs = 2 ** 9 + 1
batch_size = 16
amsgrad = False # gives similar results
amsgrad = True  # gives similar results

def fit_data(
    batch_size=batch_size,  # gamma=gamma,
    verbose=False, **kwargs

    variables, labels = torch.Tensor(X[:, None]), torch.Tensor(y[:, None])
    loader = DataLoader(
        TensorDataset(variables, labels), batch_size=batch_size, shuffle=True

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    covid_model = CovidRegressionModel()
    covid_model =
    optimizer = torch.optim.Adam(
        covid_model.parameters(), lr=learning_rate, betas=betas, amsgrad=amsgrad
    for epoch in range(int(num_epochs)):
        losses = []
        for variables_, labels_ in loader:
            variables_, labels_ =,
            outputs = covid_model(variables_)
            loss = criterion(outputs, labels_)


        if verbose and (epoch % (num_epochs // 32) == 0):
            print(f"Iteration: {epoch} - Loss: {np.sum(losses)/len(variables):.5f}")

    variables, labels = torch.Tensor(X[:, None]), torch.Tensor(y[:, None])
    outputs = covid_model(variables)
    loss = criterion(outputs, labels).item() / len(variables)
    return covid_model, loss
In [61]:
covid_model, loss = fit_data(X, y, verbose=True)
print("Final loss =", loss)
Iteration: 0 - Loss: 95.69920
Iteration: 16 - Loss: 54.13756
Iteration: 32 - Loss: 33.30596
Iteration: 48 - Loss: 31.36121
Iteration: 64 - Loss: 30.85001
Iteration: 80 - Loss: 30.38980
Iteration: 96 - Loss: 30.06311
Iteration: 112 - Loss: 29.77887
Iteration: 128 - Loss: 29.59526
Iteration: 144 - Loss: 29.40957
Iteration: 160 - Loss: 29.16254
Iteration: 176 - Loss: 29.03642
Iteration: 192 - Loss: 28.90962
Iteration: 208 - Loss: 28.83253
Iteration: 224 - Loss: 28.75095
Iteration: 240 - Loss: 28.68247
Iteration: 256 - Loss: 28.60991
Iteration: 272 - Loss: 28.55558
Iteration: 288 - Loss: 28.55271
Iteration: 304 - Loss: 28.45341
Iteration: 320 - Loss: 28.41483
Iteration: 336 - Loss: 28.38179
Iteration: 352 - Loss: 28.37117
Iteration: 368 - Loss: 28.33166
Iteration: 384 - Loss: 28.32052
Iteration: 400 - Loss: 28.29289
Iteration: 416 - Loss: 28.27676
Iteration: 432 - Loss: 28.27867
Iteration: 448 - Loss: 28.24910
Iteration: 464 - Loss: 28.24694
Iteration: 480 - Loss: 28.20544
Iteration: 496 - Loss: 28.19613
Iteration: 512 - Loss: 28.18755
Final loss = 28.163296599808184
In [62]:
outputs = covid_model(torch.Tensor(X[:, None]))
y_pred = outputs.detach().numpy()
In [63]:
fig, ax = plt.subplots(figsize=figsize)
ax.plot(X, y, '*', label='data')
ax.plot(X, y_pred, '--', label='fit')
ax.vlines(, y.min(), y.max(), colors='g', linestyles='--', label=f'$\mu$={} (mean date)')
ax.vlines(covid_model.tau.item(), y.min(), y.max(), colors='r', linestyles='--', label=fr'$\tau$={covid_model.tau.item():.1f} (day of onset)')
ax.plot([], [], label=f'A={torch.exp(covid_model.log_amp).item():.1f} (amplitude)')
ax.plot([], [], label=f'$\lambda$={torch.exp(covid_model.log_lambda).item():.1f} (spread in days)')
ax.set_xlabel(f'# number of days since {start_date}');
ax.set_ylabel(f'# number of dead per day');
ax.set_title('Fitting Daily rate of deaths (PACA-FR)');
No description has been provided for this image

some book keeping for the notebook

In [64]:
%load_ext watermark
%watermark -i -h -m -v -p numpy,pandas,matplotlib,torch  -r -g -b
The watermark extension is already loaded. To reload it, use:
  %reload_ext watermark
Python implementation: CPython
Python version       : 3.11.6
IPython version      : 8.17.2

numpy     : 1.26.2
pandas    : 2.1.2
matplotlib: 3.8.1
torch     : 2.2.0.dev20231107

Compiler    : Clang 15.0.0 (clang-1500.0.40.1)
OS          : Darwin
Release     : 23.2.0
Machine     : arm64
Processor   : arm
CPU cores   : 10
Architecture: 64bit

Hostname: obiwan.local

Git hash: 688f040143fa7b3a5258594843cf6e13498821db

Git repo:

Git branch: master