From 1f0a27f3af2cc4971b74f0abe010ee65d9099cbc Mon Sep 17 00:00:00 2001 From: Niklas Meinzer Date: Tue, 14 Oct 2025 21:13:31 +0200 Subject: [PATCH] Fix subscription bug and add auth for event creation --- src/meal_manager/main.py | 34 +++++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/src/meal_manager/main.py b/src/meal_manager/main.py index 8df3280..9b987a1 100644 --- a/src/meal_manager/main.py +++ b/src/meal_manager/main.py @@ -2,6 +2,7 @@ import locale import os from contextlib import asynccontextmanager from datetime import datetime, timedelta +from functools import partial from typing import Annotated import starlette.status as status @@ -34,15 +35,35 @@ def get_session(): with Session(engine) as session: yield session -def get_user(request: Request) -> dict | None: +def get_user(request: Request, allow_none: bool=True) -> dict | None: + """ + Retrieve user information from the incoming request. + + This function attempts to extract user information from the request headers set by ssowat. + If allow_none is set to `True`, then a `None` value will be returned if no user information is found, else + an exception will be raised, resulting in a 401 Unauthorized response. + + Used in UserDep and StrictUserDep. + + :param request: The incoming HTTP request containing headers and other context + information. + :param allow_none: Flag indicating whether returning `None` is permitted when + user information is not available. + :return: A dictionary containing the username if found, or `None` if no user + information is available and `allow_none` is `True`. + :raises HTTPException: If user information is not found and `allow_none` is + `False`. + """ if os.environ.get("MEAL_MANAGER_FAKE_USER", False): return {"username": "fake_user"} if "ynh_user" in request.headers: return { "username": request.headers["ynh_user"], } - else: + if allow_none: return None + else: + raise HTTPException(status_code=401, detail="Not logged in") def create_db_and_tables(): Base.metadata.create_all(engine) @@ -62,6 +83,7 @@ templates = Jinja2Templates(directory="src/meal_manager/templates") SessionDep = Annotated[Session, Depends(get_session)] UserDep = Annotated[dict, Depends(get_user)] +StrictUserDep = Annotated[dict, Depends(partial(get_user, allow_none=False))] @app.get("/") async def index(request: Request, session: SessionDep, user : UserDep): """Displays coming events and a button to register new ones""" @@ -102,7 +124,7 @@ async def subscribe(request: Request, session: SessionDep): statement = select(Household) households = session.scalars(statement) - subscriptions = session.scalars(select(Subscription)) + subscriptions = session.scalars(select(Subscription)).all() # filter out households with existing subscriptions households = [ @@ -165,14 +187,12 @@ async def delete_subscription(request: Request, session: SessionDep, household_i @app.get("/event/add") -async def add_event_form(request: Request, user: UserDep): - if not user: - raise HTTPException(status_code=401, detail="Only allowed for logged in users") +async def add_event_form(request: Request, user: StrictUserDep): return templates.TemplateResponse(request=request, name="add_event.html") @app.post("/event/add") -async def add_event(request: Request, session: SessionDep): +async def add_event(request: Request, session: SessionDep, user: StrictUserDep): form_data = await request.form() event_time = datetime.fromisoformat(form_data["eventTime"])