Skip to content

Commit

Permalink
Image url support (#3)
Browse files Browse the repository at this point in the history
* Update main.py

* Update ci.yml

* Update README.md

* Update requirements.txt

* Create models.py

* Update main.py

* Update test_main.py

* Update README.md

* Update models.py

* Update test_main.py

* Update main.py

* Update main.py

* Update models.py

* Update main.py
  • Loading branch information
steelcityamir authored May 31, 2024
1 parent 75b9cae commit 4ce040a
Show file tree
Hide file tree
Showing 6 changed files with 225 additions and 13 deletions.
6 changes: 5 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
name: Python CI

on: [push]
on:
push:
branches:
- main
pull_request:

concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
Expand Down
50 changes: 50 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ This endpoint allows users to upload an image file, which is then processed to d
"is_nsfw": "boolean",
"confidence_percentage": "number"
}
```

#### Curl

Expand All @@ -106,8 +107,57 @@ curl -X POST "http://127.0.0.1:8000/v1/detect" \
-H "Content-Type: multipart/form-data" \
-F "file=@/path/to/your/image.jpeg"
```
### POST /v1/detect/urls

This endpoint allows users to provide image URLs, which are then processed to determine if the content is NSFW (Not Safe For Work). The response includes whether each image is considered NSFW and the confidence level of the prediction.

#### Request

- **URL**: `/v1/detect/urls`
- **Method**: `POST`
- **Content-Type**: `application/json`
- **Body**:
```json
{
"urls": [
"https://example.com/image1.jpg",
"https://example.com/image2.jpg",
"https://example.com/image3.jpg",
"https://example.com/image4.jpg",
"https://example.com/image5.jpg"
]
}
```

#### Response

- **Content-Type**: `application/json`
- **Body**:
```json
[
{
"url": "string",
"is_nsfw": "boolean",
"confidence_percentage": "number"
}
]
```

#### Curl

```bash
curl -X POST "http://127.0.0.1:8000/v1/detect/urls" \
-H "Content-Type: application/json" \
-d '{
"urls": [
"https://example.com/image1.jpg",
"https://example.com/image2.jpg",
"https://example.com/image3.jpg",
"https://example.com/image4.jpg",
"https://example.com/image5.jpg"
]
}'
```

## 📄 License

Expand Down
108 changes: 99 additions & 9 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,19 @@
import io
import hashlib
import logging
from fastapi import FastAPI, File, UploadFile
import aiohttp
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.responses import JSONResponse
from transformers import pipeline
from transformers.pipelines import PipelineException
from PIL import Image
from cachetools import Cache
import tensorflow as tf
from models import (
FileImageDetectionResponse,
UrlImageDetectionResponse,
ImageUrlsRequest,
)


app = FastAPI()
Expand All @@ -25,31 +31,54 @@
model = pipeline("image-classification", model="falconsai/nsfw_image_detection")

# Detect the device used by TensorFlow
DEVICE = "GPU" if tf.config.list_physical_devices('GPU') else "CPU"
DEVICE = "GPU" if tf.config.list_physical_devices("GPU") else "CPU"
logging.info("TensorFlow version: %s", tf.__version__)
logging.info("Model is using: %s", DEVICE)

if DEVICE == "GPU":
logging.info("GPUs available: %d", len(tf.config.list_physical_devices("GPU")))


async def download_image(image_url: str) -> bytes:
"""Download an image from a URL."""
async with aiohttp.ClientSession() as session:
async with session.get(image_url) as response:
if response.status != 200:
raise HTTPException(
status_code=response.status, detail="Image could not be retrieved."
)
return await response.read()


def hash_data(data):
"""Function for hashing image data."""
return hashlib.sha256(data).hexdigest()


@app.post("/v1/detect")
async def classify_image(file: UploadFile = File(...)):
@app.post("/v1/detect", response_model=FileImageDetectionResponse)
async def classify_image(file: UploadFile = File(None)):
"""Function analyzing image."""
if file is None:
raise HTTPException(
status_code=400,
detail="An image file must be provided.",
)

try:
logging.info("Processing %s", file.filename)

# Read the image file
image_data = await file.read()
image_hash = hash_data(image_data)

if image_hash in cache:
# Return cached entry
logging.info("Returning cached entry for %s", file.filename)
return JSONResponse(status_code=200, content=cache[image_hash])

cached_response = cache[image_hash]
response_data = {**cached_response, "file_name": file.filename}

return FileImageDetectionResponse(**response_data)

image = Image.open(io.BytesIO(image_data))

Expand All @@ -64,18 +93,79 @@ async def classify_image(file: UploadFile = File(...)):

# Prepare the custom response data
response_data = {
"file_name": file.filename,
"is_nsfw": best_prediction["label"] == "nsfw",
"confidence_percentage": confidence_percentage,
}

