Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add endpoint for uploading file with target IDs or coordinates #62

Draft
wants to merge 3 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion python/valis/routes/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@

from enum import Enum
from typing import List, Union, Dict, Annotated, Optional
from fastapi import APIRouter, Depends, Query, HTTPException
from fastapi import APIRouter, Depends, Query, HTTPException, UploadFile, File
from fastapi_restful.cbv import cbv
from pydantic import BaseModel, Field, BeforeValidator
import csv
import io

from valis.cache import valis_cache
from valis.routes.base import Base
Expand Down Expand Up @@ -260,3 +262,50 @@ async def get_target_list_by_mapper(self,
""" Return an ordered and paged list of targets based on the mapper."""
targets = get_paged_target_list_by_mapper(mapper, page_number, items_per_page)
return list(targets)

def validate_file_content(self, content: str):
reader = csv.reader(io.StringIO(content))
target_ids = []
invalid_ids = []
coordinates = []

for row in reader:
try:
if len(row) == 1:
target_id = int(row[0])
target_ids.append(target_id)
elif len(row) == 2:
ra, dec = float(row[0]), float(row[1])
coordinates.append((ra, dec))
else:
invalid_ids.append(row)
except ValueError:
invalid_ids.append(row)

if invalid_ids:
raise HTTPException(status_code=400, detail=f'Invalid target IDs or coordinates: {", ".join(map(str, invalid_ids))}')

return target_ids, coordinates

@router.post('/upload', summary='Upload a file with a list of target IDs or coordinates',
dependencies=[Depends(get_pw_db)],
response_model=MainSearchResponse)
async def upload_file(self, file: UploadFile = File(...)):
""" Upload a file with a list of target IDs or coordinates """

if file.content_type not in ['text/csv', 'text/plain']:
raise HTTPException(status_code=400, detail='Unsupported file format. Please upload a CSV or TXT file.')

content = await file.read()
content = content.decode('utf-8')

target_ids, coordinates = self.validate_file_content(content)

results = []
if target_ids:
results.extend(list(get_targets_by_sdss_id(target_ids)))
if coordinates:
for ra, dec in coordinates:
results.extend(list(cone_search(ra, dec, 0.02, units='degree')))

return {'status': 'success', 'data': results, 'msg': 'File processed successfully'}
38 changes: 38 additions & 0 deletions tests/test_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@
#

import pytest
from fastapi.testclient import TestClient
from valis.main import app
from valis.db.queries import convert_coords

client = TestClient(app)

@pytest.mark.parametrize('ra, dec, exp',
[('315.01417', '35.299', (315.01417, 35.299)),
Expand All @@ -19,6 +22,41 @@ def test_convert_coords(ra, dec, exp):
coord = convert_coords(ra, dec)
assert coord == exp

def test_upload_file_csv():
with open("tests/data/valid_targets.csv", "rb") as file:
response = client.post("/query/upload", files={"file": ("valid_targets.csv", file, "text/csv")})
assert response.status_code == 200
data = response.json()
assert data["status"] == "success"
assert len(data["data"]) > 0

def test_upload_file_txt():
with open("tests/data/valid_targets.txt", "rb") as file:
response = client.post("/query/upload", files={"file": ("valid_targets.txt", file, "text/plain")})
assert response.status_code == 200
data = response.json()
assert data["status"] == "success"
assert len(data["data"]) > 0

def test_upload_file_invalid_ids():
with open("tests/data/invalid_targets.csv", "rb") as file:
response = client.post("/query/upload", files={"file": ("invalid_targets.csv", file, "text/csv")})
assert response.status_code == 400
data = response.json()
assert "Invalid target IDs" in data["detail"]

def test_upload_file_with_coordinates():
with open("tests/data/valid_coordinates.csv", "rb") as file:
response = client.post("/query/upload", files={"file": ("valid_coordinates.csv", file, "text/csv")})
assert response.status_code == 200
data = response.json()
assert data["status"] == "success"
assert len(data["data"]) > 0

def test_upload_file_with_mixed_data():
with open("tests/data/mixed_data.csv", "rb") as file:
response = client.post("/query/upload", files={"file": ("mixed_data.csv", file, "text/csv")})
assert response.status_code == 200
data = response.json()
assert data["status"] == "success"
assert len(data["data"]) > 0
Loading