|
- from BaiduUAV.Final.component.Qnet import *
- from BaiduUAV.Final.component.Replay_buffer import *
- from BaiduUAV.Final.component.UAVenv import *
-
- # Hyperparameters
- learning_rate = 0.0001
- gamma = 0.98
- buffer_limit = 100000
- batch_size = 128
- tau= 0.01 # for target network soft update
- save_file='./trained_model/env_9_9_0.5_100/'
- def soft_update(net, net_target):
- for param_target, param in zip(net_target.parameters(), net.parameters()):
- param_target.data.copy_(param_target.data * (1.0 - tau) + param.data * tau)
-
- def train(q, q_target, memory, optimizer):
- s, a, r, s_prime, done_mask = memory.sample(batch_size)
- #s,s_prime:torch.size([self.uav_num,batch_size,2])
- #a:torch.size([self.uav_num,batch_size,1])
- #r:torch.size([self.uav_num,batch_size,1])
- s, a, r, s_prime, done_mask = s.cuda(), a.cuda(), r.cuda(), s_prime.cuda(), done_mask.cuda()
- q_out = q(s)#(self.uav_num,batch_size,num(a))
- q_a = q_out.gather(2, a)#(self.uav_num,batch_size,1)
- ###Double_DQN
- max_q_target_a=q(s_prime).detach().argmax(2).unsqueeze(2)
- max_q_prime = q_target(s_prime).gather(2, max_q_target_a) # (self.uav_num,batch_size,1)
- ###print('test_ok!!',max_q_target_a.size(),max_q_prime.size())#!!!!
- ###
- ###Nature DQN
- #max_q_prime = q_target(s_prime).max(2)[0].unsqueeze(2) #(self.uav_num,batch_size,1)
- ###
- target = r+ gamma * max_q_prime #(self.uav_num,batch_size,1)
- loss = F.smooth_l1_loss(q_a, target)
- optimizer.zero_grad()
- loss.backward()
- optimizer.step()
- soft_update(q,q_target)##目标网络参数软更新
- # q_target.load_state_dict(q.state_dict())
-
- def main():
- env = UAVEnv(4)
- memory = ReplayBuffer(buffer_limit)
- q = Qnet().cuda()
- q_target = Qnet().cuda()
- q_target.load_state_dict(q.state_dict())
- optimizer = optim.Adam(q.parameters(), lr=learning_rate)
- print_interval = 20
- train_interval = 50
- score = 0.0
- train_times=0
- save_interval=100
- all_score=[]
- all_train_times=[]
- save_parameters={'learning_rate':learning_rate,'gamma':gamma,'buffer_limit':buffer_limit,'batch_size':batch_size,\
- 'tau':tau}
- for n_epi in range(5001):
- epsilon = max(0.01, 0.08 - 0.01 * (n_epi / 200)) # Linear annealing from 8% to 1%
- s = env.reset()
- done = False
- for i in range(501):
- a = q.sample_action(torch.from_numpy(s).float().unsqueeze(1).cuda(), epsilon)
- s_prime, r, done = env.step(a)
- # env.render()
- done_mask = np.ones([4,1])#0.0 if done else 1.0
- memory.put((s, a, r, s_prime, done_mask))
- s = s_prime
- score += sum(r[0])
- if i % train_interval == 0 and i != 0 and memory.size() > 2000:
- train(q, q_target, memory, optimizer)
- train_times+=1
- if n_epi % print_interval == 0 and n_epi != 0:
- print("n_episode :{},train_times :{}, score : {:.1f}, n_buffer : {}, eps : {:.1f}%".format(
- n_epi,train_times, score / print_interval, memory.size(), epsilon * 100))
- all_score.append(score)
- all_train_times.append(train_times)
- score = 0.0
- if n_epi % save_interval ==0 and n_epi !=0:
- torch.save(q_target, save_file+'net_v4_{}.pt'.format(n_epi))
- env.close()
- ##保存相关参数信息
- save_parameters.update({'all_score':all_score,'all_train_times':all_train_times,'trained_model':q_target,'region_length':env.region_length\
- ,'region_width':env.region_width,'uav_move_distance':env.uav_move_instance})
- torch.save(save_parameters,save_file+'all_parameters.pt')
- torch.save(q_target, save_file+'net_v4_{}.pt'.format('end'))
-
-
- if __name__ == '__main__':
- main()
|