|
- # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
-
- import os
- import gym
- import numpy as np
- import parl
- import argparse
- from parl.utils import logger, ReplayMemory
- from cartpole_model import CartpoleModel
- from cartpole_agent import CartpoleAgent
- from parl.env import CompatWrapper, is_gym_version_ge
- from parl.algorithms import DQN
-
- LEARN_FREQ = 5 # training frequency
- MEMORY_SIZE = 200000
- MEMORY_WARMUP_SIZE = 200
- BATCH_SIZE = 64
- LEARNING_RATE = 0.0005
- GAMMA = 0.99
-
-
- # train an episode
- def run_train_episode(agent, env, rpm):
- total_reward = 0
- obs = env.reset()
- step = 0
- while True:
- step += 1
- action = agent.sample(obs)
- next_obs, reward, done, _ = env.step(action)
- rpm.append(obs, action, reward, next_obs, done)
- # train model
- if (len(rpm) > MEMORY_WARMUP_SIZE) and (step % LEARN_FREQ == 0):
- # s,a,r,s',done
- (batch_obs, batch_action, batch_reward, batch_next_obs,
- batch_done) = rpm.sample_batch(BATCH_SIZE)
- train_loss = agent.learn(batch_obs, batch_action, batch_reward,
- batch_next_obs, batch_done)
-
- total_reward += reward
- obs = next_obs
- if done:
- break
- return total_reward
-
-
- # evaluate 5 episodes
- def run_evaluate_episodes(agent, eval_episodes=5, render=False):
- # Compatible for different versions of gym
- if is_gym_version_ge("0.26.0") and render: # if gym version >= 0.26.0
- env = gym.make('CartPole-v1', render_mode="human")
- else:
- env = gym.make('CartPole-v1')
- env = CompatWrapper(env)
-
- eval_reward = []
- for i in range(eval_episodes):
- obs = env.reset()
- episode_reward = 0
- while True:
- action = agent.predict(obs)
- obs, reward, done, _ = env.step(action)
- episode_reward += reward
- if render:
- env.render()
- if done:
- break
- eval_reward.append(episode_reward)
- return np.mean(eval_reward)
-
-
- def main():
- env = gym.make('CartPole-v0')
- # Compatible for different versions of gym
- env = CompatWrapper(env)
- obs_dim = env.observation_space.shape[0]
- act_dim = env.action_space.n
- logger.info('obs_dim {}, act_dim {}'.format(obs_dim, act_dim))
-
- # set action_shape = 0 while in discrete control environment
- rpm = ReplayMemory(MEMORY_SIZE, obs_dim, 0)
-
- # build an agent
- model = CartpoleModel(obs_dim=obs_dim, act_dim=act_dim)
- alg = DQN(model, gamma=GAMMA, lr=LEARNING_RATE)
- agent = CartpoleAgent(
- alg, act_dim=act_dim, e_greed=0.1, e_greed_decrement=1e-6)
-
- # warmup memory
- while len(rpm) < MEMORY_WARMUP_SIZE:
- run_train_episode(agent, env, rpm)
-
- max_episode = args.max_episode
-
- # start training
- episode = 0
- while episode < max_episode:
- # train part
- for i in range(50):
- total_reward = run_train_episode(agent, env, rpm)
- episode += 1
-
- # test part
- eval_reward = run_evaluate_episodes(agent, render=False)
- logger.info('episode:{} e_greed:{} Test reward:{}'.format(
- episode, agent.e_greed, eval_reward))
-
- # save the parameters to ./model.ckpt
- save_path = './model.ckpt'
- agent.save(save_path)
-
- # save the model and parameters of policy network for inference
- save_inference_path = './inference_model'
- input_shapes = [[None, env.observation_space.shape[0]]]
- input_dtypes = ['float32']
- agent.save_inference_model(save_inference_path, input_shapes, input_dtypes)
-
-
- if __name__ == '__main__':
-
- parser = argparse.ArgumentParser()
- parser.add_argument(
- '--max_episode',
- type=int,
- default=800,
- help='stop condition: number of max episode')
- args = parser.parse_args()
-
- main()
|