diff --git a/makefile b/makefile new file mode 100644 index 0000000..fc30616 --- /dev/null +++ b/makefile @@ -0,0 +1,12 @@ +local-run: + python3 src/model_server.py & python3 src/pipeline/app.py + bash scripts/kill_model_server.sh + +run-pipeline: + python3 src/pipeline/app.py + +model-server: + python3 src/model_server.py + +kill-model-server: + bash scripts/kill_model_server.sh diff --git a/scripts/kill_model_server.sh b/scripts/kill_model_server.sh new file mode 100644 index 0000000..c5435ae --- /dev/null +++ b/scripts/kill_model_server.sh @@ -0,0 +1,16 @@ +echo "Checking for running model_server.py files..." +ps aux | grep python | grep whale | grep model_server.py + +PID=$(ps aux | grep python | grep whale | grep model_server.py | awk '{print $2}') +if [ -z "$PID" ]; then + echo "No model_server.py files running." +else + echo "Killing PID: $PID" + kill -9 $PID + sleep 2 + + echo "Checking for running model_server.py files..." + ps aux | grep python | grep whale | grep model_server.py +fi + +echo "Done." \ No newline at end of file diff --git a/src/model_server.py b/src/model_server.py new file mode 100644 index 0000000..1a3de5a --- /dev/null +++ b/src/model_server.py @@ -0,0 +1,54 @@ +from flask import Flask, request, jsonify +import tensorflow_hub as hub +import numpy as np +import tensorflow as tf + +import logging + + +# Load the TensorFlow model +print("Loading model...") +# model = hub.load("https://www.kaggle.com/models/google/humpback-whale/TensorFlow2/humpback-whale/1") +model = hub.load("https://tfhub.dev/google/humpback_whale/1") +score_fn = model.signatures["score"] +print("Model loaded.") + +# Initialize Flask app +app = Flask(__name__) + +# Define the predict endpoint +@app.route('/predict', methods=['POST']) +def predict(): + try: + # Parse the request data + data = request.json + batch = np.array(data['batch'], dtype=np.float32) # Assuming batch is passed as a list + key = data['key'] + print(f"batch.shape = {batch.shape}") + + # Prepare the input for the model + waveform_exp = tf.expand_dims(batch, 0) # Expanding dimensions to fit model input shape + print(f"waveform_exp.shape = {waveform_exp.shape}") + + # Run inference + results = score_fn( + waveform=waveform_exp, # waveform_exp, + context_step_samples=10_000 + )["scores"][0] # NOTE currently only support batch size 1 + print(f"results.shape = {results.shape}") + print("results = ", results) + + # Return the predictions and key as JSON + return jsonify({ + 'key': key, + 'predictions': results.numpy().tolist() + }) + + except Exception as e: + logging.error(f"An error occurred: {str(e)}") + print(f"An error occurred: {str(e)}") + return jsonify({'error': str(e)}), 500 + +# Main entry point +if __name__ == "__main__": + app.run(host='0.0.0.0', port=5000, debug=True)