Skip to content

Commit 072161f

Browse files
committed
clean 1/n
1 parent cf2bc23 commit 072161f

14 files changed

+74
-377
lines changed

CPP_backend.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
/*
2+
MCST implementation in C++ for chess
3+
Interfaced with python using boost python
4+
Check README for installation and compilation instructions
5+
*/
16
#include <boost/python.hpp>
27
#include <Python.h>
38
#include <boost/multiprecision/cpp_dec_float.hpp>
@@ -13,7 +18,8 @@
1318
#include <unordered_map>
1419
long long int cnt = 0;
1520

16-
#define ll long long int
21+
typedef long long int ll;
22+
1723
#define AMOUNT_OF_PLANES 73
1824
#define BOARD_SIZE 8
1925
#define DIRICHLET_NOISE 0.3
@@ -25,6 +31,7 @@ class C_Edge;
2531
class C_MCTS;
2632
class C_Edge;
2733

34+
// One Node in the MCTS tree
2835
class C_Node{
2936
public:
3037
std::string state;
@@ -44,6 +51,7 @@ class C_Node{
4451
uint64_t get_edge(boost::python::object action);
4552
};
4653

54+
// One Edge in the MCTS tree, which corresponds to one action
4755
class C_Edge{
4856
public:
4957
C_Node* input_node;
@@ -63,9 +71,6 @@ class C_Edge{
6371
uint64_t get_N();
6472
};
6573

66-
// class C_Action{
67-
68-
// }
6974

7075
class C_MCTS{
7176
public:
@@ -94,7 +99,7 @@ class C_MCTS{
9499
~C_MCTS();
95100
};
96101

97-
102+
// Recursively delete the MCTS tree beginning at node
98103
void delete_mcts_tree(C_Node* node){
99104

100105
if(! node) return;
@@ -131,11 +136,9 @@ std::string C_Node::step(boost::python::object action){
131136
}
132137

133138
bool C_Node::is_game_over(){
134-
// Py_Initialize();
135139
boost::python::object chess_module = boost::python::import("chess");
136140
boost::python::object board = chess_module.attr("Board")(this->state);
137141
bool is_game_over = boost::python::extract<bool>(board.attr("is_game_over")());
138-
// Py_Finalize();
139142
return is_game_over;
140143
}
141144

@@ -314,8 +317,7 @@ void C_MCTS::map_valid_move(boost::python::object move){
314317

315318
}
316319

317-
// TODO : Need suggestions of how to implement probabilities to actions
318-
320+
//
319321
std::unordered_map<std::string, long double> C_MCTS::probabilities_to_actions(boost::python::object probabilities, std::string bord){
320322
std::unordered_map <std::string, long double> actions;
321323
boost::python::object chess = boost::python::import("chess");
@@ -484,6 +486,9 @@ std::string C_MCTS::get_edge_uci(uint64_t edge){
484486
return boost::python::extract<std::string>(((C_Edge*)edge)->action.attr("uci")());
485487
}
486488

489+
// Python interface for the MCTS class
490+
// The Names of the objects are the string arguments
491+
487492
BOOST_PYTHON_MODULE(CPP_backend)
488493
{
489494
Py_Initialize();

agent.py

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,14 @@
1-
import numpy as np
1+
# Contains the Agent class, which is used to play chess moves on the environment.
22
import torch
3-
import torch.nn as nn
4-
import torch.nn.functional as F
53
from neural_network import AgentNetwork
6-
# from mcts import MCTS # missing
7-
from CPP_backend import *
8-
import utils
9-
import chess
10-
import time
4+
from CPP_backend import MCTS
5+
6+
from chess import STARTING_FEN
117
import config
128
import datetime
139

1410
class Agent:
15-
def __init__(self,local_preds:bool = False, model_path:str|None = None,state:str = chess.STARTING_FEN, device=None):
11+
def __init__(self,local_preds:bool = False, model_path:str|None = None,state:str = STARTING_FEN, device=None)->None:
1612
"""
1713
An agent is an object that can play chessmoves on the environment.
1814
Based on the parameters, it can play with a local model, or send its input to a server.
@@ -32,14 +28,13 @@ def __init__(self,local_preds:bool = False, model_path:str|None = None,state:str
3228
if model_path is not None:
3329
self.model.load_state_dict(torch.load(model_path))
3430
else :
35-
raise NotImplementedError("Server predictions not implemented yet")
31+
raise NotImplementedError("Server predictions not implemented")
3632

3733
self.state = state
3834
self.mcts = MCTS(self, state, True)
3935

40-
def run_simulations(self,n:int=1):
36+
def run_simulations(self,n:int=1)->None:
4137
with torch.no_grad():
42-
4338
self.mcts.run_simulations(n)
4439

4540
def save_model(self,timestamped:bool = False)->str:
@@ -53,24 +48,16 @@ def save_model(self,timestamped:bool = False)->str:
5348

5449
return model_path
5550

56-
def predict(self, data:torch.Tensor):
51+
def predict(self, data:torch.Tensor)->torch.Tensor:
5752
data = torch.Tensor(data).to(torch.float32).unsqueeze(0).to(self.device)
58-
# print(data.shape)
59-
# print("in agent predict")
6053
if self.local_preds:
61-
# print('local')
6254
return self.predict_local(data)
6355
return self.predict_server(data)
6456

65-
def predict_local(self,data:torch.Tensor):
66-
# self.model.eval()
67-
57+
def predict_local(self,data:torch.Tensor)->(torch.Tensor,float):
6858
with torch.no_grad():
6959
v, p = self.model(data)
7060
return p.cpu(), v.cpu().item()
7161

7262
def predict_server(self,data:torch.Tensor):
73-
raise NotImplementedError("Server predictions not implemented yet")
74-
75-
if __name__=="__main__":
76-
pass
63+
raise NotImplementedError("Server predictions not implemented")

chessEnv.py

Lines changed: 3 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,81 +1,4 @@
1-
# ### This is the Chess Environment that will interact with the Agent
2-
3-
# ### TODO: make this as c as possible
4-
5-
# import numpy as np
6-
# import chess
7-
# import torch
8-
# from copy import deepcopy
9-
10-
# btoi = lambda x: (1 if x else -1)
11-
# strenc = np.array(['r', 'n', 'b', 'q', 'k','p','P','R', 'N', 'B', 'Q', 'K']).reshape(1,12,1,1)
12-
13-
# ### TODO:docs
14-
# class ChessEnv:
15-
# def __init__(self,board:str|list,batch_size:int,board_size:int=8,torch_device:str='cuda'):
16-
# self.batch_size = batch_size
17-
# self.board_size = board_size
18-
# self.torch_device = torch_device
19-
# self.num_piecetype = 12
20-
21-
# self.__init_board_frm_str(board)
22-
23-
24-
25-
# ### Interface functions ###
26-
# def get_embedding(self)->torch.Tensor:
27-
# self.__update_embedding()
28-
# ## expand out (_,7) tensor to (_,7,8,8) tensor
29-
# self.board_states_embedding = self.board_states.unsqueeze(2).unsqueeze(3).repeat(1,1,self.board_size,self.board_size)
30-
# return torch.cat([self.board_embedding,self.board_states_embedding],dim=1)
31-
32-
33-
# ### Functions to convert between different representations of the board ###
34-
# def __init_board_frm_str(self,board:str|list)->None:
35-
# if type(board) == str:
36-
# self.board_init = (chess.Board(board),)
37-
# elif type(board) == list:
38-
# self.board_init = [chess.Board(b) for b in board]
39-
# self.board = self.board_init.deepcopy()
40-
41-
# ### this will be changed !!
42-
# self.movenum = 0
43-
# ### load the initial embedding ### -> assuming initially history is repeated rather than empty
44-
# self.board_embedding = self.__board_to_tensor(self.board).repeat(1,self.board_size,1,1)
45-
46-
# self.reps = torch.zeros((self.batch_size,1),device=self.torch_device)
47-
48-
49-
50-
# ## remember devices
51-
52-
# ## (turn, 4 castling rights,movenum)
53-
# def __get_board_states_single(self,board:chess.Board)->None:
54-
# return torch.tensor([[btoi(board.turn),btoi(board.has_kingside_castling_rights(chess.WHITE)),btoi(board.has_kingside_castling_rights(chess.BLACK)),btoi(board.has_queenside_castling_rights(chess.WHITE)),btoi(board.has_queenside_castling_rights(chess.BLACK)),self.movenum]],device=self.torch_device)
55-
# ## add reps too
56-
# def __get_board_states(self)->torch.Tensor:
57-
# sixvars = torch.cat([self.__get_board_states_single(b) for b in self.board],dim=0)
58-
# return torch.cat((sixvars,self.reps),dim=1)
59-
# # f me #
60-
# def __board_to_tensor(self,boards:list|tuple)->torch.Tensor:
61-
# arr = (np.array([b.__str__().split() for b in boards]).reshape(-1,1,self.board_size,self.board_size)==strenc)*1
62-
# return torch.Tensor(arr, device=self.torch_device)
63-
64-
# def __update_embedding(self)->None: ### maybe this works correctly
65-
# self.board_embedding = torch.cat([self.board_embedding[:,self.num_piecetype:,:,:],self.__board_to_tensor(self.board)],dim=0)
66-
# self.reps = (self.reps + 1)*torch.all(torch.all(torch.all(self.board_embedding[:,-2*self.num_piecetype:-self.num_piecetype,:,:]==self.board_embedding[:,-self.num_piecetype:,:,:],dim=3),dim=2),dim=1,keepdim=True)
67-
# self.board_states = self.__get_board_states()
68-
# return
69-
70-
# ## convert moves to mask
71-
# def __moves_to_mask(self,moves)->torch.Tensor:
72-
# pass
73-
74-
# ## taking the one-hot encoding of the move chosen and updates the board based on it
75-
# def __make_move(self,movetensor)->list:
76-
# pass
77-
78-
1+
### This is the Chess Environment that will interact with the Agent and the MCTS
792
#---------------------------- A simpler implementation for now ---------------
803

814
import config
@@ -139,11 +62,12 @@ def state_to_input(fen: str):
13962
chess.KING: 0
14063
}
14164

65+
### ensure stockfish is installed and the path is correct
14266
stockfish = Stockfish(os.path.expanduser(config.STOCKFISH))
14367

14468
def estimate_winner(board: chess.Board) -> int:
14569
"""
146-
Estimate the winner of the current node.
70+
Estimate the winner of the current node. A piece counting heuristic is used.
14771
Pawn = 1, Bishop = 3, Rook = 5, Queen = 9
14872
Positive score = white wins, negative score = black wins
14973
"""

collect.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# Description: This file is used to collect experience data from the lichess database.
12
import pandas as pd
23
import config
34
import threading
@@ -12,9 +13,12 @@
1213
import os
1314
import numpy as np
1415

16+
# Location of the puzzles.csv file. Ensure the file is present at this location
1517
CSV_FILE = "puzzles.csv"
18+
# Number of games to play at once
1619
N = 5
1720

21+
# Play a puzzle starting from the given fen and moves. This is still self-play
1822
def play_puzzle(fen, moves):
1923
model_path = None if len(os.listdir(config.BEST_MODEL)) == 0 else f"{config.BEST_MODEL}best-model.pth"
2024
white = Agent(local_preds=True, model_path=model_path)
@@ -27,6 +31,7 @@ def play_puzzle(fen, moves):
2731

2832
game.game()
2933

34+
# Play a full game from the start
3035
def play_normal():
3136
model_path = None if len(os.listdir(config.BEST_MODEL)) == 0 else f"{config.BEST_MODEL}best-model.pth"
3237
white=Agent(local_preds=True, model_path=model_path)

config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# Parameters for all the files #
2+
13
BOARD_SIZE = 8
24
MAX_MOVES = 50
35
PREVIOUS_MOVES = 8
@@ -21,6 +23,7 @@
2123
MEMORY = "./memory/"
2224
PUZZLE = "./puzzles/"
2325
BEST_MODEL = "./best_model/"
26+
PGN = "./pgn/"
2427

2528
#----------Executable Locations--------------------
2629
STOCKFISH = "~/stockfish/stockfish-ubuntu-x86-64-modern"

0 commit comments

Comments
 (0)