# Populate hash
cache[image_hash] = response_data
cache[image_hash] = response_data.copy()

# Add file_name to the API response
response_data["file_name"] = file.filename

return JSONResponse(status_code=200, content=response_data)
return FileImageDetectionResponse(**response_data)

except PipelineException as e:
return JSONResponse(status_code=500, content={"message": str(e)})
logging.error("Error processing image: %s", str(e))
raise HTTPException(
status_code=500, detail=f"Error processing image: {str(e)}"
) from e


@app.post("/v1/detect/urls", response_model=list[UrlImageDetectionResponse])
async def classify_images(request: ImageUrlsRequest):
"""Function analyzing images from URLs."""
response_data = []

for image_url in request.urls:
try:
logging.info("Downloading image from URL: %s", image_url)
image_data = await download_image(image_url)
image_hash = hash_data(image_data)

if image_hash in cache:
# Return cached entry
logging.info("Returning cached entry for %s", image_url)

cached_response = cache[image_hash]
response = {**cached_response, "url": image_url}

response_data.append(response)
continue

image = Image.open(io.BytesIO(image_data))

# Use the model to classify the image
results = model(image)

# Find the prediction with the highest confidence using the max() function
best_prediction = max(results, key=lambda x: x["score"])

# Calculate the confidence score, rounded to the nearest tenth and as a percentage
confidence_percentage = round(best_prediction["score"] * 100, 1)

# Prepare the custom response data
detection_result = {
"is_nsfw": best_prediction["label"] == "nsfw",
"confidence_percentage": confidence_percentage,
}

# Populate hash
cache[image_hash] = detection_result.copy()

# Add url to the API response
detection_result["url"] = image_url

response_data.append(detection_result)

except PipelineException as e:
logging.error("Error processing image from %s: %s", image_url, str(e))
raise HTTPException(
status_code=500,
detail=f"Error processing image from {image_url}: {str(e)}",
) from e

return JSONResponse(status_code=200, content=response_data)


if __name__ == "__main__":
Expand Down
49 changes: 49 additions & 0 deletions models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Module providing base models."""

from pydantic import BaseModel


class ImageUrlsRequest(BaseModel):
"""
Model representing the request body for the /v1/detect/urls endpoint.
Attributes:
urls (list[str]): List of image URLs to be processed.
"""

urls: list[str]


class ImageDetectionResponse(BaseModel):
"""
Base model representing the response body for image detection.
Attributes:
is_nsfw (bool): Whether the image is classified as NSFW.
confidence_percentage (float): Confidence level of the NSFW classification.
"""

is_nsfw: bool
confidence_percentage: float


class FileImageDetectionResponse(ImageDetectionResponse):
"""
Model extending ImageDetectionResponse with a file attribute.
Attributes:
file (str): The name of the file that was processed.
"""

file_name: str


class UrlImageDetectionResponse(ImageDetectionResponse):
"""
Model extending ImageDetectionResponse with a URL attribute.
Attributes:
url (str): The URL of the image that was processed.
"""

url: str
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ python-multipart==0.0.9
tensorflow==2.16.1
tf-keras==2.16.0
cachetools===5.3.3
pydantic===2.7.2
24 changes: 21 additions & 3 deletions test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,20 @@ def test_read_main():


def test_invalid_input():
"""Tests that POST /v1/detect returns 422 with empty request body"""
"""Tests that POST /v1/detect returns 400 with empty request body"""
response = client.post("/v1/detect", files={})
assert response.status_code == 422
assert response.status_code == 400


def test_cache_hit():
"""Tests that the endpoint returns a cached response when an image hash matches"""
# Compute the hash of the test file
image_hash = compute_file_hash(FILE_NAME)
cached_response = {"message": "Cached data"}
cached_response = {
"file_name": FILE_NAME,
"is_nsfw": False,
"confidence_percentage": 100.0,
}

with patch.dict(cache, {image_hash: cached_response}), patch(
"main.logging.info"
Expand All @@ -63,3 +67,17 @@ def test_cache_hit():

# Ensure logging was called correctly
mock_logging.assert_called_with("Returning cached entry for %s", FILE_NAME)

def test_detect_urls():
"""Tests that POST /v1/detect/urls returns 200 OK with valid request body"""
urls = [
"https://raw.githubusercontent.com/steelcityamir/safe-content-ai/main/sunflower.jpg",
]
response = client.post("/v1/detect/urls", json={"urls": urls})
assert response.status_code == 200
assert len(response.json()) == len(urls)
assert response.json()[0] == {
"url": urls[0],
"is_nsfw": False,
"confidence_percentage": 100.0,
}

0 comments on commit 4ce040a

Please sign in to comment.