|
- import torch
- from envs.make_env_funcs import make_env_fn
- from envs.vec_dummy_env import DummyVecEnv
- from common.common_tools import get_config
-
- from torch_agents.agents.dqn_agent import DQN_Agent
- from torch_agents.utils.backbones import MLP_Backbone
- from torch_agents.policies.qnetwork import DeepQNet
-
- # Define Your Config
- config = get_config("configs/dqn/dqn_toy.yaml")
- # Define Your Environment
- envs = DummyVecEnv([make_env_fn(config.environment,i) for i in range(config.nenvs)])
- observation_space = envs.observation_space
- action_space = envs.action_space
- # Define Your Network
- backbone = MLP_Backbone(observation_space,
- hidden_size=(128,),
- activation=torch.nn.LeakyReLU,
- initialize=torch.nn.init.xavier_normal,
- device = config.device)
- policy = DeepQNet(action_space,
- backbone,
- dueling = False,
- hidden_size=(128,),
- normalize=None,
- activation=torch.nn.LeakyReLU,
- initialize=torch.nn.init.xavier_normal,
- device = config.device)
- # Define Your Optimizer
- optimizer = torch.optim.Adam(policy.parameters(),lr=config.learning_rate)
- schedular = torch.optim.lr_scheduler.StepLR(optimizer,step_size=500,gamma=0.995,last_epoch=-1)
- # Define Your Agent
- agent = DQN_Agent(config,envs,policy,optimizer,schedular,device=config.device)
- agent.train(int(config.init_steps),int(config.train_steps))
- mu,std = agent.test(20)
- print("test_mean={},test_std={}".format(mu,std))
|