import locale import os from contextlib import asynccontextmanager from datetime import datetime, timedelta from functools import partial from typing import Annotated import starlette.status as status from fastapi import Depends, FastAPI, HTTPException, Request, Response from fastapi.responses import FileResponse, RedirectResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from sqlalchemy import create_engine, select from sqlalchemy.orm import Session from meal_manager.grist import sync_with_grist from meal_manager.models import ( Base, Event, Household, Registration, Subscription, TeamRegistration, ) from meal_manager.pdf import build_dinner_overview_pdf sqlite_file_name = "database.db" sqlite_url = f"sqlite:///{sqlite_file_name}" connect_args = {"check_same_thread": False} engine = create_engine(sqlite_url, connect_args=connect_args) locale.setlocale(locale.LC_ALL, "de_DE.UTF-8") def get_session(): with Session(engine) as session: yield session def get_user(request: Request, allow_none: bool = True) -> dict | None: """ Retrieve user information from the incoming request. This function attempts to extract user information from the request headers set by ssowat. If allow_none is set to `True`, then a `None` value will be returned if no user information is found, else an exception will be raised, resulting in a 401 Unauthorized response. Used in UserDep and StrictUserDep. :param request: The incoming HTTP request containing headers and other context information. :param allow_none: Flag indicating whether returning `None` is permitted when user information is not available. :return: A dictionary containing the username if found, or `None` if no user information is available and `allow_none` is `True`. :raises HTTPException: If user information is not found and `allow_none` is `False`. """ if fake_user := os.environ.get("MEAL_MANAGER_FAKE_USER", False): return {"username": "fake_user", "admin": fake_user == "admin"} if "ynh_user" in request.headers: return { "username": request.headers["ynh_user"], # TODO: This should obviously be replaced with a role based check "admin": request.headers["ynh_user"] == "niklas.m", } if allow_none: return None else: raise HTTPException(status_code=401, detail="Not logged in") def create_db_and_tables(): Base.metadata.create_all(engine) @asynccontextmanager async def on_startup(app_: FastAPI): create_db_and_tables() yield app = FastAPI(lifespan=on_startup) app.mount("/static", StaticFiles(directory="src/meal_manager/static"), name="static") templates = Jinja2Templates(directory="src/meal_manager/templates") SessionDep = Annotated[Session, Depends(get_session)] UserDep = Annotated[dict, Depends(get_user)] StrictUserDep = Annotated[dict, Depends(partial(get_user, allow_none=False))] @app.get("/") async def index(request: Request, session: SessionDep, user: UserDep): """Displays coming events and a button to register new ones""" now = datetime.now() # TODO: Once we refactored to use SQLAlchemy directly, we can probably do a nicer filtering on the date alone statement = ( select(Event) .order_by(Event.event_time) .where(Event.event_time >= now - timedelta(days=1)) ) events = session.scalars(statement) return templates.TemplateResponse( request=request, name="index.html", context={"events": events, "current_page": "home", "now": now, "user": user}, ) @app.get("/robots.txt") async def robots_txt(): return FileResponse("src/meal_manager/static/robots.txt", media_type="text/plain") @app.get("/past_events") async def past_events(request: Request, session: SessionDep): now = datetime.now() # TODO: Once we refactored to use SQLAlchemy directly, we can probably do a nicer filtering on the date alone statement = ( select(Event) .order_by(Event.event_time) .where(Event.event_time < now - timedelta(days=1)) ) events = session.scalars(statement) return templates.TemplateResponse( request=request, name="index.html", context={"events": events, "current_page": "past", "now": now}, ) @app.get("/subscribe") async def subscribe(request: Request, session: SessionDep, user: UserDep): statement = select(Household) households = session.scalars(statement) subscriptions = session.scalars(select(Subscription)).all() # filter out households with existing subscriptions households = [ h for h in households if h.id not in [sub.household_id for sub in subscriptions] ] return templates.TemplateResponse( request=request, name="subscribe.html", context={ "households": households, "subscriptions": subscriptions, "user": user, }, ) @app.post("/subscribe") async def add_subscribe(request: Request, session: SessionDep): form_data = await request.form() # TODO: Make this return a nicer error message try: num_adult_meals = int(form_data["numAdults"]) if form_data["numAdults"] else 0 num_children_meals = int(form_data["numKids"]) if form_data["numKids"] else 0 num_small_children_meals = ( int(form_data["numSmallKids"]) if form_data["numSmallKids"] else 0 ) except ValueError: raise ValueError("All number fields must be integers") subscription = Subscription( household_id=form_data["household"], num_adult_meals=num_adult_meals, num_children_meals=num_children_meals, num_small_children_meals=num_small_children_meals, ) selected_days = form_data.getlist("days") if selected_days: subscription.monday = "1" in selected_days subscription.tuesday = "2" in selected_days subscription.wednesday = "3" in selected_days subscription.thursday = "4" in selected_days subscription.friday = "5" in selected_days subscription.saturday = "6" in selected_days subscription.sunday = "7" in selected_days session.add(subscription) session.commit() return RedirectResponse(url="/subscribe", status_code=status.HTTP_302_FOUND) @app.get("/subscribe/{household_id}/delete") async def delete_subscription( request: Request, session: SessionDep, household_id: int, user: StrictUserDep ): statement = select(Subscription).where(Subscription.household_id == household_id) sub = session.scalars(statement).one() session.delete(sub) session.commit() return RedirectResponse(url="/subscribe", status_code=status.HTTP_302_FOUND) @app.get("/event/add") async def add_event_form(request: Request, user: StrictUserDep): return templates.TemplateResponse(request=request, name="add_event.html") @app.get("/event/{event_id}/edit") async def edit_event_form( request: Request, event_id: int, session: SessionDep, user: StrictUserDep ): statement = select(Event).where(Event.id == event_id) event = session.scalars(statement).one() return templates.TemplateResponse( request=request, context={"event": event, "edit_mode": True}, name="add_event.html", ) @app.post("/event/{event_id}/edit") async def edit_event( request: Request, event_id: int, session: SessionDep, user: StrictUserDep ): statement = select(Event).where(Event.id == event_id) event = session.scalars(statement).one() form_data = await request.form() event_time, registration_deadline = await parse_event_times(form_data) event.title = form_data["eventName"] event.event_time = event_time event.registration_deadline = registration_deadline event.description = form_data.get("eventDescription") event.recipe_link = form_data.get("recipeLink") session.commit() return RedirectResponse(url=f"/event/{event.id}", status_code=status.HTTP_302_FOUND) @app.post("/event/add") async def add_event(request: Request, session: SessionDep, user: StrictUserDep): form_data = await request.form() event_time, registration_deadline = await parse_event_times(form_data) event = Event( title=form_data["eventName"], event_time=event_time, registration_deadline=registration_deadline, description=form_data.get("eventDescription"), recipe_link=form_data.get("recipeLink"), ) session.add(event) session.commit() return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) async def parse_event_times(form_data): event_time = datetime.fromisoformat(form_data["eventTime"]) registration_deadline = form_data.get("registrationDeadline") if not registration_deadline: # Find the last Sunday before event_time deadline = event_time while deadline.weekday() != 6: # 6 represents Sunday deadline = deadline.replace(day=deadline.day - 1) registration_deadline = deadline.replace( hour=19, minute=30, second=0, microsecond=0 ) else: registration_deadline = datetime.fromisoformat(registration_deadline) return event_time, registration_deadline @app.get("/event/{event_id}/delete") async def delete_event( request: Request, session: SessionDep, event_id: int, user: StrictUserDep ): if not user["admin"]: raise HTTPException(status_code=403, detail="Not authorized") statement = select(Event).where(Event.id == event_id) event = session.scalars(statement).one() session.delete(event) session.commit() return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND) @app.get("/event/{event_id}") async def read_event( request: Request, event_id: int, session: SessionDep, user: UserDep, message: str | None = None, ): statement = select(Event).where(Event.id == event_id) event = session.scalars(statement).one() statement = select(Household) households = session.scalars(statement) # filter out households with existing registrations households = [ h for h in households if h.id not in [reg.household_id for reg in event.registrations] ] return templates.TemplateResponse( request=request, name="event.html", context={ "event": event, "households": households, "now": datetime.now(), "user": user, "message": message, }, ) @app.post("/event/{event_id}/register") async def add_registration(request: Request, event_id: int, session: SessionDep): form_data = await request.form() # TODO: Make this return a nicer error message try: num_adult_meals = int(form_data["numAdults"]) if form_data["numAdults"] else 0 num_children_meals = int(form_data["numKids"]) if form_data["numKids"] else 0 num_small_children_meals = ( int(form_data["numSmallKids"]) if form_data["numSmallKids"] else 0 ) except ValueError: raise ValueError("All number fields must be integers") registration = Registration( household_id=form_data["household"], event_id=event_id, num_adult_meals=num_adult_meals, num_children_meals=num_children_meals, num_small_children_meals=num_small_children_meals, comment=form_data["comment"], ) session.add(registration) session.commit() return RedirectResponse(url=f"/event/{event_id}", status_code=status.HTTP_302_FOUND) @app.get("/event/{event_id}/registration/{household_id}/delete") async def delete_registration( request: Request, event_id: int, household_id: int, session: SessionDep, user: StrictUserDep, ): """ Deletes a registration record for a specific household at a given event. This endpoint handles the removal of the registration, commits the change to the database, and redirects the user to the event page. """ statement = select(Registration).where( Registration.household_id == household_id, Registration.event_id == event_id ) session.delete(session.scalars(statement).one()) session.commit() return RedirectResponse(url=f"/event/{event_id}", status_code=status.HTTP_302_FOUND) @app.post("/event/{event_id}/register_team") async def add_team_registration(request: Request, event_id: int, session: SessionDep): form_data = await request.form() person = form_data["personName"].strip() work_type = form_data["workType"] statement = select(TeamRegistration).where( TeamRegistration.person_name == person, TeamRegistration.work_type == work_type, TeamRegistration.event_id == event_id, ) # if the person has already registered for the same work type, just ignore if session.scalars(statement).one_or_none() is None: registration = TeamRegistration( person_name=person, event_id=event_id, work_type=form_data["workType"], ) session.add(registration) session.commit() return RedirectResponse(url=f"/event/{event_id}", status_code=status.HTTP_302_FOUND) @app.get("/event/{event_id}/register_team/{entry_id}/delete") async def delete_team_registration( request: Request, event_id: int, entry_id: int, session: SessionDep, user: StrictUserDep, ): statement = select(TeamRegistration).where(TeamRegistration.id == entry_id) session.delete(session.scalars(statement).one()) session.commit() return RedirectResponse(url=f"/event/{event_id}", status_code=status.HTTP_302_FOUND) @app.get("/event/{event_id}/pdf") def get_event_attendance_pdf(event_id: int, session: SessionDep): statement = select(Event).where(Event.id == event_id) event = session.scalars(statement).one() pdf_buffer = build_dinner_overview_pdf(event) headers = { "Content-Disposition": f"inline; filename=attendance_event_{event_id}.pdf" } return Response( content=pdf_buffer.getvalue(), media_type="application/pdf", headers=headers ) @app.get("/event/{event_id}/sync_with_grist") def sync_with_grist_route(event_id: int, session: SessionDep, user: StrictUserDep): statement = select(Event).where(Event.id == event_id) event = session.scalars(statement).one() sync_with_grist(event) return RedirectResponse( url=f"/event/{event_id}?message=Erfolgreich%20an%20Abrechnung%20%C3%BCbertragen", status_code=status.HTTP_302_FOUND, )