Оборачиваем ONNX-модели в классы

Оглавление

Введение

В предыдущей статье мы использовали две ONNX-модели для организации классификатора голосования. При этом весь исходный текст был организван в виде одного MQ5-файла. Да, весь код разбит на функции. Но попробуйте, например, поменять местами модели. А если добавить ещё модель? Исходный текст ещё более распухнет. Попробуем объектно-ориентированный подход.

1. Какие модели мы собираемся использовать

В предыдущем классификаторе голосования мы использовали одну классификационную модель и одну регрессионную модель. В регрессионной модели вместо предсказанного движения цены (вниз, вверх, не изменяется) мы получаем предсказанную цену, на основе которой и вычисляем класс. Однако, в этом случае мы не имеем распределения вероятностей по классам, что не позволяет проводить так называемое "мягкое голосование".

Мы подготовили 3 классификационные модели. Две модели уже использовались в статье "Пример ансамбля ONNX-моделей в MQL5". Первая модель — регрессионная — переделана в классификационную. Обучение проводилось на сериях из 10 цен OHLC. Вторая модель — классификационная. Обучение проводилось на сериях из 63 цен Close.

Наконец, ещё одна модель. Классификационная модель обучалась на сериях из 30 цен Close и сериях простых скользящих средних с периодами усреднения 21 и 34. Мы не делали никаких предположений по поводу пересечения скользящих средних с графиком Close и между собой — все закономерности посчитает и запомнит сеть в виде матриц коэффициентов между слоями.

Все модели обучались на данных сервера MetaQuotes-Demo, EURUSD D1 с 2010.01.01 по 2023.01.01. Тренировочные скрипты всех трёх моделей написаны на Питоне и приложены к данной статье. Мы не будем приводить их исходные коды здесь, чтобы не отвлекать внимание читателя от основной темы нашей статьи.

2. Нужен один базовый класс для всех моделей

Три модели. Каждая отличается от другой размером входных данных, подготовкой входных данных. У всех моделей есть общность. Один и тот же интерфейс. Классы всех моделей должны наследоваться от одного и того же базового класса.

Попробуем представить базовый класс.

//+------------------------------------------------------------------+
//|                                            ModelSymbolPeriod.mqh |
//|                            Авторские права 2023, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+

//--- предсказание движения цены
#define PRICE_UP   0
#define PRICE_SAME 1
#define PRICE_DOWN 2

