Livia_Zaharia commited on
Commit
bacf16b
·
1 Parent(s): eb420aa

added code for the first time

Browse files
Files changed (45) hide show
  1. __pycache__/plot_predictions.cpython-311.pyc +0 -0
  2. __pycache__/routes.cpython-311.pyc +0 -0
  3. __pycache__/tools.cpython-311.pyc +0 -0
  4. data_formatter/__init__.py +0 -0
  5. data_formatter/__pycache__/__init__.cpython-311.pyc +0 -0
  6. data_formatter/__pycache__/base.cpython-311.pyc +0 -0
  7. data_formatter/__pycache__/types.cpython-311.pyc +0 -0
  8. data_formatter/__pycache__/utils.cpython-311.pyc +0 -0
  9. data_formatter/base.py +213 -0
  10. data_formatter/types.py +19 -0
  11. data_formatter/utils.py +323 -0
  12. environment.yaml +28 -0
  13. files/config.yaml +81 -0
  14. format_dexcom.py +152 -0
  15. gluformer/__init__.py +0 -0
  16. gluformer/__pycache__/__init__.cpython-311.pyc +0 -0
  17. gluformer/__pycache__/attention.cpython-311.pyc +0 -0
  18. gluformer/__pycache__/decoder.cpython-311.pyc +0 -0
  19. gluformer/__pycache__/embed.cpython-311.pyc +0 -0
  20. gluformer/__pycache__/encoder.cpython-311.pyc +0 -0
  21. gluformer/__pycache__/model.cpython-311.pyc +0 -0
  22. gluformer/__pycache__/variance.cpython-311.pyc +0 -0
  23. gluformer/attention.py +70 -0
  24. gluformer/decoder.py +50 -0
  25. gluformer/embed.py +69 -0
  26. gluformer/encoder.py +67 -0
  27. gluformer/model.py +334 -0
  28. gluformer/utils/__init__.py +0 -0
  29. gluformer/utils/__pycache__/__init__.cpython-311.pyc +0 -0
  30. gluformer/utils/__pycache__/collate.cpython-311.pyc +0 -0
  31. gluformer/utils/__pycache__/training.cpython-311.pyc +0 -0
  32. gluformer/utils/collate.py +84 -0
  33. gluformer/utils/evaluation.py +81 -0
  34. gluformer/utils/training.py +80 -0
  35. gluformer/variance.py +24 -0
  36. main.py +8 -0
  37. tools.py +198 -0
  38. utils/__init__.py +0 -0
  39. utils/__pycache__/__init__.cpython-311.pyc +0 -0
  40. utils/__pycache__/darts_dataset.cpython-311.pyc +0 -0
  41. utils/__pycache__/darts_processing.cpython-311.pyc +0 -0
  42. utils/darts_dataset.py +881 -0
  43. utils/darts_evaluation.py +280 -0
  44. utils/darts_processing.py +367 -0
  45. utils/darts_training.py +114 -0
__pycache__/plot_predictions.cpython-311.pyc ADDED
Binary file (9.5 kB). View file
 
__pycache__/routes.cpython-311.pyc ADDED
Binary file (2.33 kB). View file
 
__pycache__/tools.cpython-311.pyc ADDED
Binary file (13.3 kB). View file
 
data_formatter/__init__.py ADDED
File without changes
data_formatter/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (182 Bytes). View file
 
data_formatter/__pycache__/base.cpython-311.pyc ADDED
Binary file (16.4 kB). View file
 
data_formatter/__pycache__/types.cpython-311.pyc ADDED
Binary file (1.09 kB). View file
 
data_formatter/__pycache__/utils.cpython-311.pyc ADDED
Binary file (19.8 kB). View file
 
