from contextlib import asynccontextmanager from typing import Annotated, Union import starlette.status as status from fastapi import Depends, FastAPI, Request from fastapi.responses import RedirectResponse from fastapi.staticfiles import StaticFiles from fastapi.templating import Jinja2Templates from sqlmodel import Session, SQLModel, create_engine, select from models import Event, Household, Registration sqlite_file_name = "database.db" sqlite_url = f"sqlite:///{sqlite_file_name}" connect_args = {"check_same_thread": False} engine = create_engine(sqlite_url, connect_args=connect_args) def get_session(): with Session(engine) as session: yield session def create_db_and_tables(): SQLModel.metadata.create_all(engine) @asynccontextmanager async def on_startup(app_: FastAPI): create_db_and_tables() yield app = FastAPI(lifespan=on_startup) app.mount("/static", StaticFiles(directory="static"), name="static") templates = Jinja2Templates(directory="templates") SessionDep = Annotated[Session, Depends(get_session)] @app.get("/") async def read_root(request: Request, session: SessionDep): statement = select(Event).order_by(Event.event_time) events = session.exec(statement).all() return templates.TemplateResponse( request=request, name="index.html", context={"events": events} ) @app.get("/event/{event_id}") async def read_event(request: Request, event_id: int, session: SessionDep): statement = select(Event).where(Event.id == event_id) event = session.exec(statement).one() statement = select(Household) households = session.exec(statement).all() return templates.TemplateResponse( request=request, name="event.html", context={"event": event, "households": households}, ) @app.post("/event/{event_id}/register") async def add_registration(request: Request, event_id: int, session: SessionDep): form_data = await request.form() # TODO: Make this return a nicer error message try: num_adult_meals = int(form_data["numAdults"]) if form_data["numAdults"] else 0 num_children_meals = int(form_data["numKids"]) if form_data["numKids"] else 0 num_small_children_meals = ( int(form_data["numSmallKids"]) if form_data["numSmallKids"] else 0 ) except ValueError: raise ValueError("All number fields must be integers") registration = Registration( household_id=form_data["household"], event_id=event_id, num_adult_meals=num_adult_meals, num_children_meals=num_children_meals, num_small_children_meals=num_small_children_meals, ) session.add(registration) session.commit() return RedirectResponse(url=f"/event/{event_id}", status_code=status.HTTP_302_FOUND) @app.get("/event/{event_id}/registration/{household_id}/delete") async def delete_registration( request: Request, event_id: int, household_id: int, session: SessionDep ): """ Deletes a registration record for a specific household at a given event. This endpoint handles the removal of the registration, commits the change to the database, and redirects the user to the event page. """ statement = select(Registration).where( Registration.household_id == household_id, Registration.event_id == event_id ) session.delete(session.exec(statement).one()) session.commit() return RedirectResponse(url=f"/event/{event_id}", status_code=status.HTTP_302_FOUND)