Визуализация архитектуры сети
Понимание архитектуры модели может быть очень полезным как для отладки вашей сети, так и для понимания её поведения.
В этом документе мы построим глобальную модель, используя данные о часовой нагрузке региона ERCOT. В качестве референса мы используем записную книжку ./global_modeling.ipynb
Наконец, мы визуализируем архитектуру сети.
Сначала мы установим Graphviz. Для Windows перейдите по ссылке https://www.graphviz.org/download/. Для Mac/Linux выполните команду ниже.
try:
# зависимости уже установлены
from torchsummary import summary
from torchviz import make_dot
except:
# установка graphviz в системе
import platform
if "Darwin" == platform.system():
!brew install graphviz
elif "Linux" == platform.system():
!sudo apt install graphviz
else:
print("перейдите по ссылке https://www.graphviz.org/download/")
# Затем нам нужно установить следующие зависимости:
!pip install torchsummary
!pip install torch-summary
!pip install torchviz
!pip install graphviz
# импорт
from torchsummary import summary
from torchviz import make_dot
try:
from neuralprophet import NeuralProphet
except:
# если NeuralProphet еще не установлен:
!pip install git+https://github.com/ourownstory/neural_prophet.git
from neuralprophet import NeuralProphet
import pandas as pd
from neuralprophet import set_log_level
set_log_level("ERROR")
Сначала загружаем данные:
data_location = "https://raw.githubusercontent.com/ourownstory/neuralprophet-data/main/datasets/"
df_ercot = pd.read_csv(data_location + "multivariate/load_ercot_regions.csv")
df_ercot.head(3)
0
2004-01-01 01:00:00
7225.09
877.79
1044.89
745.79
7124.21
1660.45
3639.12
654.61
1
2004-01-01 02:00:00
6994.25
850.75
1032.04
721.34
6854.58
1603.52
3495.16
639.88
2
2004-01-01 03:00:00
6717.42
831.63
1021.10
699.70
6639.48
1527.99
3322.70
623.42
Извлекаем названия регионов, которые впоследствии будут использоваться при создании модели.
regions = list(df_ercot)[1:]
Глобальные модели могут быть активированы, когда входные данные df
функции содержат дополнительную колонку ID
, которая идентифицирует различные временные ряды (помимо типичной колонки ds
, содержащей временные метки, и колонки y
, содержащей наблюдаемые значения временного ряда). В нашем примере мы выбираем данные за трехлетний интервал (с 2004 по 2007 год).
df_global = pd.DataFrame()
for col in regions:
aux = df_ercot[["ds", col]].copy(deep=True) # select column associated with region
aux = aux.iloc[:26301, :].copy(deep=True) # selects data up to 26301 row (2004 to 2007 time stamps)
aux = aux.rename(columns={col: "y"}) # rename column of data to 'y' which is compatible with Neural Prophet
aux["ID"] = col
df_global = pd.concat((df_global, aux))
df_global.head(3)
0
2004-01-01 01:00:00
7225.09
COAST
1
2004-01-01 02:00:00
6994.25
COAST
2
2004-01-01 03:00:00
6717.42
COAST
Когда входными данными для функции split_df
является pd.DataFrame с колонкой ‘ID’, обучающие и валидационные данные предоставляются в аналогичном формате. Для глобальных моделей входные данные обычно разделяются в соответствии с долей времени, охватывающей все временные ряды (по умолчанию, когда есть более одного ‘ID’ и когда local_split=False
). Если пользователь хочет разделить каждый временной ряд локально, параметр local_split
должен быть установлен в значение True. В этом примере мы разделим наши данные на обучающую и тестовую выборки (с долей теста 33% - 2 года обучения и 1 год теста).
Глобальное моделирование позволяет нам тренировать нашу модель, основываясь либо на глобальной, либо на локальной нормализации. В последнем случае каждый временной ряд нормализуется локально (у каждого временного ряда есть свои соответствующие параметры данных). В первом случае у нас есть уникальные параметры данных, которые будут использоваться во всех рассматриваемых временных рядах.
Глобальное моделирование — локальная нормализация
m = NeuralProphet(n_lags=24, epochs=2, learning_rate=0.1)
df_train, df_test = m.split_df(df_global, valid_p=0.33, local_split=True)
Процесс стандартной подгонки глобальных моделей основан на локализованной нормализации данных. Каждый временной ряд будет иметь параметры нормализации данных, связанные с каждым предоставленным идентификатором ('ID'). Мы собираемся определить модель, которая прогнозирует следующий час на основе данных последних 24 часов.
После создания объекта NeuralProphet
, модель можно создать, вызвав функцию fit
metrics = m.fit(df_train, freq="H")
metrics.tail(1)
1
0.025526
0.033126
0.000416
0.0
1
1. Сводная информация о сети
От https://pypi.org/project/torch-summary/ :
Torch-summary предоставляет информацию, дополняющую то, что предоставляет print(your_model)
в PyTorch, аналогично API model.summary()
в Tensorflow для визуализации модели, что помогает при отладке вашей сети. В этом проекте мы реализуем аналогичный функционал в PyTorch и создаем чистый, простой интерфейс для использования в ваших проектах.
display(summary(m.model))
=================================================================
Layer (type:depth-idx) Param #
=================================================================
├─MetricCollection: 1-1 --
| └─MeanAbsoluteError: 2-1 --
| └─MeanSquaredError: 2-2 --
├─MetricCollection: 1-2 --
| └─MeanAbsoluteError: 2-3 --
| └─MeanSquaredError: 2-4 --
├─GlobalPiecewiseLinearTrend: 1-3 13
├─GlobalFourierSeasonality: 1-4 --
| └─ParameterDict: 2-5 30
├─ModuleList: 1-5 --
| └─Linear: 2-6 24
=================================================================
Total params: 67
Trainable params: 67
Non-trainable params: 0
=================================================================
=================================================================
Layer (type:depth-idx) Param #
=================================================================
├─MetricCollection: 1-1 --
| └─MeanAbsoluteError: 2-1 --
| └─MeanSquaredError: 2-2 --
├─MetricCollection: 1-2 --
| └─MeanAbsoluteError: 2-3 --
| └─MeanSquaredError: 2-4 --
├─GlobalPiecewiseLinearTrend: 1-3 13
├─GlobalFourierSeasonality: 1-4 --
| └─ParameterDict: 2-5 30
├─ModuleList: 1-5 --
| └─Linear: 2-6 24
=================================================================
Total params: 67
Trainable params: 67
Non-trainable params: 0
=================================================================
2. Сетевая визуализация
От https://github.com/szagoruyko/pytorchviz :
Небольшой пакет для создания визуализаций графиков и трассировок выполнения PyTorch.
fig = make_dot(m.model.train_epoch_prediction, params=dict(m.model.named_parameters()))
# fig_glob.render(filename='img/fig_glob')
display(fig)
Last updated