Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
CSE_SDP_G28_ML/src/__init__.py
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
84 lines (66 sloc)
2.74 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import keras | |
import numpy | |
import requests | |
import sys | |
import tensorflow | |
import time | |
from PIL import Image | |
from skimage import transform | |
from io import BytesIO | |
def worker_loop(hostname, model): | |
# First step: Check for work from the data service | |
work_check = requests.post(hostname + "/requestwork") | |
work_check = work_check.json() | |
classification_id = work_check['classification_id'] | |
# This is the do nothing and wait case. | |
if classification_id == None: | |
return False | |
# If we've been assigned work we move onto step 2 | |
# Get the image assigned to the classification request | |
image_response = requests.get(hostname + "/getimage?classification_id={}".format(classification_id)) | |
image = Image.open(BytesIO(image_response.content)) | |
image = numpy.array(image).astype('float32')/255 | |
image = transform.resize(image, (218, 178, 3)) # TODO - parameterize this image resize | |
image = numpy.expand_dims(image, axis=0) | |
# Make a prediction on the image using the model | |
pred = model.predict(image)[0, 0] | |
# Post the prediction back to the data service | |
response_dict = dict() | |
response_dict['classification_id'] = classification_id | |
response_dict['assigned_on'] = work_check['assigned_on'] | |
response_dict['errors'] = "" # TODO - Figure out where keras will spit out errors | |
# Using the result of the prediction assing a classification | |
# TODO - In the future these possibilities should be fetchable from the data service | |
if round(pred) == 0: | |
response_dict['classification'] = "covid-19" | |
else: | |
response_dict['classification'] = "not covid-19" | |
# Report the predicted classification back to the data service | |
requests.post(hostname + "/reportclassification", json=response_dict) | |
# Return true to denote that there was previously work assigned to the worker instance | |
# and there may still be more images to classify | |
return True | |
if __name__ == "__main__": | |
arguments = sys.argv[1:] # ignore __init__.py argument | |
argc = len(arguments) | |
hostname = "http://127.0.0.1:8000" | |
model = "../covid_fine_tuned.h5" | |
interval = 15.0 | |
for i in range(0,len(arguments),2): | |
parameter = arguments[i] | |
if i+1 >= len(arguments): | |
raise Exception("No value passed for parameter {}", parameter) | |
value = arguments[i+1] | |
if parameter == "--hostname": | |
hostname = value | |
elif parameter == "--model": | |
model = value | |
elif parameter == "--retry_interval": | |
interval = float(value) | |
model = keras.models.load_model(model) | |
# The loop of the worker instance | |
while True: | |
work_check = worker_loop(hostname, model) | |
if not work_check: | |
time.sleep(interval) |