//+------------------------------------------------------------------+
//| Базовый класс для моделей на основе обученного символа и периода |
//+------------------------------------------------------------------+
class CModelSymbolPeriod
  {
protected:
   long              m_handle;           // созданный обработчик сессии модели
   string            m_symbol;           // символ обученных данных
   ENUM_TIMEFRAMES   m_period;           // период обученных данных
   datetime          m_next_bar;         // время следующего бара (мы работаем только в начале бара)
   double            m_class_delta;      // дельта для определения "цена та же самая" в регрессионных моделях

public:
   //+------------------------------------------------------------------+
   //| Конструктор                                                      |
   //+------------------------------------------------------------------+
   CModelSymbolPeriod(const string symbol,const ENUM_TIMEFRAMES period,const double class_delta=0.0001)
     {
      m_handle=INVALID_HANDLE;
      m_symbol=symbol;
      m_period=period;
      m_next_bar=0;
      m_class_delta=class_delta;
     }

   //+------------------------------------------------------------------+
   //| Деструктор                                                       |
   //+------------------------------------------------------------------+
   ~CModelSymbolPeriod(void)
     {
      Shutdown();
     }

   //+------------------------------------------------------------------+
   //| Виртуальный заглушка для Init                                   |
   //+------------------------------------------------------------------+
   virtual bool Init(const string symbol, const ENUM_TIMEFRAMES period)
     {
      return(false);
     }

   //+------------------------------------------------------------------+
   //| Проверка на инициализацию, создание модели                       |
   //+------------------------------------------------------------------+
   bool CheckInit(const string symbol, const ENUM_TIMEFRAMES period,const uchar& model[])
     {
      //--- проверка символа, периода
      if(symbol!=m_symbol || period!=m_period)
        {
         PrintFormat("Модель должна работать с %s,%s",m_symbol,EnumToString(m_period));
         return(false);
        }

      //--- создание модели из статического буфера
      m_handle=OnnxCreateFromBuffer(model,ONNX_DEFAULT);
      if(m_handle==INVALID_HANDLE)
        {
         Print("Ошибка OnnxCreateFromBuffer ",GetLastError());
         return(false);
        }

      //--- успешно
      return(true);
     }

   //+------------------------------------------------------------------+
   //| Освободить сеанс ONNX                                            |
   //+------------------------------------------------------------------+
   void Shutdown(void)
     {
      if(m_handle!=INVALID_HANDLE)
        {
         OnnxRelease(m_handle);
         m_handle=INVALID_HANDLE;
        }
     }

   //+------------------------------------------------------------------+
   //| Проверка на продолжение OnTick                                   |
   //+------------------------------------------------------------------+
   virtual bool CheckOnTick(void)
     {
      //--- проверка на новый бар
      if(TimeCurrent()<m_next_bar)
         return(false);
      //--- установить время следующего бара
      m_next_bar=TimeCurrent();
      m_next_bar-=m_next_bar%PeriodSeconds(m_period);
      m_next_bar+=PeriodSeconds(m_period);

      //--- работать на новом дневном баре
      return(true);
     }

   //+------------------------------------------------------------------+
   //| Виртуальная заглушка для PredictPrice (регрессионная модель)     |
   //+------------------------------------------------------------------+
   virtual double PredictPrice(void)
     {
      return(DBL_MAX);
     }

   //+------------------------------------------------------------------+
   //| Предсказать класс (регрессия -> классификация)                   |
   //+------------------------------------------------------------------+
   virtual int PredictClass(vector& probabilities)
     {
      double predicted_price=PredictPrice();
      if(predicted_price==DBL_MAX)
         return(-1);

      int    predicted_class=-1;
      double last_close=iClose(m_symbol,m_period,1);
      //--- классифицировать предсказанное движение цены
      double delta=last_close-predicted_price;
      if(fabs(delta)<=m_class_delta)
         predicted_class=PRICE_SAME;
      else
        {
         if(delta<0)
            predicted_class=PRICE_UP;
         else
            predicted_class=PRICE_DOWN;
        }

      //--- установить вероятность предсказания как 1.0
      probabilities.Fill(0);
      if(predicted_class<(int)probabilities.Size())
         probabilities[predicted_class]=1;
      //--- и вернуть предсказанный класс
      return(predicted_class);
     }
  };
//+------------------------------------------------------------------+

Базовый класс можно использовать как для моделей регрессии, так и для моделей классификации. Надо будет только реализовать в классе-наследнике соответствующий метод — PredictPrice или PredictClass.

В базовом классе задаётся с каким символом-периодом должна работать модель (на каких данных обучалась модель). В базовом классе проводится проверка, что эксперт, использующий модель, работает на нужном символе-периоде, а также создаётся ONNX-сессия для исполнения модели. В базовом классе обеспечивается работа только в начале нового бара.

3. Класс для первой модели

Наша первая модель называется model.eurusd.D1.10.class.onnx, то есть классификационная модель, тренированная на EURUSD D1 на сериях из 10 цен OHLC.

//+------------------------------------------------------------------+
//|                                        ModelEurusdD1_10Class.mqh |
//|                            Авторские права 2023, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#include "ModelSymbolPeriod.mqh"

#resource "Python/model.eurusd.D1.10.class.onnx" as uchar model_eurusd_D1_10_class[]

