Add shopping cart and related models

This commit is contained in:
2025-11-11 12:08:06 +01:00
parent b3166811e5
commit f4618f4d05
8 changed files with 148 additions and 41 deletions

View File

@@ -10,8 +10,7 @@ templates = get_jinja_renderer()
@root_router.get("/") @root_router.get("/")
async def landing_page(request: Request, user_info: UserDep, session: SessionDep): async def landing_page(request: Request, user: UserDep, session: SessionDep):
user = ensure_user(user_info, session)
print(f"User {user.username} ({user.display_name}) accessed landing page") print(f"User {user.username} ({user.display_name}) accessed landing page")
transactions = [] transactions = []
for account in user.accounts: for account in user.accounts:

View File

@@ -4,22 +4,8 @@ from typing import Annotated
from fastapi import Depends, HTTPException, Request from fastapi import Depends, HTTPException, Request
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from allmende_payment_system.database import SessionLocal from allmende_payment_system.database import SessionLocal, ensure_user
from allmende_payment_system.models import User
async def get_user(request: Request) -> dict:
if username := os.environ.get("APS_username", None):
return {
"username": username,
"display_name": os.environ.get("APS_display_name", "Missing Display Name"),
}
if "ynh_user" not in request.headers:
raise HTTPException(status_code=401, detail="Missing ynh_user header")
return {"username": request.headers["ynh_user"]}
UserDep = Annotated[dict, Depends(get_user)]
def get_session() -> Session: def get_session() -> Session:
@@ -31,3 +17,27 @@ def get_session() -> Session:
SessionDep = Annotated[Session, Depends(get_session)] SessionDep = Annotated[Session, Depends(get_session)]
async def get_user(request: Request) -> dict:
if username := os.environ.get("APS_username", None):
return {
"username": username,
"display_name": os.environ.get("APS_display_name", "Missing Display Name"),
}
if "ynh_user" not in request.headers:
raise HTTPException(status_code=401, detail="Missing ynh_user header")
return {"username": request.headers["ynh_user"]}
async def get_user_object(request: Request, session: SessionDep) -> User:
user_info = await get_user(request)
user = ensure_user(user_info, session)
request.state.user = user
return user
UserDep = Annotated[dict, Depends(get_user_object)]

View File

