Skip to content

Commit

Permalink
Tried to run Luis' code. Encountered error again. Going to add checkp…
Browse files Browse the repository at this point in the history
…oint to save progress.
  • Loading branch information
unknown authored and unknown committed Jan 23, 2024
1 parent 9388a46 commit ddbd109
Show file tree
Hide file tree
Showing 11 changed files with 321 additions and 11 deletions.
Binary file not shown.
3 changes: 2 additions & 1 deletion BML_project/active_learning/ss_active_learning.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ def run_minibatch_kmeans(data_loader, n_clusters, device, batch_size=100):
# Iterate through data_loader and fit MiniBatchKMeans
for batch in data_loader:
data = batch['data'].view(batch['data'].size(0), -1).to(device).cpu().numpy()
minibatch_kmeans.partial_fit(data)
# minibatch_kmeans.partial_fit(data)
minibatch_kmeans.fit(data) # Dong, 01/22/2024: Debug

return minibatch_kmeans

Expand Down
272 changes: 272 additions & 0 deletions BML_project/cassey_CS330_torch.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
name: CS330_torch
channels:
- pytorch
- nvidia
- anaconda
- conda-forge
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- abseil-cpp=20211102.0=h27087fc_1
- absl-py=2.0.0=pyhd8ed1ab_0
- aiohttp=3.8.5=py311h5eee18b_0
- aiosignal=1.3.1=pyhd8ed1ab_0
- asttokens=2.4.0=pyhd8ed1ab_0
- async-timeout=4.0.2=py311h06a4308_0
- attrs=23.1.0=pyh71513ae_1
- backcall=0.2.0=pyh9f0ad1d_0
- backports=1.0=pyhd8ed1ab_3
- backports.functools_lru_cache=1.6.5=pyhd8ed1ab_0
- blas=1.1=openblas
- blinker=1.6.3=pyhd8ed1ab_0
- bottleneck=1.3.5=py311hbed6279_0
- brotli=1.0.9=h9c3ff4c_4
- brotlipy=0.7.0=py311h5eee18b_1002
- bzip2=1.0.8=h7b6447c_0
- c-ares=1.19.1=h5eee18b_0
- ca-certificates=2023.12.12=h06a4308_0
- cachetools=5.3.1=pyhd8ed1ab_0
- cairo=1.16.0=hb05425b_5
- certifi=2023.11.17=py311h06a4308_0
- cffi=1.15.1=py311h5eee18b_3
- charset-normalizer=2.0.4=pyhd3eb1b0_0
- click=8.1.7=unix_pyh707e725_0
- cloudpickle=2.2.1=pyhd8ed1ab_0
- colorama=0.4.6=pyhd8ed1ab_0
- comm=0.1.4=pyhd8ed1ab_0
- contourpy=1.0.5=py311hdb19cb5_0
- cryptography=41.0.3=py311hdda0065_0
- cuda-cudart=11.8.89=0
- cuda-cupti=11.8.87=0
- cuda-libraries=11.8.0=0
- cuda-nvrtc=11.8.89=0
- cuda-nvtx=11.8.86=0
- cuda-runtime=11.8.0=0
- cycler=0.12.1=pyhd8ed1ab_0
- cyrus-sasl=2.1.28=h52b45da_1
- dbus=1.13.18=hb2f20db_0
- debugpy=1.6.7=py311h6a678d5_0
- decorator=5.1.1=pyhd8ed1ab_0
- eigen=3.4.0=h4bd325d_0
- exceptiongroup=1.1.3=pyhd8ed1ab_0
- executing=1.2.0=pyhd8ed1ab_0
- expat=2.5.0=h6a678d5_0
- ffmpeg=4.2.2=h20bf706_0
- filelock=3.9.0=py311h06a4308_0
- fontconfig=2.14.1=h4c34cd2_2
- fonttools=4.25.0=pyhd3eb1b0_0
- freetype=2.12.1=h4a9f257_0
- frozenlist=1.3.3=py311h5eee18b_0
- fsspec=2023.10.0=pyhca7485f_0
- giflib=5.2.1=h5eee18b_3
- glib=2.69.1=he621ea3_2
- gmp=6.2.1=h295c915_3
- gmpy2=2.1.2=py311hc9b5ff0_0
- gnutls=3.6.15=he1e5248_0
- google-auth=2.23.2=pyhca7485f_0
- google-auth-oauthlib=1.0.0=pyhd8ed1ab_1
- googledrivedownloader=0.4=pyhd3deb0d_1
- gpytorch=1.11=pyhd8ed1ab_0
- graphite2=1.3.14=h295c915_1
- grpc-cpp=1.48.2=he1ff14a_1
- grpcio=1.48.2=py311he1ff14a_1
- gst-plugins-base=1.14.1=h6a678d5_1
- gstreamer=1.14.1=h5eee18b_1
- h5py=3.9.0=py311hdd6beaf_0
- harfbuzz=4.3.0=hf52aaf7_1
- hdf5=1.12.1=h2b7332f_3
- icu=58.2=hf484d3e_1000
- idna=3.4=py311h06a4308_0
- imageio=2.31.5=pyh8c1a49c_0
- importlib-metadata=6.8.0=pyha770c72_0
- importlib_metadata=6.8.0=hd8ed1ab_0
- iniconfig=1.1.1=pyhd3eb1b0_0
- intel-openmp=2023.1.0=hdb19cb5_46305
- ipykernel=6.25.2=pyh2140261_0
- ipython=8.16.1=pyh0d859eb_0
- jaxtyping=0.2.25=pyhd8ed1ab_0
- jedi=0.19.1=pyhd8ed1ab_0
- jinja2=3.1.2=py311h06a4308_0
- joblib=1.2.0=py311h06a4308_0
- jpeg=9e=h5eee18b_1
- jupyter_client=8.3.1=pyhd8ed1ab_0
- jupyter_core=4.12.0=py311h38be061_0
- kiwisolver=1.4.4=py311h6a678d5_0
- krb5=1.20.1=h143b758_1
- lame=3.100=h7b6447c_0
- lcms2=2.12=h3be6417_0
- ld_impl_linux-64=2.38=h1181459_1
- lerc=3.0=h295c915_0
- libclang=14.0.6=default_hc6dbbc7_1
- libclang13=14.0.6=default_he11475f_1
- libcublas=11.11.3.6=0
- libcufft=10.9.0.58=0
- libcufile=1.7.2.10=0
- libcups=2.4.2=h2d74bed_1
- libcurand=10.3.3.141=0
- libcurl=7.88.1=h251f7ec_2
- libcusolver=11.4.1.48=0
- libcusparse=11.7.5.86=0
- libdeflate=1.17=h5eee18b_1
- libedit=3.1.20221030=h5eee18b_0
- libev=4.33=h7f8727e_1
- libevent=2.1.12=hdbd6064_1
- libffi=3.4.4=h6a678d5_0
- libgcc-ng=11.2.0=h1234567_1
- libgfortran=3.0.0=1
- libgfortran-ng=11.2.0=h00389a5_1
- libgfortran5=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libiconv=1.16=h7f8727e_2
- libidn2=2.3.4=h5eee18b_0
- libjpeg-turbo=2.0.0=h9bf148f_0
- libllvm14=14.0.6=hdb19cb5_3
- libnghttp2=1.52.0=h2d74bed_1
- libnpp=11.8.0.86=0
- libnvjpeg=11.9.0.86=0
- libopenblas=0.3.21=h043d6bf_0
- libopus=1.3.1=h7f98852_1
- libpng=1.6.39=h5eee18b_0
- libpq=12.15=hdbd6064_1
- libprotobuf=3.20.3=he621ea3_0
- libsodium=1.0.18=h36c2ea0_1
- libssh2=1.10.0=hdbd6064_2
- libstdcxx-ng=11.2.0=h1234567_1
- libtasn1=4.19.0=h5eee18b_0
- libtiff=4.5.1=h6a678d5_0
- libunistring=0.9.10=h27cfd23_0
- libuuid=1.41.5=h5eee18b_0
- libvpx=1.7.0=h439df22_0
- libwebp=1.3.2=h11a3e52_0
- libwebp-base=1.3.2=h5eee18b_0
- libxcb=1.15=h7f8727e_0
- libxkbcommon=1.0.1=h5eee18b_1
- libxml2=2.10.4=hcbfbd50_0
- libxslt=1.1.37=h2085143_0
- linear_operator=0.5.2=pyhd8ed1ab_0
- llvm-openmp=14.0.6=h9e868ea_0
- lockfile=0.12.2=py311h06a4308_0
- lz4-c=1.9.4=h6a678d5_0
- markdown=3.5=pyhd8ed1ab_0
- markupsafe=2.1.1=py311h5eee18b_0
- matplotlib=3.7.2=py311h06a4308_0
- matplotlib-base=3.7.2=py311ha02d727_0
- matplotlib-inline=0.1.6=pyhd8ed1ab_0
- mkl=2023.1.0=h213fc3f_46343
- mkl-service=2.4.0=py311h5eee18b_1
- mpc=1.1.0=h10f8cd9_1
- mpfr=4.0.2=hb69a4c5_1
- mpmath=1.3.0=py311h06a4308_0
- multidict=6.0.2=py311h5eee18b_0
- munkres=1.1.4=pyh9f0ad1d_0
- mysql=5.7.24=h721c034_2
- ncurses=6.4=h6a678d5_0
- nest-asyncio=1.5.6=pyhd8ed1ab_0
- nettle=3.7.3=hbbd107a_1
- networkx=3.1=py311h06a4308_0
- ninja=1.10.2=h06a4308_5
- ninja-base=1.10.2=hd09550d_5
- numexpr=2.8.7=py311h812550d_0
- numpy=1.26.0=py311h24aa872_0
- numpy-base=1.26.0=py311hbfb1bba_0
- oauthlib=3.2.2=pyhd8ed1ab_0
- openblas=0.3.3=ha44fe06_1
- opencv=4.6.0=py311h10ae9b0_5
- openh264=2.1.1=h4ff587b_0
- openjpeg=2.4.0=h3ad879b_0
- openssl=3.0.12=h7f8727e_0
- opt_einsum=3.3.0=pyhc1e730c_2
- packaging=23.2=pyhd8ed1ab_0
- pandas=2.0.3=py311ha02d727_0
- parso=0.8.3=pyhd8ed1ab_0
- pcre=8.45=h9c3ff4c_0
- pexpect=4.8.0=pyh1a96a4e_2
- pickleshare=0.7.5=pyhd3eb1b0_1003
- pillow=10.0.1=py311ha6cbd5a_0
- pip=23.2.1=py311h06a4308_0
- pixman=0.40.0=h7f8727e_1
- pluggy=1.0.0=py311h06a4308_1
- ply=3.11=py_1
- pretty_errors=1.2.25=pyhd8ed1ab_0
- prompt-toolkit=3.0.39=pyha770c72_0
- prompt_toolkit=3.0.39=hd8ed1ab_0
- protobuf=3.20.3=py311h6a678d5_0
- psutil=5.9.0=py311h5eee18b_0
- ptyprocess=0.7.0=pyhd3deb0d_0
- pure_eval=0.2.2=pyhd8ed1ab_0
- pyasn1=0.5.0=pyhd8ed1ab_0
- pyasn1-modules=0.3.0=pyhd8ed1ab_0
- pycparser=2.21=pyhd3eb1b0_0
- pygments=2.16.1=pyhd8ed1ab_0
- pyjwt=2.8.0=pyhd8ed1ab_0
- pyopenssl=23.2.0=py311h06a4308_0
- pyparsing=3.0.9=pyhd8ed1ab_0
- pyqt=5.15.7=py311h6a678d5_0
- pyqt5-sip=12.11.0=py311h6a678d5_0
- pysocks=1.7.1=py311h06a4308_0
- pytest=7.4.0=py311h06a4308_0
- python=3.11.5=h955ad1f_0
- python-dateutil=2.8.2=pyhd8ed1ab_0
- python-tzdata=2023.3=pyhd3eb1b0_0
- python_abi=3.11=2_cp311
- pytorch=2.1.0=cpu_py311h53e38e9_0
- pytorch-cuda=11.8=h7e8668a_5
- pytorch-model-summary=0.1.1=py_0
- pytorch-mutex=1.0=cuda
- pytz=2023.3.post1=py311h06a4308_0
- pyu2f=0.1.5=pyhd8ed1ab_0
- pyyaml=6.0=py311h5eee18b_1
- pyzmq=25.1.0=py311h6a678d5_0
- qt-main=5.15.2=h7358343_9
- qt-webengine=5.15.9=hbbf29b9_6
- qtwebkit=5.212=h3fafdc1_5
- re2=2022.04.01=h27087fc_0
- readline=8.2=h5eee18b_0
- requests=2.31.0=py311h06a4308_0
- requests-oauthlib=1.3.1=pyhd8ed1ab_0
- rsa=4.9=pyhd8ed1ab_0
- scikit-learn=1.2.2=py311h6a678d5_1
- scipy=1.11.3=py311h24aa872_0
- seaborn=0.12.2=py311h06a4308_0
- setuptools=68.0.0=py311h06a4308_0
- sip=6.6.2=py311h6a678d5_0
- six=1.16.0=pyh6c4a22f_0
- sqlite=3.41.2=h5eee18b_0
- stack_data=0.6.2=pyhd8ed1ab_0
- sympy=1.11.1=py311h06a4308_0
- tbb=2021.8.0=hdb19cb5_0
- tensorboard=2.14.1=pyhd8ed1ab_0
- tensorboard-data-server=0.7.0=py311h52d8a92_0
- threadpoolctl=2.2.0=pyh0d69192_0
- tk=8.6.12=h1ccaba5_0
- toml=0.10.2=pyhd8ed1ab_0
- torchaudio=2.1.0=py311_cu118
- torchinfo=1.8.0=pyhd8ed1ab_0
- torchtriton=2.1.0=py311
- torchvision=0.15.2=cpu_py311h6e929fa_0
- tornado=6.3.3=py311h5eee18b_0
- tqdm=4.66.1=pyhd8ed1ab_0
- traitlets=5.11.2=pyhd8ed1ab_0
- typeguard=2.13.3=py311h06a4308_0
- typing-extensions=4.7.1=py311h06a4308_0
- typing_extensions=4.7.1=py311h06a4308_0
- tzdata=2023c=h04d1e81_0
- urllib3=1.26.16=py311h06a4308_0
- wcwidth=0.2.8=pyhd8ed1ab_0
- werkzeug=3.0.0=pyhd8ed1ab_0
- wheel=0.41.2=py311h06a4308_0
- x264=1!157.20191217=h7b6447c_0
- xz=5.4.2=h5eee18b_0
- yaml=0.2.5=h7b6447c_0
- yarl=1.8.1=py311h5eee18b_0
- zeromq=4.3.4=h9c3ff4c_1
- zipp=3.17.0=pyhd8ed1ab_0
- zlib=1.2.13=h5eee18b_0
- zstd=1.5.5=hc292b87_0
- pip:
- beautifulsoup4==4.12.2
- gdown==4.7.1
- soupsieve==2.5
- torchsummary==1.5.1
prefix: /home/doh16101/anaconda3/envs/CS330_torch
Binary file modified BML_project/models/__pycache__/ss_gp_model.cpython-311.pyc
Binary file not shown.
3 changes: 2 additions & 1 deletion BML_project/models/ss_gp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
class MultitaskGPModel(gpytorch.models.ApproximateGP):
def __init__(self):
# Let's use a different set of inducing points for each latent function
inducing_points = torch.rand(num_latents, num_inducing_points, 127 * 128) # Assuming flattened 128x128 images
inducing_points = torch.rand(num_latents, num_inducing_points, 128 * 128) # Assuming flattened 128x128 images
# Dong, 01/22/2024: I will use 128 * 128.