//+------------------------------------------------------------------+
//| Класс-обертка для ONNX-модели                                    |
//+------------------------------------------------------------------+
class CModelEurusdD1_10Class : public CModelSymbolPeriod
  {
private:
   int               m_sample_size;

public:
   //+------------------------------------------------------------------+
   //| Конструктор                                                      |
   //+------------------------------------------------------------------+
   CModelEurusdD1_10Class(void) : CModelSymbolPeriod("EURUSD",PERIOD_D1)
     {
      m_sample_size=10;
     }

   //+------------------------------------------------------------------+
   //| Инициализация ONNX-модели                                        |
   //+------------------------------------------------------------------+
   virtual bool Init(const string symbol, const ENUM_TIMEFRAMES period)
     {
      //--- проверка символа, периода, создание модели
      if(!CModelSymbolPeriod::CheckInit(symbol,period,model_eurusd_D1_10_class))
        {
         Print("model_eurusd_D1_10_class : ошибка инициализации");
         return(false);
        }

      //--- так как не все размеры определены во входном тензоре, их необходимо явно установить
      //--- первый индекс - размер партии, второй индекс - размер серии, третий индекс - количество серий (OHLC)
      const long input_shape[] = {1,m_sample_size,4};
      if(!OnnxSetInputShape(m_handle,0,input_shape))
        {
         Print("model_eurusd_D1_10_class : ошибка установки формы ввода ",GetLastError());
         return(false);
        }
   
      //--- так как не все размеры определены в выходном тензоре, их необходимо явно установить
      //--- первый индекс - размер партии, должен совпадать с размером партии входного тензора
      //--- второй индекс - количество классов (вверх, тот же или вниз)
      const long output_shape[] = {1,3};
      if(!OnnxSetOutputShape(m_handle,0,output_shape))
        {
         Print("model_eurusd_D1_10_class : ошибка установки формы вывода ",GetLastError());
         return(false);
        }
      //--- успешно
      return(true);
     }

   //+------------------------------------------------------------------+
   //| Предсказать класс                                                |
   //+------------------------------------------------------------------+
   virtual int PredictClass(vector& probabilities)
     {
      static matrixf input_data(m_sample_size,4);    // матрица для подготовленных входных данных
      static vectorf output_data(3);                 // вектор для получения результата
      static matrix  mm(m_sample_size,4);            // матрица горизонтальных векторов Mean
      static matrix  ms(m_sample_size,4);            // матрица горизонтальных векторов Std
      static matrix  x_norm(m_sample_size,4);        // матрица для нормализации цен
   
      //--- подготовить входные данные
      matrix rates;
      //--- запросить последние бары
      if(!rates.CopyRates(m_symbol,m_period,COPY_RATES_OHLC,1,m_sample_size))
         return(-1);
      //--- получить Mean серии
      vector m=rates.Mean(1);
      //--- получить Std серии
      vector s=rates.Std(1);
      //--- подготовить матрицы для нормализации цен
      for(int i=0; i<m_sample_size; i++)
        {
         mm.Row(m,i);
         ms.Row(s,i);
        }
      //--- вход модели должен быть набором вертикальных векторов OHLC
      x_norm=rates.Transpose();
      //--- нормализовать цены
      x_norm-=mm;
      x_norm/=ms;
   
      //--- выполнить вывод
      input_data.Assign(x_norm);
      if(!OnnxRun(m_handle,ONNX_NO_CONVERSION,input_data,output_data))
         return(-1);
      //--- оценить предсказание
      probabilities.Assign(output_data);
      return(int(output_data.ArgMax()));
     }
  };
//+------------------------------------------------------------------+

Как было сказано выше: "Три модели. Каждая отличается от другой размером входных данных, подготовкой входных данных". И мы переопределили всего два метода — Init и PredictClass. В двух других классах для двух других моделей будут переопределены эти же самые методы.

В методе Init вызывается метод базового класса CheckInit, где создаётся сессия для нашей ONNX-модели. А также явно выставляются размеры входного и выходного тензоров. Здесь больше комментариев, чем кода.

В методе PredictClass обеспечивается точно такая же подготовка входных данных, что и при обучении модели. На вход подаётся матрица нормализованных цен OHLC.

4. Проверим, как это работает

Для проверки работоспособности нашего класса был создан очень компактный эксперт.

//+------------------------------------------------------------------+
//|                  ONNX.eurusd.D1.Prediction.mq5                   |
//|                    Авторское право 2023, MetaQuotes Ltd.         |
//|                     https://www.mql5.com                         |
//+------------------------------------------------------------------+
#property copyright   "Авторское право 2023, MetaQuotes Ltd."
#property link        "https://www.mql5.com"
#property version     "1.00"

