Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Implement infrastructure for learning
  • Loading branch information
sas12028 committed Apr 21, 2017
1 parent 68bd8ae commit 5642506
Show file tree
Hide file tree
Showing 9 changed files with 81 additions and 21 deletions.
8 changes: 8 additions & 0 deletions src/BaseEvaluator.java
Expand Up @@ -23,6 +23,14 @@ public class BaseEvaluator implements Evaluator{
}

public double evaluate(CheckersGameState s, int player){
if(s.isTerminal()){
if(s.winner() == player){
return 1000; // what should this be?
}
else{
return 0; // assuming only positive evalutions
}
}
double[] params = s.getFeatures(player);
return dot(this.weights, params);
}
Expand Down
9 changes: 9 additions & 0 deletions src/CheckersAI.java
Expand Up @@ -9,6 +9,14 @@ public class CheckersAI{
this.player = player;
}

public void setPlayer(int player){
this.player = player;
}

public int getPlayer(){
return this.player;
}

private boolean stop(CheckersGameState state, boolean jumped, int depth, int min_ply){
CheckersGameState3 s = (CheckersGameState3) state;
if(depth < min_ply){
Expand Down Expand Up @@ -56,6 +64,7 @@ public class CheckersAI{
return max;
}
}
max.setValue(v);
return max;
}

Expand Down
1 change: 1 addition & 0 deletions src/CheckersGameState.java
Expand Up @@ -4,6 +4,7 @@ String player ();
List < Move > actions ();
CheckersGameState result ( Move x );
boolean isTerminal();
int winner();
void printState ();
public double[] getFeatures(int player);
}
13 changes: 12 additions & 1 deletion src/CheckersGameState3.java
Expand Up @@ -421,7 +421,6 @@ public class CheckersGameState3 implements CheckersGameState{
if(valid_square(index+4)){
if(!empty(board, index+4)) return false;
}
System.out.println("loner found at index " + index);
return true;
}

Expand All @@ -439,6 +438,18 @@ public class CheckersGameState3 implements CheckersGameState{
return (rat == 0 || rat == Double.POSITIVE_INFINITY);
}

public int winner(){ // only call after isTerminal
for(int i = 0; i < board.length; i++){
if(board[i] == 1 || board[i] == 3){
return 1;
}
else if(board[i] == 2 || board[i] == 4){
return 2;
}
}
return 0;
}

public void printState(){
boolean leading = false;
int printed = 0;
Expand Down
52 changes: 36 additions & 16 deletions src/LearningEvaluator.java
@@ -1,14 +1,15 @@
import java.util.ArrayList;
import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression;
import java.util.Arrays;

public class LearningEvaluator extends BaseEvaluator{

ArrayList<double[]> params;
ArrayList<Double> values;
// need to download jar and set classpath to import and run
OLSMultipleLinearRegression reg; // performs linear regression (ordinary least squares)
double alpha; // learning parameter, higher alpha means weights are closer to the regression output
// alpha of 1 is directly setting weights to be regression weights
// ideally we start at 1 and lower alpha to get a convergence

public LearningEvaluator(String file, double alpha){
super(file);
Expand All @@ -23,7 +24,7 @@ public class LearningEvaluator extends BaseEvaluator{
alpha = a;
}

public void add_data(double[] features, double value){
public void addData(double[] features, double value){
values.add(value);
params.add(features);
}
Expand All @@ -33,23 +34,42 @@ public class LearningEvaluator extends BaseEvaluator{
}

public void updateWeights(){
double[] vals = new double [values.size()];
for(int i = 0; i < values.size(); i++){
vals[i] = values.get(i);
// NEED TO CHANGE THIS METHOD
// using least squares might be a bad idea
// get a lot of singular matrices
// we could do samuel's method or come up with another function to modify the coefficients
int curr_in = 0;
while(params.size() - curr_in > 10){ // need to do regression with data sets of size 10, so each iteration of loop uses 10 lines of data
double[] vals = new double [10]; //converting arraylist to array
System.out.println("printing values");
int j = 0;
for(int i = curr_in; i < curr_in + 10; i++){
vals[j] = values.get(i);
System.out.println(values.get(i));
j++;
}
System.out.println(vals);
System.out.println("printing params");
double[][] pars = new double[10][]; //converting 2d arraylist to array
j=0;
for(int i=curr_in; i < curr_in + 10; i++){
pars[j] = params.get(i);
System.out.println(Arrays.toString(params.get(i)));
j++;
}
System.out.println(pars);
reg.newSampleData(vals, pars); //add data
reg.setNoIntercept(true);
double[] new_weights = reg.estimateRegressionParameters(); //get parameters
for(int i = 0; i < this.weights.length; i++){
this.weights[i] = this.weights[i] + alpha * (new_weights[i] - this.weights[i]);
}
commitWeights(this.file);
curr_in += 10;
}

values.clear();
double[][] pars = new double[params.size()][];
for(int i=0; i < params.size(); i++){
pars[i] = params.get(i);
}
params.clear();
reg.newSampleData(vals, pars);
reg.setNoIntercept(true);
double[] new_weights = reg.estimateRegressionParameters();
for(int i = 0; i < this.weights.length; i++){
this.weights[i] = this.weights[i] + alpha * (new_weights[i] - this.weights[i]);
}
commitWeights(this.file);

}

Expand Down
4 changes: 4 additions & 0 deletions src/Move.java
Expand Up @@ -9,5 +9,9 @@ public interface Move {

boolean isJump();

public void setValue(double value);

public double getValue();


}
9 changes: 9 additions & 0 deletions src/Move3.java
Expand Up @@ -4,6 +4,7 @@ public class Move3 implements Move{
String[] steps;
int[] kills;
String check;
double value = 0;

public Move3(String steps){
String[] s = steps.split(",");
Expand All @@ -25,6 +26,14 @@ public class Move3 implements Move{
return k;
}

public void setValue(double value){
this.value = value;
}

public double getValue(){
return this.value;
}

public int source(){
return src;
}
Expand Down
4 changes: 1 addition & 3 deletions src/weights/alpha.csv
@@ -1,3 +1 @@
10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10
10.0, 10.0, 10.0, 10.0, 10.0, 75.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0
10.0, 10.0, 10.0, 10.0, 10.0, 75.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0, 10.0
50, 10, 10, 5, 5, 5, 30, 15, 10
2 changes: 1 addition & 1 deletion src/weights/beta.csv
@@ -1 +1 @@
10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10
50, 10, 10, 5, 5, 5, 30, 15, 10

0 comments on commit 5642506

Please sign in to comment.