|
| 1 | +from fastapi import Depends |
| 2 | +from fastapi.testclient import TestClient |
| 3 | +from sqlalchemy import create_engine |
| 4 | +from sqlalchemy.orm import sessionmaker |
| 5 | +from sqlalchemy.pool import StaticPool |
| 6 | + |
| 7 | +from sqlalchemy.orm.session import Session |
| 8 | + |
| 9 | +import db.models as db_models |
| 10 | +from repository.finding import FindingRepository, get_finding_repository |
| 11 | +from repository.recommendation import ( |
| 12 | + RecommendationRepository, |
| 13 | + get_recommendation_repository, |
| 14 | +) |
| 15 | +from repository.task import TaskRepository, get_task_repository |
| 16 | +from app import app |
| 17 | + |
| 18 | +SQLALCHEMY_DATABASE_URL = "sqlite://" |
| 19 | + |
| 20 | +engine = create_engine( |
| 21 | + SQLALCHEMY_DATABASE_URL, |
| 22 | + connect_args={"check_same_thread": False}, |
| 23 | + poolclass=StaticPool, |
| 24 | +) |
| 25 | +TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) |
| 26 | + |
| 27 | + |
| 28 | +db_models.BaseModel.metadata.create_all(bind=engine) |
| 29 | + |
| 30 | + |
| 31 | +def override_get_db(): |
| 32 | + try: |
| 33 | + db = TestingSessionLocal() |
| 34 | + yield db |
| 35 | + finally: |
| 36 | + db.close() |
| 37 | + |
| 38 | + |
| 39 | +def override_get_task_repository(session: Session = Depends(override_get_db)): |
| 40 | + return TaskRepository(session) |
| 41 | + |
| 42 | + |
| 43 | +def override_get_finding_repository(session: Session = Depends(override_get_db)): |
| 44 | + return FindingRepository(session) |
| 45 | + |
| 46 | + |
| 47 | +def override_get_recommendation_repository(session: Session = Depends(override_get_db)): |
| 48 | + return RecommendationRepository(session) |
| 49 | + |
| 50 | + |
| 51 | +app.dependency_overrides[get_task_repository] = override_get_task_repository |
| 52 | +app.dependency_overrides[get_finding_repository] = override_get_finding_repository |
| 53 | +app.dependency_overrides[get_recommendation_repository] = ( |
| 54 | + override_get_recommendation_repository |
| 55 | +) |
| 56 | +client = TestClient(app) |
| 57 | + |
| 58 | + |
| 59 | +def test_create_get_task_integration(): |
| 60 | + |
| 61 | + with TestingSessionLocal() as session: |
| 62 | + task_repo = TaskRepository(session=session) |
| 63 | + task = task_repo.create_task() |
| 64 | + task_repo.get_task_by_id |
| 65 | + |
| 66 | + response = client.get( |
| 67 | + "tasks/", |
| 68 | + ) |
| 69 | + assert response.status_code == 200 |
| 70 | + assert len(response.json()) == 1 |
0 commit comments