Files
allmende-essen/src/meal_manager/main.py

455 lines
15 KiB
Python

import locale
import os
from contextlib import asynccontextmanager
from datetime import datetime, timedelta
from functools import partial
from typing import Annotated
import starlette.status as status
from fastapi import Depends, FastAPI, HTTPException, Request, Response
from fastapi.responses import FileResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from fastapi.templating import Jinja2Templates
from sqlalchemy import create_engine, select
from sqlalchemy.orm import Session
from meal_manager.grist import sync_with_grist
from meal_manager.models import (
Base,
Event,
Household,
Registration,
Subscription,
TeamRegistration,
)
from meal_manager.pdf import build_dinner_overview_pdf
sqlite_file_name = "database.db"
sqlite_url = f"sqlite:///{sqlite_file_name}"
connect_args = {"check_same_thread": False}
engine = create_engine(sqlite_url, connect_args=connect_args)
locale.setlocale(locale.LC_ALL, "de_DE.UTF-8")
def get_session():
with Session(engine) as session:
yield session
def get_user(request: Request, allow_none: bool = True) -> dict | None:
"""
Retrieve user information from the incoming request.
This function attempts to extract user information from the request headers set by ssowat.
If allow_none is set to `True`, then a `None` value will be returned if no user information is found, else
an exception will be raised, resulting in a 401 Unauthorized response.
Used in UserDep and StrictUserDep.
:param request: The incoming HTTP request containing headers and other context
information.
:param allow_none: Flag indicating whether returning `None` is permitted when
user information is not available.
:return: A dictionary containing the username if found, or `None` if no user
information is available and `allow_none` is `True`.
:raises HTTPException: If user information is not found and `allow_none` is
`False`.
"""
if fake_user := os.environ.get("MEAL_MANAGER_FAKE_USER", False):
return {"username": "fake_user", "admin": fake_user == "admin"}
if "ynh_user" in request.headers:
return {
"username": request.headers["ynh_user"],
# TODO: This should obviously be replaced with a role based check
"admin": request.headers["ynh_user"] == "niklas.m",
}
if allow_none:
return None
else:
raise HTTPException(status_code=401, detail="Not logged in")
def create_db_and_tables():
Base.metadata.create_all(engine)
@asynccontextmanager
async def on_startup(app_: FastAPI):
create_db_and_tables()
yield
app = FastAPI(lifespan=on_startup)
app.mount("/static", StaticFiles(directory="src/meal_manager/static"), name="static")
templates = Jinja2Templates(directory="src/meal_manager/templates")
SessionDep = Annotated[Session, Depends(get_session)]
UserDep = Annotated[dict, Depends(get_user)]
StrictUserDep = Annotated[dict, Depends(partial(get_user, allow_none=False))]
@app.get("/")
async def index(request: Request, session: SessionDep, user: UserDep):
"""Displays coming events and a button to register new ones"""
now = datetime.now()
# TODO: Once we refactored to use SQLAlchemy directly, we can probably do a nicer filtering on the date alone
statement = (
select(Event)
.order_by(Event.event_time)
.where(Event.event_time >= now - timedelta(days=1))
)
events = session.scalars(statement)
return templates.TemplateResponse(
request=request,
name="index.html",
context={"events": events, "current_page": "home", "now": now, "user": user},
)
@app.get("/robots.txt")
async def robots_txt():
return FileResponse("src/meal_manager/static/robots.txt", media_type="text/plain")
@app.get("/past_events")
async def past_events(request: Request, session: SessionDep):
now = datetime.now()
# TODO: Once we refactored to use SQLAlchemy directly, we can probably do a nicer filtering on the date alone
statement = (
select(Event)
.order_by(Event.event_time)
.where(Event.event_time < now - timedelta(days=1))
)
events = session.scalars(statement)
return templates.TemplateResponse(
request=request,
name="index.html",
context={"events": events, "current_page": "past", "now": now},
)
@app.get("/subscribe")
async def subscribe(request: Request, session: SessionDep, user: UserDep):
statement = select(Household)
households = session.scalars(statement)
subscriptions = session.scalars(select(Subscription)).all()
# filter out households with existing subscriptions
households = [
h for h in households if h.id not in [sub.household_id for sub in subscriptions]
]
return templates.TemplateResponse(
request=request,
name="subscribe.html",
context={
"households": households,
"subscriptions": subscriptions,
"user": user,
},
)
@app.post("/subscribe")
async def add_subscribe(request: Request, session: SessionDep):
form_data = await request.form()
# TODO: Make this return a nicer error message
try:
num_adult_meals = int(form_data["numAdults"]) if form_data["numAdults"] else 0
num_children_meals = int(form_data["numKids"]) if form_data["numKids"] else 0
num_small_children_meals = (
int(form_data["numSmallKids"]) if form_data["numSmallKids"] else 0
)
except ValueError:
raise ValueError("All number fields must be integers")
subscription = Subscription(
household_id=form_data["household"],
num_adult_meals=num_adult_meals,
num_children_meals=num_children_meals,
num_small_children_meals=num_small_children_meals,
)
selected_days = form_data.getlist("days")
if selected_days:
subscription.monday = "1" in selected_days
subscription.tuesday = "2" in selected_days
subscription.wednesday = "3" in selected_days
subscription.thursday = "4" in selected_days
subscription.friday = "5" in selected_days
subscription.saturday = "6" in selected_days
subscription.sunday = "7" in selected_days
session.add(subscription)
session.commit()
return RedirectResponse(url="/subscribe", status_code=status.HTTP_302_FOUND)
@app.get("/subscribe/{household_id}/delete")
async def delete_subscription(
request: Request, session: SessionDep, household_id: int, user: StrictUserDep
):
statement = select(Subscription).where(Subscription.household_id == household_id)
sub = session.scalars(statement).one()
session.delete(sub)
session.commit()
return RedirectResponse(url="/subscribe", status_code=status.HTTP_302_FOUND)
@app.get("/event/add")
async def add_event_form(request: Request, user: StrictUserDep):
return templates.TemplateResponse(request=request, name="add_event.html")
@app.get("/event/{event_id}/edit")
async def edit_event_form(
request: Request, event_id: int, session: SessionDep, user: StrictUserDep
):
statement = select(Event).where(Event.id == event_id)
event = session.scalars(statement).one()
return templates.TemplateResponse(
request=request,
context={"event": event, "edit_mode": True},
name="add_event.html",
)
@app.post("/event/{event_id}/edit")
async def edit_event(
request: Request, event_id: int, session: SessionDep, user: StrictUserDep
):
statement = select(Event).where(Event.id == event_id)
event = session.scalars(statement).one()
form_data = await request.form()
event_time, registration_deadline = await parse_event_times(form_data)
event.title = form_data["eventName"]
event.event_time = event_time
event.registration_deadline = registration_deadline
event.description = form_data.get("eventDescription")
event.recipe_link = form_data.get("recipeLink")
event.ignore_subscriptions = form_data.get("ignoreSubscriptions") == "on"
session.commit()
return RedirectResponse(url=f"/event/{event.id}", status_code=status.HTTP_302_FOUND)
@app.post("/event/add")
async def add_event(request: Request, session: SessionDep, user: StrictUserDep):
form_data = await request.form()
event_time, registration_deadline = await parse_event_times(form_data)
event = Event(
title=form_data["eventName"],
event_time=event_time,
registration_deadline=registration_deadline,
description=form_data.get("eventDescription"),
recipe_link=form_data.get("recipeLink"),
ignore_subscriptions=form_data.get("ignoreSubscriptions") == "on",
)
session.add(event)
session.commit()
return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND)
async def parse_event_times(form_data):
event_time = datetime.fromisoformat(form_data["eventTime"])
registration_deadline = form_data.get("registrationDeadline")
if not registration_deadline:
# Find the last Sunday before event_time
deadline = event_time
while deadline.weekday() != 6: # 6 represents Sunday
deadline = deadline.replace(day=deadline.day - 1)
registration_deadline = deadline.replace(
hour=19, minute=30, second=0, microsecond=0
)
else:
registration_deadline = datetime.fromisoformat(registration_deadline)
return event_time, registration_deadline
@app.get("/event/{event_id}/delete")
async def delete_event(
request: Request, session: SessionDep, event_id: int, user: StrictUserDep
):
if not user["admin"]:
raise HTTPException(status_code=403, detail="Not authorized")
statement = select(Event).where(Event.id == event_id)
event = session.scalars(statement).one()
session.delete(event)
session.commit()
return RedirectResponse(url="/", status_code=status.HTTP_302_FOUND)
@app.get("/event/{event_id}")
async def read_event(
request: Request,
event_id: int,
session: SessionDep,
user: UserDep,
message: str | None = None,
):
statement = select(Event).where(Event.id == event_id)
event = session.scalars(statement).one()
statement = select(Household)
households = session.scalars(statement)
# filter out households with existing registrations
households = [
h
for h in households
if h.id not in [reg.household_id for reg in event.registrations]
]
return templates.TemplateResponse(
request=request,
name="event.html",
context={
"event": event,
"households": households,
"now": datetime.now(),
"user": user,
"message": message,
},
)
@app.post("/event/{event_id}/register")
async def add_registration(request: Request, event_id: int, session: SessionDep):
form_data = await request.form()
# TODO: Make this return a nicer error message
try:
num_adult_meals = int(form_data["numAdults"]) if form_data["numAdults"] else 0
num_children_meals = int(form_data["numKids"]) if form_data["numKids"] else 0
num_small_children_meals = (
int(form_data["numSmallKids"]) if form_data["numSmallKids"] else 0
)
except ValueError:
raise ValueError("All number fields must be integers")
registration = Registration(
household_id=form_data["household"],
event_id=event_id,
num_adult_meals=num_adult_meals,
num_children_meals=num_children_meals,
num_small_children_meals=num_small_children_meals,
comment=form_data["comment"],
)
session.add(registration)
session.commit()
return RedirectResponse(url=f"/event/{event_id}", status_code=status.HTTP_302_FOUND)
@app.get("/event/{event_id}/registration/{household_id}/delete")
async def delete_registration(
request: Request,
event_id: int,
household_id: int,
session: SessionDep,
user: StrictUserDep,
):
"""
Deletes a registration record for a specific household at a given event. This endpoint
handles the removal of the registration, commits the change to the database, and
redirects the user to the event page.
"""
statement = select(Registration).where(
Registration.household_id == household_id, Registration.event_id == event_id
)
session.delete(session.scalars(statement).one())
session.commit()
return RedirectResponse(url=f"/event/{event_id}", status_code=status.HTTP_302_FOUND)
@app.post("/event/{event_id}/register_team")
async def add_team_registration(request: Request, event_id: int, session: SessionDep):
form_data = await request.form()
person = form_data["personName"].strip()
work_type = form_data["workType"]
statement = select(TeamRegistration).where(
TeamRegistration.person_name == person,
TeamRegistration.work_type == work_type,
TeamRegistration.event_id == event_id,
)
# if the person has already registered for the same work type, just ignore
if session.scalars(statement).one_or_none() is None:
registration = TeamRegistration(
person_name=person,
event_id=event_id,
work_type=form_data["workType"],
)
session.add(registration)
session.commit()
return RedirectResponse(url=f"/event/{event_id}", status_code=status.HTTP_302_FOUND)
@app.get("/event/{event_id}/register_team/{entry_id}/delete")
async def delete_team_registration(
request: Request,
event_id: int,
entry_id: int,
session: SessionDep,
user: StrictUserDep,
):
statement = select(TeamRegistration).where(TeamRegistration.id == entry_id)
session.delete(session.scalars(statement).one())
session.commit()
return RedirectResponse(url=f"/event/{event_id}", status_code=status.HTTP_302_FOUND)
@app.get("/event/{event_id}/pdf")
def get_event_attendance_pdf(event_id: int, session: SessionDep):
statement = select(Event).where(Event.id == event_id)
event = session.scalars(statement).one()
pdf_buffer = build_dinner_overview_pdf(event)
headers = {
"Content-Disposition": f"inline; filename=attendance_event_{event_id}.pdf"
}
return Response(
content=pdf_buffer.getvalue(), media_type="application/pdf", headers=headers
)
@app.get("/event/{event_id}/sync_with_grist")
def sync_with_grist_route(event_id: int, session: SessionDep, user: StrictUserDep):
"""
Synchronizes the specified event with Grist and redirects the user.
This function retrieves the event by its identifier, synchronizes it with Grist,
and then redirects the user to the event page with a success message.
TODO: Error handling
"""
statement = select(Event).where(Event.id == event_id)
event = session.scalars(statement).one()
sync_with_grist(event)
return RedirectResponse(
url=f"/event/{event_id}?message=Erfolgreich%20an%20Abrechnung%20%C3%BCbertragen",
status_code=status.HTTP_302_FOUND,
)