From 857977dd4a68c887b876c1a3088a84521b8652d5 Mon Sep 17 00:00:00 2001 From: aymeric75 Date: Wed, 8 Jan 2025 00:50:24 +0100 Subject: [PATCH] Added a scrapper for Yahoo Finance Data in the YahooFinanceProcessor class (#1305) * last * last * added yahoo scrapper * reformatted with isort and black * last * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * last * last * last * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- examples/Stock_NeurIPS2018_SB3.ipynb | 7824 ++++------------- .../data_processors/processor_yahoofinance.py | 123 + 2 files changed, 1857 insertions(+), 6090 deletions(-) diff --git a/examples/Stock_NeurIPS2018_SB3.ipynb b/examples/Stock_NeurIPS2018_SB3.ipynb index 842260512..d595d5b7e 100644 --- a/examples/Stock_NeurIPS2018_SB3.ipynb +++ b/examples/Stock_NeurIPS2018_SB3.ipynb @@ -1,6099 +1,1743 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "yfv52r2G33jY" - }, - "source": [ - "\"Open" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gXaoZs2lh1hi" - }, - "source": [ - "# Deep Reinforcement Learning for Stock Trading from Scratch: Multiple Stock Trading\n", - "\n", - "* **Pytorch Version** \n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "lGunVt8oLCVS" - }, - "source": [ - "# Content" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HOzAKQ-SLGX6" - }, - "source": [ - "* [1. Task Description](#0)\n", - "* [2. Install Python packages](#1)\n", - " * [2.1. Install Packages](#1.1) \n", - " * [2.2. A List of Python Packages](#1.2)\n", - " * [2.3. Import Packages](#1.3)\n", - " * [2.4. Create Folders](#1.4)\n", - "* [3. Download and Preprocess Data](#2)\n", - "* [4. Preprocess Data](#3) \n", - " * [4.1. Technical Indicators](#3.1)\n", - " * [4.2. Perform Feature Engineering](#3.2)\n", - "* [5. Build Market Environment in OpenAI Gym-style](#4) \n", - " * [5.1. Data Split](#4.1) \n", - " * [5.3. Environment for Training](#4.2) \n", - "* [6. Train DRL Agents](#5)\n", - "* [7. Backtesting Performance](#6) \n", - " * [7.1. BackTestStats](#6.1)\n", - " * [7.2. BackTestPlot](#6.2) \n", - " " - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "sApkDlD9LIZv" - }, - "source": [ - "\n", - "# Part 1. Task Discription" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HjLD2TZSLKZ-" - }, - "source": [ - "We train a DRL agent for stock trading. This task is modeled as a Markov Decision Process (MDP), and the objective function is maximizing (expected) cumulative return.\n", - "\n", - "We specify the state-action-reward as follows:\n", - "\n", - "* **State s**: The state space represents an agent's perception of the market environment. Just like a human trader analyzing various information, here our agent passively observes many features and learns by interacting with the market environment (usually by replaying historical data).\n", - "\n", - "* **Action a**: The action space includes allowed actions that an agent can take at each state. For example, a ∈ {−1, 0, 1}, where −1, 0, 1 represent\n", - "selling, holding, and buying. When an action operates multiple shares, a ∈{−k, ..., −1, 0, 1, ..., k}, e.g.. \"Buy\n", - "10 shares of AAPL\" or \"Sell 10 shares of AAPL\" are 10 or −10, respectively\n", - "\n", - "* **Reward function r(s, a, s′)**: Reward is an incentive for an agent to learn a better policy. For example, it can be the change of the portfolio value when taking a at state s and arriving at new state s', i.e., r(s, a, s′) = v′ − v, where v′ and v represent the portfolio values at state s′ and s, respectively\n", - "\n", - "\n", - "**Market environment**: 30 consituent stocks of Dow Jones Industrial Average (DJIA) index. Accessed at the starting date of the testing period.\n", - "\n", - "\n", - "The data for this case study is obtained from Yahoo Finance API. The data contains Open-High-Low-Close price and volume.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Ffsre789LY08" - }, - "source": [ - "\n", - "# Part 2. Install Python Packages" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Uy5_PTmOh1hj" - }, - "source": [ - "\n", - "## 2.1. Install packages\n" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "mPT0ipYE28wL", - "outputId": "6dad74d2-c37f-4b86-c584-2436d2ef5bae" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", - "Collecting swig\n", - " Using cached swig-4.1.1-py2.py3-none-manylinux_2_5_x86_64.manylinux1_x86_64.whl (1.8 MB)\n", - "Installing collected packages: swig\n", - "Successfully installed swig-4.1.1\n", - "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0mLooking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", - "Requirement already satisfied: wrds in /usr/local/lib/python3.9/site-packages (3.1.6)\n", - "Requirement already satisfied: psycopg2-binary in /usr/local/lib/python3.9/site-packages (from wrds) (2.9.6)\n", - "Requirement already satisfied: numpy in /usr/local/lib/python3.9/site-packages (from wrds) (1.24.2)\n", - "Requirement already satisfied: scipy in /usr/local/lib/python3.9/site-packages (from wrds) (1.10.1)\n", - "Requirement already satisfied: sqlalchemy<2 in /usr/local/lib/python3.9/site-packages (from wrds) (1.4.47)\n", - "Requirement already satisfied: pandas in /usr/local/lib/python3.9/site-packages (from wrds) (2.0.0)\n", - "Requirement already satisfied: greenlet!=0.4.17 in /usr/local/lib/python3.9/site-packages (from sqlalchemy<2->wrds) (2.0.2)\n", - "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.9/site-packages (from pandas->wrds) (2.8.2)\n", - "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.9/site-packages (from pandas->wrds) (2023.3)\n", - "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.9/site-packages (from pandas->wrds) (2023.3)\n", - "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.9/site-packages (from python-dateutil>=2.8.2->pandas->wrds) (1.16.0)\n", - "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0mLooking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", - "Collecting pyportfolioopt\n", - " Using cached pyportfolioopt-1.5.4-py3-none-any.whl (61 kB)\n", - "Requirement already satisfied: pandas>=0.19 in /usr/local/lib/python3.9/site-packages (from pyportfolioopt) (2.0.0)\n", - "Requirement already satisfied: scipy<2.0,>=1.3 in /usr/local/lib/python3.9/site-packages (from pyportfolioopt) (1.10.1)\n", - "Collecting cvxpy<2.0.0,>=1.1.10\n", - " Downloading cvxpy-1.3.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.0 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.0/4.0 MB\u001b[0m \u001b[31m29.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: numpy<2.0.0,>=1.22.4 in /usr/local/lib/python3.9/site-packages (from pyportfolioopt) (1.24.2)\n", - "Requirement already satisfied: setuptools>65.5.1 in /usr/local/lib/python3.9/site-packages (from cvxpy<2.0.0,>=1.1.10->pyportfolioopt) (65.6.3)\n", - "Collecting ecos>=2\n", - " Downloading ecos-2.0.12-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (220 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m220.1/220.1 kB\u001b[0m \u001b[31m18.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting scs>=1.1.6\n", - " Downloading scs-3.2.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (10.7 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m10.7/10.7 MB\u001b[0m \u001b[31m41.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hCollecting osqp>=0.4.1\n", - " Downloading osqp-0.6.2.post8-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (298 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m298.2/298.2 kB\u001b[0m \u001b[31m22.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.9/site-packages (from pandas>=0.19->pyportfolioopt) (2.8.2)\n", - "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.9/site-packages (from pandas>=0.19->pyportfolioopt) (2023.3)\n", - "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.9/site-packages (from pandas>=0.19->pyportfolioopt) (2023.3)\n", - "Collecting qdldl\n", - " Downloading qdldl-0.1.7-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.0 MB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.0/1.0 MB\u001b[0m \u001b[31m45.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hRequirement already satisfied: six>=1.5 in /usr/local/lib/python3.9/site-packages (from python-dateutil>=2.8.2->pandas>=0.19->pyportfolioopt) (1.16.0)\n", - "Installing collected packages: scs, qdldl, ecos, osqp, cvxpy, pyportfolioopt\n", - "Successfully installed cvxpy-1.3.1 ecos-2.0.12 osqp-0.6.2.post8 pyportfolioopt-1.5.4 qdldl-0.1.7 scs-3.2.3\n", - "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0m\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0m✨🍰✨ Everything looks OK!\n", - "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n", - "Collecting git+https://github.com/AI4Finance-Foundation/FinRL.git\n", - " Cloning https://github.com/AI4Finance-Foundation/FinRL.git to /tmp/pip-req-build-cseb5t6p\n", - " Running command git clone --filter=blob:none --quiet https://github.com/AI4Finance-Foundation/FinRL.git /tmp/pip-req-build-cseb5t6p\n", - " Resolved https://github.com/AI4Finance-Foundation/FinRL.git to commit 2ed2207c926608d559789624df26ef26682d2e14\n", - " Installing build dependencies ... \u001b[?25l\u001b[?25hdone\n", - " Getting requirements to build wheel ... \u001b[?25l\u001b[?25hdone\n", - " Preparing metadata (pyproject.toml) ... \u001b[?25l\u001b[?25hdone\n", - "Collecting elegantrl@ git+https://github.com/AI4Finance-Foundation/ElegantRL.git#egg=elegantrl\n", - " Cloning https://github.com/AI4Finance-Foundation/ElegantRL.git to /tmp/pip-install-_lm24vb_/elegantrl_c12ae7d2917240d88d343d99b571685b\n", - " Running command git clone --filter=blob:none --quiet https://github.com/AI4Finance-Foundation/ElegantRL.git /tmp/pip-install-_lm24vb_/elegantrl_c12ae7d2917240d88d343d99b571685b\n", - " Resolved https://github.com/AI4Finance-Foundation/ElegantRL.git to commit c22e8402e4778f15475bade34d1b9e37b557d97d\n", - " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - "Collecting pyfolio@ git+https://github.com/quantopian/pyfolio.git#egg=pyfolio-0.9.2\n", - " Cloning https://github.com/quantopian/pyfolio.git to /tmp/pip-install-_lm24vb_/pyfolio_8da514f26bb94b29bbca93fedbe6536b\n", - " Running command git clone --filter=blob:none --quiet https://github.com/quantopian/pyfolio.git /tmp/pip-install-_lm24vb_/pyfolio_8da514f26bb94b29bbca93fedbe6536b\n", - " Resolved https://github.com/quantopian/pyfolio.git to commit 4b901f6d73aa02ceb6d04b7d83502e5c6f2e81aa\n", - " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", - "Requirement already satisfied: lz4 in /usr/local/lib/python3.9/site-packages (from finrl==0.3.5) (4.3.2)\n", - "Requirement already satisfied: stable-baselines3<2.0.0,>=1.6.2 in /usr/local/lib/python3.9/site-packages (from finrl==0.3.5) (1.8.0)\n", - "Requirement already satisfied: tensorboardX in /usr/local/lib/python3.9/site-packages (from finrl==0.3.5) (2.6)\n", - "Requirement already satisfied: yfinance in /usr/local/lib/python3.9/site-packages (from finrl==0.3.5) (0.2.17)\n", - "Requirement already satisfied: exchange_calendars==3.6.3 in /usr/local/lib/python3.9/site-packages (from finrl==0.3.5) (3.6.3)\n", - "Requirement already satisfied: gputil in /usr/local/lib/python3.9/site-packages (from finrl==0.3.5) (1.4.0)\n", - "Requirement already satisfied: alpaca_trade_api>=2.1.0 in /usr/local/lib/python3.9/site-packages (from finrl==0.3.5) (3.0.0)\n", - "Requirement already satisfied: wrds>=3.1.6 in /usr/local/lib/python3.9/site-packages (from finrl==0.3.5) (3.1.6)\n", - "Requirement already satisfied: importlib-metadata==4.13.0 in /usr/local/lib/python3.9/site-packages (from finrl==0.3.5) (4.13.0)\n", - "Requirement already satisfied: stockstats>=0.4.0 in /usr/local/lib/python3.9/site-packages (from finrl==0.3.5) (0.5.2)\n", - "Requirement already satisfied: ccxt>=1.66.32 in /usr/local/lib/python3.9/site-packages (from finrl==0.3.5) (3.0.63)\n", - "Requirement already satisfied: numpy>=1.17.3 in /usr/local/lib/python3.9/site-packages (from finrl==0.3.5) (1.24.2)\n", - "Requirement already satisfied: scikit-learn>=0.21.0 in /usr/local/lib/python3.9/site-packages (from finrl==0.3.5) (1.2.2)\n", - "Requirement already satisfied: pandas>=1.1.5 in /usr/local/lib/python3.9/site-packages (from finrl==0.3.5) (2.0.0)\n", - "Requirement already satisfied: ray[default,tune]>=2.0.0 in /usr/local/lib/python3.9/site-packages (from finrl==0.3.5) (2.3.1)\n", - "Requirement already satisfied: gym>=0.17 in /usr/local/lib/python3.9/site-packages (from finrl==0.3.5) (0.21.0)\n", - "Requirement already satisfied: matplotlib in /usr/local/lib/python3.9/site-packages (from finrl==0.3.5) (3.7.1)\n", - "Requirement already satisfied: jqdatasdk in /usr/local/lib/python3.9/site-packages (from finrl==0.3.5) (1.8.11)\n", - "Requirement already satisfied: toolz in /usr/local/lib/python3.9/site-packages (from exchange_calendars==3.6.3->finrl==0.3.5) (0.12.0)\n", - "Requirement already satisfied: pyluach in /usr/local/lib/python3.9/site-packages (from exchange_calendars==3.6.3->finrl==0.3.5) (2.2.0)\n", - "Requirement already satisfied: korean-lunar-calendar in /usr/local/lib/python3.9/site-packages (from exchange_calendars==3.6.3->finrl==0.3.5) (0.3.1)\n", - "Requirement already satisfied: python-dateutil in /usr/local/lib/python3.9/site-packages (from exchange_calendars==3.6.3->finrl==0.3.5) (2.8.2)\n", - "Requirement already satisfied: pytz in /usr/local/lib/python3.9/site-packages (from exchange_calendars==3.6.3->finrl==0.3.5) (2023.3)\n", - "Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.9/site-packages (from importlib-metadata==4.13.0->finrl==0.3.5) (3.15.0)\n", - "Requirement already satisfied: deprecation==2.1.0 in /usr/local/lib/python3.9/site-packages (from alpaca_trade_api>=2.1.0->finrl==0.3.5) (2.1.0)\n", - "Requirement already satisfied: msgpack==1.0.3 in /usr/local/lib/python3.9/site-packages (from alpaca_trade_api>=2.1.0->finrl==0.3.5) (1.0.3)\n", - "Requirement already satisfied: requests<3,>2 in /usr/local/lib/python3.9/site-packages (from alpaca_trade_api>=2.1.0->finrl==0.3.5) (2.28.2)\n", - "Requirement already satisfied: websocket-client<2,>=0.56.0 in /usr/local/lib/python3.9/site-packages (from alpaca_trade_api>=2.1.0->finrl==0.3.5) (1.5.1)\n", - "Requirement already satisfied: websockets<11,>=9.0 in /usr/local/lib/python3.9/site-packages (from alpaca_trade_api>=2.1.0->finrl==0.3.5) (10.4)\n", - "Requirement already satisfied: PyYAML==6.0 in /usr/local/lib/python3.9/site-packages (from alpaca_trade_api>=2.1.0->finrl==0.3.5) (6.0)\n", - "Requirement already satisfied: urllib3<2,>1.24 in /usr/local/lib/python3.9/site-packages (from alpaca_trade_api>=2.1.0->finrl==0.3.5) (1.26.15)\n", - "Requirement already satisfied: aiohttp==3.8.1 in /usr/local/lib/python3.9/site-packages (from alpaca_trade_api>=2.1.0->finrl==0.3.5) (3.8.1)\n", - "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.9/site-packages (from aiohttp==3.8.1->alpaca_trade_api>=2.1.0->finrl==0.3.5) (22.2.0)\n", - "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.9/site-packages (from aiohttp==3.8.1->alpaca_trade_api>=2.1.0->finrl==0.3.5) (4.0.2)\n", - "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.9/site-packages (from aiohttp==3.8.1->alpaca_trade_api>=2.1.0->finrl==0.3.5) (6.0.4)\n", - "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.9/site-packages (from aiohttp==3.8.1->alpaca_trade_api>=2.1.0->finrl==0.3.5) (1.8.2)\n", - "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.9/site-packages (from aiohttp==3.8.1->alpaca_trade_api>=2.1.0->finrl==0.3.5) (1.3.1)\n", - "Requirement already satisfied: charset-normalizer<3.0,>=2.0 in /usr/local/lib/python3.9/site-packages (from aiohttp==3.8.1->alpaca_trade_api>=2.1.0->finrl==0.3.5) (2.1.1)\n", - "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.9/site-packages (from aiohttp==3.8.1->alpaca_trade_api>=2.1.0->finrl==0.3.5) (1.3.3)\n", - "Requirement already satisfied: packaging in /usr/local/lib/python3.9/site-packages (from deprecation==2.1.0->alpaca_trade_api>=2.1.0->finrl==0.3.5) (23.1)\n", - "Requirement already satisfied: setuptools>=60.9.0 in /usr/local/lib/python3.9/site-packages (from ccxt>=1.66.32->finrl==0.3.5) (65.6.3)\n", - "Requirement already satisfied: certifi>=2018.1.18 in /usr/local/lib/python3.9/site-packages (from ccxt>=1.66.32->finrl==0.3.5) (2022.12.7)\n", - "Requirement already satisfied: cryptography>=2.6.1 in /usr/local/lib/python3.9/site-packages (from ccxt>=1.66.32->finrl==0.3.5) (39.0.2)\n", - "Requirement already satisfied: aiodns>=1.1.1 in /usr/local/lib/python3.9/site-packages (from ccxt>=1.66.32->finrl==0.3.5) (3.0.0)\n", - "Requirement already satisfied: cloudpickle>=1.2.0 in /usr/local/lib/python3.9/site-packages (from gym>=0.17->finrl==0.3.5) (2.2.1)\n", - "Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.9/site-packages (from pandas>=1.1.5->finrl==0.3.5) (2023.3)\n", - "Requirement already satisfied: protobuf!=3.19.5,>=3.15.3 in /usr/local/lib/python3.9/site-packages (from ray[default,tune]>=2.0.0->finrl==0.3.5) (3.20.3)\n", - "Requirement already satisfied: grpcio>=1.32.0 in /usr/local/lib/python3.9/site-packages (from ray[default,tune]>=2.0.0->finrl==0.3.5) (1.53.0)\n", - "Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.9/site-packages (from ray[default,tune]>=2.0.0->finrl==0.3.5) (8.1.3)\n", - "Requirement already satisfied: filelock in /usr/local/lib/python3.9/site-packages (from ray[default,tune]>=2.0.0->finrl==0.3.5) (3.11.0)\n", - "Requirement already satisfied: jsonschema in /usr/local/lib/python3.9/site-packages (from ray[default,tune]>=2.0.0->finrl==0.3.5) (4.17.3)\n", - "Requirement already satisfied: virtualenv>=20.0.24 in /usr/local/lib/python3.9/site-packages (from ray[default,tune]>=2.0.0->finrl==0.3.5) (20.21.0)\n", - "Requirement already satisfied: tabulate in /usr/local/lib/python3.9/site-packages (from ray[default,tune]>=2.0.0->finrl==0.3.5) (0.9.0)\n", - "Requirement already satisfied: colorful in /usr/local/lib/python3.9/site-packages (from ray[default,tune]>=2.0.0->finrl==0.3.5) (0.5.5)\n", - "Requirement already satisfied: opencensus in /usr/local/lib/python3.9/site-packages (from ray[default,tune]>=2.0.0->finrl==0.3.5) (0.11.2)\n", - "Requirement already satisfied: gpustat>=1.0.0 in /usr/local/lib/python3.9/site-packages (from ray[default,tune]>=2.0.0->finrl==0.3.5) (1.1)\n", - "Requirement already satisfied: aiohttp-cors in /usr/local/lib/python3.9/site-packages (from ray[default,tune]>=2.0.0->finrl==0.3.5) (0.7.0)\n", - "Requirement already satisfied: py-spy>=0.2.0 in /usr/local/lib/python3.9/site-packages (from ray[default,tune]>=2.0.0->finrl==0.3.5) (0.3.14)\n", - "Requirement already satisfied: pydantic in /usr/local/lib/python3.9/site-packages (from ray[default,tune]>=2.0.0->finrl==0.3.5) (1.10.7)\n", - "Requirement already satisfied: prometheus-client>=0.7.1 in /usr/local/lib/python3.9/site-packages (from ray[default,tune]>=2.0.0->finrl==0.3.5) (0.16.0)\n", - "Requirement already satisfied: smart-open in /usr/local/lib/python3.9/site-packages (from ray[default,tune]>=2.0.0->finrl==0.3.5) (6.3.0)\n", - "Requirement already satisfied: joblib>=1.1.1 in /usr/local/lib/python3.9/site-packages (from scikit-learn>=0.21.0->finrl==0.3.5) (1.2.0)\n", - "Requirement already satisfied: scipy>=1.3.2 in /usr/local/lib/python3.9/site-packages (from scikit-learn>=0.21.0->finrl==0.3.5) (1.10.1)\n", - "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.9/site-packages (from scikit-learn>=0.21.0->finrl==0.3.5) (3.1.0)\n", - "Requirement already satisfied: torch>=1.11 in /usr/local/lib/python3.9/site-packages (from stable-baselines3<2.0.0,>=1.6.2->finrl==0.3.5) (2.0.0)\n", - "Requirement already satisfied: psycopg2-binary in /usr/local/lib/python3.9/site-packages (from wrds>=3.1.6->finrl==0.3.5) (2.9.6)\n", - "Requirement already satisfied: sqlalchemy<2 in /usr/local/lib/python3.9/site-packages (from wrds>=3.1.6->finrl==0.3.5) (1.4.47)\n", - "Requirement already satisfied: pymysql>=0.7.6 in /usr/local/lib/python3.9/site-packages (from jqdatasdk->finrl==0.3.5) (1.0.3)\n", - "Requirement already satisfied: thriftpy2>=0.3.9 in /usr/local/lib/python3.9/site-packages (from jqdatasdk->finrl==0.3.5) (0.4.16)\n", - "Requirement already satisfied: six in /usr/local/lib/python3.9/site-packages (from jqdatasdk->finrl==0.3.5) (1.16.0)\n", - "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.9/site-packages (from matplotlib->finrl==0.3.5) (3.0.9)\n", - "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.9/site-packages (from matplotlib->finrl==0.3.5) (1.4.4)\n", - "Requirement already satisfied: importlib-resources>=3.2.0 in /usr/local/lib/python3.9/site-packages (from matplotlib->finrl==0.3.5) (5.12.0)\n", - "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.9/site-packages (from matplotlib->finrl==0.3.5) (4.39.3)\n", - "Requirement already satisfied: pillow>=6.2.0 in /usr/local/lib/python3.9/site-packages (from matplotlib->finrl==0.3.5) (9.5.0)\n", - "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.9/site-packages (from matplotlib->finrl==0.3.5) (1.0.7)\n", - "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.9/site-packages (from matplotlib->finrl==0.3.5) (0.11.0)\n", - "Requirement already satisfied: ipython>=3.2.3 in /usr/local/lib/python3.9/site-packages (from pyfolio@ git+https://github.com/quantopian/pyfolio.git#egg=pyfolio-0.9.2->finrl==0.3.5) (8.12.0)\n", - "Requirement already satisfied: seaborn>=0.7.1 in /usr/local/lib/python3.9/site-packages (from pyfolio@ git+https://github.com/quantopian/pyfolio.git#egg=pyfolio-0.9.2->finrl==0.3.5) (0.12.2)\n", - "Requirement already satisfied: empyrical>=0.5.0 in /usr/local/lib/python3.9/site-packages (from pyfolio@ git+https://github.com/quantopian/pyfolio.git#egg=pyfolio-0.9.2->finrl==0.3.5) (0.5.5)\n", - "Requirement already satisfied: multitasking>=0.0.7 in /usr/local/lib/python3.9/site-packages (from yfinance->finrl==0.3.5) (0.0.11)\n", - "Requirement already satisfied: beautifulsoup4>=4.11.1 in /usr/local/lib/python3.9/site-packages (from yfinance->finrl==0.3.5) (4.12.2)\n", - "Requirement already satisfied: appdirs>=1.4.4 in /usr/local/lib/python3.9/site-packages (from yfinance->finrl==0.3.5) (1.4.4)\n", - "Requirement already satisfied: frozendict>=2.3.4 in /usr/local/lib/python3.9/site-packages (from yfinance->finrl==0.3.5) (2.3.7)\n", - "Requirement already satisfied: html5lib>=1.1 in /usr/local/lib/python3.9/site-packages (from yfinance->finrl==0.3.5) (1.1)\n", - "Requirement already satisfied: lxml>=4.9.1 in /usr/local/lib/python3.9/site-packages (from yfinance->finrl==0.3.5) (4.9.2)\n", - "Requirement already satisfied: pycares>=4.0.0 in /usr/local/lib/python3.9/site-packages (from aiodns>=1.1.1->ccxt>=1.66.32->finrl==0.3.5) (4.3.0)\n", - "Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.9/site-packages (from beautifulsoup4>=4.11.1->yfinance->finrl==0.3.5) (2.4)\n", - "Requirement already satisfied: cffi>=1.12 in /usr/local/lib/python3.9/site-packages (from cryptography>=2.6.1->ccxt>=1.66.32->finrl==0.3.5) (1.15.1)\n", - "Requirement already satisfied: pandas-datareader>=0.2 in /usr/local/lib/python3.9/site-packages (from empyrical>=0.5.0->pyfolio@ git+https://github.com/quantopian/pyfolio.git#egg=pyfolio-0.9.2->finrl==0.3.5) (0.10.0)\n", - "Requirement already satisfied: nvidia-ml-py>=11.450.129 in /usr/local/lib/python3.9/site-packages (from gpustat>=1.0.0->ray[default,tune]>=2.0.0->finrl==0.3.5) (11.525.112)\n", - "Requirement already satisfied: blessed>=1.17.1 in /usr/local/lib/python3.9/site-packages (from gpustat>=1.0.0->ray[default,tune]>=2.0.0->finrl==0.3.5) (1.20.0)\n", - "Requirement already satisfied: psutil>=5.6.0 in /usr/local/lib/python3.9/site-packages (from gpustat>=1.0.0->ray[default,tune]>=2.0.0->finrl==0.3.5) (5.9.4)\n", - "Requirement already satisfied: webencodings in /usr/local/lib/python3.9/site-packages (from html5lib>=1.1->yfinance->finrl==0.3.5) (0.5.1)\n", - "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.9/site-packages (from ipython>=3.2.3->pyfolio@ git+https://github.com/quantopian/pyfolio.git#egg=pyfolio-0.9.2->finrl==0.3.5) (4.5.0)\n", - "Requirement already satisfied: stack-data in /usr/local/lib/python3.9/site-packages (from ipython>=3.2.3->pyfolio@ git+https://github.com/quantopian/pyfolio.git#egg=pyfolio-0.9.2->finrl==0.3.5) (0.6.2)\n", - "Requirement already satisfied: jedi>=0.16 in /usr/local/lib/python3.9/site-packages (from ipython>=3.2.3->pyfolio@ git+https://github.com/quantopian/pyfolio.git#egg=pyfolio-0.9.2->finrl==0.3.5) (0.18.2)\n", - "Requirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /usr/local/lib/python3.9/site-packages (from ipython>=3.2.3->pyfolio@ git+https://github.com/quantopian/pyfolio.git#egg=pyfolio-0.9.2->finrl==0.3.5) (3.0.38)\n", - "Requirement already satisfied: backcall in /usr/local/lib/python3.9/site-packages (from ipython>=3.2.3->pyfolio@ git+https://github.com/quantopian/pyfolio.git#egg=pyfolio-0.9.2->finrl==0.3.5) (0.2.0)\n", - "Requirement already satisfied: decorator in /usr/local/lib/python3.9/site-packages (from ipython>=3.2.3->pyfolio@ git+https://github.com/quantopian/pyfolio.git#egg=pyfolio-0.9.2->finrl==0.3.5) (5.1.1)\n", - "Requirement already satisfied: pickleshare in /usr/local/lib/python3.9/site-packages (from ipython>=3.2.3->pyfolio@ git+https://github.com/quantopian/pyfolio.git#egg=pyfolio-0.9.2->finrl==0.3.5) (0.7.5)\n", - "Requirement already satisfied: traitlets>=5 in /usr/local/lib/python3.9/site-packages (from ipython>=3.2.3->pyfolio@ git+https://github.com/quantopian/pyfolio.git#egg=pyfolio-0.9.2->finrl==0.3.5) (5.9.0)\n", - "Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.9/site-packages (from ipython>=3.2.3->pyfolio@ git+https://github.com/quantopian/pyfolio.git#egg=pyfolio-0.9.2->finrl==0.3.5) (4.8.0)\n", - "Requirement already satisfied: matplotlib-inline in /usr/local/lib/python3.9/site-packages (from ipython>=3.2.3->pyfolio@ git+https://github.com/quantopian/pyfolio.git#egg=pyfolio-0.9.2->finrl==0.3.5) (0.1.6)\n", - "Requirement already satisfied: pygments>=2.4.0 in /usr/local/lib/python3.9/site-packages (from ipython>=3.2.3->pyfolio@ git+https://github.com/quantopian/pyfolio.git#egg=pyfolio-0.9.2->finrl==0.3.5) (2.15.0)\n", - "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.9/site-packages (from requests<3,>2->alpaca_trade_api>=2.1.0->finrl==0.3.5) (3.4)\n", - "Requirement already satisfied: greenlet!=0.4.17 in /usr/local/lib/python3.9/site-packages (from sqlalchemy<2->wrds>=3.1.6->finrl==0.3.5) (2.0.2)\n", - "Requirement already satisfied: ply<4.0,>=3.4 in /usr/local/lib/python3.9/site-packages (from thriftpy2>=0.3.9->jqdatasdk->finrl==0.3.5) (3.11)\n", - "Requirement already satisfied: nvidia-cuda-cupti-cu11==11.7.101 in /usr/local/lib/python3.9/site-packages (from torch>=1.11->stable-baselines3<2.0.0,>=1.6.2->finrl==0.3.5) (11.7.101)\n", - "Requirement already satisfied: jinja2 in /usr/local/lib/python3.9/site-packages (from torch>=1.11->stable-baselines3<2.0.0,>=1.6.2->finrl==0.3.5) (3.1.2)\n", - "Requirement already satisfied: nvidia-cuda-runtime-cu11==11.7.99 in /usr/local/lib/python3.9/site-packages (from torch>=1.11->stable-baselines3<2.0.0,>=1.6.2->finrl==0.3.5) (11.7.99)\n", - "Requirement already satisfied: nvidia-nvtx-cu11==11.7.91 in /usr/local/lib/python3.9/site-packages (from torch>=1.11->stable-baselines3<2.0.0,>=1.6.2->finrl==0.3.5) (11.7.91)\n", - "Requirement already satisfied: networkx in /usr/local/lib/python3.9/site-packages (from torch>=1.11->stable-baselines3<2.0.0,>=1.6.2->finrl==0.3.5) (3.1)\n", - "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.9/site-packages (from torch>=1.11->stable-baselines3<2.0.0,>=1.6.2->finrl==0.3.5) (2.0.0)\n", - "Requirement already satisfied: sympy in /usr/local/lib/python3.9/site-packages (from torch>=1.11->stable-baselines3<2.0.0,>=1.6.2->finrl==0.3.5) (1.11.1)\n", - "Requirement already satisfied: nvidia-cuda-nvrtc-cu11==11.7.99 in /usr/local/lib/python3.9/site-packages (from torch>=1.11->stable-baselines3<2.0.0,>=1.6.2->finrl==0.3.5) (11.7.99)\n", - "Requirement already satisfied: nvidia-cudnn-cu11==8.5.0.96 in /usr/local/lib/python3.9/site-packages (from torch>=1.11->stable-baselines3<2.0.0,>=1.6.2->finrl==0.3.5) (8.5.0.96)\n", - "Requirement already satisfied: nvidia-nccl-cu11==2.14.3 in /usr/local/lib/python3.9/site-packages (from torch>=1.11->stable-baselines3<2.0.0,>=1.6.2->finrl==0.3.5) (2.14.3)\n", - "Requirement already satisfied: nvidia-cusolver-cu11==11.4.0.1 in /usr/local/lib/python3.9/site-packages (from torch>=1.11->stable-baselines3<2.0.0,>=1.6.2->finrl==0.3.5) (11.4.0.1)\n", - "Requirement already satisfied: nvidia-cufft-cu11==10.9.0.58 in /usr/local/lib/python3.9/site-packages (from torch>=1.11->stable-baselines3<2.0.0,>=1.6.2->finrl==0.3.5) (10.9.0.58)\n", - "Requirement already satisfied: nvidia-curand-cu11==10.2.10.91 in /usr/local/lib/python3.9/site-packages (from torch>=1.11->stable-baselines3<2.0.0,>=1.6.2->finrl==0.3.5) (10.2.10.91)\n", - "Requirement already satisfied: nvidia-cublas-cu11==11.10.3.66 in /usr/local/lib/python3.9/site-packages (from torch>=1.11->stable-baselines3<2.0.0,>=1.6.2->finrl==0.3.5) (11.10.3.66)\n", - "Requirement already satisfied: nvidia-cusparse-cu11==11.7.4.91 in /usr/local/lib/python3.9/site-packages (from torch>=1.11->stable-baselines3<2.0.0,>=1.6.2->finrl==0.3.5) (11.7.4.91)\n", - "Requirement already satisfied: wheel in /usr/local/lib/python3.9/site-packages (from nvidia-cublas-cu11==11.10.3.66->torch>=1.11->stable-baselines3<2.0.0,>=1.6.2->finrl==0.3.5) (0.38.4)\n", - "Requirement already satisfied: cmake in /usr/local/lib/python3.9/site-packages (from triton==2.0.0->torch>=1.11->stable-baselines3<2.0.0,>=1.6.2->finrl==0.3.5) (3.26.3)\n", - "Requirement already satisfied: lit in /usr/local/lib/python3.9/site-packages (from triton==2.0.0->torch>=1.11->stable-baselines3<2.0.0,>=1.6.2->finrl==0.3.5) (16.0.1)\n", - "Requirement already satisfied: distlib<1,>=0.3.6 in /usr/local/lib/python3.9/site-packages (from virtualenv>=20.0.24->ray[default,tune]>=2.0.0->finrl==0.3.5) (0.3.6)\n", - "Requirement already satisfied: platformdirs<4,>=2.4 in /usr/local/lib/python3.9/site-packages (from virtualenv>=20.0.24->ray[default,tune]>=2.0.0->finrl==0.3.5) (3.2.0)\n", - "Requirement already satisfied: box2d-py==2.3.5 in /usr/local/lib/python3.9/site-packages (from gym>=0.17->finrl==0.3.5) (2.3.5)\n", - "Requirement already satisfied: pyglet>=1.4.0 in /usr/local/lib/python3.9/site-packages (from gym>=0.17->finrl==0.3.5) (2.0.5)\n", - "Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.9/site-packages (from jsonschema->ray[default,tune]>=2.0.0->finrl==0.3.5) (0.19.3)\n", - "Requirement already satisfied: google-api-core<3.0.0,>=1.0.0 in /usr/local/lib/python3.9/site-packages (from opencensus->ray[default,tune]>=2.0.0->finrl==0.3.5) (2.11.0)\n", - "Requirement already satisfied: opencensus-context>=0.1.3 in /usr/local/lib/python3.9/site-packages (from opencensus->ray[default,tune]>=2.0.0->finrl==0.3.5) (0.1.3)\n", - "Requirement already satisfied: wcwidth>=0.1.4 in /usr/local/lib/python3.9/site-packages (from blessed>=1.17.1->gpustat>=1.0.0->ray[default,tune]>=2.0.0->finrl==0.3.5) (0.2.6)\n", - "Requirement already satisfied: pycparser in /usr/local/lib/python3.9/site-packages (from cffi>=1.12->cryptography>=2.6.1->ccxt>=1.66.32->finrl==0.3.5) (2.21)\n", - "Requirement already satisfied: google-auth<3.0dev,>=2.14.1 in /usr/local/lib/python3.9/site-packages (from google-api-core<3.0.0,>=1.0.0->opencensus->ray[default,tune]>=2.0.0->finrl==0.3.5) (2.17.3)\n", - "Requirement already satisfied: googleapis-common-protos<2.0dev,>=1.56.2 in /usr/local/lib/python3.9/site-packages (from google-api-core<3.0.0,>=1.0.0->opencensus->ray[default,tune]>=2.0.0->finrl==0.3.5) (1.59.0)\n", - "Requirement already satisfied: parso<0.9.0,>=0.8.0 in /usr/local/lib/python3.9/site-packages (from jedi>=0.16->ipython>=3.2.3->pyfolio@ git+https://github.com/quantopian/pyfolio.git#egg=pyfolio-0.9.2->finrl==0.3.5) (0.8.3)\n", - "Requirement already satisfied: ptyprocess>=0.5 in /usr/local/lib/python3.9/site-packages (from pexpect>4.3->ipython>=3.2.3->pyfolio@ git+https://github.com/quantopian/pyfolio.git#egg=pyfolio-0.9.2->finrl==0.3.5) (0.7.0)\n", - "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.9/site-packages (from jinja2->torch>=1.11->stable-baselines3<2.0.0,>=1.6.2->finrl==0.3.5) (2.1.2)\n", - "Requirement already satisfied: asttokens>=2.1.0 in /usr/local/lib/python3.9/site-packages (from stack-data->ipython>=3.2.3->pyfolio@ git+https://github.com/quantopian/pyfolio.git#egg=pyfolio-0.9.2->finrl==0.3.5) (2.2.1)\n", - "Requirement already satisfied: pure-eval in /usr/local/lib/python3.9/site-packages (from stack-data->ipython>=3.2.3->pyfolio@ git+https://github.com/quantopian/pyfolio.git#egg=pyfolio-0.9.2->finrl==0.3.5) (0.2.2)\n", - "Requirement already satisfied: executing>=1.2.0 in /usr/local/lib/python3.9/site-packages (from stack-data->ipython>=3.2.3->pyfolio@ git+https://github.com/quantopian/pyfolio.git#egg=pyfolio-0.9.2->finrl==0.3.5) (1.2.0)\n", - "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.9/site-packages (from sympy->torch>=1.11->stable-baselines3<2.0.0,>=1.6.2->finrl==0.3.5) (1.3.0)\n", - "Requirement already satisfied: cachetools<6.0,>=2.0.0 in /usr/local/lib/python3.9/site-packages (from google-auth<3.0dev,>=2.14.1->google-api-core<3.0.0,>=1.0.0->opencensus->ray[default,tune]>=2.0.0->finrl==0.3.5) (5.3.0)\n", - "Requirement already satisfied: rsa<5,>=3.1.4 in /usr/local/lib/python3.9/site-packages (from google-auth<3.0dev,>=2.14.1->google-api-core<3.0.0,>=1.0.0->opencensus->ray[default,tune]>=2.0.0->finrl==0.3.5) (4.9)\n", - "Requirement already satisfied: pyasn1-modules>=0.2.1 in /usr/local/lib/python3.9/site-packages (from google-auth<3.0dev,>=2.14.1->google-api-core<3.0.0,>=1.0.0->opencensus->ray[default,tune]>=2.0.0->finrl==0.3.5) (0.2.8)\n", - "Requirement already satisfied: pyasn1<0.5.0,>=0.4.6 in /usr/local/lib/python3.9/site-packages (from pyasn1-modules>=0.2.1->google-auth<3.0dev,>=2.14.1->google-api-core<3.0.0,>=1.0.0->opencensus->ray[default,tune]>=2.0.0->finrl==0.3.5) (0.4.8)\n", - "\u001b[33mWARNING: Running pip as the 'root' user can result in broken permissions and conflicting behaviour with the system package manager. It is recommended to use a virtual environment instead: https://pip.pypa.io/warnings/venv\u001b[0m\u001b[33m\n", - "\u001b[0m" - ] - } - ], - "source": [ - "## install required packages\n", - "!pip install swig\n", - "!pip install wrds\n", - "!pip install pyportfolioopt\n", - "## install finrl library\n", - "!pip install -q condacolab\n", - "import condacolab\n", - "condacolab.install()\n", - "!apt-get update -y -qq && apt-get install -y -qq cmake libopenmpi-dev python3-dev zlib1g-dev libgl1-mesa-glx swig\n", - "!pip install git+https://github.com/AI4Finance-Foundation/FinRL.git" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "osBHhVysOEzi" - }, - "source": [ - "\n", - "\n", - "## 2.2. A list of Python packages \n", - "* Yahoo Finance API\n", - "* pandas\n", - "* numpy\n", - "* matplotlib\n", - "* stockstats\n", - "* OpenAI gym\n", - "* stable-baselines\n", - "* tensorflow\n", - "* pyfolio" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "nGv01K8Sh1hn" - }, - "source": [ - "\n", - "## 2.3. Import Packages" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "lPqeTTwoh1hn", - "outputId": "e55033fc-48ae-4696-ae45-08b8bef664d5" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stderr", - "text": [ - "/usr/local/lib/python3.9/site-packages/pyfolio/pos.py:26: UserWarning: Module \"zipline.assets\" not found; multipliers will not be applied to position notionals.\n", - " warnings.warn(\n" - ] - } - ], - "source": [ - "import pandas as pd\n", - "import numpy as np\n", - "import matplotlib\n", - "import matplotlib.pyplot as plt\n", - "# matplotlib.use('Agg')\n", - "import datetime\n", - "\n", - "%matplotlib inline\n", - "from finrl.meta.preprocessor.yahoodownloader import YahooDownloader\n", - "from finrl.meta.preprocessor.preprocessors import FeatureEngineer, data_split\n", - "from finrl.meta.env_stock_trading.env_stocktrading import StockTradingEnv\n", - "from finrl.agents.stablebaselines3.models import DRLAgent\n", - "from stable_baselines3.common.logger import configure\n", - "from finrl.meta.data_processor import DataProcessor\n", - "\n", - "from finrl.plot import backtest_stats, backtest_plot, get_daily_return, get_baseline\n", - "from pprint import pprint\n", - "\n", - "import sys\n", - "sys.path.append(\"../FinRL\")\n", - "\n", - "import itertools" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "T2owTj985RW4" - }, - "source": [ - "\n", - "## 2.4. Create Folders" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": { - "id": "RtUc_ofKmpdy" - }, - "outputs": [], - "source": [ - "from finrl import config\n", - "from finrl import config_tickers\n", - "import os\n", - "from finrl.main import check_and_make_directories\n", - "from finrl.config import (\n", - " DATA_SAVE_DIR,\n", - " TRAINED_MODEL_DIR,\n", - " TENSORBOARD_LOG_DIR,\n", - " RESULTS_DIR,\n", - " INDICATORS,\n", - " TRAIN_START_DATE,\n", - " TRAIN_END_DATE,\n", - " TEST_START_DATE,\n", - " TEST_END_DATE,\n", - " TRADE_START_DATE,\n", - " TRADE_END_DATE,\n", - ")\n", - "check_and_make_directories([DATA_SAVE_DIR, TRAINED_MODEL_DIR, TENSORBOARD_LOG_DIR, RESULTS_DIR])\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "A289rQWMh1hq" - }, - "source": [ - "\n", - "# Part 3. Download Data\n", - "Yahoo Finance provides stock data, financial news, financial reports, etc. Yahoo Finance is free.\n", - "* FinRL uses a class **YahooDownloader** in FinRL-Meta to fetch data via Yahoo Finance API\n", - "* Call Limit: Using the Public API (without authentication), you are limited to 2,000 requests per hour per IP (or up to a total of 48,000 requests a day)." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "NPeQ7iS-LoMm" - }, - "source": [ - "\n", - "\n", - "-----\n", - "class YahooDownloader:\n", - " Retrieving daily stock data from\n", - " Yahoo Finance API\n", - "\n", - " Attributes\n", - " ----------\n", - " start_date : str\n", - " start date of the data (modified from config.py)\n", - " end_date : str\n", - " end date of the data (modified from config.py)\n", - " ticker_list : list\n", - " a list of stock tickers (modified from config.py)\n", - "\n", - " Methods\n", - " -------\n", - " fetch_data()\n" - ] - }, - { - "cell_type": "code", - "execution_count": 87, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 0 - }, - "id": "h3XJnvrbLp-C", - "outputId": "a03772b5-9cad-463f-e1d6-58d91a70a594" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "'2021-10-01'" - ], - "application/vnd.google.colaboratory.intrinsic+json": { - "type": "string" - } - }, - "metadata": {}, - "execution_count": 87 - } - ], - "source": [ - "# from config.py, TRAIN_START_DATE is a string\n", - "TRAIN_START_DATE\n", - "# from config.py, TRAIN_END_DATE is a string\n", - "TRAIN_END_DATE" - ] - }, - { - "cell_type": "code", - "execution_count": 88, - "metadata": { - "id": "FUnY8WEfLq3C" - }, - "outputs": [], - "source": [ - "TRAIN_START_DATE = '2010-01-01'\n", - "TRAIN_END_DATE = '2021-10-01'\n", - "TRADE_START_DATE = '2021-10-01'\n", - "TRADE_END_DATE = '2023-03-01'" - ] - }, - { - "cell_type": "code", - "execution_count": 89, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "yCKm4om-s9kE", - "outputId": "fd758d58-8946-42ee-e2e3-16f4ac74add2" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "[*********************100%***********************] 1 of 1 completed\n", - "[*********************100%***********************] 1 of 1 completed\n", - "[*********************100%***********************] 1 of 1 completed\n", - "[*********************100%***********************] 1 of 1 completed\n", - "[*********************100%***********************] 1 of 1 completed\n", - "[*********************100%***********************] 1 of 1 completed\n", - "[*********************100%***********************] 1 of 1 completed\n", - "[*********************100%***********************] 1 of 1 completed\n", - "[*********************100%***********************] 1 of 1 completed\n", - "[*********************100%***********************] 1 of 1 completed\n", - "[*********************100%***********************] 1 of 1 completed\n", - "[*********************100%***********************] 1 of 1 completed\n", - "[*********************100%***********************] 1 of 1 completed\n", - "[*********************100%***********************] 1 of 1 completed\n", - "[*********************100%***********************] 1 of 1 completed\n", - "[*********************100%***********************] 1 of 1 completed\n", - "[*********************100%***********************] 1 of 1 completed\n", - "[*********************100%***********************] 1 of 1 completed\n", - "[*********************100%***********************] 1 of 1 completed\n", - "[*********************100%***********************] 1 of 1 completed\n", - "[*********************100%***********************] 1 of 1 completed\n", - "[*********************100%***********************] 1 of 1 completed\n", - "[*********************100%***********************] 1 of 1 completed\n", - "[*********************100%***********************] 1 of 1 completed\n", - "[*********************100%***********************] 1 of 1 completed\n", - "[*********************100%***********************] 1 of 1 completed\n", - "[*********************100%***********************] 1 of 1 completed\n", - "[*********************100%***********************] 1 of 1 completed\n", - "[*********************100%***********************] 1 of 1 completed\n", - "[*********************100%***********************] 1 of 1 completed\n", - "Shape of DataFrame: (97013, 8)\n" - ] - } - ], - "source": [ - "df = YahooDownloader(start_date = TRAIN_START_DATE,\n", - " end_date = TRADE_END_DATE,\n", - " ticker_list = config_tickers.DOW_30_TICKER).fetch_data()" - ] - }, - { - "cell_type": "code", - "execution_count": 90, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "JzqRRTOX6aFu", - "outputId": "58a21ede-016a-4eaf-db9f-aeb190b3f939" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "['AXP', 'AMGN', 'AAPL', 'BA', 'CAT', 'CSCO', 'CVX', 'GS', 'HD', 'HON', 'IBM', 'INTC', 'JNJ', 'KO', 'JPM', 'MCD', 'MMM', 'MRK', 'MSFT', 'NKE', 'PG', 'TRV', 'UNH', 'CRM', 'VZ', 'V', 'WBA', 'WMT', 'DIS', 'DOW']\n" - ] - } - ], - "source": [ - "print(config_tickers.DOW_30_TICKER)" - ] - }, - { - "cell_type": "code", - "execution_count": 91, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "CV3HrZHLh1hy", - "outputId": "c2cf4956-210b-4811-be12-0c7fd18b923c" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(97013, 8)" - ] - }, - "metadata": {}, - "execution_count": 91 - } - ], - "source": [ - "df.shape" - ] - }, - { - "cell_type": "code", - "execution_count": 92, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 0 - }, - "id": "4hYkeaPiICHS", - "outputId": "6d7a1c0d-15dc-4adc-b776-f1020e173a5c" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - " date open high low close volume tic \n", - "0 2010-01-04 7.622500 7.660714 7.585000 6.505281 493729600 AAPL \\\n", - "1 2010-01-04 56.630001 57.869999 56.560001 42.888958 5277400 AMGN \n", - "2 2010-01-04 40.810001 41.099998 40.389999 33.551674 6894300 AXP \n", - "3 2010-01-04 55.720001 56.389999 54.799999 43.777550 6186700 BA \n", - "4 2010-01-04 57.650002 59.189999 57.509998 41.156906 7325600 CAT \n", - "\n", - " day \n", - "0 0 \n", - "1 0 \n", - "2 0 \n", - "3 0 \n", - "4 0 " - ], - "text/html": [ - "\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
dateopenhighlowclosevolumeticday
02010-01-047.6225007.6607147.5850006.505281493729600AAPL0
12010-01-0456.63000157.86999956.56000142.8889585277400AMGN0
22010-01-0440.81000141.09999840.38999933.5516746894300AXP0
32010-01-0455.72000156.38999954.79999943.7775506186700BA0
42010-01-0457.65000259.18999957.50999841.1569067325600CAT0
\n", - "
\n", - " \n", - " \n", - " \n", - "\n", - " \n", - "
\n", - "
\n", - " " - ] - }, - "metadata": {}, - "execution_count": 92 - } - ], - "source": [ - "df.sort_values(['date','tic'],ignore_index=True).head()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "uqC6c40Zh1iH" - }, - "source": [ - "# Part 4: Preprocess Data\n", - "We need to check for missing data and do feature engineering to convert the data point into a state.\n", - "* **Adding technical indicators**. In practical trading, various information needs to be taken into account, such as historical prices, current holding shares, technical indicators, etc. Here, we demonstrate two trend-following technical indicators: MACD and RSI.\n", - "* **Adding turbulence index**. Risk-aversion reflects whether an investor prefers to protect the capital. It also influences one's trading strategy when facing different market volatility level. To control the risk in a worst-case scenario, such as financial crisis of 2007–2008, FinRL employs the turbulence index that measures extreme fluctuation of asset price." - ] - }, - { - "cell_type": "code", - "execution_count": 93, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "PmKP-1ii3RLS", - "outputId": "22fecb54-5555-4ec4-cb32-0a54f443e54e" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Successfully added technical indicators\n", - "[*********************100%***********************] 1 of 1 completed\n", - "Shape of DataFrame: (3310, 8)\n", - "Successfully added vix\n", - "Successfully added turbulence index\n" - ] - } - ], - "source": [ - "fe = FeatureEngineer(\n", - " use_technical_indicator=True,\n", - " tech_indicator_list = INDICATORS,\n", - " use_vix=True,\n", - " use_turbulence=True,\n", - " user_defined_feature = False)\n", - "\n", - "processed = fe.preprocess_data(df)" - ] - }, - { - "cell_type": "code", - "execution_count": 94, - "metadata": { - "id": "Kixon2tR3RLT" - }, - "outputs": [], - "source": [ - "list_ticker = processed[\"tic\"].unique().tolist()\n", - "list_date = list(pd.date_range(processed['date'].min(),processed['date'].max()).astype(str))\n", - "combination = list(itertools.product(list_date,list_ticker))\n", - "\n", - "processed_full = pd.DataFrame(combination,columns=[\"date\",\"tic\"]).merge(processed,on=[\"date\",\"tic\"],how=\"left\")\n", - "processed_full = processed_full[processed_full['date'].isin(processed['date'])]\n", - "processed_full = processed_full.sort_values(['date','tic'])\n", - "\n", - "processed_full = processed_full.fillna(0)" - ] - }, - { - "cell_type": "code", - "execution_count": 95, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 0 - }, - "id": "grvhGJJII3Xn", - "outputId": "2af27938-0df3-4fea-e86d-7a361e71d2e2" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - " date tic open high low close \n", - "0 2010-01-04 AAPL 7.622500 7.660714 7.585000 6.505281 \\\n", - "1 2010-01-04 AMGN 56.630001 57.869999 56.560001 42.888958 \n", - "2 2010-01-04 AXP 40.810001 41.099998 40.389999 33.551674 \n", - "3 2010-01-04 BA 55.720001 56.389999 54.799999 43.777550 \n", - "4 2010-01-04 CAT 57.650002 59.189999 57.509998 41.156906 \n", - "5 2010-01-04 CRM 18.652500 18.882500 18.547501 18.705000 \n", - "6 2010-01-04 CSCO 24.110001 24.840000 24.010000 17.264450 \n", - "7 2010-01-04 CVX 78.199997 79.199997 78.160004 46.851871 \n", - "8 2010-01-04 DIS 32.500000 32.750000 31.870001 27.933922 \n", - "9 2010-01-04 GS 170.050003 174.250000 169.509995 139.862900 \n", - "\n", - " volume day macd boll_ub boll_lb rsi_30 cci_30 dx_30 \n", - "0 493729600.0 0.0 0.0 6.526808 6.495 100.0 66.666667 100.0 \\\n", - "1 5277400.0 0.0 0.0 6.526808 6.495 100.0 66.666667 100.0 \n", - "2 6894300.0 0.0 0.0 6.526808 6.495 100.0 66.666667 100.0 \n", - "3 6186700.0 0.0 0.0 6.526808 6.495 100.0 66.666667 100.0 \n", - "4 7325600.0 0.0 0.0 6.526808 6.495 100.0 66.666667 100.0 \n", - "5 7906000.0 0.0 0.0 6.526808 6.495 100.0 66.666667 100.0 \n", - "6 59853700.0 0.0 0.0 6.526808 6.495 100.0 66.666667 100.0 \n", - "7 10173800.0 0.0 0.0 6.526808 6.495 100.0 66.666667 100.0 \n", - "8 13700400.0 0.0 0.0 6.526808 6.495 100.0 66.666667 100.0 \n", - "9 9135000.0 0.0 0.0 6.526808 6.495 100.0 66.666667 100.0 \n", - "\n", - " close_30_sma close_60_sma vix turbulence \n", - "0 6.505281 6.505281 20.040001 0.0 \n", - "1 42.888958 42.888958 20.040001 0.0 \n", - "2 33.551674 33.551674 20.040001 0.0 \n", - "3 43.777550 43.777550 20.040001 0.0 \n", - "4 41.156906 41.156906 20.040001 0.0 \n", - "5 18.705000 18.705000 20.040001 0.0 \n", - "6 17.264450 17.264450 20.040001 0.0 \n", - "7 46.851871 46.851871 20.040001 0.0 \n", - "8 27.933922 27.933922 20.040001 0.0 \n", - "9 139.862900 139.862900 20.040001 0.0 " - ], - "text/html": [ - "\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
dateticopenhighlowclosevolumedaymacdboll_ubboll_lbrsi_30cci_30dx_30close_30_smaclose_60_smavixturbulence
02010-01-04AAPL7.6225007.6607147.5850006.505281493729600.00.00.06.5268086.495100.066.666667100.06.5052816.50528120.0400010.0
12010-01-04AMGN56.63000157.86999956.56000142.8889585277400.00.00.06.5268086.495100.066.666667100.042.88895842.88895820.0400010.0
22010-01-04AXP40.81000141.09999840.38999933.5516746894300.00.00.06.5268086.495100.066.666667100.033.55167433.55167420.0400010.0
32010-01-04BA55.72000156.38999954.79999943.7775506186700.00.00.06.5268086.495100.066.666667100.043.77755043.77755020.0400010.0
42010-01-04CAT57.65000259.18999957.50999841.1569067325600.00.00.06.5268086.495100.066.666667100.041.15690641.15690620.0400010.0
52010-01-04CRM18.65250018.88250018.54750118.7050007906000.00.00.06.5268086.495100.066.666667100.018.70500018.70500020.0400010.0
62010-01-04CSCO24.11000124.84000024.01000017.26445059853700.00.00.06.5268086.495100.066.666667100.017.26445017.26445020.0400010.0
72010-01-04CVX78.19999779.19999778.16000446.85187110173800.00.00.06.5268086.495100.066.666667100.046.85187146.85187120.0400010.0
82010-01-04DIS32.50000032.75000031.87000127.93392213700400.00.00.06.5268086.495100.066.666667100.027.93392227.93392220.0400010.0
92010-01-04GS170.050003174.250000169.509995139.8629009135000.00.00.06.5268086.495100.066.666667100.0139.862900139.86290020.0400010.0
\n", - "
\n", - " \n", - " \n", - " \n", - "\n", - " \n", - "
\n", - "
\n", - " " - ] - }, - "metadata": {}, - "execution_count": 95 - } - ], - "source": [ - "processed_full.sort_values(['date','tic'],ignore_index=True).head(10)" - ] - }, - { - "cell_type": "code", - "execution_count": 96, - "metadata": { - "id": "5vdORQ384Qx-" - }, - "outputs": [], - "source": [ - "mvo_df = processed_full.sort_values(['date','tic'],ignore_index=True)[['date','tic','close']]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-QsYaY0Dh1iw" - }, - "source": [ - "\n", - "# Part 5. Build A Market Environment in OpenAI Gym-style\n", - "The training process involves observing stock price change, taking an action and reward's calculation. By interacting with the market environment, the agent will eventually derive a trading strategy that may maximize (expected) rewards.\n", - "\n", - "Our market environment, based on OpenAI Gym, simulates stock markets with historical market data." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "5TOhcryx44bb" - }, - "source": [ - "## Data Split\n", - "We split the data into training set and testing set as follows:\n", - "\n", - "Training data period: 2009-01-01 to 2020-07-01\n", - "\n", - "Trading data period: 2020-07-01 to 2021-10-31\n" - ] - }, - { - "cell_type": "code", - "execution_count": 97, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "W0qaVGjLtgbI", - "outputId": "4f16484e-811e-46cd-efee-54c6b309f5a5" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "85753\n", - "10237\n" - ] - } - ], - "source": [ - "train = data_split(processed_full, TRAIN_START_DATE,TRAIN_END_DATE)\n", - "trade = data_split(processed_full, TRADE_START_DATE,TRADE_END_DATE)\n", - "train_length = len(train)\n", - "trade_length = len(trade)\n", - "print(train_length)\n", - "print(trade_length)" - ] - }, - { - "cell_type": "code", - "execution_count": 98, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 0 - }, - "id": "p52zNCOhTtLR", - "outputId": "d708401b-129f-495b-e691-7ab8666d6847" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - " date tic open high low close \n", - "2956 2021-09-30 UNH 401.489990 403.489990 390.459991 383.180176 \\\n", - "2956 2021-09-30 V 227.580002 228.789993 222.630005 220.257767 \n", - "2956 2021-09-30 VZ 54.500000 54.509998 54.000000 49.011551 \n", - "2956 2021-09-30 WBA 48.790001 48.930000 46.919998 43.957283 \n", - "2956 2021-09-30 WMT 140.639999 141.729996 139.250000 136.114777 \n", - "\n", - " volume day macd boll_ub boll_lb rsi_30 \n", - "2956 3779900.0 3.0 -4.349256 419.212270 386.863774 40.895395 \\\n", - "2956 7128500.0 3.0 -1.538728 228.639270 216.529608 44.078996 \n", - "2956 18736600.0 3.0 -0.233667 50.131163 48.744038 41.824917 \n", - "2956 6449400.0 3.0 -0.253540 48.531810 43.597246 44.613711 \n", - "2956 7485900.0 3.0 -1.554834 146.155111 135.633136 40.165878 \n", - "\n", - " cci_30 dx_30 close_30_sma close_60_sma vix turbulence \n", - "2956 -222.938238 41.980385 405.947466 405.438423 23.139999 24.872328 \n", - "2956 -54.614579 19.569853 224.777295 231.479431 23.139999 24.872328 \n", - "2956 -102.798842 21.682953 49.605026 50.130592 23.139999 24.872328 \n", - "2956 -107.390223 0.941150 45.886550 44.857786 23.139999 24.872328 \n", - "2956 -151.542656 45.466733 142.345313 141.246727 23.139999 24.872328 " - ], - "text/html": [ - "\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
dateticopenhighlowclosevolumedaymacdboll_ubboll_lbrsi_30cci_30dx_30close_30_smaclose_60_smavixturbulence
29562021-09-30UNH401.489990403.489990390.459991383.1801763779900.03.0-4.349256419.212270386.86377440.895395-222.93823841.980385405.947466405.43842323.13999924.872328
29562021-09-30V227.580002228.789993222.630005220.2577677128500.03.0-1.538728228.639270216.52960844.078996-54.61457919.569853224.777295231.47943123.13999924.872328
29562021-09-30VZ54.50000054.50999854.00000049.01155118736600.03.0-0.23366750.13116348.74403841.824917-102.79884221.68295349.60502650.13059223.13999924.872328
29562021-09-30WBA48.79000148.93000046.91999843.9572836449400.03.0-0.25354048.53181043.59724644.613711-107.3902230.94115045.88655044.85778623.13999924.872328
29562021-09-30WMT140.639999141.729996139.250000136.1147777485900.03.0-1.554834146.155111135.63313640.165878-151.54265645.466733142.345313141.24672723.13999924.872328
\n", - "
\n", - " \n", - " \n", - " \n", - "\n", - " \n", - "
\n", - "
\n", - " " - ] - }, - "metadata": {}, - "execution_count": 98 - } - ], - "source": [ - "train.tail()" - ] - }, - { - "cell_type": "code", - "execution_count": 99, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 0 - }, - "id": "k9zU9YaTTvFq", - "outputId": "9080799c-a150-4414-c2de-a68c5e7c3a85" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - " date tic open high low close \n", - "0 2021-10-01 AAPL 141.899994 142.919998 139.110001 141.404266 \\\n", - "0 2021-10-01 AMGN 213.589996 214.610001 210.800003 203.845871 \n", - "0 2021-10-01 AXP 168.500000 175.119995 168.479996 170.065353 \n", - "0 2021-10-01 BA 222.850006 226.720001 220.600006 226.000000 \n", - "0 2021-10-01 CAT 192.899994 195.869995 191.240005 187.928040 \n", - "\n", - " volume day macd boll_ub boll_lb rsi_30 cci_30 \n", - "0 94639600.0 4.0 -1.703488 155.382846 137.132193 46.927735 -142.190202 \\\n", - "0 2629400.0 4.0 -3.097330 212.767980 199.379595 40.408533 -96.757039 \n", - "0 3956000.0 4.0 2.273329 174.218856 149.232889 56.265093 117.538402 \n", - "0 9113600.0 4.0 0.730320 226.909442 205.727561 51.614047 116.649440 \n", - "0 3695500.0 4.0 -3.640324 205.735919 181.432783 41.999435 -112.087765 \n", - "\n", - " dx_30 close_30_sma close_60_sma vix turbulence \n", - "0 41.749873 147.171798 146.269415 21.1 120.122978 \n", - "0 36.189244 208.480832 217.103342 21.1 120.122978 \n", - "0 15.667511 161.215661 163.458888 21.1 120.122978 \n", - "0 2.027170 217.175334 221.968500 21.1 120.122978 \n", - "0 36.203176 196.993869 200.522109 21.1 120.122978 " - ], - "text/html": [ - "\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
dateticopenhighlowclosevolumedaymacdboll_ubboll_lbrsi_30cci_30dx_30close_30_smaclose_60_smavixturbulence
02021-10-01AAPL141.899994142.919998139.110001141.40426694639600.04.0-1.703488155.382846137.13219346.927735-142.19020241.749873147.171798146.26941521.1120.122978
02021-10-01AMGN213.589996214.610001210.800003203.8458712629400.04.0-3.097330212.767980199.37959540.408533-96.75703936.189244208.480832217.10334221.1120.122978
02021-10-01AXP168.500000175.119995168.479996170.0653533956000.04.02.273329174.218856149.23288956.265093117.53840215.667511161.215661163.45888821.1120.122978
02021-10-01BA222.850006226.720001220.600006226.0000009113600.04.00.730320226.909442205.72756151.614047116.6494402.027170217.175334221.96850021.1120.122978
02021-10-01CAT192.899994195.869995191.240005187.9280403695500.04.0-3.640324205.735919181.43278341.999435-112.08776536.203176196.993869200.52210921.1120.122978
\n", - "
\n", - " \n", - " \n", - " \n", - "\n", - " \n", - "
\n", - "
\n", - " " - ] - }, - "metadata": {}, - "execution_count": 99 - } - ], - "source": [ - "trade.head()" - ] - }, - { - "cell_type": "code", - "execution_count": 100, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "zYN573SOHhxG", - "outputId": "f5dcfc60-af90-4aa0-8849-11848b3ef619" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "['macd',\n", - " 'boll_ub',\n", - " 'boll_lb',\n", - " 'rsi_30',\n", - " 'cci_30',\n", - " 'dx_30',\n", - " 'close_30_sma',\n", - " 'close_60_sma']" - ] - }, - "metadata": {}, - "execution_count": 100 - } - ], - "source": [ - "INDICATORS" - ] - }, - { - "cell_type": "code", - "execution_count": 101, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Q2zqII8rMIqn", - "outputId": "b6f16ea3-8f52-44c7-ceb1-f58dabe3d1be" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Stock Dimension: 29, State Space: 291\n" - ] - } - ], - "source": [ - "stock_dimension = len(train.tic.unique())\n", - "state_space = 1 + 2*stock_dimension + len(INDICATORS)*stock_dimension\n", - "print(f\"Stock Dimension: {stock_dimension}, State Space: {state_space}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 102, - "metadata": { - "id": "AWyp84Ltto19" - }, - "outputs": [], - "source": [ - "buy_cost_list = sell_cost_list = [0.001] * stock_dimension\n", - "num_stock_shares = [0] * stock_dimension\n", - "\n", - "env_kwargs = {\n", - " \"hmax\": 100,\n", - " \"initial_amount\": 1000000,\n", - " \"num_stock_shares\": num_stock_shares,\n", - " \"buy_cost_pct\": buy_cost_list,\n", - " \"sell_cost_pct\": sell_cost_list,\n", - " \"state_space\": state_space,\n", - " \"stock_dim\": stock_dimension,\n", - " \"tech_indicator_list\": INDICATORS,\n", - " \"action_space\": stock_dimension,\n", - " \"reward_scaling\": 1e-4\n", - "}\n", - "\n", - "\n", - "e_train_gym = StockTradingEnv(df = train, **env_kwargs)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "64EoqOrQjiVf" - }, - "source": [ - "## Environment for Training\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": 103, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "xwSvvPjutpqS", - "outputId": "e8fc8f68-b8c9-47a8-e7d2-a6ed0715d216" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "\n" - ] - } - ], - "source": [ - "env_train, _ = e_train_gym.get_sb_env()\n", - "print(type(env_train))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HMNR5nHjh1iz" - }, - "source": [ - "\n", - "# Part 6: Train DRL Agents\n", - "* The DRL algorithms are from **Stable Baselines 3**. Users are also encouraged to try **ElegantRL** and **Ray RLlib**.\n", - "* FinRL includes fine-tuned standard DRL algorithms, such as DQN, DDPG, Multi-Agent DDPG, PPO, SAC, A2C and TD3. We also allow users to\n", - "design their own DRL algorithms by adapting these DRL algorithms." - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "metadata": { - "id": "364PsqckttcQ" - }, - "outputs": [], - "source": [ - "agent = DRLAgent(env = env_train)\n", - "\n", - "if_using_a2c = True\n", - "if_using_ddpg = True\n", - "if_using_ppo = True\n", - "if_using_td3 = True\n", - "if_using_sac = True\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "YDmqOyF9h1iz" - }, - "source": [ - "### Agent Training: 5 algorithms (A2C, DDPG, PPO, TD3, SAC)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "uijiWgkuh1jB" - }, - "source": [ - "### Agent 1: A2C\n" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "GUCnkn-HIbmj", - "outputId": "7112ce2a-0f62-4a9c-c8be-4443779b4ba0" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "{'n_steps': 5, 'ent_coef': 0.01, 'learning_rate': 0.0007}\n", - "Using cpu device\n", - "Logging to results/a2c\n" - ] - } - ], - "source": [ - "agent = DRLAgent(env = env_train)\n", - "model_a2c = agent.get_model(\"a2c\")\n", - "\n", - "if if_using_a2c:\n", - " # set up logger\n", - " tmp_path = RESULTS_DIR + '/a2c'\n", - " new_logger_a2c = configure(tmp_path, [\"stdout\", \"csv\", \"tensorboard\"])\n", - " # Set new logger\n", - " model_a2c.set_logger(new_logger_a2c)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "0GVpkWGqH4-D", - "outputId": "d00d9ef6-7489-4126-f53f-376612f48466" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "----------------------------------------\n", - "| time/ | |\n", - "| fps | 86 |\n", - "| iterations | 100 |\n", - "| time_elapsed | 5 |\n", - "| total_timesteps | 500 |\n", - "| train/ | |\n", - "| entropy_loss | -41.4 |\n", - "| explained_variance | -0.753 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 99 |\n", - "| policy_loss | -34.1 |\n", - "| reward | -0.101721555 |\n", - "| std | 1.01 |\n", - "| value_loss | 3.66 |\n", - "----------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 74 |\n", - "| iterations | 200 |\n", - "| time_elapsed | 13 |\n", - "| total_timesteps | 1000 |\n", - "| train/ | |\n", - "| entropy_loss | -41.5 |\n", - "| explained_variance | -0.0231 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 199 |\n", - "| policy_loss | -71 |\n", - "| reward | 0.90710044 |\n", - "| std | 1.01 |\n", - "| value_loss | 6.44 |\n", - "--------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 65 |\n", - "| iterations | 300 |\n", - "| time_elapsed | 22 |\n", - "| total_timesteps | 1500 |\n", - "| train/ | |\n", - "| entropy_loss | -41.5 |\n", - "| explained_variance | 5.96e-08 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 299 |\n", - "| policy_loss | 16.1 |\n", - "| reward | -2.4830217 |\n", - "| std | 1.01 |\n", - "| value_loss | 3.89 |\n", - "--------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 70 |\n", - "| iterations | 400 |\n", - "| time_elapsed | 28 |\n", - "| total_timesteps | 2000 |\n", - "| train/ | |\n", - "| entropy_loss | -41.5 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 399 |\n", - "| policy_loss | -38.5 |\n", - "| reward | -2.3722556 |\n", - "| std | 1.01 |\n", - "| value_loss | 6.41 |\n", - "--------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 70 |\n", - "| iterations | 500 |\n", - "| time_elapsed | 35 |\n", - "| total_timesteps | 2500 |\n", - "| train/ | |\n", - "| entropy_loss | -41.5 |\n", - "| explained_variance | -0.00813 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 499 |\n", - "| policy_loss | -74.9 |\n", - "| reward | -1.3848053 |\n", - "| std | 1.01 |\n", - "| value_loss | 5.89 |\n", - "--------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 600 |\n", - "| time_elapsed | 45 |\n", - "| total_timesteps | 3000 |\n", - "| train/ | |\n", - "| entropy_loss | -41.4 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 599 |\n", - "| policy_loss | 9.33 |\n", - "| reward | 0.40211454 |\n", - "| std | 1.01 |\n", - "| value_loss | 0.104 |\n", - "--------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 68 |\n", - "| iterations | 700 |\n", - "| time_elapsed | 50 |\n", - "| total_timesteps | 3500 |\n", - "| train/ | |\n", - "| entropy_loss | -41.5 |\n", - "| explained_variance | -1.19e-07 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 699 |\n", - "| policy_loss | 24.8 |\n", - "| reward | 0.5960596 |\n", - "| std | 1.01 |\n", - "| value_loss | 0.787 |\n", - "-------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 69 |\n", - "| iterations | 800 |\n", - "| time_elapsed | 57 |\n", - "| total_timesteps | 4000 |\n", - "| train/ | |\n", - "| entropy_loss | -41.5 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 799 |\n", - "| policy_loss | -113 |\n", - "| reward | 0.29731598 |\n", - "| std | 1.01 |\n", - "| value_loss | 7.73 |\n", - "--------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 900 |\n", - "| time_elapsed | 67 |\n", - "| total_timesteps | 4500 |\n", - "| train/ | |\n", - "| entropy_loss | -41.5 |\n", - "| explained_variance | -1.19e-07 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 899 |\n", - "| policy_loss | 65.6 |\n", - "| reward | 1.8696517 |\n", - "| std | 1.01 |\n", - "| value_loss | 3.23 |\n", - "-------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 68 |\n", - "| iterations | 1000 |\n", - "| time_elapsed | 73 |\n", - "| total_timesteps | 5000 |\n", - "| train/ | |\n", - "| entropy_loss | -41.5 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 999 |\n", - "| policy_loss | -343 |\n", - "| reward | 2.8599794 |\n", - "| std | 1.01 |\n", - "| value_loss | 161 |\n", - "-------------------------------------\n", - "---------------------------------------\n", - "| time/ | |\n", - "| fps | 68 |\n", - "| iterations | 1100 |\n", - "| time_elapsed | 79 |\n", - "| total_timesteps | 5500 |\n", - "| train/ | |\n", - "| entropy_loss | -41.5 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 1099 |\n", - "| policy_loss | 149 |\n", - "| reward | -0.89934605 |\n", - "| std | 1.01 |\n", - "| value_loss | 39.1 |\n", - "---------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 1200 |\n", - "| time_elapsed | 89 |\n", - "| total_timesteps | 6000 |\n", - "| train/ | |\n", - "| entropy_loss | -41.5 |\n", - "| explained_variance | -0.109 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 1199 |\n", - "| policy_loss | -80.9 |\n", - "| reward | -1.2479767 |\n", - "| std | 1.01 |\n", - "| value_loss | 4.9 |\n", - "--------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 1300 |\n", - "| time_elapsed | 96 |\n", - "| total_timesteps | 6500 |\n", - "| train/ | |\n", - "| entropy_loss | -41.6 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 1299 |\n", - "| policy_loss | -77.1 |\n", - "| reward | 0.3952496 |\n", - "| std | 1.02 |\n", - "| value_loss | 4.22 |\n", - "-------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 68 |\n", - "| iterations | 1400 |\n", - "| time_elapsed | 102 |\n", - "| total_timesteps | 7000 |\n", - "| train/ | |\n", - "| entropy_loss | -41.6 |\n", - "| explained_variance | -0.0301 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 1399 |\n", - "| policy_loss | 127 |\n", - "| reward | 0.37600544 |\n", - "| std | 1.02 |\n", - "| value_loss | 11.7 |\n", - "--------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 1500 |\n", - "| time_elapsed | 112 |\n", - "| total_timesteps | 7500 |\n", - "| train/ | |\n", - "| entropy_loss | -41.6 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 1499 |\n", - "| policy_loss | -9.18 |\n", - "| reward | 1.1150984 |\n", - "| std | 1.02 |\n", - "| value_loss | 0.346 |\n", - "-------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 1600 |\n", - "| time_elapsed | 118 |\n", - "| total_timesteps | 8000 |\n", - "| train/ | |\n", - "| entropy_loss | -41.6 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 1599 |\n", - "| policy_loss | -6.38 |\n", - "| reward | -0.5224333 |\n", - "| std | 1.02 |\n", - "| value_loss | 4.36 |\n", - "--------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 1700 |\n", - "| time_elapsed | 125 |\n", - "| total_timesteps | 8500 |\n", - "| train/ | |\n", - "| entropy_loss | -41.7 |\n", - "| explained_variance | -1.19e-07 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 1699 |\n", - "| policy_loss | 388 |\n", - "| reward | 9.236983 |\n", - "| std | 1.02 |\n", - "| value_loss | 140 |\n", - "-------------------------------------\n", - "------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 1800 |\n", - "| time_elapsed | 135 |\n", - "| total_timesteps | 9000 |\n", - "| train/ | |\n", - "| entropy_loss | -41.7 |\n", - "| explained_variance | 0.0156 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 1799 |\n", - "| policy_loss | -140 |\n", - "| reward | 0.933411 |\n", - "| std | 1.02 |\n", - "| value_loss | 13.9 |\n", - "------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 1900 |\n", - "| time_elapsed | 141 |\n", - "| total_timesteps | 9500 |\n", - "| train/ | |\n", - "| entropy_loss | -41.7 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 1899 |\n", - "| policy_loss | -81.9 |\n", - "| reward | 1.4432659 |\n", - "| std | 1.02 |\n", - "| value_loss | 4.5 |\n", - "-------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 2000 |\n", - "| time_elapsed | 148 |\n", - "| total_timesteps | 10000 |\n", - "| train/ | |\n", - "| entropy_loss | -41.8 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 1999 |\n", - "| policy_loss | -5.75 |\n", - "| reward | 0.3948126 |\n", - "| std | 1.02 |\n", - "| value_loss | 0.328 |\n", - "-------------------------------------\n", - "------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 2100 |\n", - "| time_elapsed | 158 |\n", - "| total_timesteps | 10500 |\n", - "| train/ | |\n", - "| entropy_loss | -41.7 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 2099 |\n", - "| policy_loss | -117 |\n", - "| reward | 3.777475 |\n", - "| std | 1.02 |\n", - "| value_loss | 16.7 |\n", - "------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 2200 |\n", - "| time_elapsed | 164 |\n", - "| total_timesteps | 11000 |\n", - "| train/ | |\n", - "| entropy_loss | -41.8 |\n", - "| explained_variance | 1.19e-07 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 2199 |\n", - "| policy_loss | 54.7 |\n", - "| reward | -2.081327 |\n", - "| std | 1.02 |\n", - "| value_loss | 5.43 |\n", - "-------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 2300 |\n", - "| time_elapsed | 171 |\n", - "| total_timesteps | 11500 |\n", - "| train/ | |\n", - "| entropy_loss | -41.7 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 2299 |\n", - "| policy_loss | 491 |\n", - "| reward | 0.36665618 |\n", - "| std | 1.02 |\n", - "| value_loss | 178 |\n", - "--------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 2400 |\n", - "| time_elapsed | 180 |\n", - "| total_timesteps | 12000 |\n", - "| train/ | |\n", - "| entropy_loss | -41.8 |\n", - "| explained_variance | 0.104 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 2399 |\n", - "| policy_loss | 38.8 |\n", - "| reward | -0.6101983 |\n", - "| std | 1.02 |\n", - "| value_loss | 1.19 |\n", - "--------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 2500 |\n", - "| time_elapsed | 186 |\n", - "| total_timesteps | 12500 |\n", - "| train/ | |\n", - "| entropy_loss | -41.7 |\n", - "| explained_variance | -0.32 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 2499 |\n", - "| policy_loss | -84.1 |\n", - "| reward | 1.4439964 |\n", - "| std | 1.02 |\n", - "| value_loss | 5.86 |\n", - "-------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 2600 |\n", - "| time_elapsed | 193 |\n", - "| total_timesteps | 13000 |\n", - "| train/ | |\n", - "| entropy_loss | -41.8 |\n", - "| explained_variance | 1.19e-07 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 2599 |\n", - "| policy_loss | 85.7 |\n", - "| reward | -1.2166104 |\n", - "| std | 1.02 |\n", - "| value_loss | 6.91 |\n", - "--------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 2700 |\n", - "| time_elapsed | 203 |\n", - "| total_timesteps | 13500 |\n", - "| train/ | |\n", - "| entropy_loss | -41.8 |\n", - "| explained_variance | 1.19e-07 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 2699 |\n", - "| policy_loss | 58.4 |\n", - "| reward | -1.4893322 |\n", - "| std | 1.02 |\n", - "| value_loss | 3.64 |\n", - "--------------------------------------\n", - "---------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 2800 |\n", - "| time_elapsed | 209 |\n", - "| total_timesteps | 14000 |\n", - "| train/ | |\n", - "| entropy_loss | -41.8 |\n", - "| explained_variance | 1.19e-07 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 2799 |\n", - "| policy_loss | -85.7 |\n", - "| reward | -0.53823775 |\n", - "| std | 1.02 |\n", - "| value_loss | 6.8 |\n", - "---------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 2900 |\n", - "| time_elapsed | 215 |\n", - "| total_timesteps | 14500 |\n", - "| train/ | |\n", - "| entropy_loss | -41.8 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 2899 |\n", - "| policy_loss | 246 |\n", - "| reward | -1.935958 |\n", - "| std | 1.02 |\n", - "| value_loss | 82.6 |\n", - "-------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 3000 |\n", - "| time_elapsed | 225 |\n", - "| total_timesteps | 15000 |\n", - "| train/ | |\n", - "| entropy_loss | -41.8 |\n", - "| explained_variance | 1.79e-07 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 2999 |\n", - "| policy_loss | -1.51 |\n", - "| reward | -0.4458434 |\n", - "| std | 1.02 |\n", - "| value_loss | 0.0488 |\n", - "--------------------------------------\n", - "---------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 3100 |\n", - "| time_elapsed | 231 |\n", - "| total_timesteps | 15500 |\n", - "| train/ | |\n", - "| entropy_loss | -41.8 |\n", - "| explained_variance | -1.19e-07 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 3099 |\n", - "| policy_loss | -66.7 |\n", - "| reward | -0.29501697 |\n", - "| std | 1.02 |\n", - "| value_loss | 5.04 |\n", - "---------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 3200 |\n", - "| time_elapsed | 237 |\n", - "| total_timesteps | 16000 |\n", - "| train/ | |\n", - "| entropy_loss | -41.8 |\n", - "| explained_variance | 1.19e-07 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 3199 |\n", - "| policy_loss | 519 |\n", - "| reward | 3.2372885 |\n", - "| std | 1.02 |\n", - "| value_loss | 169 |\n", - "-------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 3300 |\n", - "| time_elapsed | 247 |\n", - "| total_timesteps | 16500 |\n", - "| train/ | |\n", - "| entropy_loss | -41.8 |\n", - "| explained_variance | 1.19e-07 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 3299 |\n", - "| policy_loss | -35.2 |\n", - "| reward | -1.0215844 |\n", - "| std | 1.02 |\n", - "| value_loss | 1.25 |\n", - "--------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 3400 |\n", - "| time_elapsed | 254 |\n", - "| total_timesteps | 17000 |\n", - "| train/ | |\n", - "| entropy_loss | -41.9 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 3399 |\n", - "| policy_loss | -261 |\n", - "| reward | -1.4371744 |\n", - "| std | 1.03 |\n", - "| value_loss | 126 |\n", - "--------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 3500 |\n", - "| time_elapsed | 259 |\n", - "| total_timesteps | 17500 |\n", - "| train/ | |\n", - "| entropy_loss | -41.9 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 3499 |\n", - "| policy_loss | -772 |\n", - "| reward | -4.4645677 |\n", - "| std | 1.03 |\n", - "| value_loss | 364 |\n", - "--------------------------------------\n", - "---------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 3600 |\n", - "| time_elapsed | 269 |\n", - "| total_timesteps | 18000 |\n", - "| train/ | |\n", - "| entropy_loss | -41.9 |\n", - "| explained_variance | -0.0184 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 3599 |\n", - "| policy_loss | 29.8 |\n", - "| reward | -0.22187884 |\n", - "| std | 1.03 |\n", - "| value_loss | 0.795 |\n", - "---------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 3700 |\n", - "| time_elapsed | 276 |\n", - "| total_timesteps | 18500 |\n", - "| train/ | |\n", - "| entropy_loss | -41.9 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 3699 |\n", - "| policy_loss | 21.5 |\n", - "| reward | 0.2606225 |\n", - "| std | 1.03 |\n", - "| value_loss | 2.37 |\n", - "-------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 3800 |\n", - "| time_elapsed | 281 |\n", - "| total_timesteps | 19000 |\n", - "| train/ | |\n", - "| entropy_loss | -41.9 |\n", - "| explained_variance | -1.19e-07 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 3799 |\n", - "| policy_loss | 61.2 |\n", - "| reward | 0.2569313 |\n", - "| std | 1.03 |\n", - "| value_loss | 8.23 |\n", - "-------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 3900 |\n", - "| time_elapsed | 291 |\n", - "| total_timesteps | 19500 |\n", - "| train/ | |\n", - "| entropy_loss | -41.9 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 3899 |\n", - "| policy_loss | -82.6 |\n", - "| reward | 0.6337838 |\n", - "| std | 1.03 |\n", - "| value_loss | 5.46 |\n", - "-------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 4000 |\n", - "| time_elapsed | 298 |\n", - "| total_timesteps | 20000 |\n", - "| train/ | |\n", - "| entropy_loss | -41.9 |\n", - "| explained_variance | 0.00329 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 3999 |\n", - "| policy_loss | -282 |\n", - "| reward | -14.585105 |\n", - "| std | 1.03 |\n", - "| value_loss | 57.7 |\n", - "--------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 4100 |\n", - "| time_elapsed | 304 |\n", - "| total_timesteps | 20500 |\n", - "| train/ | |\n", - "| entropy_loss | -42 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 4099 |\n", - "| policy_loss | -300 |\n", - "| reward | 0.90241796 |\n", - "| std | 1.03 |\n", - "| value_loss | 50.1 |\n", - "--------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 4200 |\n", - "| time_elapsed | 314 |\n", - "| total_timesteps | 21000 |\n", - "| train/ | |\n", - "| entropy_loss | -42 |\n", - "| explained_variance | -1.19e-07 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 4199 |\n", - "| policy_loss | 37.2 |\n", - "| reward | -1.0533274 |\n", - "| std | 1.03 |\n", - "| value_loss | 1.52 |\n", - "--------------------------------------\n", - "---------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 4300 |\n", - "| time_elapsed | 321 |\n", - "| total_timesteps | 21500 |\n", - "| train/ | |\n", - "| entropy_loss | -42.1 |\n", - "| explained_variance | 0.153 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 4299 |\n", - "| policy_loss | 81.8 |\n", - "| reward | -0.45561686 |\n", - "| std | 1.03 |\n", - "| value_loss | 4.86 |\n", - "---------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 4400 |\n", - "| time_elapsed | 327 |\n", - "| total_timesteps | 22000 |\n", - "| train/ | |\n", - "| entropy_loss | -42.1 |\n", - "| explained_variance | -0.0523 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 4399 |\n", - "| policy_loss | -14.6 |\n", - "| reward | -2.289123 |\n", - "| std | 1.03 |\n", - "| value_loss | 1.85 |\n", - "-------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 4500 |\n", - "| time_elapsed | 336 |\n", - "| total_timesteps | 22500 |\n", - "| train/ | |\n", - "| entropy_loss | -42.1 |\n", - "| explained_variance | 1.19e-07 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 4499 |\n", - "| policy_loss | 139 |\n", - "| reward | 5.3293443 |\n", - "| std | 1.03 |\n", - "| value_loss | 15.6 |\n", - "-------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 4600 |\n", - "| time_elapsed | 343 |\n", - "| total_timesteps | 23000 |\n", - "| train/ | |\n", - "| entropy_loss | -42.1 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 4599 |\n", - "| policy_loss | 219 |\n", - "| reward | 1.2114522 |\n", - "| std | 1.04 |\n", - "| value_loss | 46.6 |\n", - "-------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 4700 |\n", - "| time_elapsed | 349 |\n", - "| total_timesteps | 23500 |\n", - "| train/ | |\n", - "| entropy_loss | -42.1 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 4699 |\n", - "| policy_loss | 43.9 |\n", - "| reward | 2.9557362 |\n", - "| std | 1.03 |\n", - "| value_loss | 2.94 |\n", - "-------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 4800 |\n", - "| time_elapsed | 358 |\n", - "| total_timesteps | 24000 |\n", - "| train/ | |\n", - "| entropy_loss | -42.1 |\n", - "| explained_variance | 1.19e-07 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 4799 |\n", - "| policy_loss | -18.3 |\n", - "| reward | -0.8173089 |\n", - "| std | 1.03 |\n", - "| value_loss | 1.11 |\n", - "--------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 4900 |\n", - "| time_elapsed | 366 |\n", - "| total_timesteps | 24500 |\n", - "| train/ | |\n", - "| entropy_loss | -42.1 |\n", - "| explained_variance | 4.77e-07 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 4899 |\n", - "| policy_loss | 140 |\n", - "| reward | 0.12697145 |\n", - "| std | 1.03 |\n", - "| value_loss | 13.9 |\n", - "--------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 5000 |\n", - "| time_elapsed | 372 |\n", - "| total_timesteps | 25000 |\n", - "| train/ | |\n", - "| entropy_loss | -42.1 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 4999 |\n", - "| policy_loss | -184 |\n", - "| reward | 1.2172059 |\n", - "| std | 1.03 |\n", - "| value_loss | 22.1 |\n", - "-------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 5100 |\n", - "| time_elapsed | 380 |\n", - "| total_timesteps | 25500 |\n", - "| train/ | |\n", - "| entropy_loss | -42.1 |\n", - "| explained_variance | -1.19e-07 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 5099 |\n", - "| policy_loss | 311 |\n", - "| reward | 1.1731412 |\n", - "| std | 1.03 |\n", - "| value_loss | 64.2 |\n", - "-------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 5200 |\n", - "| time_elapsed | 388 |\n", - "| total_timesteps | 26000 |\n", - "| train/ | |\n", - "| entropy_loss | -42.1 |\n", - "| explained_variance | -0.00065 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 5199 |\n", - "| policy_loss | 109 |\n", - "| reward | 1.1864622 |\n", - "| std | 1.03 |\n", - "| value_loss | 9.56 |\n", - "-------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 5300 |\n", - "| time_elapsed | 394 |\n", - "| total_timesteps | 26500 |\n", - "| train/ | |\n", - "| entropy_loss | -42.1 |\n", - "| explained_variance | 1.19e-07 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 5299 |\n", - "| policy_loss | 306 |\n", - "| reward | -5.989423 |\n", - "| std | 1.04 |\n", - "| value_loss | 59.3 |\n", - "-------------------------------------\n", - "day: 2956, episode: 10\n", - "begin_total_asset: 1000000.00\n", - "end_total_asset: 6683165.62\n", - "total_reward: 5683165.62\n", - "total_cost: 23114.39\n", - "total_trades: 56020\n", - "Sharpe: 0.863\n", - "=================================\n", - "---------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 5400 |\n", - "| time_elapsed | 403 |\n", - "| total_timesteps | 27000 |\n", - "| train/ | |\n", - "| entropy_loss | -42.2 |\n", - "| explained_variance | 1.19e-07 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 5399 |\n", - "| policy_loss | -88.2 |\n", - "| reward | -0.79377145 |\n", - "| std | 1.04 |\n", - "| value_loss | 4.92 |\n", - "---------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 5500 |\n", - "| time_elapsed | 411 |\n", - "| total_timesteps | 27500 |\n", - "| train/ | |\n", - "| entropy_loss | -42.2 |\n", - "| explained_variance | -0.49 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 5499 |\n", - "| policy_loss | 30 |\n", - "| reward | 0.04138078 |\n", - "| std | 1.04 |\n", - "| value_loss | 0.815 |\n", - "--------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 5600 |\n", - "| time_elapsed | 416 |\n", - "| total_timesteps | 28000 |\n", - "| train/ | |\n", - "| entropy_loss | -42.2 |\n", - "| explained_variance | -1.19e-07 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 5599 |\n", - "| policy_loss | -84.1 |\n", - "| reward | 0.9021898 |\n", - "| std | 1.04 |\n", - "| value_loss | 9.85 |\n", - "-------------------------------------\n", - "---------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 5700 |\n", - "| time_elapsed | 424 |\n", - "| total_timesteps | 28500 |\n", - "| train/ | |\n", - "| entropy_loss | -42.3 |\n", - "| explained_variance | 5.96e-08 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 5699 |\n", - "| policy_loss | -34.2 |\n", - "| reward | -0.18503402 |\n", - "| std | 1.04 |\n", - "| value_loss | 2.41 |\n", - "---------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 5800 |\n", - "| time_elapsed | 433 |\n", - "| total_timesteps | 29000 |\n", - "| train/ | |\n", - "| entropy_loss | -42.3 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 5799 |\n", - "| policy_loss | 111 |\n", - "| reward | 0.7673725 |\n", - "| std | 1.04 |\n", - "| value_loss | 15 |\n", - "-------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 5900 |\n", - "| time_elapsed | 438 |\n", - "| total_timesteps | 29500 |\n", - "| train/ | |\n", - "| entropy_loss | -42.4 |\n", - "| explained_variance | -0.000187 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 5899 |\n", - "| policy_loss | -75.2 |\n", - "| reward | -1.0595187 |\n", - "| std | 1.05 |\n", - "| value_loss | 8.22 |\n", - "--------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 6000 |\n", - "| time_elapsed | 446 |\n", - "| total_timesteps | 30000 |\n", - "| train/ | |\n", - "| entropy_loss | -42.4 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 5999 |\n", - "| policy_loss | -148 |\n", - "| reward | 0.9425562 |\n", - "| std | 1.05 |\n", - "| value_loss | 15.3 |\n", - "-------------------------------------\n", - "---------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 6100 |\n", - "| time_elapsed | 455 |\n", - "| total_timesteps | 30500 |\n", - "| train/ | |\n", - "| entropy_loss | -42.4 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 6099 |\n", - "| policy_loss | 54.9 |\n", - "| reward | -0.46677366 |\n", - "| std | 1.04 |\n", - "| value_loss | 5.95 |\n", - "---------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 6200 |\n", - "| time_elapsed | 462 |\n", - "| total_timesteps | 31000 |\n", - "| train/ | |\n", - "| entropy_loss | -42.4 |\n", - "| explained_variance | -2.38e-07 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 6199 |\n", - "| policy_loss | -212 |\n", - "| reward | -3.7030914 |\n", - "| std | 1.05 |\n", - "| value_loss | 57.7 |\n", - "--------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 6300 |\n", - "| time_elapsed | 472 |\n", - "| total_timesteps | 31500 |\n", - "| train/ | |\n", - "| entropy_loss | -42.4 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 6299 |\n", - "| policy_loss | 83 |\n", - "| reward | 0.8240016 |\n", - "| std | 1.05 |\n", - "| value_loss | 6.22 |\n", - "-------------------------------------\n", - "------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 6400 |\n", - "| time_elapsed | 480 |\n", - "| total_timesteps | 32000 |\n", - "| train/ | |\n", - "| entropy_loss | -42.5 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 6399 |\n", - "| policy_loss | 228 |\n", - "| reward | 8.043127 |\n", - "| std | 1.05 |\n", - "| value_loss | 44.7 |\n", - "------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 6500 |\n", - "| time_elapsed | 486 |\n", - "| total_timesteps | 32500 |\n", - "| train/ | |\n", - "| entropy_loss | -42.5 |\n", - "| explained_variance | 5.96e-08 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 6499 |\n", - "| policy_loss | -34.8 |\n", - "| reward | 0.5315314 |\n", - "| std | 1.05 |\n", - "| value_loss | 45.5 |\n", - "-------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 6600 |\n", - "| time_elapsed | 494 |\n", - "| total_timesteps | 33000 |\n", - "| train/ | |\n", - "| entropy_loss | -42.5 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 6599 |\n", - "| policy_loss | -165 |\n", - "| reward | -2.345529 |\n", - "| std | 1.05 |\n", - "| value_loss | 16.7 |\n", - "-------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 6700 |\n", - "| time_elapsed | 503 |\n", - "| total_timesteps | 33500 |\n", - "| train/ | |\n", - "| entropy_loss | -42.5 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 6699 |\n", - "| policy_loss | 15.2 |\n", - "| reward | 1.4328392 |\n", - "| std | 1.05 |\n", - "| value_loss | 1.09 |\n", - "-------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 6800 |\n", - "| time_elapsed | 509 |\n", - "| total_timesteps | 34000 |\n", - "| train/ | |\n", - "| entropy_loss | -42.5 |\n", - "| explained_variance | 0.0389 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 6799 |\n", - "| policy_loss | 17.9 |\n", - "| reward | -3.0146182 |\n", - "| std | 1.05 |\n", - "| value_loss | 6.76 |\n", - "--------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 6900 |\n", - "| time_elapsed | 518 |\n", - "| total_timesteps | 34500 |\n", - "| train/ | |\n", - "| entropy_loss | -42.5 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 6899 |\n", - "| policy_loss | 78.2 |\n", - "| reward | 1.1767033 |\n", - "| std | 1.05 |\n", - "| value_loss | 6.61 |\n", - "-------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 7000 |\n", - "| time_elapsed | 526 |\n", - "| total_timesteps | 35000 |\n", - "| train/ | |\n", - "| entropy_loss | -42.5 |\n", - "| explained_variance | -0.000428 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 6999 |\n", - "| policy_loss | 150 |\n", - "| reward | 1.8981657 |\n", - "| std | 1.05 |\n", - "| value_loss | 13.5 |\n", - "-------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 7100 |\n", - "| time_elapsed | 531 |\n", - "| total_timesteps | 35500 |\n", - "| train/ | |\n", - "| entropy_loss | -42.5 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 7099 |\n", - "| policy_loss | 5.28 |\n", - "| reward | 0.79177964 |\n", - "| std | 1.05 |\n", - "| value_loss | 0.0603 |\n", - "--------------------------------------\n", - "---------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 7200 |\n", - "| time_elapsed | 540 |\n", - "| total_timesteps | 36000 |\n", - "| train/ | |\n", - "| entropy_loss | -42.5 |\n", - "| explained_variance | -1.73e-05 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 7199 |\n", - "| policy_loss | 40.5 |\n", - "| reward | -0.41675776 |\n", - "| std | 1.05 |\n", - "| value_loss | 2.05 |\n", - "---------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 7300 |\n", - "| time_elapsed | 548 |\n", - "| total_timesteps | 36500 |\n", - "| train/ | |\n", - "| entropy_loss | -42.6 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 7299 |\n", - "| policy_loss | 109 |\n", - "| reward | -0.614702 |\n", - "| std | 1.05 |\n", - "| value_loss | 8.67 |\n", - "-------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 7400 |\n", - "| time_elapsed | 554 |\n", - "| total_timesteps | 37000 |\n", - "| train/ | |\n", - "| entropy_loss | -42.6 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 7399 |\n", - "| policy_loss | -406 |\n", - "| reward | 2.9344776 |\n", - "| std | 1.05 |\n", - "| value_loss | 94.3 |\n", - "-------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 7500 |\n", - "| time_elapsed | 562 |\n", - "| total_timesteps | 37500 |\n", - "| train/ | |\n", - "| entropy_loss | -42.6 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 7499 |\n", - "| policy_loss | -46.1 |\n", - "| reward | 3.3778825 |\n", - "| std | 1.05 |\n", - "| value_loss | 3.56 |\n", - "-------------------------------------\n", - "------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 7600 |\n", - "| time_elapsed | 571 |\n", - "| total_timesteps | 38000 |\n", - "| train/ | |\n", - "| entropy_loss | -42.7 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 7599 |\n", - "| policy_loss | 68.1 |\n", - "| reward | 8.8951 |\n", - "| std | 1.06 |\n", - "| value_loss | 6.43 |\n", - "------------------------------------\n", - "----------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 7700 |\n", - "| time_elapsed | 576 |\n", - "| total_timesteps | 38500 |\n", - "| train/ | |\n", - "| entropy_loss | -42.7 |\n", - "| explained_variance | 0.0107 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 7699 |\n", - "| policy_loss | 19.1 |\n", - "| reward | -0.079490714 |\n", - "| std | 1.06 |\n", - "| value_loss | 0.259 |\n", - "----------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 7800 |\n", - "| time_elapsed | 585 |\n", - "| total_timesteps | 39000 |\n", - "| train/ | |\n", - "| entropy_loss | -42.7 |\n", - "| explained_variance | -1.19e-07 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 7799 |\n", - "| policy_loss | 38.6 |\n", - "| reward | -0.7558417 |\n", - "| std | 1.06 |\n", - "| value_loss | 1.37 |\n", - "--------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 7900 |\n", - "| time_elapsed | 593 |\n", - "| total_timesteps | 39500 |\n", - "| train/ | |\n", - "| entropy_loss | -42.8 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 7899 |\n", - "| policy_loss | -73.1 |\n", - "| reward | 0.98963994 |\n", - "| std | 1.06 |\n", - "| value_loss | 3.09 |\n", - "--------------------------------------\n", - "---------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 8000 |\n", - "| time_elapsed | 599 |\n", - "| total_timesteps | 40000 |\n", - "| train/ | |\n", - "| entropy_loss | -42.9 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 7999 |\n", - "| policy_loss | -9.09 |\n", - "| reward | -0.19837299 |\n", - "| std | 1.06 |\n", - "| value_loss | 0.497 |\n", - "---------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 8100 |\n", - "| time_elapsed | 608 |\n", - "| total_timesteps | 40500 |\n", - "| train/ | |\n", - "| entropy_loss | -43 |\n", - "| explained_variance | -1.19e-07 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 8099 |\n", - "| policy_loss | -54 |\n", - "| reward | 4.5389633 |\n", - "| std | 1.07 |\n", - "| value_loss | 3.58 |\n", - "-------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 8200 |\n", - "| time_elapsed | 616 |\n", - "| total_timesteps | 41000 |\n", - "| train/ | |\n", - "| entropy_loss | -43 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 8199 |\n", - "| policy_loss | -1.03e+03 |\n", - "| reward | -7.644335 |\n", - "| std | 1.07 |\n", - "| value_loss | 644 |\n", - "-------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 8300 |\n", - "| time_elapsed | 622 |\n", - "| total_timesteps | 41500 |\n", - "| train/ | |\n", - "| entropy_loss | -43 |\n", - "| explained_variance | -0.0583 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 8299 |\n", - "| policy_loss | -117 |\n", - "| reward | -1.0602136 |\n", - "| std | 1.07 |\n", - "| value_loss | 13.6 |\n", - "--------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 8400 |\n", - "| time_elapsed | 630 |\n", - "| total_timesteps | 42000 |\n", - "| train/ | |\n", - "| entropy_loss | -43 |\n", - "| explained_variance | 0.000235 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 8399 |\n", - "| policy_loss | -94.3 |\n", - "| reward | 0.6174811 |\n", - "| std | 1.07 |\n", - "| value_loss | 7.66 |\n", - "-------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 8500 |\n", - "| time_elapsed | 638 |\n", - "| total_timesteps | 42500 |\n", - "| train/ | |\n", - "| entropy_loss | -43.1 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 8499 |\n", - "| policy_loss | -36.7 |\n", - "| reward | -1.7684387 |\n", - "| std | 1.07 |\n", - "| value_loss | 3.43 |\n", - "--------------------------------------\n", - "------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 8600 |\n", - "| time_elapsed | 644 |\n", - "| total_timesteps | 43000 |\n", - "| train/ | |\n", - "| entropy_loss | -43.1 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 8599 |\n", - "| policy_loss | 5.35 |\n", - "| reward | 3.532196 |\n", - "| std | 1.07 |\n", - "| value_loss | 4.08 |\n", - "------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 8700 |\n", - "| time_elapsed | 652 |\n", - "| total_timesteps | 43500 |\n", - "| train/ | |\n", - "| entropy_loss | -43.1 |\n", - "| explained_variance | -1.19e-07 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 8699 |\n", - "| policy_loss | -87.6 |\n", - "| reward | 2.3305223 |\n", - "| std | 1.07 |\n", - "| value_loss | 4.95 |\n", - "-------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 8800 |\n", - "| time_elapsed | 661 |\n", - "| total_timesteps | 44000 |\n", - "| train/ | |\n", - "| entropy_loss | -43.1 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 8799 |\n", - "| policy_loss | 932 |\n", - "| reward | -1.1599989 |\n", - "| std | 1.07 |\n", - "| value_loss | 523 |\n", - "--------------------------------------\n", - "------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 8900 |\n", - "| time_elapsed | 667 |\n", - "| total_timesteps | 44500 |\n", - "| train/ | |\n", - "| entropy_loss | -43.1 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 8899 |\n", - "| policy_loss | 95.4 |\n", - "| reward | 1.23851 |\n", - "| std | 1.07 |\n", - "| value_loss | 5.59 |\n", - "------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 9000 |\n", - "| time_elapsed | 674 |\n", - "| total_timesteps | 45000 |\n", - "| train/ | |\n", - "| entropy_loss | -43.1 |\n", - "| explained_variance | 1.79e-07 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 8999 |\n", - "| policy_loss | 78 |\n", - "| reward | 0.81774557 |\n", - "| std | 1.07 |\n", - "| value_loss | 3.88 |\n", - "--------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 9100 |\n", - "| time_elapsed | 683 |\n", - "| total_timesteps | 45500 |\n", - "| train/ | |\n", - "| entropy_loss | -43.1 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 9099 |\n", - "| policy_loss | -16 |\n", - "| reward | -1.3974568 |\n", - "| std | 1.07 |\n", - "| value_loss | 0.808 |\n", - "--------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 9200 |\n", - "| time_elapsed | 689 |\n", - "| total_timesteps | 46000 |\n", - "| train/ | |\n", - "| entropy_loss | -43.1 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 9199 |\n", - "| policy_loss | 200 |\n", - "| reward | 0.9757047 |\n", - "| std | 1.07 |\n", - "| value_loss | 25.5 |\n", - "-------------------------------------\n", - "------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 9300 |\n", - "| time_elapsed | 697 |\n", - "| total_timesteps | 46500 |\n", - "| train/ | |\n", - "| entropy_loss | -43.1 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 9299 |\n", - "| policy_loss | 126 |\n", - "| reward | 6.874766 |\n", - "| std | 1.07 |\n", - "| value_loss | 11.4 |\n", - "------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 9400 |\n", - "| time_elapsed | 706 |\n", - "| total_timesteps | 47000 |\n", - "| train/ | |\n", - "| entropy_loss | -43.1 |\n", - "| explained_variance | -2.38e-07 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 9399 |\n", - "| policy_loss | 111 |\n", - "| reward | 4.1356363 |\n", - "| std | 1.07 |\n", - "| value_loss | 80.4 |\n", - "-------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 9500 |\n", - "| time_elapsed | 712 |\n", - "| total_timesteps | 47500 |\n", - "| train/ | |\n", - "| entropy_loss | -43.1 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 9499 |\n", - "| policy_loss | 41.3 |\n", - "| reward | 0.15961184 |\n", - "| std | 1.07 |\n", - "| value_loss | 1.22 |\n", - "--------------------------------------\n", - "----------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 9600 |\n", - "| time_elapsed | 719 |\n", - "| total_timesteps | 48000 |\n", - "| train/ | |\n", - "| entropy_loss | -43.2 |\n", - "| explained_variance | 0.203 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 9599 |\n", - "| policy_loss | -9.32 |\n", - "| reward | -0.113297194 |\n", - "| std | 1.07 |\n", - "| value_loss | 0.747 |\n", - "----------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 9700 |\n", - "| time_elapsed | 728 |\n", - "| total_timesteps | 48500 |\n", - "| train/ | |\n", - "| entropy_loss | -43.1 |\n", - "| explained_variance | -1.19e-07 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 9699 |\n", - "| policy_loss | -11.3 |\n", - "| reward | -0.942293 |\n", - "| std | 1.07 |\n", - "| value_loss | 1.77 |\n", - "-------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 9800 |\n", - "| time_elapsed | 734 |\n", - "| total_timesteps | 49000 |\n", - "| train/ | |\n", - "| entropy_loss | -43 |\n", - "| explained_variance | 0 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 9799 |\n", - "| policy_loss | -251 |\n", - "| reward | -1.0976613 |\n", - "| std | 1.07 |\n", - "| value_loss | 32.3 |\n", - "--------------------------------------\n", - "-------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 9900 |\n", - "| time_elapsed | 742 |\n", - "| total_timesteps | 49500 |\n", - "| train/ | |\n", - "| entropy_loss | -42.9 |\n", - "| explained_variance | -1.19e-07 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 9899 |\n", - "| policy_loss | -193 |\n", - "| reward | 6.7274637 |\n", - "| std | 1.07 |\n", - "| value_loss | 27.1 |\n", - "-------------------------------------\n", - "--------------------------------------\n", - "| time/ | |\n", - "| fps | 66 |\n", - "| iterations | 10000 |\n", - "| time_elapsed | 751 |\n", - "| total_timesteps | 50000 |\n", - "| train/ | |\n", - "| entropy_loss | -42.9 |\n", - "| explained_variance | -1.19e-07 |\n", - "| learning_rate | 0.0007 |\n", - "| n_updates | 9999 |\n", - "| policy_loss | 335 |\n", - "| reward | -16.778671 |\n", - "| std | 1.06 |\n", - "| value_loss | 330 |\n", - "--------------------------------------\n" - ] - } - ], - "source": [ - "trained_a2c = agent.train_model(model=model_a2c, \n", - " tb_log_name='a2c',\n", - " total_timesteps=50000) if if_using_a2c else None" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MRiOtrywfAo1" - }, - "source": [ - "### Agent 2: DDPG" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "M2YadjfnLwgt", - "outputId": "8c8b5e98-763c-453c-a280-1b4f3ac13510" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "{'batch_size': 128, 'buffer_size': 50000, 'learning_rate': 0.001}\n", - "Using cpu device\n", - "Logging to results/ddpg\n" - ] - } - ], - "source": [ - "agent = DRLAgent(env = env_train)\n", - "model_ddpg = agent.get_model(\"ddpg\")\n", - "\n", - "if if_using_ddpg:\n", - " # set up logger\n", - " tmp_path = RESULTS_DIR + '/ddpg'\n", - " new_logger_ddpg = configure(tmp_path, [\"stdout\", \"csv\", \"tensorboard\"])\n", - " # Set new logger\n", - " model_ddpg.set_logger(new_logger_ddpg)" - ] - }, - { - "cell_type": "code", - "execution_count": 26, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "tCDa78rqfO_a", - "outputId": "35589661-85de-42ca-b9f1-52cde7ded447" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "day: 2956, episode: 20\n", - "begin_total_asset: 1000000.00\n", - "end_total_asset: 2896568.03\n", - "total_reward: 1896568.03\n", - "total_cost: 1065.27\n", - "total_trades: 47321\n", - "Sharpe: 0.607\n", - "=================================\n", - "-----------------------------------\n", - "| time/ | |\n", - "| episodes | 4 |\n", - "| fps | 24 |\n", - "| time_elapsed | 482 |\n", - "| total_timesteps | 11828 |\n", - "| train/ | |\n", - "| actor_loss | 3.26 |\n", - "| critic_loss | 41.2 |\n", - "| learning_rate | 0.001 |\n", - "| n_updates | 8871 |\n", - "| reward | -4.6516004 |\n", - "-----------------------------------\n", - "-----------------------------------\n", - "| time/ | |\n", - "| episodes | 8 |\n", - "| fps | 21 |\n", - "| time_elapsed | 1088 |\n", - "| total_timesteps | 23656 |\n", - "| train/ | |\n", - "| actor_loss | -2.37 |\n", - "| critic_loss | 2.02 |\n", - "| learning_rate | 0.001 |\n", - "| n_updates | 20699 |\n", - "| reward | -4.6516004 |\n", - "-----------------------------------\n", - "day: 2956, episode: 30\n", - "begin_total_asset: 1000000.00\n", - "end_total_asset: 2898382.52\n", - "total_reward: 1898382.52\n", - "total_cost: 1065.67\n", - "total_trades: 50276\n", - "Sharpe: 0.598\n", - "=================================\n", - "-----------------------------------\n", - "| time/ | |\n", - "| episodes | 12 |\n", - "| fps | 20 |\n", - "| time_elapsed | 1709 |\n", - "| total_timesteps | 35484 |\n", - "| train/ | |\n", - "| actor_loss | -3.95 |\n", - "| critic_loss | 1.36 |\n", - "| learning_rate | 0.001 |\n", - "| n_updates | 32527 |\n", - "| reward | -4.6516004 |\n", - "-----------------------------------\n", - "-----------------------------------\n", - "| time/ | |\n", - "| episodes | 16 |\n", - "| fps | 20 |\n", - "| time_elapsed | 2344 |\n", - "| total_timesteps | 47312 |\n", - "| train/ | |\n", - "| actor_loss | -5.3 |\n", - "| critic_loss | 0.938 |\n", - "| learning_rate | 0.001 |\n", - "| n_updates | 44355 |\n", - "| reward | -4.6516004 |\n", - "-----------------------------------\n" - ] - } - ], - "source": [ - "trained_ddpg = agent.train_model(model=model_ddpg, \n", - " tb_log_name='ddpg',\n", - " total_timesteps=50000) if if_using_ddpg else None" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_gDkU-j-fCmZ" - }, - "source": [ - "### Agent 3: PPO" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "y5D5PFUhMzSV", - "outputId": "2abd06c0-deca-457b-819b-3059c3f17645" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "{'n_steps': 2048, 'ent_coef': 0.01, 'learning_rate': 0.00025, 'batch_size': 128}\n", - "Using cpu device\n", - "Logging to results/ppo\n" - ] - } - ], - "source": [ - "agent = DRLAgent(env = env_train)\n", - "PPO_PARAMS = {\n", - " \"n_steps\": 2048,\n", - " \"ent_coef\": 0.01,\n", - " \"learning_rate\": 0.00025,\n", - " \"batch_size\": 128,\n", - "}\n", - "model_ppo = agent.get_model(\"ppo\",model_kwargs = PPO_PARAMS)\n", - "\n", - "if if_using_ppo:\n", - " # set up logger\n", - " tmp_path = RESULTS_DIR + '/ppo'\n", - " new_logger_ppo = configure(tmp_path, [\"stdout\", \"csv\", \"tensorboard\"])\n", - " # Set new logger\n", - " model_ppo.set_logger(new_logger_ppo)" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Gt8eIQKYM4G3", - "outputId": "26365c9a-f608-4dd4-9695-018b98d1036a" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "---------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 1 |\n", - "| time_elapsed | 30 |\n", - "| total_timesteps | 2048 |\n", - "| train/ | |\n", - "| reward | 2.736287 |\n", - "---------------------------------\n", - "-----------------------------------------\n", - "| time/ | |\n", - "| fps | 69 |\n", - "| iterations | 2 |\n", - "| time_elapsed | 59 |\n", - "| total_timesteps | 4096 |\n", - "| train/ | |\n", - "| approx_kl | 0.012204561 |\n", - "| clip_fraction | 0.187 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -41.2 |\n", - "| explained_variance | -0.0229 |\n", - "| learning_rate | 0.00025 |\n", - "| loss | 9.56 |\n", - "| n_updates | 10 |\n", - "| policy_gradient_loss | -0.0219 |\n", - "| reward | -1.1999706 |\n", - "| std | 1 |\n", - "| value_loss | 19.6 |\n", - "-----------------------------------------\n", - "---------------------------------------\n", - "| time/ | |\n", - "| fps | 68 |\n", - "| iterations | 3 |\n", - "| time_elapsed | 89 |\n", - "| total_timesteps | 6144 |\n", - "| train/ | |\n", - "| approx_kl | 0.0233445 |\n", - "| clip_fraction | 0.193 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -41.2 |\n", - "| explained_variance | 0.00147 |\n", - "| learning_rate | 0.00025 |\n", - "| loss | 40.7 |\n", - "| n_updates | 20 |\n", - "| policy_gradient_loss | -0.0196 |\n", - "| reward | 1.6823713 |\n", - "| std | 1 |\n", - "| value_loss | 95.8 |\n", - "---------------------------------------\n", - "----------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 4 |\n", - "| time_elapsed | 121 |\n", - "| total_timesteps | 8192 |\n", - "| train/ | |\n", - "| approx_kl | 0.01568921 |\n", - "| clip_fraction | 0.155 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -41.3 |\n", - "| explained_variance | -0.0312 |\n", - "| learning_rate | 0.00025 |\n", - "| loss | 13.1 |\n", - "| n_updates | 30 |\n", - "| policy_gradient_loss | -0.0172 |\n", - "| reward | 0.4094886 |\n", - "| std | 1 |\n", - "| value_loss | 38.3 |\n", - "----------------------------------------\n", - "----------------------------------------\n", - "| time/ | |\n", - "| fps | 68 |\n", - "| iterations | 5 |\n", - "| time_elapsed | 149 |\n", - "| total_timesteps | 10240 |\n", - "| train/ | |\n", - "| approx_kl | 0.01757015 |\n", - "| clip_fraction | 0.222 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -41.3 |\n", - "| explained_variance | -0.0286 |\n", - "| learning_rate | 0.00025 |\n", - "| loss | 6.71 |\n", - "| n_updates | 40 |\n", - "| policy_gradient_loss | -0.024 |\n", - "| reward | -1.0190102 |\n", - "| std | 1.01 |\n", - "| value_loss | 19.7 |\n", - "----------------------------------------\n", - "day: 2956, episode: 40\n", - "begin_total_asset: 1000000.00\n", - "end_total_asset: 3710303.65\n", - "total_reward: 2710303.65\n", - "total_cost: 383322.44\n", - "total_trades: 80523\n", - "Sharpe: 0.663\n", - "=================================\n", - "---------------------------------------\n", - "| time/ | |\n", - "| fps | 68 |\n", - "| iterations | 6 |\n", - "| time_elapsed | 178 |\n", - "| total_timesteps | 12288 |\n", - "| train/ | |\n", - "| approx_kl | 0.0189244 |\n", - "| clip_fraction | 0.2 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -41.4 |\n", - "| explained_variance | 0.000586 |\n", - "| learning_rate | 0.00025 |\n", - "| loss | 54.5 |\n", - "| n_updates | 50 |\n", - "| policy_gradient_loss | -0.0164 |\n", - "| reward | 0.229152 |\n", - "| std | 1.01 |\n", - "| value_loss | 75.9 |\n", - "---------------------------------------\n", - "-----------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 7 |\n", - "| time_elapsed | 211 |\n", - "| total_timesteps | 14336 |\n", - "| train/ | |\n", - "| approx_kl | 0.024691764 |\n", - "| clip_fraction | 0.224 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -41.5 |\n", - "| explained_variance | -0.00182 |\n", - "| learning_rate | 0.00025 |\n", - "| loss | 28.3 |\n", - "| n_updates | 60 |\n", - "| policy_gradient_loss | -0.0165 |\n", - "| reward | 2.8490999 |\n", - "| std | 1.01 |\n", - "| value_loss | 109 |\n", - "-----------------------------------------\n", - "-----------------------------------------\n", - "| time/ | |\n", - "| fps | 68 |\n", - "| iterations | 8 |\n", - "| time_elapsed | 239 |\n", - "| total_timesteps | 16384 |\n", - "| train/ | |\n", - "| approx_kl | 0.019562809 |\n", - "| clip_fraction | 0.204 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -41.5 |\n", - "| explained_variance | 0.0154 |\n", - "| learning_rate | 0.00025 |\n", - "| loss | 38.4 |\n", - "| n_updates | 70 |\n", - "| policy_gradient_loss | -0.0196 |\n", - "| reward | -1.1492106 |\n", - "| std | 1.01 |\n", - "| value_loss | 80.2 |\n", - "-----------------------------------------\n", - "-----------------------------------------\n", - "| time/ | |\n", - "| fps | 68 |\n", - "| iterations | 9 |\n", - "| time_elapsed | 268 |\n", - "| total_timesteps | 18432 |\n", - "| train/ | |\n", - "| approx_kl | 0.017293174 |\n", - "| clip_fraction | 0.214 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -41.6 |\n", - "| explained_variance | 0.000796 |\n", - "| learning_rate | 0.00025 |\n", - "| loss | 149 |\n", - "| n_updates | 80 |\n", - "| policy_gradient_loss | -0.0199 |\n", - "| reward | 0.9604615 |\n", - "| std | 1.02 |\n", - "| value_loss | 177 |\n", - "-----------------------------------------\n", - "-----------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 10 |\n", - "| time_elapsed | 301 |\n", - "| total_timesteps | 20480 |\n", - "| train/ | |\n", - "| approx_kl | 0.018084986 |\n", - "| clip_fraction | 0.187 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -41.6 |\n", - "| explained_variance | 0.0248 |\n", - "| learning_rate | 0.00025 |\n", - "| loss | 49 |\n", - "| n_updates | 90 |\n", - "| policy_gradient_loss | -0.0167 |\n", - "| reward | 4.2871847 |\n", - "| std | 1.02 |\n", - "| value_loss | 116 |\n", - "-----------------------------------------\n", - "-----------------------------------------\n", - "| time/ | |\n", - "| fps | 68 |\n", - "| iterations | 11 |\n", - "| time_elapsed | 330 |\n", - "| total_timesteps | 22528 |\n", - "| train/ | |\n", - "| approx_kl | 0.02062156 |\n", - "| clip_fraction | 0.183 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -41.7 |\n", - "| explained_variance | 0.00724 |\n", - "| learning_rate | 0.00025 |\n", - "| loss | 132 |\n", - "| n_updates | 100 |\n", - "| policy_gradient_loss | -0.0171 |\n", - "| reward | -0.59720796 |\n", - "| std | 1.02 |\n", - "| value_loss | 206 |\n", - "-----------------------------------------\n", - "-----------------------------------------\n", - "| time/ | |\n", - "| fps | 68 |\n", - "| iterations | 12 |\n", - "| time_elapsed | 360 |\n", - "| total_timesteps | 24576 |\n", - "| train/ | |\n", - "| approx_kl | 0.019513914 |\n", - "| clip_fraction | 0.249 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -41.8 |\n", - "| explained_variance | 0.062 |\n", - "| learning_rate | 0.00025 |\n", - "| loss | 12.1 |\n", - "| n_updates | 110 |\n", - "| policy_gradient_loss | -0.0151 |\n", - "| reward | 0.7724313 |\n", - "| std | 1.02 |\n", - "| value_loss | 23.5 |\n", - "-----------------------------------------\n", - "-----------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 13 |\n", - "| time_elapsed | 392 |\n", - "| total_timesteps | 26624 |\n", - "| train/ | |\n", - "| approx_kl | 0.015009267 |\n", - "| clip_fraction | 0.132 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -41.8 |\n", - "| explained_variance | 0.00529 |\n", - "| learning_rate | 0.00025 |\n", - "| loss | 97.5 |\n", - "| n_updates | 120 |\n", - "| policy_gradient_loss | -0.0182 |\n", - "| reward | -0.09052762 |\n", - "| std | 1.02 |\n", - "| value_loss | 192 |\n", - "-----------------------------------------\n", - "-----------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 14 |\n", - "| time_elapsed | 421 |\n", - "| total_timesteps | 28672 |\n", - "| train/ | |\n", - "| approx_kl | 0.019826401 |\n", - "| clip_fraction | 0.22 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -41.8 |\n", - "| explained_variance | 0.00322 |\n", - "| learning_rate | 0.00025 |\n", - "| loss | 76.1 |\n", - "| n_updates | 130 |\n", - "| policy_gradient_loss | -0.0105 |\n", - "| reward | 4.809871 |\n", - "| std | 1.02 |\n", - "| value_loss | 138 |\n", - "-----------------------------------------\n", - "-----------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 15 |\n", - "| time_elapsed | 454 |\n", - "| total_timesteps | 30720 |\n", - "| train/ | |\n", - "| approx_kl | 0.025237966 |\n", - "| clip_fraction | 0.256 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -41.9 |\n", - "| explained_variance | -0.063 |\n", - "| learning_rate | 0.00025 |\n", - "| loss | 3.86 |\n", - "| n_updates | 140 |\n", - "| policy_gradient_loss | -0.0151 |\n", - "| reward | 0.7681675 |\n", - "| std | 1.03 |\n", - "| value_loss | 12.4 |\n", - "-----------------------------------------\n", - "-----------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 16 |\n", - "| time_elapsed | 485 |\n", - "| total_timesteps | 32768 |\n", - "| train/ | |\n", - "| approx_kl | 0.019492429 |\n", - "| clip_fraction | 0.215 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -41.9 |\n", - "| explained_variance | 0.021 |\n", - "| learning_rate | 0.00025 |\n", - "| loss | 54.9 |\n", - "| n_updates | 150 |\n", - "| policy_gradient_loss | -0.0112 |\n", - "| reward | 0.03731302 |\n", - "| std | 1.03 |\n", - "| value_loss | 83.3 |\n", - "-----------------------------------------\n", - "----------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 17 |\n", - "| time_elapsed | 514 |\n", - "| total_timesteps | 34816 |\n", - "| train/ | |\n", - "| approx_kl | 0.02347028 |\n", - "| clip_fraction | 0.222 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -42 |\n", - "| explained_variance | 0.00434 |\n", - "| learning_rate | 0.00025 |\n", - "| loss | 27.4 |\n", - "| n_updates | 160 |\n", - "| policy_gradient_loss | -0.0183 |\n", - "| reward | -1.6830779 |\n", - "| std | 1.03 |\n", - "| value_loss | 77.9 |\n", - "----------------------------------------\n", - "-----------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 18 |\n", - "| time_elapsed | 544 |\n", - "| total_timesteps | 36864 |\n", - "| train/ | |\n", - "| approx_kl | 0.016928129 |\n", - "| clip_fraction | 0.187 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -42.1 |\n", - "| explained_variance | 0.0164 |\n", - "| learning_rate | 0.00025 |\n", - "| loss | 3.95 |\n", - "| n_updates | 170 |\n", - "| policy_gradient_loss | -0.019 |\n", - "| reward | -4.010391 |\n", - "| std | 1.03 |\n", - "| value_loss | 10.3 |\n", - "-----------------------------------------\n", - "-----------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 19 |\n", - "| time_elapsed | 576 |\n", - "| total_timesteps | 38912 |\n", - "| train/ | |\n", - "| approx_kl | 0.017712107 |\n", - "| clip_fraction | 0.194 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -42.1 |\n", - "| explained_variance | 0.000695 |\n", - "| learning_rate | 0.00025 |\n", - "| loss | 10.5 |\n", - "| n_updates | 180 |\n", - "| policy_gradient_loss | -0.0233 |\n", - "| reward | -0.7916964 |\n", - "| std | 1.04 |\n", - "| value_loss | 33.8 |\n", - "-----------------------------------------\n", - "-----------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 20 |\n", - "| time_elapsed | 605 |\n", - "| total_timesteps | 40960 |\n", - "| train/ | |\n", - "| approx_kl | 0.035940316 |\n", - "| clip_fraction | 0.27 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -42.2 |\n", - "| explained_variance | 0.0138 |\n", - "| learning_rate | 0.00025 |\n", - "| loss | 19.5 |\n", - "| n_updates | 190 |\n", - "| policy_gradient_loss | -0.0122 |\n", - "| reward | -3.1133146 |\n", - "| std | 1.04 |\n", - "| value_loss | 73.5 |\n", - "-----------------------------------------\n", - "day: 2956, episode: 50\n", - "begin_total_asset: 1000000.00\n", - "end_total_asset: 4489520.08\n", - "total_reward: 3489520.08\n", - "total_cost: 373083.24\n", - "total_trades: 79300\n", - "Sharpe: 0.752\n", - "=================================\n", - "------------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 21 |\n", - "| time_elapsed | 635 |\n", - "| total_timesteps | 43008 |\n", - "| train/ | |\n", - "| approx_kl | 0.028285751 |\n", - "| clip_fraction | 0.23 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -42.2 |\n", - "| explained_variance | -0.0141 |\n", - "| learning_rate | 0.00025 |\n", - "| loss | 12.6 |\n", - "| n_updates | 200 |\n", - "| policy_gradient_loss | -0.0122 |\n", - "| reward | -0.055241805 |\n", - "| std | 1.04 |\n", - "| value_loss | 26.9 |\n", - "------------------------------------------\n", - "-------------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 22 |\n", - "| time_elapsed | 667 |\n", - "| total_timesteps | 45056 |\n", - "| train/ | |\n", - "| approx_kl | 0.015426388 |\n", - "| clip_fraction | 0.218 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -42.3 |\n", - "| explained_variance | -0.01 |\n", - "| learning_rate | 0.00025 |\n", - "| loss | 47.4 |\n", - "| n_updates | 210 |\n", - "| policy_gradient_loss | -0.0143 |\n", - "| reward | -0.0147438925 |\n", - "| std | 1.04 |\n", - "| value_loss | 68.7 |\n", - "-------------------------------------------\n", - "-----------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 23 |\n", - "| time_elapsed | 695 |\n", - "| total_timesteps | 47104 |\n", - "| train/ | |\n", - "| approx_kl | 0.042087153 |\n", - "| clip_fraction | 0.314 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -42.4 |\n", - "| explained_variance | 0.00911 |\n", - "| learning_rate | 0.00025 |\n", - "| loss | 18.8 |\n", - "| n_updates | 220 |\n", - "| policy_gradient_loss | -0.0173 |\n", - "| reward | 2.4495711 |\n", - "| std | 1.04 |\n", - "| value_loss | 52.7 |\n", - "-----------------------------------------\n", - "-----------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 24 |\n", - "| time_elapsed | 725 |\n", - "| total_timesteps | 49152 |\n", - "| train/ | |\n", - "| approx_kl | 0.04404234 |\n", - "| clip_fraction | 0.306 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -42.4 |\n", - "| explained_variance | 0.00196 |\n", - "| learning_rate | 0.00025 |\n", - "| loss | 30 |\n", - "| n_updates | 230 |\n", - "| policy_gradient_loss | -0.0121 |\n", - "| reward | -0.55505633 |\n", - "| std | 1.05 |\n", - "| value_loss | 68.1 |\n", - "-----------------------------------------\n", - "-----------------------------------------\n", - "| time/ | |\n", - "| fps | 67 |\n", - "| iterations | 25 |\n", - "| time_elapsed | 757 |\n", - "| total_timesteps | 51200 |\n", - "| train/ | |\n", - "| approx_kl | 0.018749528 |\n", - "| clip_fraction | 0.19 |\n", - "| clip_range | 0.2 |\n", - "| entropy_loss | -42.5 |\n", - "| explained_variance | 0.0504 |\n", - "| learning_rate | 0.00025 |\n", - "| loss | 6.77 |\n", - "| n_updates | 240 |\n", - "| policy_gradient_loss | -0.0183 |\n", - "| reward | 0.8350754 |\n", - "| std | 1.05 |\n", - "| value_loss | 22.2 |\n", - "-----------------------------------------\n" - ] - } - ], - "source": [ - "trained_ppo = agent.train_model(model=model_ppo, \n", - " tb_log_name='ppo',\n", - " total_timesteps=50000) if if_using_ppo else None" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3Zpv4S0-fDBv" - }, - "source": [ - "### Agent 4: TD3" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "JSAHhV4Xc-bh", - "outputId": "db147b9a-163a-4d03-dd6c-9e89f0e8f421" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "{'batch_size': 100, 'buffer_size': 1000000, 'learning_rate': 0.001}\n", - "Using cpu device\n", - "Logging to results/td3\n" - ] - } - ], - "source": [ - "agent = DRLAgent(env = env_train)\n", - "TD3_PARAMS = {\"batch_size\": 100, \n", - " \"buffer_size\": 1000000, \n", - " \"learning_rate\": 0.001}\n", - "\n", - "model_td3 = agent.get_model(\"td3\",model_kwargs = TD3_PARAMS)\n", - "\n", - "if if_using_td3:\n", - " # set up logger\n", - " tmp_path = RESULTS_DIR + '/td3'\n", - " new_logger_td3 = configure(tmp_path, [\"stdout\", \"csv\", \"tensorboard\"])\n", - " # Set new logger\n", - " model_td3.set_logger(new_logger_td3)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "OSRxNYAxdKpU", - "outputId": "1d85d74c-54cf-4682-a34b-481a5aafe5d4" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "-----------------------------------\n", - "| time/ | |\n", - "| episodes | 4 |\n", - "| fps | 25 |\n", - "| time_elapsed | 464 |\n", - "| total_timesteps | 11828 |\n", - "| train/ | |\n", - "| actor_loss | 85.6 |\n", - "| critic_loss | 2.26e+03 |\n", - "| learning_rate | 0.001 |\n", - "| n_updates | 8871 |\n", - "| reward | -5.8078027 |\n", - "-----------------------------------\n", - "day: 2956, episode: 60\n", - "begin_total_asset: 1000000.00\n", - "end_total_asset: 3391778.54\n", - "total_reward: 2391778.54\n", - "total_cost: 999.00\n", - "total_trades: 50252\n", - "Sharpe: 0.630\n", - "=================================\n", - "-----------------------------------\n", - "| time/ | |\n", - "| episodes | 8 |\n", - "| fps | 22 |\n", - "| time_elapsed | 1070 |\n", - "| total_timesteps | 23656 |\n", - "| train/ | |\n", - "| actor_loss | 59.1 |\n", - "| critic_loss | 438 |\n", - "| learning_rate | 0.001 |\n", - "| n_updates | 20699 |\n", - "| reward | -5.8078027 |\n", - "-----------------------------------\n", - "-----------------------------------\n", - "| time/ | |\n", - "| episodes | 12 |\n", - "| fps | 21 |\n", - "| time_elapsed | 1681 |\n", - "| total_timesteps | 35484 |\n", - "| train/ | |\n", - "| actor_loss | 43.3 |\n", - "| critic_loss | 82.1 |\n", - "| learning_rate | 0.001 |\n", - "| n_updates | 32527 |\n", - "| reward | -5.8078027 |\n", - "-----------------------------------\n" - ] - } - ], - "source": [ - "trained_td3 = agent.train_model(model=model_td3, \n", - " tb_log_name='td3',\n", - " total_timesteps=50000) if if_using_td3 else None" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Dr49PotrfG01" - }, - "source": [ - "### Agent 5: SAC" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "xwOhVjqRkCdM", - "outputId": "9018f9ed-0dff-4b75-c0b2-7566784c52cf" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "{'batch_size': 128, 'buffer_size': 100000, 'learning_rate': 0.0001, 'learning_starts': 100, 'ent_coef': 'auto_0.1'}\n", - "Using cpu device\n", - "Logging to results/sac\n" - ] - } - ], - "source": [ - "agent = DRLAgent(env = env_train)\n", - "SAC_PARAMS = {\n", - " \"batch_size\": 128,\n", - " \"buffer_size\": 100000,\n", - " \"learning_rate\": 0.0001,\n", - " \"learning_starts\": 100,\n", - " \"ent_coef\": \"auto_0.1\",\n", - "}\n", - "\n", - "model_sac = agent.get_model(\"sac\",model_kwargs = SAC_PARAMS)\n", - "\n", - "if if_using_sac:\n", - " # set up logger\n", - " tmp_path = RESULTS_DIR + '/sac'\n", - " new_logger_sac = configure(tmp_path, [\"stdout\", \"csv\", \"tensorboard\"])\n", - " # Set new logger\n", - " model_sac.set_logger(new_logger_sac)" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "K8RSdKCckJyH", - "outputId": "bfa91496-f7e6-4d0f-fb77-bc9dd1797e81" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "day: 2956, episode: 80\n", - "begin_total_asset: 1000000.00\n", - "end_total_asset: 2788059.56\n", - "total_reward: 1788059.56\n", - "total_cost: 247103.77\n", - "total_trades: 76956\n", - "Sharpe: 0.561\n", - "=================================\n", - "----------------------------------\n", - "| time/ | |\n", - "| episodes | 4 |\n", - "| fps | 19 |\n", - "| time_elapsed | 622 |\n", - "| total_timesteps | 11828 |\n", - "| train/ | |\n", - "| actor_loss | 258 |\n", - "| critic_loss | 22.5 |\n", - "| ent_coef | 0.0849 |\n", - "| ent_coef_loss | -113 |\n", - "| learning_rate | 0.0001 |\n", - "| n_updates | 11727 |\n", - "| reward | -9.282894 |\n", - "----------------------------------\n", - "-----------------------------------\n", - "| time/ | |\n", - "| episodes | 8 |\n", - "| fps | 19 |\n", - "| time_elapsed | 1235 |\n", - "| total_timesteps | 23656 |\n", - "| train/ | |\n", - "| actor_loss | 107 |\n", - "| critic_loss | 22.7 |\n", - "| ent_coef | 0.0261 |\n", - "| ent_coef_loss | -166 |\n", - "| learning_rate | 0.0001 |\n", - "| n_updates | 23555 |\n", - "| reward | -12.521359 |\n", - "-----------------------------------\n", - "-----------------------------------\n", - "| time/ | |\n", - "| episodes | 12 |\n", - "| fps | 19 |\n", - "| time_elapsed | 1864 |\n", - "| total_timesteps | 35484 |\n", - "| train/ | |\n", - "| actor_loss | 56.2 |\n", - "| critic_loss | 16.1 |\n", - "| ent_coef | 0.00811 |\n", - "| ent_coef_loss | -181 |\n", - "| learning_rate | 0.0001 |\n", - "| n_updates | 35383 |\n", - "| reward | -13.549046 |\n", - "-----------------------------------\n", - "day: 2956, episode: 90\n", - "begin_total_asset: 1000000.00\n", - "end_total_asset: 5929965.89\n", - "total_reward: 4929965.89\n", - "total_cost: 6822.53\n", - "total_trades: 53683\n", - "Sharpe: 0.843\n", - "=================================\n", - "-----------------------------------\n", - "| time/ | |\n", - "| episodes | 16 |\n", - "| fps | 18 |\n", - "| time_elapsed | 2490 |\n", - "| total_timesteps | 47312 |\n", - "| train/ | |\n", - "| actor_loss | 32.3 |\n", - "| critic_loss | 4.64 |\n", - "| ent_coef | 0.00272 |\n", - "| ent_coef_loss | -90.9 |\n", - "| learning_rate | 0.0001 |\n", - "| n_updates | 47211 |\n", - "| reward | -7.4855604 |\n", - "-----------------------------------\n" - ] - } - ], - "source": [ - "trained_sac = agent.train_model(model=model_sac, \n", - " tb_log_name='sac',\n", - " total_timesteps=50000) if if_using_sac else None" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "f2wZgkQXh1jE" - }, - "source": [ - "## In-sample Performance\n", - "\n", - "Assume that the initial capital is $1,000,000." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bEv5KGC8h1jE" - }, - "source": [ - "### Set turbulence threshold\n", - "Set the turbulence threshold to be greater than the maximum of insample turbulence data. If current turbulence index is greater than the threshold, then we assume that the current market is volatile" - ] - }, - { - "cell_type": "code", - "execution_count": 78, - "metadata": { - "id": "efwBi84ch1jE" - }, - "outputs": [], - "source": [ - "data_risk_indicator = processed_full[(processed_full.date=TRAIN_START_DATE)]\n", - "insample_risk_indicator = data_risk_indicator.drop_duplicates(subset=['date'])" - ] - }, - { - "cell_type": "code", - "execution_count": 79, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "VHZMBpSqh1jG", - "outputId": "3164bf6e-3b83-4bbf-ecd4-7688c6309e8c" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "count 2957.000000\n", - "mean 18.105293\n", - "std 7.272476\n", - "min 9.140000\n", - "25% 13.370000\n", - "50% 16.209999\n", - "75% 20.629999\n", - "max 82.690002\n", - "Name: vix, dtype: float64" - ] - }, - "metadata": {}, - "execution_count": 79 - } - ], - "source": [ - "insample_risk_indicator.vix.describe()" - ] - }, - { - "cell_type": "code", - "execution_count": 80, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "BDkszkMloRWT", - "outputId": "7e36e119-63e2-4379-f110-490836222522" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "57.212001831054636" - ] - }, - "metadata": {}, - "execution_count": 80 - } - ], - "source": [ - "insample_risk_indicator.vix.quantile(0.996)" - ] - }, - { - "cell_type": "code", - "execution_count": 81, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "AL7hs7svnNWT", - "outputId": "13abfde5-de24-40b7-921e-385dd435b3e8" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "count 2957.000000\n", - "mean 34.139585\n", - "std 43.879110\n", - "min 0.000000\n", - "25% 14.613506\n", - "50% 23.644663\n", - "75% 38.292580\n", - "max 652.504902\n", - "Name: turbulence, dtype: float64" - ] - }, - "metadata": {}, - "execution_count": 81 - } - ], - "source": [ - "insample_risk_indicator.turbulence.describe()" - ] - }, - { - "cell_type": "code", - "execution_count": 82, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "N78hfHckoqJ9", - "outputId": "b5f650e9-cf0a-4481-b519-b77c8a0b1b2a" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "291.72619195879605" - ] - }, - "metadata": {}, - "execution_count": 82 - } - ], - "source": [ - "insample_risk_indicator.turbulence.quantile(0.996)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "U5mmgQF_h1jQ" - }, - "source": [ - "### Trading (Out-of-sample Performance)\n", - "\n", - "We update periodically in order to take full advantage of the data, e.g., retrain quarterly, monthly or weekly. We also tune the parameters along the way, in this notebook we use the in-sample data from 2009-01 to 2020-07 to tune the parameters once, so there is some alpha decay here as the length of trade date extends. \n", - "\n", - "Numerous hyperparameters – e.g. the learning rate, the total number of samples to train on – influence the learning process and are usually determined by testing some variations." - ] - }, - { - "cell_type": "code", - "execution_count": 105, - "metadata": { - "id": "cIqoV0GSI52v" - }, - "outputs": [], - "source": [ - "e_trade_gym = StockTradingEnv(df = trade, turbulence_threshold = 70,risk_indicator_col='vix', **env_kwargs)\n", - "# env_trade, obs_trade = e_trade_gym.get_sb_env()" - ] - }, - { - "cell_type": "code", - "execution_count": 106, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 357 - }, - "id": "W_XNgGsBMeVw", - "outputId": "13588f5a-daef-4a7b-c116-c737bf61e994" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - " date tic open high low close \n", - "0 2021-10-01 AAPL 141.899994 142.919998 139.110001 141.404266 \\\n", - "0 2021-10-01 AMGN 213.589996 214.610001 210.800003 203.845871 \n", - "0 2021-10-01 AXP 168.500000 175.119995 168.479996 170.065353 \n", - "0 2021-10-01 BA 222.850006 226.720001 220.600006 226.000000 \n", - "0 2021-10-01 CAT 192.899994 195.869995 191.240005 187.928040 \n", - "\n", - " volume day macd boll_ub boll_lb rsi_30 cci_30 \n", - "0 94639600.0 4.0 -1.703488 155.382846 137.132193 46.927735 -142.190202 \\\n", - "0 2629400.0 4.0 -3.097330 212.767980 199.379595 40.408533 -96.757039 \n", - "0 3956000.0 4.0 2.273329 174.218856 149.232889 56.265093 117.538402 \n", - "0 9113600.0 4.0 0.730320 226.909442 205.727561 51.614047 116.649440 \n", - "0 3695500.0 4.0 -3.640324 205.735919 181.432783 41.999435 -112.087765 \n", - "\n", - " dx_30 close_30_sma close_60_sma vix turbulence \n", - "0 41.749873 147.171798 146.269415 21.1 120.122978 \n", - "0 36.189244 208.480832 217.103342 21.1 120.122978 \n", - "0 15.667511 161.215661 163.458888 21.1 120.122978 \n", - "0 2.027170 217.175334 221.968500 21.1 120.122978 \n", - "0 36.203176 196.993869 200.522109 21.1 120.122978 " - ], - "text/html": [ - "\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
dateticopenhighlowclosevolumedaymacdboll_ubboll_lbrsi_30cci_30dx_30close_30_smaclose_60_smavixturbulence
02021-10-01AAPL141.899994142.919998139.110001141.40426694639600.04.0-1.703488155.382846137.13219346.927735-142.19020241.749873147.171798146.26941521.1120.122978
02021-10-01AMGN213.589996214.610001210.800003203.8458712629400.04.0-3.097330212.767980199.37959540.408533-96.75703936.189244208.480832217.10334221.1120.122978
02021-10-01AXP168.500000175.119995168.479996170.0653533956000.04.02.273329174.218856149.23288956.265093117.53840215.667511161.215661163.45888821.1120.122978
02021-10-01BA222.850006226.720001220.600006226.0000009113600.04.00.730320226.909442205.72756151.614047116.6494402.027170217.175334221.96850021.1120.122978
02021-10-01CAT192.899994195.869995191.240005187.9280403695500.04.0-3.640324205.735919181.43278341.999435-112.08776536.203176196.993869200.52210921.1120.122978
\n", - "
\n", - " \n", - " \n", - " \n", - "\n", - " \n", - "
\n", - "
\n", - " " - ] - }, - "metadata": {}, - "execution_count": 106 - } - ], - "source": [ - "trade.head()" - ] - }, - { - "cell_type": "code", - "execution_count": 107, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "lbFchno5j3xs", - "outputId": "5df880d8-ff14-4104-a2f8-a2d1a417cc1c" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "hit end!\n" - ] - } - ], - "source": [ - "trained_moedl = trained_a2c\n", - "df_account_value_a2c, df_actions_a2c = DRLAgent.DRL_prediction(\n", - " model=trained_moedl, \n", - " environment = e_trade_gym)" - ] - }, - { - "cell_type": "code", - "execution_count": 108, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "JbYljWGjj3pH", - "outputId": "2fb2632a-dd77-40f2-eeff-e4b3385727f2" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "hit end!\n" - ] - } - ], - "source": [ - "trained_moedl = trained_ddpg\n", - "df_account_value_ddpg, df_actions_ddpg = DRLAgent.DRL_prediction(\n", - " model=trained_moedl, \n", - " environment = e_trade_gym)" - ] - }, - { - "cell_type": "code", - "execution_count": 109, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "74jNP2DBj3hb", - "outputId": "9659e354-3d56-4fe3-b6bb-81777d179c51" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "hit end!\n" - ] - } - ], - "source": [ - "trained_moedl = trained_ppo\n", - "df_account_value_ppo, df_actions_ppo = DRLAgent.DRL_prediction(\n", - " model=trained_moedl, \n", - " environment = e_trade_gym)" - ] - }, - { - "cell_type": "code", - "execution_count": 110, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "S7VyGGJPj3SH", - "outputId": "a65b52c5-aba0-4e48-b111-481b514fcce2" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "hit end!\n" - ] - } - ], - "source": [ - "trained_moedl = trained_td3\n", - "df_account_value_td3, df_actions_td3 = DRLAgent.DRL_prediction(\n", - " model=trained_moedl, \n", - " environment = e_trade_gym)" - ] - }, - { - "cell_type": "code", - "execution_count": 111, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "eLOnL5eYh1jR", - "outputId": "3d9bf94b-2bb5-4091-dc7f-bfe2851dc0be" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "hit end!\n" - ] - } - ], - "source": [ - "trained_moedl = trained_sac\n", - "df_account_value_sac, df_actions_sac = DRLAgent.DRL_prediction(\n", - " model=trained_moedl, \n", - " environment = e_trade_gym)" - ] - }, - { - "cell_type": "code", - "execution_count": 112, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "ERxw3KqLkcP4", - "outputId": "219b1298-4a18-41a3-8390-788739158dd7" - }, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "(353, 2)" - ] - }, - "metadata": {}, - "execution_count": 112 - } - ], - "source": [ - "df_account_value_a2c.shape" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "GcE-t08w6DaW" - }, - "source": [ - "\n", - "# Part 6.5: Mean Variance Optimization" - ] - }, - { - "cell_type": "markdown", - "source": [ - "Mean Variance optimization is a very classic strategy in portfolio management. Here, we go through the whole process to do the mean variance optimization and add it as a baseline to compare.\n", - "\n", - "First, process dataframe to the form for MVO weight calculation." - ], - "metadata": { - "id": "GzyHU-RokTaj" - } - }, - { - "cell_type": "code", - "source": [ - "def process_df_for_mvo(df):\n", - " df = df.sort_values(['date','tic'],ignore_index=True)[['date','tic','close']]\n", - " fst = df\n", - " fst = fst.iloc[0:stock_dimension, :]\n", - " tic = fst['tic'].tolist()\n", - "\n", - " mvo = pd.DataFrame()\n", - "\n", - " for k in range(len(tic)):\n", - " mvo[tic[k]] = 0\n", - "\n", - " for i in range(df.shape[0]//stock_dimension):\n", - " n = df\n", - " n = n.iloc[i * stock_dimension:(i+1) * stock_dimension, :]\n", - " date = n['date'][i*stock_dimension]\n", - " mvo.loc[date] = n['close'].tolist()\n", - " \n", - " return mvo" - ], - "metadata": { - "id": "ZaxdYAdRcA67" - }, - "execution_count": 134, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "### Helper functions for mean returns and variance-covariance matrix" - ], - "metadata": { - "id": "tcHDZ7hFkdyL" - } - }, - { - "cell_type": "code", - "source": [ - "# Codes in this section partially refer to Dr G A Vijayalakshmi Pai\n", - "\n", - "# https://www.kaggle.com/code/vijipai/lesson-5-mean-variance-optimization-of-portfolios/notebook\n", - "\n", - "def StockReturnsComputing(StockPrice, Rows, Columns): \n", - " import numpy as np \n", - " StockReturn = np.zeros([Rows-1, Columns]) \n", - " for j in range(Columns): # j: Assets \n", - " for i in range(Rows-1): # i: Daily Prices \n", - " StockReturn[i,j]=((StockPrice[i+1, j]-StockPrice[i,j])/StockPrice[i,j])* 100 \n", - " \n", - " return StockReturn" - ], - "metadata": { - "id": "gKjY9bvYcEkb" - }, - "execution_count": 135, - "outputs": [] - }, - { - "cell_type": "markdown", - "source": [ - "### Calculate the weights for mean-variance" - ], - "metadata": { - "id": "CPnMNonxkj-I" - } - }, - { - "cell_type": "code", - "source": [ - "train_mvo = data_split(processed_full, TRAIN_START_DATE,TRAIN_END_DATE).reset_index()\n", - "trade_mvo = data_split(processed_full, TRADE_START_DATE,TRADE_END_DATE).reset_index()" - ], - "metadata": { - "id": "wdF2erPNcVd3" - }, - "execution_count": 138, - "outputs": [] - }, - { - "cell_type": "code", - "source": [ - "StockData = process_df_for_mvo(train_mvo)\n", - "TradeData = process_df_for_mvo(trade_mvo)\n", - "\n", - "TradeData.to_numpy()" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "9-64xYTOcJ36", - "outputId": "5cf98bac-c467-4ef1-e98c-2bb858a848c2" - }, - "execution_count": 139, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "array([[141.404, 203.846, 170.065, ..., 49.275, 43.724, 133.839],\n", - " [137.925, 201.483, 168.814, ..., 49.456, 43.621, 132.55 ],\n", - " [139.878, 201.883, 170.867, ..., 49.465, 43.995, 133.419],\n", - " ...,\n", - " [149.4 , 237.62 , 174.494, ..., 37.903, 36.21 , 141.51 ],\n", - " [146.71 , 233.66 , 173.607, ..., 38.109, 35.8 , 141.888],\n", - " [147.92 , 234.45 , 172.66 , ..., 38.247, 35.39 , 140.863]])" - ] - }, - "metadata": {}, - "execution_count": 139 - } - ] - }, - { - "cell_type": "code", - "source": [ - "#compute asset returns\n", - "arStockPrices = np.asarray(StockData)\n", - "[Rows, Cols]=arStockPrices.shape\n", - "arReturns = StockReturnsComputing(arStockPrices, Rows, Cols)\n", - "\n", - "#compute mean returns and variance covariance matrix of returns\n", - "meanReturns = np.mean(arReturns, axis = 0)\n", - "covReturns = np.cov(arReturns, rowvar=False)\n", - " \n", - "#set precision for printing results\n", - "np.set_printoptions(precision=3, suppress = True)\n", - "\n", - "#display mean returns and variance-covariance matrix of returns\n", - "print('Mean returns of assets in k-portfolio 1\\n', meanReturns)\n", - "print('Variance-Covariance matrix of returns\\n', covReturns)" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "28q2-ebfcfbu", - "outputId": "3a51ec82-f586-4462-f5d1-604017ffa1fe" - }, - "execution_count": 65, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "Mean returns of assets in k-portfolio 1\n", - " [0.12 0.065 0.07 0.08 0.068 0.116 0.051 0.039 0.073 0.049 0.102 0.078\n", - " 0.024 0.06 0.048 0.072 0.039 0.064 0.045 0.048 0.096 0.092 0.046 0.058\n", - " 0.104 0.094 0.043 0.034 0.048]\n", - "Variance-Covariance matrix of returns\n", - " [[3.14 1.017 1.294 1.529 1.374 1.696 1.379 1.13 1.092 1.377 1.165 1.224\n", - " 1.035 1.526 0.693 1.272 0.667 0.819 1.032 0.737 1.579 1.145 0.672 0.863\n", - " 1.146 1.34 0.526 0.909 0.635]\n", - " [1.017 2.406 1.076 1.016 1.067 1.218 1.027 0.967 0.938 1.128 0.984 1.032\n", - " 0.851 1.155 0.892 1.143 0.643 0.668 0.909 1.074 1.092 0.879 0.727 0.855\n", - " 1.147 1.058 0.629 1.027 0.648]\n", - " [1.294 1.076 3.307 2.486 1.912 1.575 1.435 1.926 1.721 2.24 1.382 1.859\n", - " 1.342 1.534 0.866 2.37 0.994 1.082 1.384 0.959 1.362 1.424 0.73 1.509\n", - " 1.416 1.802 0.75 1.223 0.586]\n", - " [1.529 1.016 2.486 4.984 2.129 1.698 1.518 2.153 1.845 2.214 1.537 2.132\n", - " 1.536 1.706 0.911 2.308 1.123 1.201 1.47 0.947 1.444 1.583 0.752 1.613\n", - " 1.483 1.722 0.738 1.374 0.609]\n", - " [1.374 1.067 1.912 2.129 3.34 1.621 1.522 1.893 1.473 2.005 1.242 1.874\n", - " 1.355 1.567 0.827 2.044 0.851 0.86 1.591 0.916 1.336 1.299 0.684 1.274\n", - " 1.219 1.408 0.738 1.228 0.614]\n", - " [1.696 1.218 1.575 1.698 1.621 5.063 1.658 1.263 1.357 1.61 1.363 1.447\n", - " 1.191 1.648 0.76 1.573 0.741 0.99 1.155 0.915 1.859 1.499 0.659 1.048\n", - " 1.209 1.728 0.567 1.02 0.624]\n", - " [1.379 1.027 1.435 1.518 1.522 1.658 2.806 1.307 1.267 1.508 1.137 1.36\n", - " 1.247 1.584 0.846 1.518 0.731 0.826 1.22 0.881 1.486 1.151 0.744 1.006\n", - " 1.1 1.314 0.656 1.064 0.689]\n", - " [1.13 0.967 1.926 2.153 1.893 1.263 1.307 2.928 1.393 1.833 1.207 1.647\n", - " 1.274 1.395 0.839 1.948 0.922 0.999 1.268 0.958 1.208 1.141 0.69 1.367\n", - " 1.299 1.384 0.743 1.066 0.519]\n", - " [1.092 0.938 1.721 1.845 1.473 1.357 1.267 1.393 2.459 1.621 1.156 1.467\n", - " 1.059 1.226 0.703 1.685 0.832 0.866 1.12 0.767 1.151 1.201 0.665 1.1\n", - " 1.059 1.31 0.661 0.997 0.579]\n", - " [1.377 1.128 2.24 2.214 2.005 1.61 1.508 1.833 1.621 3.381 1.354 1.732\n", - " 1.324 1.576 0.833 2.733 0.83 0.944 1.415 0.968 1.42 1.322 0.696 1.504\n", - " 1.384 1.531 0.721 1.24 0.616]\n", - " [1.165 0.984 1.382 1.537 1.242 1.363 1.137 1.207 1.156 1.354 2.063 1.244\n", - " 1.002 1.247 0.693 1.369 0.735 0.915 1.036 0.766 1.208 1.205 0.695 1.083\n", - " 1.14 1.21 0.632 0.916 0.723]\n", - " [1.224 1.032 1.859 2.132 1.874 1.447 1.36 1.647 1.467 1.732 1.244 2.184\n", - " 1.214 1.351 0.83 1.831 0.914 0.928 1.418 0.898 1.251 1.307 0.698 1.257\n", - " 1.259 1.411 0.694 1.145 0.579]\n", - " [1.035 0.851 1.342 1.536 1.355 1.191 1.247 1.274 1.059 1.324 1.002 1.214\n", - " 1.982 1.287 0.709 1.371 0.771 0.754 1.084 0.777 1.13 0.96 0.664 1.004\n", - " 1.011 1.106 0.644 0.966 0.548]\n", - " [1.526 1.155 1.534 1.706 1.567 1.648 1.584 1.395 1.226 1.576 1.247 1.351\n", - " 1.287 3.246 0.788 1.572 0.79 0.796 1.18 0.858 1.697 1.132 0.729 1.131\n", - " 1.153 1.333 0.67 1.081 0.705]\n", - " [0.693 0.892 0.866 0.911 0.827 0.76 0.846 0.839 0.703 0.833 0.693 0.83\n", - " 0.709 0.788 1.123 0.881 0.614 0.564 0.791 0.781 0.787 0.688 0.641 0.695\n", - " 0.835 0.796 0.567 0.783 0.522]\n", - " [1.272 1.143 2.37 2.308 2.044 1.573 1.518 1.948 1.685 2.733 1.369 1.831\n", - " 1.371 1.572 0.881 3.241 0.933 0.966 1.436 1.022 1.381 1.351 0.74 1.627\n", - " 1.402 1.561 0.807 1.291 0.637]\n", - " [0.667 0.643 0.994 1.123 0.851 0.741 0.731 0.922 0.832 0.83 0.735 0.914\n", - " 0.771 0.79 0.614 0.933 1.2 0.656 0.763 0.635 0.747 0.738 0.692 0.834\n", - " 0.737 0.81 0.593 0.683 0.492]\n", - " [0.819 0.668 1.082 1.201 0.86 0.99 0.826 0.999 0.866 0.944 0.915 0.928\n", - " 0.754 0.796 0.564 0.966 0.656 1.426 0.753 0.644 0.877 0.912 0.546 0.893\n", - " 0.813 0.947 0.483 0.605 0.471]\n", - " [1.032 0.909 1.384 1.47 1.591 1.155 1.22 1.268 1.12 1.415 1.036 1.418\n", - " 1.084 1.18 0.791 1.436 0.763 0.753 1.83 0.758 1.03 1.041 0.652 1.036\n", - " 1.007 1.122 0.632 1.027 0.56 ]\n", - " [0.737 1.074 0.959 0.947 0.916 0.915 0.881 0.958 0.767 0.968 0.766 0.898\n", - " 0.777 0.858 0.781 1.022 0.635 0.644 0.758 1.666 0.855 0.756 0.626 0.822\n", - " 0.925 0.884 0.627 0.79 0.525]\n", - " [1.579 1.092 1.362 1.444 1.336 1.859 1.486 1.208 1.151 1.42 1.208 1.251\n", - " 1.13 1.697 0.787 1.381 0.747 0.877 1.03 0.855 2.5 1.163 0.754 0.993\n", - " 1.191 1.409 0.638 0.964 0.707]\n", - " [1.145 0.879 1.424 1.583 1.299 1.499 1.151 1.141 1.201 1.322 1.205 1.307\n", - " 0.96 1.132 0.688 1.351 0.738 0.912 1.041 0.756 1.163 2.694 0.631 0.985\n", - " 1.087 1.28 0.565 0.881 0.585]\n", - " [0.672 0.727 0.73 0.752 0.684 0.659 0.744 0.69 0.665 0.696 0.695 0.698\n", - " 0.664 0.729 0.641 0.74 0.692 0.546 0.652 0.626 0.754 0.631 1.15 0.678\n", - " 0.666 0.693 0.579 0.689 0.604]\n", - " [0.863 0.855 1.509 1.613 1.274 1.048 1.006 1.367 1.1 1.504 1.083 1.257\n", - " 1.004 1.131 0.695 1.627 0.834 0.893 1.036 0.822 0.993 0.985 0.678 1.978\n", - " 1.09 1.12 0.702 0.924 0.589]\n", - " [1.146 1.147 1.416 1.483 1.219 1.209 1.1 1.299 1.059 1.384 1.14 1.259\n", - " 1.011 1.153 0.835 1.402 0.737 0.813 1.007 0.925 1.191 1.087 0.666 1.09\n", - " 2.629 1.186 0.61 0.998 0.617]\n", - " [1.34 1.058 1.802 1.722 1.408 1.728 1.314 1.384 1.31 1.531 1.21 1.411\n", - " 1.106 1.333 0.796 1.561 0.81 0.947 1.122 0.884 1.409 1.28 0.693 1.12\n", - " 1.186 2.58 0.607 0.93 0.574]\n", - " [0.526 0.629 0.75 0.738 0.738 0.567 0.656 0.743 0.661 0.721 0.632 0.694\n", - " 0.644 0.67 0.567 0.807 0.593 0.483 0.632 0.627 0.638 0.565 0.579 0.702\n", - " 0.61 0.607 1.195 0.663 0.512]\n", - " [0.909 1.027 1.223 1.374 1.228 1.02 1.064 1.066 0.997 1.24 0.916 1.145\n", - " 0.966 1.081 0.783 1.291 0.683 0.605 1.027 0.79 0.964 0.881 0.689 0.924\n", - " 0.998 0.93 0.663 3.131 0.71 ]\n", - " [0.635 0.648 0.586 0.609 0.614 0.624 0.689 0.519 0.579 0.616 0.723 0.579\n", - " 0.548 0.705 0.522 0.637 0.492 0.471 0.56 0.525 0.707 0.585 0.604 0.589\n", - " 0.617 0.574 0.512 0.71 1.408]]\n" - ] - } - ] - }, - { - "cell_type": "markdown", - "source": [ - "### Use PyPortfolioOpt" - ], - "metadata": { - "id": "Ei3f_NxDkpOx" - } - }, - { - "cell_type": "code", - "source": [ - "from pypfopt.efficient_frontier import EfficientFrontier\n", - "\n", - "ef_mean = EfficientFrontier(meanReturns, covReturns, weight_bounds=(0, 0.5))\n", - "raw_weights_mean = ef_mean.max_sharpe()\n", - "cleaned_weights_mean = ef_mean.clean_weights()\n", - "mvo_weights = np.array([1000000 * cleaned_weights_mean[i] for i in range(29)])\n", - "mvo_weights" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "bHc3FC3Hckay", - "outputId": "6585f4b7-fda4-4d83-c3cc-38c5ed750aea" - }, - "execution_count": 66, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "array([288130., 0., 0., 0., 0., 69330., 0.,\n", - " 0., 0., 0., 316160., 0., 0., 0.,\n", - " 0., 0., 0., 0., 0., 0., 0.,\n", - " 66290., 0., 0., 237410., 22690., 0., 0.,\n", - " 0.])" - ] - }, - "metadata": {}, - "execution_count": 66 - } - ] - }, - { - "cell_type": "code", - "source": [ - "LastPrice = np.array([1/p for p in StockData.tail(1).to_numpy()[0]])\n", - "Initial_Portfolio = np.multiply(mvo_weights, LastPrice)\n", - "Initial_Portfolio" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "iFiwDj29ck9s", - "outputId": "1e4c7967-c5af-43de-a858-beadfef5116c" - }, - "execution_count": 67, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - "array([2054.193, 0. , 0. , 0. , 0. , 255.623,\n", - " 0. , 0. , 0. , 0. , 998.634, 0. ,\n", - " 0. , 0. , 0. , 0. , 0. , 0. ,\n", - " 0. , 0. , 0. , 463.5 , 0. , 0. ,\n", - " 619.578, 103.016, 0. , 0. , 0. ])" - ] - }, - "metadata": {}, - "execution_count": 67 - } - ] - }, - { - "cell_type": "code", - "source": [ - "Portfolio_Assets = TradeData @ Initial_Portfolio\n", - "MVO_result = pd.DataFrame(Portfolio_Assets, columns=[\"Mean Var\"])\n", - "# MVO_result" - ], - "metadata": { - "id": "wbcVsNYfcn2B" - }, - "execution_count": 68, - "outputs": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "W6vvNSC6h1jZ" - }, - "source": [ - "\n", - "# Part 7: Backtesting Results\n", - "Backtesting plays a key role in evaluating the performance of a trading strategy. Automated backtesting tool is preferred because it reduces the human error. We usually use the Quantopian pyfolio package to backtest our trading strategies. It is easy to use and consists of various individual plots that provide a comprehensive image of the performance of a trading strategy." - ] - }, - { - "cell_type": "code", - "execution_count": 145, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "KeDeGAc9VrEg", - "outputId": "fe8802d9-e883-48fb-ed8d-36a8236322f7" - }, - "outputs": [ - { - "output_type": "stream", - "name": "stdout", - "text": [ - "==============Get Baseline Stats===========\n", - "\r[*********************100%***********************] 1 of 1 completed\n", - "Shape of DataFrame: (354, 8)\n", - "Annual return -0.034876\n", - "Cumulative returns -0.048644\n", - "Annual volatility 0.181612\n", - "Sharpe ratio -0.105351\n", - "Calmar ratio -0.158953\n", - "Stability 0.280983\n", - "Max drawdown -0.219408\n", - "Omega ratio 0.982546\n", - "Sortino ratio -0.146974\n", - "Skew NaN\n", - "Kurtosis NaN\n", - "Tail ratio 0.970602\n", - "Daily value at risk -0.022957\n", - "dtype: float64\n", - " a2c ddpg td3 ppo \n", - "date \n", - "2021-10-01 1.000000e+06 1.000000e+06 1.000000e+06 1.000000e+06 \\\n", - "2021-10-04 9.971647e+05 9.977126e+05 9.961645e+05 9.995875e+05 \n", - "2021-10-05 1.000508e+06 1.003342e+06 1.002158e+06 1.000310e+06 \n", - "2021-10-06 1.002597e+06 1.004220e+06 1.006038e+06 1.000651e+06 \n", - "2021-10-07 1.012727e+06 1.010823e+06 1.016102e+06 1.002144e+06 \n", - "\n", - " sac Mean Var \n", - "date \n", - "2021-10-01 1.000000e+06 1.007573e+06 \n", - "2021-10-04 9.978017e+05 9.921956e+05 \n", - "2021-10-05 1.002751e+06 1.004253e+06 \n", - "2021-10-06 1.004620e+06 1.008231e+06 \n", - "2021-10-07 1.012706e+06 1.025692e+06 \n" - ] - } - ], - "source": [ - "df_result_a2c = df_account_value_a2c.set_index(df_account_value_a2c.columns[0])\n", - "df_result_a2c.rename(columns = {'account_value':'a2c'}, inplace = True)\n", - "df_result_ddpg = df_account_value_ddpg.set_index(df_account_value_ddpg.columns[0])\n", - "df_result_ddpg.rename(columns = {'account_value':'ddpg'}, inplace = True)\n", - "df_result_td3 = df_account_value_td3.set_index(df_account_value_td3.columns[0])\n", - "df_result_td3.rename(columns = {'account_value':'td3'}, inplace = True)\n", - "df_result_ppo = df_account_value_ppo.set_index(df_account_value_ppo.columns[0])\n", - "df_result_ppo.rename(columns = {'account_value':'ppo'}, inplace = True)\n", - "df_result_sac = df_account_value_sac.set_index(df_account_value_sac.columns[0])\n", - "df_result_sac.rename(columns = {'account_value':'sac'}, inplace = True)\n", - "df_account_value_a2c.to_csv(\"df_account_value_a2c.csv\")\n", - "#baseline stats\n", - "print(\"==============Get Baseline Stats===========\")\n", - "df_dji_ = get_baseline(\n", - " ticker=\"^DJI\", \n", - " start = TRADE_START_DATE,\n", - " end = TRADE_END_DATE)\n", - "stats = backtest_stats(df_dji_, value_col_name = 'close')\n", - "df_dji = pd.DataFrame()\n", - "df_dji['date'] = df_account_value_a2c['date']\n", - "df_dji['account_value'] = df_dji_['close'] / df_dji_['close'][0] * env_kwargs[\"initial_amount\"]\n", - "df_dji.to_csv(\"df_dji.csv\")\n", - "df_dji = df_dji.set_index(df_dji.columns[0])\n", - "df_dji.to_csv(\"df_dji+.csv\")\n", - "\n", - "result = pd.DataFrame()\n", - "result = pd.merge(result, df_result_a2c, how='outer', left_index=True, right_index=True)\n", - "result = pd.merge(result, df_result_ddpg, how='outer', left_index=True, right_index=True)\n", - "result = pd.merge(result, df_result_td3, how='outer', left_index=True, right_index=True)\n", - "result = pd.merge(result, df_result_ppo, how='outer', left_index=True, right_index=True)\n", - "result = pd.merge(result, df_result_sac, how='outer', left_index=True, right_index=True)\n", - "result = pd.merge(result, MVO_result, how='outer', left_index=True, right_index=True)\n", - "print(result.head())\n", - "result = pd.merge(result, df_dji, how='outer', left_index=True, right_index=True)\n", - "# result.columns = ['a2c', 'ddpg', 'td3', 'ppo', 'sac', 'mean var', 'dji']\n", - "\n", - "# print(\"result: \", result)\n", - "result.to_csv(\"result.csv\")" - ] - }, - { - "cell_type": "code", - "source": [ - "df_result_ddpg" - ], - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 455 - }, - "id": "WLapAJTri_7B", - "outputId": "d9625b21-8814-4ec5-bc6e-3a331be40856" - }, - "execution_count": 141, - "outputs": [ - { - "output_type": "execute_result", - "data": { - "text/plain": [ - " ddpg\n", - "date \n", - "2021-10-01 1.000000e+06\n", - "2021-10-04 9.977126e+05\n", - "2021-10-05 1.003342e+06\n", - "2021-10-06 1.004220e+06\n", - "2021-10-07 1.010823e+06\n", - "... ...\n", - "2023-02-21 9.815355e+05\n", - "2023-02-22 9.794530e+05\n", - "2023-02-23 9.820209e+05\n", - "2023-02-24 9.739605e+05\n", - "2023-02-27 9.752743e+05\n", - "\n", - "[353 rows x 1 columns]" - ], - "text/html": [ - "\n", - "
\n", - "
\n", - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
ddpg
date
2021-10-011.000000e+06
2021-10-049.977126e+05
2021-10-051.003342e+06
2021-10-061.004220e+06
2021-10-071.010823e+06
......
2023-02-219.815355e+05
2023-02-229.794530e+05
2023-02-239.820209e+05
2023-02-249.739605e+05
2023-02-279.752743e+05
\n", - "

353 rows × 1 columns

\n", - "
\n", - " \n", - " \n", - " \n", - "\n", - " \n", - "
\n", - "
\n", - " " - ] - }, - "metadata": {}, - "execution_count": 141 - } - ] + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "yfv52r2G33jY" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gXaoZs2lh1hi" + }, + "source": [ + "# Deep Reinforcement Learning for Stock Trading from Scratch: Multiple Stock Trading\n", + "\n", + "* **Pytorch Version** \n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lGunVt8oLCVS" + }, + "source": [ + "# Content" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HOzAKQ-SLGX6" + }, + "source": [ + "* [1. Task Description](#0)\n", + "* [2. Install Python packages](#1)\n", + " * [2.1. Install Packages](#1.1) \n", + " * [2.2. A List of Python Packages](#1.2)\n", + " * [2.3. Import Packages](#1.3)\n", + " * [2.4. Create Folders](#1.4)\n", + "* [3. Download and Preprocess Data](#2)\n", + "* [4. Preprocess Data](#3) \n", + " * [4.1. Technical Indicators](#3.1)\n", + " * [4.2. Perform Feature Engineering](#3.2)\n", + "* [5. Build Market Environment in OpenAI Gym-style](#4) \n", + " * [5.1. Data Split](#4.1) \n", + " * [5.3. Environment for Training](#4.2) \n", + "* [6. Train DRL Agents](#5)\n", + "* [7. Backtesting Performance](#6) \n", + " * [7.1. BackTestStats](#6.1)\n", + " * [7.2. BackTestPlot](#6.2) \n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "sApkDlD9LIZv" + }, + "source": [ + "\n", + "# Part 1. Task Discription" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HjLD2TZSLKZ-" + }, + "source": [ + "We train a DRL agent for stock trading. This task is modeled as a Markov Decision Process (MDP), and the objective function is maximizing (expected) cumulative return.\n", + "\n", + "We specify the state-action-reward as follows:\n", + "\n", + "* **State s**: The state space represents an agent's perception of the market environment. Just like a human trader analyzing various information, here our agent passively observes many features and learns by interacting with the market environment (usually by replaying historical data).\n", + "\n", + "* **Action a**: The action space includes allowed actions that an agent can take at each state. For example, a ∈ {−1, 0, 1}, where −1, 0, 1 represent\n", + "selling, holding, and buying. When an action operates multiple shares, a ∈{−k, ..., −1, 0, 1, ..., k}, e.g.. \"Buy\n", + "10 shares of AAPL\" or \"Sell 10 shares of AAPL\" are 10 or −10, respectively\n", + "\n", + "* **Reward function r(s, a, s′)**: Reward is an incentive for an agent to learn a better policy. For example, it can be the change of the portfolio value when taking a at state s and arriving at new state s', i.e., r(s, a, s′) = v′ − v, where v′ and v represent the portfolio values at state s′ and s, respectively\n", + "\n", + "\n", + "**Market environment**: 30 consituent stocks of Dow Jones Industrial Average (DJIA) index. Accessed at the starting date of the testing period.\n", + "\n", + "\n", + "The data for this case study is obtained from Yahoo Finance API. The data contains Open-High-Low-Close price and volume.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ffsre789LY08" + }, + "source": [ + "\n", + "# Part 2. Install Python Packages" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Uy5_PTmOh1hj" + }, + "source": [ + "\n", + "## 2.1. Install packages\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "mPT0ipYE28wL", + "outputId": "6dad74d2-c37f-4b86-c584-2436d2ef5bae" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: swig in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (4.3.0)\n", + "Requirement already satisfied: wrds in /home/random/.local/lib/python3.12/site-packages (3.2.0)\n", + "Requirement already satisfied: numpy<1.27,>=1.26 in /home/random/.local/lib/python3.12/site-packages (from wrds) (1.26.4)\n", + "Requirement already satisfied: packaging<23.3 in /home/random/.local/lib/python3.12/site-packages (from wrds) (23.2)\n", + "Requirement already satisfied: pandas<2.3,>=2.2 in /home/random/.local/lib/python3.12/site-packages (from wrds) (2.2.3)\n", + "Requirement already satisfied: psycopg2-binary<2.10,>=2.9 in /home/random/.local/lib/python3.12/site-packages (from wrds) (2.9.10)\n", + "Requirement already satisfied: scipy<1.13,>=1.12 in /home/random/.local/lib/python3.12/site-packages (from wrds) (1.12.0)\n", + "Requirement already satisfied: sqlalchemy<2.1,>=2 in /home/random/.local/lib/python3.12/site-packages (from wrds) (2.0.36)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /home/random/.local/lib/python3.12/site-packages (from pandas<2.3,>=2.2->wrds) (2.9.0.post0)\n", + "Requirement already satisfied: pytz>=2020.1 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from pandas<2.3,>=2.2->wrds) (2024.1)\n", + "Requirement already satisfied: tzdata>=2022.7 in /home/random/.local/lib/python3.12/site-packages (from pandas<2.3,>=2.2->wrds) (2024.2)\n", + "Requirement already satisfied: typing-extensions>=4.6.0 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from sqlalchemy<2.1,>=2->wrds) (4.12.2)\n", + "Requirement already satisfied: greenlet!=0.4.17 in /home/random/.local/lib/python3.12/site-packages (from sqlalchemy<2.1,>=2->wrds) (3.1.1)\n", + "Requirement already satisfied: six>=1.5 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from python-dateutil>=2.8.2->pandas<2.3,>=2.2->wrds) (1.16.0)\n", + "Requirement already satisfied: pyportfolioopt in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (1.5.6)\n", + "Requirement already satisfied: cvxpy>=1.1.19 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from pyportfolioopt) (1.6.0)\n", + "Requirement already satisfied: ecos<3.0.0,>=2.0.14 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from pyportfolioopt) (2.0.14)\n", + "Requirement already satisfied: numpy>=1.26.0 in /home/random/.local/lib/python3.12/site-packages (from pyportfolioopt) (1.26.4)\n", + "Requirement already satisfied: pandas>=0.19 in /home/random/.local/lib/python3.12/site-packages (from pyportfolioopt) (2.2.3)\n", + "Requirement already satisfied: plotly<6.0.0,>=5.0.0 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from pyportfolioopt) (5.24.1)\n", + "Requirement already satisfied: scipy>=1.3 in /home/random/.local/lib/python3.12/site-packages (from pyportfolioopt) (1.12.0)\n", + "Requirement already satisfied: osqp>=0.6.2 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from cvxpy>=1.1.19->pyportfolioopt) (0.6.7.post3)\n", + "Requirement already satisfied: clarabel>=0.5.0 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from cvxpy>=1.1.19->pyportfolioopt) (0.9.0)\n", + "Requirement already satisfied: scs>=3.2.4.post1 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from cvxpy>=1.1.19->pyportfolioopt) (3.2.7)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /home/random/.local/lib/python3.12/site-packages (from pandas>=0.19->pyportfolioopt) (2.9.0.post0)\n", + "Requirement already satisfied: pytz>=2020.1 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from pandas>=0.19->pyportfolioopt) (2024.1)\n", + "Requirement already satisfied: tzdata>=2022.7 in /home/random/.local/lib/python3.12/site-packages (from pandas>=0.19->pyportfolioopt) (2024.2)\n", + "Requirement already satisfied: tenacity>=6.2.0 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from plotly<6.0.0,>=5.0.0->pyportfolioopt) (9.0.0)\n", + "Requirement already satisfied: packaging in /home/random/.local/lib/python3.12/site-packages (from plotly<6.0.0,>=5.0.0->pyportfolioopt) (23.2)\n", + "Requirement already satisfied: qdldl in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from osqp>=0.6.2->cvxpy>=1.1.19->pyportfolioopt) (0.1.7.post4)\n", + "Requirement already satisfied: six>=1.5 in /home/random/anaconda3/envs/finrl/lib/python3.12/site-packages (from python-dateutil>=2.8.2->pandas>=0.19->pyportfolioopt) (1.16.0)\n", + "[sudo] password for random: cmake is already the newest version (3.28.3-1build7).\n", + "libopenmpi-dev is already the newest version (4.1.6-7ubuntu2).\n", + "python3-dev is already the newest version (3.12.3-0ubuntu2).\n", + "zlib1g-dev is already the newest version (1:1.3.dfsg-3.1ubuntu2.1).\n", + "libgl1-mesa-glx is already the newest version (23.0.4-0ubuntu1~22.04.1).\n", + "swig is already the newest version (4.2.0-2ubuntu1).\n", + "0 upgraded, 0 newly installed, 0 to remove and 39 not upgraded.\n", + "Collecting git+https://github.com/AI4Finance-Foundation/FinRL.git\n", + " Cloning https://github.com/AI4Finance-Foundation/FinRL.git to /tmp/pip-req-build-flt95p98\n", + " Running command git clone --filter=blob:none --quiet https://github.com/AI4Finance-Foundation/FinRL.git /tmp/pip-req-build-flt95p98\n", + " Resolved https://github.com/AI4Finance-Foundation/FinRL.git to commit ef471fcea1f3667442f5ecbf7b4c214610a5dd55\n", + " Installing build dependencies ... \u001b[?25ldone\n", + "\u001b[?25h Getting requirements to build wheel ... \u001b[?25ldone\n", + "\u001b[?25h Preparing metadata (pyproject.toml) ... \u001b[?25ldone\n", + "\u001b[?25hCollecting elegantrl@ git+https://github.com/AI4Finance-Foundation/ElegantRL.git (from finrl==0.3.6)\n", + " Cloning https://github.com/AI4Finance-Foundation/ElegantRL.git to /tmp/pip-install-u43l6ss9/elegantrl_36782baa6d82461e89b600dda61820c8\n", + " Running command git clone --filter=blob:none --quiet https://github.com/AI4Finance-Foundation/ElegantRL.git /tmp/pip-install-u43l6ss9/elegantrl_36782baa6d82461e89b600dda61820c8\n", + " Resolved https://github.com/AI4Finance-Foundation/ElegantRL.git to commit 59d9a33e2b3ba2d77c052c2810bb61059736d88c\n", + " Preparing metadata (setup.py) ... \u001b[?25ldone\n", + "\u001b[?25hRequirement already satisfied: alpaca-trade-api<4,>=3 in /home/random/.local/lib/python3.12/site-packages (from finrl==0.3.6) (3.2.0)\n", + "Collecting ccxt<4,>=3 (from finrl==0.3.6)\n", + " Using cached ccxt-3.1.60-py2.py3-none-any.whl.metadata (108 kB)\n", + "Requirement already satisfied: exchange-calendars<5,>=4 in /home/random/.local/lib/python3.12/site-packages (from finrl==0.3.6) (4.6)\n", + "Collecting jqdatasdk<2,>=1 (from finrl==0.3.6)\n", + " Using cached jqdatasdk-1.9.7-py3-none-any.whl.metadata (5.8 kB)\n", + "Collecting pyfolio<0.10,>=0.9 (from finrl==0.3.6)\n", + " Using cached pyfolio-0.9.2.tar.gz (91 kB)\n", + " Preparing metadata (setup.py) ... \u001b[?25lerror\n", + " \u001b[1;31merror\u001b[0m: \u001b[1msubprocess-exited-with-error\u001b[0m\n", + " \n", + " \u001b[31m×\u001b[0m \u001b[32mpython setup.py egg_info\u001b[0m did not run successfully.\n", + " \u001b[31m│\u001b[0m exit code: \u001b[1;36m1\u001b[0m\n", + " \u001b[31m╰─>\u001b[0m \u001b[31m[18 lines of output]\u001b[0m\n", + " \u001b[31m \u001b[0m /tmp/pip-install-u43l6ss9/pyfolio_f61a15f976d345b4a7050d0999ff9c7b/versioneer.py:468: SyntaxWarning: invalid escape sequence '\\s'\n", + " \u001b[31m \u001b[0m LONG_VERSION_PY['git'] = '''\n", + " \u001b[31m \u001b[0m Traceback (most recent call last):\n", + " \u001b[31m \u001b[0m File \"\", line 2, in \n", + " \u001b[31m \u001b[0m File \"\", line 34, in \n", + " \u001b[31m \u001b[0m File \"/tmp/pip-install-u43l6ss9/pyfolio_f61a15f976d345b4a7050d0999ff9c7b/setup.py\", line 71, in \n", + " \u001b[31m \u001b[0m version=versioneer.get_version(),\n", + " \u001b[31m \u001b[0m ^^^^^^^^^^^^^^^^^^^^^^^^\n", + " \u001b[31m \u001b[0m File \"/tmp/pip-install-u43l6ss9/pyfolio_f61a15f976d345b4a7050d0999ff9c7b/versioneer.py\", line 1407, in get_version\n", + " \u001b[31m \u001b[0m return get_versions()[\"version\"]\n", + " \u001b[31m \u001b[0m ^^^^^^^^^^^^^^\n", + " \u001b[31m \u001b[0m File \"/tmp/pip-install-u43l6ss9/pyfolio_f61a15f976d345b4a7050d0999ff9c7b/versioneer.py\", line 1341, in get_versions\n", + " \u001b[31m \u001b[0m cfg = get_config_from_root(root)\n", + " \u001b[31m \u001b[0m ^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " \u001b[31m \u001b[0m File \"/tmp/pip-install-u43l6ss9/pyfolio_f61a15f976d345b4a7050d0999ff9c7b/versioneer.py\", line 399, in get_config_from_root\n", + " \u001b[31m \u001b[0m parser = configparser.SafeConfigParser()\n", + " \u001b[31m \u001b[0m ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", + " \u001b[31m \u001b[0m AttributeError: module 'configparser' has no attribute 'SafeConfigParser'. Did you mean: 'RawConfigParser'?\n", + " \u001b[31m \u001b[0m \u001b[31m[end of output]\u001b[0m\n", + " \n", + " \u001b[1;35mnote\u001b[0m: This error originates from a subprocess, and is likely not a problem with pip.\n", + "\u001b[1;31merror\u001b[0m: \u001b[1mmetadata-generation-failed\u001b[0m\n", + "\n", + "\u001b[31m×\u001b[0m Encountered error while generating package metadata.\n", + "\u001b[31m╰─>\u001b[0m See above for output.\n", + "\n", + "\u001b[1;35mnote\u001b[0m: This is an issue with the package mentioned above, not pip.\n", + "\u001b[1;36mhint\u001b[0m: See above for details.\n", + "\u001b[?25h" + ] + } + ], + "source": [ + "## install required packages\n", + "!pip install swig\n", + "!pip install wrds\n", + "!pip install pyportfolioopt\n", + "## install finrl library\n", + "!pip install -q condacolab\n", + "import condacolab\n", + "condacolab.install()\n", + "!apt-get update -y -qq && apt-get install -y -qq cmake libopenmpi-dev python3-dev zlib1g-dev libgl1-mesa-glx swig\n", + "!pip install git+https://github.com/AI4Finance-Foundation/FinRL.git" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "osBHhVysOEzi" + }, + "source": [ + "\n", + "\n", + "## 2.2. A list of Python packages \n", + "* Yahoo Finance API\n", + "* pandas\n", + "* numpy\n", + "* matplotlib\n", + "* stockstats\n", + "* OpenAI gym\n", + "* stable-baselines\n", + "* tensorflow\n", + "* pyfolio" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nGv01K8Sh1hn" + }, + "source": [ + "\n", + "## 2.3. Import Packages" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "lPqeTTwoh1hn", + "outputId": "e55033fc-48ae-4696-ae45-08b8bef664d5" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2025-01-04 15:29:19.697527: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.\n", + "2025-01-04 15:29:19.724461: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered\n", + "WARNING: All log messages before absl::InitializeLog() is called are written to STDERR\n", + "E0000 00:00:1736000959.745993 24692 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered\n", + "E0000 00:00:1736000959.755250 24692 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered\n", + "2025-01-04 15:29:19.798332: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.\n", + "To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.\n", + "/home/random/anaconda3/envs/finrl/lib/python3.12/site-packages/pyfolio/pos.py:25: UserWarning: Module \"zipline.assets\" not found; multipliers will not be applied to position notionals.\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import matplotlib\n", + "import matplotlib.pyplot as plt\n", + "# matplotlib.use('Agg')\n", + "import datetime\n", + "\n", + "%matplotlib inline\n", + "\n", + "from finrl.meta.preprocessor.yahoodownloader import YahooDownloader\n", + "from finrl.meta.preprocessor.preprocessors import FeatureEngineer, data_split\n", + "from finrl.meta.env_stock_trading.env_stocktrading import StockTradingEnv\n", + "from finrl.agents.stablebaselines3.models import DRLAgent\n", + "from stable_baselines3.common.logger import configure\n", + "from finrl.meta.data_processor import DataProcessor\n", + "from finrl.meta.data_processors.processor_yahoofinance import YahooFinanceProcessor\n", + "from finrl.plot import backtest_stats, backtest_plot, get_daily_return, get_baseline\n", + "from pprint import pprint\n", + "\n", + "import sys\n", + "sys.path.append(\"../FinRL\")\n", + "\n", + "import itertools" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "T2owTj985RW4" + }, + "source": [ + "\n", + "## 2.4. Create Folders" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "id": "RtUc_ofKmpdy" + }, + "outputs": [], + "source": [ + "from finrl import config\n", + "from finrl import config_tickers\n", + "import os\n", + "from finrl.main import check_and_make_directories\n", + "from finrl.config import (\n", + " DATA_SAVE_DIR,\n", + " TRAINED_MODEL_DIR,\n", + " TENSORBOARD_LOG_DIR,\n", + " RESULTS_DIR,\n", + " INDICATORS,\n", + " TRAIN_START_DATE,\n", + " TRAIN_END_DATE,\n", + " TEST_START_DATE,\n", + " TEST_END_DATE,\n", + " TRADE_START_DATE,\n", + " TRADE_END_DATE,\n", + ")\n", + "check_and_make_directories([DATA_SAVE_DIR, TRAINED_MODEL_DIR, TENSORBOARD_LOG_DIR, RESULTS_DIR])\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "A289rQWMh1hq" + }, + "source": [ + "\n", + "# Part 3. Download Data\n", + "Yahoo Finance provides stock data, financial news, financial reports, etc. Yahoo Finance is free.\n", + "* FinRL uses a class **YahooDownloader** in FinRL-Meta to fetch data via Yahoo Finance API\n", + "* Call Limit: Using the Public API (without authentication), you are limited to 2,000 requests per hour per IP (or up to a total of 48,000 requests a day)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "NPeQ7iS-LoMm" + }, + "source": [ + "\n", + "\n", + "-----\n", + "class YahooDownloader:\n", + " Retrieving daily stock data from\n", + " Yahoo Finance API\n", + "\n", + " Attributes\n", + " ----------\n", + " start_date : str\n", + " start date of the data (modified from config.py)\n", + " end_date : str\n", + " end date of the data (modified from config.py)\n", + " ticker_list : list\n", + " a list of stock tickers (modified from config.py)\n", + "\n", + " Methods\n", + " -------\n", + " fetch_data()\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 0 + }, + "id": "h3XJnvrbLp-C", + "outputId": "a03772b5-9cad-463f-e1d6-58d91a70a594" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'2020-07-31'" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# from config.py, TRAIN_START_DATE is a string\n", + "TRAIN_START_DATE\n", + "# from config.py, TRAIN_END_DATE is a string\n", + "TRAIN_END_DATE" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "FUnY8WEfLq3C" + }, + "outputs": [], + "source": [ + "TRAIN_START_DATE = '2010-01-01'\n", + "TRAIN_END_DATE = '2021-10-01'\n", + "TRADE_START_DATE = '2021-10-01'\n", + "TRADE_END_DATE = '2023-03-01'" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "yCKm4om-s9kE", + "outputId": "fd758d58-8946-42ee-e2e3-16f4ac74add2" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Processing AXP (1/3)... 33.33% complete.\n", + "Processing AMGN (2/3)... 66.67% complete.\n", + "Processing AAPL (3/3)... 100.00% complete.\n", + " Date Open High Low Close Adj Close Volume tick day\n", + "0 2010-01-04 7.62 7.66 7.59 7.64 6.45 493729600 AAPL 3\n", + "1 2010-01-04 56.63 57.87 56.56 57.72 40.92 5277400 AMGN 3\n", + "2 2010-01-04 40.81 41.10 40.39 40.92 32.83 6894300 AXP 3\n", + "3 2010-01-05 7.66 7.70 7.62 7.66 6.46 601904800 AAPL 4\n", + "4 2010-01-05 57.33 57.69 56.27 57.22 40.56 7882800 AMGN 4\n", + "5 2010-01-05 40.83 41.23 40.37 40.83 32.76 10641200 AXP 4\n", + "6 2010-01-06 7.66 7.69 7.53 7.53 6.36 552160000 AAPL 5\n", + "7 2010-01-06 56.94 57.39 56.50 56.79 40.26 6015100 AMGN 5\n", + "8 2010-01-06 41.23 41.67 41.17 41.49 33.29 8399400 AXP 5\n", + "9 2010-01-07 7.56 7.57 7.47 7.52 6.34 477131200 AAPL 6\n", + "10 2010-01-07 56.41 56.53 54.65 56.27 39.89 10371600 AMGN 6\n", + "11 2010-01-07 41.26 42.24 41.11 41.98 33.83 8981700 AXP 6\n", + "12 2010-01-08 7.51 7.57 7.47 7.57 6.39 447610800 AAPL 7\n", + "13 2010-01-08 56.07 56.83 55.64 56.77 40.24 6576000 AMGN 7\n", + "14 2010-01-08 41.76 42.48 41.40 41.95 33.80 7907700 AXP 7\n", + "15 2010-01-11 7.60 7.61 7.44 7.50 6.33 462229600 AAPL 10\n", + "16 2010-01-11 56.93 57.36 56.62 57.02 40.42 4062700 AMGN 10\n", + "17 2010-01-11 41.74 41.96 41.25 41.47 33.42 7396000 AXP 10\n", + "18 2010-01-12 7.47 7.49 7.37 7.42 6.26 594459600 AAPL 11\n", + "19 2010-01-12 57.14 57.42 54.82 56.03 39.72 11268300 AMGN 11\n", + "20 2010-01-12 41.27 42.35 41.25 42.02 33.86 12657300 AXP 11\n", + "21 2010-01-13 7.42 7.53 7.29 7.52 6.35 605892000 AAPL 12\n", + "22 2010-01-13 56.35 56.75 55.96 56.53 40.07 5056200 AMGN 12\n", + "23 2010-01-13 41.85 42.24 41.57 42.15 33.96 10137200 AXP 12\n", + "24 2010-01-14 7.50 7.52 7.47 7.48 6.31 432894000 AAPL 13\n", + "25 2010-01-14 56.35 56.53 55.91 56.16 39.81 4668900 AMGN 13\n", + "26 2010-01-14 42.04 42.74 42.02 42.68 34.39 8238400 AXP 13\n", + "27 2010-01-15 7.53 7.56 7.35 7.35 6.20 594067600 AAPL 14\n", + "28 2010-01-15 56.03 56.51 55.65 56.25 39.87 7240000 AMGN 14\n", + "29 2010-01-15 42.52 42.84 42.02 42.39 34.16 13629000 AXP 14\n", + "30 2010-01-19 7.44 7.69 7.40 7.68 6.48 730007600 AAPL 18\n", + "31 2010-01-19 56.41 57.75 56.24 57.55 40.80 8570100 AMGN 18\n", + "32 2010-01-19 42.24 43.05 42.11 42.96 34.62 9533800 AXP 18\n", + "33 2010-01-20 7.68 7.70 7.48 7.56 6.38 612152800 AAPL 19\n", + "34 2010-01-20 57.62 57.62 56.41 57.20 40.55 6625700 AMGN 19\n", + "35 2010-01-20 42.93 43.25 42.26 42.98 34.63 11643000 AXP 19\n", + "36 2010-01-21 7.57 7.62 7.40 7.43 6.27 608154400 AAPL 20\n", + "37 2010-01-21 57.43 57.56 56.31 56.63 40.14 5833700 AMGN 20\n", + "38 2010-01-21 42.99 43.10 41.53 42.16 33.97 16974300 AXP 20\n", + "39 2010-01-22 7.39 7.41 7.04 7.06 5.96 881767600 AAPL 21\n", + "40 2010-01-22 56.67 57.30 56.53 56.60 40.12 5967600 AMGN 21\n", + "41 2010-01-22 41.36 41.49 38.19 38.59 31.09 26170800 AXP 21\n", + "42 2010-01-25 7.23 7.31 7.15 7.25 6.12 1065699600 AAPL 24\n", + "43 2010-01-25 56.72 56.79 55.55 55.71 39.49 6719400 AMGN 24\n", + "44 2010-01-25 39.10 39.29 37.50 37.79 30.45 17587600 AXP 24\n", + "45 2010-01-26 7.36 7.63 7.24 7.36 6.20 1867110000 AAPL 25\n", + "46 2010-01-26 56.20 56.87 55.70 56.58 40.11 14880300 AMGN 25\n", + "47 2010-01-26 37.54 39.23 37.52 38.10 30.70 15709900 AXP 25\n", + "48 2010-01-27 7.39 7.52 7.13 7.42 6.26 1722568400 AAPL 26\n", + "49 2010-01-27 56.35 57.88 56.35 57.74 40.93 9695000 AMGN 26\n", + "50 2010-01-27 37.96 38.84 37.83 38.67 31.16 12908300 AXP 26\n", + "51 2010-01-28 7.32 7.34 7.10 7.12 6.00 1173502400 AAPL 27\n", + "52 2010-01-28 57.87 58.78 57.56 58.08 41.17 11638200 AMGN 27\n", + "53 2010-01-28 38.67 38.67 36.83 37.43 30.16 14148600 AXP 27\n", + "54 2010-01-29 7.18 7.22 6.79 6.86 5.79 1245952400 AAPL 28\n", + "55 2010-01-29 58.35 58.93 58.16 58.48 41.45 9465700 AMGN 28\n", + "56 2010-01-29 37.60 38.77 37.36 37.66 30.35 14219900 AXP 28\n" + ] + } + ], + "source": [ + "#df = YahooDownloader(start_date = TRAIN_START_DATE,\n", + "# end_date = TRADE_END_DATE,\n", + "# ticker_list = config_tickers.DOW_30_TICKER).fetch_data()\n", + "yfp = YahooFinanceProcessor()\n", + "df = yfp.scrap_data(['AXP', 'AMGN', 'AAPL'], '2010-01-01', '2010-02-01')\n", + "print(df)" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "JzqRRTOX6aFu", + "outputId": "58a21ede-016a-4eaf-db9f-aeb190b3f939" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['AXP', 'AMGN', 'AAPL', 'BA', 'CAT', 'CSCO', 'CVX', 'GS', 'HD', 'HON', 'IBM', 'INTC', 'JNJ', 'KO', 'JPM', 'MCD', 'MMM', 'MRK', 'MSFT', 'NKE', 'PG', 'TRV', 'UNH', 'CRM', 'VZ', 'V', 'WBA', 'WMT', 'DIS', 'DOW']\n" + ] + } + ], + "source": [ + "print(config_tickers.DOW_30_TICKER)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" }, + "id": "CV3HrZHLh1hy", + "outputId": "c2cf4956-210b-4811-be12-0c7fd18b923c" + }, + "outputs": [ { - "cell_type": "code", - "execution_count": 144, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 498 - }, - "id": "6xRfrqK4RVfq", - "outputId": "81bdf0b6-6471-4997-8ea0-a97ec5772d39" - }, - "outputs": [ - { - "output_type": "display_data", - "data": { - "text/plain": [ - "
" - ] - }, - "metadata": {} - }, - { - "output_type": "display_data", - "data": { - "text/plain": [ - "
" - ], - "image/png": "\n" - }, - "metadata": {} - } - ], - "source": [ - "%matplotlib inline\n", - "plt.rcParams[\"figure.figsize\"] = (15,5)\n", - "plt.figure();\n", - "result.plot();" + "data": { + "text/plain": [ + "(57, 9)" ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" } - ], - "metadata": { + ], + "source": [ + "df.shape" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { "colab": { - "collapsed_sections": [ - "Uy5_PTmOh1hj", - "A289rQWMh1hq", - "uqC6c40Zh1iH", - "-QsYaY0Dh1iw", - "uijiWgkuh1jB", - "MRiOtrywfAo1", - "_gDkU-j-fCmZ", - "3Zpv4S0-fDBv" - ], - "provenance": [] - }, - "kernelspec": { - "display_name": "base", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.5 (default, Sep 4 2020, 02:22:02) \n[Clang 10.0.0 ]" - }, - "vscode": { - "interpreter": { - "hash": "54cefccbf0f07c9750f12aa115c023dfa5ed4acecf9e7ad3bc9391869be60d0c" - } + "base_uri": "https://localhost:8080/", + "height": 0 + }, + "id": "4hYkeaPiICHS", + "outputId": "6d7a1c0d-15dc-4adc-b776-f1020e173a5c" + }, + "outputs": [ + { + "ename": "KeyError", + "evalue": "'date'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipykernel_24692/1255811168.py\u001b[0m in \u001b[0;36m?\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mdf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msort_values\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'date'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m'tic'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mignore_index\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhead\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/.local/lib/python3.12/site-packages/pandas/core/frame.py\u001b[0m in \u001b[0;36m?\u001b[0;34m(self, by, axis, ascending, inplace, kind, na_position, ignore_index, key)\u001b[0m\n\u001b[1;32m 7168\u001b[0m \u001b[0;34mf\"\u001b[0m\u001b[0;34mLength of ascending (\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mascending\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\"\u001b[0m \u001b[0;31m# type: ignore[arg-type]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7169\u001b[0m \u001b[0;34mf\"\u001b[0m\u001b[0;34m != length of by (\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mby\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7170\u001b[0m \u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7171\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mby\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 7172\u001b[0;31m \u001b[0mkeys\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_label_or_level_values\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mby\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 7173\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7174\u001b[0m \u001b[0;31m# need to rewrap columns in Series to apply key function\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7175\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mkey\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/.local/lib/python3.12/site-packages/pandas/core/generic.py\u001b[0m in \u001b[0;36m?\u001b[0;34m(self, key, axis)\u001b[0m\n\u001b[1;32m 1907\u001b[0m \u001b[0mvalues\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mxs\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mother_axes\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_values\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1908\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_is_level_reference\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1909\u001b[0m \u001b[0mvalues\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maxes\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_level_values\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_values\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1910\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1911\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mKeyError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1912\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1913\u001b[0m \u001b[0;31m# Check for duplicates\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1914\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mvalues\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mndim\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mKeyError\u001b[0m: 'date'" + ] } + ], + "source": [ + "df.sort_values(['date','tic'],ignore_index=True).head()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uqC6c40Zh1iH" + }, + "source": [ + "# Part 4: Preprocess Data\n", + "We need to check for missing data and do feature engineering to convert the data point into a state.\n", + "* **Adding technical indicators**. In practical trading, various information needs to be taken into account, such as historical prices, current holding shares, technical indicators, etc. Here, we demonstrate two trend-following technical indicators: MACD and RSI.\n", + "* **Adding turbulence index**. Risk-aversion reflects whether an investor prefers to protect the capital. It also influences one's trading strategy when facing different market volatility level. To control the risk in a worst-case scenario, such as financial crisis of 2007–2008, FinRL employs the turbulence index that measures extreme fluctuation of asset price." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "PmKP-1ii3RLS", + "outputId": "22fecb54-5555-4ec4-cb32-0a54f443e54e" + }, + "outputs": [], + "source": [ + "fe = FeatureEngineer(\n", + " use_technical_indicator=True,\n", + " tech_indicator_list = INDICATORS,\n", + " use_vix=True,\n", + " use_turbulence=True,\n", + " user_defined_feature = False)\n", + "\n", + "processed = fe.preprocess_data(df)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Kixon2tR3RLT" + }, + "outputs": [], + "source": [ + "list_ticker = processed[\"tic\"].unique().tolist()\n", + "list_date = list(pd.date_range(processed['date'].min(),processed['date'].max()).astype(str))\n", + "combination = list(itertools.product(list_date,list_ticker))\n", + "\n", + "processed_full = pd.DataFrame(combination,columns=[\"date\",\"tic\"]).merge(processed,on=[\"date\",\"tic\"],how=\"left\")\n", + "processed_full = processed_full[processed_full['date'].isin(processed['date'])]\n", + "processed_full = processed_full.sort_values(['date','tic'])\n", + "\n", + "processed_full = processed_full.fillna(0)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 0 + }, + "id": "grvhGJJII3Xn", + "outputId": "2af27938-0df3-4fea-e86d-7a361e71d2e2" + }, + "outputs": [], + "source": [ + "processed_full.sort_values(['date','tic'],ignore_index=True).head(10)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5vdORQ384Qx-" + }, + "outputs": [], + "source": [ + "mvo_df = processed_full.sort_values(['date','tic'],ignore_index=True)[['date','tic','close']]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-QsYaY0Dh1iw" + }, + "source": [ + "\n", + "# Part 5. Build A Market Environment in OpenAI Gym-style\n", + "The training process involves observing stock price change, taking an action and reward's calculation. By interacting with the market environment, the agent will eventually derive a trading strategy that may maximize (expected) rewards.\n", + "\n", + "Our market environment, based on OpenAI Gym, simulates stock markets with historical market data." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5TOhcryx44bb" + }, + "source": [ + "## Data Split\n", + "We split the data into training set and testing set as follows:\n", + "\n", + "Training data period: 2009-01-01 to 2020-07-01\n", + "\n", + "Trading data period: 2020-07-01 to 2021-10-31\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "W0qaVGjLtgbI", + "outputId": "4f16484e-811e-46cd-efee-54c6b309f5a5" + }, + "outputs": [], + "source": [ + "train = data_split(processed_full, TRAIN_START_DATE,TRAIN_END_DATE)\n", + "trade = data_split(processed_full, TRADE_START_DATE,TRADE_END_DATE)\n", + "train_length = len(train)\n", + "trade_length = len(trade)\n", + "print(train_length)\n", + "print(trade_length)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 0 + }, + "id": "p52zNCOhTtLR", + "outputId": "d708401b-129f-495b-e691-7ab8666d6847" + }, + "outputs": [], + "source": [ + "train.tail()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 0 + }, + "id": "k9zU9YaTTvFq", + "outputId": "9080799c-a150-4414-c2de-a68c5e7c3a85" + }, + "outputs": [], + "source": [ + "trade.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "zYN573SOHhxG", + "outputId": "f5dcfc60-af90-4aa0-8849-11848b3ef619" + }, + "outputs": [], + "source": [ + "INDICATORS" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Q2zqII8rMIqn", + "outputId": "b6f16ea3-8f52-44c7-ceb1-f58dabe3d1be" + }, + "outputs": [], + "source": [ + "stock_dimension = len(train.tic.unique())\n", + "state_space = 1 + 2*stock_dimension + len(INDICATORS)*stock_dimension\n", + "print(f\"Stock Dimension: {stock_dimension}, State Space: {state_space}\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "AWyp84Ltto19" + }, + "outputs": [], + "source": [ + "buy_cost_list = sell_cost_list = [0.001] * stock_dimension\n", + "num_stock_shares = [0] * stock_dimension\n", + "\n", + "env_kwargs = {\n", + " \"hmax\": 100,\n", + " \"initial_amount\": 1000000,\n", + " \"num_stock_shares\": num_stock_shares,\n", + " \"buy_cost_pct\": buy_cost_list,\n", + " \"sell_cost_pct\": sell_cost_list,\n", + " \"state_space\": state_space,\n", + " \"stock_dim\": stock_dimension,\n", + " \"tech_indicator_list\": INDICATORS,\n", + " \"action_space\": stock_dimension,\n", + " \"reward_scaling\": 1e-4\n", + "}\n", + "\n", + "\n", + "e_train_gym = StockTradingEnv(df = train, **env_kwargs)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "64EoqOrQjiVf" + }, + "source": [ + "## Environment for Training\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "xwSvvPjutpqS", + "outputId": "e8fc8f68-b8c9-47a8-e7d2-a6ed0715d216" + }, + "outputs": [], + "source": [ + "env_train, _ = e_train_gym.get_sb_env()\n", + "print(type(env_train))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HMNR5nHjh1iz" + }, + "source": [ + "\n", + "# Part 6: Train DRL Agents\n", + "* The DRL algorithms are from **Stable Baselines 3**. Users are also encouraged to try **ElegantRL** and **Ray RLlib**.\n", + "* FinRL includes fine-tuned standard DRL algorithms, such as DQN, DDPG, Multi-Agent DDPG, PPO, SAC, A2C and TD3. We also allow users to\n", + "design their own DRL algorithms by adapting these DRL algorithms." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "364PsqckttcQ" + }, + "outputs": [], + "source": [ + "agent = DRLAgent(env = env_train)\n", + "\n", + "if_using_a2c = True\n", + "if_using_ddpg = True\n", + "if_using_ppo = True\n", + "if_using_td3 = True\n", + "if_using_sac = True\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YDmqOyF9h1iz" + }, + "source": [ + "### Agent Training: 5 algorithms (A2C, DDPG, PPO, TD3, SAC)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uijiWgkuh1jB" + }, + "source": [ + "### Agent 1: A2C\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "GUCnkn-HIbmj", + "outputId": "7112ce2a-0f62-4a9c-c8be-4443779b4ba0" + }, + "outputs": [], + "source": [ + "agent = DRLAgent(env = env_train)\n", + "model_a2c = agent.get_model(\"a2c\")\n", + "\n", + "if if_using_a2c:\n", + " # set up logger\n", + " tmp_path = RESULTS_DIR + '/a2c'\n", + " new_logger_a2c = configure(tmp_path, [\"stdout\", \"csv\", \"tensorboard\"])\n", + " # Set new logger\n", + " model_a2c.set_logger(new_logger_a2c)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "0GVpkWGqH4-D", + "outputId": "d00d9ef6-7489-4126-f53f-376612f48466" + }, + "outputs": [], + "source": [ + "trained_a2c = agent.train_model(model=model_a2c, \n", + " tb_log_name='a2c',\n", + " total_timesteps=50000) if if_using_a2c else None" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MRiOtrywfAo1" + }, + "source": [ + "### Agent 2: DDPG" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "M2YadjfnLwgt", + "outputId": "8c8b5e98-763c-453c-a280-1b4f3ac13510" + }, + "outputs": [], + "source": [ + "agent = DRLAgent(env = env_train)\n", + "model_ddpg = agent.get_model(\"ddpg\")\n", + "\n", + "if if_using_ddpg:\n", + " # set up logger\n", + " tmp_path = RESULTS_DIR + '/ddpg'\n", + " new_logger_ddpg = configure(tmp_path, [\"stdout\", \"csv\", \"tensorboard\"])\n", + " # Set new logger\n", + " model_ddpg.set_logger(new_logger_ddpg)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "tCDa78rqfO_a", + "outputId": "35589661-85de-42ca-b9f1-52cde7ded447" + }, + "outputs": [], + "source": [ + "trained_ddpg = agent.train_model(model=model_ddpg, \n", + " tb_log_name='ddpg',\n", + " total_timesteps=50000) if if_using_ddpg else None" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_gDkU-j-fCmZ" + }, + "source": [ + "### Agent 3: PPO" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "y5D5PFUhMzSV", + "outputId": "2abd06c0-deca-457b-819b-3059c3f17645" + }, + "outputs": [], + "source": [ + "agent = DRLAgent(env = env_train)\n", + "PPO_PARAMS = {\n", + " \"n_steps\": 2048,\n", + " \"ent_coef\": 0.01,\n", + " \"learning_rate\": 0.00025,\n", + " \"batch_size\": 128,\n", + "}\n", + "model_ppo = agent.get_model(\"ppo\",model_kwargs = PPO_PARAMS)\n", + "\n", + "if if_using_ppo:\n", + " # set up logger\n", + " tmp_path = RESULTS_DIR + '/ppo'\n", + " new_logger_ppo = configure(tmp_path, [\"stdout\", \"csv\", \"tensorboard\"])\n", + " # Set new logger\n", + " model_ppo.set_logger(new_logger_ppo)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Gt8eIQKYM4G3", + "outputId": "26365c9a-f608-4dd4-9695-018b98d1036a" + }, + "outputs": [], + "source": [ + "trained_ppo = agent.train_model(model=model_ppo, \n", + " tb_log_name='ppo',\n", + " total_timesteps=50000) if if_using_ppo else None" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3Zpv4S0-fDBv" + }, + "source": [ + "### Agent 4: TD3" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "JSAHhV4Xc-bh", + "outputId": "db147b9a-163a-4d03-dd6c-9e89f0e8f421" + }, + "outputs": [], + "source": [ + "agent = DRLAgent(env = env_train)\n", + "TD3_PARAMS = {\"batch_size\": 100, \n", + " \"buffer_size\": 1000000, \n", + " \"learning_rate\": 0.001}\n", + "\n", + "model_td3 = agent.get_model(\"td3\",model_kwargs = TD3_PARAMS)\n", + "\n", + "if if_using_td3:\n", + " # set up logger\n", + " tmp_path = RESULTS_DIR + '/td3'\n", + " new_logger_td3 = configure(tmp_path, [\"stdout\", \"csv\", \"tensorboard\"])\n", + " # Set new logger\n", + " model_td3.set_logger(new_logger_td3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "OSRxNYAxdKpU", + "outputId": "1d85d74c-54cf-4682-a34b-481a5aafe5d4" + }, + "outputs": [], + "source": [ + "trained_td3 = agent.train_model(model=model_td3, \n", + " tb_log_name='td3',\n", + " total_timesteps=50000) if if_using_td3 else None" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Dr49PotrfG01" + }, + "source": [ + "### Agent 5: SAC" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "xwOhVjqRkCdM", + "outputId": "9018f9ed-0dff-4b75-c0b2-7566784c52cf" + }, + "outputs": [], + "source": [ + "agent = DRLAgent(env = env_train)\n", + "SAC_PARAMS = {\n", + " \"batch_size\": 128,\n", + " \"buffer_size\": 100000,\n", + " \"learning_rate\": 0.0001,\n", + " \"learning_starts\": 100,\n", + " \"ent_coef\": \"auto_0.1\",\n", + "}\n", + "\n", + "model_sac = agent.get_model(\"sac\",model_kwargs = SAC_PARAMS)\n", + "\n", + "if if_using_sac:\n", + " # set up logger\n", + " tmp_path = RESULTS_DIR + '/sac'\n", + " new_logger_sac = configure(tmp_path, [\"stdout\", \"csv\", \"tensorboard\"])\n", + " # Set new logger\n", + " model_sac.set_logger(new_logger_sac)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "K8RSdKCckJyH", + "outputId": "bfa91496-f7e6-4d0f-fb77-bc9dd1797e81" + }, + "outputs": [], + "source": [ + "trained_sac = agent.train_model(model=model_sac, \n", + " tb_log_name='sac',\n", + " total_timesteps=50000) if if_using_sac else None" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "f2wZgkQXh1jE" + }, + "source": [ + "## In-sample Performance\n", + "\n", + "Assume that the initial capital is $1,000,000." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "bEv5KGC8h1jE" + }, + "source": [ + "### Set turbulence threshold\n", + "Set the turbulence threshold to be greater than the maximum of insample turbulence data. If current turbulence index is greater than the threshold, then we assume that the current market is volatile" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "efwBi84ch1jE" + }, + "outputs": [], + "source": [ + "data_risk_indicator = processed_full[(processed_full.date=TRAIN_START_DATE)]\n", + "insample_risk_indicator = data_risk_indicator.drop_duplicates(subset=['date'])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "VHZMBpSqh1jG", + "outputId": "3164bf6e-3b83-4bbf-ecd4-7688c6309e8c" + }, + "outputs": [], + "source": [ + "insample_risk_indicator.vix.describe()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "BDkszkMloRWT", + "outputId": "7e36e119-63e2-4379-f110-490836222522" + }, + "outputs": [], + "source": [ + "insample_risk_indicator.vix.quantile(0.996)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "AL7hs7svnNWT", + "outputId": "13abfde5-de24-40b7-921e-385dd435b3e8" + }, + "outputs": [], + "source": [ + "insample_risk_indicator.turbulence.describe()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "N78hfHckoqJ9", + "outputId": "b5f650e9-cf0a-4481-b519-b77c8a0b1b2a" + }, + "outputs": [], + "source": [ + "insample_risk_indicator.turbulence.quantile(0.996)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "U5mmgQF_h1jQ" + }, + "source": [ + "### Trading (Out-of-sample Performance)\n", + "\n", + "We update periodically in order to take full advantage of the data, e.g., retrain quarterly, monthly or weekly. We also tune the parameters along the way, in this notebook we use the in-sample data from 2009-01 to 2020-07 to tune the parameters once, so there is some alpha decay here as the length of trade date extends. \n", + "\n", + "Numerous hyperparameters – e.g. the learning rate, the total number of samples to train on – influence the learning process and are usually determined by testing some variations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cIqoV0GSI52v" + }, + "outputs": [], + "source": [ + "e_trade_gym = StockTradingEnv(df = trade, turbulence_threshold = 70,risk_indicator_col='vix', **env_kwargs)\n", + "# env_trade, obs_trade = e_trade_gym.get_sb_env()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 357 + }, + "id": "W_XNgGsBMeVw", + "outputId": "13588f5a-daef-4a7b-c116-c737bf61e994" + }, + "outputs": [], + "source": [ + "trade.head()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "lbFchno5j3xs", + "outputId": "5df880d8-ff14-4104-a2f8-a2d1a417cc1c" + }, + "outputs": [], + "source": [ + "trained_moedl = trained_a2c\n", + "df_account_value_a2c, df_actions_a2c = DRLAgent.DRL_prediction(\n", + " model=trained_moedl, \n", + " environment = e_trade_gym)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "JbYljWGjj3pH", + "outputId": "2fb2632a-dd77-40f2-eeff-e4b3385727f2" + }, + "outputs": [], + "source": [ + "trained_moedl = trained_ddpg\n", + "df_account_value_ddpg, df_actions_ddpg = DRLAgent.DRL_prediction(\n", + " model=trained_moedl, \n", + " environment = e_trade_gym)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "74jNP2DBj3hb", + "outputId": "9659e354-3d56-4fe3-b6bb-81777d179c51" + }, + "outputs": [], + "source": [ + "trained_moedl = trained_ppo\n", + "df_account_value_ppo, df_actions_ppo = DRLAgent.DRL_prediction(\n", + " model=trained_moedl, \n", + " environment = e_trade_gym)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "S7VyGGJPj3SH", + "outputId": "a65b52c5-aba0-4e48-b111-481b514fcce2" + }, + "outputs": [], + "source": [ + "trained_moedl = trained_td3\n", + "df_account_value_td3, df_actions_td3 = DRLAgent.DRL_prediction(\n", + " model=trained_moedl, \n", + " environment = e_trade_gym)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "eLOnL5eYh1jR", + "outputId": "3d9bf94b-2bb5-4091-dc7f-bfe2851dc0be" + }, + "outputs": [], + "source": [ + "trained_moedl = trained_sac\n", + "df_account_value_sac, df_actions_sac = DRLAgent.DRL_prediction(\n", + " model=trained_moedl, \n", + " environment = e_trade_gym)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "ERxw3KqLkcP4", + "outputId": "219b1298-4a18-41a3-8390-788739158dd7" + }, + "outputs": [], + "source": [ + "df_account_value_a2c.shape" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GcE-t08w6DaW" + }, + "source": [ + "\n", + "# Part 6.5: Mean Variance Optimization" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GzyHU-RokTaj" + }, + "source": [ + "Mean Variance optimization is a very classic strategy in portfolio management. Here, we go through the whole process to do the mean variance optimization and add it as a baseline to compare.\n", + "\n", + "First, process dataframe to the form for MVO weight calculation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ZaxdYAdRcA67" + }, + "outputs": [], + "source": [ + "def process_df_for_mvo(df):\n", + " df = df.sort_values(['date','tic'],ignore_index=True)[['date','tic','close']]\n", + " fst = df\n", + " fst = fst.iloc[0:stock_dimension, :]\n", + " tic = fst['tic'].tolist()\n", + "\n", + " mvo = pd.DataFrame()\n", + "\n", + " for k in range(len(tic)):\n", + " mvo[tic[k]] = 0\n", + "\n", + " for i in range(df.shape[0]//stock_dimension):\n", + " n = df\n", + " n = n.iloc[i * stock_dimension:(i+1) * stock_dimension, :]\n", + " date = n['date'][i*stock_dimension]\n", + " mvo.loc[date] = n['close'].tolist()\n", + " \n", + " return mvo" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tcHDZ7hFkdyL" + }, + "source": [ + "### Helper functions for mean returns and variance-covariance matrix" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gKjY9bvYcEkb" + }, + "outputs": [], + "source": [ + "# Codes in this section partially refer to Dr G A Vijayalakshmi Pai\n", + "\n", + "# https://www.kaggle.com/code/vijipai/lesson-5-mean-variance-optimization-of-portfolios/notebook\n", + "\n", + "def StockReturnsComputing(StockPrice, Rows, Columns): \n", + " import numpy as np \n", + " StockReturn = np.zeros([Rows-1, Columns]) \n", + " for j in range(Columns): # j: Assets \n", + " for i in range(Rows-1): # i: Daily Prices \n", + " StockReturn[i,j]=((StockPrice[i+1, j]-StockPrice[i,j])/StockPrice[i,j])* 100 \n", + " \n", + " return StockReturn" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "CPnMNonxkj-I" + }, + "source": [ + "### Calculate the weights for mean-variance" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wdF2erPNcVd3" + }, + "outputs": [], + "source": [ + "train_mvo = data_split(processed_full, TRAIN_START_DATE,TRAIN_END_DATE).reset_index()\n", + "trade_mvo = data_split(processed_full, TRADE_START_DATE,TRADE_END_DATE).reset_index()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "9-64xYTOcJ36", + "outputId": "5cf98bac-c467-4ef1-e98c-2bb858a848c2" + }, + "outputs": [], + "source": [ + "StockData = process_df_for_mvo(train_mvo)\n", + "TradeData = process_df_for_mvo(trade_mvo)\n", + "\n", + "TradeData.to_numpy()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "28q2-ebfcfbu", + "outputId": "3a51ec82-f586-4462-f5d1-604017ffa1fe" + }, + "outputs": [], + "source": [ + "#compute asset returns\n", + "arStockPrices = np.asarray(StockData)\n", + "[Rows, Cols]=arStockPrices.shape\n", + "arReturns = StockReturnsComputing(arStockPrices, Rows, Cols)\n", + "\n", + "#compute mean returns and variance covariance matrix of returns\n", + "meanReturns = np.mean(arReturns, axis = 0)\n", + "covReturns = np.cov(arReturns, rowvar=False)\n", + " \n", + "#set precision for printing results\n", + "np.set_printoptions(precision=3, suppress = True)\n", + "\n", + "#display mean returns and variance-covariance matrix of returns\n", + "print('Mean returns of assets in k-portfolio 1\\n', meanReturns)\n", + "print('Variance-Covariance matrix of returns\\n', covReturns)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Ei3f_NxDkpOx" + }, + "source": [ + "### Use PyPortfolioOpt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "bHc3FC3Hckay", + "outputId": "6585f4b7-fda4-4d83-c3cc-38c5ed750aea" + }, + "outputs": [], + "source": [ + "from pypfopt.efficient_frontier import EfficientFrontier\n", + "\n", + "ef_mean = EfficientFrontier(meanReturns, covReturns, weight_bounds=(0, 0.5))\n", + "raw_weights_mean = ef_mean.max_sharpe()\n", + "cleaned_weights_mean = ef_mean.clean_weights()\n", + "mvo_weights = np.array([1000000 * cleaned_weights_mean[i] for i in range(29)])\n", + "mvo_weights" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "iFiwDj29ck9s", + "outputId": "1e4c7967-c5af-43de-a858-beadfef5116c" + }, + "outputs": [], + "source": [ + "LastPrice = np.array([1/p for p in StockData.tail(1).to_numpy()[0]])\n", + "Initial_Portfolio = np.multiply(mvo_weights, LastPrice)\n", + "Initial_Portfolio" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wbcVsNYfcn2B" + }, + "outputs": [], + "source": [ + "Portfolio_Assets = TradeData @ Initial_Portfolio\n", + "MVO_result = pd.DataFrame(Portfolio_Assets, columns=[\"Mean Var\"])\n", + "# MVO_result" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "W6vvNSC6h1jZ" + }, + "source": [ + "\n", + "# Part 7: Backtesting Results\n", + "Backtesting plays a key role in evaluating the performance of a trading strategy. Automated backtesting tool is preferred because it reduces the human error. We usually use the Quantopian pyfolio package to backtest our trading strategies. It is easy to use and consists of various individual plots that provide a comprehensive image of the performance of a trading strategy." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "KeDeGAc9VrEg", + "outputId": "fe8802d9-e883-48fb-ed8d-36a8236322f7" + }, + "outputs": [], + "source": [ + "df_result_a2c = df_account_value_a2c.set_index(df_account_value_a2c.columns[0])\n", + "df_result_a2c.rename(columns = {'account_value':'a2c'}, inplace = True)\n", + "df_result_ddpg = df_account_value_ddpg.set_index(df_account_value_ddpg.columns[0])\n", + "df_result_ddpg.rename(columns = {'account_value':'ddpg'}, inplace = True)\n", + "df_result_td3 = df_account_value_td3.set_index(df_account_value_td3.columns[0])\n", + "df_result_td3.rename(columns = {'account_value':'td3'}, inplace = True)\n", + "df_result_ppo = df_account_value_ppo.set_index(df_account_value_ppo.columns[0])\n", + "df_result_ppo.rename(columns = {'account_value':'ppo'}, inplace = True)\n", + "df_result_sac = df_account_value_sac.set_index(df_account_value_sac.columns[0])\n", + "df_result_sac.rename(columns = {'account_value':'sac'}, inplace = True)\n", + "df_account_value_a2c.to_csv(\"df_account_value_a2c.csv\")\n", + "#baseline stats\n", + "print(\"==============Get Baseline Stats===========\")\n", + "df_dji_ = get_baseline(\n", + " ticker=\"^DJI\", \n", + " start = TRADE_START_DATE,\n", + " end = TRADE_END_DATE)\n", + "stats = backtest_stats(df_dji_, value_col_name = 'close')\n", + "df_dji = pd.DataFrame()\n", + "df_dji['date'] = df_account_value_a2c['date']\n", + "df_dji['account_value'] = df_dji_['close'] / df_dji_['close'][0] * env_kwargs[\"initial_amount\"]\n", + "df_dji.to_csv(\"df_dji.csv\")\n", + "df_dji = df_dji.set_index(df_dji.columns[0])\n", + "df_dji.to_csv(\"df_dji+.csv\")\n", + "\n", + "result = pd.DataFrame()\n", + "result = pd.merge(result, df_result_a2c, how='outer', left_index=True, right_index=True)\n", + "result = pd.merge(result, df_result_ddpg, how='outer', left_index=True, right_index=True)\n", + "result = pd.merge(result, df_result_td3, how='outer', left_index=True, right_index=True)\n", + "result = pd.merge(result, df_result_ppo, how='outer', left_index=True, right_index=True)\n", + "result = pd.merge(result, df_result_sac, how='outer', left_index=True, right_index=True)\n", + "result = pd.merge(result, MVO_result, how='outer', left_index=True, right_index=True)\n", + "print(result.head())\n", + "result = pd.merge(result, df_dji, how='outer', left_index=True, right_index=True)\n", + "# result.columns = ['a2c', 'ddpg', 'td3', 'ppo', 'sac', 'mean var', 'dji']\n", + "\n", + "# print(\"result: \", result)\n", + "result.to_csv(\"result.csv\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 455 + }, + "id": "WLapAJTri_7B", + "outputId": "d9625b21-8814-4ec5-bc6e-3a331be40856" + }, + "outputs": [], + "source": [ + "df_result_ddpg" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 498 + }, + "id": "6xRfrqK4RVfq", + "outputId": "81bdf0b6-6471-4997-8ea0-a97ec5772d39" + }, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "plt.rcParams[\"figure.figsize\"] = (15,5)\n", + "plt.figure();\n", + "result.plot();" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [ + "Uy5_PTmOh1hj", + "A289rQWMh1hq", + "uqC6c40Zh1iH", + "-QsYaY0Dh1iw", + "uijiWgkuh1jB", + "MRiOtrywfAo1", + "_gDkU-j-fCmZ", + "3Zpv4S0-fDBv" + ], + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.8" }, - "nbformat": 4, - "nbformat_minor": 0 + "vscode": { + "interpreter": { + "hash": "54cefccbf0f07c9750f12aa115c023dfa5ed4acecf9e7ad3bc9391869be60d0c" + } + } + }, + "nbformat": 4, + "nbformat_minor": 4 } diff --git a/finrl/meta/data_processors/processor_yahoofinance.py b/finrl/meta/data_processors/processor_yahoofinance.py index 11c4c2da0..6c05d588e 100644 --- a/finrl/meta/data_processors/processor_yahoofinance.py +++ b/finrl/meta/data_processors/processor_yahoofinance.py @@ -3,6 +3,7 @@ from __future__ import annotations import datetime +import time from datetime import date from datetime import timedelta from sqlite3 import Timestamp @@ -19,7 +20,16 @@ import pandas as pd import pytz import yfinance as yf +from bs4 import BeautifulSoup +from selenium import webdriver +from selenium.webdriver.chrome.options import Options +from selenium.webdriver.chrome.service import Service +from selenium.webdriver.common.action_chains import ActionChains +from selenium.webdriver.common.by import By from stockstats import StockDataFrame as Sdf +from webdriver_manager.chrome import ChromeDriverManager + +### Added by aymeric75 for scrap_data function class YahooFinanceProcessor: @@ -56,6 +66,119 @@ def __init__(self): ... """ + ######## ADDED BY aymeric75 ################### + + def date_to_unix(self, date_str) -> int: + """Convert a date string in yyyy-mm-dd format to Unix timestamp.""" + dt = datetime.datetime.strptime(date_str, "%Y-%m-%d") + return int(dt.timestamp()) + + def fetch_stock_data(self, stock_name, period1, period2) -> pd.DataFrame: + # Base URL + url = f"https://finance.yahoo.com/quote/{stock_name}/history/?period1={period1}&period2={period2}&filter=history" + + # Selenium WebDriver Setup + options = Options() + options.add_argument("--headless") # Headless for performance + options.add_argument("--disable-gpu") # Disable GPU for compatibility + driver = webdriver.Chrome( + service=Service(ChromeDriverManager().install()), options=options + ) + + # Navigate to the URL + driver.get(url) + driver.maximize_window() + time.sleep(5) # Wait for redirection and page load + + # Handle potential popup + try: + RejectAll = driver.find_element( + By.XPATH, '//button[@class="btn secondary reject-all"]' + ) + action = ActionChains(driver) + action.click(on_element=RejectAll) + action.perform() + time.sleep(5) + + except Exception as e: + print("Popup not found or handled:", e) + + # Parse the page for the table + soup = BeautifulSoup(driver.page_source, "html.parser") + table = soup.find("table") + if not table: + raise Exception("No table found after handling redirection and popup.") + + # Extract headers + headers = [th.text.strip() for th in table.find_all("th")] + headers[4] = "Close" + headers[5] = "Adj Close" + headers = ["date", "open", "high", "low", "close", "adjcp", "volume"] + # , 'tic', 'day' + + # Extract rows + rows = [] + for tr in table.find_all("tr")[1:]: # Skip header row + cells = [td.text.strip() for td in tr.find_all("td")] + if len(cells) == len(headers): # Only add rows with correct column count + rows.append(cells) + + # Create DataFrame + df = pd.DataFrame(rows, columns=headers) + + # Convert columns to appropriate data types + def safe_convert(value, dtype): + try: + return dtype(value.replace(",", "")) + except ValueError: + return value + + df["open"] = df["open"].apply(lambda x: safe_convert(x, float)) + df["high"] = df["high"].apply(lambda x: safe_convert(x, float)) + df["low"] = df["low"].apply(lambda x: safe_convert(x, float)) + df["close"] = df["close"].apply(lambda x: safe_convert(x, float)) + df["adjcp"] = df["adjcp"].apply(lambda x: safe_convert(x, float)) + df["volume"] = df["volume"].apply(lambda x: safe_convert(x, int)) + + # Add 'tic' column + df["tic"] = stock_name + + # Add 'day' column + start_date = datetime.datetime.fromtimestamp(period1) + df["date"] = pd.to_datetime(df["date"]) + df["day"] = (df["date"] - start_date).dt.days + df = df[df["day"] >= 0] # Exclude rows with days before the start date + + # Reverse the DataFrame rows + df = df.iloc[::-1].reset_index(drop=True) + + return df + + def scrap_data(self, stock_names, start_date, end_date) -> pd.DataFrame: + """Fetch and combine stock data for multiple stock names.""" + period1 = self.date_to_unix(start_date) + period2 = self.date_to_unix(end_date) + + all_dataframes = [] + total_stocks = len(stock_names) + + for i, stock_name in enumerate(stock_names): + try: + print( + f"Processing {stock_name} ({i + 1}/{total_stocks})... {(i + 1) / total_stocks * 100:.2f}% complete." + ) + df = self.fetch_stock_data(stock_name, period1, period2) + all_dataframes.append(df) + except Exception as e: + print(f"Error fetching data for {stock_name}: {e}") + + combined_df = pd.concat(all_dataframes, ignore_index=True) + combined_df = combined_df.sort_values(by=["day", "tick"]).reset_index(drop=True) + + return combined_df + + ######## END ADDED BY aymeric75 ################### + def convert_interval(self, time_interval: str) -> str: # Convert FinRL 'standardised' time periods to Yahoo format: 1m, 2m, 5m, 15m, 30m, 60m, 90m, 1h, 1d, 5d, 1wk, 1mo, 3mo if time_interval in [