diff --git a/dev-server.sh b/dev-server.sh new file mode 100644 index 0000000..2c8e0c2 --- /dev/null +++ b/dev-server.sh @@ -0,0 +1,4 @@ +if [ -z "${APS_username}" ]; then + export APS_username="testuser" +fi +fastapi dev src/allmende_payment_system/app.py \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 7acb679..e019bce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,9 +22,10 @@ build-backend = "uv_build" [dependency-groups] dev = [ "black>=25.9.0", + "httpx>=0.28.1", "isort>=7.0.0", "pytest>=8.4.2", ] [tool.isort] -profile = "black" \ No newline at end of file +profile = "black" diff --git a/src/allmende_payment_system/app.py b/src/allmende_payment_system/app.py index 946fbac..62135b1 100644 --- a/src/allmende_payment_system/app.py +++ b/src/allmende_payment_system/app.py @@ -1,9 +1,25 @@ -# backend/app/main.py -from fastapi import FastAPI, Request +import os +from typing import Annotated + +from fastapi import Depends, FastAPI, HTTPException, Request from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates -app = FastAPI() + +async def get_user(request: Request) -> dict: + + if username := os.environ.get("APS_username", None): + return {"username": username} + if "ynh_user" not in request.headers: + raise HTTPException(status_code=401, detail="Missing ynh_user header") + return {"username": request.headers["ynh_user"]} + + +UserDep = Annotated[dict, Depends(get_user)] + + +app = FastAPI(dependencies=[Depends(get_user)]) + templates = Jinja2Templates(directory="src/allmende_payment_system/templates") app.mount( "/static", diff --git a/test/conftest.py b/test/conftest.py new file mode 100644 index 0000000..15dff24 --- /dev/null +++ b/test/conftest.py @@ -0,0 +1,30 @@ +import pytest +from fastapi import Request +from fastapi.testclient import TestClient + +from allmende_payment_system.app import create_app + + +@pytest.fixture(scope="session") +def client(): + app = create_app() + + async def add_ynh_headers(request: Request, call_next): + username = request.headers.get("APS-TEST-username", "test") + # This seems to work although headers are immutable + # If this ever turns out to be a problem, we can use request.state instead, + # but will have to modify app.get_user + request.headers._list.append((b"ynh_user", username.encode("utf-8"))) + + response = await call_next(request) + return response + + app.middleware("http")(add_ynh_headers) + + return TestClient(app) + + +@pytest.fixture(scope="session") +def unauthorized_client(): + app = create_app() + return TestClient(app) diff --git a/test/test_auth.py b/test/test_auth.py new file mode 100644 index 0000000..dae3e81 --- /dev/null +++ b/test/test_auth.py @@ -0,0 +1,12 @@ +from allmende_payment_system.models import Account + + +def test_unauthorized_access(unauthorized_client): + response = unauthorized_client.get("/") + assert response.status_code == 401 + + +def test_unauthorized_access(unauthorized_client): + response = unauthorized_client.get("/") + print(response.text) + assert response.status_code == 401 diff --git a/test/test_models.py b/test/test_models.py index ea4a391..2431d44 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -1,8 +1,9 @@ -# tests/conftest.py import pytest from sqlalchemy import create_engine -from sqlalchemy.orm import sessionmaker, declarative_base -from allmende_payment_system.models import Base, User +from sqlalchemy.orm import sessionmaker + +from allmende_payment_system.models import Account, Base, User + # Create an in-memory SQLite database @pytest.fixture @@ -19,9 +20,16 @@ def in_memory_db(): db.close() -def test_create_user(in_memory_db): +def test_user_model(in_memory_db): user = User(username="test", display_name="Test User") in_memory_db.add(user) in_memory_db.commit() - assert user.id is not None \ No newline at end of file + assert user.id is not None + + account = Account(name="Test Account") + account.users.append(user) + in_memory_db.add(account) + in_memory_db.commit() + + assert len(user.accounts) == 1 diff --git a/uv.lock b/uv.lock index 22409b6..ec7436f 100644 --- a/uv.lock +++ b/uv.lock @@ -14,6 +14,7 @@ dependencies = [ [package.dev-dependencies] dev = [ { name = "black" }, + { name = "httpx" }, { name = "isort" }, { name = "pytest" }, ] @@ -27,6 +28,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ { name = "black", specifier = ">=25.9.0" }, + { name = "httpx", specifier = ">=0.28.1" }, { name = "isort", specifier = ">=7.0.0" }, { name = "pytest", specifier = ">=8.4.2" }, ]