@@ -4,11 +4,11 @@ from fastapi import Depends, FastAPI
from fastapi.staticfiles import StaticFiles from fastapi.staticfiles import StaticFiles
from allmende_payment_system.api import root_router from allmende_payment_system.api import root_router
from allmende_payment_system.api.dependencies import get_user from allmende_payment_system.api.dependencies import get_user_object
from allmende_payment_system.api.shop import shop_router from allmende_payment_system.api.shop import shop_router
locale.setlocale(locale.LC_ALL, "de_DE.UTF-8") locale.setlocale(locale.LC_ALL, "de_DE.UTF-8")
app = FastAPI(dependencies=[Depends(get_user)]) app = FastAPI(dependencies=[Depends(get_user_object)])
app.mount( app.mount(

View File

@@ -2,8 +2,15 @@ import datetime
import decimal import decimal
import typing import typing
from sqlalchemy import Column, ForeignKey, Numeric, Table from sqlalchemy import Column, ForeignKey, Numeric, Table, select
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship from sqlalchemy.exc import NoResultFound
from sqlalchemy.orm import (
DeclarativeBase,
Mapped,
mapped_column,
object_session,
relationship,
)
TABLE_PREFIX = "aps_" TABLE_PREFIX = "aps_"
@@ -46,6 +53,20 @@ class User(Base):
accounts: Mapped[list["Account"]] = relationship( accounts: Mapped[list["Account"]] = relationship(
"Account", secondary=user_account_association, back_populates="users" "Account", secondary=user_account_association, back_populates="users"
) )
orders: Mapped[list["Order"]] = relationship("Order", back_populates="user")
@property
def shopping_cart(self):
for order in self.orders:
if order.account_id is None:
cart = order
break
else:
cart = Order(user=self)
session = object_session(self)
session.add(cart)
return cart
class Area(Base): class Area(Base):
@@ -81,6 +102,40 @@ class Product(Base):
image_path: Mapped[str] = mapped_column(nullable=True) image_path: Mapped[str] = mapped_column(nullable=True)
class Order(Base):
__tablename__ = TABLE_PREFIX + "order"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
user_id: Mapped[int] = mapped_column(ForeignKey(TABLE_PREFIX + "user.id"))
user: Mapped[User] = relationship("User", back_populates="orders")
account_id: Mapped[int] = mapped_column(
ForeignKey(TABLE_PREFIX + "account.id"), nullable=True
)
account: Mapped[Account | None] = relationship("Account")
items: Mapped[list["OrderItem"]] = relationship(
"OrderItem", cascade="all, delete-orphan", back_populates="order"
)
@property
def is_in_shopping_cart(self):
return self.account is None
class OrderItem(Base):
__tablename__ = TABLE_PREFIX + "order_item"
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
order_id: Mapped[int] = mapped_column(ForeignKey(TABLE_PREFIX + "order.id"))
order: Mapped[Order] = relationship("Order", back_populates="items")
product_id: Mapped[int] = mapped_column(ForeignKey(TABLE_PREFIX + "product.id"))
product: Mapped[Product] = relationship("Product")
quantity: Mapped[int] = mapped_column(nullable=False)
total_amount: Mapped[decimal.Decimal] = mapped_column(
Numeric(10, 2), nullable=False
)
TransactionTypes = typing.Literal[ TransactionTypes = typing.Literal[
"product", "product",
"deposit", "deposit",

View File

@@ -40,6 +40,21 @@
</a> </a>
</li> </li>
</ul> </ul>
<!-- Shopping Cart at Bottom -->
<div class="mt-auto pt-3 border-top">
<a href="/cart" class="btn btn-primary w-100 position-relative">
<svg xmlns="http://www.w3.org/2000/svg" width="20" height="20" fill="currentColor" class="bi bi-cart3 me-2" viewBox="0 0 16 16">
<path d="M0 1.5A.5.5 0 0 1 .5 1H2a.5.5 0 0 1 .485.379L2.89 3H14.5a.5.5 0 0 1 .49.598l-1 5a.5.5 0 0 1-.465.401l-9.397.472L4.415 11H13a.5.5 0 0 1 0 1H4a.5.5 0 0 1-.491-.408L2.01 3.607 1.61 2H.5a.5.5 0 0 1-.5-.5M3.102 4l.84 4.479 9.144-.459L13.89 4zM5 12a2 2 0 1 0 0 4 2 2 0 0 0 0-4m7 0a2 2 0 1 0 0 4 2 2 0 0 0 0-4m-7 1a1 1 0 1 1 0 2 1 1 0 0 1 0-2m7 0a1 1 0 1 1 0 2 1 1 0 0 1 0-2"/>
</svg>
Warenkorb
<span class="position-absolute top-0 start-100 translate-middle badge rounded-pill bg-danger">
{{ request.state.user.shopping_cart.items|length }}
<span class="visually-hidden">Artikel im Warenkorb</span>
</span>
</a>
</div>
</div> </div>
</div> </div>

View File

@@ -1,14 +1,38 @@
import os import os
from unittest import mock
import pytest import pytest
from fastapi.testclient import TestClient from fastapi.testclient import TestClient
from sqlalchemy import create_engine from sqlalchemy import StaticPool, create_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from allmende_payment_system.api.dependencies import get_session
from allmende_payment_system.app import app from allmende_payment_system.app import app
from allmende_payment_system.models import Base from allmende_payment_system.models import Base
def make_db():
engine = create_engine(
"sqlite:///:memory:",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
Base.metadata.create_all(bind=engine) # Create tables
return sessionmaker(autocommit=False, autoflush=False, bind=engine)
def make_in_memory_session():
db = make_db()
session = db()
try:
yield session
finally:
session.close()
app.dependency_overrides[get_session] = make_in_memory_session
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def client(): def client():
os.environ["APS_username"] = "test" os.environ["APS_username"] = "test"
@@ -23,13 +47,5 @@ def unauthorized_client():
@pytest.fixture @pytest.fixture
def test_db(): def test_db():
engine = create_engine("sqlite:///:memory:") db = make_db()
Base.metadata.create_all(bind=engine) # Create tables return db()
TestingSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# Provide a session and the engine
db = TestingSessionLocal()
try:
yield db
finally:
db.close()

View File

@@ -1,11 +1,11 @@
from allmende_payment_system.models import Account from allmende_payment_system.models import Account
def test_unauthorized_access(unauthorized_client): def test_unauthorized_access(unauthorized_client, test_db):
response = unauthorized_client.get("/") response = unauthorized_client.get("/")
assert response.status_code == 401 assert response.status_code == 401
def test_authorized_access(client): def test_authorized_access(client, test_db):
response = client.get("/") response = client.get("/")
assert response.status_code == 200 assert response.status_code == 200

View File

@@ -1,16 +1,28 @@
import pytest
from allmende_payment_system.models import Account, User from allmende_payment_system.models import Account, User
def test_user_model(test_db): @pytest.fixture(scope="function")
def test_user(test_db):
user = User(username="test", display_name="Test User") user = User(username="test", display_name="Test User")
test_db.add(user) test_db.add(user)
test_db.commit() test_db.flush()
return user
assert user.id is not None
def test_user_model(test_db, test_user):
assert test_user.id is not None
account = Account(name="Test Account") account = Account(name="Test Account")
account.users.append(user) account.users.append(test_user)
test_db.add(account) test_db.add(account)
test_db.commit() test_db.flush()
assert len(user.accounts) == 1 assert len(test_user.accounts) == 1
def test_user_shopping_cart_new(test_db, test_user):
cart = test_user.shopping_cart
assert len(cart.items) == 0