|
- import torch.nn as nn
- import torch
- from datetime import datetime
- from Classes import AdjTable
- from Classes import Klotski
- from Classes import LinkedList
- from Classes import Node
- from Other_Function import SolveMethodTable
- from MCTS_Function import Simulation
- from MCTS_Function import Simulation3
- from MCTS_Function import Expansion
- from MCTS_Function import Selection
- from MCTS_Function import Backpropagation
-
- # 选择/扩展/仿真/反向传播
- LEVEL, NUM_PIECES = 6, 3
- ROWS, COLUMN = 5, 4
- klotski = Klotski(LEVEL, NUM_PIECES, ROWS, COLUMN)
- SIMULATION_TIMES = 20
- print('\n', klotski.map)
-
- def main():
- ## 初始化
- timeSelection = 0
- timeExpansion = 0
- timeSimulation = 0
- timePropagation = 0
- # 建立邻接表
- table = AdjTable()
- # 创建链表
- node_link = Node(klotski)
- node_link.Order = 1
- node_link.Layer = 1
- link_layer = LinkedList()
- link_layer.insert(node_link)
- # 将链表存至邻接表首位置
- table.insert(link_layer)
-
- for i in range(1, 3143655151):
- tic = datetime.now()
- ## 选择
- table = Selection(table) # 选择一个叶节点
- timeSelection = timeSelection + (datetime.now() - tic).total_seconds()
-
- # print('####table.tail.order = ', table.size, '#table.index.point.Layer = ', table.index.Order, '##第', i, '次')
-
- # print(table.index.Data.point.Data.map)
- if table.index.Data.point.Times != 0: # 若该节点被访问,对其进行扩展
- tic = datetime.now()
- ## 扩展
- [table, caoMove, FRE] = Expansion(table)
- if caoMove == 100:
- table = Backpropagation(table)
- break
- #找到游戏结束状态,反向传播并结束程序
- elif caoMove == -100:
- break
- # *2的方块无法移动,退出循环
- elif FRE != 1:
- continue
- #当前节点无法生成有效子状态,经循环:删除,并尝试删除上层后重新选择节点
- timeExpansion = timeExpansion + (datetime.now() - tic).total_seconds()
- # print('扩展历时:', (datetime.now() - tic).total_seconds())
-
- tic = datetime.now()
- ## 模拟
- # if table.index.Order >= layer_simulation:
- if LEVEL <= 5:
- [value, _] = Simulation(table, SIMULATION_TIMES)
- # [value, _] = Simulation2(table.index.Data.point.Data, SIMULATION_TIMES)
- else:
- value = Simulation3(table.index.Data.point.Data, SIMULATION_TIMES)
- table.index.Data.point.Value = table.index.Data.point.Value + value
- # 评估并更新当前指向节点的价值
- timeSimulation = timeSimulation + (datetime.now() - tic).total_seconds()
- tic = datetime.now()
- ## 反向传播
- table = Backpropagation(table)
- timePropagation = timePropagation + (datetime.now() - tic).total_seconds()
- # print('反向传播历时:', (datetime.now() - tic).total_seconds())
- strategy = SolveMethodTable(table)
- # print('\n选择耗时:', round(timeSelection, 3), '\n扩展耗时:', round(timeExpansion, 3), '\n模拟耗时:', round(timeSimulation, 3),'\n反向传播耗时:', round(timePropagation, 3))
- return table, strategy, timeSelection + timeExpansion + timeSimulation + timePropagation
-
- if __name__ == "__main__":
- n = 1
- timeA = 0
- for i in range(n):
- print('\n',i)
- table, strategy, time = main()
- timeA = timeA + time
- time = timeA/n
- print('\n',time)
|