In [ ]:
#https://tail-island.github.io/programming/2018/06/19/monte-carlo-tree-search.html
#上記記事のコードを実行してみる
In [ ]:
from random import random
from operator import attrgetter
In [ ]:
!pip install funcy
Requirement already satisfied: funcy in /usr/local/lib/python3.7/dist-packages (1.17)
In [ ]:
from funcy  import *


def _popcount(x):
    return bin(x).count('1')  # Pythonだと、コレが手軽で速いらしい。


# ゲームの状態。
class State:
    def __init__(self, pieces=0, enemy_pieces=0):
        self.pieces       = pieces
        self.enemy_pieces = enemy_pieces

    @property
    def lose(self):
        return any(lambda mask: self.enemy_pieces & mask == mask, (0b111000000, 0b000111000, 0b000000111, 0b100100100, 0b010010010, 0b001001001, 0b100010001, 0b001010100))

    @property
    def draw(self):
        return _popcount(self.pieces) + _popcount(self.enemy_pieces) == 9

    @property
    def end(self):
        return self.lose or self.draw

    @property
    def legal_actions(self):
        return tuple(i for i in range(9) if not self.pieces & 0b100000000 >> i and not self.enemy_pieces & 0b100000000 >> i)

    def next(self, action):
        return State(self.enemy_pieces, self.pieces | 0b100000000 >> action)

    def __str__(self):
        ox = ('o', 'x') if _popcount(self.pieces) == _popcount(self.enemy_pieces) else ('x', 'o')
        return '\n'.join(''.join((ox[0] if self.pieces & 0b100000000 >> i * 3 + j else None) or (ox[1] if self.enemy_pieces & 0b100000000 >> i * 3 + j else None) or '-' for j in range(3)) for i in range(3))
In [ ]:
from random import randint

def random_next_action(state):
    return state.legal_actions[randint(0, len(state.legal_actions) - 1)]
In [ ]:
from math import inf


# アルファ・ベータ法(正確にはネガ・アルファ法)
def nega_alpha(state, alpha, beta):
    if state.lose:
        return -1

    if state.draw:
        return  0

    for action in state.legal_actions:
        score = -nega_alpha(state.next(action), -beta, -alpha)

        if score > alpha:
            alpha = score

        if alpha >= beta:
            return alpha

    return alpha


# 次の手を返します(nega_alphaはスコアを返すので、手を返すようにするためにほぼ同じ関数が必要になっちゃいました)。
def nega_alpha_next_action(state):
    alpha = -inf

    for action in state.legal_actions:
        score = -nega_alpha(state.next(action), -inf, -alpha)
        if score > alpha:
            best_action = action
            alpha       = score

    return best_action
In [ ]:
# プレイアウト。
def playout(state):
    if state.lose:
        return -1

    if state.draw:
        return  0

    return -playout(state.next(random_next_action(state)))


# 集合の最大値のインデックスを返します。
def argmax(collection, key=None):
    return collection.index(max(collection, key=key) if key else max(collection))


# 原始モンテカルロ探索。
def monte_carlo_search_next_action(state):
    values = [0] * len(state.legal_actions)

    for i, action in enumerate(state.legal_actions):
        for _ in range(10):
            values[i] += -playout(state.next(action))

    return state.legal_actions[argmax(values)]
In [ ]:
from math import log
class node:
    def __init__(self, state):
        self.state       = state
        self.w           = 0     # 価値
        self.n           = 0     # 試行回数
        self.child_nodes = None  # 子ノード

    def evaluate(self):
        if self.state.end:
            value = -1 if self.state.lose else 0

            self.w += value
            self.n += 1

            return value

        if not self.child_nodes:
            value = playout(self.state)

            self.w += value
            self.n += 1

            if self.n == 10:
                self.expand

            return value
        else:
            value = -self.next_child_node().evaluate()

            self.w += value
            self.n += 1

            return value

    def expand(self):
        self.child_nodes = tuple(node(self.state.next(action)) for action in self.state.legal_actions)

    def next_child_node(self):
        def ucb1_values():
            t = sum(map(attrgetter('n'), self.child_nodes))

            return tuple(-child_node.w / child_node.n + 2 * (2 * log(t) / child_node.n) ** 0.5 for child_node in self.child_nodes)

        for child_node in self.child_nodes:
            if child_node.n == 0:
                return child_node

        ucb1_values = ucb1_values()

        return self.child_nodes[argmax(ucb1_values)]
In [ ]:
def monte_carlo_tree_search_next_action(state):
    root_node = node(state)
    root_node.expand()

    for _ in range(100):
        root_node.evaluate()

    return state.legal_actions[argmax(root_node.child_nodes, key=attrgetter('n'))]
In [ ]:
def main():
    def first_player_point(ended_state):
        if ended_state.lose:
            return 1 if (_popcount(ended_state.pieces) + _popcount(ended_state.enemy_pieces)) % 2 == 1 else 0

        return 0.5

    def test_algorithm(next_actions):
        total_point = 0

        for _ in range(100):
            state = State()

            for next_action in cat(repeat(next_actions)):
                if state.end:
                    break;

                state = state.next(next_action(state))

            total_point += first_player_point(state)

        return total_point / 100

    print("モンテカルロ木vsランダムの勝率")
    print(test_algorithm((monte_carlo_tree_search_next_action, random_next_action)))
    print("モンテカルロ木vsαβ法の勝率")
    print(test_algorithm((monte_carlo_tree_search_next_action, nega_alpha_next_action)))
In [ ]:
main()
モンテカルロ木vsランダムの勝率
1.0
モンテカルロ木vsαβ法の勝率
0.475