//#include "ModelEurusdD1_10Class.mqh"
#include "ModelEurusdD1_63Class.mqh"
//#include "ModelEurusdD1_30Class.mqh"
#include <Trade\Trade.mqh>

input double InpLots = 1.0;    // Объем лотов для открытия позиции

//CModelEurusdD1_10Class ExtModel;
CModelEurusdD1_63Class ExtModel;
//CModelEurusdD1_30Class ExtModel;
CTrade                 ExtTrade;

//+------------------------------------------------------------------+
//| Функция инициализации эксперта                                   |
//+------------------------------------------------------------------+
int OnInit()
  {
   if(!ExtModel.Init(_Symbol,_Period))
      return(INIT_FAILED);
//---
   return(INIT_SUCCEEDED);
  }
//+------------------------------------------------------------------+
//| Функция деинициализации эксперта                                 |
//+------------------------------------------------------------------+
void OnDeinit(const int reason)
  {
   ExtModel.Shutdown();
  }
//+------------------------------------------------------------------+
//| Функция тика эксперта                                            |
//+------------------------------------------------------------------+
void OnTick()
  {
   if(!ExtModel.CheckOnTick())
      return;

//--- предсказать следующее движение цены
   vector prob(3);
   int predicted_class=ExtModel.PredictClass(prob);
//--- проверить торговлю в соответствии с прогнозом

Comment("Ответ модели =",predicted_class);
   if(predicted_class>=0)
      if(PositionSelect(_Symbol))
         CheckForClose(predicted_class);
      else
         CheckForOpen(predicted_class);
  }
//+------------------------------------------------------------------+
//| Проверить условия открытия позиции                               |
//+------------------------------------------------------------------+
void CheckForOpen(const int predicted_class)
  {
   ENUM_ORDER_TYPE signal=WRONG_VALUE;
//--- проверить сигналы
   if(predicted_class==PRICE_DOWN)
      signal=ORDER_TYPE_SELL;    // условие на продажу
   else
     {
      if(predicted_class==PRICE_UP)
         signal=ORDER_TYPE_BUY;  // условие на покупку
     }

//--- открыть позицию, если это возможно согласно сигналу
   if(signal!=WRONG_VALUE && TerminalInfoInteger(TERMINAL_TRADE_ALLOWED))
     {
      double price=SymbolInfoDouble(_Symbol,(signal==ORDER_TYPE_SELL) ? SYMBOL_BID : SYMBOL_ASK);
      ExtTrade.PositionOpen(_Symbol,signal,InpLots,price,0,0);
     }
  }
//+------------------------------------------------------------------+
//| Проверить условия закрытия позиции                               |
//+------------------------------------------------------------------+
void CheckForClose(const int predicted_class)
  {
   bool bsignal=false;
//--- позиция уже выбрана ранее
   long type=PositionGetInteger(POSITION_TYPE);
//--- проверить сигналы
   if(type==POSITION_TYPE_BUY && predicted_class==PRICE_DOWN)
      bsignal=true;
   if(type==POSITION_TYPE_SELL && predicted_class==PRICE_UP)
      bsignal=true;

//--- закрыть позицию, если это возможно
   if(bsignal && TerminalInfoInteger(TERMINAL_TRADE_ALLOWED))
     {
      ExtTrade.PositionClose(_Symbol,3);
      //--- открыть противоположную
      CheckForOpen(predicted_class);
     }
  }
//+------------------------------------------------------------------+

Так как модель обучалась на ценовых данных до 2023 года, запустим тестирование с 1 января 2023 года.

Настройки тестирования

И получим следующий результат:

Результаты тестирования

Как видим, вполне работоспособная модель.

5. Класс для второй модели

Вторая модель называется model.eurusd.D1.30.class.onnx. Классификационная модель, тренированная на EURUSD D1 на сериях из 30 цен Close и двух простых скользящих средних с периодами усреднения 21 и 34.

//+------------------------------------------------------------------+
//|                                        ModelEurusdD1_30Class.mqh |
//|                            Авторские права 2023, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#include "ModelSymbolPeriod.mqh"

#resource "Python/model.eurusd.D1.30.class.onnx" as uchar model_eurusd_D1_30_class[]

