diff --git a/src/allmende_payment_system/api/__init__.py b/src/allmende_payment_system/api/__init__.py index 6fca777..3f0a390 100644 --- a/src/allmende_payment_system/api/__init__.py +++ b/src/allmende_payment_system/api/__init__.py @@ -10,8 +10,7 @@ templates = get_jinja_renderer() @root_router.get("/") -async def landing_page(request: Request, user_info: UserDep, session: SessionDep): - user = ensure_user(user_info, session) +async def landing_page(request: Request, user: UserDep, session: SessionDep): print(f"User {user.username} ({user.display_name}) accessed landing page") transactions = [] for account in user.accounts: diff --git a/src/allmende_payment_system/api/dependencies.py b/src/allmende_payment_system/api/dependencies.py index 80a3b12..67cf694 100644 --- a/src/allmende_payment_system/api/dependencies.py +++ b/src/allmende_payment_system/api/dependencies.py @@ -4,22 +4,8 @@ from typing import Annotated from fastapi import Depends, HTTPException, Request from sqlalchemy.orm import Session -from allmende_payment_system.database import SessionLocal - - -async def get_user(request: Request) -> dict: - - if username := os.environ.get("APS_username", None): - return { - "username": username, - "display_name": os.environ.get("APS_display_name", "Missing Display Name"), - } - if "ynh_user" not in request.headers: - raise HTTPException(status_code=401, detail="Missing ynh_user header") - return {"username": request.headers["ynh_user"]} - - -UserDep = Annotated[dict, Depends(get_user)] +from allmende_payment_system.database import SessionLocal, ensure_user +from allmende_payment_system.models import User def get_session() -> Session: @@ -31,3 +17,27 @@ def get_session() -> Session: SessionDep = Annotated[Session, Depends(get_session)] + + +async def get_user(request: Request) -> dict: + + if username := os.environ.get("APS_username", None): + return { + "username": username, + "display_name": os.environ.get("APS_display_name", "Missing Display Name"), + } + if "ynh_user" not in request.headers: + raise HTTPException(status_code=401, detail="Missing ynh_user header") + + return {"username": request.headers["ynh_user"]} + + +async def get_user_object(request: Request, session: SessionDep) -> User: + + user_info = await get_user(request) + user = ensure_user(user_info, session) + request.state.user = user + return user + + +UserDep = Annotated[dict, Depends(get_user_object)] diff --git a/src/allmende_payment_system/app.py b/src/allmende_payment_system/app.py index d7ae837..adcbbff 100644 --- a/src/allmende_payment_system/app.py +++ b/src/allmende_payment_system/app.py @@ -4,11 +4,11 @@ from fastapi import Depends, FastAPI from fastapi.staticfiles import StaticFiles from allmende_payment_system.api import root_router -from allmende_payment_system.api.dependencies import get_user +from allmende_payment_system.api.dependencies import get_user_object from allmende_payment_system.api.shop import shop_router locale.setlocale(locale.LC_ALL, "de_DE.UTF-8") -app = FastAPI(dependencies=[Depends(get_user)]) +app = FastAPI(dependencies=[Depends(get_user_object)]) app.mount( diff --git a/src/allmende_payment_system/models.py b/src/allmende_payment_system/models.py index 9400be1..3d19a8b 100644 --- a/src/allmende_payment_system/models.py +++ b/src/allmende_payment_system/models.py @@ -2,8 +2,15 @@ import datetime import decimal import typing -from sqlalchemy import Column, ForeignKey, Numeric, Table -from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship +from sqlalchemy import Column, ForeignKey, Numeric, Table, select +from sqlalchemy.exc import NoResultFound +from sqlalchemy.orm import ( + DeclarativeBase, + Mapped, + mapped_column, + object_session, + relationship, +) TABLE_PREFIX = "aps_" @@ -46,6 +53,20 @@ class User(Base): accounts: Mapped[list["Account"]] = relationship( "Account", secondary=user_account_association, back_populates="users" ) + orders: Mapped[list["Order"]] = relationship("Order", back_populates="user") + + @property + def shopping_cart(self): + for order in self.orders: + if order.account_id is None: + cart = order + break + else: + cart = Order(user=self) + session = object_session(self) + session.add(cart) + + return cart class Area(Base): @@ -81,6 +102,40 @@ class Product(Base): image_path: Mapped[str] = mapped_column(nullable=True) +class Order(Base): + __tablename__ = TABLE_PREFIX + "order" + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + user_id: Mapped[int] = mapped_column(ForeignKey(TABLE_PREFIX + "user.id")) + user: Mapped[User] = relationship("User", back_populates="orders") + + account_id: Mapped[int] = mapped_column( + ForeignKey(TABLE_PREFIX + "account.id"), nullable=True + ) + account: Mapped[Account | None] = relationship("Account") + + items: Mapped[list["OrderItem"]] = relationship( + "OrderItem", cascade="all, delete-orphan", back_populates="order" + ) + + @property + def is_in_shopping_cart(self): + return self.account is None + + +class OrderItem(Base): + __tablename__ = TABLE_PREFIX + "order_item" + id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) + order_id: Mapped[int] = mapped_column(ForeignKey(TABLE_PREFIX + "order.id")) + order: Mapped[Order] = relationship("Order", back_populates="items") + + product_id: Mapped[int] = mapped_column(ForeignKey(TABLE_PREFIX + "product.id")) + product: Mapped[Product] = relationship("Product") + quantity: Mapped[int] = mapped_column(nullable=False) + total_amount: Mapped[decimal.Decimal] = mapped_column( + Numeric(10, 2), nullable=False + ) + + TransactionTypes = typing.Literal[ "product", "deposit", diff --git a/src/allmende_payment_system/templates/base.html.jinja b/src/allmende_payment_system/templates/base.html.jinja index 2cf7801..466d7ac 100644 --- a/src/allmende_payment_system/templates/base.html.jinja +++ b/src/allmende_payment_system/templates/base.html.jinja @@ -40,6 +40,21 @@ + + +
+ diff --git a/test/conftest.py b/test/conftest.py index 95d8be6..71b3596 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -1,14 +1,38 @@ import os +from unittest import mock import pytest from fastapi.testclient import TestClient -from sqlalchemy import create_engine +from sqlalchemy import StaticPool, create_engine from sqlalchemy.orm import sessionmaker +from allmende_payment_system.api.dependencies import get_session from allmende_payment_system.app import app from allmende_payment_system.models import Base +def make_db(): + engine = create_engine( + "sqlite:///:memory:", + connect_args={"check_same_thread": False}, + poolclass=StaticPool, + ) + Base.metadata.create_all(bind=engine) # Create tables + return sessionmaker(autocommit=False, autoflush=False, bind=engine) + + +def make_in_memory_session(): + db = make_db() + session = db() + try: + yield session + finally: + session.close() + + +app.dependency_overrides[get_session] = make_in_memory_session + + @pytest.fixture(scope="session") def client(): os.environ["APS_username"] = "test" @@ -23,13 +47,5 @@ def unauthorized_client(): @pytest.fixture def test_db(): - engine = create_engine("sqlite:///:memory:") - Base.metadata.create_all(bind=engine) # Create tables - TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) - - # Provide a session and the engine - db = TestingSessionLocal() - try: - yield db - finally: - db.close() + db = make_db() + return db() diff --git a/test/test_auth.py b/test/test_auth.py index b5d9625..eac0b3d 100644 --- a/test/test_auth.py +++ b/test/test_auth.py @@ -1,11 +1,11 @@ from allmende_payment_system.models import Account -def test_unauthorized_access(unauthorized_client): +def test_unauthorized_access(unauthorized_client, test_db): response = unauthorized_client.get("/") assert response.status_code == 401 -def test_authorized_access(client): +def test_authorized_access(client, test_db): response = client.get("/") assert response.status_code == 200 diff --git a/test/test_models.py b/test/test_models.py index d5825e2..246986b 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -1,16 +1,28 @@ +import pytest + from allmende_payment_system.models import Account, User -def test_user_model(test_db): +@pytest.fixture(scope="function") +def test_user(test_db): user = User(username="test", display_name="Test User") test_db.add(user) - test_db.commit() + test_db.flush() + return user - assert user.id is not None + +def test_user_model(test_db, test_user): + assert test_user.id is not None account = Account(name="Test Account") - account.users.append(user) + account.users.append(test_user) test_db.add(account) - test_db.commit() + test_db.flush() - assert len(user.accounts) == 1 + assert len(test_user.accounts) == 1 + + +def test_user_shopping_cart_new(test_db, test_user): + cart = test_user.shopping_cart + + assert len(cart.items) == 0