|
- # Third party code
- #
- # The following code are copied or modified from:
- # https://github.com/suragnair/alpha-zero-general
-
- import math
- import numpy as np
-
- EPS = 1e-8
-
-
- class MCTS():
- """
- This class handles the MCTS tree.
- """
-
- def __init__(self, game, nn_agent, args, dirichlet_noise=False):
- self.game = game
- self.nn_agent = nn_agent
- self.args = args
- self.dirichlet_noise = dirichlet_noise
- self.Qsa = {} # stores Q values for s,a (as defined in the paper)
- self.Nsa = {} # stores #times edge s,a was visited
- self.Ns = {} # stores #times board s was visited
- self.Ps = {} # stores initial policy (returned by neural net)
-
- self.Es = {} # stores game.getGameEnded ended for board s
- self.Vs = {} # stores game.getValidMoves for board s
-
- def getActionProb(self, canonicalBoard, temp=1):
- """
- This function performs numMCTSSims simulations of MCTS starting from
- canonicalBoard.
-
- Returns:
- probs: a policy vector where the probability of the ith action is
- proportional to Nsa[(s,a)]**(1./temp)
- """
- for i in range(self.args.numMCTSSims):
- dir_noise = (i == 0 and self.dirichlet_noise)
- self.search(canonicalBoard, dirichlet_noise=dir_noise)
-
- s = self.game.stringRepresentation(canonicalBoard)
- counts = [
- self.Nsa[(s, a)] if (s, a) in self.Nsa else 0
- for a in range(self.game.getActionSize())
- ]
-
- if temp == 0:
- bestAs = np.array(np.argwhere(counts == np.max(counts))).flatten()
- bestA = np.random.choice(bestAs)
- probs = [0] * len(counts)
- probs[bestA] = 1
- return probs
-
- counts = [x**(1. / temp) for x in counts]
- counts_sum = float(sum(counts))
- probs = [x / counts_sum for x in counts]
- return probs
-
- def search(self, canonicalBoard, dirichlet_noise=False):
- """
- This function performs one iteration of MCTS. It is recursively called
- till a leaf node is found. The action chosen at each node is one that
- has the maximum upper confidence bound as in the paper.
-
- Once a leaf node is found, the neural network is called to return an
- initial policy P and a value v for the state. This value is propagated
- up the search path. In case the leaf node is a terminal state, the
- outcome is propagated up the search path. The values of Ns, Nsa, Qsa are
- updated.
-
- NOTE: the return values are the negative of the value of the current
- state. This is done since v is in [-1,1] and if v is the value of a
- state for the current player, then its value is -v for the other player.
-
- Returns:
- v: the negative of the value of the current canonicalBoard
- """
-
- s = self.game.stringRepresentation(canonicalBoard)
-
- if s not in self.Es:
- self.Es[s] = self.game.getGameEnded(canonicalBoard, 1)
- if self.Es[s] != 0:
- # terminal node
- return -self.Es[s]
-
- if s not in self.Ps:
- # leaf node
- self.Ps[s], v = self.nn_agent.predict(canonicalBoard)
-
- valids = self.game.getValidMoves(canonicalBoard, 1)
- self.Ps[s] = self.Ps[s] * valids # masking invalid moves
- if dirichlet_noise:
- self.applyDirNoise(s, valids)
- sum_Ps_s = np.sum(self.Ps[s])
- if sum_Ps_s > 0:
- self.Ps[s] /= sum_Ps_s # renormalize
- else:
- # if all valid moves were masked make all valid moves equally probable
-
- # NB! All valid moves may be masked if either your NNet architecture is insufficient or you've get overfitting or something else.
- # If you have got dozens or hundreds of these messages you should pay attention to your NNet and/or training process.
- print("All valid moves were masked, doing a workaround.")
- self.Ps[s] = self.Ps[s] + valids
- self.Ps[s] /= np.sum(self.Ps[s])
-
- self.Vs[s] = valids
- self.Ns[s] = 0
- return -v
-
- valids = self.Vs[s]
- if dirichlet_noise:
- self.applyDirNoise(s, valids)
- sum_Ps_s = np.sum(self.Ps[s])
- self.Ps[s] /= sum_Ps_s # renormalize
- cur_best = -float('inf')
- best_act = -1
-
- # pick the action with the highest upper confidence bound
- for a in range(self.game.getActionSize()):
- if valids[a]:
- if (s, a) in self.Qsa:
- u = self.Qsa[
- (s, a)] + self.args.cpuct * self.Ps[s][a] * math.sqrt(
- self.Ns[s]) / (1 + self.Nsa[(s, a)])
- else:
- u = self.args.cpuct * self.Ps[s][a] * math.sqrt(
- self.Ns[s] + EPS) # Q = 0 ?
-
- if u > cur_best:
- cur_best = u
- best_act = a
-
- a = best_act
- next_s, next_player = self.game.getNextState(canonicalBoard, 1, a)
- next_s = self.game.getCanonicalForm(next_s, next_player)
-
- v = self.search(next_s)
-
- if (s, a) in self.Qsa:
- self.Qsa[(s, a)] = (self.Nsa[(s, a)] * self.Qsa[
- (s, a)] + v) / (self.Nsa[(s, a)] + 1)
- self.Nsa[(s, a)] += 1
-
- else:
- self.Qsa[(s, a)] = v
- self.Nsa[(s, a)] = 1
-
- self.Ns[s] += 1
- return -v
-
- def applyDirNoise(self, s, valids):
- dir_values = np.random.dirichlet(
- [self.args.dirichletAlpha] * np.count_nonzero(valids))
- dir_idx = 0
- for idx in range(len(self.Ps[s])):
- if self.Ps[s][idx]:
- self.Ps[s][idx] = (0.75 * self.Ps[s][idx]) + (
- 0.25 * dir_values[dir_idx])
- dir_idx += 1
|