Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
CheckersAI/src/LearningEvaluator.java
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
78 lines (65 sloc)
2.8 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.ArrayList; | |
import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression; | |
import java.util.Arrays; | |
public class LearningEvaluator extends BaseEvaluator{ | |
ArrayList<double[]> params; | |
ArrayList<Double> values; | |
// need to download jar and set classpath to import and run | |
OLSMultipleLinearRegression reg; // performs linear regression (ordinary least squares) | |
double alpha; // learning parameter, higher alpha means weights are closer to the regression output | |
// alpha of 1 is directly setting weights to be regression weights | |
public LearningEvaluator(String file, double alpha){ | |
super(file); | |
params = new ArrayList<double[]>(); | |
values = new ArrayList<Double>(); | |
reg = new OLSMultipleLinearRegression(); | |
this.alpha = alpha; | |
} | |
public void setAlpha(double a){ | |
alpha = a; | |
} | |
public void addData(double[] features, double value){ | |
values.add(value); | |
params.add(features); | |
} | |
public void commitWeights(String path){ | |
this.wp.writeWeights(path, this.weights); // method to commit weights to beta. provide path to beta csv | |
} | |
public void updateWeights(){ | |
// NEED TO CHANGE THIS METHOD | |
// using least squares might be a bad idea | |
// get a lot of singular matrices | |
// we could do samuel's method or come up with another function to modify the coefficients | |
int curr_in = 0; | |
while(params.size() - curr_in > 10){ // need to do regression with data sets of size 10, so each iteration of loop uses 10 lines of data | |
double[] vals = new double [10]; //converting arraylist to array | |
System.out.println("printing values"); | |
int j = 0; | |
for(int i = curr_in; i < curr_in + 10; i++){ | |
vals[j] = values.get(i); | |
System.out.println(values.get(i)); | |
j++; | |
} | |
System.out.println(vals); | |
System.out.println("printing params"); | |
double[][] pars = new double[10][]; //converting 2d arraylist to array | |
j=0; | |
for(int i=curr_in; i < curr_in + 10; i++){ | |
pars[j] = params.get(i); | |
System.out.println(Arrays.toString(params.get(i))); | |
j++; | |
} | |
System.out.println(pars); | |
reg.newSampleData(vals, pars); //add data | |
reg.setNoIntercept(true); | |
double[] new_weights = reg.estimateRegressionParameters(); //get parameters | |
for(int i = 0; i < this.weights.length; i++){ | |
this.weights[i] = this.weights[i] + alpha * (new_weights[i] - this.weights[i]); | |
} | |
commitWeights(this.file); | |
curr_in += 10; | |
} | |
values.clear(); | |
params.clear(); | |
} | |
} |