//+------------------------------------------------------------------+
//| Класс-обертка для ONNX-модели                                    |
//+------------------------------------------------------------------+
class CModelEurusdD1_30Class : public CModelSymbolPeriod
  {
private:
   int               m_sample_size;
   int               m_fast_period;
   int               m_slow_period;
   int               m_sma_fast;
   int               m_sma_slow;

public:
   //+------------------------------------------------------------------+
   //| Конструктор                                                      |
   //+------------------------------------------------------------------+
   CModelEurusdD1_30Class(void) : CModelSymbolPeriod("EURUSD",PERIOD_D1)
     {
      m_sample_size=30;
      m_fast_period=21;
      m_slow_period=34;
      m_sma_fast=INVALID_HANDLE;
      m_sma_slow=INVALID_HANDLE;
     }

   //+------------------------------------------------------------------+
   //| Инициализация ONNX-модели                                        |
   //+------------------------------------------------------------------+
   virtual bool Init(const string symbol, const ENUM_TIMEFRAMES period)
     {
      //--- проверка символа, периода, создание модели
      if(!CModelSymbolPeriod::CheckInit(symbol,period,model_eurusd_D1_30_class))
        {
         Print("model_eurusd_D1_30_class : ошибка инициализации");
         return(false);
        }

      //--- так как не все размеры определены во входном тензоре, их необходимо явно установить
      //--- первый индекс - размер партии, второй индекс - размер серии, третий индекс - количество серий (Close, MA fast, MA slow)
      const long input_shape[] = {1,m_sample_size,3};
      if(!OnnxSetInputShape(m_handle,0,input_shape))
        {
         Print("model_eurusd_D1_30_class : ошибка установки формы ввода ",GetLastError());
         return(false);
        }
   
      //--- так как не все размеры определены в выходном тензоре, их необходимо явно установить
      //--- первый индекс - размер партии, должен совпадать с размером партии входного тензора
      //--- второй индекс - количество классов (вверх, тот же или вниз)
      const long output_shape[] = {1,3};
      if(!OnnxSetOutputShape(m_handle,0,output_shape))
        {
         Print("model_eurusd_D1_30_class : ошибка установки формы вывода ",GetLastError());
         return(false);
        }
      //--- индикаторы
      m_sma_fast=iMA(m_symbol,m_period,m_fast_period,0,MODE_SMA,PRICE_CLOSE);
      m_sma_slow=iMA(m_symbol,m_period,m_slow_period,0,MODE_SMA,PRICE_CLOSE);
      if(m_sma_fast==INVALID_HANDLE || m_sma_slow==INVALID_HANDLE)
        {
         Print("model_eurusd_D1_30_class : не удается создать индикатор");
         return(false);
        }
      //--- успешно
      return(true);
     }

   //+------------------------------------------------------------------+
   //| Предсказать класс                                                |
   //+------------------------------------------------------------------+
   virtual int PredictClass(vector& probabilities)
     {
      static matrixf input_data(m_sample_size,3);    // матрица для подготовленных входных данных
      static vectorf output_data(3);                 // вектор для получения результата
      static matrix  x_norm(m_sample_size,3);        // матрица для нормализации цен
      static vector  vtemp(m_sample_size);
      static double  ma_buffer[];
   
      //--- запросить последние бары
      if(!vtemp.CopyRates(m_symbol,m_period,COPY_RATES_CLOSE,1,m_sample_size))
         return(-1);
      //--- получить Mean серии
      double m=vtemp.Mean();
      //--- получить Std серии
      double s=vtemp.Std();
      //--- нормализовать
      vtemp-=m;
      vtemp/=s;
      x_norm.Col(vtemp,0);
      //--- быстрое среднее
      if(CopyBuffer(m_sma_fast,0,1,m_sample_size,ma_buffer)!=m_sample_size)
         return(-1);
      vtemp.Assign(ma_buffer);
      m=vtemp.Mean();
      s=vtemp.Std();
      vtemp-=m;
      vtemp/=s;
      x_norm.Col(vtemp,1);
      //--- медленное среднее
      if(CopyBuffer(m_sma_slow,0,1,m_sample_size,ma_buffer)!=m_sample_size)
         return(-1);
      vtemp.Assign(ma_buffer);
      m=vtemp.Mean();
      s=vtemp.Std();
      vtemp-=m;
      vtemp/=s;
      x_norm.Col(vtemp,2);
   
      //--- выполнить вывод
      input_data.Assign(x_norm);
      if(!OnnxRun(m_handle,ONNX_NO_CONVERSION,input_data,output_data))
         return(-1);
      //--- оценить предсказание
      probabilities.Assign(output_data);
      return(int(output_data.ArgMax()));
     }
  };
