Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Implement learning from draws
  • Loading branch information
sas12028 committed Apr 26, 2017
1 parent 0604ecc commit b1825e5
Show file tree
Hide file tree
Showing 7 changed files with 214 additions and 128 deletions.
20 changes: 12 additions & 8 deletions src/BaseEvaluator.java
Expand Up @@ -23,21 +23,25 @@ public class BaseEvaluator implements Evaluator{
}

public double evaluate(CheckersGameState s, int player){
if(s.isTerminal()){
if(s.winner() == player){
return 200; // what should this be?
}
else{
return 0; // assuming only positive evalutions
}
}
//if(s.isTerminal()){
// if(s.winner() == player){
// return 1000; // what should this be?
// }
// else{
// return 0; // assuming only positive evalutions
// }
//}
double[] params = s.getFeatures(player);
return dot(this.weights, params);
}

public void refreshWeights(){
this.weights = this.wp.getWeights(this.file);
}
public void commitWeights(String path){
this.wp.writeWeights(path, this.weights); // method to commit weights to beta. provide path to beta csv
}



}
Expand Down
3 changes: 2 additions & 1 deletion src/CheckersAI.java
@@ -1,3 +1,5 @@
import java.util.Arrays;

public class CheckersAI{


Expand Down Expand Up @@ -52,7 +54,6 @@ public class CheckersAI{
double v = Double.NEGATIVE_INFINITY;
double check;
Move max = null;
System.out.println(s.actions().size());
for(Move a: s.actions()){
check = minValue(s.result(a), alpha, beta, depth + 1, a.isJump(), min_ply);
if(check > v){
Expand Down
15 changes: 15 additions & 0 deletions src/CheckersGameState3.java
Expand Up @@ -535,6 +535,21 @@ public class CheckersGameState3 implements CheckersGameState{
}
}

public int numPieces(int player){
int tot = 0;
for(int i: this.board){
if(i == player){
tot += 1;
}
if(i == player + 2){
tot += 2;
}
}
return tot;
}



public void printState(){
boolean leading = false;
int printed = 0;
Expand Down
102 changes: 65 additions & 37 deletions src/Learn.java
Expand Up @@ -3,75 +3,103 @@ import java.util.Random;
public class Learn{

public static void main(String[] args){

final int num_games = 30;
LearningEvaluator le = new LearningEvaluator("../src/weights/alpha.csv", .1);
LearningEvaluator le = new LearningEvaluator("../src/weights/alpha.csv");
BaseEvaluator be = new BaseEvaluator("../src/weights/beta.csv");
CheckersAI alpha = new CheckersAI(le, 1);
CheckersAI beta = new CheckersAI(be, 2);
CheckersGameState s;
learn(alpha, beta, le, be);

int played = 0;
int won = 0;
int winner;
}

for(int i = 0; i < num_games; i++){ // play num_games amount of games
s = new CheckersGameState3();
winner = play(s, alpha, beta, le); // alpha and beta play a game
System.out.println("winner " + winner);
le.updateWeights(); // get new weights using data from game
played++;
if(winner == alpha.getPlayer()){
won++;
// need to decide what to do if we are going on the wrong track
// samuel resets one of the weights to be zero


// for draws, make function called is_improved that checks if piece count is greater than 4 (king is worth 2)
// for learning rate, first 30 with .1, next 30 with .05, then final 30 with .01 and see what happens

public static void learn(CheckersAI alpha, CheckersAI beta, LearningEvaluator le, BaseEvaluator be){
final int num_games = 30;
final int iterations = 3;

for(int j = 0; j < iterations; j++){
for(int i = 1; i <= num_games; i++){ // play num_games amount of games
System.out.println("playing game " + i);
play(alpha, beta, le, true); // alpha and beta play a game
le.updateWeights(.1); // get new weights using data from game
}
if(played == 10){
if(won >= 7){ // if alpha wins 7 of every ten games, make beta use alpha's new evaluator
le.commitWeights("../src/weights/beta.csv");
be.refreshWeights();
}
played = 0;
won = 0;
faceBeta(alpha, beta, le, be);
}
}

public static void faceBeta(CheckersAI alpha, CheckersAI beta, LearningEvaluator le, BaseEvaluator be){
int won = 0;
boolean w;
CheckersGameState s;
System.out.println("facing beta");
for(int i = 0; i < 10; i++){
s = new CheckersGameState3();
w = play(alpha, beta, le, false);
if(w){
won++;
}
}
System.out.println("alpha won " + won + " times");
if(won >= 7){
System.out.println("updating beta");
le.commitWeights("../src/weights/beta.csv");
be.refreshWeights();
}
else{
be.commitWeights("../src/weights/alpha.csv");
le.refreshWeights();
}

}

// need to decide what to do if we are going on the wrong track
// samuel resets one of the weights to be zero

}



public static int play(CheckersGameState s, CheckersAI alpha, CheckersAI beta, LearningEvaluator le){
CheckersGameState current = s;
int moves = 0; // draw after 200 moves
public static boolean play(CheckersAI alpha, CheckersAI beta, LearningEvaluator le, boolean learning){
CheckersGameState current = new CheckersGameState3();
Random rand = new Random();
int player = rand.nextInt(2) + 1; // choose which player alpha plays as
int other = 1 - (player - 1) + 1;
//System.out.println("playing as " + player);
alpha.setPlayer(player);
beta.setPlayer(other);
current.printState();
if(other == 1){ // if beta goes first, make a move
current = current.result(beta.minimax(current, 7));
current.printState();
moves++;
}
while(!current.isTerminal() && moves <= 100){
int same_moves = 0;
Move lastmove = null;
Move secondlast = null;
while(!current.isTerminal() && same_moves <= 3){
Move next = alpha.minimax(current, 7); // get alpha's move
le.addData(current.getFeatures(alpha.getPlayer()), next.getValue()); // add this moves data to the data set (the value of the state is stored in the move. there is probably a better way to do this)
if(secondlast != null && next.toString().equals(secondlast.toString())){
same_moves++;
}
secondlast = lastmove;
lastmove = next;
if(learning){
le.addData(current.getFeatures(alpha.getPlayer()), next.getValue()); // add this moves data to the data set (the value of the state is stored in the move. there is probably a better way to do this)
}
current = current.result(next); // make the move
//current.printState();
moves++;
if(current.isTerminal()){ // if alpha won, then break
break;
}
current = current.result(beta.minimax(current, 7)); // beta's move
moves++;

}
current.printState();
System.out.println("playing as " + alpha.getPlayer());
return current.winner();
return (current.winner() == alpha.getPlayer() || improved(current, alpha.getPlayer()));
}

public static boolean improved(CheckersGameState i, int player){
CheckersGameState3 s = (CheckersGameState3) i;
return s.numPieces(player) >= (4 + s.numPieces(1 + (1 - (player-1))));
}

}
81 changes: 31 additions & 50 deletions src/LearningEvaluator.java
Expand Up @@ -12,71 +12,52 @@ public class LearningEvaluator extends BaseEvaluator{
double alpha; // learning parameter, higher alpha means weights are closer to the regression output
// alpha of 1 is directly setting weights to be regression weights

public LearningEvaluator(String file, double alpha){
public LearningEvaluator(String file){
super(file);
params = new ArrayList<double[]>();
values = new ArrayList<Double>();
reg = new OLSMultipleLinearRegression();
this.alpha = alpha;

}

public void setAlpha(double a){
alpha = a;
}

public void addData(double[] features, double value){
values.add(value);
params.add(features);
}

public void commitWeights(String path){
this.wp.writeWeights(path, this.weights); // method to commit weights to beta. provide path to beta csv
}

public void updateWeights(){
// NEED TO CHANGE THIS METHOD
// using least squares might be a bad idea
// get a lot of singular matrices
// we could do samuel's method or come up with another function to modify the coefficients
// int curr_in = 0;
// int data_sz = params.get(0).length + 1; // need to do regression with data sets of size 10, so each iteration of loop uses 10 lines of data
//while(params.size() - curr_in > data_sz){
double[] vals = new double [values.size()]; //converting arraylist to array
System.out.println("printing values");
int j = 0;
for(int i = 0; i < values.size(); i++){
vals[i] = values.get(i);
System.out.println(values.get(i));
j++;
}
System.out.println(vals);
public void updateWeights(double alpha){
double[] vals = new double [values.size()]; //converting arraylist to array
System.out.println("printing values");
for(int i = 0; i < values.size(); i++){
vals[i] = values.get(i);
System.out.println(values.get(i));
}
// System.out.println(vals);
System.out.println("printing params");
double[][] pars = new double[params.size()][]; //converting 2d arraylist to array
j=0;
for(int i=0; i < params.size(); i++){
pars[i] = params.get(i);
System.out.println(Arrays.toString(params.get(i)));
j++;
double[][] pars = new double[params.size()][]; //converting 2d arraylist to array
for(int i=0; i < params.size(); i++){
pars[i] = params.get(i);
System.out.println(Arrays.toString(params.get(i)));
}
//System.out.println(pars);
reg.newSampleData(vals, pars); //add data
reg.setNoIntercept(true);
try {
double[] new_weights = reg.estimateRegressionParameters(); //get parameters
for(double x: new_weights){
if(Math.abs(x) > 10000){
System.out.println("bad data, not updating");
return;
}
}
System.out.println(pars);
reg.newSampleData(vals, pars); //add data
reg.setNoIntercept(true);
try {
double[] new_weights = reg.estimateRegressionParameters(); //get parameters
for(int i = 0; i < this.weights.length; i++){
this.weights[i] = this.weights[i] + alpha * (new_weights[i] - this.weights[i]);
}
System.out.println("updated weights " + Arrays.toString(this.weights));
commitWeights(this.file);
} catch(SingularMatrixException e) {
System.out.println("Matrix was singular, not updating weights");
for(int i = 0; i < this.weights.length; i++){
this.weights[i] = this.weights[i] + alpha * (new_weights[i] - this.weights[i]);
}
//curr_in += data_sz;
//}

//values = new ArrayList<Double>(values.subList(curr_in, values.size()));
//params = new ArrayList<double[]>(params.subList(curr_in, params.size()));
System.out.println("updated weights " + Arrays.toString(this.weights));
commitWeights(this.file);
} catch(SingularMatrixException e) {
System.out.println("Matrix was singular, not updating weights");
}
values.clear();
params.clear();

Expand Down

0 comments on commit b1825e5

Please sign in to comment.