|
- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- import paddle
- import paddle.nn as nn
- import paddle.nn.functional as F
- import parl
-
-
- class RNNModel(parl.Model):
- def __init__(self, input_shape, n_actions, rnn_hidden_dim=64):
- super(RNNModel, self).__init__()
- self.rnn_hidden_dim = rnn_hidden_dim
-
- self.fc1 = nn.Linear(input_shape, rnn_hidden_dim)
- self.rnn = nn.GRUCell(
- input_size=rnn_hidden_dim, hidden_size=rnn_hidden_dim)
- self.fc2 = nn.Linear(rnn_hidden_dim, n_actions)
-
- def init_hidden(self):
- hidden_state = paddle.zeros((1, self.rnn_hidden_dim), dtype='float32')
- return hidden_state
-
- def forward(self, inputs, hidden_state):
- x = F.relu(self.fc1(inputs))
- h_in = hidden_state.reshape(shape=(-1, self.rnn_hidden_dim))
-
- _, h = self.rnn(x, h_in)
- q = self.fc2(h) # (batch_size, n_actions)
- return q, h
|