Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Created a script which is rough around the edges to check for work, c…
…lassify an image against a model, and post the result back to the data service
  • Loading branch information
jrc16107 committed Mar 30, 2021
1 parent 75c807b commit 4e1c93f
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 1 deletion.
4 changes: 3 additions & 1 deletion .gitignore
@@ -1,2 +1,4 @@
__pycache__
venv
env
.ipynb_checkpoints
.vscode
11 changes: 11 additions & 0 deletions requirements.txt
Expand Up @@ -3,36 +3,47 @@ astunparse==1.6.3
cachetools==4.1.1
certifi==2020.6.20
chardet==3.0.4
cycler==0.10.0
decorator==4.4.2
gast==0.3.3
google-auth==1.22.1
google-auth-oauthlib==0.4.1
google-pasta==0.2.0
grpcio==1.33.1
h5py==2.10.0
idna==2.10
imageio==2.9.0
Keras==2.4.3
Keras-Preprocessing==1.1.2
kiwisolver==1.3.1
Markdown==3.3.2
matplotlib==3.4.0
networkx==2.5
numpy==1.18.5
oauthlib==3.1.0
opt-einsum==3.3.0
pandas==1.1.3
Pillow==8.1.2
protobuf==3.13.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pyparsing==2.4.7
python-dateutil==2.8.1
pytz==2020.1
PyWavelets==1.1.1
PyYAML==5.3.1
requests==2.24.0
requests-oauthlib==1.3.0
rsa==4.6
scikit-image==0.18.1
scipy==1.5.3
six==1.15.0
tensorboard==2.3.0
tensorboard-plugin-wit==1.7.0
tensorflow==2.3.1
tensorflow-estimator==2.3.0
termcolor==1.1.0
tifffile==2021.3.17
urllib3==1.25.11
Werkzeug==1.0.1
wrapt==1.12.1
84 changes: 84 additions & 0 deletions src/__init__.py
@@ -0,0 +1,84 @@
import os
import keras
import numpy
import requests
import sys
import tensorflow
import time

from PIL import Image
from skimage import transform
from io import BytesIO

def worker_loop(hostname, model):
# First step: Check for work from the data service
work_check = requests.post(hostname + "/requestwork")
work_check = work_check.json()
classification_id = work_check['classification_id']

# This is the do nothing and wait case.
if classification_id == None:
return False

# If we've been assigned work we move onto step 2

# Get the image assigned to the classification request
image_response = requests.get(hostname + "/getimage?classification_id={}".format(classification_id))
image = Image.open(BytesIO(image_response.content))
image = numpy.array(image).astype('float32')/255
image = transform.resize(image, (218, 178, 3)) # TODO - parameterize this image resize
image = numpy.expand_dims(image, axis=0)

# Make a prediction on the image using the model
pred = model.predict(image)[0, 0]

# Post the prediction back to the data service
response_dict = dict()
response_dict['classification_id'] = classification_id
response_dict['assigned_on'] = work_check['assigned_on']
response_dict['errors'] = "" # TODO - Figure out where keras will spit out errors

# Using the result of the prediction assing a classification
# TODO - In the future these possibilities should be fetchable from the data service
if round(pred) == 1:
response_dict['classification'] = "covid-19"
else:
response_dict['classification'] = "not covid-19"

# Report the predicted classification back to the data service
requests.post(hostname + "/reportclassification", json=response_dict)

# Return true to denote that there was previously work assigned to the worker instance
# and there may still be more images to classify
return True

if __name__ == "__main__":
arguments = sys.argv[1:] # ignore __init__.py argument
argc = len(arguments)

hostname = "http://127.0.0.1:8000"
model = "../covid_fine_tuned.h5"
interval = 15.0

for i in range(0,len(arguments),2):
parameter = arguments[i]
if i+1 >= len(arguments):
raise Exception("No value passed for parameter {}", parameter)

value = arguments[i+1]

if parameter == "--hostname":
hostname = value
elif parameter == "--model":
model = value
elif parameter == "--retry_interval":
interval = float(value)

model = keras.models.load_model(model)

# The loop of the worker instance
while True:
work_check = worker_loop(hostname, model)
if not work_check:
time.sleep(interval)

0 comments on commit 4e1c93f

Please sign in to comment.