From 4e1c93fc4bc383ecfa714f566ea7583a7321351e Mon Sep 17 00:00:00 2001 From: Jamey Calabrese Date: Tue, 30 Mar 2021 00:56:23 -0400 Subject: [PATCH] Created a script which is rough around the edges to check for work, classify an image against a model, and post the result back to the data service --- .gitignore | 4 ++- requirements.txt | 11 +++++++ src/__init__.py | 84 ++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 98 insertions(+), 1 deletion(-) create mode 100644 src/__init__.py diff --git a/.gitignore b/.gitignore index 82adb58..de75073 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ __pycache__ -venv +env +.ipynb_checkpoints +.vscode diff --git a/requirements.txt b/requirements.txt index 7763fb4..4c15161 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,6 +3,8 @@ astunparse==1.6.3 cachetools==4.1.1 certifi==2020.6.20 chardet==3.0.4 +cycler==0.10.0 +decorator==4.4.2 gast==0.3.3 google-auth==1.22.1 google-auth-oauthlib==0.4.1 @@ -10,22 +12,30 @@ google-pasta==0.2.0 grpcio==1.33.1 h5py==2.10.0 idna==2.10 +imageio==2.9.0 Keras==2.4.3 Keras-Preprocessing==1.1.2 +kiwisolver==1.3.1 Markdown==3.3.2 +matplotlib==3.4.0 +networkx==2.5 numpy==1.18.5 oauthlib==3.1.0 opt-einsum==3.3.0 pandas==1.1.3 +Pillow==8.1.2 protobuf==3.13.0 pyasn1==0.4.8 pyasn1-modules==0.2.8 +pyparsing==2.4.7 python-dateutil==2.8.1 pytz==2020.1 +PyWavelets==1.1.1 PyYAML==5.3.1 requests==2.24.0 requests-oauthlib==1.3.0 rsa==4.6 +scikit-image==0.18.1 scipy==1.5.3 six==1.15.0 tensorboard==2.3.0 @@ -33,6 +43,7 @@ tensorboard-plugin-wit==1.7.0 tensorflow==2.3.1 tensorflow-estimator==2.3.0 termcolor==1.1.0 +tifffile==2021.3.17 urllib3==1.25.11 Werkzeug==1.0.1 wrapt==1.12.1 diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..7ef36a1 --- /dev/null +++ b/src/__init__.py @@ -0,0 +1,84 @@ +import os +import keras +import numpy +import requests +import sys +import tensorflow +import time + +from PIL import Image +from skimage import transform +from io import BytesIO + +def worker_loop(hostname, model): + # First step: Check for work from the data service + work_check = requests.post(hostname + "/requestwork") + work_check = work_check.json() + classification_id = work_check['classification_id'] + + # This is the do nothing and wait case. + if classification_id == None: + return False + + # If we've been assigned work we move onto step 2 + + # Get the image assigned to the classification request + image_response = requests.get(hostname + "/getimage?classification_id={}".format(classification_id)) + image = Image.open(BytesIO(image_response.content)) + image = numpy.array(image).astype('float32')/255 + image = transform.resize(image, (218, 178, 3)) # TODO - parameterize this image resize + image = numpy.expand_dims(image, axis=0) + + # Make a prediction on the image using the model + pred = model.predict(image)[0, 0] + + # Post the prediction back to the data service + response_dict = dict() + response_dict['classification_id'] = classification_id + response_dict['assigned_on'] = work_check['assigned_on'] + response_dict['errors'] = "" # TODO - Figure out where keras will spit out errors + + # Using the result of the prediction assing a classification + # TODO - In the future these possibilities should be fetchable from the data service + if round(pred) == 1: + response_dict['classification'] = "covid-19" + else: + response_dict['classification'] = "not covid-19" + + # Report the predicted classification back to the data service + requests.post(hostname + "/reportclassification", json=response_dict) + + # Return true to denote that there was previously work assigned to the worker instance + # and there may still be more images to classify + return True + +if __name__ == "__main__": + arguments = sys.argv[1:] # ignore __init__.py argument + argc = len(arguments) + + hostname = "http://127.0.0.1:8000" + model = "../covid_fine_tuned.h5" + interval = 15.0 + + for i in range(0,len(arguments),2): + parameter = arguments[i] + if i+1 >= len(arguments): + raise Exception("No value passed for parameter {}", parameter) + + value = arguments[i+1] + + if parameter == "--hostname": + hostname = value + elif parameter == "--model": + model = value + elif parameter == "--retry_interval": + interval = float(value) + + model = keras.models.load_model(model) + + # The loop of the worker instance + while True: + work_check = worker_loop(hostname, model) + if not work_check: + time.sleep(interval) +