diff --git a/src/CheckersAI.java b/src/CheckersAI.java index 8bcbbb3..0166553 100644 --- a/src/CheckersAI.java +++ b/src/CheckersAI.java @@ -52,7 +52,7 @@ public class CheckersAI{ double v = Double.NEGATIVE_INFINITY; double check; Move max = null; - // System.out.println(s.actions().size()); + System.out.println(s.actions().size()); for(Move a: s.actions()){ check = minValue(s.result(a), alpha, beta, depth + 1, a.isJump(), min_ply); if(check > v){ diff --git a/src/CheckersGameState3.java b/src/CheckersGameState3.java index eaff276..303dd4b 100644 --- a/src/CheckersGameState3.java +++ b/src/CheckersGameState3.java @@ -374,13 +374,16 @@ public class CheckersGameState3 implements CheckersGameState{ } /* computes feature vector: - piece-ratio, loners, safes, pawns, - moveable pawns, aggregate distance, - kings, moveable kings, promotion line - opening + [piece-ratio, + loners, + safes, + pawns, + moveable pawns, + aggregate distance, + promotion line opening] */ public double[] getFeatures(int player){ - double[] features = new double[9]; + double[] features = new double[7]; double total = 0.0; double mypieces = 0.0; for(int i = 0; i= 7){ // if alpha wins 7 of every ten games, make beta use alpha's new evaluator - le.commitWeights("weights/beta.csv"); + le.commitWeights("../src/weights/beta.csv"); be.refreshWeights(); } played = 0; @@ -58,12 +58,14 @@ public class Learn{ moves++; } while(!current.isTerminal() && moves <= 50){ + System.out.println("alphas moves:"); + System.out.println(current.actions()); Move next = alpha.minimax(current, 7); // get alpha's move le.addData(current.getFeatures(alpha.getPlayer()), next.getValue()); // add this moves data to the data set (the value of the state is stored in the move. there is probably a better way to do this) current = current.result(next); // make the move current.printState(); moves++; - if(current.isTerminal()){ // if alpha won, then brea + if(current.isTerminal()){ // if alpha won, then break break; } current = current.result(beta.minimax(current, 7)); // beta's move @@ -74,5 +76,3 @@ public class Learn{ return current.winner(); } } - - diff --git a/src/LearningEvaluator.java b/src/LearningEvaluator.java index 5531c8e..91a6556 100644 --- a/src/LearningEvaluator.java +++ b/src/LearningEvaluator.java @@ -1,12 +1,13 @@ import java.util.ArrayList; import org.apache.commons.math3.stat.regression.OLSMultipleLinearRegression; +import org.apache.commons.math3.linear.SingularMatrixException; import java.util.Arrays; public class LearningEvaluator extends BaseEvaluator{ ArrayList params; ArrayList values; - // need to download jar and set classpath to import and run + // need to download jar and set classpath to import and run OLSMultipleLinearRegression reg; // performs linear regression (ordinary least squares) double alpha; // learning parameter, higher alpha means weights are closer to the regression output // alpha of 1 is directly setting weights to be regression weights @@ -39,20 +40,21 @@ public class LearningEvaluator extends BaseEvaluator{ // get a lot of singular matrices // we could do samuel's method or come up with another function to modify the coefficients int curr_in = 0; - while(params.size() - curr_in > 10){ // need to do regression with data sets of size 10, so each iteration of loop uses 10 lines of data - double[] vals = new double [10]; //converting arraylist to array + int data_sz = 8; // need to do regression with data sets of size 10, so each iteration of loop uses 10 lines of data + while(params.size() - curr_in > data_sz){ + double[] vals = new double [data_sz]; //converting arraylist to array System.out.println("printing values"); int j = 0; - for(int i = curr_in; i < curr_in + 10; i++){ + for(int i = curr_in; i < curr_in + data_sz; i++){ vals[j] = values.get(i); System.out.println(values.get(i)); j++; } System.out.println(vals); System.out.println("printing params"); - double[][] pars = new double[10][]; //converting 2d arraylist to array + double[][] pars = new double[data_sz][]; //converting 2d arraylist to array j=0; - for(int i=curr_in; i < curr_in + 10; i++){ + for(int i=curr_in; i < curr_in + data_sz; i++){ pars[j] = params.get(i); System.out.println(Arrays.toString(params.get(i))); j++; @@ -60,12 +62,16 @@ public class LearningEvaluator extends BaseEvaluator{ System.out.println(pars); reg.newSampleData(vals, pars); //add data reg.setNoIntercept(true); - double[] new_weights = reg.estimateRegressionParameters(); //get parameters - for(int i = 0; i < this.weights.length; i++){ - this.weights[i] = this.weights[i] + alpha * (new_weights[i] - this.weights[i]); + try { + double[] new_weights = reg.estimateRegressionParameters(); //get parameters + for(int i = 0; i < this.weights.length; i++){ + this.weights[i] = this.weights[i] + alpha * (new_weights[i] - this.weights[i]); + } + commitWeights(this.file); + } catch(SingularMatrixException e) { + System.out.println("Matrix was singular, not updating weights"); } - commitWeights(this.file); - curr_in += 10; + curr_in += data_sz; } values.clear(); diff --git a/src/weights/alpha.csv b/src/weights/alpha.csv index 6ffa5a6..5895be1 100644 --- a/src/weights/alpha.csv +++ b/src/weights/alpha.csv @@ -1 +1,30 @@ -50, 10, 10, 5, 5, 5, 30, 15, 10 +50, 10, 10, 5, 5, 5 +-9.3685899134567718E17, 7.794026190160416E16, -6.049540124358089E14, -6.932782101834882E14, 7.463767693764232E16, 1.1512441608154629 +-2.8738797198358354E18, 5.3535252570681231E18, 3.3336822929004636E16, 3.433162512891778E16, -4.0433107676171347E17, 7.819856042285374E15 +1.817671195817515E32, -5.438334910319976E30, 3.7970448614125307E30, 5.999603288154183E30, 4.2657208866270885E30, -3.420682188700641E30 +1.7206054178773225E32, -1.2372315866627587E31, 4.135155496534389E30, 4.954032483591002E30, 5.800958050966378E30, -3.0930635559914327E30 +2.2189863047375575E32, -1.1200084316563596E31, 2.795880090367318E30, 5.808933256464445E30, 2.5411383173703824E30, -3.0954239523221725E30 +1.4570594680388968E32, -8.866159458937068E30, -4.184354587177827E45, 4.184354587177812E45, -4.184354587177801E45, -9.203700236115977E29 +4.828745193602742E46, -2.117283421111761E47, -3.431170761485821E45, 3.849606220203588E45, 1.8411160183577473E45, -2.510612752306685E44 +2.6886269228363255E46, -1.9092532105319082E47, -2.5877150925926487E45, 5.4585742201173775E45, 1.9360832367781485E45, -4.176681793472993E44 +9.68056166753188E47, -1.822791958582696E47, 3.2799103164625757E46, 6.982451188800286E46, -1.2759822630112968E46, -2.1251547085596302E46 +6.607747772113344E47, -1.5775834620045355E47, 1.9082744084210428E46, 4.943658822412271E46, -1.0928504000796732E46, -1.268366597892222E46 +-6.640413286411508E63, 6.191527539777473E61, 8.048985801710929E61, 5.262798408810957E62, -1.8574582619332455E62, -6.12794061383161E45 +-9.975387123273426E76, 1.9950774246545522E77, -4.750184344415589E75, 5.842609164879715E62, -1.0648576764909763E62, -2.3269824336997295E61 +-8.977848410946235E76, 1.795569682189097E77, -4.275165909974025E75, 6.109277826946604E62, -8.213410456570652E61, -2.5089001685635713E61 +-1.395983161991158E77, 1.6160127139701874E77, 9.721046181162213E75, 5.25986567097873E62, -9.42642845969764E61, -4.2967917201625334E61 +-2.783659649264689E77, 7.993733678169044E77, 2.3936964084631586E76, -1.9442092362320297E75, -9.098576084558526E76, 1.3443043569096783E76 +1.1278410857435816E78, 7.301806896008975E77, 2.500785282308765E76, 4.311516318358936E75, -1.0440491512512438E77, 8.264007014739985E75 +4.3825255215879347E77, 6.46912413373827E77, 3.3368415505215776E77, 5.44692780818499E76, -8.380612099682645E76, -8.224821683358007E75 +6.536493354945876E76, 5.90016774220315E77, 3.033792609868331E77, 1.9930898294085956E77, -1.0093881611537739E77, -1.996258945193199E76 +-6.484760054376507E77, 5.693556733979261E77, 3.52941302481836E77, 4.0495478599741917E77, -1.2805312734164537E77, -4.6965237247001146E76 +-2.9395502408946725E78, 5.165544365100268E77, 9.744982653143925E77, -1.2426575673168948E77, 4.5261748871156896E77, -6.567082327835312E76 +-8.200718081043308E78, 5.423910896161437E77, 1.2845513616993203E78, -1.375451023113857E77, 4.2791673573848214E77, -3.2178738680038086E76 +8.373699459108173E78, 4.596571219901508E77, 1.11227296311433E78, -1.3219280792060426E77, 9.137064349203897E76, -2.2307875336690484E76 +5.722730436152141E78, 3.266170201871765E77, 1.1356081942713786E78, -2.924947663746416E77, 3.799050720906769E77, -7.641060717501331E75 +3.8630225969823536E78, 2.8181958605463597E77, 1.709325061397922E78, 7.731448514518302E76, 3.2094710923023195E77, -7.463884758456114E76 +3.621525694225168E78, 2.577641971114625E77, 1.5178164811966045E78, 2.33441167062367E77, 2.810616951734669E77, -7.504402810380263E76 +-3.9246295090853187E80, 2.5992773231027646E77, -1.3579310902003569E93, -5.248570840331341E93, 5.248570840331358E93, -3.2362800212151406E76 +-2.1067562907318428E94, -4.468385728885715E92, -2.5123998504920498E93, -9.180443867939546E93, 6.345104944836925E93, 7.980054086800604E92 +-7.81284835659837E94, -1.6166995865070447E93, -5.128336274216761E93, -1.9307596009148744E94, 9.615561351884508E93, 2.7628625498248885E93 +-2.258434065448233E95, -4.623187033028576E93, -1.148760484999136E94, -4.5123076330934E94, 1.837401892192431E94, 7.685804031253954E93 diff --git a/src/weights/beta.csv b/src/weights/beta.csv index 6ffa5a6..4192e5a 100644 --- a/src/weights/beta.csv +++ b/src/weights/beta.csv @@ -1 +1 @@ -50, 10, 10, 5, 5, 5, 30, 15, 10 +50, 10, 10, 5, 5, 5