|
- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- import parl
- import paddle
- import paddle.nn as nn
- import paddle.nn.functional as F
-
-
- class MujocoModel(parl.Model):
- def __init__(self, obs_dim, action_dim):
- super(MujocoModel, self).__init__()
- self.actor_model = Actor(obs_dim, action_dim)
- self.critic_model = Critic(obs_dim, action_dim)
-
- def policy(self, obs):
- return self.actor_model(obs)
-
- def value(self, obs, action):
- return self.critic_model(obs, action)
-
- def Q1(self, obs, action):
- return self.critic_model.Q1(obs, action)
-
- def get_actor_params(self):
- return self.actor_model.parameters()
-
- def get_critic_params(self):
- return self.critic_model.parameters()
-
-
- class Actor(parl.Model):
- def __init__(self, obs_dim, action_dim):
- super(Actor, self).__init__()
-
- self.l1 = nn.Linear(obs_dim, 256)
- self.l2 = nn.Linear(256, 256)
- self.l3 = nn.Linear(256, action_dim)
-
- def forward(self, obs):
- x = F.relu(self.l1(obs))
- x = F.relu(self.l2(x))
- action = paddle.tanh(self.l3(x))
- return action
-
-
- class Critic(parl.Model):
- def __init__(self, obs_dim, action_dim):
- super(Critic, self).__init__()
-
- # Q1 architecture
- self.l1 = nn.Linear(obs_dim + action_dim, 256)
- self.l2 = nn.Linear(256, 256)
- self.l3 = nn.Linear(256, 1)
-
- # Q2 architecture
- self.l4 = nn.Linear(obs_dim + action_dim, 256)
- self.l5 = nn.Linear(256, 256)
- self.l6 = nn.Linear(256, 1)
-
- def forward(self, obs, action):
- sa = paddle.concat([obs, action], 1)
-
- q1 = F.relu(self.l1(sa))
- q1 = F.relu(self.l2(q1))
- q1 = self.l3(q1)
-
- q2 = F.relu(self.l4(sa))
- q2 = F.relu(self.l5(q2))
- q2 = self.l6(q2)
- return q1, q2
-
- def Q1(self, obs, action):
- sa = paddle.concat([obs, action], 1)
-
- q1 = F.relu(self.l1(sa))
- q1 = F.relu(self.l2(q1))
- q1 = self.l3(q1)
- return q1
|