Spaces:
Sleeping
Sleeping
Livia_Zaharia
commited on
Commit
·
bacf16b
1
Parent(s):
eb420aa
added code for the first time
Browse files- __pycache__/plot_predictions.cpython-311.pyc +0 -0
- __pycache__/routes.cpython-311.pyc +0 -0
- __pycache__/tools.cpython-311.pyc +0 -0
- data_formatter/__init__.py +0 -0
- data_formatter/__pycache__/__init__.cpython-311.pyc +0 -0
- data_formatter/__pycache__/base.cpython-311.pyc +0 -0
- data_formatter/__pycache__/types.cpython-311.pyc +0 -0
- data_formatter/__pycache__/utils.cpython-311.pyc +0 -0
- data_formatter/base.py +213 -0
- data_formatter/types.py +19 -0
- data_formatter/utils.py +323 -0
- environment.yaml +28 -0
- files/config.yaml +81 -0
- format_dexcom.py +152 -0
- gluformer/__init__.py +0 -0
- gluformer/__pycache__/__init__.cpython-311.pyc +0 -0
- gluformer/__pycache__/attention.cpython-311.pyc +0 -0
- gluformer/__pycache__/decoder.cpython-311.pyc +0 -0
- gluformer/__pycache__/embed.cpython-311.pyc +0 -0
- gluformer/__pycache__/encoder.cpython-311.pyc +0 -0
- gluformer/__pycache__/model.cpython-311.pyc +0 -0
- gluformer/__pycache__/variance.cpython-311.pyc +0 -0
- gluformer/attention.py +70 -0
- gluformer/decoder.py +50 -0
- gluformer/embed.py +69 -0
- gluformer/encoder.py +67 -0
- gluformer/model.py +334 -0
- gluformer/utils/__init__.py +0 -0
- gluformer/utils/__pycache__/__init__.cpython-311.pyc +0 -0
- gluformer/utils/__pycache__/collate.cpython-311.pyc +0 -0
- gluformer/utils/__pycache__/training.cpython-311.pyc +0 -0
- gluformer/utils/collate.py +84 -0
- gluformer/utils/evaluation.py +81 -0
- gluformer/utils/training.py +80 -0
- gluformer/variance.py +24 -0
- main.py +8 -0
- tools.py +198 -0
- utils/__init__.py +0 -0
- utils/__pycache__/__init__.cpython-311.pyc +0 -0
- utils/__pycache__/darts_dataset.cpython-311.pyc +0 -0
- utils/__pycache__/darts_processing.cpython-311.pyc +0 -0
- utils/darts_dataset.py +881 -0
- utils/darts_evaluation.py +280 -0
- utils/darts_processing.py +367 -0
- utils/darts_training.py +114 -0
__pycache__/plot_predictions.cpython-311.pyc
ADDED
Binary file (9.5 kB). View file
|
|
__pycache__/routes.cpython-311.pyc
ADDED
Binary file (2.33 kB). View file
|
|
__pycache__/tools.cpython-311.pyc
ADDED
Binary file (13.3 kB). View file
|
|
data_formatter/__init__.py
ADDED
File without changes
|
data_formatter/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (182 Bytes). View file
|
|
data_formatter/__pycache__/base.cpython-311.pyc
ADDED
Binary file (16.4 kB). View file
|
|
data_formatter/__pycache__/types.cpython-311.pyc
ADDED
Binary file (1.09 kB). View file
|
|
data_formatter/__pycache__/utils.cpython-311.pyc
ADDED
Binary file (19.8 kB). View file
|
|
data_formatter/base.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''Defines a generic data formatter for CGM data sets.'''
|
2 |
+
import sys
|
3 |
+
import warnings
|
4 |
+
import numpy as np
|
5 |
+
import pandas as pd
|
6 |
+
import sklearn.preprocessing
|
7 |
+
import data_formatter.types as types
|
8 |
+
import data_formatter.utils as utils
|
9 |
+
|
10 |
+
DataTypes = types.DataTypes
|
11 |
+
InputTypes = types.InputTypes
|
12 |
+
|
13 |
+
dict_data_type = {'categorical': DataTypes.CATEGORICAL,
|
14 |
+
'real_valued': DataTypes.REAL_VALUED,
|
15 |
+
'date': DataTypes.DATE}
|
16 |
+
dict_input_type = {'target': InputTypes.TARGET,
|
17 |
+
'observed_input': InputTypes.OBSERVED_INPUT,
|
18 |
+
'known_input': InputTypes.KNOWN_INPUT,
|
19 |
+
'static_input': InputTypes.STATIC_INPUT,
|
20 |
+
'id': InputTypes.ID,
|
21 |
+
'time': InputTypes.TIME}
|
22 |
+
|
23 |
+
|
24 |
+
class DataFormatter:
|
25 |
+
# Defines and formats data.
|
26 |
+
|
27 |
+
def __init__(self, cnf):
|
28 |
+
"""Initialises formatter."""
|
29 |
+
# load parameters from the config file
|
30 |
+
self.params = cnf
|
31 |
+
# write progress to file if specified
|
32 |
+
|
33 |
+
# load column definition
|
34 |
+
print('-'*32)
|
35 |
+
print('Loading column definition...')
|
36 |
+
self.__process_column_definition()
|
37 |
+
|
38 |
+
# check that column definition is valid
|
39 |
+
print('Checking column definition...')
|
40 |
+
self.__check_column_definition()
|
41 |
+
|
42 |
+
# load data
|
43 |
+
# check if data table has index col: -1 if not, index >= 0 if yes
|
44 |
+
print('Loading data...')
|
45 |
+
self.params['index_col'] = False if self.params['index_col'] == -1 else self.params['index_col']
|
46 |
+
# read data table
|
47 |
+
self.data = pd.read_csv(self.params['data_csv_path'], index_col=self.params['index_col'])
|
48 |
+
|
49 |
+
# drop columns / rows
|
50 |
+
print('Dropping columns / rows...')
|
51 |
+
self.__drop()
|
52 |
+
|
53 |
+
# check NA values
|
54 |
+
print('Checking for NA values...')
|
55 |
+
self.__check_nan()
|
56 |
+
|
57 |
+
# set data types in DataFrame to match column definition
|
58 |
+
print('Setting data types...')
|
59 |
+
self.__set_data_types()
|
60 |
+
|
61 |
+
# drop columns / rows
|
62 |
+
print('Dropping columns / rows...')
|
63 |
+
self.__drop()
|
64 |
+
|
65 |
+
# encode
|
66 |
+
print('Encoding data...')
|
67 |
+
self._encoding_params = self.params['encoding_params']
|
68 |
+
self.__encode()
|
69 |
+
|
70 |
+
# interpolate
|
71 |
+
print('Interpolating data...')
|
72 |
+
self._interpolation_params = self.params['interpolation_params']
|
73 |
+
self._interpolation_params['interval_length'] = self.params['observation_interval']
|
74 |
+
self.__interpolate()
|
75 |
+
|
76 |
+
# split data
|
77 |
+
print('Splitting data...')
|
78 |
+
self._split_params = self.params['split_params']
|
79 |
+
self._split_params['max_length_input'] = self.params['max_length_input']
|
80 |
+
self.__split_data()
|
81 |
+
|
82 |
+
# scale
|
83 |
+
print('Scaling data...')
|
84 |
+
self._scaling_params = self.params['scaling_params']
|
85 |
+
self.__scale()
|
86 |
+
|
87 |
+
print('Data formatting complete.')
|
88 |
+
print('-'*32)
|
89 |
+
|
90 |
+
|
91 |
+
def __process_column_definition(self):
|
92 |
+
self._column_definition = []
|
93 |
+
for col in self.params['column_definition']:
|
94 |
+
self._column_definition.append((col['name'],
|
95 |
+
dict_data_type[col['data_type']],
|
96 |
+
dict_input_type[col['input_type']]))
|
97 |
+
|
98 |
+
def __check_column_definition(self):
|
99 |
+
# check that there is unique ID column
|
100 |
+
assert len([col for col in self._column_definition if col[2] == InputTypes.ID]) == 1, 'There must be exactly one ID column.'
|
101 |
+
# check that there is unique time column
|
102 |
+
assert len([col for col in self._column_definition if col[2] == InputTypes.TIME]) == 1, 'There must be exactly one time column.'
|
103 |
+
# check that there is at least one target column
|
104 |
+
assert len([col for col in self._column_definition if col[2] == InputTypes.TARGET]) >= 1, 'There must be at least one target column.'
|
105 |
+
|
106 |
+
def __set_data_types(self):
|
107 |
+
# set time column as datetime format in pandas
|
108 |
+
for col in self._column_definition:
|
109 |
+
if col[1] == DataTypes.DATE:
|
110 |
+
self.data[col[0]] = pd.to_datetime(self.data[col[0]])
|
111 |
+
if col[1] == DataTypes.CATEGORICAL:
|
112 |
+
self.data[col[0]] = self.data[col[0]].astype('category')
|
113 |
+
if col[1] == DataTypes.REAL_VALUED:
|
114 |
+
self.data[col[0]] = self.data[col[0]].astype(np.float32)
|
115 |
+
|
116 |
+
def __check_nan(self):
|
117 |
+
# delete rows where target, time, or id are na
|
118 |
+
self.data = self.data.dropna(subset=[col[0]
|
119 |
+
for col in self._column_definition
|
120 |
+
if col[2] in [InputTypes.TARGET, InputTypes.TIME, InputTypes.ID]])
|
121 |
+
# assert that there are no na values in the data
|
122 |
+
assert self.data.isna().sum().sum() == 0, 'There are NA values in the data even after dropping with missing time, glucose, or id.'
|
123 |
+
|
124 |
+
def __drop(self):
|
125 |
+
# drop columns that are not in the column definition
|
126 |
+
self.data = self.data[[col[0] for col in self._column_definition]]
|
127 |
+
# drop rows based on conditions set in the formatter
|
128 |
+
if self.params['drop'] is not None:
|
129 |
+
if self.params['drop']['rows'] is not None:
|
130 |
+
# drop row at indices in the list self.params['drop']['rows']
|
131 |
+
self.data = self.data.drop(self.params['drop']['rows'])
|
132 |
+
self.data = self.data.reset_index(drop=True)
|
133 |
+
if self.params['drop']['columns'] is not None:
|
134 |
+
for col in self.params['drop']['columns'].keys():
|
135 |
+
# drop rows where specified columns have values in the list self.params['drop']['columns'][col]
|
136 |
+
self.data = self.data.loc[~self.data[col].isin(self.params['drop']['columns'][col])].copy()
|
137 |
+
|
138 |
+
def __interpolate(self):
|
139 |
+
self.data, self._column_definition = utils.interpolate(self.data,
|
140 |
+
self._column_definition,
|
141 |
+
**self._interpolation_params)
|
142 |
+
|
143 |
+
def __split_data(self):
|
144 |
+
if self.params['split_params']['test_percent_subjects'] == 0 or \
|
145 |
+
self.params['split_params']['length_segment'] == 0:
|
146 |
+
print('\tNo splitting performed since test_percent_subjects or length_segment is 0.')
|
147 |
+
self.train_idx, self.val_idx, self.test_idx, self.test_idx_ood = None, None, None, None
|
148 |
+
self.train_data, self.val_data, self.test_data = self.data, None, None
|
149 |
+
else:
|
150 |
+
assert self.params['split_params']['length_segment'] > self.params['length_pred'], \
|
151 |
+
'length_segment for test / val must be greater than length_pred.'
|
152 |
+
self.train_idx, self.val_idx, self.test_idx, self.test_idx_ood = utils.split(self.data,
|
153 |
+
self._column_definition,
|
154 |
+
**self._split_params)
|
155 |
+
self.train_data, self.val_data, self.test_data = self.data.iloc[self.train_idx], \
|
156 |
+
self.data.iloc[self.val_idx], \
|
157 |
+
self.data.iloc[self.test_idx + self.test_idx_ood]
|
158 |
+
|
159 |
+
def __encode(self):
|
160 |
+
self.data, self._column_definition, self.encoders = utils.encode(self.data,
|
161 |
+
self._column_definition,
|
162 |
+
**self._encoding_params)
|
163 |
+
|
164 |
+
def __scale(self):
|
165 |
+
self.train_data, self.val_data, self.test_data, self.scalers = utils.scale(self.train_data,
|
166 |
+
self.val_data,
|
167 |
+
self.test_data,
|
168 |
+
self._column_definition,
|
169 |
+
**self.params['scaling_params'])
|
170 |
+
|
171 |
+
def reshuffle(self, seed):
|
172 |
+
stdout = sys.stdout
|
173 |
+
f = open(self.study_file, 'a')
|
174 |
+
sys.stdout = f
|
175 |
+
self.params['split_params']['random_state'] = seed
|
176 |
+
# split data
|
177 |
+
self.train_idx, self.val_idx, self.test_idx, self.test_idx_ood = utils.split(self.data,
|
178 |
+
self._column_definition,
|
179 |
+
**self._split_params)
|
180 |
+
self.train_data, self.val_data, self.test_data = self.data.iloc[self.train_idx], \
|
181 |
+
self.data.iloc[self.val_idx], \
|
182 |
+
self.data.iloc[self.test_idx+self.test_idx_ood]
|
183 |
+
# re-scale data
|
184 |
+
self.train_data, self.val_data, self.test_data, self.scalers = utils.scale(self.train_data,
|
185 |
+
self.val_data,
|
186 |
+
self.test_data,
|
187 |
+
self._column_definition,
|
188 |
+
**self.params['scaling_params'])
|
189 |
+
sys.stdout = stdout
|
190 |
+
f.close()
|
191 |
+
|
192 |
+
def get_column(self, column_name):
|
193 |
+
# write cases for time, id, target, future, static, dynamic covariates
|
194 |
+
if column_name == 'time':
|
195 |
+
return [col[0] for col in self._column_definition if col[2] == InputTypes.TIME][0]
|
196 |
+
elif column_name == 'id':
|
197 |
+
return [col[0] for col in self._column_definition if col[2] == InputTypes.ID][0]
|
198 |
+
elif column_name == 'sid':
|
199 |
+
return [col[0] for col in self._column_definition if col[2] == InputTypes.SID][0]
|
200 |
+
elif column_name == 'target':
|
201 |
+
return [col[0] for col in self._column_definition if col[2] == InputTypes.TARGET]
|
202 |
+
elif column_name == 'future_covs':
|
203 |
+
future_covs = [col[0] for col in self._column_definition if col[2] == InputTypes.KNOWN_INPUT]
|
204 |
+
return future_covs if len(future_covs) > 0 else None
|
205 |
+
elif column_name == 'static_covs':
|
206 |
+
static_covs = [col[0] for col in self._column_definition if col[2] == InputTypes.STATIC_INPUT]
|
207 |
+
return static_covs if len(static_covs) > 0 else None
|
208 |
+
elif column_name == 'dynamic_covs':
|
209 |
+
dynamic_covs = [col[0] for col in self._column_definition if col[2] == InputTypes.OBSERVED_INPUT]
|
210 |
+
return dynamic_covs if len(dynamic_covs) > 0 else None
|
211 |
+
else:
|
212 |
+
raise ValueError('Column {} not found.'.format(column_name))
|
213 |
+
|
data_formatter/types.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''Defines data and input types of each column in the dataset.'''
|
2 |
+
|
3 |
+
import enum
|
4 |
+
|
5 |
+
class DataTypes(enum.IntEnum):
|
6 |
+
"""Defines numerical types of each column."""
|
7 |
+
REAL_VALUED = 0
|
8 |
+
CATEGORICAL = 1
|
9 |
+
DATE = 2
|
10 |
+
|
11 |
+
class InputTypes(enum.IntEnum):
|
12 |
+
"""Defines input types of each column."""
|
13 |
+
TARGET = 0
|
14 |
+
OBSERVED_INPUT = 1
|
15 |
+
KNOWN_INPUT = 2
|
16 |
+
STATIC_INPUT = 3
|
17 |
+
ID = 4 # Single column used as an entity identifier
|
18 |
+
SID = 5 # Single column used as a segment identifier
|
19 |
+
TIME = 6 # Single column exclusively used as a time index
|
data_formatter/utils.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2019 The Google Research Authors.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
# Lint as: python3
|
17 |
+
"""Generic helper functions used across codebase."""
|
18 |
+
import warnings
|
19 |
+
from collections import namedtuple
|
20 |
+
from datetime import datetime
|
21 |
+
import os
|
22 |
+
import math
|
23 |
+
import pathlib
|
24 |
+
import torch
|
25 |
+
import numpy as np
|
26 |
+
import pandas as pd
|
27 |
+
pd.options.mode.chained_assignment = None
|
28 |
+
from typing import List, Tuple
|
29 |
+
from sklearn import preprocessing
|
30 |
+
|
31 |
+
import data_formatter
|
32 |
+
from data_formatter import types
|
33 |
+
|
34 |
+
DataTypes = types.DataTypes
|
35 |
+
InputTypes = types.InputTypes
|
36 |
+
MINUTE = 60
|
37 |
+
|
38 |
+
# OS related functions.
|
39 |
+
def create_folder_if_not_exist(directory):
|
40 |
+
"""Creates folder if it doesn't exist.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
directory: Folder path to create.
|
44 |
+
"""
|
45 |
+
# Also creates directories recursively
|
46 |
+
pathlib.Path(directory).mkdir(parents=True, exist_ok=True)
|
47 |
+
|
48 |
+
|
49 |
+
def csv_path_to_folder(path: str):
|
50 |
+
return "/".join(path.split('/')[:-1]) + "/"
|
51 |
+
|
52 |
+
|
53 |
+
def interpolate(data: pd.DataFrame,
|
54 |
+
column_definition: List[Tuple[str, DataTypes, InputTypes]],
|
55 |
+
gap_threshold: int = 0,
|
56 |
+
min_drop_length: int = 0,
|
57 |
+
interval_length: int = 0):
|
58 |
+
"""Interpolates missing values in data.
|
59 |
+
|
60 |
+
Args:
|
61 |
+
df: Dataframe to interpolate on. Sorted by id and then time (a DateTime object).
|
62 |
+
column_definition: List of tuples describing columns (column_name, data_type, input_type).
|
63 |
+
gap_threshold: Number in minutes, maximum allowed gap for interpolation.
|
64 |
+
min_drop_length: Number of points, minimum number within an interval to interpolate.
|
65 |
+
interval_length: Number in minutes, length of interpolation.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
data: DataFrame with missing values interpolated and
|
69 |
+
additional column ('segment') indicating continuous segments.
|
70 |
+
column_definition: Updataed list of tuples (column_name, data_type, input_type).
|
71 |
+
"""
|
72 |
+
# select all real-valued columns that are not id, time, or static
|
73 |
+
interpolation_columns = [column_name for column_name, data_type, input_type in column_definition if
|
74 |
+
data_type == DataTypes.REAL_VALUED and
|
75 |
+
input_type not in set([InputTypes.ID, InputTypes.TIME, InputTypes.STATIC_INPUT])]
|
76 |
+
# select all other columns except time
|
77 |
+
constant_columns = [column_name for column_name, data_type, input_type in column_definition if
|
78 |
+
input_type not in set([InputTypes.TIME])]
|
79 |
+
constant_columns += ['id_segment']
|
80 |
+
|
81 |
+
# get id and time columns
|
82 |
+
id_col = [column_name for column_name, data_type, input_type in column_definition if input_type == InputTypes.ID][0]
|
83 |
+
time_col = [column_name for column_name, data_type, input_type in column_definition if input_type == InputTypes.TIME][0]
|
84 |
+
|
85 |
+
# round to minute
|
86 |
+
data[time_col] = data[time_col].dt.round('1min')
|
87 |
+
# count dropped segments
|
88 |
+
dropped_segments = 0
|
89 |
+
# count number of values that are interpolated
|
90 |
+
interpolation_count = 0
|
91 |
+
# store final output
|
92 |
+
output = []
|
93 |
+
for id, id_data in data.groupby(id_col):
|
94 |
+
# sort values
|
95 |
+
id_data.sort_values(time_col, inplace=True)
|
96 |
+
# get time difference between consecutive rows
|
97 |
+
lag = (id_data[time_col].diff().dt.total_seconds().fillna(0) / 60.0).astype(int)
|
98 |
+
# if lag > gap_threshold
|
99 |
+
id_segment = (lag > gap_threshold).cumsum()
|
100 |
+
id_data['id_segment'] = id_segment
|
101 |
+
for segment, segment_data in id_data.groupby('id_segment'):
|
102 |
+
# if segment is too short, then we don't interpolate
|
103 |
+
if len(segment_data) < min_drop_length:
|
104 |
+
dropped_segments += 1
|
105 |
+
continue
|
106 |
+
|
107 |
+
# find and print duplicated times
|
108 |
+
duplicates = segment_data.duplicated(time_col, keep=False)
|
109 |
+
if duplicates.any():
|
110 |
+
print(segment_data[duplicates])
|
111 |
+
raise ValueError('Duplicate times in segment {} of id {}'.format(segment, id))
|
112 |
+
|
113 |
+
# reindex at interval_length minute intervals
|
114 |
+
segment_data = segment_data.set_index(time_col)
|
115 |
+
index_new = pd.date_range(start = segment_data.index[0],
|
116 |
+
end = segment_data.index[-1],
|
117 |
+
freq = interval_length)
|
118 |
+
index_union = index_new.union(segment_data.index)
|
119 |
+
segment_data = segment_data.reindex(index_union)
|
120 |
+
# count nan values in interpolation columns
|
121 |
+
interpolation_count += segment_data[interpolation_columns[0]].isna().sum()
|
122 |
+
# interpolate
|
123 |
+
segment_data[interpolation_columns] = segment_data[interpolation_columns].interpolate(method='index')
|
124 |
+
# fill constant columns with last value
|
125 |
+
segment_data[constant_columns] = segment_data[constant_columns].ffill()
|
126 |
+
# delete rows not conforming to frequency
|
127 |
+
segment_data = segment_data.reindex(index_new)
|
128 |
+
# reset index, make the time a column with name time_col
|
129 |
+
segment_data = segment_data.reset_index().rename(columns={'index': time_col})
|
130 |
+
# set the id_segment to position in output
|
131 |
+
segment_data['id_segment'] = len(output)
|
132 |
+
# add to output
|
133 |
+
output.append(segment_data)
|
134 |
+
# print number of dropped segments and number of segments
|
135 |
+
print('\tDropped segments: {}'.format(dropped_segments))
|
136 |
+
print('\tExtracted segments: {}'.format(len(output)))
|
137 |
+
# concat all segments and reset index
|
138 |
+
output = pd.concat(output)
|
139 |
+
output.reset_index(drop=True, inplace=True)
|
140 |
+
# count number of interpolated values
|
141 |
+
print('\tInterpolated values: {}'.format(interpolation_count))
|
142 |
+
print('\tPercent of values interpolated: {:.2f}%'.format(interpolation_count / len(output) * 100))
|
143 |
+
# add id_segment column to column_definition as ID
|
144 |
+
column_definition += [('id_segment', DataTypes.CATEGORICAL, InputTypes.SID)]
|
145 |
+
|
146 |
+
return output, column_definition
|
147 |
+
|
148 |
+
def create_index(time_col: pd.Series, interval_length: int):
|
149 |
+
"""Creates a new index at interval_length minute intervals.
|
150 |
+
|
151 |
+
Args:
|
152 |
+
time_col: Series of times.
|
153 |
+
interval_length: Number in minutes, length of interpolation.
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
index: New index.
|
157 |
+
"""
|
158 |
+
# margin of error
|
159 |
+
eps = pd.Timedelta('1min')
|
160 |
+
new_time_col = [time_col.iloc[0]]
|
161 |
+
for time in time_col.iloc[1:]:
|
162 |
+
if time - new_time_col[-1] <= pd.Timedelta(interval_length) + eps:
|
163 |
+
new_time_col.append(time)
|
164 |
+
else:
|
165 |
+
filler = new_time_col[-1] + pd.Timedelta(interval_length)
|
166 |
+
while filler < time:
|
167 |
+
new_time_col.append(filler)
|
168 |
+
filler += pd.Timedelta(interval_length)
|
169 |
+
new_time_col.append(time)
|
170 |
+
return pd.to_datetime(new_time_col)
|
171 |
+
|
172 |
+
def split(df: pd.DataFrame,
|
173 |
+
column_definition: List[Tuple[str, DataTypes, InputTypes]],
|
174 |
+
test_percent_subjects: float,
|
175 |
+
length_segment: int,
|
176 |
+
max_length_input: int,
|
177 |
+
random_state: int = 42):
|
178 |
+
"""Splits data into train, validation and test sets.
|
179 |
+
|
180 |
+
Args:
|
181 |
+
df: Dataframe to split.
|
182 |
+
column_definition: List of tuples describing columns (column_name, data_type, input_type).
|
183 |
+
test_percent_subjects: Float number from [0, 1], percentage of subjects to use for test set.
|
184 |
+
length_segment: Number of points, length of segments saved for validation / test sets.
|
185 |
+
max_length_input: Number of points, maximum length of input sequences for models.
|
186 |
+
random_state: Number, Random state for reproducibility.
|
187 |
+
|
188 |
+
Returns:
|
189 |
+
train_idx: Training set indices.
|
190 |
+
val_idx: Validation set indices.
|
191 |
+
test_idx: Test set indices.
|
192 |
+
"""
|
193 |
+
# set random state
|
194 |
+
np.random.seed(random_state)
|
195 |
+
# get id and id_segment columns
|
196 |
+
id_col = [column_name for column_name, data_type, input_type in column_definition if input_type == InputTypes.ID][0]
|
197 |
+
id_segment_col = [column_name for column_name, data_type, input_type in column_definition if input_type == InputTypes.SID][0]
|
198 |
+
# get unique ids
|
199 |
+
ids = df[id_col].unique()
|
200 |
+
|
201 |
+
# select some subjects for test data set
|
202 |
+
test_ids = np.random.choice(ids, math.ceil(len(ids) * test_percent_subjects), replace=False)
|
203 |
+
test_idx_ood = list(df[df[id_col].isin(test_ids)].index)
|
204 |
+
# get the remaning data for training and validation
|
205 |
+
df = df[~df[id_col].isin(test_ids)]
|
206 |
+
|
207 |
+
# iterate through subjects and split into train, val and test
|
208 |
+
train_idx = []; val_idx = []; test_idx = []
|
209 |
+
for id, id_data in df.groupby(id_col):
|
210 |
+
segment_ids = id_data[id_segment_col].unique()
|
211 |
+
if len(segment_ids) >= 2:
|
212 |
+
train_idx += list(id_data.loc[id_data[id_segment_col].isin(segment_ids[:-2])].index)
|
213 |
+
penultimate_segment = id_data[id_data[id_segment_col] == segment_ids[-2]]
|
214 |
+
last_segment = id_data[id_data[id_segment_col] == segment_ids[-1]]
|
215 |
+
if len(last_segment) >= max_length_input + 3 * length_segment:
|
216 |
+
train_idx += list(penultimate_segment.index)
|
217 |
+
train_idx += list(last_segment.iloc[:-2*length_segment].index)
|
218 |
+
val_idx += list(last_segment.iloc[-2*length_segment-max_length_input:-length_segment].index)
|
219 |
+
test_idx += list(last_segment.iloc[-length_segment-max_length_input:].index)
|
220 |
+
elif len(last_segment) >= max_length_input + 2 * length_segment:
|
221 |
+
train_idx += list(penultimate_segment.index)
|
222 |
+
val_idx += list(last_segment.iloc[:-length_segment].index)
|
223 |
+
test_idx += list(last_segment.iloc[-length_segment-max_length_input:].index)
|
224 |
+
else:
|
225 |
+
test_idx += list(last_segment.index)
|
226 |
+
if len(penultimate_segment) >= max_length_input + 2 * length_segment:
|
227 |
+
val_idx += list(penultimate_segment.iloc[-length_segment-max_length_input:].index)
|
228 |
+
train_idx += list(penultimate_segment.iloc[:-length_segment].index)
|
229 |
+
else:
|
230 |
+
train_idx += list(penultimate_segment.index)
|
231 |
+
else:
|
232 |
+
if len(id_data) >= max_length_input + 3 * length_segment:
|
233 |
+
train_idx += list(id_data.iloc[:-2*length_segment].index)
|
234 |
+
val_idx += list(id_data.iloc[-2*length_segment-max_length_input:-length_segment].index)
|
235 |
+
test_idx += list(id_data.iloc[-length_segment-max_length_input:].index)
|
236 |
+
elif len(id_data) >= max_length_input + 2 * length_segment:
|
237 |
+
train_idx += list(id_data.iloc[:-length_segment].index)
|
238 |
+
test_idx += list(id_data.iloc[-length_segment-max_length_input:].index)
|
239 |
+
else:
|
240 |
+
train_idx += list(id_data.index)
|
241 |
+
total_len = len(train_idx) + len(val_idx) + len(test_idx) + len(test_idx_ood)
|
242 |
+
print('\tTrain: {} ({:.2f}%)'.format(len(train_idx), len(train_idx) / total_len * 100))
|
243 |
+
print('\tVal: {} ({:.2f}%)'.format(len(val_idx), len(val_idx) / total_len * 100))
|
244 |
+
print('\tTest: {} ({:.2f}%)'.format(len(test_idx), len(test_idx) / total_len * 100))
|
245 |
+
print('\tTest OOD: {} ({:.2f}%)'.format(len(test_idx_ood), len(test_idx_ood) / total_len * 100))
|
246 |
+
return train_idx, val_idx, test_idx, test_idx_ood
|
247 |
+
|
248 |
+
def encode(df: pd.DataFrame,
|
249 |
+
column_definition: List[Tuple[str, DataTypes, InputTypes]],
|
250 |
+
date: List,):
|
251 |
+
"""Encodes categorical columns.
|
252 |
+
|
253 |
+
Args:
|
254 |
+
df: Dataframe to split.
|
255 |
+
column_definition: List of tuples describing columns (column_name, data_type, input_type).
|
256 |
+
date: List of str, list containing date info to extract.
|
257 |
+
|
258 |
+
Returns:
|
259 |
+
df: Dataframe with encoded columns.
|
260 |
+
column_definition: Updated list of tuples containing column name and types.
|
261 |
+
encoders: dictionary containing encoders.
|
262 |
+
"""
|
263 |
+
encoders = {}
|
264 |
+
new_columns = []
|
265 |
+
for i in range(len(column_definition)):
|
266 |
+
column, column_type, input_type = column_definition[i]
|
267 |
+
if column_type == DataTypes.DATE:
|
268 |
+
for extract_col in date:
|
269 |
+
df[column + '_' + extract_col] = getattr(df[column].dt, extract_col)
|
270 |
+
df[column + '_' + extract_col] = df[column + '_' + extract_col].astype(np.float32)
|
271 |
+
new_columns.append((column + '_' + extract_col, DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT))
|
272 |
+
elif column_type == DataTypes.CATEGORICAL:
|
273 |
+
encoders[column] = preprocessing.LabelEncoder()
|
274 |
+
df[column] = encoders[column].fit_transform(df[column]).astype(np.float32)
|
275 |
+
column_definition[i] = (column, DataTypes.REAL_VALUED, input_type)
|
276 |
+
else:
|
277 |
+
continue
|
278 |
+
column_definition += new_columns
|
279 |
+
# print updated column definition
|
280 |
+
print('\tUpdated column definition:')
|
281 |
+
for column, column_type, input_type in column_definition:
|
282 |
+
print('\t\t{}: {} ({})'.format(column,
|
283 |
+
DataTypes(column_type).name,
|
284 |
+
InputTypes(input_type).name))
|
285 |
+
return df, column_definition, encoders
|
286 |
+
|
287 |
+
def scale(train_data: pd.DataFrame,
|
288 |
+
val_data: pd.DataFrame,
|
289 |
+
test_data: pd.DataFrame,
|
290 |
+
column_definition: List[Tuple[str, DataTypes, InputTypes]],
|
291 |
+
scaler: str):
|
292 |
+
"""Scales numerical data.
|
293 |
+
|
294 |
+
Args:
|
295 |
+
train_data: pd.Dataframe, DataFrame of training data.
|
296 |
+
val_data: pd.Dataframe, DataFrame of validation data.
|
297 |
+
test_data: pd.Dataframe, DataFrame of testing data.
|
298 |
+
column_definition: List of tuples describing columns (column_name, data_type, input_type).
|
299 |
+
scaler: String, scaler to use.
|
300 |
+
|
301 |
+
Returns:
|
302 |
+
train_data: pd.Dataframe, DataFrame of scaled training data.
|
303 |
+
val_data: pd.Dataframe, DataFrame of scaled validation data.
|
304 |
+
test_data: pd.Dataframe, DataFrame of scaled testing data.
|
305 |
+
scalers: dictionary index by column names containing scalers.
|
306 |
+
"""
|
307 |
+
# select all real-valued columns
|
308 |
+
columns_to_scale = [column for column, data_type, input_type in column_definition if data_type == DataTypes.REAL_VALUED]
|
309 |
+
# handle no scaling case
|
310 |
+
if scaler == 'None':
|
311 |
+
print('\tNo scaling applied')
|
312 |
+
return train_data, val_data, test_data, None
|
313 |
+
scalers = {}
|
314 |
+
for column in columns_to_scale:
|
315 |
+
scaler_column = getattr(preprocessing, scaler)()
|
316 |
+
train_data[column] = scaler_column.fit_transform(train_data[column].values.reshape(-1, 1))
|
317 |
+
# handle empty validation and test sets
|
318 |
+
val_data[column] = scaler_column.transform(val_data[column].values.reshape(-1, 1)) if val_data.shape[0] > 0 else val_data[column]
|
319 |
+
test_data[column] = scaler_column.transform(test_data[column].values.reshape(-1, 1)) if test_data.shape[0] > 0 else test_data[column]
|
320 |
+
scalers[column] = scaler_column
|
321 |
+
# print columns that were scaled
|
322 |
+
print('\tScaled columns: {}'.format(columns_to_scale))
|
323 |
+
return train_data, val_data, test_data, scalers
|
environment.yaml
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: glucose_genie
|
2 |
+
channels:
|
3 |
+
- conda-forge
|
4 |
+
- defaults
|
5 |
+
dependencies:
|
6 |
+
- python=3.11
|
7 |
+
- gradio
|
8 |
+
- seaborn
|
9 |
+
- pytorch
|
10 |
+
- optuna
|
11 |
+
- numpy<2.0.0
|
12 |
+
- tensorboard
|
13 |
+
- pip:
|
14 |
+
- fastapi
|
15 |
+
- uvicorn
|
16 |
+
- thefuzz
|
17 |
+
- pycomfort>=0.0.15
|
18 |
+
- polars>=1.3.0
|
19 |
+
- hybrid_search>=0.0.15
|
20 |
+
- psutil #compartibility
|
21 |
+
- httpx
|
22 |
+
- just-agents>=0.1.0
|
23 |
+
- FlagEmbedding
|
24 |
+
- typer
|
25 |
+
- darts==0.29.0
|
26 |
+
- pmdarima==2.0.4
|
27 |
+
- numpy==1.26.4
|
28 |
+
- peft
|
files/config.yaml
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
data_csv_path: ./raw_data/anton.csv
|
2 |
+
drop: null
|
3 |
+
ds_name: livia_mini
|
4 |
+
index_col: -1
|
5 |
+
observation_interval: 5min
|
6 |
+
|
7 |
+
column_definition:
|
8 |
+
- data_type: categorical
|
9 |
+
input_type: id
|
10 |
+
name: id
|
11 |
+
- data_type: date
|
12 |
+
input_type: time
|
13 |
+
name: time
|
14 |
+
- data_type: real_valued
|
15 |
+
input_type: target
|
16 |
+
name: gl
|
17 |
+
|
18 |
+
encoding_params:
|
19 |
+
date:
|
20 |
+
- day
|
21 |
+
- month
|
22 |
+
- year
|
23 |
+
- hour
|
24 |
+
- minute
|
25 |
+
- second
|
26 |
+
|
27 |
+
# NA values abbreviation
|
28 |
+
nan_vals: null
|
29 |
+
|
30 |
+
# Interpolation parameters
|
31 |
+
interpolation_params:
|
32 |
+
gap_threshold: 45 # in minutes
|
33 |
+
min_drop_length: 240 # in number of points (20 hrs)
|
34 |
+
|
35 |
+
scaling_params:
|
36 |
+
scaler: None
|
37 |
+
|
38 |
+
split_params:
|
39 |
+
length_segment: 13
|
40 |
+
random_state: 0
|
41 |
+
test_percent_subjects: 0.1
|
42 |
+
|
43 |
+
|
44 |
+
# Splitting parameters
|
45 |
+
#split_params:
|
46 |
+
# test_percent_subjects: .1
|
47 |
+
# length_segment: 240
|
48 |
+
# random_state: 0
|
49 |
+
|
50 |
+
# Model params
|
51 |
+
max_length_input: 192
|
52 |
+
length_pred: 12
|
53 |
+
|
54 |
+
transformer:
|
55 |
+
batch_size: 32
|
56 |
+
d_model: 96
|
57 |
+
dim_feedforward: 448
|
58 |
+
dropout: 0.10161152207464333
|
59 |
+
in_len: 96
|
60 |
+
lr: 0.000840888489686657
|
61 |
+
lr_epochs: 16
|
62 |
+
max_grad_norm: 0.6740479322943925
|
63 |
+
max_samples_per_ts: 50
|
64 |
+
n_heads: 4
|
65 |
+
num_decoder_layers: 1
|
66 |
+
num_encoder_layers: 4
|
67 |
+
|
68 |
+
transformer_covariates:
|
69 |
+
batch_size: 32
|
70 |
+
d_model: 128
|
71 |
+
dim_feedforward: 160
|
72 |
+
dropout: 0.044926981080245884
|
73 |
+
in_len: 108
|
74 |
+
lr: 0.00029632347559614453
|
75 |
+
lr_epochs: 20
|
76 |
+
max_grad_norm: 0.8890169619043728
|
77 |
+
max_samples_per_ts: 50
|
78 |
+
n_heads: 2
|
79 |
+
num_decoder_layers: 2
|
80 |
+
num_encoder_layers: 2
|
81 |
+
|
format_dexcom.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
from pathlib import Path
|
3 |
+
import typer
|
4 |
+
|
5 |
+
|
6 |
+
def process_csv(
|
7 |
+
input_dir: Path = typer.Argument( help="Directory containing the input CSV files."),
|
8 |
+
output_file: Path = typer.Argument( help="Path to save the processed CSV file."),
|
9 |
+
event_type_filter: str = typer.Option('egv', help="Event type to filter by."),
|
10 |
+
drop_duplicates: bool = typer.Option(True, help="Whether to drop duplicate timestamps."),
|
11 |
+
time_diff_minutes: int = typer.Option(1, help="Minimum time difference in minutes to keep a row."),
|
12 |
+
chunk_size: int = typer.Option(1000, help="Chunk size for the 'id' column increment. Set to 0 or None for a single id."),
|
13 |
+
) -> pd.DataFrame:
|
14 |
+
|
15 |
+
# Read CSV file into a DataFrame
|
16 |
+
filename=input_dir
|
17 |
+
df = pd.read_csv(filename, low_memory=False)
|
18 |
+
|
19 |
+
|
20 |
+
# Filter by Event Type and Event Subtype
|
21 |
+
df = df[df['Event Type'].str.lower() == event_type_filter]
|
22 |
+
df = df[df['Event Subtype'].isna()]
|
23 |
+
|
24 |
+
# List of columns to keep
|
25 |
+
columns_to_keep = [
|
26 |
+
'Index',
|
27 |
+
'Timestamp (YYYY-MM-DDThh:mm:ss)',
|
28 |
+
'Glucose Value (mg/dL)',
|
29 |
+
]
|
30 |
+
|
31 |
+
# Keep only the specified columns
|
32 |
+
df = df[columns_to_keep]
|
33 |
+
|
34 |
+
# Rename columns
|
35 |
+
column_rename = {
|
36 |
+
'Index': 'id',
|
37 |
+
'Timestamp (YYYY-MM-DDThh:mm:ss)': 'time',
|
38 |
+
'Glucose Value (mg/dL)': 'gl'
|
39 |
+
}
|
40 |
+
df = df.rename(columns=column_rename)
|
41 |
+
|
42 |
+
|
43 |
+
# Handle id assignment based on chunk_size
|
44 |
+
if chunk_size is None or chunk_size == 0:
|
45 |
+
df['id'] = 1 # Assign the same id to all rows
|
46 |
+
else:
|
47 |
+
df['id'] = ((df.index // chunk_size) % (df.index.max() // chunk_size + 1)).astype(int)
|
48 |
+
|
49 |
+
# Convert timestamp to datetime
|
50 |
+
df['time'] = pd.to_datetime(df['time'])
|
51 |
+
|
52 |
+
# Calculate time difference and keep rows with at least the specified time difference
|
53 |
+
df['time_diff'] = df['time'].diff()
|
54 |
+
df = df[df['time_diff'].isna() | (df['time_diff'] >= pd.Timedelta(minutes=time_diff_minutes))]
|
55 |
+
|
56 |
+
# Drop the temporary time_diff column
|
57 |
+
df = df.drop(columns=['time_diff'])
|
58 |
+
|
59 |
+
# Ensure glucose values are in float64
|
60 |
+
df['gl'] = df['gl'].astype('float64')
|
61 |
+
|
62 |
+
# Optionally drop duplicate rows based on time
|
63 |
+
if drop_duplicates:
|
64 |
+
df = df.drop_duplicates(subset=['time'], keep='first')
|
65 |
+
|
66 |
+
# Write the modified dataframe to a new CSV file
|
67 |
+
df.to_csv(output_file, index=False)
|
68 |
+
|
69 |
+
typer.echo("CSV files have been successfully merged, modified, and saved.")
|
70 |
+
|
71 |
+
return df
|
72 |
+
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
def process_multiple_csv(
|
77 |
+
input_dir: Path = typer.Argument('./raw_data/livia_unmerged', help="Directory containing the input CSV files."),
|
78 |
+
output_file: Path = typer.Argument('./raw_data/livia_unmerged/livia_mini.csv', help="Path to save the processed CSV file."),
|
79 |
+
event_type_filter: str = typer.Option('egv', help="Event type to filter by."),
|
80 |
+
drop_duplicates: bool = typer.Option(True, help="Whether to drop duplicate timestamps."),
|
81 |
+
time_diff_minutes: int = typer.Option(1, help="Minimum time difference in minutes to keep a row."),
|
82 |
+
chunk_size: int = typer.Option(1000, help="Chunk size for the 'id' column increment. Set to 0 or None for a single id."),
|
83 |
+
):
|
84 |
+
# Get all the CSV files in the specified directory
|
85 |
+
all_files = list(input_dir.glob("*.csv"))
|
86 |
+
|
87 |
+
# List to store the DataFrames
|
88 |
+
df_list = []
|
89 |
+
|
90 |
+
# Read each CSV file into a DataFrame and append to the list
|
91 |
+
for filename in all_files:
|
92 |
+
df = pd.read_csv(filename, low_memory=False)
|
93 |
+
df_list.append(df)
|
94 |
+
|
95 |
+
# Concatenate all DataFrames in the list
|
96 |
+
combined_df = pd.concat(df_list, ignore_index=True)
|
97 |
+
|
98 |
+
# Filter by Event Type and Event Subtype
|
99 |
+
combined_df = combined_df[combined_df['Event Type'].str.lower() == event_type_filter]
|
100 |
+
combined_df = combined_df[combined_df['Event Subtype'].isna()]
|
101 |
+
|
102 |
+
# List of columns to keep
|
103 |
+
columns_to_keep = [
|
104 |
+
'Index',
|
105 |
+
'Timestamp (YYYY-MM-DDThh:mm:ss)',
|
106 |
+
'Glucose Value (mg/dL)',
|
107 |
+
]
|
108 |
+
|
109 |
+
# Keep only the specified columns
|
110 |
+
combined_df = combined_df[columns_to_keep]
|
111 |
+
|
112 |
+
# Rename columns
|
113 |
+
column_rename = {
|
114 |
+
'Index': 'id',
|
115 |
+
'Timestamp (YYYY-MM-DDThh:mm:ss)': 'time',
|
116 |
+
'Glucose Value (mg/dL)': 'gl'
|
117 |
+
}
|
118 |
+
combined_df = combined_df.rename(columns=column_rename)
|
119 |
+
|
120 |
+
# Sort the combined DataFrame by timestamp
|
121 |
+
combined_df = combined_df.sort_values('time')
|
122 |
+
|
123 |
+
# Handle id assignment based on chunk_size
|
124 |
+
if chunk_size is None or chunk_size == 0:
|
125 |
+
combined_df['id'] = 1 # Assign the same id to all rows
|
126 |
+
else:
|
127 |
+
combined_df['id'] = ((combined_df.index // chunk_size) % (combined_df.index.max() // chunk_size + 1)).astype(int)
|
128 |
+
|
129 |
+
# Convert timestamp to datetime
|
130 |
+
combined_df['time'] = pd.to_datetime(combined_df['time'])
|
131 |
+
|
132 |
+
# Calculate time difference and keep rows with at least the specified time difference
|
133 |
+
combined_df['time_diff'] = combined_df['time'].diff()
|
134 |
+
combined_df = combined_df[combined_df['time_diff'].isna() | (combined_df['time_diff'] >= pd.Timedelta(minutes=time_diff_minutes))]
|
135 |
+
|
136 |
+
# Drop the temporary time_diff column
|
137 |
+
combined_df = combined_df.drop(columns=['time_diff'])
|
138 |
+
|
139 |
+
# Ensure glucose values are in float64
|
140 |
+
combined_df['gl'] = combined_df['gl'].astype('float64')
|
141 |
+
|
142 |
+
# Optionally drop duplicate rows based on time
|
143 |
+
if drop_duplicates:
|
144 |
+
combined_df = combined_df.drop_duplicates(subset=['time'], keep='first')
|
145 |
+
|
146 |
+
# Write the modified dataframe to a new CSV file
|
147 |
+
combined_df.to_csv(output_file, index=False)
|
148 |
+
|
149 |
+
typer.echo("CSV files have been successfully merged, modified, and saved.")
|
150 |
+
|
151 |
+
if __name__ == "__main__":
|
152 |
+
typer.run(process_csv)
|
gluformer/__init__.py
ADDED
File without changes
|
gluformer/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (177 Bytes). View file
|
|
gluformer/__pycache__/attention.cpython-311.pyc
ADDED
Binary file (5.85 kB). View file
|
|
gluformer/__pycache__/decoder.cpython-311.pyc
ADDED
Binary file (3.65 kB). View file
|
|
gluformer/__pycache__/embed.cpython-311.pyc
ADDED
Binary file (6.37 kB). View file
|
|
gluformer/__pycache__/encoder.cpython-311.pyc
ADDED
Binary file (5.28 kB). View file
|
|
gluformer/__pycache__/model.cpython-311.pyc
ADDED
Binary file (15.9 kB). View file
|
|
gluformer/__pycache__/variance.cpython-311.pyc
ADDED
Binary file (1.89 kB). View file
|
|
gluformer/attention.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
from math import sqrt
|
6 |
+
|
7 |
+
class CausalConv1d(torch.nn.Conv1d):
|
8 |
+
def __init__(self,
|
9 |
+
in_channels,
|
10 |
+
out_channels,
|
11 |
+
kernel_size,
|
12 |
+
stride=1,
|
13 |
+
dilation=1,
|
14 |
+
groups=1,
|
15 |
+
bias=True):
|
16 |
+
self.__padding = (kernel_size - 1) * dilation
|
17 |
+
|
18 |
+
super(CausalConv1d, self).__init__(
|
19 |
+
in_channels,
|
20 |
+
out_channels,
|
21 |
+
kernel_size=kernel_size,
|
22 |
+
stride=stride,
|
23 |
+
padding=self.__padding,
|
24 |
+
dilation=dilation,
|
25 |
+
groups=groups,
|
26 |
+
bias=bias)
|
27 |
+
|
28 |
+
def forward(self, input):
|
29 |
+
result = super(CausalConv1d, self).forward(input)
|
30 |
+
if self.__padding != 0:
|
31 |
+
return result[:, :, :-self.__padding]
|
32 |
+
return result
|
33 |
+
|
34 |
+
class TriangularCausalMask():
|
35 |
+
def __init__(self, b, n, device="cpu"):
|
36 |
+
mask_shape = [b, 1, n, n]
|
37 |
+
with torch.no_grad():
|
38 |
+
self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)
|
39 |
+
|
40 |
+
@property
|
41 |
+
def mask(self):
|
42 |
+
return self._mask
|
43 |
+
|
44 |
+
class MultiheadAttention(nn.Module):
|
45 |
+
def __init__(self, d_model, n_heads, d_keys, mask_flag, r_att_drop=0.1):
|
46 |
+
super(MultiheadAttention, self).__init__()
|
47 |
+
self.h, self.d, self.mask_flag= n_heads, d_keys, mask_flag
|
48 |
+
self.proj_q = nn.Linear(d_model, self.h * self.d)
|
49 |
+
self.proj_k = nn.Linear(d_model, self.h * self.d)
|
50 |
+
self.proj_v = nn.Linear(d_model, self.h * self.d)
|
51 |
+
self.proj_out = nn.Linear(self.h * self.d, d_model)
|
52 |
+
self.dropout = nn.Dropout(r_att_drop)
|
53 |
+
|
54 |
+
def forward(self, q, k, v):
|
55 |
+
b, n_q, n_k, h, d = q.size(0), q.size(1), k.size(1), self.h, self.d
|
56 |
+
|
57 |
+
q, k, v = self.proj_q(q), self.proj_k(k), self.proj_v(v) # b, n_*, h*d
|
58 |
+
q, k, v = map(lambda x: x.reshape(b, -1, h, d), [q, k, v]) # b, n_*, h, d
|
59 |
+
scores = torch.einsum('bnhd,bmhd->bhnm', (q,k)) # b, h, n_q, n_k
|
60 |
+
|
61 |
+
if self.mask_flag:
|
62 |
+
att_mask = TriangularCausalMask(b, n_q, device=q.device)
|
63 |
+
scores.masked_fill_(att_mask.mask, -np.inf)
|
64 |
+
|
65 |
+
att = F.softmax(scores / (self.d ** .5), dim=-1) # b, h, n_q, n_k
|
66 |
+
att = self.dropout(att)
|
67 |
+
att_out = torch.einsum('bhnm,bmhd->bnhd', (att,v)) # b, n_q, h, d
|
68 |
+
att_out = att_out.reshape(b, -1, h*d) # b, n_q, h*d
|
69 |
+
out = self.proj_out(att_out) # b, n_q, d_model
|
70 |
+
return out
|
gluformer/decoder.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .attention import *
|
6 |
+
|
7 |
+
class DecoderLayer(nn.Module):
|
8 |
+
def __init__(self, self_att, cross_att, d_model, d_fcn,
|
9 |
+
r_drop, activ="relu"):
|
10 |
+
super(DecoderLayer, self).__init__()
|
11 |
+
|
12 |
+
self.self_att = self_att
|
13 |
+
self.cross_att = cross_att
|
14 |
+
|
15 |
+
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_fcn, kernel_size=1)
|
16 |
+
self.conv2 = nn.Conv1d(in_channels=d_fcn, out_channels=d_model, kernel_size=1)
|
17 |
+
|
18 |
+
self.norm1 = nn.LayerNorm(d_model)
|
19 |
+
self.norm2 = nn.LayerNorm(d_model)
|
20 |
+
self.norm3 = nn.LayerNorm(d_model)
|
21 |
+
|
22 |
+
self.dropout = nn.Dropout(r_drop)
|
23 |
+
self.activ = F.relu if activ == "relu" else F.gelu
|
24 |
+
|
25 |
+
def forward(self, x_dec, x_enc):
|
26 |
+
x_dec = x_dec + self.self_att(x_dec, x_dec, x_dec)
|
27 |
+
x_dec = self.norm1(x_dec)
|
28 |
+
|
29 |
+
x_dec = x_dec + self.cross_att(x_dec, x_enc, x_enc)
|
30 |
+
res = x_dec = self.norm2(x_dec)
|
31 |
+
|
32 |
+
res = self.dropout(self.activ(self.conv1(res.transpose(-1,1))))
|
33 |
+
res = self.dropout(self.conv2(res).transpose(-1,1))
|
34 |
+
|
35 |
+
return self.norm3(x_dec+res)
|
36 |
+
|
37 |
+
class Decoder(nn.Module):
|
38 |
+
def __init__(self, layers, norm_layer=None):
|
39 |
+
super(Decoder, self).__init__()
|
40 |
+
self.layers = nn.ModuleList(layers)
|
41 |
+
self.norm = norm_layer
|
42 |
+
|
43 |
+
def forward(self, x_dec, x_enc):
|
44 |
+
for layer in self.layers:
|
45 |
+
x_dec = layer(x_dec, x_enc)
|
46 |
+
|
47 |
+
if self.norm is not None:
|
48 |
+
x_dec = self.norm(x_dec)
|
49 |
+
|
50 |
+
return x_dec
|
gluformer/embed.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import math
|
5 |
+
|
6 |
+
class PositionalEmbedding(nn.Module):
|
7 |
+
def __init__(self, d_model, max_len=5000):
|
8 |
+
super(PositionalEmbedding, self).__init__()
|
9 |
+
# Compute the positional encodings once in log space.
|
10 |
+
pos_emb = torch.zeros(max_len, d_model)
|
11 |
+
pos_emb.require_grad = False
|
12 |
+
|
13 |
+
position = torch.arange(0, max_len).unsqueeze(1)
|
14 |
+
div_term = (torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)).exp()
|
15 |
+
|
16 |
+
pos_emb[:, 0::2] = torch.sin(position * div_term)
|
17 |
+
pos_emb[:, 1::2] = torch.cos(position * div_term)
|
18 |
+
|
19 |
+
pos_emb = pos_emb.unsqueeze(0)
|
20 |
+
self.register_buffer('pos_emb', pos_emb)
|
21 |
+
|
22 |
+
def forward(self, x):
|
23 |
+
return self.pos_emb[:, :x.size(1)]
|
24 |
+
|
25 |
+
class TokenEmbedding(nn.Module):
|
26 |
+
def __init__(self, d_model):
|
27 |
+
super(TokenEmbedding, self).__init__()
|
28 |
+
D_INP = 1 # one sequence
|
29 |
+
self.conv = nn.Conv1d(in_channels=D_INP, out_channels=d_model,
|
30 |
+
kernel_size=3, padding=1, padding_mode='circular')
|
31 |
+
# nn.init.kaiming_normal_(self.conv.weight, mode='fan_in', nonlinearity='leaky_relu')
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
x = self.conv(x.transpose(-1, 1)).transpose(-1, 1)
|
35 |
+
return x
|
36 |
+
|
37 |
+
class TemporalEmbedding(nn.Module):
|
38 |
+
def __init__(self, d_model, num_features):
|
39 |
+
super(TemporalEmbedding, self).__init__()
|
40 |
+
self.embed = nn.Linear(num_features, d_model)
|
41 |
+
|
42 |
+
def forward(self, x):
|
43 |
+
return self.embed(x)
|
44 |
+
|
45 |
+
class SubjectEmbedding(nn.Module):
|
46 |
+
def __init__(self, d_model, num_features):
|
47 |
+
super(SubjectEmbedding, self).__init__()
|
48 |
+
self.id_embedding = nn.Linear(num_features, d_model)
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
embed_x = self.id_embedding(x)
|
52 |
+
|
53 |
+
return embed_x
|
54 |
+
|
55 |
+
class DataEmbedding(nn.Module):
|
56 |
+
def __init__(self, d_model, r_drop, num_dynamic_features, num_static_features):
|
57 |
+
super(DataEmbedding, self).__init__()
|
58 |
+
# note: d_model // 2 == 0
|
59 |
+
self.value_embedding = TokenEmbedding(d_model)
|
60 |
+
self.time_embedding = TemporalEmbedding(d_model, num_dynamic_features) # alternative: TimeFeatureEmbedding
|
61 |
+
self.positional_embedding = PositionalEmbedding(d_model)
|
62 |
+
self.subject_embedding = SubjectEmbedding(d_model, num_static_features)
|
63 |
+
self.dropout = nn.Dropout(r_drop)
|
64 |
+
|
65 |
+
def forward(self, x_id, x, x_mark):
|
66 |
+
x = self.value_embedding(x) + self.positional_embedding(x) + self.time_embedding(x_mark)
|
67 |
+
x_id = self.subject_embedding(x_id)
|
68 |
+
x = torch.cat((x_id.unsqueeze(1), x), dim = 1)
|
69 |
+
return self.dropout(x)
|
gluformer/encoder.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
from .attention import *
|
6 |
+
|
7 |
+
class ConvLayer(nn.Module):
|
8 |
+
def __init__(self, d_model):
|
9 |
+
super(ConvLayer, self).__init__()
|
10 |
+
self.downConv = nn.Conv1d(in_channels=d_model, out_channels=d_model,
|
11 |
+
kernel_size=3, padding=1, padding_mode='circular')
|
12 |
+
self.norm = nn.BatchNorm1d(d_model)
|
13 |
+
self.activ = nn.ELU()
|
14 |
+
self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
x = self.downConv(x.transpose(-1, 1))
|
18 |
+
x = self.norm(x)
|
19 |
+
x = self.activ(x)
|
20 |
+
x = self.maxPool(x)
|
21 |
+
x = x.transpose(-1,1)
|
22 |
+
return x
|
23 |
+
|
24 |
+
class EncoderLayer(nn.Module):
|
25 |
+
def __init__(self, att, d_model, d_fcn, r_drop, activ="relu"):
|
26 |
+
super(EncoderLayer, self).__init__()
|
27 |
+
|
28 |
+
self.att = att
|
29 |
+
self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_fcn, kernel_size=1)
|
30 |
+
self.conv2 = nn.Conv1d(in_channels=d_fcn, out_channels=d_model, kernel_size=1)
|
31 |
+
self.norm1 = nn.LayerNorm(d_model)
|
32 |
+
self.norm2 = nn.LayerNorm(d_model)
|
33 |
+
self.dropout = nn.Dropout(r_drop)
|
34 |
+
self.activ = F.relu if activ == "relu" else F.gelu
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
new_x = self.att(x, x, x)
|
38 |
+
x = x + self.dropout(new_x)
|
39 |
+
|
40 |
+
res = x = self.norm1(x)
|
41 |
+
res = self.dropout(self.activ(self.conv1(res.transpose(-1,1))))
|
42 |
+
res = self.dropout(self.conv2(res).transpose(-1,1))
|
43 |
+
|
44 |
+
return self.norm2(x+res)
|
45 |
+
|
46 |
+
class Encoder(nn.Module):
|
47 |
+
def __init__(self, enc_layers, conv_layers=None, norm_layer=None):
|
48 |
+
super(Encoder, self).__init__()
|
49 |
+
self.enc_layers = nn.ModuleList(enc_layers)
|
50 |
+
self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
|
51 |
+
self.norm = norm_layer
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
# x [B, L, D]
|
55 |
+
if self.conv_layers is not None:
|
56 |
+
for enc_layer, conv_layer in zip(self.enc_layers, self.conv_layers):
|
57 |
+
x = enc_layer(x)
|
58 |
+
x = conv_layer(x)
|
59 |
+
x = self.enc_layers[-1](x)
|
60 |
+
else:
|
61 |
+
for enc_layer in self.enc_layers:
|
62 |
+
x = enc_layer(x)
|
63 |
+
|
64 |
+
if self.norm is not None:
|
65 |
+
x = self.norm(x)
|
66 |
+
|
67 |
+
return x
|
gluformer/model.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from .embed import *
|
10 |
+
from .attention import *
|
11 |
+
from .encoder import *
|
12 |
+
from .decoder import *
|
13 |
+
from .variance import *
|
14 |
+
|
15 |
+
############################################
|
16 |
+
# Added for GluNet package
|
17 |
+
############################################
|
18 |
+
import optuna
|
19 |
+
import darts
|
20 |
+
from torch.utils.tensorboard import SummaryWriter
|
21 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
|
22 |
+
from glucose.gluformer.utils.training import ExpLikeliLoss, \
|
23 |
+
EarlyStop, \
|
24 |
+
modify_collate, \
|
25 |
+
adjust_learning_rate
|
26 |
+
from glucose.utils.darts_dataset import SamplingDatasetDual
|
27 |
+
############################################
|
28 |
+
|
29 |
+
class Gluformer(nn.Module):
|
30 |
+
def __init__(self, d_model, n_heads, d_fcn, r_drop,
|
31 |
+
activ, num_enc_layers, num_dec_layers,
|
32 |
+
distil, len_seq, len_pred, num_dynamic_features,
|
33 |
+
num_static_features, label_len):
|
34 |
+
super(Gluformer, self).__init__()
|
35 |
+
# Set prediction length
|
36 |
+
self.len_pred = len_pred
|
37 |
+
self.label_len = label_len
|
38 |
+
# Embedding
|
39 |
+
# note: d_model // 2 == 0
|
40 |
+
self.enc_embedding = DataEmbedding(d_model, r_drop, num_dynamic_features, num_static_features)
|
41 |
+
self.dec_embedding = DataEmbedding(d_model, r_drop, num_dynamic_features, num_static_features)
|
42 |
+
# Encoding
|
43 |
+
self.encoder = Encoder(
|
44 |
+
[
|
45 |
+
EncoderLayer(
|
46 |
+
att=MultiheadAttention(d_model=d_model, n_heads=n_heads,
|
47 |
+
d_keys=d_model//n_heads, mask_flag=False,
|
48 |
+
r_att_drop=r_drop),
|
49 |
+
d_model=d_model,
|
50 |
+
d_fcn=d_fcn,
|
51 |
+
r_drop=r_drop,
|
52 |
+
activ=activ) for l in range(num_enc_layers)
|
53 |
+
],
|
54 |
+
[
|
55 |
+
ConvLayer(
|
56 |
+
d_model) for l in range(num_enc_layers-1)
|
57 |
+
] if distil else None,
|
58 |
+
norm_layer=torch.nn.LayerNorm(d_model)
|
59 |
+
)
|
60 |
+
|
61 |
+
# Decoding
|
62 |
+
self.decoder = Decoder(
|
63 |
+
[
|
64 |
+
DecoderLayer(
|
65 |
+
self_att=MultiheadAttention(d_model=d_model, n_heads=n_heads,
|
66 |
+
d_keys=d_model//n_heads, mask_flag=True,
|
67 |
+
r_att_drop=r_drop),
|
68 |
+
cross_att=MultiheadAttention(d_model=d_model, n_heads=n_heads,
|
69 |
+
d_keys=d_model//n_heads, mask_flag=False,
|
70 |
+
r_att_drop=r_drop),
|
71 |
+
d_model=d_model,
|
72 |
+
d_fcn=d_fcn,
|
73 |
+
r_drop=r_drop,
|
74 |
+
activ=activ) for l in range(num_dec_layers)
|
75 |
+
],
|
76 |
+
norm_layer=torch.nn.LayerNorm(d_model)
|
77 |
+
)
|
78 |
+
|
79 |
+
# Output
|
80 |
+
D_OUT = 1
|
81 |
+
self.projection = nn.Linear(d_model, D_OUT, bias=True)
|
82 |
+
|
83 |
+
# Train variance
|
84 |
+
self.var = Variance(d_model, r_drop, len_seq)
|
85 |
+
|
86 |
+
def forward(self, x_id, x_enc, x_mark_enc, x_dec, x_mark_dec):
|
87 |
+
enc_out = self.enc_embedding(x_id, x_enc, x_mark_enc)
|
88 |
+
var_out = self.var(enc_out)
|
89 |
+
enc_out = self.encoder(enc_out)
|
90 |
+
|
91 |
+
dec_out = self.dec_embedding(x_id, x_dec, x_mark_dec)
|
92 |
+
dec_out = self.decoder(dec_out, enc_out)
|
93 |
+
dec_out = self.projection(dec_out)
|
94 |
+
|
95 |
+
return dec_out[:, -self.len_pred:, :], var_out # [B, L, D], log variance
|
96 |
+
|
97 |
+
############################################
|
98 |
+
# Added for GluNet package
|
99 |
+
############################################
|
100 |
+
def fit(self,
|
101 |
+
train_dataset: SamplingDatasetDual,
|
102 |
+
val_dataset: SamplingDatasetDual,
|
103 |
+
learning_rate: float = 1e-3,
|
104 |
+
batch_size: int = 32,
|
105 |
+
epochs: int = 100,
|
106 |
+
num_samples: int = 100,
|
107 |
+
device: str = 'cuda',
|
108 |
+
model_path: str = None,
|
109 |
+
trial: optuna.trial.Trial = None,
|
110 |
+
logger: SummaryWriter = None,):
|
111 |
+
"""
|
112 |
+
Fit the model to the data, using Optuna for hyperparameter tuning.
|
113 |
+
|
114 |
+
Parameters
|
115 |
+
----------
|
116 |
+
train_dataset: SamplingDatasetPast
|
117 |
+
Training dataset.
|
118 |
+
val_dataset: SamplingDatasetPast
|
119 |
+
Validation dataset.
|
120 |
+
learning_rate: float
|
121 |
+
Learning rate for Adam.
|
122 |
+
batch_size: int
|
123 |
+
Batch size.
|
124 |
+
epochs: int
|
125 |
+
Number of epochs.
|
126 |
+
num_samples: int
|
127 |
+
Number of samples for infinite mixture
|
128 |
+
device: str
|
129 |
+
Device to use.
|
130 |
+
model_path: str
|
131 |
+
Path to save the model.
|
132 |
+
trial: optuna.trial.Trial
|
133 |
+
Trial for hyperparameter tuning.
|
134 |
+
logger: SummaryWriter
|
135 |
+
Tensorboard logger for logging.
|
136 |
+
"""
|
137 |
+
# create data loaders, optimizer, loss, and early stopping
|
138 |
+
collate_fn_custom = modify_collate(num_samples)
|
139 |
+
train_loader = torch.utils.data.DataLoader(train_dataset,
|
140 |
+
batch_size=batch_size,
|
141 |
+
shuffle=True,
|
142 |
+
drop_last=True,
|
143 |
+
collate_fn=collate_fn_custom)
|
144 |
+
val_loader = torch.utils.data.DataLoader(val_dataset,
|
145 |
+
batch_size=batch_size,
|
146 |
+
shuffle=True,
|
147 |
+
drop_last=True,
|
148 |
+
collate_fn=collate_fn_custom)
|
149 |
+
criterion = ExpLikeliLoss(num_samples)
|
150 |
+
optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate, betas=(0.1, 0.9))
|
151 |
+
scaler = torch.cuda.amp.GradScaler()
|
152 |
+
early_stop = EarlyStop(patience=10, delta=0.001)
|
153 |
+
self.to(device)
|
154 |
+
# train and evaluate the model
|
155 |
+
for epoch in range(epochs):
|
156 |
+
train_loss = []
|
157 |
+
for i, (past_target_series,
|
158 |
+
past_covariates,
|
159 |
+
future_covariates,
|
160 |
+
static_covariates,
|
161 |
+
future_target_series) in enumerate(train_loader):
|
162 |
+
# zero out gradient
|
163 |
+
optimizer.zero_grad()
|
164 |
+
# reshape static covariates to be [batch_size, num_static_covariates]
|
165 |
+
static_covariates = static_covariates.reshape(-1, static_covariates.shape[-1])
|
166 |
+
# create decoder input: pad with zeros the prediction sequence
|
167 |
+
dec_inp = torch.cat([past_target_series[:, -self.label_len:, :],
|
168 |
+
torch.zeros([
|
169 |
+
past_target_series.shape[0],
|
170 |
+
self.len_pred,
|
171 |
+
past_target_series.shape[-1]
|
172 |
+
])],
|
173 |
+
dim=1)
|
174 |
+
future_covariates = torch.cat([past_covariates[:, -self.label_len:, :],
|
175 |
+
future_covariates], dim=1)
|
176 |
+
# move to device
|
177 |
+
dec_inp = dec_inp.to(device)
|
178 |
+
past_target_series = past_target_series.to(device)
|
179 |
+
past_covariates = past_covariates.to(device)
|
180 |
+
future_covariates = future_covariates.to(device)
|
181 |
+
static_covariates = static_covariates.to(device)
|
182 |
+
future_target_series = future_target_series.to(device)
|
183 |
+
# forward pass with autograd
|
184 |
+
with torch.cuda.amp.autocast():
|
185 |
+
pred, logvar = self(static_covariates,
|
186 |
+
past_target_series,
|
187 |
+
past_covariates,
|
188 |
+
dec_inp,
|
189 |
+
future_covariates)
|
190 |
+
loss = criterion(pred, future_target_series, logvar)
|
191 |
+
# backward pass
|
192 |
+
scaler.scale(loss).backward()
|
193 |
+
scaler.step(optimizer)
|
194 |
+
scaler.update()
|
195 |
+
# log loss
|
196 |
+
if logger is not None:
|
197 |
+
logger.add_scalar('train_loss', loss.item(), epoch * len(train_loader) + i)
|
198 |
+
train_loss.append(loss.item())
|
199 |
+
# log loss
|
200 |
+
if logger is not None:
|
201 |
+
logger.add_scalar('train_loss_epoch', np.mean(train_loss), epoch)
|
202 |
+
# evaluate the model
|
203 |
+
val_loss = []
|
204 |
+
with torch.no_grad():
|
205 |
+
for i, (past_target_series,
|
206 |
+
past_covariates,
|
207 |
+
future_covariates,
|
208 |
+
static_covariates,
|
209 |
+
future_target_series) in enumerate(val_loader):
|
210 |
+
# reshape static covariates to be [batch_size, num_static_covariates]
|
211 |
+
static_covariates = static_covariates.reshape(-1, static_covariates.shape[-1])
|
212 |
+
# create decoder input
|
213 |
+
dec_inp = torch.cat([past_target_series[:, -self.label_len:, :],
|
214 |
+
torch.zeros([
|
215 |
+
past_target_series.shape[0],
|
216 |
+
self.len_pred,
|
217 |
+
past_target_series.shape[-1]
|
218 |
+
])],
|
219 |
+
dim=1)
|
220 |
+
future_covariates = torch.cat([past_covariates[:, -self.label_len:, :],
|
221 |
+
future_covariates], dim=1)
|
222 |
+
# move to device
|
223 |
+
dec_inp = dec_inp.to(device)
|
224 |
+
past_target_series = past_target_series.to(device)
|
225 |
+
past_covariates = past_covariates.to(device)
|
226 |
+
future_covariates = future_covariates.to(device)
|
227 |
+
static_covariates = static_covariates.to(device)
|
228 |
+
future_target_series = future_target_series.to(device)
|
229 |
+
# forward pass
|
230 |
+
pred, logvar = self(static_covariates,
|
231 |
+
past_target_series,
|
232 |
+
past_covariates,
|
233 |
+
dec_inp,
|
234 |
+
future_covariates)
|
235 |
+
loss = criterion(pred, future_target_series, logvar)
|
236 |
+
val_loss.append(loss.item())
|
237 |
+
# log loss
|
238 |
+
if logger is not None:
|
239 |
+
logger.add_scalar('val_loss', loss.item(), epoch * len(val_loader) + i)
|
240 |
+
# log loss
|
241 |
+
logger.add_scalar('val_loss_epoch', np.mean(val_loss), epoch)
|
242 |
+
# check early stopping
|
243 |
+
early_stop(np.mean(val_loss), self, model_path)
|
244 |
+
if early_stop.stop:
|
245 |
+
break
|
246 |
+
# check pruning
|
247 |
+
if trial is not None:
|
248 |
+
trial.report(np.mean(val_loss), epoch)
|
249 |
+
if trial.should_prune():
|
250 |
+
raise optuna.exceptions.TrialPruned()
|
251 |
+
# load best model
|
252 |
+
if model_path is not None:
|
253 |
+
self.load_state_dict(torch.load(model_path))
|
254 |
+
|
255 |
+
def predict(self, test_dataset: SamplingDatasetDual,
|
256 |
+
batch_size: int = 32,
|
257 |
+
num_samples: int = 100,
|
258 |
+
device: str = 'cuda'):
|
259 |
+
"""
|
260 |
+
Predict the future target series given the supplied samples from the dataset.
|
261 |
+
|
262 |
+
Parameters
|
263 |
+
----------
|
264 |
+
test_dataset : SamplingDatasetInferenceDual
|
265 |
+
The dataset to use for inference.
|
266 |
+
batch_size : int, optional
|
267 |
+
The batch size to use for inference, by default 32
|
268 |
+
num_samples : int, optional
|
269 |
+
The number of samples to use for inference, by default 100
|
270 |
+
|
271 |
+
Returns
|
272 |
+
-------
|
273 |
+
Predictions
|
274 |
+
The predicted future target series in shape n x len_pred x num_samples, where
|
275 |
+
n is total number of predictions.
|
276 |
+
Logvar
|
277 |
+
The logvariance of the predicted future target series in shape n x len_pred.
|
278 |
+
"""
|
279 |
+
# define data loader
|
280 |
+
collate_fn_custom = modify_collate(num_samples)
|
281 |
+
test_loader = torch.utils.data.DataLoader(test_dataset,
|
282 |
+
batch_size=batch_size,
|
283 |
+
shuffle=False,
|
284 |
+
drop_last=False,
|
285 |
+
collate_fn=collate_fn_custom)
|
286 |
+
# predict
|
287 |
+
self.train()
|
288 |
+
# move to device
|
289 |
+
self.to(device)
|
290 |
+
predictions = []; logvars = []
|
291 |
+
for i, (past_target_series,
|
292 |
+
historic_future_covariates,
|
293 |
+
future_covariates,
|
294 |
+
static_covariates) in enumerate(test_loader):
|
295 |
+
# reshape static covariates to be [batch_size, num_static_covariates]
|
296 |
+
static_covariates = static_covariates.reshape(-1, static_covariates.shape[-1])
|
297 |
+
# create decoder input
|
298 |
+
dec_inp = torch.cat([past_target_series[:, -self.label_len:, :],
|
299 |
+
torch.zeros([
|
300 |
+
past_target_series.shape[0],
|
301 |
+
self.len_pred,
|
302 |
+
past_target_series.shape[-1]
|
303 |
+
])],
|
304 |
+
dim=1)
|
305 |
+
future_covariates = torch.cat([historic_future_covariates[:, -self.label_len:, :],
|
306 |
+
future_covariates], dim=1)
|
307 |
+
# move to device
|
308 |
+
dec_inp = dec_inp.to(device)
|
309 |
+
past_target_series = past_target_series.to(device)
|
310 |
+
historic_future_covariates = historic_future_covariates.to(device)
|
311 |
+
future_covariates = future_covariates.to(device)
|
312 |
+
static_covariates = static_covariates.to(device)
|
313 |
+
# forward pass
|
314 |
+
pred, logvar = self(static_covariates,
|
315 |
+
past_target_series,
|
316 |
+
historic_future_covariates,
|
317 |
+
dec_inp,
|
318 |
+
future_covariates)
|
319 |
+
# transfer in numpy and arrange sample along last axis
|
320 |
+
pred = pred.cpu().detach().numpy()
|
321 |
+
logvar = logvar.cpu().detach().numpy()
|
322 |
+
pred = pred.transpose((1, 0, 2)).reshape((pred.shape[1], -1, num_samples)).transpose((1, 0, 2))
|
323 |
+
logvar = logvar.transpose((1, 0, 2)).reshape((logvar.shape[1], -1, num_samples)).transpose((1, 0, 2))
|
324 |
+
predictions.append(pred)
|
325 |
+
logvars.append(logvar)
|
326 |
+
predictions = np.concatenate(predictions, axis=0)
|
327 |
+
logvars = np.concatenate(logvars, axis=0)
|
328 |
+
return predictions, logvars
|
329 |
+
|
330 |
+
############################################
|
331 |
+
|
332 |
+
|
333 |
+
|
334 |
+
|
gluformer/utils/__init__.py
ADDED
File without changes
|
gluformer/utils/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (183 Bytes). View file
|
|
gluformer/utils/__pycache__/collate.cpython-311.pyc
ADDED
Binary file (7.14 kB). View file
|
|
gluformer/utils/__pycache__/training.cpython-311.pyc
ADDED
Binary file (6.84 kB). View file
|
|
gluformer/utils/collate.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers to
|
2 |
+
collate samples fetched from dataset into Tensor(s).
|
3 |
+
These **needs** to be in global scope since Py2 doesn't support serializing
|
4 |
+
static methods.
|
5 |
+
"""
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import re
|
9 |
+
import collections
|
10 |
+
|
11 |
+
np_str_obj_array_pattern = re.compile(r'[SaUO]')
|
12 |
+
|
13 |
+
|
14 |
+
def default_convert(data):
|
15 |
+
r"""Converts each NumPy array data field into a tensor"""
|
16 |
+
elem_type = type(data)
|
17 |
+
if isinstance(data, torch.Tensor):
|
18 |
+
return data
|
19 |
+
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
20 |
+
and elem_type.__name__ != 'string_':
|
21 |
+
# array of string classes and object
|
22 |
+
if elem_type.__name__ == 'ndarray' \
|
23 |
+
and np_str_obj_array_pattern.search(data.dtype.str) is not None:
|
24 |
+
return data
|
25 |
+
return torch.as_tensor(data)
|
26 |
+
elif isinstance(data, collections.abc.Mapping):
|
27 |
+
return {key: default_convert(data[key]) for key in data}
|
28 |
+
elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple
|
29 |
+
return elem_type(*(default_convert(d) for d in data))
|
30 |
+
elif isinstance(data, collections.abc.Sequence) and not isinstance(data, str):
|
31 |
+
return [default_convert(d) for d in data]
|
32 |
+
else:
|
33 |
+
return data
|
34 |
+
|
35 |
+
|
36 |
+
default_collate_err_msg_format = (
|
37 |
+
"default_collate: batch must contain tensors, numpy arrays, numbers, "
|
38 |
+
"dicts or lists; found {}")
|
39 |
+
|
40 |
+
|
41 |
+
def default_collate(batch):
|
42 |
+
r"""Puts each data field into a tensor with outer dimension batch size"""
|
43 |
+
|
44 |
+
elem = batch[0]
|
45 |
+
elem_type = type(elem)
|
46 |
+
if isinstance(elem, torch.Tensor):
|
47 |
+
out = None
|
48 |
+
if torch.utils.data.get_worker_info() is not None:
|
49 |
+
# If we're in a background process, concatenate directly into a
|
50 |
+
# shared memory tensor to avoid an extra copy
|
51 |
+
numel = sum(x.numel() for x in batch)
|
52 |
+
storage = elem.storage()._new_shared(numel)
|
53 |
+
out = elem.new(storage)
|
54 |
+
return torch.stack(batch, 0, out=out)
|
55 |
+
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
|
56 |
+
and elem_type.__name__ != 'string_':
|
57 |
+
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
|
58 |
+
# array of string classes and object
|
59 |
+
if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
|
60 |
+
raise TypeError(default_collate_err_msg_format.format(elem.dtype))
|
61 |
+
|
62 |
+
return default_collate([torch.as_tensor(b) for b in batch])
|
63 |
+
elif elem.shape == (): # scalars
|
64 |
+
return torch.as_tensor(batch)
|
65 |
+
elif isinstance(elem, float):
|
66 |
+
return torch.tensor(batch, dtype=torch.float64)
|
67 |
+
elif isinstance(elem, int):
|
68 |
+
return torch.tensor(batch)
|
69 |
+
elif isinstance(elem, str):
|
70 |
+
return batch
|
71 |
+
elif isinstance(elem, collections.abc.Mapping):
|
72 |
+
return {key: default_collate([d[key] for d in batch]) for key in elem}
|
73 |
+
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
|
74 |
+
return elem_type(*(default_collate(samples) for samples in zip(*batch)))
|
75 |
+
elif isinstance(elem, collections.abc.Sequence):
|
76 |
+
# check to make sure that the elements in batch have consistent size
|
77 |
+
it = iter(batch)
|
78 |
+
elem_size = len(next(it))
|
79 |
+
if not all(len(elem) == elem_size for elem in it):
|
80 |
+
raise RuntimeError('each element in list of batch should be of equal size')
|
81 |
+
transposed = zip(*batch)
|
82 |
+
return [default_collate(samples) for samples in transposed]
|
83 |
+
|
84 |
+
raise TypeError(default_collate_err_msg_format.format(elem_type))
|
gluformer/utils/evaluation.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
import yaml
|
4 |
+
import random
|
5 |
+
from typing import Any, \
|
6 |
+
BinaryIO, \
|
7 |
+
Callable, \
|
8 |
+
Dict, \
|
9 |
+
List, \
|
10 |
+
Optional, \
|
11 |
+
Sequence, \
|
12 |
+
Tuple, \
|
13 |
+
Union
|
14 |
+
|
15 |
+
import numpy as np
|
16 |
+
import scipy as sp
|
17 |
+
import pandas as pd
|
18 |
+
import torch
|
19 |
+
|
20 |
+
# import data formatter
|
21 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
22 |
+
|
23 |
+
def test(series: np.ndarray,
|
24 |
+
forecasts: np.ndarray,
|
25 |
+
var: np.ndarray,
|
26 |
+
cal_thresholds: Optional[np.ndarray] = np.linspace(0, 1, 11),
|
27 |
+
):
|
28 |
+
"""
|
29 |
+
Test the (rescaled to original scale) forecasts on the series.
|
30 |
+
|
31 |
+
Parameters
|
32 |
+
----------
|
33 |
+
series
|
34 |
+
The target time series of shape (n, t),
|
35 |
+
where t is length of prediction.
|
36 |
+
forecasts
|
37 |
+
The forecasted means of mixture components of shape (n, t, k),
|
38 |
+
where k is the number of mixture components.
|
39 |
+
var
|
40 |
+
The forecasted variances of mixture components of shape (n, 1, k),
|
41 |
+
where k is the number of mixture components.
|
42 |
+
metric
|
43 |
+
The metric or metrics to use for backtesting.
|
44 |
+
cal_thresholds
|
45 |
+
The thresholds to use for computing the calibration error.
|
46 |
+
|
47 |
+
Returns
|
48 |
+
-------
|
49 |
+
np.ndarray
|
50 |
+
Error array. Array of shape (n, p)
|
51 |
+
where n = series.shape[0] = forecasts.shape[0] and p = len(metric).
|
52 |
+
float
|
53 |
+
The estimated log-likelihood of the model on the data.
|
54 |
+
np.ndarray
|
55 |
+
The ECE for each time point in the forecast.
|
56 |
+
"""
|
57 |
+
# compute errors: 1) get samples 2) compute errors using median
|
58 |
+
samples = np.random.normal(loc=forecasts[..., None],
|
59 |
+
scale=np.sqrt(var)[..., None],
|
60 |
+
size=(forecasts.shape[0],
|
61 |
+
forecasts.shape[1],
|
62 |
+
forecasts.shape[2],
|
63 |
+
30))
|
64 |
+
samples = samples.reshape(samples.shape[0], samples.shape[1], -1)
|
65 |
+
mse = np.mean((series.squeeze() - forecasts.mean(axis=-1))**2, axis=-1)
|
66 |
+
mae = np.mean(np.abs(series.squeeze() - forecasts.mean(axis=-1)), axis=-1)
|
67 |
+
errors = np.stack([mse, mae], axis=-1)
|
68 |
+
|
69 |
+
# compute likelihood
|
70 |
+
log_likelihood = sp.special.logsumexp((forecasts - series)**2 / (2 * var) -
|
71 |
+
0.5 * np.log(2 * np.pi * var), axis=-1)
|
72 |
+
log_likelihood = np.mean(log_likelihood)
|
73 |
+
|
74 |
+
# compute calibration error:
|
75 |
+
cal_error = np.zeros(forecasts.shape[1])
|
76 |
+
for p in cal_thresholds:
|
77 |
+
q = np.quantile(samples, p, axis=-1)
|
78 |
+
est_p = np.mean(series.squeeze() <= q, axis=0)
|
79 |
+
cal_error += (est_p - p) ** 2
|
80 |
+
|
81 |
+
return errors, log_likelihood, cal_error
|
gluformer/utils/training.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Callable, Any, Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from sympy import pprint
|
6 |
+
from torch import nn, Tensor
|
7 |
+
|
8 |
+
from .collate import default_collate
|
9 |
+
|
10 |
+
class EarlyStop:
|
11 |
+
def __init__(self, patience: int, delta: float):
|
12 |
+
self.patience: int = patience
|
13 |
+
self.delta: float = delta
|
14 |
+
self.counter: int = 0
|
15 |
+
self.best_loss: float = np.Inf
|
16 |
+
self.stop: bool = False
|
17 |
+
|
18 |
+
def __call__(self, loss: float, model: nn.Module, path: str) -> None:
|
19 |
+
if loss < self.best_loss:
|
20 |
+
self.best_loss = loss
|
21 |
+
self.counter = 0
|
22 |
+
torch.save(model.state_dict(), path)
|
23 |
+
elif loss > self.best_loss + self.delta:
|
24 |
+
self.counter = self.counter + 1
|
25 |
+
if self.counter >= self.patience:
|
26 |
+
self.stop = True
|
27 |
+
|
28 |
+
class ExpLikeliLoss(nn.Module):
|
29 |
+
def __init__(self, num_samples: int = 100):
|
30 |
+
super(ExpLikeliLoss, self).__init__()
|
31 |
+
self.num_samples: int = num_samples
|
32 |
+
|
33 |
+
def forward(self, pred: Tensor, true: Tensor, logvar: Tensor) -> Tensor:
|
34 |
+
b, l, d = pred.size(0), pred.size(1), pred.size(2)
|
35 |
+
true = true.transpose(0,1).reshape(l, -1, self.num_samples).transpose(0, 1)
|
36 |
+
pred = pred.transpose(0,1).reshape(l, -1, self.num_samples).transpose(0, 1)
|
37 |
+
logvar = logvar.reshape(-1, self.num_samples)
|
38 |
+
|
39 |
+
loss = torch.mean((-1) * torch.logsumexp((-l / 2) * logvar + (-1 / (2 * torch.exp(logvar))) * torch.sum((true - pred) ** 2, dim=1), dim=1))
|
40 |
+
return loss
|
41 |
+
|
42 |
+
def modify_collate(num_samples: int) -> Callable[[List[Any]], Any]:
|
43 |
+
def wrapper(batch: List[Any]) -> Any:
|
44 |
+
batch_rep = [sample for sample in batch for _ in range(num_samples)]
|
45 |
+
result = default_collate(batch_rep)
|
46 |
+
return result
|
47 |
+
return wrapper
|
48 |
+
|
49 |
+
def adjust_learning_rate(model_optim: torch.optim.Optimizer, epoch: int, lr: float) -> None:
|
50 |
+
lr = lr * (0.5 ** epoch)
|
51 |
+
print("Learning rate halving...")
|
52 |
+
print(f"New lr: {lr:.7f}")
|
53 |
+
for param_group in model_optim.param_groups:
|
54 |
+
param_group['lr'] = lr
|
55 |
+
|
56 |
+
def process_batch(
|
57 |
+
subj_id: Tensor,
|
58 |
+
batch_x: Tensor,
|
59 |
+
batch_y: Tensor,
|
60 |
+
batch_x_mark: Tensor,
|
61 |
+
batch_y_mark: Tensor,
|
62 |
+
len_pred: int,
|
63 |
+
len_label: int,
|
64 |
+
model: nn.Module,
|
65 |
+
device: torch.device
|
66 |
+
) -> Tuple[Tensor, Tensor, Tensor]:
|
67 |
+
subj_id = subj_id.long().to(device)
|
68 |
+
batch_x = batch_x.float().to(device)
|
69 |
+
batch_y = batch_y.float()
|
70 |
+
batch_x_mark = batch_x_mark.float().to(device)
|
71 |
+
batch_y_mark = batch_y_mark.float().to(device)
|
72 |
+
|
73 |
+
true = batch_y[:, -len_pred:, :].to(device)
|
74 |
+
|
75 |
+
dec_inp = torch.zeros([batch_y.shape[0], len_pred, batch_y.shape[-1]], dtype=torch.float, device=device)
|
76 |
+
dec_inp = torch.cat([batch_y[:, :len_label, :].to(device), dec_inp], dim=1)
|
77 |
+
|
78 |
+
pred, logvar = model(subj_id, batch_x, batch_x_mark, dec_inp, batch_y_mark)
|
79 |
+
|
80 |
+
return pred, true, logvar
|
gluformer/variance.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
class Variance(nn.Module):
|
6 |
+
def __init__(self, d_model, r_drop, len_seq):
|
7 |
+
super(Variance, self).__init__()
|
8 |
+
|
9 |
+
self.proj1 = nn.Linear(d_model, 1)
|
10 |
+
self.dropout = nn.Dropout(r_drop)
|
11 |
+
self.activ1 = nn.ReLU()
|
12 |
+
# + 1 (for seq) for embedded person token
|
13 |
+
self.proj2 = nn.Linear(len_seq+1, 1)
|
14 |
+
self.activ2 = nn.Tanh()
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
x = self.proj1(x)
|
18 |
+
x = self.activ1(x)
|
19 |
+
x = self.dropout(x)
|
20 |
+
x = x.transpose(-1, 1)
|
21 |
+
x = self.proj2(x)
|
22 |
+
# scale to [-10, 10] range
|
23 |
+
x = 10 * self.activ2(x)
|
24 |
+
return x
|
main.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from tools import *
|
3 |
+
|
4 |
+
|
5 |
+
def gradio_output():
|
6 |
+
return (predict_glucose_tool())
|
7 |
+
|
8 |
+
gr.Interface(fn=gradio_output).launch()
|
tools.py
ADDED
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
import pickle
|
4 |
+
import gzip
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
import seaborn as sns
|
8 |
+
import numpy as np
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
import matplotlib.colors as mcolors
|
11 |
+
from matplotlib.figure import Figure
|
12 |
+
import torch
|
13 |
+
from scipy import stats
|
14 |
+
|
15 |
+
from gluformer.model import Gluformer
|
16 |
+
from utils.darts_processing import *
|
17 |
+
from utils.darts_dataset import *
|
18 |
+
|
19 |
+
|
20 |
+
import hashlib
|
21 |
+
from urllib.parse import urlparse
|
22 |
+
|
23 |
+
import numpy as np
|
24 |
+
import typer
|
25 |
+
|
26 |
+
|
27 |
+
glucose = Path(os.path.abspath(__file__)).parent.resolve()
|
28 |
+
file_directory = glucose / "files"
|
29 |
+
|
30 |
+
|
31 |
+
def plot_forecast(forecasts: np.ndarray, scalers: Any, dataset_test_glufo: Any, filename: str):
|
32 |
+
filename=filename
|
33 |
+
forecasts = (forecasts - scalers['target'].min_) / scalers['target'].scale_
|
34 |
+
|
35 |
+
trues = [dataset_test_glufo.evalsample(i) for i in range(len(dataset_test_glufo))]
|
36 |
+
trues = scalers['target'].inverse_transform(trues)
|
37 |
+
|
38 |
+
trues = [ts.values() for ts in trues] # Convert TimeSeries to numpy arrays
|
39 |
+
trues = np.array(trues)
|
40 |
+
|
41 |
+
inputs = [dataset_test_glufo[i][0] for i in range(len(dataset_test_glufo))]
|
42 |
+
inputs = (np.array(inputs) - scalers['target'].min_) / scalers['target'].scale_
|
43 |
+
|
44 |
+
# Plot settings
|
45 |
+
colors = ['#00264c', '#0a2c62', '#14437f', '#1f5a9d', '#2973bb', '#358ad9', '#4d9af4', '#7bb7ff', '#add5ff', '#e6f3ff']
|
46 |
+
cmap = mcolors.LinearSegmentedColormap.from_list('my_colormap', colors)
|
47 |
+
sns.set_theme(style="whitegrid")
|
48 |
+
|
49 |
+
# Generate the plot
|
50 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
51 |
+
|
52 |
+
|
53 |
+
# Select a specific sample to plot
|
54 |
+
ind = 30 # Example index
|
55 |
+
|
56 |
+
samples = np.random.normal(
|
57 |
+
loc=forecasts[ind, :], # Mean (center) of the distribution
|
58 |
+
scale=0.1, # Standard deviation (spread) of the distribution
|
59 |
+
size=(forecasts.shape[1], forecasts.shape[2])
|
60 |
+
)
|
61 |
+
#samples = samples.reshape(samples.shape[0], samples.shape[1], -1)
|
62 |
+
#print ("samples",samples.shape)
|
63 |
+
|
64 |
+
# Plot predictive distribution
|
65 |
+
for point in range(samples.shape[0]):
|
66 |
+
kde = stats.gaussian_kde(samples[point,:])
|
67 |
+
maxi, mini = 1.2 * np.max(samples[point, :]), 0.8 * np.min(samples[point, :])
|
68 |
+
y_grid = np.linspace(mini, maxi, 200)
|
69 |
+
x = kde(y_grid)
|
70 |
+
ax.fill_betweenx(y_grid, x1=point, x2=point - x * 15,
|
71 |
+
alpha=0.7,
|
72 |
+
edgecolor='black',
|
73 |
+
color=cmap(point / samples.shape[0]))
|
74 |
+
|
75 |
+
# Plot median
|
76 |
+
forecast = samples[:, :]
|
77 |
+
median = np.quantile(forecast, 0.5, axis=-1)
|
78 |
+
ax.plot(np.arange(12), median, color='red', marker='o')
|
79 |
+
|
80 |
+
# Plot true values
|
81 |
+
ax.plot(np.arange(-12, 12), np.concatenate([inputs[ind, -12:], trues[ind, :]]), color='blue')
|
82 |
+
|
83 |
+
# Add labels and title
|
84 |
+
ax.set_xlabel('Time (in 5 minute intervals)')
|
85 |
+
ax.set_ylabel('Glucose (mg/dL)')
|
86 |
+
ax.set_title(f'Gluformer Prediction with Gradient for dateset')
|
87 |
+
|
88 |
+
# Adjust font sizes
|
89 |
+
ax.xaxis.label.set_fontsize(16)
|
90 |
+
ax.yaxis.label.set_fontsize(16)
|
91 |
+
ax.title.set_fontsize(18)
|
92 |
+
for item in ax.get_xticklabels() + ax.get_yticklabels():
|
93 |
+
item.set_fontsize(14)
|
94 |
+
|
95 |
+
# Save figure
|
96 |
+
plt.tight_layout()
|
97 |
+
where = file_directory /filename
|
98 |
+
plt.savefig(str(where), dpi=300, bbox_inches='tight')
|
99 |
+
|
100 |
+
return where,ax
|
101 |
+
|
102 |
+
|
103 |
+
|
104 |
+
def generate_filename_from_url(url: str, extension: str = "png") -> str:
|
105 |
+
"""
|
106 |
+
:param url:
|
107 |
+
:param extension:
|
108 |
+
:return:
|
109 |
+
"""
|
110 |
+
# Extract the last segment of the URL
|
111 |
+
last_segment = urlparse(url).path.split('/')[-1]
|
112 |
+
|
113 |
+
# Compute the hash of the URL
|
114 |
+
url_hash = hashlib.md5(url.encode('utf-8')).hexdigest()
|
115 |
+
|
116 |
+
# Create the filename
|
117 |
+
filename = f"{last_segment.replace('.','_')}_{url_hash}.{extension}"
|
118 |
+
|
119 |
+
return filename
|
120 |
+
|
121 |
+
|
122 |
+
|
123 |
+
def predict_glucose_tool(url: str= 'https://huggingface.co/datasets/Livia-Zaharia/glucose_processed/blob/main/livia_mini.csv',
|
124 |
+
model: str = 'https://huggingface.co/Livia-Zaharia/gluformer_models/blob/main/gluformer_1samples_10000epochs_10heads_32batch_geluactivation_livia_mini_weights.pth'
|
125 |
+
) -> Figure:
|
126 |
+
"""
|
127 |
+
Function to predict future glucose of user. It receives URL with users csv. It will run an ML and will return URL with predictions that user can open on her own..
|
128 |
+
:param url: of the csv file with glucose values
|
129 |
+
:param model: model that is used to predict the glucose
|
130 |
+
:param explain if it should give both url and explanation
|
131 |
+
:param if the person is diabetic when doing prediction and explanation
|
132 |
+
:return:
|
133 |
+
"""
|
134 |
+
|
135 |
+
formatter, series, scalers = load_data(url=str(url), config_path=file_directory / "config.yaml", use_covs=True,
|
136 |
+
cov_type='dual',
|
137 |
+
use_static_covs=True)
|
138 |
+
|
139 |
+
filename = generate_filename_from_url(url)
|
140 |
+
|
141 |
+
formatter.params['gluformer'] = {
|
142 |
+
'in_len': 96, # example input length, adjust as necessary
|
143 |
+
'd_model': 512, # model dimension
|
144 |
+
'n_heads': 10, # number of attention heads##############################################################################
|
145 |
+
'd_fcn': 1024, # fully connected layer dimension
|
146 |
+
'num_enc_layers': 2, # number of encoder layers
|
147 |
+
'num_dec_layers': 2, # number of decoder layers
|
148 |
+
'length_pred': 12 # prediction length, adjust as necessary
|
149 |
+
}
|
150 |
+
|
151 |
+
num_dynamic_features = series['train']['future'][-1].n_components
|
152 |
+
num_static_features = series['train']['static'][-1].n_components
|
153 |
+
|
154 |
+
glufo = Gluformer(
|
155 |
+
d_model=formatter.params['gluformer']['d_model'],
|
156 |
+
n_heads=formatter.params['gluformer']['n_heads'],
|
157 |
+
d_fcn=formatter.params['gluformer']['d_fcn'],
|
158 |
+
r_drop=0.2,
|
159 |
+
activ='gelu',
|
160 |
+
num_enc_layers=formatter.params['gluformer']['num_enc_layers'],
|
161 |
+
num_dec_layers=formatter.params['gluformer']['num_dec_layers'],
|
162 |
+
distil=True,
|
163 |
+
len_seq=formatter.params['gluformer']['in_len'],
|
164 |
+
label_len=formatter.params['gluformer']['in_len'] // 3,
|
165 |
+
len_pred=formatter.params['length_pred'],
|
166 |
+
num_dynamic_features=num_dynamic_features,
|
167 |
+
num_static_features=num_static_features
|
168 |
+
)
|
169 |
+
weights = gr.Interface.load(model)
|
170 |
+
assert f"weights for {model} should exist", weights.exists()
|
171 |
+
|
172 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
173 |
+
glufo.load_state_dict(torch.load(str(weights), map_location=torch.device(device), weights_only=False))
|
174 |
+
|
175 |
+
# Define dataset for inference
|
176 |
+
dataset_test_glufo = SamplingDatasetInferenceDual(
|
177 |
+
target_series=series['test']['target'],
|
178 |
+
covariates=series['test']['future'],
|
179 |
+
input_chunk_length=formatter.params['gluformer']['in_len'],
|
180 |
+
output_chunk_length=formatter.params['length_pred'],
|
181 |
+
use_static_covariates=True,
|
182 |
+
array_output_only=True
|
183 |
+
)
|
184 |
+
|
185 |
+
forecasts, _ = glufo.predict(
|
186 |
+
dataset_test_glufo,
|
187 |
+
batch_size=16,####################################################
|
188 |
+
num_samples=10,
|
189 |
+
device='cpu'
|
190 |
+
)
|
191 |
+
figure_path, result = plot_forecast(forecasts, scalers, dataset_test_glufo,filename)
|
192 |
+
|
193 |
+
return result
|
194 |
+
|
195 |
+
|
196 |
+
|
197 |
+
if __name__ == "__main__":
|
198 |
+
predict_glucose_tool()
|
utils/__init__.py
ADDED
File without changes
|
utils/__pycache__/__init__.cpython-311.pyc
ADDED
Binary file (173 Bytes). View file
|
|
utils/__pycache__/darts_dataset.cpython-311.pyc
ADDED
Binary file (38.6 kB). View file
|
|
utils/__pycache__/darts_processing.cpython-311.pyc
ADDED
Binary file (17.2 kB). View file
|
|
utils/darts_dataset.py
ADDED
@@ -0,0 +1,881 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
import yaml
|
4 |
+
import random
|
5 |
+
from typing import Any, BinaryIO, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
from scipy import stats
|
9 |
+
import pandas as pd
|
10 |
+
import darts
|
11 |
+
|
12 |
+
from darts import models
|
13 |
+
from darts import metrics
|
14 |
+
from darts import TimeSeries
|
15 |
+
from darts.dataprocessing.transformers import Scaler
|
16 |
+
from pytorch_lightning.callbacks import Callback
|
17 |
+
|
18 |
+
# for darts dataset
|
19 |
+
from darts.logging import get_logger, raise_if_not
|
20 |
+
|
21 |
+
from darts.utils.data.training_dataset import PastCovariatesTrainingDataset, \
|
22 |
+
DualCovariatesTrainingDataset, \
|
23 |
+
MixedCovariatesTrainingDataset
|
24 |
+
from darts.utils.data.inference_dataset import PastCovariatesInferenceDataset, \
|
25 |
+
DualCovariatesInferenceDataset, \
|
26 |
+
MixedCovariatesInferenceDataset
|
27 |
+
from darts.utils.data.utils import CovariateType
|
28 |
+
|
29 |
+
# import data formatter
|
30 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
31 |
+
from data_formatter.base import *
|
32 |
+
|
33 |
+
def get_valid_sampling_locations(target_series: Union[TimeSeries, Sequence[TimeSeries]],
|
34 |
+
output_chunk_length: int = 12,
|
35 |
+
input_chunk_length: int = 12,
|
36 |
+
random_state: Optional[int] = 0,
|
37 |
+
max_samples_per_ts: Optional[int] = None):
|
38 |
+
"""
|
39 |
+
Get valid sampling indices data for the model.
|
40 |
+
|
41 |
+
Parameters
|
42 |
+
----------
|
43 |
+
target_series
|
44 |
+
The target time series.
|
45 |
+
output_chunk_length
|
46 |
+
The length of the output chunk.
|
47 |
+
input_chunk_length
|
48 |
+
The length of the input chunk.
|
49 |
+
use_static_covariates
|
50 |
+
Whether to use static covariates.
|
51 |
+
max_samples_per_ts
|
52 |
+
The maximum number of samples per time series.
|
53 |
+
"""
|
54 |
+
random.seed(random_state)
|
55 |
+
valid_sampling_locations = {}
|
56 |
+
total_length = input_chunk_length + output_chunk_length
|
57 |
+
for id, series in enumerate(target_series):
|
58 |
+
num_entries = len(series)
|
59 |
+
if num_entries >= total_length:
|
60 |
+
valid_sampling_locations[id] = [i for i in range(num_entries - total_length + 1)]
|
61 |
+
if max_samples_per_ts is not None:
|
62 |
+
updated_sampling_locations = {}
|
63 |
+
for id, locations in valid_sampling_locations.items():
|
64 |
+
if len(locations) > max_samples_per_ts:
|
65 |
+
updated_sampling_locations[id] = random.sample(locations, max_samples_per_ts)
|
66 |
+
else:
|
67 |
+
updated_sampling_locations[id] = locations
|
68 |
+
valid_sampling_locations = updated_sampling_locations
|
69 |
+
|
70 |
+
return valid_sampling_locations
|
71 |
+
|
72 |
+
class SamplingDatasetPast(PastCovariatesTrainingDataset):
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
target_series: Union[TimeSeries, Sequence[TimeSeries]],
|
76 |
+
covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
|
77 |
+
output_chunk_length: int = 12,
|
78 |
+
input_chunk_length: int = 12,
|
79 |
+
use_static_covariates: bool = True,
|
80 |
+
random_state: Optional[int] = 0,
|
81 |
+
max_samples_per_ts: Optional[int] = None,
|
82 |
+
remove_nan: bool = False,
|
83 |
+
) -> None:
|
84 |
+
"""
|
85 |
+
Parameters
|
86 |
+
----------
|
87 |
+
target_series
|
88 |
+
One or a sequence of target `TimeSeries`.
|
89 |
+
covariates:
|
90 |
+
Optionally, one or a sequence of `TimeSeries` containing past-observed covariates. If this parameter is set,
|
91 |
+
the provided sequence must have the same length as that of `target_series`. Moreover, all
|
92 |
+
covariates in the sequence must have a time span large enough to contain all the required slices.
|
93 |
+
The joint slicing of the target and covariates is relying on the time axes of both series.
|
94 |
+
output_chunk_length
|
95 |
+
The length of the "output" series emitted by the model
|
96 |
+
input_chunk_length
|
97 |
+
The length of the "input" series fed to the model
|
98 |
+
use_static_covariates
|
99 |
+
Whether to use/include static covariate data from input series.
|
100 |
+
random_state
|
101 |
+
The random state to use for sampling.
|
102 |
+
max_samples_per_ts
|
103 |
+
The maximum number of samples to be drawn from each time series. If None, all samples will be drawn.
|
104 |
+
remove_nan
|
105 |
+
Whether to remove None from the output. E.g. if no covariates are provided, the covariates output will be None
|
106 |
+
or (optionally) removed from the __getitem__ output.
|
107 |
+
"""
|
108 |
+
super().__init__()
|
109 |
+
self.remove_nan = remove_nan
|
110 |
+
|
111 |
+
self.target_series = (
|
112 |
+
[target_series] if isinstance(target_series, TimeSeries) else target_series
|
113 |
+
)
|
114 |
+
self.covariates = (
|
115 |
+
[covariates] if isinstance(covariates, TimeSeries) else covariates
|
116 |
+
)
|
117 |
+
|
118 |
+
# checks
|
119 |
+
raise_if_not(
|
120 |
+
covariates is None or len(self.target_series) == len(self.covariates),
|
121 |
+
"The provided sequence of target series must have the same length as "
|
122 |
+
"the provided sequence of covariate series.",
|
123 |
+
)
|
124 |
+
|
125 |
+
# get valid sampling locations
|
126 |
+
self.valid_sampling_locations = get_valid_sampling_locations(target_series,
|
127 |
+
output_chunk_length,
|
128 |
+
input_chunk_length,
|
129 |
+
random_state,
|
130 |
+
max_samples_per_ts)
|
131 |
+
|
132 |
+
# set parameters
|
133 |
+
self.output_chunk_length = output_chunk_length
|
134 |
+
self.input_chunk_length = input_chunk_length
|
135 |
+
self.total_length = input_chunk_length + output_chunk_length
|
136 |
+
self.total_number_samples = sum([len(v) for v in self.valid_sampling_locations.values()])
|
137 |
+
self.use_static_covariates = use_static_covariates
|
138 |
+
|
139 |
+
def __len__(self):
|
140 |
+
"""
|
141 |
+
Returns the total number of possible (input, target) splits.
|
142 |
+
"""
|
143 |
+
return self.total_number_samples
|
144 |
+
|
145 |
+
def __getitem__(self, idx: int):
|
146 |
+
# get idx of target series
|
147 |
+
target_idx = 0
|
148 |
+
while idx >= len(self.valid_sampling_locations[target_idx]):
|
149 |
+
idx -= len(self.valid_sampling_locations[target_idx])
|
150 |
+
target_idx += 1
|
151 |
+
# get sampling location within the target series
|
152 |
+
sampling_location = self.valid_sampling_locations[target_idx][idx]
|
153 |
+
# get target series
|
154 |
+
target_series = self.target_series[target_idx].values()
|
155 |
+
past_target_series = target_series[sampling_location : sampling_location + self.input_chunk_length]
|
156 |
+
future_target_series = target_series[sampling_location + self.input_chunk_length : sampling_location + self.total_length]
|
157 |
+
# get covariates
|
158 |
+
if self.covariates is not None:
|
159 |
+
covariates = self.covariates[target_idx].values()
|
160 |
+
covariates = covariates[sampling_location : sampling_location + self.input_chunk_length]
|
161 |
+
else:
|
162 |
+
covariates = None
|
163 |
+
# get static covariates
|
164 |
+
if self.use_static_covariates:
|
165 |
+
static_covariates = self.target_series[target_idx].static_covariates_values(copy=True)
|
166 |
+
else:
|
167 |
+
static_covariates = None
|
168 |
+
|
169 |
+
# return elements that are not None
|
170 |
+
if self.remove_nan:
|
171 |
+
out = []
|
172 |
+
out += [past_target_series] if past_target_series is not None else []
|
173 |
+
out += [covariates] if covariates is not None else []
|
174 |
+
out += [static_covariates] if static_covariates is not None else []
|
175 |
+
out += [future_target_series] if future_target_series is not None else []
|
176 |
+
return tuple(out)
|
177 |
+
else:
|
178 |
+
return tuple([past_target_series,
|
179 |
+
covariates,
|
180 |
+
static_covariates,
|
181 |
+
future_target_series])
|
182 |
+
|
183 |
+
class SamplingDatasetDual(DualCovariatesTrainingDataset):
|
184 |
+
def __init__(
|
185 |
+
self,
|
186 |
+
target_series: Union[TimeSeries, Sequence[TimeSeries]],
|
187 |
+
covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
|
188 |
+
output_chunk_length: int = 12,
|
189 |
+
input_chunk_length: int = 12,
|
190 |
+
use_static_covariates: bool = True,
|
191 |
+
random_state: Optional[int] = 0,
|
192 |
+
max_samples_per_ts: Optional[int] = None,
|
193 |
+
remove_nan: bool = False,
|
194 |
+
) -> None:
|
195 |
+
"""
|
196 |
+
Parameters
|
197 |
+
----------
|
198 |
+
target_series
|
199 |
+
One or a sequence of target `TimeSeries`.
|
200 |
+
covariates:
|
201 |
+
Optionally, one or a sequence of `TimeSeries` containing future-known covariates. If this parameter is set,
|
202 |
+
the provided sequence must have the same length as that of `target_series`. Moreover, all
|
203 |
+
covariates in the sequence must have a time span large enough to contain all the required slices.
|
204 |
+
The joint slicing of the target and covariates is relying on the time axes of both series.
|
205 |
+
output_chunk_length
|
206 |
+
The length of the "output" series emitted by the model
|
207 |
+
input_chunk_length
|
208 |
+
The length of the "input" series fed to the model
|
209 |
+
use_static_covariates
|
210 |
+
Whether to use/include static covariate data from input series.
|
211 |
+
random_state
|
212 |
+
The random state to use for sampling.
|
213 |
+
max_samples_per_ts
|
214 |
+
The maximum number of samples to be drawn from each time series. If None, all samples will be drawn.
|
215 |
+
remove_nan
|
216 |
+
Whether to remove None from the output. E.g. if no covariates are provided, the covariates output will be None
|
217 |
+
or (optionally) removed from the __getitem__ output.
|
218 |
+
"""
|
219 |
+
super().__init__()
|
220 |
+
self.remove_nan = remove_nan
|
221 |
+
|
222 |
+
self.target_series = (
|
223 |
+
[target_series] if isinstance(target_series, TimeSeries) else target_series
|
224 |
+
)
|
225 |
+
self.covariates = (
|
226 |
+
[covariates] if isinstance(covariates, TimeSeries) else covariates
|
227 |
+
)
|
228 |
+
|
229 |
+
# checks
|
230 |
+
raise_if_not(
|
231 |
+
covariates is None or len(self.target_series) == len(self.covariates),
|
232 |
+
"The provided sequence of target series must have the same length as "
|
233 |
+
"the provided sequence of covariate series.",
|
234 |
+
)
|
235 |
+
|
236 |
+
# get valid sampling locations
|
237 |
+
self.valid_sampling_locations = get_valid_sampling_locations(target_series,
|
238 |
+
output_chunk_length,
|
239 |
+
input_chunk_length,
|
240 |
+
random_state,
|
241 |
+
max_samples_per_ts,)
|
242 |
+
|
243 |
+
# set parameters
|
244 |
+
self.output_chunk_length = output_chunk_length
|
245 |
+
self.input_chunk_length = input_chunk_length
|
246 |
+
self.total_length = input_chunk_length + output_chunk_length
|
247 |
+
self.total_number_samples = sum([len(v) for v in self.valid_sampling_locations.values()])
|
248 |
+
self.use_static_covariates = use_static_covariates
|
249 |
+
|
250 |
+
def __len__(self):
|
251 |
+
"""
|
252 |
+
Returns the total number of possible (input, target) splits.
|
253 |
+
"""
|
254 |
+
return self.total_number_samples
|
255 |
+
|
256 |
+
def __getitem__(self, idx: int):
|
257 |
+
# get idx of target series
|
258 |
+
target_idx = 0
|
259 |
+
while idx >= len(self.valid_sampling_locations[target_idx]):
|
260 |
+
idx -= len(self.valid_sampling_locations[target_idx])
|
261 |
+
target_idx += 1
|
262 |
+
# get sampling location within the target series
|
263 |
+
sampling_location = self.valid_sampling_locations[target_idx][idx]
|
264 |
+
# get target series
|
265 |
+
target_series = self.target_series[target_idx].values()
|
266 |
+
past_target_series = target_series[sampling_location : sampling_location + self.input_chunk_length]
|
267 |
+
future_target_series = target_series[sampling_location + self.input_chunk_length : sampling_location + self.total_length]
|
268 |
+
# get covariates
|
269 |
+
if self.covariates is not None:
|
270 |
+
covariates = self.covariates[target_idx].values()
|
271 |
+
past_covariates = covariates[sampling_location : sampling_location + self.input_chunk_length]
|
272 |
+
future_covariates = covariates[sampling_location + self.input_chunk_length : sampling_location + self.total_length]
|
273 |
+
else:
|
274 |
+
past_covariates = None
|
275 |
+
future_covariates = None
|
276 |
+
# get static covariates
|
277 |
+
if self.use_static_covariates:
|
278 |
+
static_covariates = self.target_series[target_idx].static_covariates_values(copy=True)
|
279 |
+
else:
|
280 |
+
static_covariates = None
|
281 |
+
|
282 |
+
# return elements that are not None
|
283 |
+
if self.remove_nan:
|
284 |
+
out = []
|
285 |
+
out += [past_target_series] if past_target_series is not None else []
|
286 |
+
out += [past_covariates] if past_covariates is not None else []
|
287 |
+
out += [future_covariates] if future_covariates is not None else []
|
288 |
+
out += [static_covariates] if static_covariates is not None else []
|
289 |
+
out += [future_target_series] if future_target_series is not None else []
|
290 |
+
return tuple(out)
|
291 |
+
else:
|
292 |
+
return tuple([past_target_series,
|
293 |
+
past_covariates,
|
294 |
+
future_covariates,
|
295 |
+
static_covariates,
|
296 |
+
future_target_series])
|
297 |
+
|
298 |
+
class SamplingDatasetMixed(MixedCovariatesTrainingDataset):
|
299 |
+
def __init__(
|
300 |
+
self,
|
301 |
+
target_series: Union[TimeSeries, Sequence[TimeSeries]],
|
302 |
+
past_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
|
303 |
+
future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
|
304 |
+
output_chunk_length: int = 12,
|
305 |
+
input_chunk_length: int = 12,
|
306 |
+
use_static_covariates: bool = True,
|
307 |
+
random_state: Optional[int] = 0,
|
308 |
+
max_samples_per_ts: Optional[int] = None,
|
309 |
+
remove_nan: bool = False,
|
310 |
+
) -> None:
|
311 |
+
"""
|
312 |
+
Parameters
|
313 |
+
----------
|
314 |
+
target_series
|
315 |
+
One or a sequence of target `TimeSeries`.
|
316 |
+
past_covariates
|
317 |
+
Optionally, one or a sequence of `TimeSeries` containing past-observed covariates. If this parameter is set,
|
318 |
+
the provided sequence must have the same length as that of `target_series`. Moreover, all
|
319 |
+
covariates in the sequence must have a time span large enough to contain all the required slices.
|
320 |
+
The joint slicing of the target and covariates is relying on the time axes of both series.
|
321 |
+
future_covariates
|
322 |
+
Optionally, one or a sequence of `TimeSeries` containing future-known covariates. This has to follow
|
323 |
+
the same constraints as `past_covariates`.
|
324 |
+
output_chunk_length
|
325 |
+
The length of the "output" series emitted by the model
|
326 |
+
input_chunk_length
|
327 |
+
The length of the "input" series fed to the model
|
328 |
+
use_static_covariates
|
329 |
+
Whether to use/include static covariate data from input series.
|
330 |
+
random_state
|
331 |
+
The random state to use for sampling.
|
332 |
+
max_samples_per_ts
|
333 |
+
The maximum number of samples to be drawn from each time series. If None, all samples will be drawn.
|
334 |
+
remove_nan
|
335 |
+
Whether to remove None from the output. E.g. if no covariates are provided, the covariates output will be None
|
336 |
+
or (optionally) removed from the __getitem__ output.
|
337 |
+
"""
|
338 |
+
super().__init__()
|
339 |
+
self.remove_nan = remove_nan
|
340 |
+
|
341 |
+
self.target_series = (
|
342 |
+
[target_series] if isinstance(target_series, TimeSeries) else target_series
|
343 |
+
)
|
344 |
+
self.past_covariates = (
|
345 |
+
[past_covariates] if isinstance(past_covariates, TimeSeries) else past_covariates
|
346 |
+
)
|
347 |
+
self.future_covariates = (
|
348 |
+
[future_covariates] if isinstance(future_covariates, TimeSeries) else future_covariates
|
349 |
+
)
|
350 |
+
|
351 |
+
# checks
|
352 |
+
raise_if_not(
|
353 |
+
future_covariates is None or len(self.target_series) == len(self.future_covariates),
|
354 |
+
"The provided sequence of target series must have the same length as "
|
355 |
+
"the provided sequence of covariate series.",
|
356 |
+
)
|
357 |
+
raise_if_not(
|
358 |
+
past_covariates is None or len(self.target_series) == len(self.past_covariates),
|
359 |
+
"The provided sequence of target series must have the same length as "
|
360 |
+
"the provided sequence of covariate series.",
|
361 |
+
)
|
362 |
+
|
363 |
+
# get valid sampling locations
|
364 |
+
self.valid_sampling_locations = get_valid_sampling_locations(target_series,
|
365 |
+
output_chunk_length,
|
366 |
+
input_chunk_length,
|
367 |
+
random_state,
|
368 |
+
max_samples_per_ts,)
|
369 |
+
|
370 |
+
# set parameters
|
371 |
+
self.output_chunk_length = output_chunk_length
|
372 |
+
self.input_chunk_length = input_chunk_length
|
373 |
+
self.total_length = input_chunk_length + output_chunk_length
|
374 |
+
self.total_number_samples = sum([len(v) for v in self.valid_sampling_locations.values()])
|
375 |
+
self.use_static_covariates = use_static_covariates
|
376 |
+
|
377 |
+
def __len__(self):
|
378 |
+
"""
|
379 |
+
Returns the total number of possible (input, target) splits.
|
380 |
+
"""
|
381 |
+
return self.total_number_samples
|
382 |
+
|
383 |
+
def __getitem__(self, idx: int):
|
384 |
+
# get idx of target series
|
385 |
+
target_idx = 0
|
386 |
+
while idx >= len(self.valid_sampling_locations[target_idx]):
|
387 |
+
idx -= len(self.valid_sampling_locations[target_idx])
|
388 |
+
target_idx += 1
|
389 |
+
# get sampling location within the target series
|
390 |
+
sampling_location = self.valid_sampling_locations[target_idx][idx]
|
391 |
+
# get target series
|
392 |
+
target_series = self.target_series[target_idx].values()
|
393 |
+
past_target_series = target_series[sampling_location : sampling_location + self.input_chunk_length]
|
394 |
+
future_target_series = target_series[sampling_location + self.input_chunk_length : sampling_location + self.total_length]
|
395 |
+
# get past covariates
|
396 |
+
if self.past_covariates is not None:
|
397 |
+
past_covariates = self.past_covariates[target_idx].values()
|
398 |
+
past_covariates = past_covariates[sampling_location : sampling_location + self.input_chunk_length]
|
399 |
+
else:
|
400 |
+
past_covariates = None
|
401 |
+
# get future covariates
|
402 |
+
if self.future_covariates is not None:
|
403 |
+
future_covariates = self.future_covariates[target_idx].values()
|
404 |
+
historic_future_covariates = future_covariates[sampling_location : sampling_location + self.input_chunk_length]
|
405 |
+
future_covariates = future_covariates[sampling_location + self.input_chunk_length : sampling_location + self.total_length]
|
406 |
+
else:
|
407 |
+
future_covariates = None
|
408 |
+
historic_future_covariates = None
|
409 |
+
# get static covariates
|
410 |
+
if self.use_static_covariates:
|
411 |
+
static_covariates = self.target_series[target_idx].static_covariates_values(copy=True)
|
412 |
+
else:
|
413 |
+
static_covariates = None
|
414 |
+
|
415 |
+
# return elements that are not None
|
416 |
+
if self.remove_nan:
|
417 |
+
out = []
|
418 |
+
out += [past_target_series] if past_target_series is not None else []
|
419 |
+
out += [past_covariates] if past_covariates is not None else []
|
420 |
+
out += [historic_future_covariates] if historic_future_covariates is not None else []
|
421 |
+
out += [future_covariates] if future_covariates is not None else []
|
422 |
+
out += [static_covariates] if static_covariates is not None else []
|
423 |
+
out += [future_target_series] if future_target_series is not None else []
|
424 |
+
return tuple(out)
|
425 |
+
else:
|
426 |
+
return tuple([past_target_series,
|
427 |
+
past_covariates,
|
428 |
+
historic_future_covariates,
|
429 |
+
future_covariates,
|
430 |
+
static_covariates,
|
431 |
+
future_target_series])
|
432 |
+
|
433 |
+
class SamplingDatasetInferenceMixed(MixedCovariatesInferenceDataset):
|
434 |
+
def __init__(
|
435 |
+
self,
|
436 |
+
target_series: Union[TimeSeries, Sequence[TimeSeries]],
|
437 |
+
past_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
|
438 |
+
future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
|
439 |
+
n: int = 1,
|
440 |
+
input_chunk_length: int = 12,
|
441 |
+
output_chunk_length: int = 1,
|
442 |
+
use_static_covariates: bool = True,
|
443 |
+
random_state: Optional[int] = 0,
|
444 |
+
max_samples_per_ts: Optional[int] = None,
|
445 |
+
array_output_only: bool = False,
|
446 |
+
):
|
447 |
+
"""
|
448 |
+
Parameters
|
449 |
+
----------
|
450 |
+
target_series
|
451 |
+
One or a sequence of target `TimeSeries`.
|
452 |
+
past_covariates
|
453 |
+
Optionally, one or a sequence of `TimeSeries` containing past-observed covariates. If this parameter is set,
|
454 |
+
the provided sequence must have the same length as that of `target_series`. Moreover, all
|
455 |
+
covariates in the sequence must have a time span large enough to contain all the required slices.
|
456 |
+
The joint slicing of the target and covariates is relying on the time axes of both series.
|
457 |
+
future_covariates
|
458 |
+
Optionally, one or a sequence of `TimeSeries` containing future-known covariates. This has to follow
|
459 |
+
the same constraints as `past_covariates`.
|
460 |
+
n
|
461 |
+
Number of predictions into the future, could be greater than the output chunk length, in which case, the model
|
462 |
+
will be called autorregressively.
|
463 |
+
output_chunk_length
|
464 |
+
The length of the "output" series emitted by the model
|
465 |
+
input_chunk_length
|
466 |
+
The length of the "input" series fed to the model
|
467 |
+
use_static_covariates
|
468 |
+
Whether to use/include static covariate data from input series.
|
469 |
+
random_state
|
470 |
+
The random state to use for sampling.
|
471 |
+
max_samples_per_ts
|
472 |
+
The maximum number of samples to be drawn from each time series. If None, all samples will be drawn.
|
473 |
+
array_output_only
|
474 |
+
Whether __getitem__ returns only the arrays or adds the full `TimeSeries` object to the output tuple
|
475 |
+
This may cause problems with the torch collate and loader functions but works for Darts.
|
476 |
+
"""
|
477 |
+
super().__init__(target_series = target_series,
|
478 |
+
past_covariates = past_covariates,
|
479 |
+
future_covariates = future_covariates,
|
480 |
+
n = n,
|
481 |
+
input_chunk_length = input_chunk_length,
|
482 |
+
output_chunk_length = output_chunk_length,)
|
483 |
+
|
484 |
+
self.target_series = (
|
485 |
+
[target_series] if isinstance(target_series, TimeSeries) else target_series
|
486 |
+
)
|
487 |
+
self.past_covariates = (
|
488 |
+
[past_covariates] if isinstance(past_covariates, TimeSeries) else past_covariates
|
489 |
+
)
|
490 |
+
self.future_covariates = (
|
491 |
+
[future_covariates] if isinstance(future_covariates, TimeSeries) else future_covariates
|
492 |
+
)
|
493 |
+
|
494 |
+
# checks
|
495 |
+
raise_if_not(
|
496 |
+
future_covariates is None or len(self.target_series) == len(self.future_covariates),
|
497 |
+
"The provided sequence of target series must have the same length as "
|
498 |
+
"the provided sequence of covariate series.",
|
499 |
+
)
|
500 |
+
raise_if_not(
|
501 |
+
past_covariates is None or len(self.target_series) == len(self.past_covariates),
|
502 |
+
"The provided sequence of target series must have the same length as "
|
503 |
+
"the provided sequence of covariate series.",
|
504 |
+
)
|
505 |
+
|
506 |
+
# get valid sampling locations
|
507 |
+
self.valid_sampling_locations = get_valid_sampling_locations(target_series,
|
508 |
+
output_chunk_length,
|
509 |
+
input_chunk_length,
|
510 |
+
random_state,
|
511 |
+
max_samples_per_ts,)
|
512 |
+
|
513 |
+
# set parameters
|
514 |
+
self.output_chunk_length = output_chunk_length
|
515 |
+
self.input_chunk_length = input_chunk_length
|
516 |
+
self.total_length = input_chunk_length + output_chunk_length
|
517 |
+
self.total_number_samples = sum([len(v) for v in self.valid_sampling_locations.values()])
|
518 |
+
self.use_static_covariates = use_static_covariates
|
519 |
+
self.array_output_only = array_output_only
|
520 |
+
|
521 |
+
def __len__(self):
|
522 |
+
"""
|
523 |
+
Returns the total number of possible (input, target) splits.
|
524 |
+
"""
|
525 |
+
return self.total_number_samples
|
526 |
+
|
527 |
+
def __getitem__(self, idx: int):
|
528 |
+
# get idx of target series
|
529 |
+
target_idx = 0
|
530 |
+
while idx >= len(self.valid_sampling_locations[target_idx]):
|
531 |
+
idx -= len(self.valid_sampling_locations[target_idx])
|
532 |
+
target_idx += 1
|
533 |
+
# get sampling location within the target series
|
534 |
+
sampling_location = self.valid_sampling_locations[target_idx][idx]
|
535 |
+
# get target series
|
536 |
+
target_series = self.target_series[target_idx]
|
537 |
+
past_target_series_with_time = target_series[sampling_location : sampling_location + self.input_chunk_length]
|
538 |
+
past_end = past_target_series_with_time.time_index[-1]
|
539 |
+
target_series = self.target_series[target_idx].values()
|
540 |
+
past_target_series = target_series[sampling_location : sampling_location + self.input_chunk_length]
|
541 |
+
# get past covariates
|
542 |
+
if self.past_covariates is not None:
|
543 |
+
past_covariates = self.past_covariates[target_idx].values()
|
544 |
+
past_covariates = past_covariates[sampling_location : sampling_location + self.input_chunk_length]
|
545 |
+
future_past_covariates = past_covariates[sampling_location + self.input_chunk_length : sampling_location + self.total_length]
|
546 |
+
else:
|
547 |
+
past_covariates = None
|
548 |
+
future_past_covariates = None
|
549 |
+
# get future covariates
|
550 |
+
if self.future_covariates is not None:
|
551 |
+
future_covariates = self.future_covariates[target_idx].values()
|
552 |
+
historic_future_covariates = future_covariates[sampling_location : sampling_location + self.input_chunk_length]
|
553 |
+
future_covariates = future_covariates[sampling_location + self.input_chunk_length : sampling_location + self.total_length]
|
554 |
+
else:
|
555 |
+
future_covariates = None
|
556 |
+
historic_future_covariates = None
|
557 |
+
# get static covariates
|
558 |
+
if self.use_static_covariates:
|
559 |
+
static_covariates = self.target_series[target_idx].static_covariates_values(copy=True)
|
560 |
+
else:
|
561 |
+
static_covariates = None
|
562 |
+
# whether to remove Timeseries and None and return only arrays
|
563 |
+
|
564 |
+
if self.array_output_only:
|
565 |
+
out = []
|
566 |
+
out += [past_target_series] if past_target_series is not None else []
|
567 |
+
out += [past_covariates] if past_covariates is not None else []
|
568 |
+
out += [historic_future_covariates] if historic_future_covariates is not None else []
|
569 |
+
out += [future_covariates] if future_covariates is not None else []
|
570 |
+
out += [future_past_covariates] if future_past_covariates is not None else []
|
571 |
+
out += [static_covariates] if static_covariates is not None else []
|
572 |
+
return tuple(out)
|
573 |
+
else:
|
574 |
+
return tuple([past_target_series,
|
575 |
+
past_covariates,
|
576 |
+
historic_future_covariates,
|
577 |
+
future_covariates,
|
578 |
+
future_past_covariates,
|
579 |
+
static_covariates,
|
580 |
+
past_target_series_with_time,
|
581 |
+
past_end + past_target_series_with_time.freq
|
582 |
+
])
|
583 |
+
|
584 |
+
def evalsample(
|
585 |
+
self, idx: int
|
586 |
+
) -> TimeSeries:
|
587 |
+
"""
|
588 |
+
Returns the future target series at the given index.
|
589 |
+
"""
|
590 |
+
# get idx of target series
|
591 |
+
target_idx = 0
|
592 |
+
while idx >= len(self.valid_sampling_locations[target_idx]):
|
593 |
+
idx -= len(self.valid_sampling_locations[target_idx])
|
594 |
+
target_idx += 1
|
595 |
+
# get sampling location within the target series
|
596 |
+
sampling_location = self.valid_sampling_locations[target_idx][idx]
|
597 |
+
# get target series
|
598 |
+
target_series = self.target_series[target_idx][sampling_location + self.input_chunk_length : sampling_location + self.total_length]
|
599 |
+
|
600 |
+
return target_series
|
601 |
+
|
602 |
+
class SamplingDatasetInferencePast(PastCovariatesInferenceDataset):
|
603 |
+
def __init__(
|
604 |
+
self,
|
605 |
+
target_series: Union[TimeSeries, Sequence[TimeSeries]],
|
606 |
+
covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
|
607 |
+
n: int = 1,
|
608 |
+
input_chunk_length: int = 12,
|
609 |
+
output_chunk_length: int = 1,
|
610 |
+
use_static_covariates: bool = True,
|
611 |
+
random_state: Optional[int] = 0,
|
612 |
+
max_samples_per_ts: Optional[int] = None,
|
613 |
+
array_output_only: bool = False,
|
614 |
+
):
|
615 |
+
"""
|
616 |
+
Parameters
|
617 |
+
----------
|
618 |
+
target_series
|
619 |
+
One or a sequence of target `TimeSeries`.
|
620 |
+
past_covariates
|
621 |
+
Optionally, one or a sequence of `TimeSeries` containing past-observed covariates. If this parameter is set,
|
622 |
+
the provided sequence must have the same length as that of `target_series`. Moreover, all
|
623 |
+
covariates in the sequence must have a time span large enough to contain all the required slices.
|
624 |
+
The joint slicing of the target and covariates is relying on the time axes of both series.
|
625 |
+
n
|
626 |
+
Number of predictions into the future, could be greater than the output chunk length, in which case, the model
|
627 |
+
will be called autorregressively.
|
628 |
+
output_chunk_length
|
629 |
+
The length of the "output" series emitted by the model
|
630 |
+
input_chunk_length
|
631 |
+
The length of the "input" series fed to the model
|
632 |
+
use_static_covariates
|
633 |
+
Whether to use/include static covariate data from input series.
|
634 |
+
random_state
|
635 |
+
The random state to use for sampling.
|
636 |
+
max_samples_per_ts
|
637 |
+
The maximum number of samples to be drawn from each time series. If None, all samples will be drawn.
|
638 |
+
array_output_only
|
639 |
+
Whether __getitem__ returns only the arrays or adds the full `TimeSeries` object to the output tuple
|
640 |
+
This may cause problems with the torch collate and loader functions but works for Darts.
|
641 |
+
"""
|
642 |
+
super().__init__(target_series = target_series,
|
643 |
+
covariates = covariates,
|
644 |
+
n = n,
|
645 |
+
input_chunk_length = input_chunk_length,
|
646 |
+
output_chunk_length = output_chunk_length,)
|
647 |
+
|
648 |
+
self.target_series = (
|
649 |
+
[target_series] if isinstance(target_series, TimeSeries) else target_series
|
650 |
+
)
|
651 |
+
self.covariates = (
|
652 |
+
[covariates] if isinstance(covariates, TimeSeries) else covariates
|
653 |
+
)
|
654 |
+
|
655 |
+
raise_if_not(
|
656 |
+
covariates is None or len(self.target_series) == len(self.covariates),
|
657 |
+
"The provided sequence of target series must have the same length as "
|
658 |
+
"the provided sequence of covariate series.",
|
659 |
+
)
|
660 |
+
|
661 |
+
# get valid sampling locations
|
662 |
+
self.valid_sampling_locations = get_valid_sampling_locations(target_series,
|
663 |
+
output_chunk_length,
|
664 |
+
input_chunk_length,
|
665 |
+
random_state,
|
666 |
+
max_samples_per_ts,)
|
667 |
+
|
668 |
+
# set parameters
|
669 |
+
self.output_chunk_length = output_chunk_length
|
670 |
+
self.input_chunk_length = input_chunk_length
|
671 |
+
self.total_length = input_chunk_length + output_chunk_length
|
672 |
+
self.total_number_samples = sum([len(v) for v in self.valid_sampling_locations.values()])
|
673 |
+
self.use_static_covariates = use_static_covariates
|
674 |
+
self.array_output_only = array_output_only
|
675 |
+
|
676 |
+
def __len__(self):
|
677 |
+
"""
|
678 |
+
Returns the total number of possible (input, target) splits.
|
679 |
+
"""
|
680 |
+
return self.total_number_samples
|
681 |
+
|
682 |
+
def __getitem__(self, idx: int):
|
683 |
+
# get idx of target series
|
684 |
+
target_idx = 0
|
685 |
+
while idx >= len(self.valid_sampling_locations[target_idx]):
|
686 |
+
idx -= len(self.valid_sampling_locations[target_idx])
|
687 |
+
target_idx += 1
|
688 |
+
# get sampling location within the target series
|
689 |
+
sampling_location = self.valid_sampling_locations[target_idx][idx]
|
690 |
+
# get target series
|
691 |
+
target_series = self.target_series[target_idx]
|
692 |
+
past_target_series_with_time = target_series[sampling_location : sampling_location + self.input_chunk_length]
|
693 |
+
past_end = past_target_series_with_time.time_index[-1]
|
694 |
+
target_series = self.target_series[target_idx].values()
|
695 |
+
past_target_series = target_series[sampling_location : sampling_location + self.input_chunk_length]
|
696 |
+
# get past covariates
|
697 |
+
if self.covariates is not None:
|
698 |
+
past_covariates = self.covariates[target_idx].values()
|
699 |
+
past_covariates = past_covariates[sampling_location : sampling_location + self.input_chunk_length]
|
700 |
+
future_past_covariates = past_covariates[sampling_location + self.input_chunk_length : sampling_location + self.total_length]
|
701 |
+
else:
|
702 |
+
past_covariates = None
|
703 |
+
future_past_covariates = None
|
704 |
+
# get static covariates
|
705 |
+
if self.use_static_covariates:
|
706 |
+
static_covariates = self.target_series[target_idx].static_covariates_values(copy=True)
|
707 |
+
else:
|
708 |
+
static_covariates = None
|
709 |
+
# return arrays or arrays with TimeSeries
|
710 |
+
if self.array_output_only:
|
711 |
+
out = []
|
712 |
+
out += [past_target_series] if past_target_series is not None else []
|
713 |
+
out += [past_covariates] if past_covariates is not None else []
|
714 |
+
out += [future_past_covariates] if future_past_covariates is not None else []
|
715 |
+
out += [static_covariates] if static_covariates is not None else []
|
716 |
+
return tuple(out)
|
717 |
+
else:
|
718 |
+
return tuple([past_target_series,
|
719 |
+
past_covariates,
|
720 |
+
future_past_covariates,
|
721 |
+
static_covariates,
|
722 |
+
past_target_series_with_time,
|
723 |
+
past_end + past_target_series_with_time.freq])
|
724 |
+
|
725 |
+
def evalsample(
|
726 |
+
self, idx: int
|
727 |
+
) -> TimeSeries:
|
728 |
+
"""
|
729 |
+
Returns the future target series at the given index.
|
730 |
+
"""
|
731 |
+
# get idx of target series
|
732 |
+
target_idx = 0
|
733 |
+
while idx >= len(self.valid_sampling_locations[target_idx]):
|
734 |
+
idx -= len(self.valid_sampling_locations[target_idx])
|
735 |
+
target_idx += 1
|
736 |
+
# get sampling location within the target series
|
737 |
+
sampling_location = self.valid_sampling_locations[target_idx][idx]
|
738 |
+
# get target series
|
739 |
+
target_series = self.target_series[target_idx][sampling_location + self.input_chunk_length : sampling_location + self.total_length]
|
740 |
+
|
741 |
+
return target_series
|
742 |
+
|
743 |
+
class SamplingDatasetInferenceDual(DualCovariatesInferenceDataset):
|
744 |
+
def __init__(
|
745 |
+
self,
|
746 |
+
target_series: Union[TimeSeries, Sequence[TimeSeries]],
|
747 |
+
covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
|
748 |
+
n: int = 12,
|
749 |
+
input_chunk_length: int = 12,
|
750 |
+
output_chunk_length: int = 1,
|
751 |
+
use_static_covariates: bool = True,
|
752 |
+
random_state: Optional[int] = 0,
|
753 |
+
max_samples_per_ts: Optional[int] = None,
|
754 |
+
array_output_only: bool = False,
|
755 |
+
):
|
756 |
+
"""
|
757 |
+
Parameters
|
758 |
+
----------
|
759 |
+
target_series
|
760 |
+
One or a sequence of target `TimeSeries`.
|
761 |
+
covariates
|
762 |
+
Optionally, some future-known covariates that are used for predictions. This argument is required
|
763 |
+
if the model was trained with future-known covariates.
|
764 |
+
n
|
765 |
+
Number of predictions into the future, could be greater than the output chunk length, in which case, the model
|
766 |
+
will be called autorregressively.
|
767 |
+
output_chunk_length
|
768 |
+
The length of the "output" series emitted by the model
|
769 |
+
input_chunk_length
|
770 |
+
The length of the "input" series fed to the model
|
771 |
+
use_static_covariates
|
772 |
+
Whether to use/include static covariate data from input series.
|
773 |
+
random_state
|
774 |
+
The random state to use for sampling.
|
775 |
+
max_samples_per_ts
|
776 |
+
The maximum number of samples to be drawn from each time series. If None, all samples will be drawn.
|
777 |
+
array_output_only
|
778 |
+
Whether __getitem__ returns only the arrays or adds the full `TimeSeries` object to the output tuple
|
779 |
+
This may cause problems with the torch collate and loader functions but works for Darts.
|
780 |
+
"""
|
781 |
+
super().__init__(target_series = target_series,
|
782 |
+
covariates = covariates,
|
783 |
+
n = n,
|
784 |
+
input_chunk_length = input_chunk_length,
|
785 |
+
output_chunk_length = output_chunk_length,)
|
786 |
+
|
787 |
+
self.target_series = (
|
788 |
+
[target_series] if isinstance(target_series, TimeSeries) else target_series
|
789 |
+
)
|
790 |
+
self.covariates = (
|
791 |
+
[covariates] if isinstance(covariates, TimeSeries) else covariates
|
792 |
+
)
|
793 |
+
|
794 |
+
raise_if_not(
|
795 |
+
covariates is None or len(self.target_series) == len(self.covariates),
|
796 |
+
"The provided sequence of target series must have the same length as "
|
797 |
+
"the provided sequence of covariate series.",
|
798 |
+
)
|
799 |
+
|
800 |
+
# get valid sampling locations
|
801 |
+
self.valid_sampling_locations = get_valid_sampling_locations(target_series,
|
802 |
+
output_chunk_length,
|
803 |
+
input_chunk_length,
|
804 |
+
random_state,
|
805 |
+
max_samples_per_ts,)
|
806 |
+
|
807 |
+
# set parameters
|
808 |
+
self.output_chunk_length = output_chunk_length
|
809 |
+
self.input_chunk_length = input_chunk_length
|
810 |
+
self.total_length = input_chunk_length + output_chunk_length
|
811 |
+
self.total_number_samples = sum([len(v) for v in self.valid_sampling_locations.values()])
|
812 |
+
self.use_static_covariates = use_static_covariates
|
813 |
+
self.array_output_only = array_output_only
|
814 |
+
|
815 |
+
def __len__(self):
|
816 |
+
"""
|
817 |
+
Returns the total number of possible (input, target) splits.
|
818 |
+
"""
|
819 |
+
return self.total_number_samples
|
820 |
+
|
821 |
+
def __getitem__(self, idx: int):
|
822 |
+
# get idx of target series
|
823 |
+
target_idx = 0
|
824 |
+
while idx >= len(self.valid_sampling_locations[target_idx]):
|
825 |
+
idx -= len(self.valid_sampling_locations[target_idx])
|
826 |
+
target_idx += 1
|
827 |
+
# get sampling location within the target series
|
828 |
+
sampling_location = self.valid_sampling_locations[target_idx][idx]
|
829 |
+
# get target series
|
830 |
+
target_series = self.target_series[target_idx]
|
831 |
+
past_target_series_with_time = target_series[sampling_location : sampling_location + self.input_chunk_length]
|
832 |
+
past_end = past_target_series_with_time.time_index[-1]
|
833 |
+
target_series = self.target_series[target_idx].values()
|
834 |
+
past_target_series = target_series[sampling_location : sampling_location + self.input_chunk_length]
|
835 |
+
# get past covariates
|
836 |
+
if self.covariates is not None:
|
837 |
+
future_covariates = self.covariates[target_idx].values()
|
838 |
+
historic_future_covariates = future_covariates[sampling_location : sampling_location + self.input_chunk_length]
|
839 |
+
future_covariates = future_covariates[sampling_location + self.input_chunk_length : sampling_location + self.total_length]
|
840 |
+
else:
|
841 |
+
historic_future_covariates = None
|
842 |
+
future_covariates = None
|
843 |
+
# get static covariates
|
844 |
+
if self.use_static_covariates:
|
845 |
+
static_covariates = self.target_series[target_idx].static_covariates_values(copy=True)
|
846 |
+
else:
|
847 |
+
static_covariates = None
|
848 |
+
# return arrays or arrays with TimeSeries
|
849 |
+
if self.array_output_only:
|
850 |
+
out = []
|
851 |
+
out += [past_target_series] if past_target_series is not None else []
|
852 |
+
out += [historic_future_covariates] if historic_future_covariates is not None else []
|
853 |
+
out += [future_covariates] if future_covariates is not None else []
|
854 |
+
out += [static_covariates] if static_covariates is not None else []
|
855 |
+
return tuple(out)
|
856 |
+
else:
|
857 |
+
return tuple([past_target_series,
|
858 |
+
historic_future_covariates,
|
859 |
+
future_covariates,
|
860 |
+
static_covariates,
|
861 |
+
past_target_series_with_time,
|
862 |
+
past_end + past_target_series_with_time.freq,])
|
863 |
+
|
864 |
+
def evalsample(
|
865 |
+
self, idx: int
|
866 |
+
) -> TimeSeries:
|
867 |
+
"""
|
868 |
+
Returns the future target series at the given index.
|
869 |
+
"""
|
870 |
+
# get idx of target series
|
871 |
+
target_idx = 0
|
872 |
+
while idx >= len(self.valid_sampling_locations[target_idx]):
|
873 |
+
idx -= len(self.valid_sampling_locations[target_idx])
|
874 |
+
target_idx += 1
|
875 |
+
# get sampling location within the target series
|
876 |
+
sampling_location = self.valid_sampling_locations[target_idx][idx]
|
877 |
+
# get target series
|
878 |
+
target_series = self.target_series[target_idx][sampling_location + self.input_chunk_length : sampling_location + self.total_length]
|
879 |
+
|
880 |
+
return target_series
|
881 |
+
|
utils/darts_evaluation.py
ADDED
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
import yaml
|
4 |
+
import random
|
5 |
+
from typing import Any, BinaryIO, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
from scipy import stats
|
9 |
+
import pandas as pd
|
10 |
+
import darts
|
11 |
+
|
12 |
+
from darts import models
|
13 |
+
from darts import metrics
|
14 |
+
from darts import TimeSeries
|
15 |
+
|
16 |
+
# import data formatter
|
17 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
18 |
+
from data_formatter.base import *
|
19 |
+
from utils.darts_processing import *
|
20 |
+
|
21 |
+
def _get_values(
|
22 |
+
series: TimeSeries, stochastic_quantile: Optional[float] = 0.5
|
23 |
+
) -> np.ndarray:
|
24 |
+
"""
|
25 |
+
Returns the numpy values of a time series.
|
26 |
+
For stochastic series, return either all sample values with (stochastic_quantile=None) or the quantile sample value
|
27 |
+
with (stochastic_quantile {>=0,<=1})
|
28 |
+
"""
|
29 |
+
if series.is_deterministic:
|
30 |
+
series_values = series.univariate_values()
|
31 |
+
else: # stochastic
|
32 |
+
if stochastic_quantile is None:
|
33 |
+
series_values = series.all_values(copy=False)
|
34 |
+
else:
|
35 |
+
series_values = series.quantile_timeseries(
|
36 |
+
quantile=stochastic_quantile
|
37 |
+
).univariate_values()
|
38 |
+
return series_values
|
39 |
+
|
40 |
+
def _get_values_or_raise(
|
41 |
+
series_a: TimeSeries,
|
42 |
+
series_b: TimeSeries,
|
43 |
+
intersect: bool,
|
44 |
+
stochastic_quantile: Optional[float] = 0.5,
|
45 |
+
remove_nan_union: bool = False,
|
46 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
47 |
+
"""Returns the processed numpy values of two time series. Processing can be customized with arguments
|
48 |
+
`intersect, stochastic_quantile, remove_nan_union`.
|
49 |
+
|
50 |
+
Raises a ValueError if the two time series (or their intersection) do not have the same time index.
|
51 |
+
|
52 |
+
Parameters
|
53 |
+
----------
|
54 |
+
series_a
|
55 |
+
A univariate deterministic ``TimeSeries`` instance (the actual series).
|
56 |
+
series_b
|
57 |
+
A univariate (deterministic or stochastic) ``TimeSeries`` instance (the predicted series).
|
58 |
+
intersect
|
59 |
+
A boolean for whether or not to only consider the time intersection between `series_a` and `series_b`
|
60 |
+
stochastic_quantile
|
61 |
+
Optionally, for stochastic predicted series, return either all sample values with (`stochastic_quantile=None`)
|
62 |
+
or any deterministic quantile sample values by setting `stochastic_quantile=quantile` {>=0,<=1}.
|
63 |
+
remove_nan_union
|
64 |
+
By setting `remove_non_union` to True, remove all indices from `series_a` and `series_b` which have a NaN value
|
65 |
+
in either of the two input series.
|
66 |
+
"""
|
67 |
+
series_a_common = series_a.slice_intersect(series_b) if intersect else series_a
|
68 |
+
series_b_common = series_b.slice_intersect(series_a) if intersect else series_b
|
69 |
+
|
70 |
+
series_a_det = _get_values(series_a_common, stochastic_quantile=stochastic_quantile)
|
71 |
+
series_b_det = _get_values(series_b_common, stochastic_quantile=stochastic_quantile)
|
72 |
+
|
73 |
+
if not remove_nan_union:
|
74 |
+
return series_a_det, series_b_det
|
75 |
+
|
76 |
+
b_is_deterministic = bool(len(series_b_det.shape) == 1)
|
77 |
+
if b_is_deterministic:
|
78 |
+
isnan_mask = np.logical_or(np.isnan(series_a_det), np.isnan(series_b_det))
|
79 |
+
else:
|
80 |
+
isnan_mask = np.logical_or(
|
81 |
+
np.isnan(series_a_det), np.isnan(series_b_det).any(axis=2).flatten()
|
82 |
+
)
|
83 |
+
return np.delete(series_a_det, isnan_mask), np.delete(
|
84 |
+
series_b_det, isnan_mask, axis=0
|
85 |
+
)
|
86 |
+
|
87 |
+
def rescale_and_backtest(series: Union[TimeSeries,
|
88 |
+
Sequence[TimeSeries]],
|
89 |
+
forecasts: Union[TimeSeries,
|
90 |
+
Sequence[TimeSeries],
|
91 |
+
Sequence[Sequence[TimeSeries]]],
|
92 |
+
metric: Union[
|
93 |
+
Callable[[TimeSeries, TimeSeries], float],
|
94 |
+
List[Callable[[TimeSeries, TimeSeries], float]],
|
95 |
+
],
|
96 |
+
scaler: Callable[[TimeSeries], TimeSeries] = None,
|
97 |
+
reduction: Union[Callable[[np.ndarray], float], None] = np.mean,
|
98 |
+
likelihood: str = "GaussianMean",
|
99 |
+
cal_thresholds: Optional[np.ndarray] = np.linspace(0, 1, 11),
|
100 |
+
):
|
101 |
+
"""
|
102 |
+
Backtest the historical forecasts (as provided by Darts) on the series.
|
103 |
+
|
104 |
+
Parameters
|
105 |
+
----------
|
106 |
+
series
|
107 |
+
The target time series.
|
108 |
+
forecasts
|
109 |
+
The forecasts.
|
110 |
+
scaler
|
111 |
+
The scaler used to scale the series.
|
112 |
+
metric
|
113 |
+
The metric or metrics to use for backtesting.
|
114 |
+
reduction
|
115 |
+
The reduction to apply to the metric.
|
116 |
+
likelihood
|
117 |
+
The likelihood to use for evaluating the model.
|
118 |
+
cal_thresholds
|
119 |
+
The thresholds to use for computing the calibration error.
|
120 |
+
|
121 |
+
Returns
|
122 |
+
-------
|
123 |
+
np.ndarray
|
124 |
+
Error array. If the reduction is none, array is of shape (n, p)
|
125 |
+
where n is the total number of samples (forecasts) and p is the number of metrics.
|
126 |
+
If the reduction is not none, array is of shape (k, p), where k is the number of series.
|
127 |
+
float
|
128 |
+
The estimated log-likelihood of the model on the data.
|
129 |
+
np.ndarray
|
130 |
+
The ECE for each time point in the forecast.
|
131 |
+
"""
|
132 |
+
series = [series] if isinstance(series, TimeSeries) else series
|
133 |
+
forecasts = [forecasts] if isinstance(forecasts, TimeSeries) else forecasts
|
134 |
+
metric = [metric] if not isinstance(metric, list) else metric
|
135 |
+
|
136 |
+
# compute errors: 1) reverse scaling forecasts and true values, 2)compute errors
|
137 |
+
backtest_list = []
|
138 |
+
for idx in range(len(series)):
|
139 |
+
if scaler is not None:
|
140 |
+
series[idx] = scaler.inverse_transform(series[idx])
|
141 |
+
forecasts[idx] = [scaler.inverse_transform(f) for f in forecasts[idx]]
|
142 |
+
errors = [
|
143 |
+
[metric_f(series[idx], f) for metric_f in metric]
|
144 |
+
if len(metric) > 1
|
145 |
+
else metric[0](series[idx], f)
|
146 |
+
for f in forecasts[idx]
|
147 |
+
]
|
148 |
+
if reduction is None:
|
149 |
+
backtest_list.append(np.array(errors))
|
150 |
+
else:
|
151 |
+
backtest_list.append(reduction(np.array(errors), axis=0))
|
152 |
+
backtest_list = np.vstack(backtest_list)
|
153 |
+
|
154 |
+
if likelihood == "GaussianMean":
|
155 |
+
# compute likelihood
|
156 |
+
est_var = []
|
157 |
+
for idx, target_ts in enumerate(series):
|
158 |
+
est_var += [metrics.mse(target_ts, f) for f in forecasts[idx]]
|
159 |
+
est_var = np.mean(est_var)
|
160 |
+
forecast_len = forecasts[0][0].n_timesteps
|
161 |
+
log_likelihood = -0.5*forecast_len - 0.5*np.log(2*np.pi*est_var)
|
162 |
+
|
163 |
+
# compute calibration error: 1) cdf values 2) compute calibration error
|
164 |
+
# compute the cdf values
|
165 |
+
cdf_vals = []
|
166 |
+
for idx in range(len(series)):
|
167 |
+
for forecast in forecasts[idx]:
|
168 |
+
y_true, y_pred = _get_values_or_raise(series[idx],
|
169 |
+
forecast,
|
170 |
+
intersect=True,
|
171 |
+
remove_nan_union=True)
|
172 |
+
y_true, y_pred = y_true.flatten(), y_pred.flatten()
|
173 |
+
cdf_vals.append(stats.norm.cdf(y_true, loc=y_pred, scale=np.sqrt(est_var)))
|
174 |
+
cdf_vals = np.vstack(cdf_vals)
|
175 |
+
# compute the prediction calibration
|
176 |
+
cal_error = np.zeros(forecasts[0][0].n_timesteps)
|
177 |
+
for p in cal_thresholds:
|
178 |
+
est_p = (cdf_vals <= p).astype(float)
|
179 |
+
est_p = np.mean(est_p, axis=0)
|
180 |
+
cal_error += (est_p - p) ** 2
|
181 |
+
|
182 |
+
return backtest_list, log_likelihood, cal_error
|
183 |
+
|
184 |
+
def rescale_and_test(series: Union[TimeSeries,
|
185 |
+
Sequence[TimeSeries]],
|
186 |
+
forecasts: Union[TimeSeries,
|
187 |
+
Sequence[TimeSeries]],
|
188 |
+
metric: Union[
|
189 |
+
Callable[[TimeSeries, TimeSeries], float],
|
190 |
+
List[Callable[[TimeSeries, TimeSeries], float]],
|
191 |
+
],
|
192 |
+
scaler: Callable[[TimeSeries], TimeSeries] = None,
|
193 |
+
likelihood: str = "GaussianMean",
|
194 |
+
cal_thresholds: Optional[np.ndarray] = np.linspace(0, 1, 11),
|
195 |
+
):
|
196 |
+
"""
|
197 |
+
Test the forecasts on the series.
|
198 |
+
|
199 |
+
Parameters
|
200 |
+
----------
|
201 |
+
series
|
202 |
+
The target time series.
|
203 |
+
forecasts
|
204 |
+
The forecasts.
|
205 |
+
scaler
|
206 |
+
The scaler used to scale the series.
|
207 |
+
metric
|
208 |
+
The metric or metrics to use for backtesting.
|
209 |
+
reduction
|
210 |
+
The reduction to apply to the metric.
|
211 |
+
likelihood
|
212 |
+
The likelihood to use for evaluating the likelihood and calibration of model.
|
213 |
+
cal_thresholds
|
214 |
+
The thresholds to use for computing the calibration error.
|
215 |
+
|
216 |
+
Returns
|
217 |
+
-------
|
218 |
+
np.ndarray
|
219 |
+
Error array. If the reduction is none, array is of shape (n, p)
|
220 |
+
where n is the total number of samples (forecasts) and p is the number of metrics.
|
221 |
+
If the reduction is not none, array is of shape (k, p), where k is the number of series.
|
222 |
+
float
|
223 |
+
The estimated log-likelihood of the model on the data.
|
224 |
+
np.ndarray
|
225 |
+
The ECE for each time point in the forecast.
|
226 |
+
"""
|
227 |
+
series = [series] if isinstance(series, TimeSeries) else series
|
228 |
+
forecasts = [forecasts] if isinstance(forecasts, TimeSeries) else forecasts
|
229 |
+
metric = [metric] if not isinstance(metric, list) else metric
|
230 |
+
|
231 |
+
# compute errors: 1) reverse scaling forecasts and true values, 2)compute errors
|
232 |
+
series = scaler.inverse_transform(series)
|
233 |
+
forecasts = scaler.inverse_transform(forecasts)
|
234 |
+
errors = [
|
235 |
+
[metric_f(t, f) for metric_f in metric]
|
236 |
+
if len(metric) > 1
|
237 |
+
else metric[0](t, f)
|
238 |
+
for (t, f) in zip(series, forecasts)
|
239 |
+
]
|
240 |
+
errors = np.array(errors)
|
241 |
+
|
242 |
+
if likelihood == "GaussianMean":
|
243 |
+
# compute likelihood
|
244 |
+
est_var = [metrics.mse(t, f) for (t, f) in zip(series, forecasts)]
|
245 |
+
est_var = np.mean(est_var)
|
246 |
+
forecast_len = forecasts[0].n_timesteps
|
247 |
+
log_likelihood = -0.5*forecast_len - 0.5*np.log(2*np.pi*est_var)
|
248 |
+
|
249 |
+
# compute calibration error: 1) cdf values 2) compute calibration error
|
250 |
+
# compute the cdf values
|
251 |
+
cdf_vals = []
|
252 |
+
for t, f in zip(series, forecasts):
|
253 |
+
t, f = _get_values_or_raise(t, f, intersect=True, remove_nan_union=True)
|
254 |
+
t, f = t.flatten(), f.flatten()
|
255 |
+
cdf_vals.append(stats.norm.cdf(t, loc=f, scale=np.sqrt(est_var)))
|
256 |
+
cdf_vals = np.vstack(cdf_vals)
|
257 |
+
# compute the prediction calibration
|
258 |
+
cal_error = np.zeros(forecasts[0].n_timesteps)
|
259 |
+
for p in cal_thresholds:
|
260 |
+
est_p = (cdf_vals <= p).astype(float)
|
261 |
+
est_p = np.mean(est_p, axis=0)
|
262 |
+
cal_error += (est_p - p) ** 2
|
263 |
+
|
264 |
+
if likelihood == "Quantile":
|
265 |
+
# no likelihood since we don't have a parametric model
|
266 |
+
log_likelihood = 0
|
267 |
+
|
268 |
+
# compute calibration error: 1) get quantiles 2) compute calibration error
|
269 |
+
cal_error = np.zeros(forecasts[0].n_timesteps)
|
270 |
+
for p in cal_thresholds:
|
271 |
+
est_p = 0
|
272 |
+
for t, f in zip(series, forecasts):
|
273 |
+
q = f.quantile(p)
|
274 |
+
t, q = _get_values_or_raise(t, q, intersect=True, remove_nan_union=True)
|
275 |
+
t, q = t.flatten(), q.flatten()
|
276 |
+
est_p += (t <= q).astype(float)
|
277 |
+
est_p = (est_p / len(series)).flatten()
|
278 |
+
cal_error += (est_p - p) ** 2
|
279 |
+
|
280 |
+
return errors, log_likelihood, cal_error
|
utils/darts_processing.py
ADDED
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
import yaml
|
4 |
+
import random
|
5 |
+
from typing import Any, BinaryIO, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
6 |
+
from pathlib import Path
|
7 |
+
import numpy as np
|
8 |
+
from scipy import stats
|
9 |
+
import pandas as pd
|
10 |
+
import darts
|
11 |
+
|
12 |
+
from darts import models
|
13 |
+
from darts import metrics
|
14 |
+
from darts import TimeSeries
|
15 |
+
from darts.dataprocessing.transformers import Scaler
|
16 |
+
from pytorch_lightning.callbacks import Callback
|
17 |
+
from sympy import pprint
|
18 |
+
|
19 |
+
# import data formatter
|
20 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
|
21 |
+
from data_formatter.base import *
|
22 |
+
|
23 |
+
pd.set_option('display.width', None) # Set display width to None to avoid truncation
|
24 |
+
pd.set_option('display.max_columns', None) # Display all columns
|
25 |
+
|
26 |
+
def make_series(data: Dict[str, pd.DataFrame],
|
27 |
+
time_col: str,
|
28 |
+
group_col: str,
|
29 |
+
value_cols: Dict[str, List[str]],
|
30 |
+
include_sid: bool = False,
|
31 |
+
verbose: bool = False
|
32 |
+
) -> Dict[str, darts.TimeSeries]:
|
33 |
+
"""
|
34 |
+
Makes the TimeSeries from the data.
|
35 |
+
|
36 |
+
Parameters
|
37 |
+
----------
|
38 |
+
data
|
39 |
+
dict of train, val, test dataframes
|
40 |
+
time_col
|
41 |
+
name of time column
|
42 |
+
group_col
|
43 |
+
name of group column
|
44 |
+
value_cols
|
45 |
+
dict with key specifying the type of covariate and value specifying the list of columns.
|
46 |
+
include_sid
|
47 |
+
whether to include segment id as static covariate
|
48 |
+
|
49 |
+
Returns
|
50 |
+
-------
|
51 |
+
series: Dict[str, Dict[str, darts.TimeSeries]]
|
52 |
+
dict of train, val, test splits of target and covariates TimeSeries objects
|
53 |
+
scalers: Dict[str, darts.preprocessing.Scaler]
|
54 |
+
dict of scalers for target and covariates
|
55 |
+
"""
|
56 |
+
series = {i: {j: None for j in value_cols} for i in data.keys()}
|
57 |
+
scalers = {}
|
58 |
+
for key, df in data.items():
|
59 |
+
|
60 |
+
for name, cols in value_cols.items():
|
61 |
+
# Adjust display settings
|
62 |
+
if verbose:
|
63 |
+
print(f"DATAFRAME for key {key} in NAME {name} and COLS {cols} and GROUP_COL {group_col}")
|
64 |
+
pprint(df.head(1))
|
65 |
+
series[key][name] = TimeSeries.from_group_dataframe(df = df,
|
66 |
+
group_cols = group_col,
|
67 |
+
time_col = time_col,
|
68 |
+
value_cols = cols) if cols is not None else None
|
69 |
+
if series[key][name] is not None and include_sid is False:
|
70 |
+
for i in range(len(series[key][name])):
|
71 |
+
series[key][name][i] = series[key][name][i].with_static_covariates(None)
|
72 |
+
if cols is not None:
|
73 |
+
if key == 'train':
|
74 |
+
scalers[name] = ScalerCustom()
|
75 |
+
series[key][name] = scalers[name].fit_transform(series[key][name])
|
76 |
+
else:
|
77 |
+
series[key][name] = scalers[name].transform(series[key][name])
|
78 |
+
else:
|
79 |
+
scalers[name] = None
|
80 |
+
return series, scalers
|
81 |
+
|
82 |
+
def load_data(url: str,
|
83 |
+
config_path: Path,
|
84 |
+
use_covs: bool = False,
|
85 |
+
cov_type: str = 'past',
|
86 |
+
use_static_covs: bool = False, seed = 0):
|
87 |
+
"""
|
88 |
+
Load data according to the specified config file and covert to Darts TimeSeries objects.
|
89 |
+
|
90 |
+
Parameters
|
91 |
+
----------
|
92 |
+
seed: int
|
93 |
+
Random seed for data splitting.
|
94 |
+
study_file: str
|
95 |
+
Path to the study file.
|
96 |
+
dataset: str
|
97 |
+
Name of the dataset.
|
98 |
+
use_covs: bool
|
99 |
+
Whether to use covariates.
|
100 |
+
cov_type: str
|
101 |
+
Type of covariates to use. Can be 'past' or 'mixed' or 'dual'.
|
102 |
+
use_static_covs: bool
|
103 |
+
Whether to use static covariates.
|
104 |
+
|
105 |
+
Returns
|
106 |
+
-------
|
107 |
+
formatter: DataFormatter
|
108 |
+
Data formatter object.
|
109 |
+
series: Dict[str, Dict[str, TimeSeries]]
|
110 |
+
First dictionary specified the split, second dictionary specifies the type of series (target or covariate).
|
111 |
+
scalers: Dict[str, Scaler]
|
112 |
+
Dictionary of scalers with key indicating the type of series (target or covariate).
|
113 |
+
"""
|
114 |
+
|
115 |
+
|
116 |
+
"""
|
117 |
+
config={
|
118 |
+
'data_csv_path':f'{url}',
|
119 |
+
'drop': None,
|
120 |
+
'ds_name': 'livia_mini',
|
121 |
+
'index_col': -1,
|
122 |
+
'observation_interval': '5min',
|
123 |
+
'column_definition': {
|
124 |
+
{'data_type': 'categorical',
|
125 |
+
'input_type':'id',
|
126 |
+
'name':'id'
|
127 |
+
},
|
128 |
+
{'date_type':'date',
|
129 |
+
'input_type':'time',
|
130 |
+
'name':'time'
|
131 |
+
},
|
132 |
+
{'date_type':'real_valued',
|
133 |
+
'input_type':'target',
|
134 |
+
'name':'gl'
|
135 |
+
}
|
136 |
+
},
|
137 |
+
'encoding_params':{'date':['day','month','year','hour','minute','second']
|
138 |
+
},
|
139 |
+
'nan_vals':None,
|
140 |
+
'interpolation_params':{'gap_threshold': 45,
|
141 |
+
'min_drop_length': 240
|
142 |
+
},
|
143 |
+
'scaling_params':{'scaler':None
|
144 |
+
},
|
145 |
+
'split_params':{'length_segment': 13,
|
146 |
+
'random_state':seed,
|
147 |
+
'test_percent_subjects': 0.1
|
148 |
+
},
|
149 |
+
'max_length_input': 192,
|
150 |
+
'length_pred': 12,
|
151 |
+
'params':{
|
152 |
+
'gluformer':{'in_len': 96,
|
153 |
+
'd_model': 512,
|
154 |
+
'n_heads': 10,
|
155 |
+
'd_fcn': 1024,
|
156 |
+
'num_enc_layers': 2,
|
157 |
+
'num_dec_layers': 2,
|
158 |
+
'length_pred': 12
|
159 |
+
}
|
160 |
+
}
|
161 |
+
}
|
162 |
+
"""
|
163 |
+
with config_path.open("r") as f:
|
164 |
+
config = yaml.safe_load(f)
|
165 |
+
config["data_csv_path"] = url
|
166 |
+
|
167 |
+
formatter = DataFormatter(config)
|
168 |
+
#assert dataset is not None, 'dataset must be specified in the load_data call'
|
169 |
+
assert use_covs is not None, 'use_covs must be specified in the load_data call'
|
170 |
+
|
171 |
+
# convert to series
|
172 |
+
time_col = formatter.get_column('time')
|
173 |
+
group_col = formatter.get_column('sid')
|
174 |
+
target_col = formatter.get_column('target')
|
175 |
+
static_cols = formatter.get_column('static_covs')
|
176 |
+
static_cols = static_cols + [formatter.get_column('id')] if static_cols is not None else [formatter.get_column('id')]
|
177 |
+
dynamic_cols = formatter.get_column('dynamic_covs')
|
178 |
+
future_cols = formatter.get_column('future_covs')
|
179 |
+
|
180 |
+
data = {'train': formatter.train_data,
|
181 |
+
'val': formatter.val_data,
|
182 |
+
'test': formatter.test_data.loc[~formatter.test_data.index.isin(formatter.test_idx_ood)],
|
183 |
+
'test_ood': formatter.test_data.loc[formatter.test_data.index.isin(formatter.test_idx_ood)]}
|
184 |
+
value_cols = {'target': target_col,
|
185 |
+
'static': static_cols,
|
186 |
+
'dynamic': dynamic_cols,
|
187 |
+
'future': future_cols}
|
188 |
+
# build series
|
189 |
+
series, scalers = make_series(data,
|
190 |
+
time_col,
|
191 |
+
group_col,
|
192 |
+
value_cols)
|
193 |
+
if not use_covs:
|
194 |
+
# set dynamic and future covariates to None
|
195 |
+
for split in ['train', 'val', 'test', 'test_ood']:
|
196 |
+
for cov in ['dynamic', 'future']:
|
197 |
+
series[split][cov] = None
|
198 |
+
elif use_covs and cov_type == 'mixed':
|
199 |
+
pass # this is the default for make_series()
|
200 |
+
elif use_covs and cov_type == 'past':
|
201 |
+
# use future covariates as dynamic (past) covariates
|
202 |
+
if series['train']['dynamic'] is None:
|
203 |
+
for split in ['train', 'val', 'test', 'test_ood']:
|
204 |
+
series[split]['dynamic'] = series[split]['future']
|
205 |
+
else:
|
206 |
+
for split in ['train', 'val', 'test', 'test_ood']:
|
207 |
+
for i in range(len(series[split]['future'])):
|
208 |
+
series[split]['dynamic'][i] = series[split]['dynamic'][i].concatenate(series[split]['future'][i], axis=1)
|
209 |
+
# erase future covariates
|
210 |
+
for split in ['train', 'val', 'test', 'test_ood']:
|
211 |
+
series[split]['future'] = None
|
212 |
+
elif use_covs and cov_type == 'dual':
|
213 |
+
# erase dynamic (past) covariates
|
214 |
+
for split in ['train', 'val', 'test', 'test_ood']:
|
215 |
+
series[split]['dynamic'] = None
|
216 |
+
|
217 |
+
if use_static_covs:
|
218 |
+
# attach static covariates to series
|
219 |
+
for split in ['train', 'val', 'test', 'test_ood']:
|
220 |
+
for i in range(len(series[split]['target'])):
|
221 |
+
static_covs = series[split]['static'][i][0].pd_dataframe()
|
222 |
+
series[split]['target'][i] = series[split]['target'][i].with_static_covariates(static_covs)
|
223 |
+
|
224 |
+
return formatter, series, scalers
|
225 |
+
|
226 |
+
def reshuffle_data(formatter: DataFormatter,
|
227 |
+
seed: int = 0,
|
228 |
+
use_covs: bool = None,
|
229 |
+
cov_type: str = 'past',
|
230 |
+
use_static_covs: bool = False,):
|
231 |
+
"""
|
232 |
+
Reshuffle data according to the seed and covert to Darts TimeSeries objects.
|
233 |
+
|
234 |
+
Parameters
|
235 |
+
----------
|
236 |
+
formatter: DataFormatter
|
237 |
+
Data formatter object containing the data
|
238 |
+
seed: int
|
239 |
+
Random seed for data splitting.
|
240 |
+
use_covs: bool
|
241 |
+
Whether to use covariates.
|
242 |
+
cov_type: str
|
243 |
+
Type of covariates to use. Can be 'past' or 'mixed' or 'dual'.
|
244 |
+
use_static_covs: bool
|
245 |
+
Whether to use static covariates.
|
246 |
+
|
247 |
+
Returns
|
248 |
+
-------
|
249 |
+
formatter: DataFormatter
|
250 |
+
Reshuffled data formatter object.
|
251 |
+
series: Dict[str, Dict[str, TimeSeries]]
|
252 |
+
First dictionary specified the split, second dictionary specifies the type of series (target or covariate).
|
253 |
+
scalers: Dict[str, Scaler]
|
254 |
+
Dictionary of scalers with key indicating the type of series (target or covariate).
|
255 |
+
"""
|
256 |
+
# reshuffle
|
257 |
+
formatter.reshuffle(seed)
|
258 |
+
assert use_covs is not None, 'use_covs must be specified in the reshuffle_data call'
|
259 |
+
|
260 |
+
# convert to series
|
261 |
+
time_col = formatter.get_column('time')
|
262 |
+
group_col = formatter.get_column('sid')
|
263 |
+
target_col = formatter.get_column('target')
|
264 |
+
static_cols = formatter.get_column('static_covs')
|
265 |
+
static_cols = static_cols + [formatter.get_column('id')] if static_cols is not None else [formatter.get_column('id')]
|
266 |
+
dynamic_cols = formatter.get_column('dynamic_covs')
|
267 |
+
future_cols = formatter.get_column('future_covs')
|
268 |
+
|
269 |
+
# build series
|
270 |
+
series, scalers = make_series({'train': formatter.train_data,
|
271 |
+
'val': formatter.val_data,
|
272 |
+
'test': formatter.test_data.loc[~formatter.test_data.index.isin(formatter.test_idx_ood)],
|
273 |
+
'test_ood': formatter.test_data.loc[formatter.test_data.index.isin(formatter.test_idx_ood)]},
|
274 |
+
time_col,
|
275 |
+
group_col,
|
276 |
+
{'target': target_col,
|
277 |
+
'static': static_cols,
|
278 |
+
'dynamic': dynamic_cols,
|
279 |
+
'future': future_cols})
|
280 |
+
|
281 |
+
if not use_covs:
|
282 |
+
# set dynamic and future covariates to None
|
283 |
+
for split in ['train', 'val', 'test', 'test_ood']:
|
284 |
+
for cov in ['dynamic', 'future']:
|
285 |
+
series[split][cov] = None
|
286 |
+
elif use_covs and cov_type == 'past':
|
287 |
+
# use future covariates as dynamic covariates
|
288 |
+
if series['train']['dynamic'] is None:
|
289 |
+
for split in ['train', 'val', 'test', 'test_ood']:
|
290 |
+
series[split]['dynamic'] = series[split]['future']
|
291 |
+
# or attach them to dynamic covariates
|
292 |
+
else:
|
293 |
+
for split in ['train', 'val', 'test', 'test_ood']:
|
294 |
+
for i in range(len(series[split]['future'])):
|
295 |
+
series[split]['dynamic'][i] = series[split]['dynamic'][i].concatenate(series[split]['future'][i], axis=1)
|
296 |
+
elif use_covs and cov_type == 'dual':
|
297 |
+
# set dynamic covariates to None, because they are not supported
|
298 |
+
for split in ['train', 'val', 'test', 'test_ood']:
|
299 |
+
series[split]['dynamic'] = None
|
300 |
+
|
301 |
+
if use_static_covs:
|
302 |
+
# attach static covariates to series
|
303 |
+
for split in ['train', 'val', 'test', 'test_ood']:
|
304 |
+
for i in range(len(series[split]['target'])):
|
305 |
+
static_covs = series[split]['static'][i][0].pd_dataframe()
|
306 |
+
series[split]['target'][i] = series[split]['target'][i].with_static_covariates(static_covs)
|
307 |
+
|
308 |
+
return formatter, series, scalers
|
309 |
+
|
310 |
+
class ScalerCustom:
|
311 |
+
'''
|
312 |
+
Min-max scaler for TimeSeries that fits on all sequences simultaenously.
|
313 |
+
Default Darts scaler fits one scaler per sequence in the list.
|
314 |
+
|
315 |
+
Attributes
|
316 |
+
----------
|
317 |
+
scaler: Scaler
|
318 |
+
Darts scaler object.
|
319 |
+
min_: np.ndarray
|
320 |
+
Per feature adjustment for minimum (see Scikit-learn).
|
321 |
+
scale_: np.ndarray
|
322 |
+
Per feature relative scaling of the data (see Scikit-learn).
|
323 |
+
'''
|
324 |
+
def __init__(self):
|
325 |
+
self.scaler = Scaler()
|
326 |
+
self.min_ = None
|
327 |
+
self.scale_ = None
|
328 |
+
|
329 |
+
def fit(self, time_series: Union[List[TimeSeries], TimeSeries]) -> None:
|
330 |
+
|
331 |
+
if isinstance(time_series, list):
|
332 |
+
|
333 |
+
# extract series as Pandas dataframe
|
334 |
+
df = pd.concat([ts.pd_dataframe() for ts in time_series])
|
335 |
+
value_cols = df.columns
|
336 |
+
df.reset_index(inplace=True)
|
337 |
+
# create new equally spaced time grid
|
338 |
+
df['new_time'] = pd.date_range(start=df['time'].min(), periods=len(df), freq='1h')
|
339 |
+
# fit scaler
|
340 |
+
series = TimeSeries.from_dataframe(df, time_col='new_time', value_cols=value_cols)
|
341 |
+
series = self.scaler.fit(series)
|
342 |
+
else:
|
343 |
+
series = self.scaler.fit(time_series)
|
344 |
+
# extract min and scale
|
345 |
+
self.min_ = self.scaler._fitted_params[0].min_
|
346 |
+
self.scale_ = self.scaler._fitted_params[0].scale_
|
347 |
+
|
348 |
+
def transform(self, time_series: Union[List[TimeSeries], TimeSeries]) -> Union[List[TimeSeries], TimeSeries]:
|
349 |
+
if isinstance(time_series, list):
|
350 |
+
# transform one by one
|
351 |
+
series = [self.scaler.transform(ts) for ts in time_series]
|
352 |
+
else:
|
353 |
+
series = self.scaler.transform(time_series)
|
354 |
+
return series
|
355 |
+
|
356 |
+
def inverse_transform(self, time_series: Union[List[TimeSeries], TimeSeries]) -> Union[List[TimeSeries], TimeSeries]:
|
357 |
+
if isinstance(time_series, list):
|
358 |
+
# transform one by one
|
359 |
+
series = [self.scaler.inverse_transform(ts) for ts in time_series]
|
360 |
+
else:
|
361 |
+
series = self.scaler.inverse_transform(time_series)
|
362 |
+
return series
|
363 |
+
|
364 |
+
def fit_transform(self, time_series: Union[List[TimeSeries], TimeSeries]) -> Union[List[TimeSeries], TimeSeries]:
|
365 |
+
self.fit(time_series)
|
366 |
+
series = self.transform(time_series)
|
367 |
+
return series
|
utils/darts_training.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
import yaml
|
4 |
+
import random
|
5 |
+
from typing import Any, BinaryIO, Callable, Dict, List, Optional, Sequence, Tuple, Union
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
from scipy import stats
|
9 |
+
import pandas as pd
|
10 |
+
import darts
|
11 |
+
|
12 |
+
from darts import models
|
13 |
+
from darts import metrics
|
14 |
+
from darts import TimeSeries
|
15 |
+
from pytorch_lightning.callbacks import Callback
|
16 |
+
from darts.logging import get_logger, raise_if_not
|
17 |
+
|
18 |
+
# for optuna callback
|
19 |
+
import warnings
|
20 |
+
import optuna
|
21 |
+
from optuna.storages._cached_storage import _CachedStorage
|
22 |
+
from optuna.storages._rdb.storage import RDBStorage
|
23 |
+
# Define key names of `Trial.system_attrs`.
|
24 |
+
_PRUNED_KEY = "ddp_pl:pruned"
|
25 |
+
_EPOCH_KEY = "ddp_pl:epoch"
|
26 |
+
with optuna._imports.try_import() as _imports:
|
27 |
+
import pytorch_lightning as pl
|
28 |
+
from pytorch_lightning import LightningModule
|
29 |
+
from pytorch_lightning import Trainer
|
30 |
+
from pytorch_lightning.callbacks import Callback
|
31 |
+
if not _imports.is_successful():
|
32 |
+
Callback = object # type: ignore # NOQA
|
33 |
+
LightningModule = object # type: ignore # NOQA
|
34 |
+
Trainer = object # type: ignore # NOQA
|
35 |
+
|
36 |
+
def print_callback(study, trial, study_file=None):
|
37 |
+
# write output to a file
|
38 |
+
with open(study_file, "a") as f:
|
39 |
+
f.write(f"Current value: {trial.value}, Current params: {trial.params}\n")
|
40 |
+
f.write(f"Best value: {study.best_value}, Best params: {study.best_trial.params}\n")
|
41 |
+
|
42 |
+
def early_stopping_check(study,
|
43 |
+
trial,
|
44 |
+
study_file,
|
45 |
+
early_stopping_rounds=10):
|
46 |
+
"""
|
47 |
+
Early stopping callback for Optuna.
|
48 |
+
This function checks the current trial number and the best trial number.
|
49 |
+
"""
|
50 |
+
current_trial_number = trial.number
|
51 |
+
best_trial_number = study.best_trial.number
|
52 |
+
should_stop = (current_trial_number - best_trial_number) >= early_stopping_rounds
|
53 |
+
if should_stop:
|
54 |
+
with open(study_file, 'a') as f:
|
55 |
+
f.write('\nEarly stopping at trial {} (best trial: {})'.format(current_trial_number, best_trial_number))
|
56 |
+
study.stop()
|
57 |
+
|
58 |
+
class LossLogger(Callback):
|
59 |
+
def __init__(self):
|
60 |
+
self.train_loss = []
|
61 |
+
self.val_loss = []
|
62 |
+
|
63 |
+
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
64 |
+
self.train_loss.append(float(trainer.callback_metrics["train_loss"]))
|
65 |
+
|
66 |
+
def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
|
67 |
+
self.val_loss.append(float(trainer.callback_metrics["val_loss"]))
|
68 |
+
|
69 |
+
class PyTorchLightningPruningCallback(Callback):
|
70 |
+
"""PyTorch Lightning callback to prune unpromising trials.
|
71 |
+
See `the example <https://github.com/optuna/optuna-examples/blob/
|
72 |
+
main/pytorch/pytorch_lightning_simple.py>`__
|
73 |
+
if you want to add a pruning callback which observes accuracy.
|
74 |
+
Args:
|
75 |
+
trial:
|
76 |
+
A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the
|
77 |
+
objective function.
|
78 |
+
monitor:
|
79 |
+
An evaluation metric for pruning, e.g., ``val_loss`` or
|
80 |
+
``val_acc``. The metrics are obtained from the returned dictionaries from e.g.
|
81 |
+
``pytorch_lightning.LightningModule.training_step`` or
|
82 |
+
``pytorch_lightning.LightningModule.validation_epoch_end`` and the names thus depend on
|
83 |
+
how this dictionary is formatted.
|
84 |
+
"""
|
85 |
+
|
86 |
+
def __init__(self, trial: optuna.trial.Trial, monitor: str) -> None:
|
87 |
+
super().__init__()
|
88 |
+
|
89 |
+
self._trial = trial
|
90 |
+
self.monitor = monitor
|
91 |
+
|
92 |
+
def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
|
93 |
+
# When the trainer calls `on_validation_end` for sanity check,
|
94 |
+
# do not call `trial.report` to avoid calling `trial.report` multiple times
|
95 |
+
# at epoch 0. The related page is
|
96 |
+
# https://github.com/PyTorchLightning/pytorch-lightning/issues/1391.
|
97 |
+
if trainer.sanity_checking:
|
98 |
+
return
|
99 |
+
|
100 |
+
epoch = pl_module.current_epoch
|
101 |
+
|
102 |
+
current_score = trainer.callback_metrics.get(self.monitor)
|
103 |
+
if current_score is None:
|
104 |
+
message = (
|
105 |
+
"The metric '{}' is not in the evaluation logs for pruning. "
|
106 |
+
"Please make sure you set the correct metric name.".format(self.monitor)
|
107 |
+
)
|
108 |
+
warnings.warn(message)
|
109 |
+
return
|
110 |
+
|
111 |
+
self._trial.report(current_score, step=epoch)
|
112 |
+
if self._trial.should_prune():
|
113 |
+
message = "Trial was pruned at epoch {}.".format(epoch)
|
114 |
+
raise optuna.TrialPruned(message)
|