Fix subscription bug and add auth for event creation

This commit is contained in:
2025-10-14 21:13:31 +02:00
parent 7980a112a3
commit 1f0a27f3af

View File

@@ -2,6 +2,7 @@ import locale
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import datetime, timedelta from datetime import datetime, timedelta
from functools import partial
from typing import Annotated from typing import Annotated
import starlette.status as status import starlette.status as status
@@ -34,15 +35,35 @@ def get_session():
with Session(engine) as session: with Session(engine) as session:
yield session yield session
def get_user(request: Request) -> dict | None: def get_user(request: Request, allow_none: bool=True) -> dict | None:
"""
Retrieve user information from the incoming request.
This function attempts to extract user information from the request headers set by ssowat.
If allow_none is set to `True`, then a `None` value will be returned if no user information is found, else
an exception will be raised, resulting in a 401 Unauthorized response.
Used in UserDep and StrictUserDep.
:param request: The incoming HTTP request containing headers and other context
information.
:param allow_none: Flag indicating whether returning `None` is permitted when
user information is not available.
:return: A dictionary containing the username if found, or `None` if no user
information is available and `allow_none` is `True`.
:raises HTTPException: If user information is not found and `allow_none` is
`False`.
"""
if os.environ.get("MEAL_MANAGER_FAKE_USER", False): if os.environ.get("MEAL_MANAGER_FAKE_USER", False):
return {"username": "fake_user"} return {"username": "fake_user"}
if "ynh_user" in request.headers: if "ynh_user" in request.headers:
return { return {
"username": request.headers["ynh_user"], "username": request.headers["ynh_user"],
} }
else: if allow_none:
return None return None
else:
raise HTTPException(status_code=401, detail="Not logged in")
def create_db_and_tables(): def create_db_and_tables():
Base.metadata.create_all(engine) Base.metadata.create_all(engine)
@@ -62,6 +83,7 @@ templates = Jinja2Templates(directory="src/meal_manager/templates")
SessionDep = Annotated[Session, Depends(get_session)] SessionDep = Annotated[Session, Depends(get_session)]
UserDep = Annotated[dict, Depends(get_user)] UserDep = Annotated[dict, Depends(get_user)]
StrictUserDep = Annotated[dict, Depends(partial(get_user, allow_none=False))]
@app.get("/") @app.get("/")
async def index(request: Request, session: SessionDep, user : UserDep): async def index(request: Request, session: SessionDep, user : UserDep):
"""Displays coming events and a button to register new ones""" """Displays coming events and a button to register new ones"""
@@ -102,7 +124,7 @@ async def subscribe(request: Request, session: SessionDep):
statement = select(Household) statement = select(Household)
households = session.scalars(statement) households = session.scalars(statement)
subscriptions = session.scalars(select(Subscription)) subscriptions = session.scalars(select(Subscription)).all()
# filter out households with existing subscriptions # filter out households with existing subscriptions
households = [ households = [
@@ -165,14 +187,12 @@ async def delete_subscription(request: Request, session: SessionDep, household_i
@app.get("/event/add") @app.get("/event/add")
async def add_event_form(request: Request, user: UserDep): async def add_event_form(request: Request, user: StrictUserDep):
if not user:
raise HTTPException(status_code=401, detail="Only allowed for logged in users")
return templates.TemplateResponse(request=request, name="add_event.html") return templates.TemplateResponse(request=request, name="add_event.html")
@app.post("/event/add") @app.post("/event/add")
async def add_event(request: Request, session: SessionDep): async def add_event(request: Request, session: SessionDep, user: StrictUserDep):
form_data = await request.form() form_data = await request.form()
event_time = datetime.fromisoformat(form_data["eventTime"]) event_time = datetime.fromisoformat(form_data["eventTime"])