Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
CoinHMM/main.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
103 lines (80 sloc)
2.54 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Import libraries | |
# __future_ meaning Python 3, which defaults int division as floats, so 3/2 = 1.5 not 1 | |
from __future__ import division | |
import time | |
def read_in_sequences(f_name): | |
''' | |
Reads in file of sequences, outputs array of sequences | |
Assuming files consists of one sequence per line, i.e. | |
THT | |
THH | |
HHH | |
TTT | |
etc | |
''' | |
# Open f_name, which is arg to this function, open as readable | |
f_in = open(f_name, 'r') | |
sequences = [] | |
# Read lines of file into an arrray | |
lines = f_in.readlines() | |
for i in lines: | |
# Append line, stripped of formatting etc | |
sequences.append(i.strip()) | |
# Replace heads with 1, tails with 0 | |
sequences = [s.replace("H", "1").replace("T", "0") for s in sequences] | |
return sequences | |
def update_parameters(sequences, L, p1, p2, N, M): | |
''' | |
Calculates n1, n2 given sequences, lambda, p1, and p2 | |
''' | |
P1 = 0 | |
P2 = 0 | |
n1 = 0 # Number of times coin 1 chosen | |
n2 = 0 # Number of times coin 2 chosen | |
e1h = 0 # Number of times heads chosen for coin 1 | |
e2h = 0 # Number of times heads chosen for coin 2 | |
num_heads = 0 # Number of heads in all sequences | |
# Find n1, n2 | |
for s in sequences: | |
num_heads += s.count('1') | |
P1 = L*(p1**(s.count('1')) * (1-p1)**(s.count('0'))) | |
P2 = (1-L)*(p2**(s.count('1')) * (1-p2)**(s.count('0'))) | |
n1 += P1 / (P1 + P2) | |
n2 += P2 / (P1 + P2) | |
# Update params | |
L_new = n1 / (n1 + n2) | |
e1h = (num_heads * L) | |
e2h = (num_heads * (1-L)) | |
p1_new = e1h / (n1 * N) | |
p2_new = e2h / (n2 * N) | |
return L_new, p1_new, p2_new | |
def main(): | |
''' | |
main func | |
''' | |
max_its = 1000000 | |
inputs = "Input/input3.txt" | |
sequences = read_in_sequences(inputs) | |
# Get M, N | |
M = len(sequences) # Len of array (how many elements) | |
N = len(sequences[0]) # Length of one sequence, Assuming all sequences are the same length in a given file | |
# Initial guesses | |
L_o = 0.3 | |
p1_o = 0.2 | |
p2_o = 0.3 | |
i = 0 | |
start = time.time() | |
while i <= max_its: | |
L_new, p1_new, p2_new = update_parameters(sequences, L_o, p1_o, p2_o, N, M) | |
error_L = abs((L_new - L_o) / L_o) | |
error_p1 = abs((p1_new - p1_o) / p1_o) | |
error_p2 = abs((p2_new - p2_o) / p2_o) | |
L_o = L_new | |
p1_o = p1_new | |
p2_o = p2_new | |
i += 1 | |
print "Lambda: %s" % L_o | |
print "p1: %s" % p1_o | |
print "p2: %s" % p2_o | |
if __name__ == "__main__": | |
main() |