Skip to content

update on step2 and step3 #7

Merged
merged 1 commit into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,7 @@ checkpoint.json
checkpoint.json
checkpoint.json
*.txt
output/embeddings/gene_based/gene_based_embeddings_part0.pt
output/embeddings/gene_based/gene_based_embeddings_part1.pt
output/embeddings/gene_based/gene_based_embeddings_part2.pt
output/gene_based_preprocessed/processed_genes.json
Binary file not shown.
6 changes: 3 additions & 3 deletions code/step2_data_preprocessing_Luis_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def safe_read(file_path):
print(f"Unexpected error reading file {file_path}: {e}")
return None

def process_gene_docs_in_chunks(file_path, chunk_size=1000):
def process_gene_docs_in_chunks(file_path, chunk_size=10000):
gene2doc = {}
with open(file_path, 'r', encoding='utf-8') as f:
current_gene = None
Expand Down Expand Up @@ -75,7 +75,7 @@ def making_doc_data(self, gene_list, name, dic, mode='w'):

else:
for i in range(len(gene_list)):
if counting == 1000:
if counting == 10000:
print(i, '/', len(gene_list))
counting = 0
data = dic[gene_list[i]]
Expand All @@ -93,7 +93,7 @@ def making_doc_data(self, gene_list, name, dic, mode='w'):
gene_based_dir = os.path.join(batch_dir, 'results', 'gene_based_records')
baseline_doc_dir = os.path.join(batch_dir, 'results', 'baseline_doc')
comb_dir = os.path.join(output, 'arranged')
preprocessed_dir = os.path.join(output, 'preprocessed')
preprocessed_dir = os.path.join(output, 'gene_based_preprocessed')

print(f"Checking directories...")
print(f"batch_dir: {batch_dir}")
Expand Down
146 changes: 109 additions & 37 deletions code/step3_literature_embedding_training_Luis.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,68 @@
import os
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import gc

# Import necessary libraries
import os # For file and directory operations
import torch # PyTorch library for deep learning
from torch.utils.data import Dataset, DataLoader # For creating custom datasets and data loaders
from tqdm import tqdm # For progress bars
import gc # For garbage collection
import json # For working with JSON data

# Try to import various transformer models and tokenizers
try:
from transformers import AlbertTokenizer, AlbertModel
MODEL_NAME = 'albert-base-v2'
print("Using ALBERT model")
except ImportError:
from transformers import BertTokenizer, BertModel
MODEL_NAME = 'bert-base-uncased'
print("ALBERT not available, using BERT model instead")

from transformers import (
AlbertTokenizer, AlbertModel,
BertTokenizer, BertModel,
RobertaTokenizer, RobertaModel,
DistilBertTokenizer, DistilBertModel,
XLNetTokenizer, XLNetModel
)
# Define a dictionary of available models with their respective tokenizers and model classes
MODELS = {
'albert': (AlbertTokenizer, AlbertModel, 'albert-base-v2'),
'bert': (BertTokenizer, BertModel, 'bert-base-uncased'),
'roberta': (RobertaTokenizer, RobertaModel, 'roberta-base'),
'distilbert': (DistilBertTokenizer, DistilBertModel, 'distilbert-base-uncased'),
'xlnet': (XLNetTokenizer, XLNetModel, 'xlnet-base-cased')
}
print("All models available")
except ImportError as e:
# If some models are not available, print an error message
print(f"Some models might not be available: {e}")
MODELS = {}

# Class for saving embeddings
class EmbeddingSaver:
def __init__(self, base_path):
# Initialize the saver with a base path for saving embeddings
self.base_path = base_path
# Create directories for saving embeddings
os.makedirs(base_path, exist_ok=True)
self.gene_based_path = os.path.join(base_path, 'gene_based')
self.baseline_path = os.path.join(base_path, 'baseline')
os.makedirs(self.gene_based_path, exist_ok=True)
os.makedirs(self.baseline_path, exist_ok=True)
# Initialize metadata dictionary
self.metadata = {'gene_based': {}, 'baseline': {}}

def save_gene_based(self, embeddings, part):
# Save gene-based embeddings
file_path = os.path.join(self.gene_based_path, f'gene_based_embeddings_part{part}.pt')
torch.save(embeddings, file_path)
self.metadata['gene_based'][f'part{part}'] = file_path

def save_baseline(self, embeddings, file_name):
# Save baseline embeddings
file_path = os.path.join(self.baseline_path, f'{file_name}_embeddings.pt')
torch.save(embeddings, file_path)
self.metadata['baseline'][file_name] = file_path

def save_metadata(self):
# Save metadata to a JSON file
with open(os.path.join(self.base_path, 'embedding_metadata.json'), 'w') as f:
json.dump(self.metadata, f, indent=2)

# Custom Dataset class for processing literature in chunks
class ChunkedLiteratureDataset(Dataset):
def __init__(self, file_path, tokenizer, max_length=512, chunk_size=1000):
def __init__(self, file_path, tokenizer, max_length=512, chunk_size=100000):
self.file_path = file_path
self.tokenizer = tokenizer
self.max_length = max_length
Expand All @@ -23,6 +71,7 @@ def __init__(self, file_path, tokenizer, max_length=512, chunk_size=1000):
self.file = open(self.file_path, 'r', encoding='utf-8')

def load_chunk(self):
# Load a chunk of data from the file
self.current_chunk = []
for _ in range(self.chunk_size):
line = self.file.readline()
Expand All @@ -32,9 +81,11 @@ def load_chunk(self):
return len(self.current_chunk) > 0

