|
- # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- import os
- import gym
- import numpy as np
- import parl
- from parl.utils import logger
- from parl.env import CompatWrapper, is_gym_version_ge
- from cartpole_model import CartpoleModel
- from cartpole_agent import CartpoleAgent
- import argparse
-
- LEARNING_RATE = 1e-3
-
-
- # train an episode
- def run_train_episode(agent, env):
- obs_list, action_list, reward_list = [], [], []
- obs = env.reset()
- while True:
- obs_list.append(obs)
- action = agent.sample(obs)
- action_list.append(action)
-
- obs, reward, done, info = env.step(action)
- reward_list.append(reward)
-
- if done:
- break
- return obs_list, action_list, reward_list
-
-
- # evaluate 5 episodes
- def run_evaluate_episodes(agent, eval_episodes=5, render=False):
- # Compatible for different versions of gym
- if is_gym_version_ge("0.26.0") and render: # if gym version >= 0.26.0
- env = gym.make('CartPole-v1', render_mode="human")
- else:
- env = gym.make('CartPole-v1')
- env = CompatWrapper(env)
-
- eval_reward = []
- for i in range(eval_episodes):
- obs = env.reset()
- episode_reward = 0
- while True:
- action = agent.predict(obs)
- obs, reward, isOver, _ = env.step(action)
- episode_reward += reward
- if render:
- env.render()
- if isOver:
- break
- eval_reward.append(episode_reward)
- return np.mean(eval_reward)
-
-
- def calc_reward_to_go(reward_list, gamma=1.0):
- for i in range(len(reward_list) - 2, -1, -1):
- # G_i = r_i + γ·G_i+1
- reward_list[i] += gamma * reward_list[i + 1] # Gt
- return np.array(reward_list)
-
-
- def main():
- env = gym.make('CartPole-v1')
- # Compatible for different versions of gym
- env = CompatWrapper(env)
- # env = env.unwrapped # Cancel the minimum score limit
- obs_dim = env.observation_space.shape[0]
- act_dim = env.action_space.n
- logger.info('obs_dim {}, act_dim {}'.format(obs_dim, act_dim))
-
- # build an agent
- model = CartpoleModel(obs_dim=obs_dim, act_dim=act_dim)
- alg = parl.algorithms.PolicyGradient(model, lr=LEARNING_RATE)
- agent = CartpoleAgent(alg)
-
- # load model and evaluate
- # if os.path.exists('./model.ckpt'):
- # agent.restore('./model.ckpt')
- # run_evaluate_episodes(agent, env, render=True)
- # exit()
-
- for i in range(args.max_episodes):
- obs_list, action_list, reward_list = run_train_episode(agent, env)
- if i % 10 == 0:
- logger.info("Episode {}, Reward Sum {}.".format(
- i, sum(reward_list)))
-
- batch_obs = np.array(obs_list)
- batch_action = np.array(action_list)
- batch_reward = calc_reward_to_go(reward_list)
-
- agent.learn(batch_obs, batch_action, batch_reward)
- if (i + 1) % 100 == 0:
- total_reward = run_evaluate_episodes(agent, render=False)
- logger.info('Test reward: {}'.format(total_reward))
-
- # save the parameters to ./model.ckpt
- agent.save('./model.ckpt')
-
-
- if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- # Environment
- parser.add_argument(
- '--max_episodes',
- type=int,
- default=1000,
- help='stop condition: number of episodes')
- args = parser.parse_args()
- main()
|