//+------------------------------------------------------------------+

Как и в предыдущем классе, в методе Init вызывается метод базового класса CheckInit, где создаётся сессия для ONNX-модели и явно выставляются размеры входного и выходного тензоров

В методе PredictClass обеспечиваются серии из 30 предыдущих Close и рассчитанные скользящие средние. Данные нормируются тем же способом, что и при обучении.

Проверим, как работает эта модель. Для этого изменим всего две строки проверочного эксперта

#include "ModelEurusdD1_30Class.mqh"
#include <Trade\Trade.mqh>

input double InpLots = 1.0;    // Lots amount to open position

CModelEurusdD1_30Class ExtModel;
CTrade                 ExtTrade;

Параметры тестирования те же самые.

Результаты тестирования второй модели

Видим, что модель работает.

6. Класс для третьей модели

Последняя модель называется model.eurusd.D1.63.class.onnx. Классификационная модель, тренированная на EURUSD D1 на сериях из 63 цен Close.

//+------------------------------------------------------------------+
//|                                             ModelEurusdD1_63.mqh |
//|                            Авторские права 2023, MetaQuotes Ltd. |
//|                                             https://www.mql5.com |
//+------------------------------------------------------------------+
#include "ModelSymbolPeriod.mqh"

#resource "Python/model.eurusd.D1.63.class.onnx" as uchar model_eurusd_D1_63_class[]

//+------------------------------------------------------------------+
//| Класс-обертка для ONNX-модели                                    |
//+------------------------------------------------------------------+
class CModelEurusdD1_63Class : public CModelSymbolPeriod
  {
private:
   int               m_sample_size;

public:
   //+------------------------------------------------------------------+
   //| Конструктор                                                      |
   //+------------------------------------------------------------------+
   CModelEurusdD1_63Class(void) : CModelSymbolPeriod("EURUSD",PERIOD_D1)
     {
      m_sample_size=63;
     }

   //+------------------------------------------------------------------+
   //| Инициализация ONNX-модели                                        |
   //+------------------------------------------------------------------+
   virtual bool Init(const string symbol, const ENUM_TIMEFRAMES period)
     {
      //--- проверка символа, периода, создание модели
      if(!CModelSymbolPeriod::CheckInit(symbol,period,model_eurusd_D1_63_class))
        {
         Print("model_eurusd_D1_63_class : ошибка инициализации");
         return(false);
        }

      //--- так как не все размеры определены во входном тензоре, их необходимо явно установить
      //--- первый индекс - размер партии, второй индекс - размер серии
      const long input_shape[] = {1,m_sample_size};
      if(!OnnxSetInputShape(m_handle,0,input_shape))
        {
         Print("model_eurusd_D1_63_class : ошибка установки формы ввода ",GetLastError());
         return(false);
        }
   
      //--- так как не все размеры определены в выходном тензоре, их необходимо явно установить
      //--- первый индекс - размер партии, должен совпадать с размером партии входного тензора
      //--- второй индекс - количество классов (вверх, тот же или вниз)
      const long output_shape[] = {1,3};
      if(!OnnxSetOutputShape(m_handle,0,output_shape))
        {
         Print("model_eurusd_D1_63_class : ошибка установки формы вывода ",GetLastError());
         return(false);
        }
      //--- успешно
      return(true);
     }

   //+------------------------------------------------------------------+
   //| Предсказать класс                                                |
   //+------------------------------------------------------------------+
   virtual int PredictClass(vector& probabilities)
     {
      static vectorf input_data(m_sample_size);  // вектор для подготовленных входных данных
      static vectorf output_data(3);             // вектор для получения результата
   
      //--- запросить последние бары
      if(!input_data.CopyRates(m_symbol,m_period,COPY_RATES_CLOSE,1,m_sample_size))
         return(-1);
      //--- получить Mean серии
      float m=input_data.Mean();
      //--- получить Std серии
      float s=input_data.Std();
      //--- нормализовать цены
      input_data-=m;
      input_data/=s;
   
      //--- выполнить вывод
      if(!OnnxRun(m_handle,ONNX_NO_CONVERSION,input_data,output_data))
         return(-1);
      //--- оценить предсказание
      probabilities.Assign(output_data);
      return(int(output_data.ArgMax()));
     }
  };
