Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
lrm22005 committed Jul 30, 2024
2 parents e916610 + a2e8540 commit c61aafe
Show file tree
Hide file tree
Showing 6 changed files with 116 additions and 40 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,7 @@ output/embeddings/gene_based/gene_based_embeddings_part1.pt
output/embeddings/gene_based/gene_based_embeddings_part0.pt
output/embeddings/gene_based/gene_based_embeddings_part2.pt
output/gene_based_preprocessed/processed_genes.json
output/embeddings/gene_based/gene_based_embeddings_part0.pt
output/embeddings/gene_based/gene_based_embeddings_part1.pt
output/embeddings/gene_based/gene_based_embeddings_part2.pt
output/gene_based_preprocessed/processed_genes.json
Binary file not shown.
6 changes: 3 additions & 3 deletions code/step2_data_preprocessing_Luis_new.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def safe_read(file_path):
print(f"Unexpected error reading file {file_path}: {e}")
return None

def process_gene_docs_in_chunks(file_path, chunk_size=1000):
def process_gene_docs_in_chunks(file_path, chunk_size=10000):
gene2doc = {}
with open(file_path, 'r', encoding='utf-8') as f:
current_gene = None
Expand Down Expand Up @@ -75,7 +75,7 @@ def making_doc_data(self, gene_list, name, dic, mode='w'):

else:
for i in range(len(gene_list)):
if counting == 1000:
if counting == 10000:
print(i, '/', len(gene_list))
counting = 0
data = dic[gene_list[i]]
Expand All @@ -93,7 +93,7 @@ def making_doc_data(self, gene_list, name, dic, mode='w'):
gene_based_dir = os.path.join(batch_dir, 'results', 'gene_based_records')
baseline_doc_dir = os.path.join(batch_dir, 'results', 'baseline_doc')
comb_dir = os.path.join(output, 'arranged')
preprocessed_dir = os.path.join(output, 'preprocessed')
preprocessed_dir = os.path.join(output, 'gene_based_preprocessed')

print(f"Checking directories...")
print(f"batch_dir: {batch_dir}")
Expand Down
146 changes: 109 additions & 37 deletions code/step3_literature_embedding_training_Luis.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,68 @@
import os
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import gc

# Import necessary libraries
import os # For file and directory operations
import torch # PyTorch library for deep learning
from torch.utils.data import Dataset, DataLoader # For creating custom datasets and data loaders
from tqdm import tqdm # For progress bars
import gc # For garbage collection
import json # For working with JSON data

# Try to import various transformer models and tokenizers
try:
from transformers import AlbertTokenizer, AlbertModel
MODEL_NAME = 'albert-base-v2'
print("Using ALBERT model")
except ImportError:
from transformers import BertTokenizer, BertModel
MODEL_NAME = 'bert-base-uncased'
print("ALBERT not available, using BERT model instead")

from transformers import (
AlbertTokenizer, AlbertModel,
BertTokenizer, BertModel,
RobertaTokenizer, RobertaModel,
DistilBertTokenizer, DistilBertModel,
XLNetTokenizer, XLNetModel
)
# Define a dictionary of available models with their respective tokenizers and model classes
MODELS = {
'albert': (AlbertTokenizer, AlbertModel, 'albert-base-v2'),
'bert': (BertTokenizer, BertModel, 'bert-base-uncased'),
'roberta': (RobertaTokenizer, RobertaModel, 'roberta-base'),
'distilbert': (DistilBertTokenizer, DistilBertModel, 'distilbert-base-uncased'),
'xlnet': (XLNetTokenizer, XLNetModel, 'xlnet-base-cased')
}
print("All models available")
except ImportError as e:
# If some models are not available, print an error message
print(f"Some models might not be available: {e}")
MODELS = {}

# Class for saving embeddings
class EmbeddingSaver:
def __init__(self, base_path):
# Initialize the saver with a base path for saving embeddings
self.base_path = base_path
# Create directories for saving embeddings
os.makedirs(base_path, exist_ok=True)
self.gene_based_path = os.path.join(base_path, 'gene_based')
self.baseline_path = os.path.join(base_path, 'baseline')
os.makedirs(self.gene_based_path, exist_ok=True)
os.makedirs(self.baseline_path, exist_ok=True)
# Initialize metadata dictionary
self.metadata = {'gene_based': {}, 'baseline': {}}

def save_gene_based(self, embeddings, part):
# Save gene-based embeddings
file_path = os.path.join(self.gene_based_path, f'gene_based_embeddings_part{part}.pt')
torch.save(embeddings, file_path)
self.metadata['gene_based'][f'part{part}'] = file_path

def save_baseline(self, embeddings, file_name):
# Save baseline embeddings
file_path = os.path.join(self.baseline_path, f'{file_name}_embeddings.pt')
torch.save(embeddings, file_path)
self.metadata['baseline'][file_name] = file_path

def save_metadata(self):
# Save metadata to a JSON file
with open(os.path.join(self.base_path, 'embedding_metadata.json'), 'w') as f:
json.dump(self.metadata, f, indent=2)

# Custom Dataset class for processing literature in chunks
class ChunkedLiteratureDataset(Dataset):
def __init__(self, file_path, tokenizer, max_length=512, chunk_size=1000):
def __init__(self, file_path, tokenizer, max_length=512, chunk_size=100000):
self.file_path = file_path
self.tokenizer = tokenizer
self.max_length = max_length
Expand All @@ -23,6 +71,7 @@ def __init__(self, file_path, tokenizer, max_length=512, chunk_size=1000):
self.file = open(self.file_path, 'r', encoding='utf-8')