# We have to mark the CholeskyVariationalDistribution as batch
# so that we learn a variational distribution for each task
Expand Down
7 changes: 4 additions & 3 deletions BML_project/ss_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
@author: lrm22005
"""
import tqdm
from tqdm import tqdm
import torch
from utils.data_loader import preprocess_data, split_uids, update_train_loader_with_uncertain_samples
from utils_gp.data_loader import preprocess_data, split_uids, update_train_loader_with_uncertain_samples
from models.ss_gp_model import MultitaskGPModel, train_gp_model
from utils_gp.ss_evaluation import stochastic_evaluation, evaluate_model_on_all_data
from active_learning.ss_active_learning import stochastic_uncertainty_sampling, run_minibatch_kmeans, stochastic_compare_kmeans_gp_predictions
from utils.visualization import plot_comparative_results, plot_training_performance, plot_results
from utils_gp.visualization import plot_comparative_results, plot_training_performance, plot_results

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Expand All @@ -22,6 +22,7 @@ def main():
data_format = 'pt'
# Preprocess data
train_loader, val_loader, test_loader = preprocess_data(data_format, clinical_trial_train, clinical_trial_test, clinical_trial_unlabeled, batch_size)
print('Debug: len(train_loader)',len(train_loader))

kmeans_model = run_minibatch_kmeans(train_loader, n_clusters=n_classes, device=device)

Expand Down
Binary file modified BML_project/utils_gp/__pycache__/data_loader.cpython-311.pyc
Binary file not shown.
Binary file modified BML_project/utils_gp/__pycache__/ss_evaluation.cpython-311.pyc
Binary file not shown.
Binary file modified BML_project/utils_gp/__pycache__/visualization.cpython-311.pyc
Binary file not shown.
28 changes: 22 additions & 6 deletions BML_project/utils_gp/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,20 @@
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler
from torchvision.transforms import ToTensor
import socket

def split_uids():
# ====== Load the per subject arrythmia summary ======
df_summary = pd.read_csv(r'\\grove.ad.uconn.edu\research\ENGR_Chon\NIH_Pulsewatch_Database\Adjudication_UConn\final_attemp_4_1_Dong_Ohm_summary_20231025.csv')
your_computer_name = socket.gethostname()
if your_computer_name == 'localhost.localdomain':
# Dong, 12/09/2023: I am so sick of changing the path every time on different computer.
# This is Cassey's Luis server name.
df_summary = pd.read_csv(r'/mnt/r/ENGR_Chon/NIH_Pulsewatch_Database/Adjudication_UConn/final_attemp_4_1_Dong_Ohm_summary_20231025.csv')
elif your_computer_name == 'Darren_computer_name':
# Darren, you can put your computer name in the elif condition to separate it from Luis's computer.
df_summary = pd.read_csv(r'R:\ENGR_Chon\NIH_Pulsewatch_Database\Adjudication_UConn\final_attemp_4_1_Dong_Ohm_summary_20231025.csv')
elif your_computer_name == 'Luis_computer_name':
df_summary = pd.read_csv(r'\\grove.ad.uconn.edu\research\ENGR_Chon\NIH_Pulsewatch_Database\Adjudication_UConn\final_attemp_4_1_Dong_Ohm_summary_20231025.csv')
df_summary['UID'] = df_summary['UID'].astype(str).str.zfill(3)

df_summary['sample_nonAF'] = df_summary['NSR'] + df_summary['PACPVC'] + df_summary['SVT']
Expand Down Expand Up @@ -130,7 +140,6 @@ def __getitem__(self, idx):
else:
# Load data on-the-fly based on the segment_name
time_freq_tensor = self.load_data(segment_name)

return {'data': time_freq_tensor, 'label': label, 'segment_name': segment_name}

def add_data_label_pair(self, data, label):
Expand All @@ -155,6 +164,7 @@ def extract_segment_names_and_labels(self):
for UID in self.UIDs:
label_file = os.path.join(self.labels_path, UID + "_final_attemp_4_1_Dong.csv")
if os.path.exists(label_file):
print('Debug: this file exists',label_file)
label_data = pd.read_csv(label_file, sep=',', header=0, names=['segment', 'label'])
label_segment_names = label_data['segment'].apply(lambda x: x.split('.')[0])
for idx, segment_name in enumerate(label_segment_names):
Expand Down Expand Up @@ -207,20 +217,26 @@ def load_data_split_batched(data_path, labels_path, UIDs, batch_size, standardiz
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=drop_last, num_workers=num_workers, prefetch_factor=2)
return dataloader

def get_data_paths(data_format, is_linux=False, is_hpc=False):
if is_linux:
def get_data_paths(data_format):
your_computer_name = socket.gethostname()
print('Debug: your_computer_name',your_computer_name)
if your_computer_name == 'localhost.localdomain':
base_path = "/mnt/r/ENGR_Chon/Dong/MATLAB_generate_results/NIH_PulseWatch"
labels_base_path = "/mnt/r/ENGR_Chon/NIH_Pulsewatch_Database/Adjudication_UConn"
saving_base_path = "/mnt/r/ENGR_Chon/Luis/Research/Casseys_case/Project_1_analysis"
elif is_hpc:
elif your_computer_name == 'HPC_computer_name':
base_path = "/gpfs/scratchfs1/kic14002/doh16101"
labels_base_path = "/gpfs/scratchfs1/hfp14002/lrm22005"
saving_base_path = "/gpfs/scratchfs1/hfp14002/lrm22005/Casseys_case/Project_1_analysis"
else:
elif your_computer_name == 'Darren_computer_name':
# R:\ENGR_Chon\Dong\MATLAB_generate_results\NIH_PulseWatch
base_path = "R:\ENGR_Chon\Dong\MATLAB_generate_results\\NIH_PulseWatch"
labels_base_path = "R:\ENGR_Chon\\NIH_Pulsewatch_Database\Adjudication_UConn"
saving_base_path = r"\\grove.ad.uconn.edu\research\ENGR_Chon\Luis\Research\Casseys_case"
else:
print('ERROR! YOUR DID NOT GET THE PATH.')
raise ValueError

if data_format == 'csv':
data_path = os.path.join(base_path, "TFS_csv")
labels_path = os.path.join(labels_base_path, "final_attemp_4_1_Dong_Ohm")
Expand Down
Loading

0 comments on commit ddbd109

Please sign in to comment.