//+------------------------------------------------------------------+

Это — самая простая модель из трёх. Поэтому код метода PredictClass получился таким компактным.

Опять изменим две строки в эксперте

#include "ModelEurusdD1_63Class.mqh"
#include <Trade\Trade.mqh>

input double InpLots = 1.0;    // Lots amount to open position

CModelEurusdD1_63Class ExtModel;
CTrade                 ExtTrade;

И запустим тестирование с теми же самыми настройками.

Результаты тестирования третьей модели

Модель работает

7. Собираем все модели в одном эксперте. Жёсткое голосование

Все три модели показали свою работоспособность. Теперь попробуем объединить их усилия. Устроим голосование моделей.

Предварительные объявления и определения

#include "ModelEurusdD1_10Class.mqh"
#include "ModelEurusdD1_30Class.mqh"
#include "ModelEurusdD1_63Class.mqh"
#include <Trade\Trade.mqh>

input double  InpLots  = 1.0;    // Количество лотов для открытия позиции

CModelSymbolPeriod *ExtModels[3];
CTrade              ExtTrade;

Функция OnInit

int OnInit()
  {
   ExtModels[0]=new CModelEurusdD1_10Class;
   ExtModels[1]=new CModelEurusdD1_30Class;
   ExtModels[2]=new CModelEurusdD1_63Class;

   for(long i=0; i<ExtModels.Size(); i++)
      if(!ExtModels[i].Init(_Symbol,_Period))
         return(INIT_FAILED);
//---
   return(INIT_SUCCEEDED);
  }

Функция OnTick

void OnTick()
  {
   for(long i=0; i<ArraySize(ExtModels); i++)
      if(!ExtModels[i].CheckOnTick())
         return;

//--- предсказать следующее движение цены
   int    returned[3]={0,0,0};
   vector soft=vector::Zeros(3);
//--- собрать возвращенные классы
   for(long i=0; i<ArraySize(ExtModels); i++)
     {
      vector prob(3);
      int    pred=ExtModels[i].PredictClass(prob);
      if(pred>=0)
        {
         returned[pred]++;
         soft+=prob;
        }
     }
//--- получить одно предсказание для всех моделей
   int predicted_class=-1;
//--- мягкое или жесткое голосование
   if(InpVotes==Soft)
      predicted_class=(int)soft.ArgMax();
   else
     {
      //--- подсчитать голоса за предсказания
      for(int n=0; n<3; n++)
        {
         if(returned[n]>=InpVotes)
           {
            predicted_class=n;
            break;
           }
        }
     }

//--- проверить торговлю в соответствии с предсказанием
   if(predicted_class>=0)
      if(PositionSelect(_Symbol))
         CheckForClose(predicted_class);
      else
         CheckForOpen(predicted_class);
  }

Большинство голосов считается по формуле <общее количество голосов>/2 + 1. Для общего числа голосов 3 большинством являются 2 голоса. Это - так называемое "жёсткое голосование"

Результат тестирования всё с теми же самыми настройками.

Вспомним работу всех трёх моделей по отдельности, а именно количество прибыльных и убыточных трейдов.

Модельколичество прибыльных и убыточных трейдов.

Первая модель

11 : 3

Вторая модель

6 : 1

Третья модель

16 : 10

Похоже, при помощи жёсткого голосования мы улучшили результат — 16 : 4. Но, конечно же, необходимо смотреть полные отчёты и графики тестирования.

8. Мягкое голосование

Мягкое голосование отличается от жёсткого тем, что учитывается не количество голосов, а считается сумма вероятностей всех трёх классов от всех трёх моделей. И уже по самой высокой вероятности выбирается класс.

