diff --git a/src/BaseEvaluator.java b/src/BaseEvaluator.java index c35659a..90af4fc 100644 --- a/src/BaseEvaluator.java +++ b/src/BaseEvaluator.java @@ -23,6 +23,14 @@ public class BaseEvaluator implements Evaluator{ } public double evaluate(CheckersGameState s, int player){ + if(s.isTerminal()){ + if(s.winner() == player){ + return 1000; // what should this be? + } + else{ + return 0; // assuming only positive evalutions + } + } double[] params = s.getFeatures(player); return dot(this.weights, params); } diff --git a/src/CheckersAI.java b/src/CheckersAI.java index 36a5b07..8bcbbb3 100644 --- a/src/CheckersAI.java +++ b/src/CheckersAI.java @@ -9,6 +9,14 @@ public class CheckersAI{ this.player = player; } + public void setPlayer(int player){ + this.player = player; + } + + public int getPlayer(){ + return this.player; + } + private boolean stop(CheckersGameState state, boolean jumped, int depth, int min_ply){ CheckersGameState3 s = (CheckersGameState3) state; if(depth < min_ply){ @@ -56,6 +64,7 @@ public class CheckersAI{ return max; } } + max.setValue(v); return max; } diff --git a/src/CheckersGameState.java b/src/CheckersGameState.java index fd1cca3..8d9f4b7 100644 --- a/src/CheckersGameState.java +++ b/src/CheckersGameState.java @@ -4,6 +4,7 @@ String player (); List < Move > actions (); CheckersGameState result ( Move x ); boolean isTerminal(); +int winner(); void printState (); public double[] getFeatures(int player); } diff --git a/src/CheckersGameState3.java b/src/CheckersGameState3.java index a1a293f..eaff276 100644 --- a/src/CheckersGameState3.java +++ b/src/CheckersGameState3.java @@ -421,7 +421,6 @@ public class CheckersGameState3 implements CheckersGameState{ if(valid_square(index+4)){ if(!empty(board, index+4)) return false; } - System.out.println("loner found at index " + index); return true; } @@ -439,6 +438,18 @@ public class CheckersGameState3 implements CheckersGameState{ return (rat == 0 || rat == Double.POSITIVE_INFINITY); } + public int winner(){ // only call after isTerminal + for(int i = 0; i < board.length; i++){ + if(board[i] == 1 || board[i] == 3){ + return 1; + } + else if(board[i] == 2 || board[i] == 4){ + return 2; + } + } + return 0; + } + public void printState(){ boolean leading = false; int printed = 0; diff --git a/src/LearningEvaluator.java b/src/LearningEvaluator.java index b162954..5531c8e 100644 --- a/src/LearningEvaluator.java +++ b/src/LearningEvaluator.java @@ -1,14 +1,15 @@ import java.util.ArrayList; import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression; +import java.util.Arrays; public class LearningEvaluator extends BaseEvaluator{ ArrayList params; ArrayList values; + // need to download jar and set classpath to import and run OLSMultipleLinearRegression reg; // performs linear regression (ordinary least squares) double alpha; // learning parameter, higher alpha means weights are closer to the regression output // alpha of 1 is directly setting weights to be regression weights - // ideally we start at 1 and lower alpha to get a convergence public LearningEvaluator(String file, double alpha){ super(file); @@ -23,7 +24,7 @@ public class LearningEvaluator extends BaseEvaluator{ alpha = a; } - public void add_data(double[] features, double value){ + public void addData(double[] features, double value){ values.add(value); params.add(features); } @@ -33,23 +34,42 @@ public class LearningEvaluator extends BaseEvaluator{ } public void updateWeights(){ - double[] vals = new double [values.size()]; - for(int i = 0; i < values.size(); i++){ - vals[i] = values.get(i); + // NEED TO CHANGE THIS METHOD + // using least squares might be a bad idea + // get a lot of singular matrices + // we could do samuel's method or come up with another function to modify the coefficients + int curr_in = 0; + while(params.size() - curr_in > 10){ // need to do regression with data sets of size 10, so each iteration of loop uses 10 lines of data + double[] vals = new double [10]; //converting arraylist to array + System.out.println("printing values"); + int j = 0; + for(int i = curr_in; i < curr_in + 10; i++){ + vals[j] = values.get(i); + System.out.println(values.get(i)); + j++; + } + System.out.println(vals); + System.out.println("printing params"); + double[][] pars = new double[10][]; //converting 2d arraylist to array + j=0; + for(int i=curr_in; i < curr_in + 10; i++){ + pars[j] = params.get(i); + System.out.println(Arrays.toString(params.get(i))); + j++; + } + System.out.println(pars); + reg.newSampleData(vals, pars); //add data + reg.setNoIntercept(true); + double[] new_weights = reg.estimateRegressionParameters(); //get parameters + for(int i = 0; i < this.weights.length; i++){ + this.weights[i] = this.weights[i] + alpha * (new_weights[i] - this.weights[i]); + } + commitWeights(this.file); + curr_in += 10; } + values.clear(); - double[][] pars = new double[params.size()][]; - for(int i=0; i < params.size(); i++){ - pars[i] = params.get(i); - } params.clear(); - reg.newSampleData(vals, pars); - reg.setNoIntercept(true); - double[] new_weights = reg.estimateRegressionParameters(); - for(int i = 0; i < this.weights.length; i++){ - this.weights[i] = this.weights[i] + alpha * (new_weights[i] - this.weights[i]); - } - commitWeights(this.file); } diff --git a/src/Move.java b/src/Move.java index fc1add5..aafde66 100644 --- a/src/Move.java +++ b/src/Move.java @@ -9,5 +9,9 @@ public interface Move { boolean isJump(); + public void setValue(double value); + + public double getValue(); + } diff --git a/src/Move3.java b/src/Move3.java index b769fa1..a5e5b77 100644 --- a/src/Move3.java +++ b/src/Move3.java @@ -4,6 +4,7 @@ public class Move3 implements Move{ String[] steps; int[] kills; String check; + double value = 0; public Move3(String steps){ String[] s = steps.split(","); @@ -25,6 +26,14 @@ public class Move3 implements Move{ return k; } + public void setValue(double value){ + this.value = value; + } + + public double getValue(){ + return this.value; + } + public int source(){ return src; } diff --git a/src/weights/alpha.csv b/src/weights/alpha.csv index 16e15a6..6ffa5a6 100644 --- a/src/weights/alpha.csv +++ b/src/weights/alpha.csv @@ -1,3 +1 @@ -10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10 -10.0, 10.0, 10.0, 10.0, 10.0, 75.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0 -10.0, 10.0, 10.0, 10.0, 10.0, 75.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0 +50, 10, 10, 5, 5, 5, 30, 15, 10 diff --git a/src/weights/beta.csv b/src/weights/beta.csv index a8af049..6ffa5a6 100644 --- a/src/weights/beta.csv +++ b/src/weights/beta.csv @@ -1 +1 @@ -10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10 +50, 10, 10, 5, 5, 5, 30, 15, 10