|
- from xuance_torch.representations import Basic_MLP,Basic_Identical
- from xuance_torch.policies import TD3Policy
- from xuance_torch.agents import TD3_Agent
- from environment import make_envs
- from common import get_config,space2shape
- import argparse
- import torch
- import itertools
- def get_args():
- parser = argparse.ArgumentParser()
- parser.add_argument("--env_id",type=str,default="InvertedPendulum-v2")
- parser.add_argument("--seed",type=int,default=2910)
- parser.add_argument("--vectorize",type=str,default="Dummy")
- parser.add_argument("--config_path",type=str,default="./configs/td3/mujoco.yaml")
- args = parser.parse_known_args()[0]
- return args
-
- if __name__ == "__main__":
- args = get_args()
- config = get_config(args.config_path)
- envs = make_envs(args.env_id,args.seed,args.vectorize,config)
- observation_space = envs.observation_space
- action_space = envs.action_space
- # representation = Basic_MLP(space2shape(observation_space),
- # config.representation_hidden_size,
- # None,
- # torch.nn.init.orthogonal_,rm
- # torch.nn.Tanh,
- # config.device)
- representation = Basic_Identical(space2shape(observation_space),config.device)
- policy = TD3Policy(action_space,
- representation,
- config.actor_hidden_size,
- config.critic_hidden_size,
- initialize=None,
- activation=torch.nn.LeakyReLU,
- device = config.device)
- actor_optimizer = torch.optim.Adam(policy.actor.parameters(),config.actor_learning_rate)
- critic_optimizer = torch.optim.Adam(itertools.chain(policy.criticA.parameters(),policy.criticB.parameters()),config.critic_learning_rate)
- actor_lr_scheduler = torch.optim.lr_scheduler.LinearLR(actor_optimizer,start_factor=1.0,end_factor=0.25,total_iters=int(config.training_steps/config.training_frequency))
- critic_lr_scheduler = torch.optim.lr_scheduler.LinearLR(critic_optimizer,start_factor=1.0,end_factor=0.25,total_iters=int(config.training_steps/config.training_frequency))
- agent = TD3_Agent(config,envs,policy,[actor_optimizer,critic_optimizer],[actor_lr_scheduler,critic_lr_scheduler],config.device)
- agent.train(config.training_steps)
-
-
|