data_formatter/base.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''Defines a generic data formatter for CGM data sets.'''
2
+ import sys
3
+ import warnings
4
+ import numpy as np
5
+ import pandas as pd
6
+ import sklearn.preprocessing
7
+ import data_formatter.types as types
8
+ import data_formatter.utils as utils
9
+
10
+ DataTypes = types.DataTypes
11
+ InputTypes = types.InputTypes
12
+
13
+ dict_data_type = {'categorical': DataTypes.CATEGORICAL,
14
+ 'real_valued': DataTypes.REAL_VALUED,
15
+ 'date': DataTypes.DATE}
16
+ dict_input_type = {'target': InputTypes.TARGET,
17
+ 'observed_input': InputTypes.OBSERVED_INPUT,
18
+ 'known_input': InputTypes.KNOWN_INPUT,
19
+ 'static_input': InputTypes.STATIC_INPUT,
20
+ 'id': InputTypes.ID,
21
+ 'time': InputTypes.TIME}
22
+
23
+
24
+ class DataFormatter:
25
+ # Defines and formats data.
26
+
27
+ def __init__(self, cnf):
28
+ """Initialises formatter."""
29
+ # load parameters from the config file
30
+ self.params = cnf
31
+ # write progress to file if specified
32
+
33
+ # load column definition
34
+ print('-'*32)
35
+ print('Loading column definition...')
36
+ self.__process_column_definition()
37
+
38
+ # check that column definition is valid
39
+ print('Checking column definition...')
40
+ self.__check_column_definition()
41
+
42
+ # load data
43
+ # check if data table has index col: -1 if not, index >= 0 if yes
44
+ print('Loading data...')
45
+ self.params['index_col'] = False if self.params['index_col'] == -1 else self.params['index_col']
46
+ # read data table
47
+ self.data = pd.read_csv(self.params['data_csv_path'], index_col=self.params['index_col'])
48
+
49
+ # drop columns / rows
50
+ print('Dropping columns / rows...')
51
+ self.__drop()
52
+
53
+ # check NA values
54
+ print('Checking for NA values...')
55
+ self.__check_nan()
56
+
57
+ # set data types in DataFrame to match column definition
58
+ print('Setting data types...')
59
+ self.__set_data_types()
60
+
61
+ # drop columns / rows
62
+ print('Dropping columns / rows...')
63
+ self.__drop()
64
+
65
+ # encode
66
+ print('Encoding data...')
67
+ self._encoding_params = self.params['encoding_params']
68
+ self.__encode()
69
+
70
+ # interpolate
71
+ print('Interpolating data...')
72
+ self._interpolation_params = self.params['interpolation_params']
73
+ self._interpolation_params['interval_length'] = self.params['observation_interval']
74
+ self.__interpolate()
75
+
76
+ # split data
77
+ print('Splitting data...')
78
+ self._split_params = self.params['split_params']
79
+ self._split_params['max_length_input'] = self.params['max_length_input']
80
+ self.__split_data()
81
+
82
+ # scale
83
+ print('Scaling data...')
84
+ self._scaling_params = self.params['scaling_params']
85
+ self.__scale()
86
+
87
+ print('Data formatting complete.')
88
+ print('-'*32)
89
+
90
+
91
+ def __process_column_definition(self):
92
+ self._column_definition = []
93
+ for col in self.params['column_definition']:
94
+ self._column_definition.append((col['name'],
95
+ dict_data_type[col['data_type']],
96
+ dict_input_type[col['input_type']]))
97
+
98
+ def __check_column_definition(self):
99
+ # check that there is unique ID column
100
+ assert len([col for col in self._column_definition if col[2] == InputTypes.ID]) == 1, 'There must be exactly one ID column.'
101
+ # check that there is unique time column
102
+ assert len([col for col in self._column_definition if col[2] == InputTypes.TIME]) == 1, 'There must be exactly one time column.'
103
+ # check that there is at least one target column
104
+ assert len([col for col in self._column_definition if col[2] == InputTypes.TARGET]) >= 1, 'There must be at least one target column.'
105
+
106
+ def __set_data_types(self):
107
+ # set time column as datetime format in pandas
108
+ for col in self._column_definition:
109
+ if col[1] == DataTypes.DATE:
110
+ self.data[col[0]] = pd.to_datetime(self.data[col[0]])
111
+ if col[1] == DataTypes.CATEGORICAL:
112
+ self.data[col[0]] = self.data[col[0]].astype('category')
113
+ if col[1] == DataTypes.REAL_VALUED:
114
+ self.data[col[0]] = self.data[col[0]].astype(np.float32)
115
+
116
+ def __check_nan(self):
117
+ # delete rows where target, time, or id are na
118
+ self.data = self.data.dropna(subset=[col[0]
119
+ for col in self._column_definition
120
+ if col[2] in [InputTypes.TARGET, InputTypes.TIME, InputTypes.ID]])
121
+ # assert that there are no na values in the data
122
+ assert self.data.isna().sum().sum() == 0, 'There are NA values in the data even after dropping with missing time, glucose, or id.'
123
+
124
+ def __drop(self):
125
+ # drop columns that are not in the column definition
126
+ self.data = self.data[[col[0] for col in self._column_definition]]
127
+ # drop rows based on conditions set in the formatter
128
+ if self.params['drop'] is not None:
129
+ if self.params['drop']['rows'] is not None:
130
+ # drop row at indices in the list self.params['drop']['rows']
131
+ self.data = self.data.drop(self.params['drop']['rows'])
132
+ self.data = self.data.reset_index(drop=True)
133
+ if self.params['drop']['columns'] is not None:
134
+ for col in self.params['drop']['columns'].keys():
135
+ # drop rows where specified columns have values in the list self.params['drop']['columns'][col]
136
+ self.data = self.data.loc[~self.data[col].isin(self.params['drop']['columns'][col])].copy()
137
+
138
+ def __interpolate(self):
139
+ self.data, self._column_definition = utils.interpolate(self.data,
140
+ self._column_definition,
141
+ **self._interpolation_params)
142
+
143
+ def __split_data(self):
144
+ if self.params['split_params']['test_percent_subjects'] == 0 or \
145
+ self.params['split_params']['length_segment'] == 0:
146
+ print('\tNo splitting performed since test_percent_subjects or length_segment is 0.')
147
+ self.train_idx, self.val_idx, self.test_idx, self.test_idx_ood = None, None, None, None
148
+ self.train_data, self.val_data, self.test_data = self.data, None, None
149
+ else:
150
+ assert self.params['split_params']['length_segment'] > self.params['length_pred'], \
151
+ 'length_segment for test / val must be greater than length_pred.'
152
+ self.train_idx, self.val_idx, self.test_idx, self.test_idx_ood = utils.split(self.data,
153
+ self._column_definition,
154
+ **self._split_params)
155
+ self.train_data, self.val_data, self.test_data = self.data.iloc[self.train_idx], \
156
+ self.data.iloc[self.val_idx], \
157
+ self.data.iloc[self.test_idx + self.test_idx_ood]
158
+
159
+ def __encode(self):
160
+ self.data, self._column_definition, self.encoders = utils.encode(self.data,
161
+ self._column_definition,
162
+ **self._encoding_params)
163
+
164
+ def __scale(self):
165
+ self.train_data, self.val_data, self.test_data, self.scalers = utils.scale(self.train_data,
166
+ self.val_data,
167
+ self.test_data,
168
+ self._column_definition,
169
+ **self.params['scaling_params'])
170
+
171
+ def reshuffle(self, seed):
172
+ stdout = sys.stdout
173
+ f = open(self.study_file, 'a')
174
+ sys.stdout = f
175
+ self.params['split_params']['random_state'] = seed
176
+ # split data
177
+ self.train_idx, self.val_idx, self.test_idx, self.test_idx_ood = utils.split(self.data,
178
+ self._column_definition,
179
+ **self._split_params)
180
+ self.train_data, self.val_data, self.test_data = self.data.iloc[self.train_idx], \
181
+ self.data.iloc[self.val_idx], \
182
+ self.data.iloc[self.test_idx+self.test_idx_ood]
183
+ # re-scale data
184
+ self.train_data, self.val_data, self.test_data, self.scalers = utils.scale(self.train_data,
185
+ self.val_data,
186
+ self.test_data,
187
+ self._column_definition,
188
+ **self.params['scaling_params'])
189
+ sys.stdout = stdout
190
+ f.close()
191
+
192
+ def get_column(self, column_name):
193
+ # write cases for time, id, target, future, static, dynamic covariates
194
+ if column_name == 'time':
195
+ return [col[0] for col in self._column_definition if col[2] == InputTypes.TIME][0]
196
+ elif column_name == 'id':
197
+ return [col[0] for col in self._column_definition if col[2] == InputTypes.ID][0]
198
+ elif column_name == 'sid':
199
+ return [col[0] for col in self._column_definition if col[2] == InputTypes.SID][0]
200
+ elif column_name == 'target':
201
+ return [col[0] for col in self._column_definition if col[2] == InputTypes.TARGET]
202
+ elif column_name == 'future_covs':
203
+ future_covs = [col[0] for col in self._column_definition if col[2] == InputTypes.KNOWN_INPUT]
204
+ return future_covs if len(future_covs) > 0 else None
205
+ elif column_name == 'static_covs':
206
+ static_covs = [col[0] for col in self._column_definition if col[2] == InputTypes.STATIC_INPUT]
207
+ return static_covs if len(static_covs) > 0 else None
208
+ elif column_name == 'dynamic_covs':
209
+ dynamic_covs = [col[0] for col in self._column_definition if col[2] == InputTypes.OBSERVED_INPUT]
210
+ return dynamic_covs if len(dynamic_covs) > 0 else None
211
+ else:
212
+ raise ValueError('Column {} not found.'.format(column_name))
213
+
data_formatter/types.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''Defines data and input types of each column in the dataset.'''
2
+
3
+ import enum
4
+
5
+ class DataTypes(enum.IntEnum):
6
+ """Defines numerical types of each column."""
7
+ REAL_VALUED = 0
8
+ CATEGORICAL = 1
9
+ DATE = 2
10
+
11
+ class InputTypes(enum.IntEnum):
12
+ """Defines input types of each column."""
13
+ TARGET = 0
14
+ OBSERVED_INPUT = 1
15
+ KNOWN_INPUT = 2
16
+ STATIC_INPUT = 3
17
+ ID = 4 # Single column used as an entity identifier
18
+ SID = 5 # Single column used as a segment identifier
19
+ TIME = 6 # Single column exclusively used as a time index
data_formatter/utils.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2019 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ # Lint as: python3
17
+ """Generic helper functions used across codebase."""
18
+ import warnings
19
+ from collections import namedtuple
20
+ from datetime import datetime
21
+ import os
22
+ import math
23
+ import pathlib
24
+ import torch
25
+ import numpy as np
26
+ import pandas as pd
27
+ pd.options.mode.chained_assignment = None
28
+ from typing import List, Tuple
29
+ from sklearn import preprocessing
30
+
31
+ import data_formatter
32
+ from data_formatter import types
33
+
34
+ DataTypes = types.DataTypes
35
+ InputTypes = types.InputTypes
36
+ MINUTE = 60
37
+
38
+ # OS related functions.
39
+ def create_folder_if_not_exist(directory):
40
+ """Creates folder if it doesn't exist.
41
+
42
+ Args:
43
+ directory: Folder path to create.
44
+ """
45
+ # Also creates directories recursively
46
+ pathlib.Path(directory).mkdir(parents=True, exist_ok=True)
47
+
48
+
49
+ def csv_path_to_folder(path: str):
50
+ return "/".join(path.split('/')[:-1]) + "/"
51
+
52
+
53
+ def interpolate(data: pd.DataFrame,
54
+ column_definition: List[Tuple[str, DataTypes, InputTypes]],
55
+ gap_threshold: int = 0,
56
+ min_drop_length: int = 0,
57
+ interval_length: int = 0):
58
+ """Interpolates missing values in data.
59
+
60
+ Args:
61
+ df: Dataframe to interpolate on. Sorted by id and then time (a DateTime object).
62
+ column_definition: List of tuples describing columns (column_name, data_type, input_type).
63
+ gap_threshold: Number in minutes, maximum allowed gap for interpolation.
64
+ min_drop_length: Number of points, minimum number within an interval to interpolate.
65
+ interval_length: Number in minutes, length of interpolation.
66
+
67
+ Returns:
68
+ data: DataFrame with missing values interpolated and
69
+ additional column ('segment') indicating continuous segments.
70
+ column_definition: Updataed list of tuples (column_name, data_type, input_type).
71
+ """
72
+ # select all real-valued columns that are not id, time, or static
73
+ interpolation_columns = [column_name for column_name, data_type, input_type in column_definition if
74
+ data_type == DataTypes.REAL_VALUED and
75
+ input_type not in set([InputTypes.ID, InputTypes.TIME, InputTypes.STATIC_INPUT])]
76
+ # select all other columns except time
77
+ constant_columns = [column_name for column_name, data_type, input_type in column_definition if
78
+ input_type not in set([InputTypes.TIME])]
79
+ constant_columns += ['id_segment']
80
+
81
+ # get id and time columns
82
+ id_col = [column_name for column_name, data_type, input_type in column_definition if input_type == InputTypes.ID][0]
83
+ time_col = [column_name for column_name, data_type, input_type in column_definition if input_type == InputTypes.TIME][0]
84
+
85
+ # round to minute
86
+ data[time_col] = data[time_col].dt.round('1min')
87
+ # count dropped segments
88
+ dropped_segments = 0
89
+ # count number of values that are interpolated
90
+ interpolation_count = 0
91
+ # store final output
92
+ output = []
93
+ for id, id_data in data.groupby(id_col):
94
+ # sort values
95
+ id_data.sort_values(time_col, inplace=True)
96
+ # get time difference between consecutive rows
97
+ lag = (id_data[time_col].diff().dt.total_seconds().fillna(0) / 60.0).astype(int)
98
+ # if lag > gap_threshold
99
+ id_segment = (lag > gap_threshold).cumsum()
100
+ id_data['id_segment'] = id_segment
101
+ for segment, segment_data in id_data.groupby('id_segment'):
102
+ # if segment is too short, then we don't interpolate
103
+ if len(segment_data) < min_drop_length:
104
+ dropped_segments += 1
105
+ continue
106
+
107
+ # find and print duplicated times
108
+ duplicates = segment_data.duplicated(time_col, keep=False)
109
+ if duplicates.any():
110
+ print(segment_data[duplicates])
111
+ raise ValueError('Duplicate times in segment {} of id {}'.format(segment, id))
112
+
113
+ # reindex at interval_length minute intervals
114
+ segment_data = segment_data.set_index(time_col)
115
+ index_new = pd.date_range(start = segment_data.index[0],
116
+ end = segment_data.index[-1],
117
+ freq = interval_length)
118
+ index_union = index_new.union(segment_data.index)
119
+ segment_data = segment_data.reindex(index_union)
120
+ # count nan values in interpolation columns
121
+ interpolation_count += segment_data[interpolation_columns[0]].isna().sum()
122
+ # interpolate
123
+ segment_data[interpolation_columns] = segment_data[interpolation_columns].interpolate(method='index')
124
+ # fill constant columns with last value
125
+ segment_data[constant_columns] = segment_data[constant_columns].ffill()
126
+ # delete rows not conforming to frequency
127
+ segment_data = segment_data.reindex(index_new)
128
+ # reset index, make the time a column with name time_col
129
+ segment_data = segment_data.reset_index().rename(columns={'index': time_col})
130
+ # set the id_segment to position in output
131
+ segment_data['id_segment'] = len(output)
132
+ # add to output
133
+ output.append(segment_data)
134
+ # print number of dropped segments and number of segments
135
+ print('\tDropped segments: {}'.format(dropped_segments))
136
+ print('\tExtracted segments: {}'.format(len(output)))
137
+ # concat all segments and reset index
138
+ output = pd.concat(output)
139
+ output.reset_index(drop=True, inplace=True)
140
+ # count number of interpolated values
141
+ print('\tInterpolated values: {}'.format(interpolation_count))
142
+ print('\tPercent of values interpolated: {:.2f}%'.format(interpolation_count / len(output) * 100))
143
+ # add id_segment column to column_definition as ID
144
+ column_definition += [('id_segment', DataTypes.CATEGORICAL, InputTypes.SID)]
145
+
146
+ return output, column_definition
147
+
148
+ def create_index(time_col: pd.Series, interval_length: int):
149
+ """Creates a new index at interval_length minute intervals.
150
+
151
+ Args:
152
+ time_col: Series of times.
153
+ interval_length: Number in minutes, length of interpolation.
154
+
155
+ Returns:
156
+ index: New index.
157
+ """
158
+ # margin of error
159
+ eps = pd.Timedelta('1min')
160
+ new_time_col = [time_col.iloc[0]]
161
+ for time in time_col.iloc[1:]:
162
+ if time - new_time_col[-1] <= pd.Timedelta(interval_length) + eps:
163
+ new_time_col.append(time)
164
+ else:
165
+ filler = new_time_col[-1] + pd.Timedelta(interval_length)
166
+ while filler < time:
167
+ new_time_col.append(filler)
168
+ filler += pd.Timedelta(interval_length)
169
+ new_time_col.append(time)
170
+ return pd.to_datetime(new_time_col)
171
+
172
+ def split(df: pd.DataFrame,
173
+ column_definition: List[Tuple[str, DataTypes, InputTypes]],
174
+ test_percent_subjects: float,
175
+ length_segment: int,
176
+ max_length_input: int,
177
+ random_state: int = 42):
178
+ """Splits data into train, validation and test sets.
179
+
180
+ Args:
181
+ df: Dataframe to split.
182
+ column_definition: List of tuples describing columns (column_name, data_type, input_type).
183
+ test_percent_subjects: Float number from [0, 1], percentage of subjects to use for test set.
184
+ length_segment: Number of points, length of segments saved for validation / test sets.
185
+ max_length_input: Number of points, maximum length of input sequences for models.
186
+ random_state: Number, Random state for reproducibility.
187
+
188
+ Returns:
189
+ train_idx: Training set indices.
190
+ val_idx: Validation set indices.
191
+ test_idx: Test set indices.
192
+ """
193
+ # set random state
194
+ np.random.seed(random_state)
195
+ # get id and id_segment columns
196
+ id_col = [column_name for column_name, data_type, input_type in column_definition if input_type == InputTypes.ID][0]
197
+ id_segment_col = [column_name for column_name, data_type, input_type in column_definition if input_type == InputTypes.SID][0]
198
+ # get unique ids
199
+ ids = df[id_col].unique()
200
+
201
+ # select some subjects for test data set
202
+ test_ids = np.random.choice(ids, math.ceil(len(ids) * test_percent_subjects), replace=False)
203
+ test_idx_ood = list(df[df[id_col].isin(test_ids)].index)
204
+ # get the remaning data for training and validation
205
+ df = df[~df[id_col].isin(test_ids)]
206
+
207
+ # iterate through subjects and split into train, val and test
208
+ train_idx = []; val_idx = []; test_idx = []
209
+ for id, id_data in df.groupby(id_col):
210
+ segment_ids = id_data[id_segment_col].unique()
211
+ if len(segment_ids) >= 2:
212
+ train_idx += list(id_data.loc[id_data[id_segment_col].isin(segment_ids[:-2])].index)
213
+ penultimate_segment = id_data[id_data[id_segment_col] == segment_ids[-2]]
214
+ last_segment = id_data[id_data[id_segment_col] == segment_ids[-1]]
215
+ if len(last_segment) >= max_length_input + 3 * length_segment:
216
+ train_idx += list(penultimate_segment.index)
217
+ train_idx += list(last_segment.iloc[:-2*length_segment].index)
218
+ val_idx += list(last_segment.iloc[-2*length_segment-max_length_input:-length_segment].index)
219
+ test_idx += list(last_segment.iloc[-length_segment-max_length_input:].index)
220
+ elif len(last_segment) >= max_length_input + 2 * length_segment:
221
+ train_idx += list(penultimate_segment.index)
222
+ val_idx += list(last_segment.iloc[:-length_segment].index)
223
+ test_idx += list(last_segment.iloc[-length_segment-max_length_input:].index)
224
+ else:
225
+ test_idx += list(last_segment.index)
226
+ if len(penultimate_segment) >= max_length_input + 2 * length_segment:
227
+ val_idx += list(penultimate_segment.iloc[-length_segment-max_length_input:].index)
228
+ train_idx += list(penultimate_segment.iloc[:-length_segment].index)
229
+ else:
230
+ train_idx += list(penultimate_segment.index)
231
+ else:
232
+ if len(id_data) >= max_length_input + 3 * length_segment:
233
+ train_idx += list(id_data.iloc[:-2*length_segment].index)
234
+ val_idx += list(id_data.iloc[-2*length_segment-max_length_input:-length_segment].index)
235
+ test_idx += list(id_data.iloc[-length_segment-max_length_input:].index)
236
+ elif len(id_data) >= max_length_input + 2 * length_segment:
237
+ train_idx += list(id_data.iloc[:-length_segment].index)
238
+ test_idx += list(id_data.iloc[-length_segment-max_length_input:].index)
239
+ else:
240
+ train_idx += list(id_data.index)
241
+ total_len = len(train_idx) + len(val_idx) + len(test_idx) + len(test_idx_ood)
242
+ print('\tTrain: {} ({:.2f}%)'.format(len(train_idx), len(train_idx) / total_len * 100))
243
+ print('\tVal: {} ({:.2f}%)'.format(len(val_idx), len(val_idx) / total_len * 100))
244
+ print('\tTest: {} ({:.2f}%)'.format(len(test_idx), len(test_idx) / total_len * 100))
245
+ print('\tTest OOD: {} ({:.2f}%)'.format(len(test_idx_ood), len(test_idx_ood) / total_len * 100))
246
+ return train_idx, val_idx, test_idx, test_idx_ood
247
+
248
+ def encode(df: pd.DataFrame,
249
+ column_definition: List[Tuple[str, DataTypes, InputTypes]],
250
+ date: List,):
251
+ """Encodes categorical columns.
252
+
253
+ Args:
254
+ df: Dataframe to split.
255
+ column_definition: List of tuples describing columns (column_name, data_type, input_type).
256
+ date: List of str, list containing date info to extract.
257
+
258
+ Returns:
259
+ df: Dataframe with encoded columns.
260
+ column_definition: Updated list of tuples containing column name and types.
261
+ encoders: dictionary containing encoders.
262
+ """
263
+ encoders = {}
264
+ new_columns = []
265
+ for i in range(len(column_definition)):
266
+ column, column_type, input_type = column_definition[i]
267
+ if column_type == DataTypes.DATE:
268
+ for extract_col in date:
269
+ df[column + '_' + extract_col] = getattr(df[column].dt, extract_col)
270
+ df[column + '_' + extract_col] = df[column + '_' + extract_col].astype(np.float32)
271
+ new_columns.append((column + '_' + extract_col, DataTypes.REAL_VALUED, InputTypes.KNOWN_INPUT))
272
+ elif column_type == DataTypes.CATEGORICAL:
273
+ encoders[column] = preprocessing.LabelEncoder()
274
+ df[column] = encoders[column].fit_transform(df[column]).astype(np.float32)
275
+ column_definition[i] = (column, DataTypes.REAL_VALUED, input_type)
276
+ else:
277
+ continue
278
+ column_definition += new_columns
279
+ # print updated column definition
280
+ print('\tUpdated column definition:')
281
+ for column, column_type, input_type in column_definition:
282
+ print('\t\t{}: {} ({})'.format(column,
283
+ DataTypes(column_type).name,
284
+ InputTypes(input_type).name))
285
+ return df, column_definition, encoders
286
+
287
+ def scale(train_data: pd.DataFrame,
288
+ val_data: pd.DataFrame,
289
+ test_data: pd.DataFrame,
290
+ column_definition: List[Tuple[str, DataTypes, InputTypes]],
291
+ scaler: str):
292
+ """Scales numerical data.
293
+
294
+ Args:
295
+ train_data: pd.Dataframe, DataFrame of training data.
296
+ val_data: pd.Dataframe, DataFrame of validation data.
297
+ test_data: pd.Dataframe, DataFrame of testing data.
298
+ column_definition: List of tuples describing columns (column_name, data_type, input_type).
299
+ scaler: String, scaler to use.
300
+
301
+ Returns:
302
+ train_data: pd.Dataframe, DataFrame of scaled training data.
303
+ val_data: pd.Dataframe, DataFrame of scaled validation data.
304
+ test_data: pd.Dataframe, DataFrame of scaled testing data.
305
+ scalers: dictionary index by column names containing scalers.
306
+ """
307
+ # select all real-valued columns
308
+ columns_to_scale = [column for column, data_type, input_type in column_definition if data_type == DataTypes.REAL_VALUED]
309
+ # handle no scaling case
310
+ if scaler == 'None':
311
+ print('\tNo scaling applied')
312
+ return train_data, val_data, test_data, None
313
+ scalers = {}
314
+ for column in columns_to_scale:
315
+ scaler_column = getattr(preprocessing, scaler)()
316
+ train_data[column] = scaler_column.fit_transform(train_data[column].values.reshape(-1, 1))
317
+ # handle empty validation and test sets
318
+ val_data[column] = scaler_column.transform(val_data[column].values.reshape(-1, 1)) if val_data.shape[0] > 0 else val_data[column]
319
+ test_data[column] = scaler_column.transform(test_data[column].values.reshape(-1, 1)) if test_data.shape[0] > 0 else test_data[column]
320
+ scalers[column] = scaler_column
321
+ # print columns that were scaled
322
+ print('\tScaled columns: {}'.format(columns_to_scale))
323
+ return train_data, val_data, test_data, scalers
environment.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: glucose_genie
2
+ channels:
3
+ - conda-forge
4
+ - defaults
5
+ dependencies:
6
+ - python=3.11
7
+ - gradio
8
+ - seaborn
9
+ - pytorch
10
+ - optuna
11
+ - numpy<2.0.0
12
+ - tensorboard
13
+ - pip:
14
+ - fastapi
15
+ - uvicorn
16
+ - thefuzz
17
+ - pycomfort>=0.0.15
18
+ - polars>=1.3.0
19
+ - hybrid_search>=0.0.15
20
+ - psutil #compartibility
21
+ - httpx
22
+ - just-agents>=0.1.0
23
+ - FlagEmbedding
24
+ - typer
25
+ - darts==0.29.0
26
+ - pmdarima==2.0.4
27
+ - numpy==1.26.4
28
+ - peft
files/config.yaml ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ data_csv_path: ./raw_data/anton.csv
2
+ drop: null
3
+ ds_name: livia_mini
4
+ index_col: -1
5
+ observation_interval: 5min
6
+
7
+ column_definition:
8
+ - data_type: categorical
9
+ input_type: id
10
+ name: id
11
+ - data_type: date
12
+ input_type: time
13
+ name: time
14
+ - data_type: real_valued
15
+ input_type: target
16
+ name: gl
17
+
18
+ encoding_params:
19
+ date:
20
+ - day
21
+ - month
22
+ - year
23
+ - hour
24
+ - minute
25
+ - second
26
+
27
+ # NA values abbreviation
28
+ nan_vals: null
29
+
30
+ # Interpolation parameters
31
+ interpolation_params:
32
+ gap_threshold: 45 # in minutes
33
+ min_drop_length: 240 # in number of points (20 hrs)
34
+
35
+ scaling_params:
36
+ scaler: None
37
+
38
+ split_params:
39
+ length_segment: 13
40
+ random_state: 0
41
+ test_percent_subjects: 0.1
42
+
43
+
44
+ # Splitting parameters
45
+ #split_params:
46
+ # test_percent_subjects: .1
47
+ # length_segment: 240
48
+ # random_state: 0
49
+
50
+ # Model params
51
+ max_length_input: 192
52
+ length_pred: 12
53
+
54
+ transformer:
55
+ batch_size: 32
56
+ d_model: 96
57
+ dim_feedforward: 448
58
+ dropout: 0.10161152207464333
59
+ in_len: 96
60
+ lr: 0.000840888489686657
61
+ lr_epochs: 16
62
+ max_grad_norm: 0.6740479322943925
63
+ max_samples_per_ts: 50
64
+ n_heads: 4
65
+ num_decoder_layers: 1
66
+ num_encoder_layers: 4
67
+
68
+ transformer_covariates:
69
+ batch_size: 32
70
+ d_model: 128
71
+ dim_feedforward: 160
72
+ dropout: 0.044926981080245884
73
+ in_len: 108
74
+ lr: 0.00029632347559614453
75
+ lr_epochs: 20
76
+ max_grad_norm: 0.8890169619043728
77
+ max_samples_per_ts: 50
78
+ n_heads: 2
79
+ num_decoder_layers: 2
80
+ num_encoder_layers: 2
81
+
format_dexcom.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ from pathlib import Path
3
+ import typer
4
+
5
+
6
+ def process_csv(
7
+ input_dir: Path = typer.Argument( help="Directory containing the input CSV files."),
8
+ output_file: Path = typer.Argument( help="Path to save the processed CSV file."),
9
+ event_type_filter: str = typer.Option('egv', help="Event type to filter by."),
10
+ drop_duplicates: bool = typer.Option(True, help="Whether to drop duplicate timestamps."),
11
+ time_diff_minutes: int = typer.Option(1, help="Minimum time difference in minutes to keep a row."),
12
+ chunk_size: int = typer.Option(1000, help="Chunk size for the 'id' column increment. Set to 0 or None for a single id."),
13
+ ) -> pd.DataFrame:
14
+
15
+ # Read CSV file into a DataFrame
16
+ filename=input_dir
17
+ df = pd.read_csv(filename, low_memory=False)
18
+
19
+
20
+ # Filter by Event Type and Event Subtype
21
+ df = df[df['Event Type'].str.lower() == event_type_filter]
22
+ df = df[df['Event Subtype'].isna()]
23
+
24
+ # List of columns to keep
25
+ columns_to_keep = [
26
+ 'Index',
27
+ 'Timestamp (YYYY-MM-DDThh:mm:ss)',
28
+ 'Glucose Value (mg/dL)',
29
+ ]
30
+
31
+ # Keep only the specified columns
32
+ df = df[columns_to_keep]
33
+
34
+ # Rename columns
35
+ column_rename = {
36
+ 'Index': 'id',
37
+ 'Timestamp (YYYY-MM-DDThh:mm:ss)': 'time',
38
+ 'Glucose Value (mg/dL)': 'gl'
39
+ }
40
+ df = df.rename(columns=column_rename)
41
+
42
+
43
+ # Handle id assignment based on chunk_size
44
+ if chunk_size is None or chunk_size == 0:
45
+ df['id'] = 1 # Assign the same id to all rows
46
+ else:
47
+ df['id'] = ((df.index // chunk_size) % (df.index.max() // chunk_size + 1)).astype(int)
48
+
49
+ # Convert timestamp to datetime
50
+ df['time'] = pd.to_datetime(df['time'])
51
+
52
+ # Calculate time difference and keep rows with at least the specified time difference
53
+ df['time_diff'] = df['time'].diff()
54
+ df = df[df['time_diff'].isna() | (df['time_diff'] >= pd.Timedelta(minutes=time_diff_minutes))]
55
+
56
+ # Drop the temporary time_diff column
57
+ df = df.drop(columns=['time_diff'])
58
+
59
+ # Ensure glucose values are in float64
60
+ df['gl'] = df['gl'].astype('float64')
61
+
62
+ # Optionally drop duplicate rows based on time
63
+ if drop_duplicates:
64
+ df = df.drop_duplicates(subset=['time'], keep='first')
65
+
66
+ # Write the modified dataframe to a new CSV file
67
+ df.to_csv(output_file, index=False)
68
+
69
+ typer.echo("CSV files have been successfully merged, modified, and saved.")
70
+
71
+ return df
72
+
73
+
74
+
75
+
76
+ def process_multiple_csv(
77
+ input_dir: Path = typer.Argument('./raw_data/livia_unmerged', help="Directory containing the input CSV files."),
78
+ output_file: Path = typer.Argument('./raw_data/livia_unmerged/livia_mini.csv', help="Path to save the processed CSV file."),
79
+ event_type_filter: str = typer.Option('egv', help="Event type to filter by."),
80
+ drop_duplicates: bool = typer.Option(True, help="Whether to drop duplicate timestamps."),
81
+ time_diff_minutes: int = typer.Option(1, help="Minimum time difference in minutes to keep a row."),
82
+ chunk_size: int = typer.Option(1000, help="Chunk size for the 'id' column increment. Set to 0 or None for a single id."),
83
+ ):
84
+ # Get all the CSV files in the specified directory
85
+ all_files = list(input_dir.glob("*.csv"))
86
+
87
+ # List to store the DataFrames
88
+ df_list = []
89
+
90
+ # Read each CSV file into a DataFrame and append to the list
91
+ for filename in all_files:
92
+ df = pd.read_csv(filename, low_memory=False)
93
+ df_list.append(df)
94
+
95
+ # Concatenate all DataFrames in the list
96
+ combined_df = pd.concat(df_list, ignore_index=True)
97
+
98
+ # Filter by Event Type and Event Subtype
99
+ combined_df = combined_df[combined_df['Event Type'].str.lower() == event_type_filter]
100
+ combined_df = combined_df[combined_df['Event Subtype'].isna()]
101
+
102
+ # List of columns to keep
103
+ columns_to_keep = [
104
+ 'Index',
105
+ 'Timestamp (YYYY-MM-DDThh:mm:ss)',
106
+ 'Glucose Value (mg/dL)',
107
+ ]
108
+
109
+ # Keep only the specified columns
110
+ combined_df = combined_df[columns_to_keep]
111
+
112
+ # Rename columns
113
+ column_rename = {
114
+ 'Index': 'id',
115
+ 'Timestamp (YYYY-MM-DDThh:mm:ss)': 'time',
116
+ 'Glucose Value (mg/dL)': 'gl'
117
+ }
118
+ combined_df = combined_df.rename(columns=column_rename)
119
+
120
+ # Sort the combined DataFrame by timestamp
121
+ combined_df = combined_df.sort_values('time')
122
+
123
+ # Handle id assignment based on chunk_size
124
+ if chunk_size is None or chunk_size == 0:
125
+ combined_df['id'] = 1 # Assign the same id to all rows
126
+ else:
127
+ combined_df['id'] = ((combined_df.index // chunk_size) % (combined_df.index.max() // chunk_size + 1)).astype(int)
128
+
129
+ # Convert timestamp to datetime
130
+ combined_df['time'] = pd.to_datetime(combined_df['time'])
131
+
132
+ # Calculate time difference and keep rows with at least the specified time difference
133
+ combined_df['time_diff'] = combined_df['time'].diff()
134
+ combined_df = combined_df[combined_df['time_diff'].isna() | (combined_df['time_diff'] >= pd.Timedelta(minutes=time_diff_minutes))]
135
+
136
+ # Drop the temporary time_diff column
137
+ combined_df = combined_df.drop(columns=['time_diff'])
138
+
139
+ # Ensure glucose values are in float64
140
+ combined_df['gl'] = combined_df['gl'].astype('float64')
141
+
142
+ # Optionally drop duplicate rows based on time
143
+ if drop_duplicates:
144
+ combined_df = combined_df.drop_duplicates(subset=['time'], keep='first')
145
+
146
+ # Write the modified dataframe to a new CSV file
147
+ combined_df.to_csv(output_file, index=False)
148
+
149
+ typer.echo("CSV files have been successfully merged, modified, and saved.")
150
+
151
+ if __name__ == "__main__":
152
+ typer.run(process_csv)
gluformer/__init__.py ADDED
File without changes
gluformer/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (177 Bytes). View file
 
gluformer/__pycache__/attention.cpython-311.pyc ADDED
Binary file (5.85 kB). View file
 
gluformer/__pycache__/decoder.cpython-311.pyc ADDED
Binary file (3.65 kB). View file
 
gluformer/__pycache__/embed.cpython-311.pyc ADDED
Binary file (6.37 kB). View file
 
gluformer/__pycache__/encoder.cpython-311.pyc ADDED
Binary file (5.28 kB). View file
 
gluformer/__pycache__/model.cpython-311.pyc ADDED
Binary file (15.9 kB). View file
 
gluformer/__pycache__/variance.cpython-311.pyc ADDED
Binary file (1.89 kB). View file
 
gluformer/attention.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from math import sqrt
6
+
7
+ class CausalConv1d(torch.nn.Conv1d):
8
+ def __init__(self,
9
+ in_channels,
10
+ out_channels,
11
+ kernel_size,
12
+ stride=1,
13
+ dilation=1,
14
+ groups=1,
15
+ bias=True):
16
+ self.__padding = (kernel_size - 1) * dilation
17
+
18
+ super(CausalConv1d, self).__init__(
19
+ in_channels,
20
+ out_channels,
21
+ kernel_size=kernel_size,
22
+ stride=stride,
23
+ padding=self.__padding,
24
+ dilation=dilation,
25
+ groups=groups,
26
+ bias=bias)
27
+
28
+ def forward(self, input):
29
+ result = super(CausalConv1d, self).forward(input)
30
+ if self.__padding != 0:
31
+ return result[:, :, :-self.__padding]
32
+ return result
33
+
34
+ class TriangularCausalMask():
35
+ def __init__(self, b, n, device="cpu"):
36
+ mask_shape = [b, 1, n, n]
37
+ with torch.no_grad():
38
+ self._mask = torch.triu(torch.ones(mask_shape, dtype=torch.bool), diagonal=1).to(device)
39
+
40
+ @property
41
+ def mask(self):
42
+ return self._mask
43
+
44
+ class MultiheadAttention(nn.Module):
45
+ def __init__(self, d_model, n_heads, d_keys, mask_flag, r_att_drop=0.1):
46
+ super(MultiheadAttention, self).__init__()
47
+ self.h, self.d, self.mask_flag= n_heads, d_keys, mask_flag
48
+ self.proj_q = nn.Linear(d_model, self.h * self.d)
49
+ self.proj_k = nn.Linear(d_model, self.h * self.d)
50
+ self.proj_v = nn.Linear(d_model, self.h * self.d)
51
+ self.proj_out = nn.Linear(self.h * self.d, d_model)
52
+ self.dropout = nn.Dropout(r_att_drop)
53
+
54
+ def forward(self, q, k, v):
55
+ b, n_q, n_k, h, d = q.size(0), q.size(1), k.size(1), self.h, self.d
56
+
57
+ q, k, v = self.proj_q(q), self.proj_k(k), self.proj_v(v) # b, n_*, h*d
58
+ q, k, v = map(lambda x: x.reshape(b, -1, h, d), [q, k, v]) # b, n_*, h, d
59
+ scores = torch.einsum('bnhd,bmhd->bhnm', (q,k)) # b, h, n_q, n_k
60
+
61
+ if self.mask_flag:
62
+ att_mask = TriangularCausalMask(b, n_q, device=q.device)
63
+ scores.masked_fill_(att_mask.mask, -np.inf)
64
+
65
+ att = F.softmax(scores / (self.d ** .5), dim=-1) # b, h, n_q, n_k
66
+ att = self.dropout(att)
67
+ att_out = torch.einsum('bhnm,bmhd->bnhd', (att,v)) # b, n_q, h, d
68
+ att_out = att_out.reshape(b, -1, h*d) # b, n_q, h*d
69
+ out = self.proj_out(att_out) # b, n_q, d_model
70
+ return out
gluformer/decoder.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .attention import *
6
+
7
+ class DecoderLayer(nn.Module):
8
+ def __init__(self, self_att, cross_att, d_model, d_fcn,
9
+ r_drop, activ="relu"):
10
+ super(DecoderLayer, self).__init__()
11
+
12
+ self.self_att = self_att
13
+ self.cross_att = cross_att
14
+
15
+ self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_fcn, kernel_size=1)
16
+ self.conv2 = nn.Conv1d(in_channels=d_fcn, out_channels=d_model, kernel_size=1)
17
+
18
+ self.norm1 = nn.LayerNorm(d_model)
19
+ self.norm2 = nn.LayerNorm(d_model)
20
+ self.norm3 = nn.LayerNorm(d_model)
21
+
22
+ self.dropout = nn.Dropout(r_drop)
23
+ self.activ = F.relu if activ == "relu" else F.gelu
24
+
25
+ def forward(self, x_dec, x_enc):
26
+ x_dec = x_dec + self.self_att(x_dec, x_dec, x_dec)
27
+ x_dec = self.norm1(x_dec)
28
+
29
+ x_dec = x_dec + self.cross_att(x_dec, x_enc, x_enc)
30
+ res = x_dec = self.norm2(x_dec)
31
+
32
+ res = self.dropout(self.activ(self.conv1(res.transpose(-1,1))))
33
+ res = self.dropout(self.conv2(res).transpose(-1,1))
34
+
35
+ return self.norm3(x_dec+res)
36
+
37
+ class Decoder(nn.Module):
38
+ def __init__(self, layers, norm_layer=None):
39
+ super(Decoder, self).__init__()
40
+ self.layers = nn.ModuleList(layers)
41
+ self.norm = norm_layer
42
+
43
+ def forward(self, x_dec, x_enc):
44
+ for layer in self.layers:
45
+ x_dec = layer(x_dec, x_enc)
46
+
47
+ if self.norm is not None:
48
+ x_dec = self.norm(x_dec)
49
+
50
+ return x_dec
gluformer/embed.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+
6
+ class PositionalEmbedding(nn.Module):
7
+ def __init__(self, d_model, max_len=5000):
8
+ super(PositionalEmbedding, self).__init__()
9
+ # Compute the positional encodings once in log space.
10
+ pos_emb = torch.zeros(max_len, d_model)
11
+ pos_emb.require_grad = False
12
+
13
+ position = torch.arange(0, max_len).unsqueeze(1)
14
+ div_term = (torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)).exp()
15
+
16
+ pos_emb[:, 0::2] = torch.sin(position * div_term)
17
+ pos_emb[:, 1::2] = torch.cos(position * div_term)
18
+
19
+ pos_emb = pos_emb.unsqueeze(0)
20
+ self.register_buffer('pos_emb', pos_emb)
21
+
22
+ def forward(self, x):
23
+ return self.pos_emb[:, :x.size(1)]
24
+
25
+ class TokenEmbedding(nn.Module):
26
+ def __init__(self, d_model):
27
+ super(TokenEmbedding, self).__init__()
28
+ D_INP = 1 # one sequence
29
+ self.conv = nn.Conv1d(in_channels=D_INP, out_channels=d_model,
30
+ kernel_size=3, padding=1, padding_mode='circular')
31
+ # nn.init.kaiming_normal_(self.conv.weight, mode='fan_in', nonlinearity='leaky_relu')
32
+
33
+ def forward(self, x):
34
+ x = self.conv(x.transpose(-1, 1)).transpose(-1, 1)
35
+ return x
36
+
37
+ class TemporalEmbedding(nn.Module):
38
+ def __init__(self, d_model, num_features):
39
+ super(TemporalEmbedding, self).__init__()
40
+ self.embed = nn.Linear(num_features, d_model)
41
+
42
+ def forward(self, x):
43
+ return self.embed(x)
44
+
45
+ class SubjectEmbedding(nn.Module):
46
+ def __init__(self, d_model, num_features):
47
+ super(SubjectEmbedding, self).__init__()
48
+ self.id_embedding = nn.Linear(num_features, d_model)
49
+
50
+ def forward(self, x):
51
+ embed_x = self.id_embedding(x)
52
+
53
+ return embed_x
54
+
55
+ class DataEmbedding(nn.Module):
56
+ def __init__(self, d_model, r_drop, num_dynamic_features, num_static_features):
57
+ super(DataEmbedding, self).__init__()
58
+ # note: d_model // 2 == 0
59
+ self.value_embedding = TokenEmbedding(d_model)
60
+ self.time_embedding = TemporalEmbedding(d_model, num_dynamic_features) # alternative: TimeFeatureEmbedding
61
+ self.positional_embedding = PositionalEmbedding(d_model)
62
+ self.subject_embedding = SubjectEmbedding(d_model, num_static_features)
63
+ self.dropout = nn.Dropout(r_drop)
64
+
65
+ def forward(self, x_id, x, x_mark):
66
+ x = self.value_embedding(x) + self.positional_embedding(x) + self.time_embedding(x_mark)
67
+ x_id = self.subject_embedding(x_id)
68
+ x = torch.cat((x_id.unsqueeze(1), x), dim = 1)
69
+ return self.dropout(x)
gluformer/encoder.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from .attention import *
6
+
7
+ class ConvLayer(nn.Module):
8
+ def __init__(self, d_model):
9
+ super(ConvLayer, self).__init__()
10
+ self.downConv = nn.Conv1d(in_channels=d_model, out_channels=d_model,
11
+ kernel_size=3, padding=1, padding_mode='circular')
12
+ self.norm = nn.BatchNorm1d(d_model)
13
+ self.activ = nn.ELU()
14
+ self.maxPool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)
15
+
16
+ def forward(self, x):
17
+ x = self.downConv(x.transpose(-1, 1))
18
+ x = self.norm(x)
19
+ x = self.activ(x)
20
+ x = self.maxPool(x)
21
+ x = x.transpose(-1,1)
22
+ return x
23
+
24
+ class EncoderLayer(nn.Module):
25
+ def __init__(self, att, d_model, d_fcn, r_drop, activ="relu"):
26
+ super(EncoderLayer, self).__init__()
27
+
28
+ self.att = att
29
+ self.conv1 = nn.Conv1d(in_channels=d_model, out_channels=d_fcn, kernel_size=1)
30
+ self.conv2 = nn.Conv1d(in_channels=d_fcn, out_channels=d_model, kernel_size=1)
31
+ self.norm1 = nn.LayerNorm(d_model)
32
+ self.norm2 = nn.LayerNorm(d_model)
33
+ self.dropout = nn.Dropout(r_drop)
34
+ self.activ = F.relu if activ == "relu" else F.gelu
35
+
36
+ def forward(self, x):
37
+ new_x = self.att(x, x, x)
38
+ x = x + self.dropout(new_x)
39
+
40
+ res = x = self.norm1(x)
41
+ res = self.dropout(self.activ(self.conv1(res.transpose(-1,1))))
42
+ res = self.dropout(self.conv2(res).transpose(-1,1))
43
+
44
+ return self.norm2(x+res)
45
+
46
+ class Encoder(nn.Module):
47
+ def __init__(self, enc_layers, conv_layers=None, norm_layer=None):
48
+ super(Encoder, self).__init__()
49
+ self.enc_layers = nn.ModuleList(enc_layers)
50
+ self.conv_layers = nn.ModuleList(conv_layers) if conv_layers is not None else None
51
+ self.norm = norm_layer
52
+
53
+ def forward(self, x):
54
+ # x [B, L, D]
55
+ if self.conv_layers is not None:
56
+ for enc_layer, conv_layer in zip(self.enc_layers, self.conv_layers):
57
+ x = enc_layer(x)
58
+ x = conv_layer(x)
59
+ x = self.enc_layers[-1](x)
60
+ else:
61
+ for enc_layer in self.enc_layers:
62
+ x = enc_layer(x)
63
+
64
+ if self.norm is not None:
65
+ x = self.norm(x)
66
+
67
+ return x
gluformer/model.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from tqdm import tqdm
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+ from .embed import *
10
+ from .attention import *
11
+ from .encoder import *
12
+ from .decoder import *
13
+ from .variance import *
14
+
15
+ ############################################
16
+ # Added for GluNet package
17
+ ############################################
18
+ import optuna
19
+ import darts
20
+ from torch.utils.tensorboard import SummaryWriter
21
+ sys.path.append(os.path.join(os.path.dirname(__file__), '../..'))
22
+ from glucose.gluformer.utils.training import ExpLikeliLoss, \
23
+ EarlyStop, \
24
+ modify_collate, \
25
+ adjust_learning_rate
26
+ from glucose.utils.darts_dataset import SamplingDatasetDual
27
+ ############################################
28
+
29
+ class Gluformer(nn.Module):
30
+ def __init__(self, d_model, n_heads, d_fcn, r_drop,
31
+ activ, num_enc_layers, num_dec_layers,
32
+ distil, len_seq, len_pred, num_dynamic_features,
33
+ num_static_features, label_len):
34
+ super(Gluformer, self).__init__()
35
+ # Set prediction length
36
+ self.len_pred = len_pred
37
+ self.label_len = label_len
38
+ # Embedding
39
+ # note: d_model // 2 == 0
40
+ self.enc_embedding = DataEmbedding(d_model, r_drop, num_dynamic_features, num_static_features)
41
+ self.dec_embedding = DataEmbedding(d_model, r_drop, num_dynamic_features, num_static_features)
42
+ # Encoding
43
+ self.encoder = Encoder(
44
+ [
45
+ EncoderLayer(
46
+ att=MultiheadAttention(d_model=d_model, n_heads=n_heads,
47
+ d_keys=d_model//n_heads, mask_flag=False,
48
+ r_att_drop=r_drop),
49
+ d_model=d_model,
50
+ d_fcn=d_fcn,
51
+ r_drop=r_drop,
52
+ activ=activ) for l in range(num_enc_layers)
53
+ ],
54
+ [
55
+ ConvLayer(
56
+ d_model) for l in range(num_enc_layers-1)
57
+ ] if distil else None,
58
+ norm_layer=torch.nn.LayerNorm(d_model)
59
+ )
60
+
61
+ # Decoding
62
+ self.decoder = Decoder(
63
+ [
64
+ DecoderLayer(
65
+ self_att=MultiheadAttention(d_model=d_model, n_heads=n_heads,
66
+ d_keys=d_model//n_heads, mask_flag=True,
67
+ r_att_drop=r_drop),
68
+ cross_att=MultiheadAttention(d_model=d_model, n_heads=n_heads,
69
+ d_keys=d_model//n_heads, mask_flag=False,
70
+ r_att_drop=r_drop),
71
+ d_model=d_model,
72
+ d_fcn=d_fcn,
73
+ r_drop=r_drop,
74
+ activ=activ) for l in range(num_dec_layers)
75
+ ],
76
+ norm_layer=torch.nn.LayerNorm(d_model)
77
+ )
78
+
79
+ # Output
80
+ D_OUT = 1
81
+ self.projection = nn.Linear(d_model, D_OUT, bias=True)
82
+
83
+ # Train variance
84
+ self.var = Variance(d_model, r_drop, len_seq)
85
+
86
+ def forward(self, x_id, x_enc, x_mark_enc, x_dec, x_mark_dec):
87
+ enc_out = self.enc_embedding(x_id, x_enc, x_mark_enc)
88
+ var_out = self.var(enc_out)
89
+ enc_out = self.encoder(enc_out)
90
+
91
+ dec_out = self.dec_embedding(x_id, x_dec, x_mark_dec)
92
+ dec_out = self.decoder(dec_out, enc_out)
93
+ dec_out = self.projection(dec_out)
94
+
95
+ return dec_out[:, -self.len_pred:, :], var_out # [B, L, D], log variance
96
+
97
+ ############################################
98
+ # Added for GluNet package
99
+ ############################################
100
+ def fit(self,
101
+ train_dataset: SamplingDatasetDual,
102
+ val_dataset: SamplingDatasetDual,
103
+ learning_rate: float = 1e-3,
104
+ batch_size: int = 32,
105
+ epochs: int = 100,
106
+ num_samples: int = 100,
107
+ device: str = 'cuda',
108
+ model_path: str = None,
109
+ trial: optuna.trial.Trial = None,
110
+ logger: SummaryWriter = None,):
111
+ """
112
+ Fit the model to the data, using Optuna for hyperparameter tuning.
113
+
114
+ Parameters
115
+ ----------
116
+ train_dataset: SamplingDatasetPast
117
+ Training dataset.
118
+ val_dataset: SamplingDatasetPast
119
+ Validation dataset.
120
+ learning_rate: float
121
+ Learning rate for Adam.
122
+ batch_size: int
123
+ Batch size.
124
+ epochs: int
125
+ Number of epochs.
126
+ num_samples: int
127
+ Number of samples for infinite mixture
128
+ device: str
129
+ Device to use.
130
+ model_path: str
131
+ Path to save the model.
132
+ trial: optuna.trial.Trial
133
+ Trial for hyperparameter tuning.
134
+ logger: SummaryWriter
135
+ Tensorboard logger for logging.
136
+ """
137
+ # create data loaders, optimizer, loss, and early stopping
138
+ collate_fn_custom = modify_collate(num_samples)
139
+ train_loader = torch.utils.data.DataLoader(train_dataset,
140
+ batch_size=batch_size,
141
+ shuffle=True,
142
+ drop_last=True,
143
+ collate_fn=collate_fn_custom)
144
+ val_loader = torch.utils.data.DataLoader(val_dataset,
145
+ batch_size=batch_size,
146
+ shuffle=True,
147
+ drop_last=True,
148
+ collate_fn=collate_fn_custom)
149
+ criterion = ExpLikeliLoss(num_samples)
150
+ optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate, betas=(0.1, 0.9))
151
+ scaler = torch.cuda.amp.GradScaler()
152
+ early_stop = EarlyStop(patience=10, delta=0.001)
153
+ self.to(device)
154
+ # train and evaluate the model
155
+ for epoch in range(epochs):
156
+ train_loss = []
157
+ for i, (past_target_series,
158
+ past_covariates,
159
+ future_covariates,
160
+ static_covariates,
161
+ future_target_series) in enumerate(train_loader):
162
+ # zero out gradient
163
+ optimizer.zero_grad()
164
+ # reshape static covariates to be [batch_size, num_static_covariates]
165
+ static_covariates = static_covariates.reshape(-1, static_covariates.shape[-1])
166
+ # create decoder input: pad with zeros the prediction sequence
167
+ dec_inp = torch.cat([past_target_series[:, -self.label_len:, :],
168
+ torch.zeros([
169
+ past_target_series.shape[0],
170
+ self.len_pred,
171
+ past_target_series.shape[-1]
172
+ ])],
173
+ dim=1)
174
+ future_covariates = torch.cat([past_covariates[:, -self.label_len:, :],
175
+ future_covariates], dim=1)
176
+ # move to device
177
+ dec_inp = dec_inp.to(device)
178
+ past_target_series = past_target_series.to(device)
179
+ past_covariates = past_covariates.to(device)
180
+ future_covariates = future_covariates.to(device)
181
+ static_covariates = static_covariates.to(device)
182
+ future_target_series = future_target_series.to(device)
183
+ # forward pass with autograd
184
+ with torch.cuda.amp.autocast():
185
+ pred, logvar = self(static_covariates,
186
+ past_target_series,
187
+ past_covariates,
188
+ dec_inp,
189
+ future_covariates)
190
+ loss = criterion(pred, future_target_series, logvar)
191
+ # backward pass
192
+ scaler.scale(loss).backward()
193
+ scaler.step(optimizer)
194
+ scaler.update()
195
+ # log loss
196
+ if logger is not None:
197
+ logger.add_scalar('train_loss', loss.item(), epoch * len(train_loader) + i)
198
+ train_loss.append(loss.item())
199
+ # log loss
200
+ if logger is not None:
201
+ logger.add_scalar('train_loss_epoch', np.mean(train_loss), epoch)
202
+ # evaluate the model
203
+ val_loss = []
204
+ with torch.no_grad():
205
+ for i, (past_target_series,
206
+ past_covariates,
207
+ future_covariates,
208
+ static_covariates,
209
+ future_target_series) in enumerate(val_loader):
210
+ # reshape static covariates to be [batch_size, num_static_covariates]
211
+ static_covariates = static_covariates.reshape(-1, static_covariates.shape[-1])
212
+ # create decoder input
213
+ dec_inp = torch.cat([past_target_series[:, -self.label_len:, :],
214
+ torch.zeros([
215
+ past_target_series.shape[0],
216
+ self.len_pred,
217
+ past_target_series.shape[-1]
218
+ ])],
219
+ dim=1)
220
+ future_covariates = torch.cat([past_covariates[:, -self.label_len:, :],
221
+ future_covariates], dim=1)
222
+ # move to device
223
+ dec_inp = dec_inp.to(device)
224
+ past_target_series = past_target_series.to(device)
225
+ past_covariates = past_covariates.to(device)
226
+ future_covariates = future_covariates.to(device)
227
+ static_covariates = static_covariates.to(device)
228
+ future_target_series = future_target_series.to(device)
229
+ # forward pass
230
+ pred, logvar = self(static_covariates,
231
+ past_target_series,
232
+ past_covariates,
233
+ dec_inp,
234
+ future_covariates)
235
+ loss = criterion(pred, future_target_series, logvar)
236
+ val_loss.append(loss.item())
237
+ # log loss
238
+ if logger is not None:
239
+ logger.add_scalar('val_loss', loss.item(), epoch * len(val_loader) + i)
240
+ # log loss
241
+ logger.add_scalar('val_loss_epoch', np.mean(val_loss), epoch)
242
+ # check early stopping
243
+ early_stop(np.mean(val_loss), self, model_path)
244
+ if early_stop.stop:
245
+ break
246
+ # check pruning
247
+ if trial is not None:
248
+ trial.report(np.mean(val_loss), epoch)
249
+ if trial.should_prune():
250
+ raise optuna.exceptions.TrialPruned()
251
+ # load best model
252
+ if model_path is not None:
253
+ self.load_state_dict(torch.load(model_path))
254
+
255
+ def predict(self, test_dataset: SamplingDatasetDual,
256
+ batch_size: int = 32,
257
+ num_samples: int = 100,
258
+ device: str = 'cuda'):
259
+ """
260
+ Predict the future target series given the supplied samples from the dataset.
261
+
262
+ Parameters
263
+ ----------
264
+ test_dataset : SamplingDatasetInferenceDual
265
+ The dataset to use for inference.
266
+ batch_size : int, optional
267
+ The batch size to use for inference, by default 32
268
+ num_samples : int, optional
269
+ The number of samples to use for inference, by default 100
270
+
271
+ Returns
272
+ -------
273
+ Predictions
274
+ The predicted future target series in shape n x len_pred x num_samples, where
275
+ n is total number of predictions.
276
+ Logvar
277
+ The logvariance of the predicted future target series in shape n x len_pred.
278
+ """
279
+ # define data loader
280
+ collate_fn_custom = modify_collate(num_samples)
281
+ test_loader = torch.utils.data.DataLoader(test_dataset,
282
+ batch_size=batch_size,
283
+ shuffle=False,
284
+ drop_last=False,
285
+ collate_fn=collate_fn_custom)
286
+ # predict
287
+ self.train()
288
+ # move to device
289
+ self.to(device)
290
+ predictions = []; logvars = []
291
+ for i, (past_target_series,
292
+ historic_future_covariates,
293
+ future_covariates,
294
+ static_covariates) in enumerate(test_loader):
295
+ # reshape static covariates to be [batch_size, num_static_covariates]
296
+ static_covariates = static_covariates.reshape(-1, static_covariates.shape[-1])
297
+ # create decoder input
298
+ dec_inp = torch.cat([past_target_series[:, -self.label_len:, :],
299
+ torch.zeros([
300
+ past_target_series.shape[0],
301
+ self.len_pred,
302
+ past_target_series.shape[-1]
303
+ ])],
304
+ dim=1)
305
+ future_covariates = torch.cat([historic_future_covariates[:, -self.label_len:, :],
306
+ future_covariates], dim=1)
307
+ # move to device
308
+ dec_inp = dec_inp.to(device)
309
+ past_target_series = past_target_series.to(device)
310
+ historic_future_covariates = historic_future_covariates.to(device)
311
+ future_covariates = future_covariates.to(device)
312
+ static_covariates = static_covariates.to(device)
313
+ # forward pass
314
+ pred, logvar = self(static_covariates,
315
+ past_target_series,
316
+ historic_future_covariates,
317
+ dec_inp,
318
+ future_covariates)
319
+ # transfer in numpy and arrange sample along last axis
320
+ pred = pred.cpu().detach().numpy()
321
+ logvar = logvar.cpu().detach().numpy()
322
+ pred = pred.transpose((1, 0, 2)).reshape((pred.shape[1], -1, num_samples)).transpose((1, 0, 2))
323
+ logvar = logvar.transpose((1, 0, 2)).reshape((logvar.shape[1], -1, num_samples)).transpose((1, 0, 2))
324
+ predictions.append(pred)
325
+ logvars.append(logvar)
326
+ predictions = np.concatenate(predictions, axis=0)
327
+ logvars = np.concatenate(logvars, axis=0)
328
+ return predictions, logvars
329
+
330
+ ############################################
331
+
332
+
333
+
334
+
gluformer/utils/__init__.py ADDED
File without changes
gluformer/utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (183 Bytes). View file
 