Для обеспечения мягкого голосования необходимо внести некоторые изменения.

В базовом классе:

 virtual int PredictClass(vector& probabilities)
     {
      double predicted_price=PredictPrice();
      if(predicted_price==DBL_MAX)
         return(-1);

      int    predicted_class=-1;
      double last_close=iClose(m_symbol,m_period,1);
      //--- классифицировать предсказанное движение цены
      double delta=last_close-predicted_price;
      if(fabs(delta)<=m_class_delta)
         predicted_class=PRICE_SAME;
      else
        {
         if(delta<0)
            predicted_class=PRICE_UP;
         else
            predicted_class=PRICE_DOWN;
        }

      //--- установить вероятность предсказания как 1.0
      probabilities.Fill(0);
      if(predicted_class<(int)probabilities.Size())
         probabilities[predicted_class]=1;
      //--- и вернуть предсказанный класс
      return(predicted_class);
     }
  };

В классах наследниках:

   //+------------------------------------------------------------------+
   //| Predict class                                                    |
   //+------------------------------------------------------------------+
   virtual int PredictClass(vector& probabilities)
     {
...
      //--- evaluate prediction
      probabilities.Assign(output_data);
      return(int(output_data.ArgMax()));
     }

В эксперте:

#include "ModelEurusdD1_10Class.mqh"
#include "ModelEurusdD1_30Class.mqh"
#include "ModelEurusdD1_63Class.mqh"
#include <Trade\Trade.mqh>

enum EnVotes
  {
   Two=2,    // Два голоса
   Three=3,  // Три голоса
   Soft=4    // Мягкое голосование
  };

input double  InpLots  = 1.0;    // Количество лотов для открытия позиции
input EnVotes InpVotes = Two;    // Голоса для принятия торгового решения

CModelSymbolPeriod *ExtModels[3];
CTrade              ExtTrade;
void OnTick()
  {
   for(long i=0; i<ArraySize(ExtModels); i++)
      if(!ExtModels[i].CheckOnTick())
         return;

//--- предсказать следующее движение цены
   int    returned[3]={0,0,0};
   vector soft=vector::Zeros(3);
//--- собрать возвращенные классы
   for(long i=0; i<ArraySize(ExtModels); i++)
     {
      vector prob(3);
      int    pred=ExtModels[i].PredictClass(prob);
      if(pred>=0)
        {
         returned[pred]++;
         soft+=prob;
        }
     }
//--- получить одно предсказание для всех моделей
   int predicted_class=-1;
//--- мягкое или жесткое голосование
   if(InpVotes==Soft)
      predicted_class=(int)soft.ArgMax();
   else
     {
      //--- подсчитать голоса за предсказания
      for(int n=0; n<3; n++)
        {
         if(returned[n]>=InpVotes)
           {
            predicted_class=n;
            break;
           }
        }
     }

//--- проверить торговлю в соответствии с предсказанием
   if(predicted_class>=0)
      if(PositionSelect(_Symbol))
         CheckForClose(predicted_class);
      else
         CheckForOpen(predicted_class);
  }

Тестируем всё с теми же настройками. Во входных параметрах выбираем Soft.

Получим результат.

Прибыльных трейдов — 15, Убыточных трейдов — 3. В денежном выражении жёсткое голосование тоже оказалось лучше, чем мягкое.

9. Единогласное голосоване

Интересно посмотреть на результат единогласного голосования, то есть при количестве голосов 3.

Очень консервативная торговля. При этом единственная убыточная сделка была закрыта при окончании тестирования (возможно, она и не убыточная на самом деле).

Важно: обращаем ваше внимание, что использованные в статье модели представлены только в целях демонстрации работы с ONNX-моделями средствами языка MQL5. Советник не предназначен для торговли на реальных счетах.

Заключение

Мы показали, как объектно-ориентированное программирование позволяет упростить написание программ. Все сложности моделей (а модели могут быть гораздо более сложными, чем представленные нами в качестве примера) прячутся в своих классах. А остальная "сложность" уместилась в 45 строках функции OnTick.

Прикрепленные файлы

Last updated