def load_chunk(self):
# Load a chunk of data from the file
self.current_chunk = []
for _ in range(self.chunk_size):
line = self.file.readline()
Expand All @@ -32,9 +81,11 @@ def load_chunk(self):
return len(self.current_chunk) > 0

def __len__(self):
return self.chunk_size
# Return the length of the current chunk
return len(self.current_chunk)

def __getitem__(self, idx):
# Get and tokenize an item from the current chunk
if idx >= len(self.current_chunk):
raise IndexError("Index out of bounds")
text = self.current_chunk[idx]
Expand All @@ -51,67 +102,88 @@ def __getitem__(self, idx):
'attention_mask': encoding['attention_mask'].flatten()
}

def process_in_batches(file_path, model, tokenizer, batch_size=32, max_length=512, chunk_size=1000):
# Function to process data in batches
def process_in_batches(file_path, model, tokenizer, embedding_saver, is_gene_based=True, batch_size=300, max_length=512, chunk_size=100000):
# Create a dataset
dataset = ChunkedLiteratureDataset(file_path, tokenizer, max_length, chunk_size)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

all_embeddings = []
chunk_count = 0
model.eval()
model.eval() # Set model to evaluation mode

with torch.no_grad():
while dataset.load_chunk():
with torch.no_grad(): # Disable gradient calculation
while dataset.load_chunk(): # Process data chunk by chunk
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
chunk_embeddings = []
for batch in tqdm(dataloader, desc=f"Processing chunk {chunk_count}"):
# Move batch to device and get model outputs
input_ids = batch['input_ids'].to(model.device)
attention_mask = batch['attention_mask'].to(model.device)
outputs = model(input_ids, attention_mask=attention_mask)
chunk_embeddings.append(outputs.last_hidden_state[:, 0, :].cpu())

# Concatenate embeddings for the current chunk
chunk_embeddings = torch.cat(chunk_embeddings)
all_embeddings.append(chunk_embeddings)

# Save embeddings periodically
if len(all_embeddings) % 10 == 0:
torch.save(torch.cat(all_embeddings), f'gene_based_embeddings_part{chunk_count//10}.pt')
embeddings_to_save = torch.cat(all_embeddings)
if is_gene_based:
embedding_saver.save_gene_based(embeddings_to_save, chunk_count//10)
else:
embedding_saver.save_baseline(embeddings_to_save, os.path.basename(file_path))
all_embeddings = [] # Clear the list to free up memory
gc.collect() # Force garbage collection

chunk_count += 1

# Save any remaining embeddings
if all_embeddings:
torch.save(torch.cat(all_embeddings), f'gene_based_embeddings_part{chunk_count//10}.pt')
embeddings_to_save = torch.cat(all_embeddings)
if is_gene_based:
embedding_saver.save_gene_based(embeddings_to_save, chunk_count//10)
else:
embedding_saver.save_baseline(embeddings_to_save, os.path.basename(file_path))

print(f"Processed {chunk_count} chunks from {file_path}")

# Set up paths
gene_based_path = r"D:\ZIP\arranged\gene_based_consolidated.txt"
gene_based_path = r"C:\Users\lrm22005\OneDrive - University of Connecticut\Research\ZIP11_Bioinformatic\capsule-3642152\ZIP11\output\arranged\consolidated_gene_docs.txt"
baseline_doc_dir = r"C:\Users\lrm22005\OneDrive - University of Connecticut\Research\ZIP11_Bioinformatic\capsule-3642152\ZIP11\output\preprocessed"
output_dir = r"C:\Users\lrm22005\OneDrive - University of Connecticut\Research\ZIP11_Bioinformatic\capsule-3642152\ZIP11\output\embeddings"

# Choose model
model_choice = 'albert' # Can be changed to 'bert', 'roberta', 'distilbert', or 'xlnet'
if model_choice not in MODELS:
raise ValueError(f"Model {model_choice} not available. Choose from: {list(MODELS.keys())}")

# Get the appropriate classes and model name for the chosen model
TokenizerClass, ModelClass, model_name = MODELS[model_choice]
print(f"Using {model_choice.upper()} model")

# Initialize model and tokenizer
tokenizer = AlbertTokenizer.from_pretrained(MODEL_NAME) if MODEL_NAME.startswith('albert') else BertTokenizer.from_pretrained(MODEL_NAME)
model = AlbertModel.from_pretrained(MODEL_NAME) if MODEL_NAME.startswith('albert') else BertModel.from_pretrained(MODEL_NAME)
tokenizer = TokenizerClass.from_pretrained(model_name)
model = ModelClass.from_pretrained(model_name)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
model.to(device) # Move model to GPU if available

# Initialize embedding saver
embedding_saver = EmbeddingSaver(output_dir)

# Process gene-based data
print("Processing gene-based data...")
process_in_batches(gene_based_path, model, tokenizer)
process_in_batches(gene_based_path, model, tokenizer, embedding_saver, is_gene_based=True)

# Process baseline documents (unchanged)
# Process baseline documents
print("Processing baseline documents...")
baseline_embeddings = []
for file in os.listdir(baseline_doc_dir):
if file.endswith('.txt'):
file_path = os.path.join(baseline_doc_dir, file)
embeddings = process_in_batches(file_path, model, tokenizer)
baseline_embeddings.append(embeddings)

# Combine baseline embeddings
baseline_embeddings = torch.cat(baseline_embeddings)
print(f"Processing {file}...")
process_in_batches(file_path, model, tokenizer, embedding_saver, is_gene_based=False, chunk_size=100000)

# Save baseline embeddings
torch.save(baseline_embeddings, 'baseline_embeddings.pt')
# Save metadata
embedding_saver.save_metadata()

print("Embedding process completed.")
Binary file not shown.
Binary file not shown.

0 comments on commit c61aafe

Please sign in to comment.