|
- from xuance_torch.representations import Basic_MLP
- from xuance_torch.policies import BasicQnetwork
- from xuance_torch.agents import DQN_Agent
- from environment import make_envs
- from common import get_config,space2shape
- import argparse
- import torch
-
- def get_args():
- parser = argparse.ArgumentParser()
- parser.add_argument("--env_id",type=str,default="CartPole-v0")
- parser.add_argument("--seed",type=int,default=2910)
- parser.add_argument("--vectorize",type=str,default="Dummy")
- parser.add_argument("--config_path",type=str,default="./configs/dqn/toy.yaml")
- args = parser.parse_known_args()[0]
- return args
-
- if __name__ == "__main__":
- args = get_args()
- config = get_config(args.config_path)
- envs = make_envs(args.env_id,args.seed,args.vectorize,config)
- observation_space = envs.observation_space
- action_space = envs.action_space
-
- representation = Basic_MLP(space2shape(observation_space),
- config.representation_hidden_size,
- None,
- torch.nn.init.orthogonal_,
- torch.nn.Tanh,
- config.device)
- policy = BasicQnetwork(action_space,
- representation,
- config.q_hidden_size,
- normalize=None,
- initialize=torch.nn.init.orthogonal_,
- activation=torch.nn.Tanh,
- device = config.device)
- optimizer = torch.optim.Adam(policy.parameters(),config.learning_rate,eps=1e-5)
- lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer,start_factor=1.0,end_factor=0.5,total_iters=int(config.training_steps/config.training_frequency))
- agent = DQN_Agent(config,envs,policy,optimizer,lr_scheduler,config.device)
- agent.train(config.training_steps)
-
-
|