def __len__(self):
return self.chunk_size
# Return the length of the current chunk
return len(self.current_chunk)

def __getitem__(self, idx):
# Get and tokenize an item from the current chunk
if idx >= len(self.current_chunk):
raise IndexError("Index out of bounds")
text = self.current_chunk[idx]
Expand All @@ -51,67 +102,88 @@ def __getitem__(self, idx):
'attention_mask': encoding['attention_mask'].flatten()
}

def process_in_batches(file_path, model, tokenizer, batch_size=32, max_length=512, chunk_size=1000):
# Function to process data in batches
def process_in_batches(file_path, model, tokenizer, embedding_saver, is_gene_based=True, batch_size=300, max_length=512, chunk_size=100000):
# Create a dataset
dataset = ChunkedLiteratureDataset(file_path, tokenizer, max_length, chunk_size)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

all_embeddings = []
chunk_count = 0
model.eval()
model.eval() # Set model to evaluation mode

with torch.no_grad():
while dataset.load_chunk():
with torch.no_grad(): # Disable gradient calculation
while dataset.load_chunk(): # Process data chunk by chunk
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
chunk_embeddings = []
for batch in tqdm(dataloader, desc=f"Processing chunk {chunk_count}"):
# Move batch to device and get model outputs
input_ids = batch['input_ids'].to(model.device)
attention_mask = batch['attention_mask'].to(model.device)
outputs = model(input_ids, attention_mask=attention_mask)
chunk_embeddings.append(outputs.last_hidden_state[:, 0, :].cpu())

# Concatenate embeddings for the current chunk
chunk_embeddings = torch.cat(chunk_embeddings)
all_embeddings.append(chunk_embeddings)

# Save embeddings periodically
if len(all_embeddings) % 10 == 0:
torch.save(torch.cat(all_embeddings), f'gene_based_embeddings_part{chunk_count//10}.pt')
embeddings_to_save = torch.cat(all_embeddings)
if is_gene_based:
embedding_saver.save_gene_based(embeddings_to_save, chunk_count//10)
else:
embedding_saver.save_baseline(embeddings_to_save, os.path.basename(file_path))
all_embeddings = [] # Clear the list to free up memory
gc.collect() # Force garbage collection

chunk_count += 1

# Save any remaining embeddings
if all_embeddings:
torch.save(torch.cat(all_embeddings), f'gene_based_embeddings_part{chunk_count//10}.pt')
embeddings_to_save = torch.cat(all_embeddings)
if is_gene_based:
embedding_saver.save_gene_based(embeddings_to_save, chunk_count//10)
else:
embedding_saver.save_baseline(embeddings_to_save, os.path.basename(file_path))

print(f"Processed {chunk_count} chunks from {file_path}")

# Set up paths
gene_based_path = r"D:\ZIP\arranged\gene_based_consolidated.txt"
gene_based_path = r"C:\Users\lrm22005\OneDrive - University of Connecticut\Research\ZIP11_Bioinformatic\capsule-3642152\ZIP11\output\arranged\consolidated_gene_docs.txt"
baseline_doc_dir = r"C:\Users\lrm22005\OneDrive - University of Connecticut\Research\ZIP11_Bioinformatic\capsule-3642152\ZIP11\output\preprocessed"
output_dir = r"C:\Users\lrm22005\OneDrive - University of Connecticut\Research\ZIP11_Bioinformatic\capsule-3642152\ZIP11\output\embeddings"

# Choose model
model_choice = 'albert' # Can be changed to 'bert', 'roberta', 'distilbert', or 'xlnet'
if model_choice not in MODELS:
raise ValueError(f"Model {model_choice} not available. Choose from: {list(MODELS.keys())}")

# Get the appropriate classes and model name for the chosen model
TokenizerClass, ModelClass, model_name = MODELS[model_choice]
print(f"Using {model_choice.upper()} model")

# Initialize model and tokenizer
tokenizer = AlbertTokenizer.from_pretrained(MODEL_NAME) if MODEL_NAME.startswith('albert') else BertTokenizer.from_pretrained(MODEL_NAME)
model = AlbertModel.from_pretrained(MODEL_NAME) if MODEL_NAME.startswith('albert') else BertModel.from_pretrained(MODEL_NAME)
tokenizer = TokenizerClass.from_pretrained(model_name)
model = ModelClass.from_pretrained(model_name)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.to(device) # Move model to GPU if available

# Initialize embedding saver
embedding_saver = EmbeddingSaver(output_dir)

# Process gene-based data
print("Processing gene-based data...")
process_in_batches(gene_based_path, model, tokenizer)
process_in_batches(gene_based_path, model, tokenizer, embedding_saver, is_gene_based=True)

# Process baseline documents (unchanged)
# Process baseline documents
print("Processing baseline documents...")
baseline_embeddings = []
for file in os.listdir(baseline_doc_dir):
if file.endswith('.txt'):
file_path = os.path.join(baseline_doc_dir, file)
embeddings = process_in_batches(file_path, model, tokenizer)
baseline_embeddings.append(embeddings)

# Combine baseline embeddings
baseline_embeddings = torch.cat(baseline_embeddings)
print(f"Processing {file}...")
process_in_batches(file_path, model, tokenizer, embedding_saver, is_gene_based=False, chunk_size=100000)

# Save baseline embeddings
torch.save(baseline_embeddings, 'baseline_embeddings.pt')
# Save metadata
embedding_saver.save_metadata()

print("Embedding process completed.")
Binary file not shown.
Binary file not shown.