gluformer/utils/__pycache__/collate.cpython-311.pyc ADDED
Binary file (7.14 kB). View file
 
gluformer/utils/__pycache__/training.cpython-311.pyc ADDED
Binary file (6.84 kB). View file
 
gluformer/utils/collate.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r""""Contains definitions of the methods used by the _BaseDataLoaderIter workers to
2
+ collate samples fetched from dataset into Tensor(s).
3
+ These **needs** to be in global scope since Py2 doesn't support serializing
4
+ static methods.
5
+ """
6
+
7
+ import torch
8
+ import re
9
+ import collections
10
+
11
+ np_str_obj_array_pattern = re.compile(r'[SaUO]')
12
+
13
+
14
+ def default_convert(data):
15
+ r"""Converts each NumPy array data field into a tensor"""
16
+ elem_type = type(data)
17
+ if isinstance(data, torch.Tensor):
18
+ return data
19
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
20
+ and elem_type.__name__ != 'string_':
21
+ # array of string classes and object
22
+ if elem_type.__name__ == 'ndarray' \
23
+ and np_str_obj_array_pattern.search(data.dtype.str) is not None:
24
+ return data
25
+ return torch.as_tensor(data)
26
+ elif isinstance(data, collections.abc.Mapping):
27
+ return {key: default_convert(data[key]) for key in data}
28
+ elif isinstance(data, tuple) and hasattr(data, '_fields'): # namedtuple
29
+ return elem_type(*(default_convert(d) for d in data))
30
+ elif isinstance(data, collections.abc.Sequence) and not isinstance(data, str):
31
+ return [default_convert(d) for d in data]
32
+ else:
33
+ return data
34
+
35
+
36
+ default_collate_err_msg_format = (
37
+ "default_collate: batch must contain tensors, numpy arrays, numbers, "
38
+ "dicts or lists; found {}")
39
+
40
+
41
+ def default_collate(batch):
42
+ r"""Puts each data field into a tensor with outer dimension batch size"""
43
+
44
+ elem = batch[0]
45
+ elem_type = type(elem)
46
+ if isinstance(elem, torch.Tensor):
47
+ out = None
48
+ if torch.utils.data.get_worker_info() is not None:
49
+ # If we're in a background process, concatenate directly into a
50
+ # shared memory tensor to avoid an extra copy
51
+ numel = sum(x.numel() for x in batch)
52
+ storage = elem.storage()._new_shared(numel)
53
+ out = elem.new(storage)
54
+ return torch.stack(batch, 0, out=out)
55
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
56
+ and elem_type.__name__ != 'string_':
57
+ if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
58
+ # array of string classes and object
59
+ if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
60
+ raise TypeError(default_collate_err_msg_format.format(elem.dtype))
61
+
62
+ return default_collate([torch.as_tensor(b) for b in batch])
63
+ elif elem.shape == (): # scalars
64
+ return torch.as_tensor(batch)
65
+ elif isinstance(elem, float):
66
+ return torch.tensor(batch, dtype=torch.float64)
67
+ elif isinstance(elem, int):
68
+ return torch.tensor(batch)
69
+ elif isinstance(elem, str):
70
+ return batch
71
+ elif isinstance(elem, collections.abc.Mapping):
72
+ return {key: default_collate([d[key] for d in batch]) for key in elem}
73
+ elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
74
+ return elem_type(*(default_collate(samples) for samples in zip(*batch)))
75
+ elif isinstance(elem, collections.abc.Sequence):
76
+ # check to make sure that the elements in batch have consistent size
77
+ it = iter(batch)
78
+ elem_size = len(next(it))
79
+ if not all(len(elem) == elem_size for elem in it):
80
+ raise RuntimeError('each element in list of batch should be of equal size')
81
+ transposed = zip(*batch)
82
+ return [default_collate(samples) for samples in transposed]
83
+
84
+ raise TypeError(default_collate_err_msg_format.format(elem_type))
gluformer/utils/evaluation.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import yaml
4
+ import random
5
+ from typing import Any, \
6
+ BinaryIO, \
7
+ Callable, \
8
+ Dict, \
9
+ List, \
10
+ Optional, \
11
+ Sequence, \
12
+ Tuple, \
13
+ Union
14
+
15
+ import numpy as np
16
+ import scipy as sp
17
+ import pandas as pd
18
+ import torch
19
+
20
+ # import data formatter
21
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
22
+
23
+ def test(series: np.ndarray,
24
+ forecasts: np.ndarray,
25
+ var: np.ndarray,
26
+ cal_thresholds: Optional[np.ndarray] = np.linspace(0, 1, 11),
27
+ ):
28
+ """
29
+ Test the (rescaled to original scale) forecasts on the series.
30
+
31
+ Parameters
32
+ ----------
33
+ series
34
+ The target time series of shape (n, t),
35
+ where t is length of prediction.
36
+ forecasts
37
+ The forecasted means of mixture components of shape (n, t, k),
38
+ where k is the number of mixture components.
39
+ var
40
+ The forecasted variances of mixture components of shape (n, 1, k),
41
+ where k is the number of mixture components.
42
+ metric
43
+ The metric or metrics to use for backtesting.
44
+ cal_thresholds
45
+ The thresholds to use for computing the calibration error.
46
+
47
+ Returns
48
+ -------
49
+ np.ndarray
50
+ Error array. Array of shape (n, p)
51
+ where n = series.shape[0] = forecasts.shape[0] and p = len(metric).
52
+ float
53
+ The estimated log-likelihood of the model on the data.
54
+ np.ndarray
55
+ The ECE for each time point in the forecast.
56
+ """
57
+ # compute errors: 1) get samples 2) compute errors using median
58
+ samples = np.random.normal(loc=forecasts[..., None],
59
+ scale=np.sqrt(var)[..., None],
60
+ size=(forecasts.shape[0],
61
+ forecasts.shape[1],
62
+ forecasts.shape[2],
63
+ 30))
64
+ samples = samples.reshape(samples.shape[0], samples.shape[1], -1)
65
+ mse = np.mean((series.squeeze() - forecasts.mean(axis=-1))**2, axis=-1)
66
+ mae = np.mean(np.abs(series.squeeze() - forecasts.mean(axis=-1)), axis=-1)
67
+ errors = np.stack([mse, mae], axis=-1)
68
+
69
+ # compute likelihood
70
+ log_likelihood = sp.special.logsumexp((forecasts - series)**2 / (2 * var) -
71
+ 0.5 * np.log(2 * np.pi * var), axis=-1)
72
+ log_likelihood = np.mean(log_likelihood)
73
+
74
+ # compute calibration error:
75
+ cal_error = np.zeros(forecasts.shape[1])
76
+ for p in cal_thresholds:
77
+ q = np.quantile(samples, p, axis=-1)
78
+ est_p = np.mean(series.squeeze() <= q, axis=0)
79
+ cal_error += (est_p - p) ** 2
80
+
81
+ return errors, log_likelihood, cal_error
gluformer/utils/training.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Callable, Any, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ from sympy import pprint
6
+ from torch import nn, Tensor
7
+
8
+ from .collate import default_collate
9
+
10
+ class EarlyStop:
11
+ def __init__(self, patience: int, delta: float):
12
+ self.patience: int = patience
13
+ self.delta: float = delta
14
+ self.counter: int = 0
15
+ self.best_loss: float = np.Inf
16
+ self.stop: bool = False
17
+
18
+ def __call__(self, loss: float, model: nn.Module, path: str) -> None:
19
+ if loss < self.best_loss:
20
+ self.best_loss = loss
21
+ self.counter = 0
22
+ torch.save(model.state_dict(), path)
23
+ elif loss > self.best_loss + self.delta:
24
+ self.counter = self.counter + 1
25
+ if self.counter >= self.patience:
26
+ self.stop = True
27
+
28
+ class ExpLikeliLoss(nn.Module):
29
+ def __init__(self, num_samples: int = 100):
30
+ super(ExpLikeliLoss, self).__init__()
31
+ self.num_samples: int = num_samples
32
+
33
+ def forward(self, pred: Tensor, true: Tensor, logvar: Tensor) -> Tensor:
34
+ b, l, d = pred.size(0), pred.size(1), pred.size(2)
35
+ true = true.transpose(0,1).reshape(l, -1, self.num_samples).transpose(0, 1)
36
+ pred = pred.transpose(0,1).reshape(l, -1, self.num_samples).transpose(0, 1)
37
+ logvar = logvar.reshape(-1, self.num_samples)
38
+
39
+ loss = torch.mean((-1) * torch.logsumexp((-l / 2) * logvar + (-1 / (2 * torch.exp(logvar))) * torch.sum((true - pred) ** 2, dim=1), dim=1))
40
+ return loss
41
+
42
+ def modify_collate(num_samples: int) -> Callable[[List[Any]], Any]:
43
+ def wrapper(batch: List[Any]) -> Any:
44
+ batch_rep = [sample for sample in batch for _ in range(num_samples)]
45
+ result = default_collate(batch_rep)
46
+ return result
47
+ return wrapper
48
+
49
+ def adjust_learning_rate(model_optim: torch.optim.Optimizer, epoch: int, lr: float) -> None:
50
+ lr = lr * (0.5 ** epoch)
51
+ print("Learning rate halving...")
52
+ print(f"New lr: {lr:.7f}")
53
+ for param_group in model_optim.param_groups:
54
+ param_group['lr'] = lr
55
+
56
+ def process_batch(
57
+ subj_id: Tensor,
58
+ batch_x: Tensor,
59
+ batch_y: Tensor,
60
+ batch_x_mark: Tensor,
61
+ batch_y_mark: Tensor,
62
+ len_pred: int,
63
+ len_label: int,
64
+ model: nn.Module,
65
+ device: torch.device
66
+ ) -> Tuple[Tensor, Tensor, Tensor]:
67
+ subj_id = subj_id.long().to(device)
68
+ batch_x = batch_x.float().to(device)
69
+ batch_y = batch_y.float()
70
+ batch_x_mark = batch_x_mark.float().to(device)
71
+ batch_y_mark = batch_y_mark.float().to(device)
72
+
73
+ true = batch_y[:, -len_pred:, :].to(device)
74
+
75
+ dec_inp = torch.zeros([batch_y.shape[0], len_pred, batch_y.shape[-1]], dtype=torch.float, device=device)
76
+ dec_inp = torch.cat([batch_y[:, :len_label, :].to(device), dec_inp], dim=1)
77
+
78
+ pred, logvar = model(subj_id, batch_x, batch_x_mark, dec_inp, batch_y_mark)
79
+
80
+ return pred, true, logvar
gluformer/variance.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class Variance(nn.Module):
6
+ def __init__(self, d_model, r_drop, len_seq):
7
+ super(Variance, self).__init__()
8
+
9
+ self.proj1 = nn.Linear(d_model, 1)
10
+ self.dropout = nn.Dropout(r_drop)
11
+ self.activ1 = nn.ReLU()
12
+ # + 1 (for seq) for embedded person token
13
+ self.proj2 = nn.Linear(len_seq+1, 1)
14
+ self.activ2 = nn.Tanh()
15
+
16
+ def forward(self, x):
17
+ x = self.proj1(x)
18
+ x = self.activ1(x)
19
+ x = self.dropout(x)
20
+ x = x.transpose(-1, 1)
21
+ x = self.proj2(x)
22
+ # scale to [-10, 10] range
23
+ x = 10 * self.activ2(x)
24
+ return x
main.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from tools import *
3
+
4
+
5
+ def gradio_output():
6
+ return (predict_glucose_tool())
7
+
8
+ gr.Interface(fn=gradio_output).launch()
tools.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import pickle
4
+ import gzip
5
+ from pathlib import Path
6
+
7
+ import seaborn as sns
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ import matplotlib.colors as mcolors
11
+ from matplotlib.figure import Figure
12
+ import torch
13
+ from scipy import stats
14
+
15
+ from gluformer.model import Gluformer
16
+ from utils.darts_processing import *
17
+ from utils.darts_dataset import *
18
+
19
+
20
+ import hashlib
21
+ from urllib.parse import urlparse
22
+
23
+ import numpy as np
24
+ import typer
25
+
26
+
27
+ glucose = Path(os.path.abspath(__file__)).parent.resolve()
28
+ file_directory = glucose / "files"
29
+
30
+
31
+ def plot_forecast(forecasts: np.ndarray, scalers: Any, dataset_test_glufo: Any, filename: str):
32
+ filename=filename
33
+ forecasts = (forecasts - scalers['target'].min_) / scalers['target'].scale_
34
+
35
+ trues = [dataset_test_glufo.evalsample(i) for i in range(len(dataset_test_glufo))]
36
+ trues = scalers['target'].inverse_transform(trues)
37
+
38
+ trues = [ts.values() for ts in trues] # Convert TimeSeries to numpy arrays
39
+ trues = np.array(trues)
40
+
41
+ inputs = [dataset_test_glufo[i][0] for i in range(len(dataset_test_glufo))]
42
+ inputs = (np.array(inputs) - scalers['target'].min_) / scalers['target'].scale_
43
+
44
+ # Plot settings
45
+ colors = ['#00264c', '#0a2c62', '#14437f', '#1f5a9d', '#2973bb', '#358ad9', '#4d9af4', '#7bb7ff', '#add5ff', '#e6f3ff']
46
+ cmap = mcolors.LinearSegmentedColormap.from_list('my_colormap', colors)
47
+ sns.set_theme(style="whitegrid")
48
+
49
+ # Generate the plot
50
+ fig, ax = plt.subplots(figsize=(10, 6))
51
+
52
+
53
+ # Select a specific sample to plot
54
+ ind = 30 # Example index
55
+
56
+ samples = np.random.normal(
57
+ loc=forecasts[ind, :], # Mean (center) of the distribution
58
+ scale=0.1, # Standard deviation (spread) of the distribution
59
+ size=(forecasts.shape[1], forecasts.shape[2])
60
+ )
61
+ #samples = samples.reshape(samples.shape[0], samples.shape[1], -1)
62
+ #print ("samples",samples.shape)
63
+
64
+ # Plot predictive distribution
65
+ for point in range(samples.shape[0]):
66
+ kde = stats.gaussian_kde(samples[point,:])
67
+ maxi, mini = 1.2 * np.max(samples[point, :]), 0.8 * np.min(samples[point, :])
68
+ y_grid = np.linspace(mini, maxi, 200)
69
+ x = kde(y_grid)
70
+ ax.fill_betweenx(y_grid, x1=point, x2=point - x * 15,
71
+ alpha=0.7,
72
+ edgecolor='black',
73
+ color=cmap(point / samples.shape[0]))
74
+
75
+ # Plot median
76
+ forecast = samples[:, :]
77
+ median = np.quantile(forecast, 0.5, axis=-1)
78
+ ax.plot(np.arange(12), median, color='red', marker='o')
79
+
80
+ # Plot true values
81
+ ax.plot(np.arange(-12, 12), np.concatenate([inputs[ind, -12:], trues[ind, :]]), color='blue')
82
+
83
+ # Add labels and title
84
+ ax.set_xlabel('Time (in 5 minute intervals)')
85
+ ax.set_ylabel('Glucose (mg/dL)')
86
+ ax.set_title(f'Gluformer Prediction with Gradient for dateset')
87
+
88
+ # Adjust font sizes
89
+ ax.xaxis.label.set_fontsize(16)
90
+ ax.yaxis.label.set_fontsize(16)
91
+ ax.title.set_fontsize(18)
92
+ for item in ax.get_xticklabels() + ax.get_yticklabels():
93
+ item.set_fontsize(14)
94
+
95
+ # Save figure
96
+ plt.tight_layout()
97
+ where = file_directory /filename
98
+ plt.savefig(str(where), dpi=300, bbox_inches='tight')
99
+
100
+ return where,ax
101
+
102
+
103
+
104
+ def generate_filename_from_url(url: str, extension: str = "png") -> str:
105
+ """
106
+ :param url:
107
+ :param extension:
108
+ :return:
109
+ """
110
+ # Extract the last segment of the URL
111
+ last_segment = urlparse(url).path.split('/')[-1]
112
+
113
+ # Compute the hash of the URL
114
+ url_hash = hashlib.md5(url.encode('utf-8')).hexdigest()
115
+
116
+ # Create the filename
117
+ filename = f"{last_segment.replace('.','_')}_{url_hash}.{extension}"
118
+
119
+ return filename
120
+
121
+
122
+
123
+ def predict_glucose_tool(url: str= 'https://huggingface.co/datasets/Livia-Zaharia/glucose_processed/blob/main/livia_mini.csv',
124
+ model: str = 'https://huggingface.co/Livia-Zaharia/gluformer_models/blob/main/gluformer_1samples_10000epochs_10heads_32batch_geluactivation_livia_mini_weights.pth'
125
+ ) -> Figure:
126
+ """
127
+ Function to predict future glucose of user. It receives URL with users csv. It will run an ML and will return URL with predictions that user can open on her own..
128
+ :param url: of the csv file with glucose values
129
+ :param model: model that is used to predict the glucose
130
+ :param explain if it should give both url and explanation
131
+ :param if the person is diabetic when doing prediction and explanation
132
+ :return:
133
+ """
134
+
135
+ formatter, series, scalers = load_data(url=str(url), config_path=file_directory / "config.yaml", use_covs=True,
136
+ cov_type='dual',
137
+ use_static_covs=True)
138
+
139
+ filename = generate_filename_from_url(url)
140
+
141
+ formatter.params['gluformer'] = {
142
+ 'in_len': 96, # example input length, adjust as necessary
143
+ 'd_model': 512, # model dimension
144
+ 'n_heads': 10, # number of attention heads##############################################################################
145
+ 'd_fcn': 1024, # fully connected layer dimension
146
+ 'num_enc_layers': 2, # number of encoder layers
147
+ 'num_dec_layers': 2, # number of decoder layers
148
+ 'length_pred': 12 # prediction length, adjust as necessary
149
+ }
150
+
151
+ num_dynamic_features = series['train']['future'][-1].n_components
152
+ num_static_features = series['train']['static'][-1].n_components
153
+
154
+ glufo = Gluformer(
155
+ d_model=formatter.params['gluformer']['d_model'],
156
+ n_heads=formatter.params['gluformer']['n_heads'],
157
+ d_fcn=formatter.params['gluformer']['d_fcn'],
158
+ r_drop=0.2,
159
+ activ='gelu',
160
+ num_enc_layers=formatter.params['gluformer']['num_enc_layers'],
161
+ num_dec_layers=formatter.params['gluformer']['num_dec_layers'],
162
+ distil=True,
163
+ len_seq=formatter.params['gluformer']['in_len'],
164
+ label_len=formatter.params['gluformer']['in_len'] // 3,
165
+ len_pred=formatter.params['length_pred'],
166
+ num_dynamic_features=num_dynamic_features,
167
+ num_static_features=num_static_features
168
+ )
169
+ weights = gr.Interface.load(model)
170
+ assert f"weights for {model} should exist", weights.exists()
171
+
172
+ device = "cuda" if torch.cuda.is_available() else "cpu"
173
+ glufo.load_state_dict(torch.load(str(weights), map_location=torch.device(device), weights_only=False))
174
+
175
+ # Define dataset for inference
176
+ dataset_test_glufo = SamplingDatasetInferenceDual(
177
+ target_series=series['test']['target'],
178
+ covariates=series['test']['future'],
179
+ input_chunk_length=formatter.params['gluformer']['in_len'],
180
+ output_chunk_length=formatter.params['length_pred'],
181
+ use_static_covariates=True,
182
+ array_output_only=True
183
+ )
184
+
185
+ forecasts, _ = glufo.predict(
186
+ dataset_test_glufo,
187
+ batch_size=16,####################################################
188
+ num_samples=10,
189
+ device='cpu'
190
+ )
191
+ figure_path, result = plot_forecast(forecasts, scalers, dataset_test_glufo,filename)
192
+
193
+ return result
194
+
195
+
196
+
197
+ if __name__ == "__main__":
198
+ predict_glucose_tool()
utils/__init__.py ADDED
File without changes
utils/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (173 Bytes). View file
 
