diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ab73d27..0fa8f37 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,6 +1,10 @@ name: Python CI -on: [push] +on: + push: + branches: + - main + pull_request: concurrency: group: ${{ github.workflow }}-${{ github.ref }} diff --git a/README.md b/README.md index 7cec949..2d062bb 100644 --- a/README.md +++ b/README.md @@ -98,6 +98,7 @@ This endpoint allows users to upload an image file, which is then processed to d "is_nsfw": "boolean", "confidence_percentage": "number" } + ``` #### Curl @@ -106,8 +107,57 @@ curl -X POST "http://127.0.0.1:8000/v1/detect" \ -H "Content-Type: multipart/form-data" \ -F "file=@/path/to/your/image.jpeg" ``` +### POST /v1/detect/urls +This endpoint allows users to provide image URLs, which are then processed to determine if the content is NSFW (Not Safe For Work). The response includes whether each image is considered NSFW and the confidence level of the prediction. +#### Request + +- **URL**: `/v1/detect/urls` +- **Method**: `POST` +- **Content-Type**: `application/json` +- **Body**: + ```json + { + "urls": [ + "https://example.com/image1.jpg", + "https://example.com/image2.jpg", + "https://example.com/image3.jpg", + "https://example.com/image4.jpg", + "https://example.com/image5.jpg" + ] + } + ``` + +#### Response + +- **Content-Type**: `application/json` +- **Body**: + ```json + [ + { + "url": "string", + "is_nsfw": "boolean", + "confidence_percentage": "number" + } + ] + ``` + +#### Curl + +```bash +curl -X POST "http://127.0.0.1:8000/v1/detect/urls" \ + -H "Content-Type: application/json" \ + -d '{ + "urls": [ + "https://example.com/image1.jpg", + "https://example.com/image2.jpg", + "https://example.com/image3.jpg", + "https://example.com/image4.jpg", + "https://example.com/image5.jpg" + ] + }' +``` ## 📄 License diff --git a/main.py b/main.py index d4e68bc..33afea1 100644 --- a/main.py +++ b/main.py @@ -3,13 +3,19 @@ import io import hashlib import logging -from fastapi import FastAPI, File, UploadFile +import aiohttp +from fastapi import FastAPI, File, UploadFile, HTTPException from fastapi.responses import JSONResponse from transformers import pipeline from transformers.pipelines import PipelineException from PIL import Image from cachetools import Cache import tensorflow as tf +from models import ( + FileImageDetectionResponse, + UrlImageDetectionResponse, + ImageUrlsRequest, +) app = FastAPI() @@ -25,23 +31,42 @@ model = pipeline("image-classification", model="falconsai/nsfw_image_detection") # Detect the device used by TensorFlow -DEVICE = "GPU" if tf.config.list_physical_devices('GPU') else "CPU" +DEVICE = "GPU" if tf.config.list_physical_devices("GPU") else "CPU" logging.info("TensorFlow version: %s", tf.__version__) logging.info("Model is using: %s", DEVICE) if DEVICE == "GPU": logging.info("GPUs available: %d", len(tf.config.list_physical_devices("GPU"))) + +async def download_image(image_url: str) -> bytes: + """Download an image from a URL.""" + async with aiohttp.ClientSession() as session: + async with session.get(image_url) as response: + if response.status != 200: + raise HTTPException( + status_code=response.status, detail="Image could not be retrieved." + ) + return await response.read() + + def hash_data(data): """Function for hashing image data.""" return hashlib.sha256(data).hexdigest() -@app.post("/v1/detect") -async def classify_image(file: UploadFile = File(...)): +@app.post("/v1/detect", response_model=FileImageDetectionResponse) +async def classify_image(file: UploadFile = File(None)): """Function analyzing image.""" + if file is None: + raise HTTPException( + status_code=400, + detail="An image file must be provided.", + ) + try: logging.info("Processing %s", file.filename) + # Read the image file image_data = await file.read() image_hash = hash_data(image_data) @@ -49,7 +74,11 @@ async def classify_image(file: UploadFile = File(...)): if image_hash in cache: # Return cached entry logging.info("Returning cached entry for %s", file.filename) - return JSONResponse(status_code=200, content=cache[image_hash]) + + cached_response = cache[image_hash] + response_data = {**cached_response, "file_name": file.filename} + + return FileImageDetectionResponse(**response_data) image = Image.open(io.BytesIO(image_data)) @@ -64,18 +93,79 @@ async def classify_image(file: UploadFile = File(...)): # Prepare the custom response data response_data = { - "file_name": file.filename, "is_nsfw": best_prediction["label"] == "nsfw", "confidence_percentage": confidence_percentage, } # Populate hash - cache[image_hash] = response_data + cache[image_hash] = response_data.copy() + + # Add file_name to the API response + response_data["file_name"] = file.filename - return JSONResponse(status_code=200, content=response_data) + return FileImageDetectionResponse(**response_data) except PipelineException as e: - return JSONResponse(status_code=500, content={"message": str(e)}) + logging.error("Error processing image: %s", str(e)) + raise HTTPException( + status_code=500, detail=f"Error processing image: {str(e)}" + ) from e + + +@app.post("/v1/detect/urls", response_model=list[UrlImageDetectionResponse]) +async def classify_images(request: ImageUrlsRequest): + """Function analyzing images from URLs.""" + response_data = [] + + for image_url in request.urls: + try: + logging.info("Downloading image from URL: %s", image_url) + image_data = await download_image(image_url) + image_hash = hash_data(image_data) + + if image_hash in cache: + # Return cached entry + logging.info("Returning cached entry for %s", image_url) + + cached_response = cache[image_hash] + response = {**cached_response, "url": image_url} + + response_data.append(response) + continue + + image = Image.open(io.BytesIO(image_data)) + + # Use the model to classify the image + results = model(image) + + # Find the prediction with the highest confidence using the max() function + best_prediction = max(results, key=lambda x: x["score"]) + + # Calculate the confidence score, rounded to the nearest tenth and as a percentage + confidence_percentage = round(best_prediction["score"] * 100, 1) + + # Prepare the custom response data + detection_result = { + "is_nsfw": best_prediction["label"] == "nsfw", + "confidence_percentage": confidence_percentage, + } + + # Populate hash + cache[image_hash] = detection_result.copy() + + # Add url to the API response + detection_result["url"] = image_url + + response_data.append(detection_result) + + except PipelineException as e: + logging.error("Error processing image from %s: %s", image_url, str(e)) + raise HTTPException( + status_code=500, + detail=f"Error processing image from {image_url}: {str(e)}", + ) from e + + return JSONResponse(status_code=200, content=response_data) if __name__ == "__main__": diff --git a/models.py b/models.py new file mode 100644 index 0000000..ed3a20f --- /dev/null +++ b/models.py @@ -0,0 +1,49 @@ +"""Module providing base models.""" + +from pydantic import BaseModel + + +class ImageUrlsRequest(BaseModel): + """ + Model representing the request body for the /v1/detect/urls endpoint. + + Attributes: + urls (list[str]): List of image URLs to be processed. + """ + + urls: list[str] + + +class ImageDetectionResponse(BaseModel): + """ + Base model representing the response body for image detection. + + Attributes: + is_nsfw (bool): Whether the image is classified as NSFW. + confidence_percentage (float): Confidence level of the NSFW classification. + """ + + is_nsfw: bool + confidence_percentage: float + + +class FileImageDetectionResponse(ImageDetectionResponse): + """ + Model extending ImageDetectionResponse with a file attribute. + + Attributes: + file (str): The name of the file that was processed. + """ + + file_name: str + + +class UrlImageDetectionResponse(ImageDetectionResponse): + """ + Model extending ImageDetectionResponse with a URL attribute. + + Attributes: + url (str): The URL of the image that was processed. + """ + + url: str diff --git a/requirements.txt b/requirements.txt index c26f0a8..fb306eb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,3 +7,4 @@ python-multipart==0.0.9 tensorflow==2.16.1 tf-keras==2.16.0 cachetools===5.3.3 +pydantic===2.7.2 diff --git a/test_main.py b/test_main.py index 3a5f49d..c15908d 100644 --- a/test_main.py +++ b/test_main.py @@ -36,16 +36,20 @@ def test_read_main(): def test_invalid_input(): - """Tests that POST /v1/detect returns 422 with empty request body""" + """Tests that POST /v1/detect returns 400 with empty request body""" response = client.post("/v1/detect", files={}) - assert response.status_code == 422 + assert response.status_code == 400 def test_cache_hit(): """Tests that the endpoint returns a cached response when an image hash matches""" # Compute the hash of the test file image_hash = compute_file_hash(FILE_NAME) - cached_response = {"message": "Cached data"} + cached_response = { + "file_name": FILE_NAME, + "is_nsfw": False, + "confidence_percentage": 100.0, + } with patch.dict(cache, {image_hash: cached_response}), patch( "main.logging.info" @@ -63,3 +67,17 @@ def test_cache_hit(): # Ensure logging was called correctly mock_logging.assert_called_with("Returning cached entry for %s", FILE_NAME) + +def test_detect_urls(): + """Tests that POST /v1/detect/urls returns 200 OK with valid request body""" + urls = [ + "https://raw.githubusercontent.com/steelcityamir/safe-content-ai/main/sunflower.jpg", + ] + response = client.post("/v1/detect/urls", json={"urls": urls}) + assert response.status_code == 200 + assert len(response.json()) == len(urls) + assert response.json()[0] == { + "url": urls[0], + "is_nsfw": False, + "confidence_percentage": 100.0, + }