diff --git a/src/BaseEvaluator.java b/src/BaseEvaluator.java index dc79bf2..c35659a 100644 --- a/src/BaseEvaluator.java +++ b/src/BaseEvaluator.java @@ -5,7 +5,7 @@ public class BaseEvaluator implements Evaluator{ // when beta should use alpha's weights, have alpha commit to beta.csv and then call refreshWeights() protected WeightsParser wp; - String file; + protected String file; protected double[] weights; public BaseEvaluator(String file){ diff --git a/src/LearningEvaluator.java b/src/LearningEvaluator.java index 8f31cac..b162954 100644 --- a/src/LearningEvaluator.java +++ b/src/LearningEvaluator.java @@ -5,6 +5,7 @@ public class LearningEvaluator extends BaseEvaluator{ ArrayList params; ArrayList values; + OLSMultipleLinearRegression reg; // performs linear regression (ordinary least squares) double alpha; // learning parameter, higher alpha means weights are closer to the regression output // alpha of 1 is directly setting weights to be regression weights // ideally we start at 1 and lower alpha to get a convergence @@ -13,6 +14,7 @@ public LearningEvaluator(String file, double alpha){ super(file); params = new ArrayList(); values = new ArrayList(); + reg = new OLSMultipleLinearRegression(); this.alpha = alpha; } @@ -30,6 +32,26 @@ public void commitWeights(String path){ this.wp.writeWeights(path, this.weights); // method to commit weights to beta. provide path to beta csv } + public void updateWeights(){ + double[] vals = new double [values.size()]; + for(int i = 0; i < values.size(); i++){ + vals[i] = values.get(i); + } + values.clear(); + double[][] pars = new double[params.size()][]; + for(int i=0; i < params.size(); i++){ + pars[i] = params.get(i); + } + params.clear(); + reg.newSampleData(vals, pars); + reg.setNoIntercept(true); + double[] new_weights = reg.estimateRegressionParameters(); + for(int i = 0; i < this.weights.length; i++){ + this.weights[i] = this.weights[i] + alpha * (new_weights[i] - this.weights[i]); + } + commitWeights(this.file); + + }