utils/__pycache__/darts_dataset.cpython-311.pyc ADDED
Binary file (38.6 kB). View file
 
utils/__pycache__/darts_processing.cpython-311.pyc ADDED
Binary file (17.2 kB). View file
 
utils/darts_dataset.py ADDED
@@ -0,0 +1,881 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import yaml
4
+ import random
5
+ from typing import Any, BinaryIO, Callable, Dict, List, Optional, Sequence, Tuple, Union
6
+
7
+ import numpy as np
8
+ from scipy import stats
9
+ import pandas as pd
10
+ import darts
11
+
12
+ from darts import models
13
+ from darts import metrics
14
+ from darts import TimeSeries
15
+ from darts.dataprocessing.transformers import Scaler
16
+ from pytorch_lightning.callbacks import Callback
17
+
18
+ # for darts dataset
19
+ from darts.logging import get_logger, raise_if_not
20
+
21
+ from darts.utils.data.training_dataset import PastCovariatesTrainingDataset, \
22
+ DualCovariatesTrainingDataset, \
23
+ MixedCovariatesTrainingDataset
24
+ from darts.utils.data.inference_dataset import PastCovariatesInferenceDataset, \
25
+ DualCovariatesInferenceDataset, \
26
+ MixedCovariatesInferenceDataset
27
+ from darts.utils.data.utils import CovariateType
28
+
29
+ # import data formatter
30
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
31
+ from data_formatter.base import *
32
+
33
+ def get_valid_sampling_locations(target_series: Union[TimeSeries, Sequence[TimeSeries]],
34
+ output_chunk_length: int = 12,
35
+ input_chunk_length: int = 12,
36
+ random_state: Optional[int] = 0,
37
+ max_samples_per_ts: Optional[int] = None):
38
+ """
39
+ Get valid sampling indices data for the model.
40
+
41
+ Parameters
42
+ ----------
43
+ target_series
44
+ The target time series.
45
+ output_chunk_length
46
+ The length of the output chunk.
47
+ input_chunk_length
48
+ The length of the input chunk.
49
+ use_static_covariates
50
+ Whether to use static covariates.
51
+ max_samples_per_ts
52
+ The maximum number of samples per time series.
53
+ """
54
+ random.seed(random_state)
55
+ valid_sampling_locations = {}
56
+ total_length = input_chunk_length + output_chunk_length
57
+ for id, series in enumerate(target_series):
58
+ num_entries = len(series)
59
+ if num_entries >= total_length:
60
+ valid_sampling_locations[id] = [i for i in range(num_entries - total_length + 1)]
61
+ if max_samples_per_ts is not None:
62
+ updated_sampling_locations = {}
63
+ for id, locations in valid_sampling_locations.items():
64
+ if len(locations) > max_samples_per_ts:
65
+ updated_sampling_locations[id] = random.sample(locations, max_samples_per_ts)
66
+ else:
67
+ updated_sampling_locations[id] = locations
68
+ valid_sampling_locations = updated_sampling_locations
69
+
70
+ return valid_sampling_locations
71
+
72
+ class SamplingDatasetPast(PastCovariatesTrainingDataset):
73
+ def __init__(
74
+ self,
75
+ target_series: Union[TimeSeries, Sequence[TimeSeries]],
76
+ covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
77
+ output_chunk_length: int = 12,
78
+ input_chunk_length: int = 12,
79
+ use_static_covariates: bool = True,
80
+ random_state: Optional[int] = 0,
81
+ max_samples_per_ts: Optional[int] = None,
82
+ remove_nan: bool = False,
83
+ ) -> None:
84
+ """
85
+ Parameters
86
+ ----------
87
+ target_series
88
+ One or a sequence of target `TimeSeries`.
89
+ covariates:
90
+ Optionally, one or a sequence of `TimeSeries` containing past-observed covariates. If this parameter is set,
91
+ the provided sequence must have the same length as that of `target_series`. Moreover, all
92
+ covariates in the sequence must have a time span large enough to contain all the required slices.
93
+ The joint slicing of the target and covariates is relying on the time axes of both series.
94
+ output_chunk_length
95
+ The length of the "output" series emitted by the model
96
+ input_chunk_length
97
+ The length of the "input" series fed to the model
98
+ use_static_covariates
99
+ Whether to use/include static covariate data from input series.
100
+ random_state
101
+ The random state to use for sampling.
102
+ max_samples_per_ts
103
+ The maximum number of samples to be drawn from each time series. If None, all samples will be drawn.
104
+ remove_nan
105
+ Whether to remove None from the output. E.g. if no covariates are provided, the covariates output will be None
106
+ or (optionally) removed from the __getitem__ output.
107
+ """
108
+ super().__init__()
109
+ self.remove_nan = remove_nan
110
+
111
+ self.target_series = (
112
+ [target_series] if isinstance(target_series, TimeSeries) else target_series
113
+ )
114
+ self.covariates = (
115
+ [covariates] if isinstance(covariates, TimeSeries) else covariates
116
+ )
117
+
118
+ # checks
119
+ raise_if_not(
120
+ covariates is None or len(self.target_series) == len(self.covariates),
121
+ "The provided sequence of target series must have the same length as "
122
+ "the provided sequence of covariate series.",
123
+ )
124
+
125
+ # get valid sampling locations
126
+ self.valid_sampling_locations = get_valid_sampling_locations(target_series,
127
+ output_chunk_length,
128
+ input_chunk_length,
129
+ random_state,
130
+ max_samples_per_ts)
131
+
132
+ # set parameters
133
+ self.output_chunk_length = output_chunk_length
134
+ self.input_chunk_length = input_chunk_length
135
+ self.total_length = input_chunk_length + output_chunk_length
136
+ self.total_number_samples = sum([len(v) for v in self.valid_sampling_locations.values()])
137
+ self.use_static_covariates = use_static_covariates
138
+
139
+ def __len__(self):
140
+ """
141
+ Returns the total number of possible (input, target) splits.
142
+ """
143
+ return self.total_number_samples
144
+
145
+ def __getitem__(self, idx: int):
146
+ # get idx of target series
147
+ target_idx = 0
148
+ while idx >= len(self.valid_sampling_locations[target_idx]):
149
+ idx -= len(self.valid_sampling_locations[target_idx])
150
+ target_idx += 1
151
+ # get sampling location within the target series
152
+ sampling_location = self.valid_sampling_locations[target_idx][idx]
153
+ # get target series
154
+ target_series = self.target_series[target_idx].values()
155
+ past_target_series = target_series[sampling_location : sampling_location + self.input_chunk_length]
156
+ future_target_series = target_series[sampling_location + self.input_chunk_length : sampling_location + self.total_length]
157
+ # get covariates
158
+ if self.covariates is not None:
159
+ covariates = self.covariates[target_idx].values()
160
+ covariates = covariates[sampling_location : sampling_location + self.input_chunk_length]
161
+ else:
162
+ covariates = None
163
+ # get static covariates
164
+ if self.use_static_covariates:
165
+ static_covariates = self.target_series[target_idx].static_covariates_values(copy=True)
166
+ else:
167
+ static_covariates = None
168
+
169
+ # return elements that are not None
170
+ if self.remove_nan:
171
+ out = []
172
+ out += [past_target_series] if past_target_series is not None else []
173
+ out += [covariates] if covariates is not None else []
174
+ out += [static_covariates] if static_covariates is not None else []
175
+ out += [future_target_series] if future_target_series is not None else []
176
+ return tuple(out)
177
+ else:
178
+ return tuple([past_target_series,
179
+ covariates,
180
+ static_covariates,
181
+ future_target_series])
182
+
183
+ class SamplingDatasetDual(DualCovariatesTrainingDataset):
184
+ def __init__(
185
+ self,
186
+ target_series: Union[TimeSeries, Sequence[TimeSeries]],
187
+ covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
188
+ output_chunk_length: int = 12,
189
+ input_chunk_length: int = 12,
190
+ use_static_covariates: bool = True,
191
+ random_state: Optional[int] = 0,
192
+ max_samples_per_ts: Optional[int] = None,
193
+ remove_nan: bool = False,
194
+ ) -> None:
195
+ """
196
+ Parameters
197
+ ----------
198
+ target_series
199
+ One or a sequence of target `TimeSeries`.
200
+ covariates:
201
+ Optionally, one or a sequence of `TimeSeries` containing future-known covariates. If this parameter is set,
202
+ the provided sequence must have the same length as that of `target_series`. Moreover, all
203
+ covariates in the sequence must have a time span large enough to contain all the required slices.
204
+ The joint slicing of the target and covariates is relying on the time axes of both series.
205
+ output_chunk_length
206
+ The length of the "output" series emitted by the model
207
+ input_chunk_length
208
+ The length of the "input" series fed to the model
209
+ use_static_covariates
210
+ Whether to use/include static covariate data from input series.
211
+ random_state
212
+ The random state to use for sampling.
213
+ max_samples_per_ts
214
+ The maximum number of samples to be drawn from each time series. If None, all samples will be drawn.
215
+ remove_nan
216
+ Whether to remove None from the output. E.g. if no covariates are provided, the covariates output will be None
217
+ or (optionally) removed from the __getitem__ output.
218
+ """
219
+ super().__init__()
220
+ self.remove_nan = remove_nan
221
+
222
+ self.target_series = (
223
+ [target_series] if isinstance(target_series, TimeSeries) else target_series
224
+ )
225
+ self.covariates = (
226
+ [covariates] if isinstance(covariates, TimeSeries) else covariates
227
+ )
228
+
229
+ # checks
230
+ raise_if_not(
231
+ covariates is None or len(self.target_series) == len(self.covariates),
232
+ "The provided sequence of target series must have the same length as "
233
+ "the provided sequence of covariate series.",
234
+ )
235
+
236
+ # get valid sampling locations
237
+ self.valid_sampling_locations = get_valid_sampling_locations(target_series,
238
+ output_chunk_length,
239
+ input_chunk_length,
240
+ random_state,
241
+ max_samples_per_ts,)
242
+
243
+ # set parameters
244
+ self.output_chunk_length = output_chunk_length
245
+ self.input_chunk_length = input_chunk_length
246
+ self.total_length = input_chunk_length + output_chunk_length
247
+ self.total_number_samples = sum([len(v) for v in self.valid_sampling_locations.values()])
248
+ self.use_static_covariates = use_static_covariates
249
+
250
+ def __len__(self):
251
+ """
252
+ Returns the total number of possible (input, target) splits.
253
+ """
254
+ return self.total_number_samples
255
+
256
+ def __getitem__(self, idx: int):
257
+ # get idx of target series
258
+ target_idx = 0
259
+ while idx >= len(self.valid_sampling_locations[target_idx]):
260
+ idx -= len(self.valid_sampling_locations[target_idx])
261
+ target_idx += 1
262
+ # get sampling location within the target series
263
+ sampling_location = self.valid_sampling_locations[target_idx][idx]
264
+ # get target series
265
+ target_series = self.target_series[target_idx].values()
266
+ past_target_series = target_series[sampling_location : sampling_location + self.input_chunk_length]
267
+ future_target_series = target_series[sampling_location + self.input_chunk_length : sampling_location + self.total_length]
268
+ # get covariates
269
+ if self.covariates is not None:
270
+ covariates = self.covariates[target_idx].values()
271
+ past_covariates = covariates[sampling_location : sampling_location + self.input_chunk_length]
272
+ future_covariates = covariates[sampling_location + self.input_chunk_length : sampling_location + self.total_length]
273
+ else:
274
+ past_covariates = None
275
+ future_covariates = None
276
+ # get static covariates
277
+ if self.use_static_covariates:
278
+ static_covariates = self.target_series[target_idx].static_covariates_values(copy=True)
279
+ else:
280
+ static_covariates = None
281
+
282
+ # return elements that are not None
283
+ if self.remove_nan:
284
+ out = []
285
+ out += [past_target_series] if past_target_series is not None else []
286
+ out += [past_covariates] if past_covariates is not None else []
287
+ out += [future_covariates] if future_covariates is not None else []
288
+ out += [static_covariates] if static_covariates is not None else []
289
+ out += [future_target_series] if future_target_series is not None else []
290
+ return tuple(out)
291
+ else:
292
+ return tuple([past_target_series,
293
+ past_covariates,
294
+ future_covariates,
295
+ static_covariates,
296
+ future_target_series])
297
+
298
+ class SamplingDatasetMixed(MixedCovariatesTrainingDataset):
299
+ def __init__(
300
+ self,
301
+ target_series: Union[TimeSeries, Sequence[TimeSeries]],
302
+ past_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
303
+ future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
304
+ output_chunk_length: int = 12,
305
+ input_chunk_length: int = 12,
306
+ use_static_covariates: bool = True,
307
+ random_state: Optional[int] = 0,
308
+ max_samples_per_ts: Optional[int] = None,
309
+ remove_nan: bool = False,
310
+ ) -> None:
311
+ """
312
+ Parameters
313
+ ----------
314
+ target_series
315
+ One or a sequence of target `TimeSeries`.
316
+ past_covariates
317
+ Optionally, one or a sequence of `TimeSeries` containing past-observed covariates. If this parameter is set,
318
+ the provided sequence must have the same length as that of `target_series`. Moreover, all
319
+ covariates in the sequence must have a time span large enough to contain all the required slices.
320
+ The joint slicing of the target and covariates is relying on the time axes of both series.
321
+ future_covariates
322
+ Optionally, one or a sequence of `TimeSeries` containing future-known covariates. This has to follow
323
+ the same constraints as `past_covariates`.
324
+ output_chunk_length
325
+ The length of the "output" series emitted by the model
326
+ input_chunk_length
327
+ The length of the "input" series fed to the model
328
+ use_static_covariates
329
+ Whether to use/include static covariate data from input series.
330
+ random_state
331
+ The random state to use for sampling.
332
+ max_samples_per_ts
333
+ The maximum number of samples to be drawn from each time series. If None, all samples will be drawn.
334
+ remove_nan
335
+ Whether to remove None from the output. E.g. if no covariates are provided, the covariates output will be None
336
+ or (optionally) removed from the __getitem__ output.
337
+ """
338
+ super().__init__()
339
+ self.remove_nan = remove_nan
340
+
341
+ self.target_series = (
342
+ [target_series] if isinstance(target_series, TimeSeries) else target_series
343
+ )
344
+ self.past_covariates = (
345
+ [past_covariates] if isinstance(past_covariates, TimeSeries) else past_covariates
346
+ )
347
+ self.future_covariates = (
348
+ [future_covariates] if isinstance(future_covariates, TimeSeries) else future_covariates
349
+ )
350
+
351
+ # checks
352
+ raise_if_not(
353
+ future_covariates is None or len(self.target_series) == len(self.future_covariates),
354
+ "The provided sequence of target series must have the same length as "
355
+ "the provided sequence of covariate series.",
356
+ )
357
+ raise_if_not(
358
+ past_covariates is None or len(self.target_series) == len(self.past_covariates),
359
+ "The provided sequence of target series must have the same length as "
360
+ "the provided sequence of covariate series.",
361
+ )
362
+
363
+ # get valid sampling locations
364
+ self.valid_sampling_locations = get_valid_sampling_locations(target_series,
365
+ output_chunk_length,
366
+ input_chunk_length,
367
+ random_state,
368
+ max_samples_per_ts,)
369
+
370
+ # set parameters
371
+ self.output_chunk_length = output_chunk_length
372
+ self.input_chunk_length = input_chunk_length
373
+ self.total_length = input_chunk_length + output_chunk_length
374
+ self.total_number_samples = sum([len(v) for v in self.valid_sampling_locations.values()])
375
+ self.use_static_covariates = use_static_covariates
376
+
377
+ def __len__(self):
378
+ """
379
+ Returns the total number of possible (input, target) splits.
380
+ """
381
+ return self.total_number_samples
382
+
383
+ def __getitem__(self, idx: int):
384
+ # get idx of target series
385
+ target_idx = 0
386
+ while idx >= len(self.valid_sampling_locations[target_idx]):
387
+ idx -= len(self.valid_sampling_locations[target_idx])
388
+ target_idx += 1
389
+ # get sampling location within the target series
390
+ sampling_location = self.valid_sampling_locations[target_idx][idx]
391
+ # get target series
392
+ target_series = self.target_series[target_idx].values()
393
+ past_target_series = target_series[sampling_location : sampling_location + self.input_chunk_length]
394
+ future_target_series = target_series[sampling_location + self.input_chunk_length : sampling_location + self.total_length]
395
+ # get past covariates
396
+ if self.past_covariates is not None:
397
+ past_covariates = self.past_covariates[target_idx].values()
398
+ past_covariates = past_covariates[sampling_location : sampling_location + self.input_chunk_length]
399
+ else:
400
+ past_covariates = None
401
+ # get future covariates
402
+ if self.future_covariates is not None:
403
+ future_covariates = self.future_covariates[target_idx].values()
404
+ historic_future_covariates = future_covariates[sampling_location : sampling_location + self.input_chunk_length]
405
+ future_covariates = future_covariates[sampling_location + self.input_chunk_length : sampling_location + self.total_length]
406
+ else:
407
+ future_covariates = None
408
+ historic_future_covariates = None
409
+ # get static covariates
410
+ if self.use_static_covariates:
411
+ static_covariates = self.target_series[target_idx].static_covariates_values(copy=True)
412
+ else:
413
+ static_covariates = None
414
+
415
+ # return elements that are not None
416
+ if self.remove_nan:
417
+ out = []
418
+ out += [past_target_series] if past_target_series is not None else []
419
+ out += [past_covariates] if past_covariates is not None else []
420
+ out += [historic_future_covariates] if historic_future_covariates is not None else []
421
+ out += [future_covariates] if future_covariates is not None else []
422
+ out += [static_covariates] if static_covariates is not None else []
423
+ out += [future_target_series] if future_target_series is not None else []
424
+ return tuple(out)
425
+ else:
426
+ return tuple([past_target_series,
427
+ past_covariates,
428
+ historic_future_covariates,
429
+ future_covariates,
430
+ static_covariates,
431
+ future_target_series])
432
+
433
+ class SamplingDatasetInferenceMixed(MixedCovariatesInferenceDataset):
434
+ def __init__(
435
+ self,
436
+ target_series: Union[TimeSeries, Sequence[TimeSeries]],
437
+ past_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
438
+ future_covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
439
+ n: int = 1,
440
+ input_chunk_length: int = 12,
441
+ output_chunk_length: int = 1,
442
+ use_static_covariates: bool = True,
443
+ random_state: Optional[int] = 0,
444
+ max_samples_per_ts: Optional[int] = None,
445
+ array_output_only: bool = False,
446
+ ):
447
+ """
448
+ Parameters
449
+ ----------
450
+ target_series
451
+ One or a sequence of target `TimeSeries`.
452
+ past_covariates
453
+ Optionally, one or a sequence of `TimeSeries` containing past-observed covariates. If this parameter is set,
454
+ the provided sequence must have the same length as that of `target_series`. Moreover, all
455
+ covariates in the sequence must have a time span large enough to contain all the required slices.
456
+ The joint slicing of the target and covariates is relying on the time axes of both series.
457
+ future_covariates
458
+ Optionally, one or a sequence of `TimeSeries` containing future-known covariates. This has to follow
459
+ the same constraints as `past_covariates`.
460
+ n
461
+ Number of predictions into the future, could be greater than the output chunk length, in which case, the model
462
+ will be called autorregressively.
463
+ output_chunk_length
464
+ The length of the "output" series emitted by the model
465
+ input_chunk_length
466
+ The length of the "input" series fed to the model
467
+ use_static_covariates
468
+ Whether to use/include static covariate data from input series.
469
+ random_state
470
+ The random state to use for sampling.
471
+ max_samples_per_ts
472
+ The maximum number of samples to be drawn from each time series. If None, all samples will be drawn.
473
+ array_output_only
474
+ Whether __getitem__ returns only the arrays or adds the full `TimeSeries` object to the output tuple
475
+ This may cause problems with the torch collate and loader functions but works for Darts.
476
+ """
477
+ super().__init__(target_series = target_series,
478
+ past_covariates = past_covariates,
479
+ future_covariates = future_covariates,
480
+ n = n,
481
+ input_chunk_length = input_chunk_length,
482
+ output_chunk_length = output_chunk_length,)
483
+
484
+ self.target_series = (
485
+ [target_series] if isinstance(target_series, TimeSeries) else target_series
486
+ )
487
+ self.past_covariates = (
488
+ [past_covariates] if isinstance(past_covariates, TimeSeries) else past_covariates
489
+ )
490
+ self.future_covariates = (
491
+ [future_covariates] if isinstance(future_covariates, TimeSeries) else future_covariates
492
+ )
493
+
494
+ # checks
495
+ raise_if_not(
496
+ future_covariates is None or len(self.target_series) == len(self.future_covariates),
497
+ "The provided sequence of target series must have the same length as "
498
+ "the provided sequence of covariate series.",
499
+ )
500
+ raise_if_not(
501
+ past_covariates is None or len(self.target_series) == len(self.past_covariates),
502
+ "The provided sequence of target series must have the same length as "
503
+ "the provided sequence of covariate series.",
504
+ )
505
+
506
+ # get valid sampling locations
507
+ self.valid_sampling_locations = get_valid_sampling_locations(target_series,
508
+ output_chunk_length,
509
+ input_chunk_length,
510
+ random_state,
511
+ max_samples_per_ts,)
512
+
513
+ # set parameters
514
+ self.output_chunk_length = output_chunk_length
515
+ self.input_chunk_length = input_chunk_length
516
+ self.total_length = input_chunk_length + output_chunk_length
517
+ self.total_number_samples = sum([len(v) for v in self.valid_sampling_locations.values()])
518
+ self.use_static_covariates = use_static_covariates
519
+ self.array_output_only = array_output_only
520
+
521
+ def __len__(self):
522
+ """
523
+ Returns the total number of possible (input, target) splits.
524
+ """
525
+ return self.total_number_samples
526
+
527
+ def __getitem__(self, idx: int):
528
+ # get idx of target series
529
+ target_idx = 0
530
+ while idx >= len(self.valid_sampling_locations[target_idx]):
531
+ idx -= len(self.valid_sampling_locations[target_idx])
532
+ target_idx += 1
533
+ # get sampling location within the target series
534
+ sampling_location = self.valid_sampling_locations[target_idx][idx]
535
+ # get target series
536
+ target_series = self.target_series[target_idx]
537
+ past_target_series_with_time = target_series[sampling_location : sampling_location + self.input_chunk_length]
538
+ past_end = past_target_series_with_time.time_index[-1]
539
+ target_series = self.target_series[target_idx].values()
540
+ past_target_series = target_series[sampling_location : sampling_location + self.input_chunk_length]
541
+ # get past covariates
542
+ if self.past_covariates is not None:
543
+ past_covariates = self.past_covariates[target_idx].values()
544
+ past_covariates = past_covariates[sampling_location : sampling_location + self.input_chunk_length]
545
+ future_past_covariates = past_covariates[sampling_location + self.input_chunk_length : sampling_location + self.total_length]
546
+ else:
547
+ past_covariates = None
548
+ future_past_covariates = None
549
+ # get future covariates
550
+ if self.future_covariates is not None:
551
+ future_covariates = self.future_covariates[target_idx].values()
552
+ historic_future_covariates = future_covariates[sampling_location : sampling_location + self.input_chunk_length]
553
+ future_covariates = future_covariates[sampling_location + self.input_chunk_length : sampling_location + self.total_length]
554
+ else:
555
+ future_covariates = None
556
+ historic_future_covariates = None
557
+ # get static covariates
558
+ if self.use_static_covariates:
559
+ static_covariates = self.target_series[target_idx].static_covariates_values(copy=True)
560
+ else:
561
+ static_covariates = None
562
+ # whether to remove Timeseries and None and return only arrays
563
+
564
+ if self.array_output_only:
565
+ out = []
566
+ out += [past_target_series] if past_target_series is not None else []
567
+ out += [past_covariates] if past_covariates is not None else []
568
+ out += [historic_future_covariates] if historic_future_covariates is not None else []
569
+ out += [future_covariates] if future_covariates is not None else []
570
+ out += [future_past_covariates] if future_past_covariates is not None else []
571
+ out += [static_covariates] if static_covariates is not None else []
572
+ return tuple(out)
573
+ else:
574
+ return tuple([past_target_series,
575
+ past_covariates,
576
+ historic_future_covariates,
577
+ future_covariates,
578
+ future_past_covariates,
579
+ static_covariates,
580
+ past_target_series_with_time,
581
+ past_end + past_target_series_with_time.freq
582
+ ])
583
+
584
+ def evalsample(
585
+ self, idx: int
586
+ ) -> TimeSeries:
587
+ """
588
+ Returns the future target series at the given index.
589
+ """
590
+ # get idx of target series
591
+ target_idx = 0
592
+ while idx >= len(self.valid_sampling_locations[target_idx]):
593
+ idx -= len(self.valid_sampling_locations[target_idx])
594
+ target_idx += 1
595
+ # get sampling location within the target series
596
+ sampling_location = self.valid_sampling_locations[target_idx][idx]
597
+ # get target series
598
+ target_series = self.target_series[target_idx][sampling_location + self.input_chunk_length : sampling_location + self.total_length]
599
+
600
+ return target_series
601
+
602
+ class SamplingDatasetInferencePast(PastCovariatesInferenceDataset):
603
+ def __init__(
604
+ self,
605
+ target_series: Union[TimeSeries, Sequence[TimeSeries]],
606
+ covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
607
+ n: int = 1,
608
+ input_chunk_length: int = 12,
609
+ output_chunk_length: int = 1,
610
+ use_static_covariates: bool = True,
611
+ random_state: Optional[int] = 0,
612
+ max_samples_per_ts: Optional[int] = None,
613
+ array_output_only: bool = False,
614
+ ):
615
+ """
616
+ Parameters
617
+ ----------
618
+ target_series
619
+ One or a sequence of target `TimeSeries`.
620
+ past_covariates
621
+ Optionally, one or a sequence of `TimeSeries` containing past-observed covariates. If this parameter is set,
622
+ the provided sequence must have the same length as that of `target_series`. Moreover, all
623
+ covariates in the sequence must have a time span large enough to contain all the required slices.
624
+ The joint slicing of the target and covariates is relying on the time axes of both series.
625
+ n
626
+ Number of predictions into the future, could be greater than the output chunk length, in which case, the model
627
+ will be called autorregressively.
628
+ output_chunk_length
629
+ The length of the "output" series emitted by the model
630
+ input_chunk_length
631
+ The length of the "input" series fed to the model
632
+ use_static_covariates
633
+ Whether to use/include static covariate data from input series.
634
+ random_state
635
+ The random state to use for sampling.
636
+ max_samples_per_ts
637
+ The maximum number of samples to be drawn from each time series. If None, all samples will be drawn.
638
+ array_output_only
639
+ Whether __getitem__ returns only the arrays or adds the full `TimeSeries` object to the output tuple
640
+ This may cause problems with the torch collate and loader functions but works for Darts.
641
+ """
642
+ super().__init__(target_series = target_series,
643
+ covariates = covariates,
644
+ n = n,
645
+ input_chunk_length = input_chunk_length,
646
+ output_chunk_length = output_chunk_length,)
647
+
648
+ self.target_series = (
649
+ [target_series] if isinstance(target_series, TimeSeries) else target_series
650
+ )
651
+ self.covariates = (
652
+ [covariates] if isinstance(covariates, TimeSeries) else covariates
653
+ )
654
+
655
+ raise_if_not(
656
+ covariates is None or len(self.target_series) == len(self.covariates),
657
+ "The provided sequence of target series must have the same length as "
658
+ "the provided sequence of covariate series.",
659
+ )
660
+
661
+ # get valid sampling locations
662
+ self.valid_sampling_locations = get_valid_sampling_locations(target_series,
663
+ output_chunk_length,
664
+ input_chunk_length,
665
+ random_state,
666
+ max_samples_per_ts,)
667
+
668
+ # set parameters
669
+ self.output_chunk_length = output_chunk_length
670
+ self.input_chunk_length = input_chunk_length
671
+ self.total_length = input_chunk_length + output_chunk_length
672
+ self.total_number_samples = sum([len(v) for v in self.valid_sampling_locations.values()])
673
+ self.use_static_covariates = use_static_covariates
674
+ self.array_output_only = array_output_only
675
+
676
+ def __len__(self):
677
+ """
678
+ Returns the total number of possible (input, target) splits.
679
+ """
680
+ return self.total_number_samples
681
+
682
+ def __getitem__(self, idx: int):
683
+ # get idx of target series
684
+ target_idx = 0
685
+ while idx >= len(self.valid_sampling_locations[target_idx]):
686
+ idx -= len(self.valid_sampling_locations[target_idx])
687
+ target_idx += 1
688
+ # get sampling location within the target series
689
+ sampling_location = self.valid_sampling_locations[target_idx][idx]
690
+ # get target series
691
+ target_series = self.target_series[target_idx]
692
+ past_target_series_with_time = target_series[sampling_location : sampling_location + self.input_chunk_length]
693
+ past_end = past_target_series_with_time.time_index[-1]
694
+ target_series = self.target_series[target_idx].values()
695
+ past_target_series = target_series[sampling_location : sampling_location + self.input_chunk_length]
696
+ # get past covariates
697
+ if self.covariates is not None:
698
+ past_covariates = self.covariates[target_idx].values()
699
+ past_covariates = past_covariates[sampling_location : sampling_location + self.input_chunk_length]
700
+ future_past_covariates = past_covariates[sampling_location + self.input_chunk_length : sampling_location + self.total_length]
701
+ else:
702
+ past_covariates = None
703
+ future_past_covariates = None
704
+ # get static covariates
705
+ if self.use_static_covariates:
706
+ static_covariates = self.target_series[target_idx].static_covariates_values(copy=True)
707
+ else:
708
+ static_covariates = None
709
+ # return arrays or arrays with TimeSeries
710
+ if self.array_output_only:
711
+ out = []
712
+ out += [past_target_series] if past_target_series is not None else []
713
+ out += [past_covariates] if past_covariates is not None else []
714
+ out += [future_past_covariates] if future_past_covariates is not None else []
715
+ out += [static_covariates] if static_covariates is not None else []
716
+ return tuple(out)
717
+ else:
718
+ return tuple([past_target_series,
719
+ past_covariates,
720
+ future_past_covariates,
721
+ static_covariates,
722
+ past_target_series_with_time,
723
+ past_end + past_target_series_with_time.freq])
724
+
725
+ def evalsample(
726
+ self, idx: int
727
+ ) -> TimeSeries:
728
+ """
729
+ Returns the future target series at the given index.
730
+ """
731
+ # get idx of target series
732
+ target_idx = 0
733
+ while idx >= len(self.valid_sampling_locations[target_idx]):
734
+ idx -= len(self.valid_sampling_locations[target_idx])
735
+ target_idx += 1
736
+ # get sampling location within the target series
737
+ sampling_location = self.valid_sampling_locations[target_idx][idx]
738
+ # get target series
739
+ target_series = self.target_series[target_idx][sampling_location + self.input_chunk_length : sampling_location + self.total_length]
740
+
741
+ return target_series
742
+
743
+ class SamplingDatasetInferenceDual(DualCovariatesInferenceDataset):
744
+ def __init__(
745
+ self,
746
+ target_series: Union[TimeSeries, Sequence[TimeSeries]],
747
+ covariates: Optional[Union[TimeSeries, Sequence[TimeSeries]]] = None,
748
+ n: int = 12,
749
+ input_chunk_length: int = 12,
750
+ output_chunk_length: int = 1,
751
+ use_static_covariates: bool = True,
752
+ random_state: Optional[int] = 0,
753
+ max_samples_per_ts: Optional[int] = None,
754
+ array_output_only: bool = False,
755
+ ):
756
+ """
757
+ Parameters
758
+ ----------
759
+ target_series
760
+ One or a sequence of target `TimeSeries`.
761
+ covariates
762
+ Optionally, some future-known covariates that are used for predictions. This argument is required
763
+ if the model was trained with future-known covariates.
764
+ n
765
+ Number of predictions into the future, could be greater than the output chunk length, in which case, the model
766
+ will be called autorregressively.
767
+ output_chunk_length
768
+ The length of the "output" series emitted by the model
769
+ input_chunk_length
770
+ The length of the "input" series fed to the model
771
+ use_static_covariates
772
+ Whether to use/include static covariate data from input series.
773
+ random_state
774
+ The random state to use for sampling.
775
+ max_samples_per_ts
776
+ The maximum number of samples to be drawn from each time series. If None, all samples will be drawn.
777
+ array_output_only
778
+ Whether __getitem__ returns only the arrays or adds the full `TimeSeries` object to the output tuple
779
+ This may cause problems with the torch collate and loader functions but works for Darts.
780
+ """
781
+ super().__init__(target_series = target_series,
782
+ covariates = covariates,
783
+ n = n,
784
+ input_chunk_length = input_chunk_length,
785
+ output_chunk_length = output_chunk_length,)
786
+
787
+ self.target_series = (
788
+ [target_series] if isinstance(target_series, TimeSeries) else target_series
789
+ )
790
+ self.covariates = (
791
+ [covariates] if isinstance(covariates, TimeSeries) else covariates
792
+ )
793
+
794
+ raise_if_not(
795
+ covariates is None or len(self.target_series) == len(self.covariates),
796
+ "The provided sequence of target series must have the same length as "
797
+ "the provided sequence of covariate series.",
798
+ )
799
+
800
+ # get valid sampling locations
801
+ self.valid_sampling_locations = get_valid_sampling_locations(target_series,
802
+ output_chunk_length,
803
+ input_chunk_length,
804
+ random_state,
805
+ max_samples_per_ts,)
806
+
807
+ # set parameters
808
+ self.output_chunk_length = output_chunk_length
809
+ self.input_chunk_length = input_chunk_length
810
+ self.total_length = input_chunk_length + output_chunk_length
811
+ self.total_number_samples = sum([len(v) for v in self.valid_sampling_locations.values()])
812
+ self.use_static_covariates = use_static_covariates
813
+ self.array_output_only = array_output_only
814
+
815
+ def __len__(self):
816
+ """
817
+ Returns the total number of possible (input, target) splits.
818
+ """
819
+ return self.total_number_samples
820
+
821
+ def __getitem__(self, idx: int):
822
+ # get idx of target series
823
+ target_idx = 0
824
+ while idx >= len(self.valid_sampling_locations[target_idx]):
825
+ idx -= len(self.valid_sampling_locations[target_idx])
826
+ target_idx += 1
827
+ # get sampling location within the target series
828
+ sampling_location = self.valid_sampling_locations[target_idx][idx]
829
+ # get target series
830
+ target_series = self.target_series[target_idx]
831
+ past_target_series_with_time = target_series[sampling_location : sampling_location + self.input_chunk_length]
832
+ past_end = past_target_series_with_time.time_index[-1]
833
+ target_series = self.target_series[target_idx].values()
834
+ past_target_series = target_series[sampling_location : sampling_location + self.input_chunk_length]
835
+ # get past covariates
836
+ if self.covariates is not None:
837
+ future_covariates = self.covariates[target_idx].values()
838
+ historic_future_covariates = future_covariates[sampling_location : sampling_location + self.input_chunk_length]
839
+ future_covariates = future_covariates[sampling_location + self.input_chunk_length : sampling_location + self.total_length]
840
+ else:
841
+ historic_future_covariates = None
842
+ future_covariates = None
843
+ # get static covariates
844
+ if self.use_static_covariates:
845
+ static_covariates = self.target_series[target_idx].static_covariates_values(copy=True)
846
+ else:
847
+ static_covariates = None
848
+ # return arrays or arrays with TimeSeries
849
+ if self.array_output_only:
850
+ out = []
851
+ out += [past_target_series] if past_target_series is not None else []
852
+ out += [historic_future_covariates] if historic_future_covariates is not None else []
853
+ out += [future_covariates] if future_covariates is not None else []
854
+ out += [static_covariates] if static_covariates is not None else []
855
+ return tuple(out)
856
+ else:
857
+ return tuple([past_target_series,
858
+ historic_future_covariates,
859
+ future_covariates,
860
+ static_covariates,
861
+ past_target_series_with_time,
862
+ past_end + past_target_series_with_time.freq,])
863
+
864
+ def evalsample(
865
+ self, idx: int
866
+ ) -> TimeSeries:
867
+ """
868
+ Returns the future target series at the given index.
869
+ """
870
+ # get idx of target series
871
+ target_idx = 0
872
+ while idx >= len(self.valid_sampling_locations[target_idx]):
873
+ idx -= len(self.valid_sampling_locations[target_idx])
874
+ target_idx += 1
875
+ # get sampling location within the target series
876
+ sampling_location = self.valid_sampling_locations[target_idx][idx]
877
+ # get target series
878
+ target_series = self.target_series[target_idx][sampling_location + self.input_chunk_length : sampling_location + self.total_length]
879
+
880
+ return target_series
881
+
utils/darts_evaluation.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import yaml
4
+ import random
5
+ from typing import Any, BinaryIO, Callable, Dict, List, Optional, Sequence, Tuple, Union
6
+
7
+ import numpy as np
8
+ from scipy import stats
9
+ import pandas as pd
10
+ import darts
11
+
12
+ from darts import models
13
+ from darts import metrics
14
+ from darts import TimeSeries
15
+
16
+ # import data formatter
17
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
18
+ from data_formatter.base import *
19
+ from utils.darts_processing import *
20
+
21
+ def _get_values(
22
+ series: TimeSeries, stochastic_quantile: Optional[float] = 0.5
23
+ ) -> np.ndarray:
24
+ """
25
+ Returns the numpy values of a time series.
26
+ For stochastic series, return either all sample values with (stochastic_quantile=None) or the quantile sample value
27
+ with (stochastic_quantile {>=0,<=1})
28
+ """
29
+ if series.is_deterministic:
30
+ series_values = series.univariate_values()
31
+ else: # stochastic
32
+ if stochastic_quantile is None:
33
+ series_values = series.all_values(copy=False)
34
+ else:
35
+ series_values = series.quantile_timeseries(
36
+ quantile=stochastic_quantile
37
+ ).univariate_values()
38
+ return series_values
39
+
40
+ def _get_values_or_raise(
41
+ series_a: TimeSeries,
42
+ series_b: TimeSeries,
43
+ intersect: bool,
44
+ stochastic_quantile: Optional[float] = 0.5,
45
+ remove_nan_union: bool = False,
46
+ ) -> Tuple[np.ndarray, np.ndarray]:
47
+ """Returns the processed numpy values of two time series. Processing can be customized with arguments
48
+ `intersect, stochastic_quantile, remove_nan_union`.
49
+
50
+ Raises a ValueError if the two time series (or their intersection) do not have the same time index.
51
+
52
+ Parameters
53
+ ----------
54
+ series_a
55
+ A univariate deterministic ``TimeSeries`` instance (the actual series).
56
+ series_b
57
+ A univariate (deterministic or stochastic) ``TimeSeries`` instance (the predicted series).
58
+ intersect
59
+ A boolean for whether or not to only consider the time intersection between `series_a` and `series_b`
60
+ stochastic_quantile
61
+ Optionally, for stochastic predicted series, return either all sample values with (`stochastic_quantile=None`)
62
+ or any deterministic quantile sample values by setting `stochastic_quantile=quantile` {>=0,<=1}.
63
+ remove_nan_union
64
+ By setting `remove_non_union` to True, remove all indices from `series_a` and `series_b` which have a NaN value
65
+ in either of the two input series.
66
+ """
67
+ series_a_common = series_a.slice_intersect(series_b) if intersect else series_a
68
+ series_b_common = series_b.slice_intersect(series_a) if intersect else series_b
69
+
70
+ series_a_det = _get_values(series_a_common, stochastic_quantile=stochastic_quantile)
71
+ series_b_det = _get_values(series_b_common, stochastic_quantile=stochastic_quantile)
72
+
73
+ if not remove_nan_union:
74
+ return series_a_det, series_b_det
75
+
76
+ b_is_deterministic = bool(len(series_b_det.shape) == 1)
77
+ if b_is_deterministic:
78
+ isnan_mask = np.logical_or(np.isnan(series_a_det), np.isnan(series_b_det))
79
+ else:
80
+ isnan_mask = np.logical_or(
81
+ np.isnan(series_a_det), np.isnan(series_b_det).any(axis=2).flatten()
82
+ )
83
+ return np.delete(series_a_det, isnan_mask), np.delete(
84
+ series_b_det, isnan_mask, axis=0
85
+ )
86
+
87
+ def rescale_and_backtest(series: Union[TimeSeries,
88
+ Sequence[TimeSeries]],
89
+ forecasts: Union[TimeSeries,
90
+ Sequence[TimeSeries],
91
+ Sequence[Sequence[TimeSeries]]],
92
+ metric: Union[
93
+ Callable[[TimeSeries, TimeSeries], float],
94
+ List[Callable[[TimeSeries, TimeSeries], float]],
95
+ ],
96
+ scaler: Callable[[TimeSeries], TimeSeries] = None,
97
+ reduction: Union[Callable[[np.ndarray], float], None] = np.mean,
98
+ likelihood: str = "GaussianMean",
99
+ cal_thresholds: Optional[np.ndarray] = np.linspace(0, 1, 11),
100
+ ):
101
+ """
102
+ Backtest the historical forecasts (as provided by Darts) on the series.
103
+
104
+ Parameters
105
+ ----------
106
+ series
107
+ The target time series.
108
+ forecasts
109
+ The forecasts.
110
+ scaler
111
+ The scaler used to scale the series.
112
+ metric
113
+ The metric or metrics to use for backtesting.
114
+ reduction
115
+ The reduction to apply to the metric.
116
+ likelihood
117
+ The likelihood to use for evaluating the model.
118
+ cal_thresholds
119
+ The thresholds to use for computing the calibration error.
120
+
121
+ Returns
122
+ -------
123
+ np.ndarray
124
+ Error array. If the reduction is none, array is of shape (n, p)
125
+ where n is the total number of samples (forecasts) and p is the number of metrics.
126
+ If the reduction is not none, array is of shape (k, p), where k is the number of series.
127
+ float
128
+ The estimated log-likelihood of the model on the data.
129
+ np.ndarray
130
+ The ECE for each time point in the forecast.
131
+ """
132
+ series = [series] if isinstance(series, TimeSeries) else series
133
+ forecasts = [forecasts] if isinstance(forecasts, TimeSeries) else forecasts
134
+ metric = [metric] if not isinstance(metric, list) else metric
135
+
136
+ # compute errors: 1) reverse scaling forecasts and true values, 2)compute errors
137
+ backtest_list = []
138
+ for idx in range(len(series)):
139
+ if scaler is not None:
140
+ series[idx] = scaler.inverse_transform(series[idx])
141
+ forecasts[idx] = [scaler.inverse_transform(f) for f in forecasts[idx]]
142
+ errors = [
143
+ [metric_f(series[idx], f) for metric_f in metric]
144
+ if len(metric) > 1
145
+ else metric[0](series[idx], f)
146
+ for f in forecasts[idx]
147
+ ]
148
+ if reduction is None:
149
+ backtest_list.append(np.array(errors))
150
+ else:
151
+ backtest_list.append(reduction(np.array(errors), axis=0))
152
+ backtest_list = np.vstack(backtest_list)
153
+
154
+ if likelihood == "GaussianMean":
155
+ # compute likelihood
156
+ est_var = []
157
+ for idx, target_ts in enumerate(series):
158
+ est_var += [metrics.mse(target_ts, f) for f in forecasts[idx]]
159
+ est_var = np.mean(est_var)
160
+ forecast_len = forecasts[0][0].n_timesteps
161
+ log_likelihood = -0.5*forecast_len - 0.5*np.log(2*np.pi*est_var)
162
+
163
+ # compute calibration error: 1) cdf values 2) compute calibration error
164
+ # compute the cdf values
165
+ cdf_vals = []
166
+ for idx in range(len(series)):
167
+ for forecast in forecasts[idx]:
168
+ y_true, y_pred = _get_values_or_raise(series[idx],
169
+ forecast,
170
+ intersect=True,
171
+ remove_nan_union=True)
172
+ y_true, y_pred = y_true.flatten(), y_pred.flatten()
173
+ cdf_vals.append(stats.norm.cdf(y_true, loc=y_pred, scale=np.sqrt(est_var)))
174
+ cdf_vals = np.vstack(cdf_vals)
175
+ # compute the prediction calibration
176
+ cal_error = np.zeros(forecasts[0][0].n_timesteps)
177
+ for p in cal_thresholds:
178
+ est_p = (cdf_vals <= p).astype(float)
179
+ est_p = np.mean(est_p, axis=0)
180
+ cal_error += (est_p - p) ** 2
181
+
182
+ return backtest_list, log_likelihood, cal_error
183
+
184
+ def rescale_and_test(series: Union[TimeSeries,
185
+ Sequence[TimeSeries]],
186
+ forecasts: Union[TimeSeries,
187
+ Sequence[TimeSeries]],
188
+ metric: Union[
189
+ Callable[[TimeSeries, TimeSeries], float],
190
+ List[Callable[[TimeSeries, TimeSeries], float]],
191
+ ],
192
+ scaler: Callable[[TimeSeries], TimeSeries] = None,
193
+ likelihood: str = "GaussianMean",
194
+ cal_thresholds: Optional[np.ndarray] = np.linspace(0, 1, 11),
195
+ ):
196
+ """
197
+ Test the forecasts on the series.
198
+
199
+ Parameters
200
+ ----------
201
+ series
202
+ The target time series.
203
+ forecasts
204
+ The forecasts.
205
+ scaler
206
+ The scaler used to scale the series.
207
+ metric
208
+ The metric or metrics to use for backtesting.
209
+ reduction
210
+ The reduction to apply to the metric.
211
+ likelihood
212
+ The likelihood to use for evaluating the likelihood and calibration of model.
213
+ cal_thresholds
214
+ The thresholds to use for computing the calibration error.
215
+
216
+ Returns
217
+ -------
218
+ np.ndarray
219
+ Error array. If the reduction is none, array is of shape (n, p)
220
+ where n is the total number of samples (forecasts) and p is the number of metrics.
221
+ If the reduction is not none, array is of shape (k, p), where k is the number of series.
222
+ float
223
+ The estimated log-likelihood of the model on the data.
224
+ np.ndarray
225
+ The ECE for each time point in the forecast.
226
+ """
227
+ series = [series] if isinstance(series, TimeSeries) else series
228
+ forecasts = [forecasts] if isinstance(forecasts, TimeSeries) else forecasts
229
+ metric = [metric] if not isinstance(metric, list) else metric
230
+
231
+ # compute errors: 1) reverse scaling forecasts and true values, 2)compute errors
232
+ series = scaler.inverse_transform(series)
233
+ forecasts = scaler.inverse_transform(forecasts)
234
+ errors = [
235
+ [metric_f(t, f) for metric_f in metric]
236
+ if len(metric) > 1
237
+ else metric[0](t, f)
238
+ for (t, f) in zip(series, forecasts)
239
+ ]
240
+ errors = np.array(errors)
241
+
242
+ if likelihood == "GaussianMean":
243
+ # compute likelihood
244
+ est_var = [metrics.mse(t, f) for (t, f) in zip(series, forecasts)]
245
+ est_var = np.mean(est_var)
246
+ forecast_len = forecasts[0].n_timesteps
247
+ log_likelihood = -0.5*forecast_len - 0.5*np.log(2*np.pi*est_var)
248
+
249
+ # compute calibration error: 1) cdf values 2) compute calibration error
250
+ # compute the cdf values
251
+ cdf_vals = []
252
+ for t, f in zip(series, forecasts):
253
+ t, f = _get_values_or_raise(t, f, intersect=True, remove_nan_union=True)
254
+ t, f = t.flatten(), f.flatten()
255
+ cdf_vals.append(stats.norm.cdf(t, loc=f, scale=np.sqrt(est_var)))
256
+ cdf_vals = np.vstack(cdf_vals)
257
+ # compute the prediction calibration
258
+ cal_error = np.zeros(forecasts[0].n_timesteps)
259
+ for p in cal_thresholds:
260
+ est_p = (cdf_vals <= p).astype(float)
261
+ est_p = np.mean(est_p, axis=0)
262
+ cal_error += (est_p - p) ** 2
263
+
264
+ if likelihood == "Quantile":
265
+ # no likelihood since we don't have a parametric model
266
+ log_likelihood = 0
267
+
268
+ # compute calibration error: 1) get quantiles 2) compute calibration error
269
+ cal_error = np.zeros(forecasts[0].n_timesteps)
270
+ for p in cal_thresholds:
271
+ est_p = 0
272
+ for t, f in zip(series, forecasts):
273
+ q = f.quantile(p)
274
+ t, q = _get_values_or_raise(t, q, intersect=True, remove_nan_union=True)
275
+ t, q = t.flatten(), q.flatten()
276
+ est_p += (t <= q).astype(float)
277
+ est_p = (est_p / len(series)).flatten()
278
+ cal_error += (est_p - p) ** 2
279
+
280
+ return errors, log_likelihood, cal_error
utils/darts_processing.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import yaml
4
+ import random
5
+ from typing import Any, BinaryIO, Callable, Dict, List, Optional, Sequence, Tuple, Union
6
+ from pathlib import Path
7
+ import numpy as np
8
+ from scipy import stats
9
+ import pandas as pd
10
+ import darts
11
+
12
+ from darts import models
13
+ from darts import metrics
14
+ from darts import TimeSeries
15
+ from darts.dataprocessing.transformers import Scaler
16
+ from pytorch_lightning.callbacks import Callback
17
+ from sympy import pprint
18
+
19
+ # import data formatter
20
+ sys.path.append(os.path.join(os.path.dirname(__file__), '..'))
21
+ from data_formatter.base import *
22
+
23
+ pd.set_option('display.width', None) # Set display width to None to avoid truncation
24
+ pd.set_option('display.max_columns', None) # Display all columns
25
+
26
+ def make_series(data: Dict[str, pd.DataFrame],
27
+ time_col: str,
28
+ group_col: str,
29
+ value_cols: Dict[str, List[str]],
30
+ include_sid: bool = False,
31
+ verbose: bool = False
32
+ ) -> Dict[str, darts.TimeSeries]:
33
+ """
34
+ Makes the TimeSeries from the data.
35
+
36
+ Parameters
37
+ ----------
38
+ data
39
+ dict of train, val, test dataframes
40
+ time_col
41
+ name of time column
42
+ group_col
43
+ name of group column
44
+ value_cols
45
+ dict with key specifying the type of covariate and value specifying the list of columns.
46
+ include_sid
47
+ whether to include segment id as static covariate
48
+
49
+ Returns
50
+ -------
51
+ series: Dict[str, Dict[str, darts.TimeSeries]]
52
+ dict of train, val, test splits of target and covariates TimeSeries objects
53
+ scalers: Dict[str, darts.preprocessing.Scaler]
54
+ dict of scalers for target and covariates
55
+ """
56
+ series = {i: {j: None for j in value_cols} for i in data.keys()}
57
+ scalers = {}
58
+ for key, df in data.items():
59
+
60
+ for name, cols in value_cols.items():
61
+ # Adjust display settings
62
+ if verbose:
63
+ print(f"DATAFRAME for key {key} in NAME {name} and COLS {cols} and GROUP_COL {group_col}")
64
+ pprint(df.head(1))
65
+ series[key][name] = TimeSeries.from_group_dataframe(df = df,
66
+ group_cols = group_col,
67
+ time_col = time_col,
68
+ value_cols = cols) if cols is not None else None
69
+ if series[key][name] is not None and include_sid is False:
70
+ for i in range(len(series[key][name])):
71
+ series[key][name][i] = series[key][name][i].with_static_covariates(None)
72
+ if cols is not None:
73
+ if key == 'train':
74
+ scalers[name] = ScalerCustom()
75
+ series[key][name] = scalers[name].fit_transform(series[key][name])
76
+ else:
77
+ series[key][name] = scalers[name].transform(series[key][name])
78
+ else:
79
+ scalers[name] = None
80
+ return series, scalers
81
+
82
+ def load_data(url: str,
83
+ config_path: Path,
84
+ use_covs: bool = False,
85
+ cov_type: str = 'past',
86
+ use_static_covs: bool = False, seed = 0):
87
+ """
88
+ Load data according to the specified config file and covert to Darts TimeSeries objects.
89
+
90
+ Parameters
91
+ ----------
92
+ seed: int
93
+ Random seed for data splitting.
94
+ study_file: str
95
+ Path to the study file.
96
+ dataset: str
97
+ Name of the dataset.
98
+ use_covs: bool
99
+ Whether to use covariates.
100
+ cov_type: str
101
+ Type of covariates to use. Can be 'past' or 'mixed' or 'dual'.
102
+ use_static_covs: bool
103
+ Whether to use static covariates.
104
+
105
+ Returns
106
+ -------
107
+ formatter: DataFormatter
108
+ Data formatter object.
109
+ series: Dict[str, Dict[str, TimeSeries]]
110
+ First dictionary specified the split, second dictionary specifies the type of series (target or covariate).
111
+ scalers: Dict[str, Scaler]
112
+ Dictionary of scalers with key indicating the type of series (target or covariate).
113
+ """
114
+
115
+
116
+ """
117
+ config={
118
+ 'data_csv_path':f'{url}',
119
+ 'drop': None,
120
+ 'ds_name': 'livia_mini',
121
+ 'index_col': -1,
122
+ 'observation_interval': '5min',
123
+ 'column_definition': {
124
+ {'data_type': 'categorical',
125
+ 'input_type':'id',
126
+ 'name':'id'
127
+ },
128
+ {'date_type':'date',
129
+ 'input_type':'time',
130
+ 'name':'time'
131
+ },
132
+ {'date_type':'real_valued',
133
+ 'input_type':'target',
134
+ 'name':'gl'
135
+ }
136
+ },
137
+ 'encoding_params':{'date':['day','month','year','hour','minute','second']
138
+ },
139
+ 'nan_vals':None,
140
+ 'interpolation_params':{'gap_threshold': 45,
141
+ 'min_drop_length': 240
142
+ },
143
+ 'scaling_params':{'scaler':None
144
+ },
145
+ 'split_params':{'length_segment': 13,
146
+ 'random_state':seed,
147
+ 'test_percent_subjects': 0.1
148
+ },
149
+ 'max_length_input': 192,
150
+ 'length_pred': 12,
151
+ 'params':{
152
+ 'gluformer':{'in_len': 96,
153
+ 'd_model': 512,
154
+ 'n_heads': 10,
155
+ 'd_fcn': 1024,
156
+ 'num_enc_layers': 2,
157
+ 'num_dec_layers': 2,
158
+ 'length_pred': 12
159
+ }
160
+ }
161
+ }
162
+ """
163
+ with config_path.open("r") as f:
164
+ config = yaml.safe_load(f)
165
+ config["data_csv_path"] = url
166
+
167
+ formatter = DataFormatter(config)
168
+ #assert dataset is not None, 'dataset must be specified in the load_data call'
169
+ assert use_covs is not None, 'use_covs must be specified in the load_data call'
170
+
171
+ # convert to series
172
+ time_col = formatter.get_column('time')
173
+ group_col = formatter.get_column('sid')
174
+ target_col = formatter.get_column('target')
175
+ static_cols = formatter.get_column('static_covs')
176
+ static_cols = static_cols + [formatter.get_column('id')] if static_cols is not None else [formatter.get_column('id')]
177
+ dynamic_cols = formatter.get_column('dynamic_covs')
178
+ future_cols = formatter.get_column('future_covs')
179
+
180
+ data = {'train': formatter.train_data,
181
+ 'val': formatter.val_data,
182
+ 'test': formatter.test_data.loc[~formatter.test_data.index.isin(formatter.test_idx_ood)],
183
+ 'test_ood': formatter.test_data.loc[formatter.test_data.index.isin(formatter.test_idx_ood)]}
184
+ value_cols = {'target': target_col,
185
+ 'static': static_cols,
186
+ 'dynamic': dynamic_cols,
187
+ 'future': future_cols}
188
+ # build series
189
+ series, scalers = make_series(data,
190
+ time_col,
191
+ group_col,
192
+ value_cols)
193
+ if not use_covs:
194
+ # set dynamic and future covariates to None
195
+ for split in ['train', 'val', 'test', 'test_ood']:
196
+ for cov in ['dynamic', 'future']:
197
+ series[split][cov] = None
198
+ elif use_covs and cov_type == 'mixed':
199
+ pass # this is the default for make_series()
200
+ elif use_covs and cov_type == 'past':
201
+ # use future covariates as dynamic (past) covariates
202
+ if series['train']['dynamic'] is None:
203
+ for split in ['train', 'val', 'test', 'test_ood']:
204
+ series[split]['dynamic'] = series[split]['future']
205
+ else:
206
+ for split in ['train', 'val', 'test', 'test_ood']:
207
+ for i in range(len(series[split]['future'])):
208
+ series[split]['dynamic'][i] = series[split]['dynamic'][i].concatenate(series[split]['future'][i], axis=1)
209
+ # erase future covariates
210
+ for split in ['train', 'val', 'test', 'test_ood']:
211
+ series[split]['future'] = None
212
+ elif use_covs and cov_type == 'dual':
213
+ # erase dynamic (past) covariates
214
+ for split in ['train', 'val', 'test', 'test_ood']:
215
+ series[split]['dynamic'] = None
216
+
217
+ if use_static_covs:
218
+ # attach static covariates to series
219
+ for split in ['train', 'val', 'test', 'test_ood']:
220
+ for i in range(len(series[split]['target'])):
221
+ static_covs = series[split]['static'][i][0].pd_dataframe()
222
+ series[split]['target'][i] = series[split]['target'][i].with_static_covariates(static_covs)
223
+
224
+ return formatter, series, scalers
225
+
226
+ def reshuffle_data(formatter: DataFormatter,
227
+ seed: int = 0,
228
+ use_covs: bool = None,
229
+ cov_type: str = 'past',
230
+ use_static_covs: bool = False,):
231
+ """
232
+ Reshuffle data according to the seed and covert to Darts TimeSeries objects.
233
+
234
+ Parameters
235
+ ----------
236
+ formatter: DataFormatter
237
+ Data formatter object containing the data
238
+ seed: int
239
+ Random seed for data splitting.
240
+ use_covs: bool
241
+ Whether to use covariates.
242
+ cov_type: str
243
+ Type of covariates to use. Can be 'past' or 'mixed' or 'dual'.
244
+ use_static_covs: bool
245
+ Whether to use static covariates.
246
+
247
+ Returns
248
+ -------
249
+ formatter: DataFormatter
250
+ Reshuffled data formatter object.
251
+ series: Dict[str, Dict[str, TimeSeries]]
252
+ First dictionary specified the split, second dictionary specifies the type of series (target or covariate).
253
+ scalers: Dict[str, Scaler]
254
+ Dictionary of scalers with key indicating the type of series (target or covariate).
255
+ """
256
+ # reshuffle
257
+ formatter.reshuffle(seed)
258
+ assert use_covs is not None, 'use_covs must be specified in the reshuffle_data call'
259
+
260
+ # convert to series
261
+ time_col = formatter.get_column('time')
262
+ group_col = formatter.get_column('sid')
263
+ target_col = formatter.get_column('target')
264
+ static_cols = formatter.get_column('static_covs')
265
+ static_cols = static_cols + [formatter.get_column('id')] if static_cols is not None else [formatter.get_column('id')]
266
+ dynamic_cols = formatter.get_column('dynamic_covs')
267
+ future_cols = formatter.get_column('future_covs')
268
+
269
+ # build series
270
+ series, scalers = make_series({'train': formatter.train_data,
271
+ 'val': formatter.val_data,
272
+ 'test': formatter.test_data.loc[~formatter.test_data.index.isin(formatter.test_idx_ood)],
273
+ 'test_ood': formatter.test_data.loc[formatter.test_data.index.isin(formatter.test_idx_ood)]},
274
+ time_col,
275
+ group_col,
276
+ {'target': target_col,
277
+ 'static': static_cols,
278
+ 'dynamic': dynamic_cols,
279
+ 'future': future_cols})
280
+
281
+ if not use_covs:
282
+ # set dynamic and future covariates to None
283
+ for split in ['train', 'val', 'test', 'test_ood']:
284
+ for cov in ['dynamic', 'future']:
285
+ series[split][cov] = None
286
+ elif use_covs and cov_type == 'past':
287
+ # use future covariates as dynamic covariates
288
+ if series['train']['dynamic'] is None:
289
+ for split in ['train', 'val', 'test', 'test_ood']:
290
+ series[split]['dynamic'] = series[split]['future']
291
+ # or attach them to dynamic covariates
292
+ else:
293
+ for split in ['train', 'val', 'test', 'test_ood']:
294
+ for i in range(len(series[split]['future'])):
295
+ series[split]['dynamic'][i] = series[split]['dynamic'][i].concatenate(series[split]['future'][i], axis=1)
296
+ elif use_covs and cov_type == 'dual':
297
+ # set dynamic covariates to None, because they are not supported
298
+ for split in ['train', 'val', 'test', 'test_ood']:
299
+ series[split]['dynamic'] = None
300
+
301
+ if use_static_covs:
302
+ # attach static covariates to series
303
+ for split in ['train', 'val', 'test', 'test_ood']:
304
+ for i in range(len(series[split]['target'])):
305
+ static_covs = series[split]['static'][i][0].pd_dataframe()
306
+ series[split]['target'][i] = series[split]['target'][i].with_static_covariates(static_covs)
307
+
308
+ return formatter, series, scalers
309
+
310
+ class ScalerCustom:
311
+ '''
312
+ Min-max scaler for TimeSeries that fits on all sequences simultaenously.
313
+ Default Darts scaler fits one scaler per sequence in the list.
314
+
315
+ Attributes
316
+ ----------
317
+ scaler: Scaler
318
+ Darts scaler object.
319
+ min_: np.ndarray
320
+ Per feature adjustment for minimum (see Scikit-learn).
321
+ scale_: np.ndarray
322
+ Per feature relative scaling of the data (see Scikit-learn).
323
+ '''
324
+ def __init__(self):
325
+ self.scaler = Scaler()
326
+ self.min_ = None
327
+ self.scale_ = None
328
+
329
+ def fit(self, time_series: Union[List[TimeSeries], TimeSeries]) -> None:
330
+
331
+ if isinstance(time_series, list):
332
+
333
+ # extract series as Pandas dataframe
334
+ df = pd.concat([ts.pd_dataframe() for ts in time_series])
335
+ value_cols = df.columns
336
+ df.reset_index(inplace=True)
337
+ # create new equally spaced time grid
338
+ df['new_time'] = pd.date_range(start=df['time'].min(), periods=len(df), freq='1h')
339
+ # fit scaler
340
+ series = TimeSeries.from_dataframe(df, time_col='new_time', value_cols=value_cols)
341
+ series = self.scaler.fit(series)
342
+ else:
343
+ series = self.scaler.fit(time_series)
344
+ # extract min and scale
345
+ self.min_ = self.scaler._fitted_params[0].min_
346
+ self.scale_ = self.scaler._fitted_params[0].scale_
347
+
348
+ def transform(self, time_series: Union[List[TimeSeries], TimeSeries]) -> Union[List[TimeSeries], TimeSeries]:
349
+ if isinstance(time_series, list):
350
+ # transform one by one
351
+ series = [self.scaler.transform(ts) for ts in time_series]
352
+ else:
353
+ series = self.scaler.transform(time_series)
354
+ return series
355
+
356
+ def inverse_transform(self, time_series: Union[List[TimeSeries], TimeSeries]) -> Union[List[TimeSeries], TimeSeries]:
357
+ if isinstance(time_series, list):
358
+ # transform one by one
359
+ series = [self.scaler.inverse_transform(ts) for ts in time_series]
360
+ else:
361
+ series = self.scaler.inverse_transform(time_series)
362
+ return series
363
+
364
+ def fit_transform(self, time_series: Union[List[TimeSeries], TimeSeries]) -> Union[List[TimeSeries], TimeSeries]:
365
+ self.fit(time_series)
366
+ series = self.transform(time_series)
367
+ return series
utils/darts_training.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import yaml
4
+ import random
5
+ from typing import Any, BinaryIO, Callable, Dict, List, Optional, Sequence, Tuple, Union
6
+
7
+ import numpy as np
8
+ from scipy import stats
9
+ import pandas as pd
10
+ import darts
11
+
12
+ from darts import models
13
+ from darts import metrics
14
+ from darts import TimeSeries
15
+ from pytorch_lightning.callbacks import Callback
16
+ from darts.logging import get_logger, raise_if_not
17
+
18
+ # for optuna callback
19
+ import warnings
20
+ import optuna
21
+ from optuna.storages._cached_storage import _CachedStorage
22
+ from optuna.storages._rdb.storage import RDBStorage
23
+ # Define key names of `Trial.system_attrs`.
24
+ _PRUNED_KEY = "ddp_pl:pruned"
25
+ _EPOCH_KEY = "ddp_pl:epoch"
26
+ with optuna._imports.try_import() as _imports:
27
+ import pytorch_lightning as pl
28
+ from pytorch_lightning import LightningModule
29
+ from pytorch_lightning import Trainer
30
+ from pytorch_lightning.callbacks import Callback
31
+ if not _imports.is_successful():
32
+ Callback = object # type: ignore # NOQA
33
+ LightningModule = object # type: ignore # NOQA
34
+ Trainer = object # type: ignore # NOQA
35
+
36
+ def print_callback(study, trial, study_file=None):
37
+ # write output to a file
38
+ with open(study_file, "a") as f:
39
+ f.write(f"Current value: {trial.value}, Current params: {trial.params}\n")
40
+ f.write(f"Best value: {study.best_value}, Best params: {study.best_trial.params}\n")
41
+
42
+ def early_stopping_check(study,
43
+ trial,
44
+ study_file,
45
+ early_stopping_rounds=10):
46
+ """
47
+ Early stopping callback for Optuna.
48
+ This function checks the current trial number and the best trial number.
49
+ """
50
+ current_trial_number = trial.number
51
+ best_trial_number = study.best_trial.number
52
+ should_stop = (current_trial_number - best_trial_number) >= early_stopping_rounds
53
+ if should_stop:
54
+ with open(study_file, 'a') as f:
55
+ f.write('\nEarly stopping at trial {} (best trial: {})'.format(current_trial_number, best_trial_number))
56
+ study.stop()
57
+
58
+ class LossLogger(Callback):
59
+ def __init__(self):
60
+ self.train_loss = []
61
+ self.val_loss = []
62
+
63
+ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
64
+ self.train_loss.append(float(trainer.callback_metrics["train_loss"]))
65
+
66
+ def on_validation_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
67
+ self.val_loss.append(float(trainer.callback_metrics["val_loss"]))
68
+
69
+ class PyTorchLightningPruningCallback(Callback):
70
+ """PyTorch Lightning callback to prune unpromising trials.
71
+ See `the example <https://github.com/optuna/optuna-examples/blob/
72
+ main/pytorch/pytorch_lightning_simple.py>`__
73
+ if you want to add a pruning callback which observes accuracy.
74
+ Args:
75
+ trial:
76
+ A :class:`~optuna.trial.Trial` corresponding to the current evaluation of the
77
+ objective function.
78
+ monitor:
79
+ An evaluation metric for pruning, e.g., ``val_loss`` or
80
+ ``val_acc``. The metrics are obtained from the returned dictionaries from e.g.
81
+ ``pytorch_lightning.LightningModule.training_step`` or
82
+ ``pytorch_lightning.LightningModule.validation_epoch_end`` and the names thus depend on
83
+ how this dictionary is formatted.
84
+ """
85
+
86
+ def __init__(self, trial: optuna.trial.Trial, monitor: str) -> None:
87
+ super().__init__()
88
+
89
+ self._trial = trial
90
+ self.monitor = monitor
91
+
92
+ def on_validation_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
93
+ # When the trainer calls `on_validation_end` for sanity check,
94
+ # do not call `trial.report` to avoid calling `trial.report` multiple times
95
+ # at epoch 0. The related page is
96
+ # https://github.com/PyTorchLightning/pytorch-lightning/issues/1391.
97
+ if trainer.sanity_checking:
98
+ return
99
+
100
+ epoch = pl_module.current_epoch
101
+
102
+ current_score = trainer.callback_metrics.get(self.monitor)
103
+ if current_score is None:
104
+ message = (
105
+ "The metric '{}' is not in the evaluation logs for pruning. "
106
+ "Please make sure you set the correct metric name.".format(self.monitor)
107
+ )
108
+ warnings.warn(message)
109
+ return
110
+
111
+ self._trial.report(current_score, step=epoch)
112
+ if self._trial.should_prune():
113
+ message = "Trial was pruned at epoch {}.".format(epoch)
114
+ raise optuna.TrialPruned(message)