Fix subscription bug and add auth for event creation

This commit is contained in:
2025-10-14 21:13:31 +02:00
parent 7980a112a3
commit 1f0a27f3af

View File

@@ -2,6 +2,7 @@ import locale
import os
from contextlib import asynccontextmanager
from datetime import datetime, timedelta
from functools import partial
from typing import Annotated
import starlette.status as status
@@ -34,15 +35,35 @@ def get_session():
with Session(engine) as session:
yield session
def get_user(request: Request) -> dict | None:
def get_user(request: Request, allow_none: bool=True) -> dict | None:
"""
Retrieve user information from the incoming request.
This function attempts to extract user information from the request headers set by ssowat.
If allow_none is set to `True`, then a `None` value will be returned if no user information is found, else
an exception will be raised, resulting in a 401 Unauthorized response.
Used in UserDep and StrictUserDep.
:param request: The incoming HTTP request containing headers and other context
information.
:param allow_none: Flag indicating whether returning `None` is permitted when
user information is not available.
:return: A dictionary containing the username if found, or `None` if no user
information is available and `allow_none` is `True`.
:raises HTTPException: If user information is not found and `allow_none` is
`False`.
"""
if os.environ.get("MEAL_MANAGER_FAKE_USER", False):
return {"username": "fake_user"}
if "ynh_user" in request.headers:
return {
"username": request.headers["ynh_user"],
}
else:
if allow_none:
return None
else:
raise HTTPException(status_code=401, detail="Not logged in")
def create_db_and_tables():
Base.metadata.create_all(engine)
@@ -62,6 +83,7 @@ templates = Jinja2Templates(directory="src/meal_manager/templates")
SessionDep = Annotated[Session, Depends(get_session)]
UserDep = Annotated[dict, Depends(get_user)]
StrictUserDep = Annotated[dict, Depends(partial(get_user, allow_none=False))]
@app.get("/")
async def index(request: Request, session: SessionDep, user : UserDep):
"""Displays coming events and a button to register new ones"""
@@ -102,7 +124,7 @@ async def subscribe(request: Request, session: SessionDep):
statement = select(Household)
households = session.scalars(statement)
subscriptions = session.scalars(select(Subscription))
subscriptions = session.scalars(select(Subscription)).all()
# filter out households with existing subscriptions
households = [
@@ -165,14 +187,12 @@ async def delete_subscription(request: Request, session: SessionDep, household_i
@app.get("/event/add")
async def add_event_form(request: Request, user: UserDep):
if not user:
raise HTTPException(status_code=401, detail="Only allowed for logged in users")
async def add_event_form(request: Request, user: StrictUserDep):
return templates.TemplateResponse(request=request, name="add_event.html")
@app.post("/event/add")
async def add_event(request: Request, session: SessionDep):
async def add_event(request: Request, session: SessionDep, user: StrictUserDep):
form_data = await request.form()
event_time = datetime.fromisoformat(form_data["eventTime"])