Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Browse files
Browse the repository at this point in the history
Created a script which is rough around the edges to check for work, c…
…lassify an image against a model, and post the result back to the data service
- Loading branch information
Showing
3 changed files
with
98 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,4 @@ | ||
__pycache__ | ||
venv | ||
env | ||
.ipynb_checkpoints | ||
.vscode |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import os | ||
import keras | ||
import numpy | ||
import requests | ||
import sys | ||
import tensorflow | ||
import time | ||
|
||
from PIL import Image | ||
from skimage import transform | ||
from io import BytesIO | ||
|
||
def worker_loop(hostname, model): | ||
# First step: Check for work from the data service | ||
work_check = requests.post(hostname + "/requestwork") | ||
work_check = work_check.json() | ||
classification_id = work_check['classification_id'] | ||
|
||
# This is the do nothing and wait case. | ||
if classification_id == None: | ||
return False | ||
|
||
# If we've been assigned work we move onto step 2 | ||
|
||
# Get the image assigned to the classification request | ||
image_response = requests.get(hostname + "/getimage?classification_id={}".format(classification_id)) | ||
image = Image.open(BytesIO(image_response.content)) | ||
image = numpy.array(image).astype('float32')/255 | ||
image = transform.resize(image, (218, 178, 3)) # TODO - parameterize this image resize | ||
image = numpy.expand_dims(image, axis=0) | ||
|
||
# Make a prediction on the image using the model | ||
pred = model.predict(image)[0, 0] | ||
|
||
# Post the prediction back to the data service | ||
response_dict = dict() | ||
response_dict['classification_id'] = classification_id | ||
response_dict['assigned_on'] = work_check['assigned_on'] | ||
response_dict['errors'] = "" # TODO - Figure out where keras will spit out errors | ||
|
||
# Using the result of the prediction assing a classification | ||
# TODO - In the future these possibilities should be fetchable from the data service | ||
if round(pred) == 1: | ||
response_dict['classification'] = "covid-19" | ||
else: | ||
response_dict['classification'] = "not covid-19" | ||
|
||
# Report the predicted classification back to the data service | ||
requests.post(hostname + "/reportclassification", json=response_dict) | ||
|
||
# Return true to denote that there was previously work assigned to the worker instance | ||
# and there may still be more images to classify | ||
return True | ||
|
||
if __name__ == "__main__": | ||
arguments = sys.argv[1:] # ignore __init__.py argument | ||
argc = len(arguments) | ||
|
||
hostname = "http://127.0.0.1:8000" | ||
model = "../covid_fine_tuned.h5" | ||
interval = 15.0 | ||
|
||
for i in range(0,len(arguments),2): | ||
parameter = arguments[i] | ||
if i+1 >= len(arguments): | ||
raise Exception("No value passed for parameter {}", parameter) | ||
|
||
value = arguments[i+1] | ||
|
||
if parameter == "--hostname": | ||
hostname = value | ||
elif parameter == "--model": | ||
model = value | ||
elif parameter == "--retry_interval": | ||
interval = float(value) | ||
|
||
model = keras.models.load_model(model) | ||
|
||
# The loop of the worker instance | ||
while True: | ||
work_check = worker_loop(hostname, model) | ||
if not work_check: | ||
time.sleep(interval) | ||
|