|
34 | 34 | "\n",
|
35 | 35 | "## The REINFORCE Algorithm\n",
|
36 | 36 | "\n",
|
37 |
| - "The reinforcement learning objective $J$ is the expected total return, following the policy $\\pi$. If the transition probability is denoted by $p(s'|s,a)$, and the initial state distribution is $p(s_0)$, the probability for a trajectory $\\tau = (s_0,a_0,r_1,s_1,a_1,\\dots,s_{T-1},a_{T-1},r_T,s_T)$ to occur can be written as\n", |
| 37 | + "The reinforcement learning objective $\\mathbf J$ is the expected total return, following the policy $\\pi$. If the transition probability is denoted by $p(s'|s,a)$, and the initial state distribution is $p(s_0)$, the probability for a trajectory $\\tau = (s_0,a_0,r_1,s_1,a_1,\\dots,s_{T-1},a_{T-1},r_T,s_T)$ to occur can be written as\n", |
38 | 38 | "\n",
|
39 | 39 | "$$ \n",
|
40 | 40 | "P_\\pi(\\tau) = p(s_0)\\prod_{t=1}^T \\pi(a_t|s_t)p(s_{t+1}|s_t,a_t). \n",
|
|
685 | 685 | "\n",
|
686 | 686 | "The implementation of the PG algorithm proceeds as follows:\n",
|
687 | 687 | "\n",
|
688 |
| - "1. Define a SoftMax model for the discrete policy $\\pi_\\theta$.\n", |
689 |
| - "2. Define the pseudo loss function to easily compute $\\nabla_\\theta J(\\theta)$.\n", |
| 688 | + "1. Define a model to learn the mean and standard deviation of a continuous Gaussian policy $\\pi_\\theta$.\n", |
| 689 | + "2. Define the pseudo loss function to easily compute $\\nabla_\\theta \\mathbf J(\\theta)$.\n", |
690 | 690 | "3. Define generalized gradient descent optimizer.\n",
|
691 | 691 | "4. Define the PG training loop and train the policy.\n",
|
692 | 692 | "\n",
|
693 |
| - "*Note:* if you are familiar with solving the MNIST problem, you will recognize many of the steps used to construct and train the neural network. What is different here is the training algorithm.\n", |
| 693 | + "### Define a model for the mean and standard deviation of the continuous Gaussian policy $\\pi_\\theta$\n", |
694 | 694 | "\n",
|
695 |
| - "### Define a SoftMax model for the discrete policy $\\pi_\\theta$\n", |
| 695 | + "Use JAX to construct a feed-forward deep neural network architecture; we use a single fully connected input layer that splits into two heads, one for the mean and the standard deviation, respectively. We model each gate angle $\\alpha,\\beta,\\gamma$ by an independent Gaussian policy, so we need the mean and std head to output a vector of three values each (one for each angle). The network architecture thus reads $(M_s, 8, 3/3, 3)$, where there are $8$ ($3$) neurons in the first (second) layer, respectively, and $M_s$ defines the batch size.\n", |
696 | 696 | "\n",
|
697 |
| - "Use JAX to construct a feed-forward fully-connected deep neural network with neuron architecture $(M_s, 512, 256, |\\mathcal{A}|)$, where there are $512$ ($256$) neurons in the first (second) hidden layer, respectively, and $M_s$ and $|\\mathcal{A}|$ define the input and output sizes.\n", |
698 |
| - "\n", |
699 |
| - "The input data into the neural network should have the shape `input_shape = (-1, n_time_steps, M_s)`, where `M_s` is the number of features/components in the RL state $s=(\\theta,\\varphi)$. The output data should have the shape `output_shape = (-1, n_time_steps, abs_A)`, where `abs_A`$=|\\mathcal{A}|$. In this way, we can use the neural network to process simultaneously all time steps and MC samples, generated in a single training iteration. \n", |
| 697 | + "The input data into the neural network should have the shape `input_shape = (-1, n_time_steps, 3*env.n_time_steps)`, where `(n_time_steps, 3*env.n_time_steps)` is the shape of the RL state variable $s$. The output data should have the shape `output_shape = (-1, n_time_steps, 3)`, where $3$ refers to the three angles. In this way, we can use the neural network to process simultaneously all time steps and MC samples, generated in a single training iteration. \n", |
700 | 698 | "\n",
|
701 | 699 | "Check explicitly the output shape and test that the network runs on some fake data (e.g. a small batch of vectors of ones with the appropriate shape). "
|
702 | 700 | ]
|
|
778 | 776 | "id": "343Vdvxv2tlT"
|
779 | 777 | },
|
780 | 778 | "source": [
|
781 |
| - "### Define the pseudo loss function to easily compute $\\nabla_\\theta J(\\theta)$\n", |
| 779 | + "### Define the pseudo loss function to easily compute $\\nabla_\\theta \\mathbf J(\\theta)$\n", |
782 | 780 | "\n",
|
783 |
| - "REINFORCE allows to define a scalar pseudoloss function, whose gradients give $\\nabla_\\theta J(\\theta)$. Note that this pseudoloss does ***NOT*** correspond to the RL objective $J(\\theta)$: the difference stems from the fact that the two operations of taking the derivative and performing the MC approximation are not interchangeable (do you see why?). \n", |
| 781 | + "REINFORCE allows to define a scalar pseudoloss function, whose gradients give $\\nabla_\\theta \\mathbf J(\\theta)$. Note that this pseudoloss does ***NOT*** correspond to the RL objective $\\mathbf J(\\theta)$: the difference stems from the fact that the two operations of taking the derivative and performing the MC approximation are not interchangeable (do you see why?). \n", |
784 | 782 | "\n",
|
785 | 783 | "$$\n",
|
786 |
| - "J_\\mathrm{pseudo}(\\theta) = \n", |
| 784 | + "\\mathbf J_\\mathrm{pseudo}(\\theta) = \n", |
787 | 785 | "\\frac{1}{N}\\sum_{j=1}^N \\sum_{t=1}^T \\log \\pi_\\theta(a^j_t|s^j_t) \\left[\\sum_{t'=t}^T r(a^j_{t'}|s^j_{t'}) - b_t\\right],\\qquad \n",
|
788 |
| - "b_t = \\frac{1}{N}\\sum_{j=1}^N G_t(\\tau_j).\n", |
| 786 | + "b_t = \\frac{1}{N}\\sum_{j=1}^N \\mathbf G_t(\\tau_j).\n", |
789 | 787 | "$$\n",
|
790 |
| - "The baseline is a sample average of the reward-to-go (return) from time step $t$ onwards: $G_t(\\tau_j) = \\sum_{t'=t}^N r(s^j_{t'},s^j_{t'})$ .\n", |
| 788 | + "The baseline is a sample average of the reward-to-go (return) from time step $t$ onwards: $\\mathbf G_t(\\tau_j) = \\sum_{t'=t}^N r(s^j_{t'},s^j_{t'})$ .\n", |
791 | 789 | "\n",
|
792 |
| - "Because we will be doing gradient **a**scent, do **NOT** forget to add an extra minus sign to the output ot the pseudoloss (or else your agent will end up minimizing the return). \n", |
| 790 | + "Because we will be doing gradient **a**scent, do **NOT** forget to add an extra minus sign to the output of the pseudoloss (or else your agent will end up minimizing the return). \n", |
793 | 791 | "\n",
|
794 | 792 | "Below, we also add an L2 regularizer to the pseudoloss function to prevent overfitting. "
|
795 | 793 | ]
|
|
0 commit comments