diff --git a/python/valis/routes/query.py b/python/valis/routes/query.py index 4459245..5680657 100644 --- a/python/valis/routes/query.py +++ b/python/valis/routes/query.py @@ -4,9 +4,11 @@ from enum import Enum from typing import List, Union, Dict, Annotated, Optional -from fastapi import APIRouter, Depends, Query, HTTPException +from fastapi import APIRouter, Depends, Query, HTTPException, UploadFile, File from fastapi_restful.cbv import cbv from pydantic import BaseModel, Field, BeforeValidator +import csv +import io from valis.cache import valis_cache from valis.routes.base import Base @@ -260,3 +262,50 @@ async def get_target_list_by_mapper(self, """ Return an ordered and paged list of targets based on the mapper.""" targets = get_paged_target_list_by_mapper(mapper, page_number, items_per_page) return list(targets) + + def validate_file_content(self, content: str): + reader = csv.reader(io.StringIO(content)) + target_ids = [] + invalid_ids = [] + coordinates = [] + + for row in reader: + try: + if len(row) == 1: + target_id = int(row[0]) + target_ids.append(target_id) + elif len(row) == 2: + ra, dec = float(row[0]), float(row[1]) + coordinates.append((ra, dec)) + else: + invalid_ids.append(row) + except ValueError: + invalid_ids.append(row) + + if invalid_ids: + raise HTTPException(status_code=400, detail=f'Invalid target IDs or coordinates: {", ".join(map(str, invalid_ids))}') + + return target_ids, coordinates + + @router.post('/upload', summary='Upload a file with a list of target IDs or coordinates', + dependencies=[Depends(get_pw_db)], + response_model=MainSearchResponse) + async def upload_file(self, file: UploadFile = File(...)): + """ Upload a file with a list of target IDs or coordinates """ + + if file.content_type not in ['text/csv', 'text/plain']: + raise HTTPException(status_code=400, detail='Unsupported file format. Please upload a CSV or TXT file.') + + content = await file.read() + content = content.decode('utf-8') + + target_ids, coordinates = self.validate_file_content(content) + + results = [] + if target_ids: + results.extend(list(get_targets_by_sdss_id(target_ids))) + if coordinates: + for ra, dec in coordinates: + results.extend(list(cone_search(ra, dec, 0.02, units='degree'))) + + return {'status': 'success', 'data': results, 'msg': 'File processed successfully'} diff --git a/tests/test_queries.py b/tests/test_queries.py index 316cc84..7a5e611 100644 --- a/tests/test_queries.py +++ b/tests/test_queries.py @@ -2,8 +2,11 @@ # import pytest +from fastapi.testclient import TestClient +from valis.main import app from valis.db.queries import convert_coords +client = TestClient(app) @pytest.mark.parametrize('ra, dec, exp', [('315.01417', '35.299', (315.01417, 35.299)), @@ -19,6 +22,41 @@ def test_convert_coords(ra, dec, exp): coord = convert_coords(ra, dec) assert coord == exp +def test_upload_file_csv(): + with open("tests/data/valid_targets.csv", "rb") as file: + response = client.post("/query/upload", files={"file": ("valid_targets.csv", file, "text/csv")}) + assert response.status_code == 200 + data = response.json() + assert data["status"] == "success" + assert len(data["data"]) > 0 +def test_upload_file_txt(): + with open("tests/data/valid_targets.txt", "rb") as file: + response = client.post("/query/upload", files={"file": ("valid_targets.txt", file, "text/plain")}) + assert response.status_code == 200 + data = response.json() + assert data["status"] == "success" + assert len(data["data"]) > 0 +def test_upload_file_invalid_ids(): + with open("tests/data/invalid_targets.csv", "rb") as file: + response = client.post("/query/upload", files={"file": ("invalid_targets.csv", file, "text/csv")}) + assert response.status_code == 400 + data = response.json() + assert "Invalid target IDs" in data["detail"] +def test_upload_file_with_coordinates(): + with open("tests/data/valid_coordinates.csv", "rb") as file: + response = client.post("/query/upload", files={"file": ("valid_coordinates.csv", file, "text/csv")}) + assert response.status_code == 200 + data = response.json() + assert data["status"] == "success" + assert len(data["data"]) > 0 + +def test_upload_file_with_mixed_data(): + with open("tests/data/mixed_data.csv", "rb") as file: + response = client.post("/query/upload", files={"file": ("mixed_data.csv", file, "text/csv")}) + assert response.status_code == 200 + data = response.json() + assert data["status"] == "success" + assert len(data["data"]) > 0