From ddbd109b5bef61f52486a6c11ec82455a7abd050 Mon Sep 17 00:00:00 2001 From: unknown Date: Tue, 23 Jan 2024 16:05:12 -0500 Subject: [PATCH] Tried to run Luis' code. Encountered error again. Going to add checkpoint to save progress. --- .../ss_active_learning.cpython-311.pyc | Bin 6942 -> 6909 bytes .../active_learning/ss_active_learning.py | 3 +- BML_project/cassey_CS330_torch.yml | 272 ++++++++++++++++++ .../__pycache__/ss_gp_model.cpython-311.pyc | Bin 12179 -> 12155 bytes BML_project/models/ss_gp_model.py | 3 +- BML_project/ss_main.py | 7 +- .../__pycache__/data_loader.cpython-311.pyc | Bin 17950 -> 18912 bytes .../__pycache__/ss_evaluation.cpython-311.pyc | Bin 9772 -> 9750 bytes .../__pycache__/visualization.cpython-311.pyc | Bin 5734 -> 5677 bytes BML_project/utils_gp/data_loader.py | 28 +- transfer_data/tar_PT_files.sh | 19 ++ 11 files changed, 321 insertions(+), 11 deletions(-) create mode 100644 BML_project/cassey_CS330_torch.yml create mode 100644 transfer_data/tar_PT_files.sh diff --git a/BML_project/active_learning/__pycache__/ss_active_learning.cpython-311.pyc b/BML_project/active_learning/__pycache__/ss_active_learning.cpython-311.pyc index 45c9875deaf1bad595378c59b1e08a1296ac1ec3..9504984498b04e14db8bf2c7a9373efccf469c21 100644 GIT binary patch delta 204 zcmbPd_SckmIWI340}yyDU6<;yk(YxhGFLx0uSCB{-__4ODBd|EKTqE!KQCQBpt1zS zc8d4)i4Q2s&q_@$(e)|KEJhL0cLK^5KxFk3lS?woQsZ+{6N~aPfhKHjXPU&p$UIqz zdm;xjtMUg1ATfD2cd+;a4vFg=GM6}HW`tbiP`bjQbODSutMa(AFnVv!7P!sE7`@qD Uyn|VUi&5qS11e!K`LUEW0EaU`U;qFB delta 260 zcmexsI?s%EIWI340}!Zs%t__g$jiZ0-{x!;6Iz^FR2-92lxt*UU|<^KpO@-Vlv$Rl zpsNs?2c(ONGfOHJ^3xQY^YijjlS?v_OG{#0@{>z*Q}arSW85=KGD?$ToZ@|b;sc8E zvw$jMoPg{CC_53TqbxN(CpEDsFEcMarnopBA+))gsg;A1yCAWsBr`E5eljEXL=Gla z?GFq;I^mffF`Yx7z5 zGA2gT$u3+wwLO6L{%Bx;!4FJwtXv-$un-Q6lMT7WH`{UB@-upFu8}I@Vw|@5y$TOI m&zkJ1?*#x_p%j$T1!ks{DFRg%7wtJp0WIW})$ zFJodfnXJvVQ_CGF@uPtO20t*#vT}W3z(Uw<=I0LKXY|;dB2~o2ICb-N6&`lRRhyq^ h%Q5R4GD?47z$7MAOsT&jZt@W%^#viHG1*Yx3jhUmUONB) diff --git a/BML_project/models/ss_gp_model.py b/BML_project/models/ss_gp_model.py index c18f06f..e364cec 100644 --- a/BML_project/models/ss_gp_model.py +++ b/BML_project/models/ss_gp_model.py @@ -20,7 +20,8 @@ class MultitaskGPModel(gpytorch.models.ApproximateGP): def __init__(self): # Let's use a different set of inducing points for each latent function - inducing_points = torch.rand(num_latents, num_inducing_points, 127 * 128) # Assuming flattened 128x128 images + inducing_points = torch.rand(num_latents, num_inducing_points, 128 * 128) # Assuming flattened 128x128 images + # Dong, 01/22/2024: I will use 128 * 128. # We have to mark the CholeskyVariationalDistribution as batch # so that we learn a variational distribution for each task diff --git a/BML_project/ss_main.py b/BML_project/ss_main.py index a610684..2716179 100644 --- a/BML_project/ss_main.py +++ b/BML_project/ss_main.py @@ -4,13 +4,13 @@ @author: lrm22005 """ -import tqdm +from tqdm import tqdm import torch -from utils.data_loader import preprocess_data, split_uids, update_train_loader_with_uncertain_samples +from utils_gp.data_loader import preprocess_data, split_uids, update_train_loader_with_uncertain_samples from models.ss_gp_model import MultitaskGPModel, train_gp_model from utils_gp.ss_evaluation import stochastic_evaluation, evaluate_model_on_all_data from active_learning.ss_active_learning import stochastic_uncertainty_sampling, run_minibatch_kmeans, stochastic_compare_kmeans_gp_predictions -from utils.visualization import plot_comparative_results, plot_training_performance, plot_results +from utils_gp.visualization import plot_comparative_results, plot_training_performance, plot_results device = torch.device("cuda" if torch.cuda.is_available() else "cpu") @@ -22,6 +22,7 @@ def main(): data_format = 'pt' # Preprocess data train_loader, val_loader, test_loader = preprocess_data(data_format, clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled, batch_size) + print('Debug: len(train_loader)',len(train_loader)) kmeans_model = run_minibatch_kmeans(train_loader, n_clusters=n_classes, device=device) diff --git a/BML_project/utils_gp/__pycache__/data_loader.cpython-311.pyc b/BML_project/utils_gp/__pycache__/data_loader.cpython-311.pyc index 0b1a7ebdb4b25cf36cb492dfb1b7d20074e1bb08..e131c154219f7dc3144ea923de24d050f059f45a 100644 GIT binary patch delta 5497 zcmcH+ZE#f8_1*p6&2Bay`@NfF^9f5Jn*<^xfiMY6h%qD~gb}g0ta+P|xY>>Oz3>s= zHb7B*U@=-4_7WrnzwDYk0W{=rd3rBmB;-peN0U^{lE zy~&>U?)kdsoO|xM=bb!Ct}>E)+F~(rFiySrZ1fLnkLS9`3&$7TWg)yOmbWdB*B|7z zQ(fJ6Po92s6-bBb`>8Tct9 zN#Y)m(#}ZPvh`-Tc-=xNiST-NhRjWAIF5Zl3N#`Fh5f0@VH%mWk4;uNCk@&1D7_k_ z;c1-H0yERxY(k;u+=MaDNLuD7AelRZHD_a$Sv^_`O!J7Sej-Q4T4zM(JEUiu%50Bn zazjZKpU3A9sgmj;jxTtcBsK6Xgr^pE)(&``q)zH3P0}D~B^_!_8pV8EPZ_*JEMyP2 zd2~;3yi?M%F@r7dCi*02dGdQHZd@VJnYL>cckc0(IDX~4;l{pLfk1}JTuFjv>);TRmE3*F;&G$1yZ48lkEJW zvx`r`!8rpzWh5P9wfweRhLVHl{=J>qCh-BbL$|EaIip}k{=bm^tu9#Zl3bEWPGL9X zkwByE(xw zdFM3Ot;$NaC0V=U%=Q}6uv&87e5OE%^%{!EQWiHjvJX_YKJQKXmPuaFXW#CepRhwD z>6APRFdoSZSygktYA&%MBETnk2h~u@gWWgmvBW?mwlg7!HS!yu7>x|aCq`JLp_v3( zaHX5AG&$^Afh1cy&#Dk$kd??tvSH$dlDqnXd)X!TvXsU#t#_x2R$SDtNaGFzAGkZI8j-|~ybhV^1P=lbppAgtkq7Nyh0aP6V~x(Qz{Y&jIY{=ie>hw7 zk02wq1Ab)>53?Ru8QI0CYYFh3aJ^L>0%Cy33k1{@b0MvU_)u9=jgtP%{*Vlv2x5X)Dvu!1|n$~}un2fNKvNha96o-MG@c~3{tek7~tMtIPN5Ih1P zV;fuQEhh(9uh$2@k9x~Ulg#oku%z`nBSJ(JsUrFWqKvUbgb(u(F%rB#Y7pDxobCxv80x3Jwq`)sHQ@|J?;am&^5aqW4-Uto{^%a46~Ya&lwz3DdLcSA z8jXvv_wKR-W~^niiFLZHhdjeRDJ$ubb570h64EN@mOXV`kqeuhV@O8BVeb<;6O= zh#q6t%a`k>04wydnhJ+5-?-*OE}wN)G}x?`HCOWNDM#=*(}ld|OL@(yyk>Tw;_brE zu3CNH8cbO$QickOZVaf{^NqH1$0{2&g_AI!&|sk87@vg+?P6~P20GD2`f>&YJoJb7 z{vv{x5F9}890DvUX#*T#VfQe&t%{GOwQyxc_GUbNU`Ys$(rZf=yJi1xM}F-@lqQ6L zQE~euGgmJoPqU`#igMY{Dd3rKt{WG`#Apc4sSw3Wa^}$f}5Qq6oS>UlMM8)=eiUYpd!Yk{JhrRafT0E!VP?!=R_ShDP{~zqx1@_p#tUZqZr9F`^jE_~s?qg=~3;K50P54-r^m3o< z>Yi}V)bE=cJ8A?KthC{$h3Wi`-gOJQzBe0=YN-i1cHZ)%%8AnTeSJNBrT#m5w)FW! zw}$-PJ^lU->-+uv9qawQt^FM}aPmu6uFn4zh@_4E?VH16u>=(DU$gki=2|`N13cY~ zKrUmy$G7aF7%(XP6%QnKd1dXSTqN;urgMbIuHpDl7;gpv^9iqf9CF1-=R)@@HzW!J zh4zqp6Mn+?A?^2ZSF(G{D;i_URGF$;bwSg3Nz<6pG|ulfr)_uewD-=G z!3F-dkI2$JV5&K=!YF`JC z)7CK>h3}4ms36FZp`QVk5LVIrGBkPbHeV+JEW2=PVXv2IKFKhB7p4N%uz9Vgre!Iy zvAUKeI=QZfSfs^okr`FUxC5H6hg#ga&tYOD><29YSmJVvN9EiE82SQ&s{mjX6r!h@ zp|zg8#hP2o$O(2=Yn6PRq!+WUb#8ksa(6Tq<-irFXv~CPOvm%0ggQwNKw{}qs zC_T?ALwEGb2SLKQd=cRmE1QiwL}hQ#XNra_xEgO5R*xq5@mRElT7e4x*vPyZI_jV1 mEc@iYYnrWu^uUFg^$=NXeMepy30~u7K__9atnbsHm;VBY?yEY zaPHP7kYb3d1Pv{OG%d*wM+$Be6{)S1Rt;&SqE?mMDpk=EkXxlcqCVm%P4j4)^qhNb z<6Y>Fj=bN@J#)^PGjqLTPY<+Ce&<{VLoDEJAl4p-@9BjLwP6~bX1-!4>*i@IqL zxU?jMob8|idcA*16MPwNeT`}J#@i^S1xfR*8IZIrP)&tY^(&$-0Zhv#QT|Hr3mWZW;LQWVYqU|M`(?M^CUN1VtGN2n2A!k%{Ig_GS45%z;mTl@f zDu_N;!th za!z2h=E~*T#bpSQ9k<3xd7-3OJ!T-Jkd|}Fb+;r=1q)|SkB25hho~4nVLlb?BGbTm zX^A`Mo(#gQAg;LctUwf}Qm(iZw^%b%JFN#<_-QTYksB7r5TYKD+<6lqpp^4gV-5G3 zf)Q__(49q=BotF+ZG}>ySk&P4L0tJ29|TYkLdCAoQNlcmUvohGHrGsrfIpy6;Qxxs z?}Zpx3vzzwn~QwD7z-!7Q3BdJ+iqWHLqS~D75L9v=naV2dc}(uY0dJBR4C)W zO{PKd-*)*xhQDU6)HNyrxOwhg#k)MX2Wif$_=_lhB>;gJd20Ec3n;*+1P=O80ds?c40gsCsR^% zFgqZ|$Xnu>k+dA`>)*q=;9J6& z42LHsSU0lt0Pr)mFzw+N>>Ft>|AYMr>gIGC)9wDJF5p>cSBnLvQ4DuotGPz;_#5u;2lHGq{E_ZV$DDPzkW3xp`yah@41^?6^id!5V15hS<(Cdds(`LpZZcGdmPQUA81 ze%4&idN03JSwlCflBa+tZyJh8F;9v!4ki&cGEso6)`R|bU(`Dn(R3QbzalYzJr8w z0C{6N@pu%@4HuIjo5Q(51BW%O*hDA!(2CPR_2ic-5nwUH50Ci6Qf{jW3Xgomn`)lv zo>`uF4x{(grszdM)QNi0FeuCzi?X&sF)XiVE}HR;@`_N66=t2BQFcLb5V0*QcMewR$RA;?n46w29#poi|B|K};s#o2OSkS&B(uDIWZFyXg0M zvaWik8ftn~g6#m7_fVE$?fKr9&>}W13Wbc6*Joq$P+osHlT54E2}2X^y2N&&$Zr0} zx~dRfDTW@6Sk(~GVo~+3g*)rlLRAC|=lS;fDx;csPV!X!T{Y-qsRNuhL9Y1%leTM4 z*Q9aIRdM#7c~WM%7l2=>U$3*Y)9#`Jgs zaNqF7W;=MSDZUyv#a>^8NPzti$FCy1hM;DLS8z1J-*5687dB&W?xN$ozq!%B59js( zsqrM0#6DnGaO5i2^VFQVJaj`A>{FXnqu73bmiQ|zhHa!gR`@^>&KF<+_GLs0kyC^{2uFDzH4n+Wp9w-K&7=dCG zP4H!gFALQ4M6_xJDvR-8W3x*qmP~`DOYqZL&ZyG$VrdcGg!B~#(N;t^r*D8(dy!U0 z5!e5(briIgE!WB`?{@Q0W0UKD*tYdfwmD?e;?4;Cqf(fro}`LVF)3!XesT5j{Tp8A z4fjm(pWgEm0mAI;#&+k=Kss;P+tVM-rZRBFf6o83aZ`)F`11S;*x)`)Ah#JF^;9xF zawKnr(NH$d@88r~zx4R#OQhInGCdf@EFxJz0>WKukj7u!f%8Lblf6c|1wJlPq6WqD^nsOSR?gX-b^?w7~aZNGzpdQA`)&@zmApDpa)K**QlE3DZPg{m$I z=b)Wf;W@Map!72cUjh_68#S#O$wDEKh{`OMOh>t=x5}Uco5FlUZ~udM`y&P|!SwVn z+*(BF>50&uh)$#Kjrj7WXjII^qfsp;4*(PU1*$e?p)!qWuToyOd&m6@69>D(Z}jfJ zZxn}Wp28r|!U28A)D7dYFtnp6bAD!SAI^v)sl;Yxg<1F`(gXaP+ovk#^}8tD1!=UP ZJILF1YUb&Imn>{E|KR2&LzewNJ&pjyKIU_$$-z7gUT|c0*1jKfV s_w|VnD9X=DO)k;(Da|ZK5z%)7$`(Lm^-D`KbBg2B3-mYdW6zTV0DkBtHUIzs delta 153 zcmbQ{v&M&eIWI340}#YZPT$Dw&0b&bY!wq)oLW>IlT(yyWMp7q8sndr>Qa;fpev?Md9IHtHbKD8_{r!=u7Ge1wSpz@YzaY=r1#^(L(;c@_*lsBsY diff --git a/BML_project/utils_gp/__pycache__/visualization.cpython-311.pyc b/BML_project/utils_gp/__pycache__/visualization.cpython-311.pyc index a8c3afd30ad9aefea43425eec7756a9d74ec4254..fe8ddd586cc0009554b2a6bedcddf59cc949fd24 100644 GIT binary patch delta 108 zcmaE+vsQUb&Hb0&7sBer{fgev!VbpLX(*e<`l=L7wB(RV{;Gy0GlZ!wEzGB delta 166 zcmZ3h^Gt_(IWI340}yyzY1qg;fwg{?vsFxJacWUI& Lj@kT{)j