-
Notifications
You must be signed in to change notification settings - Fork 0
/
scalable_go_client.cpp
164 lines (130 loc) · 5 KB
/
scalable_go_client.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
// Copyright [2015, 2016] <[email protected]>
#include <array>
#include <vector>
#include <iostream>
#include <string>
#include <limits>
#include <stdexcept>
#include "gogame.h"
#include "gogamenn.h"
#include "gogameab.h"
#include "gohelpers.h"
#define DEPTH 1
class ClientArgumentError : public std::runtime_error {
public:
ClientArgumentError() : std::runtime_error("ClientArgumentError") { }
};
class ClientImportError : public std::runtime_error {
public:
ClientImportError() : std::runtime_error("ClientImportError") { }
};
std::array<uint8_t, 2> play_game(GoGameNN &i_network, const uint8_t board_size) {
// Array of Vectors to hold win counts for networks
std::array<uint8_t , 2> scores = {0, 0};
// Currently, networks plays as black. Player plays as white.
// GoGame instance used for matches
GoGame game(board_size);
GoMove best_move(game.get_board());
std::string x_in, y_in;
// Bool to determine if game should continue
bool continue_match = true;
// Value of best move
double best_move_value, temp_best_move_value = 0;
while (continue_match) {
std::cout << "Black taking move... \n";
// Generate and take black move
game.generate_moves(0);
best_move_value = -std::numeric_limits<double>::infinity();
// For each possible move, calculate Alpha Beta
for (const GoMove &element : game.get_move_list()) {
GoGame temp_game(game);
temp_game.make_move(element, 0);
temp_best_move_value = scalable_go_ab_prune(i_network, temp_game, DEPTH,
-std::numeric_limits<double>::infinity(),
std::numeric_limits<double>::infinity(), 1, false, 0);
if (temp_best_move_value > best_move_value) {
best_move_value = temp_best_move_value;
best_move = element;
}
}
// Make White move
game.make_move(best_move, 0);
// Check if passed
if (best_move.check_pass()) {
std::cout << "Black passed\n";
} else {
std::cout << "Black made move: (" << int(best_move.get_piece().x) << ", " << int(best_move.get_piece().y) <<
")\n";
}
// White is player move
// Get player move
while (true) {
std::cout << "Current board state: \n";
render_board(game.get_board());
std::cout << "Enter coordinates for where you would like to move, or pass.\n";
std::cout << "Enter X or pass: \n";
std::cin >> x_in;
std::cout << "Received " << x_in << std::endl;
if (x_in == "pass") {
best_move = GoMove(game.get_board());
} else {
std::cout << "Enter Y: \n";
std::cin >> y_in;
std::cout << "Received " << y_in << std::endl;
try {
best_move = GoMove(game.get_board(), XYCoordinate(stoul(x_in), stoul(y_in)));
} catch(...) {
std::cout << "Error creating move. Please try again.\n";
continue;
}
}
// Attempt to make move.
try {
best_move.check_move(1);
game.make_move(best_move, 1);
} catch (...) {
std::cout << "Invalid move, please try again.\n";
continue;
}
break;
}
// Game end detection
std::vector<GoMove> history(game.get_move_history());
// Check if the last 2 moves were passes. If so, end
if (history[history.size() - 1].check_pass() && history[history.size() - 2].check_pass()) {
scores = game.calculate_scores();
// Else, draw... assign no scores.
continue_match = false;
}
}
return scores;
}
int main(int argc, char* argv[]) {
uint8_t board_size = 0;
std::string network_file_path = "";
bool network_uniform = 0;
// Validate command line parameters
if (argc == 4) {
// TODO(wdfraser): Add some better error checking
board_size = uint8_t(atoi(argv[1]));
network_file_path = argv[2];
network_uniform = atoi(argv[3]) != 0;
} else {
throw ClientArgumentError();
}
GoGameNN client_network(board_size, network_uniform);
// Import network
std::ifstream network_in(network_file_path);
if (network_in.is_open()) {
client_network.import_weights_stream(network_in);
} else {
std::cout << "Network file failed to open. Exiting. \n";
throw ClientImportError();
}
std::cout << "Game start, Network goes first: " << std::endl;
std::array<uint8_t, 2> scores = play_game(client_network, board_size);
// Who won and score.
std::cout << "Game over, scores: \n"
<< "Black Score: " << int(scores[0]) << std::endl
<< "White Score: " << int(scores[1]) << std::endl;
}