Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Implement basic learning evaluator
  • Loading branch information
sas12028 committed Apr 20, 2017
1 parent c6d1f00 commit 68bd8ae
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
2 changes: 1 addition & 1 deletion src/BaseEvaluator.java
Expand Up @@ -5,7 +5,7 @@ public class BaseEvaluator implements Evaluator{
// when beta should use alpha's weights, have alpha commit to beta.csv and then call refreshWeights()

protected WeightsParser wp;
String file;
protected String file;
protected double[] weights;

public BaseEvaluator(String file){
Expand Down
22 changes: 22 additions & 0 deletions src/LearningEvaluator.java
Expand Up @@ -5,6 +5,7 @@ public class LearningEvaluator extends BaseEvaluator{

ArrayList<double[]> params;
ArrayList<Double> values;
OLSMultipleLinearRegression reg; // performs linear regression (ordinary least squares)
double alpha; // learning parameter, higher alpha means weights are closer to the regression output
// alpha of 1 is directly setting weights to be regression weights
// ideally we start at 1 and lower alpha to get a convergence
Expand All @@ -13,6 +14,7 @@ public class LearningEvaluator extends BaseEvaluator{
super(file);
params = new ArrayList<double[]>();
values = new ArrayList<Double>();
reg = new OLSMultipleLinearRegression();
this.alpha = alpha;

}
Expand All @@ -30,6 +32,26 @@ public class LearningEvaluator extends BaseEvaluator{
this.wp.writeWeights(path, this.weights); // method to commit weights to beta. provide path to beta csv
}

public void updateWeights(){
double[] vals = new double [values.size()];
for(int i = 0; i < values.size(); i++){
vals[i] = values.get(i);
}
values.clear();
double[][] pars = new double[params.size()][];
for(int i=0; i < params.size(); i++){
pars[i] = params.get(i);
}
params.clear();
reg.newSampleData(vals, pars);
reg.setNoIntercept(true);
double[] new_weights = reg.estimateRegressionParameters();
for(int i = 0; i < this.weights.length; i++){
this.weights[i] = this.weights[i] + alpha * (new_weights[i] - this.weights[i]);
}
commitWeights(this.file);

}



Expand Down

0 comments on commit 68bd8ae

Please sign in to comment.