107 lines
3.4 KiB
Python
107 lines
3.4 KiB
Python
from contextlib import asynccontextmanager
|
|
from typing import Annotated, Union
|
|
|
|
import starlette.status as status
|
|
from fastapi import Depends, FastAPI, Request
|
|
from fastapi.responses import RedirectResponse
|
|
from fastapi.staticfiles import StaticFiles
|
|
from fastapi.templating import Jinja2Templates
|
|
from sqlmodel import Session, SQLModel, create_engine, select
|
|
|
|
from models import Event, Household, Registration
|
|
|
|
sqlite_file_name = "database.db"
|
|
sqlite_url = f"sqlite:///{sqlite_file_name}"
|
|
|
|
connect_args = {"check_same_thread": False}
|
|
engine = create_engine(sqlite_url, connect_args=connect_args)
|
|
|
|
|
|
def get_session():
|
|
with Session(engine) as session:
|
|
yield session
|
|
|
|
|
|
def create_db_and_tables():
|
|
SQLModel.metadata.create_all(engine)
|
|
|
|
|
|
@asynccontextmanager
|
|
async def on_startup(app_: FastAPI):
|
|
create_db_and_tables()
|
|
yield
|
|
|
|
|
|
app = FastAPI(lifespan=on_startup)
|
|
app.mount("/static", StaticFiles(directory="static"), name="static")
|
|
|
|
templates = Jinja2Templates(directory="templates")
|
|
|
|
SessionDep = Annotated[Session, Depends(get_session)]
|
|
|
|
|
|
@app.get("/")
|
|
async def read_root(request: Request, session: SessionDep):
|
|
statement = select(Event).order_by(Event.event_time)
|
|
events = session.exec(statement).all()
|
|
return templates.TemplateResponse(
|
|
request=request, name="index.html", context={"events": events}
|
|
)
|
|
|
|
|
|
@app.get("/event/{event_id}")
|
|
async def read_event(request: Request, event_id: int, session: SessionDep):
|
|
statement = select(Event).where(Event.id == event_id)
|
|
event = session.exec(statement).one()
|
|
|
|
statement = select(Household)
|
|
households = session.exec(statement).all()
|
|
return templates.TemplateResponse(
|
|
request=request,
|
|
name="event.html",
|
|
context={"event": event, "households": households},
|
|
)
|
|
|
|
|
|
@app.post("/event/{event_id}/register")
|
|
async def add_registration(request: Request, event_id: int, session: SessionDep):
|
|
form_data = await request.form()
|
|
|
|
# TODO: Make this return a nicer error message
|
|
try:
|
|
num_adult_meals = int(form_data["numAdults"]) if form_data["numAdults"] else 0
|
|
num_children_meals = int(form_data["numKids"]) if form_data["numKids"] else 0
|
|
num_small_children_meals = (
|
|
int(form_data["numSmallKids"]) if form_data["numSmallKids"] else 0
|
|
)
|
|
except ValueError:
|
|
raise ValueError("All number fields must be integers")
|
|
|
|
registration = Registration(
|
|
household_id=form_data["household"],
|
|
event_id=event_id,
|
|
num_adult_meals=num_adult_meals,
|
|
num_children_meals=num_children_meals,
|
|
num_small_children_meals=num_small_children_meals,
|
|
)
|
|
session.add(registration)
|
|
session.commit()
|
|
return RedirectResponse(url=f"/event/{event_id}", status_code=status.HTTP_302_FOUND)
|
|
|
|
|
|
@app.get("/event/{event_id}/registration/{household_id}/delete")
|
|
async def delete_registration(
|
|
request: Request, event_id: int, household_id: int, session: SessionDep
|
|
):
|
|
"""
|
|
Deletes a registration record for a specific household at a given event. This endpoint
|
|
handles the removal of the registration, commits the change to the database, and
|
|
redirects the user to the event page.
|
|
"""
|
|
statement = select(Registration).where(
|
|
Registration.household_id == household_id, Registration.event_id == event_id
|
|
)
|
|
session.delete(session.exec(statement).one())
|
|
session.commit()
|
|
return RedirectResponse(url=f"/event/{event_id}", status_code=status.HTTP_302_FOUND)
|