diff options
author | Parker <contact@pkrm.dev> | 2024-11-04 23:01:13 -0600 |
---|---|---|
committer | Parker <contact@pkrm.dev> | 2024-11-04 23:01:13 -0600 |
commit | 3f8e39cc86ca22c3e94f52d693c90553ef1dfd57 (patch) | |
tree | 0bf2ef55e3250d059f1bdaf8546f2c1f2773ad52 /app/main.py | |
parent | 5a0777033f6733c33fbd6119ade812e0c749be44 (diff) |
Major consolidation and upgrades
Diffstat (limited to 'app/main.py')
-rw-r--r-- | app/main.py | 331 |
1 files changed, 147 insertions, 184 deletions
diff --git a/app/main.py b/app/main.py index 78a65c8..c36d64a 100644 --- a/app/main.py +++ b/app/main.py @@ -1,187 +1,150 @@ -from flask_login import ( - current_user, - login_user, - login_required, - logout_user, - LoginManager, - UserMixin, +from fastapi import FastAPI, Path, Depends, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import RedirectResponse +from fastapi.templating import Jinja2Templates +from app.routes.links_route import router as links_router +from app.routes.refresh_route import router as refresh_router +from app.routes.token_route import router as token_router +from typing import Annotated +from fastapi.exceptions import HTTPException +from starlette.status import HTTP_404_NOT_FOUND + +from app.util.authentication import get_current_user_from_cookie +from app.schemas.auth_schemas import User + +app = FastAPI( + title="LinkLogger API", + version="1.0", + summary="Public API for a combined link shortener and IP logger", + license_info={ + "name": "The Unlicense", + "identifier": "Unlicense", + "url": "https://unlicense.org", + }, ) -from flask import Flask, redirect, render_template, request, url_for -import bcrypt -import os -import string -import random - -from models import User, Link -from database import * -from app.util.log import log - - -class FlaskUser(UserMixin): - pass - - -app = Flask(__name__) -app.config["SECRET_KEY"] = os.urandom(24) - -login_manager = LoginManager() -login_manager.init_app(app) - - -@login_manager.user_loader -def user_loader(username): - user = FlaskUser() - user.id = username - return user - - -""" -Handle login requests from the web UI -""" - - -@app.route("/login", methods=["GET", "POST"]) -def login(): - if request.method == "POST": - username = request.form["username"] - password = request.form["password"] - - # Get database session - db = SessionLocal() - - user = db.query(User).filter(User.username == username).first() - db.close() - if not user: - return {"status": "Invalid username or password"} - - if bcrypt.checkpw( - password.encode("utf-8"), user.password.encode("utf-8") - ): - flask_user = FlaskUser() - flask_user.id = username - login_user(flask_user) - return {"status": "success"} - - return {"status": "Invalid username or password"} - return render_template("login.html") - - -""" -Handle signup requests from the web UI -""" - - -@app.route("/signup", methods=["GET", "POST"]) -def signup(): - if request.method == "POST": - username = request.form["username"] - password = request.form["password"] - - # Verify the password meets requirements - if len(password) < 8: - return {"status": "Password must be at least 8 characters"} - if not any(char.isdigit() for char in password): - return {"status": "Password must contain at least one digit"} - if not any(char.isupper() for char in password): - return { - "status": "Password must contain at least one uppercase letter" - } - - # Get database session - db = SessionLocal() - - user = db.query(User).filter(User.username == username).first() - if user: - db.close() - return {"status": "Username not available"} - # Add information to the database - hashed_password = bcrypt.hashpw( - password.encode("utf-8"), bcrypt.gensalt() - ).decode("utf-8") - api_key = "".join( - random.choices(string.ascii_letters + string.digits, k=20) - ) - new_user = User( - username=username, password=hashed_password, api_key=api_key - ) - db.add(new_user) - db.commit() - db.close() - # Log in the newly created user - flask_user = FlaskUser() - flask_user.id = username - login_user(flask_user) - - return {"status": "success"} - return render_template("signup.html") +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_methods=["*"], + allow_headers=["*"], + allow_credentials=True, +) -""" -Load the 'dashboard' page for logged in users -""" - - -@app.route("/dashboard", methods=["GET"]) -@login_required -def dashboard(): - # Get database session - db = SessionLocal() - - # Get the API key for the current user - user = db.query(User).filter(User.username == current_user.id).first() - db.close() - api_key = user.api_key - - return render_template("dashboard.html", api_key=api_key) - - -""" -Log users out of their account -""" - - -@app.route("/logout", methods=["GET"]) -@login_required -def logout(): - logout_user() - return redirect(url_for("login")) - - -""" -Log all records for visits to shortened links -""" - - -@app.route("/<link>", methods=["GET"]) -def log_redirect(link: str): - link = link.upper() - # If `link` is not exactly 5 characters, return redirect to base url - if len(link) != 5: - return redirect(url_for("login")) - - # Make sure the link exists in the database - db = SessionLocal() - link_record = db.query(Link).filter(Link.link == link).first() - if not link_record: - db.close() - return redirect(url_for("login")) - else: - # Log the visit - if request.headers.get("X-Real-IP"): - ip = request.headers.get("X-Real-IP").split(",")[0] - else: - ip = request.remote_addr - user_agent = request.headers.get("User-Agent") - log(link, ip, user_agent) - db.close() - return redirect(link_record.redirect_link) - - -@app.errorhandler(401) -def unauthorized(e): - return redirect(url_for("login")) - - -@app.errorhandler(404) -def not_found(e): - return redirect(url_for("login")) +templates = Jinja2Templates(directory="app/templates") + +# Import routes +app.include_router(links_router, prefix="/api") +# Must not have a prefix... for some reason you can't change +# the prefix of the Swagger UI OAuth2 redirect to /api/token +# you can only change it to /token, so we have to remove the +# prefix in order to keep logging in via Swagger UI working +app.include_router(token_router) +app.include_router(refresh_router, prefix="/api") + + +@app.get("/login") +async def login(request: Request): + return templates.TemplateResponse("login.html", {"request": request}) + + +# Handle login requests through Swagger UI + + +# @app.route("/signup", methods=["GET", "POST"]) +# def signup(): +# if request.method == "POST": +# username = request.form["username"] +# password = request.form["password"] + +# # Verify the password meets requirements +# if len(password) < 8: +# return {"status": "Password must be at least 8 characters"} +# if not any(char.isdigit() for char in password): +# return {"status": "Password must contain at least one digit"} +# if not any(char.isupper() for char in password): +# return { +# "status": "Password must contain at least one uppercase letter" +# } + +# # Get database session +# db = SessionLocal() + +# user = db.query(User).filter(User.username == username).first() +# if user: +# db.close() +# return {"status": "Username not available"} +# # Add information to the database +# hashed_password = bcrypt.hashpw( +# password.encode("utf-8"), bcrypt.gensalt() +# ).decode("utf-8") +# api_key = "".join( +# random.choices(string.ascii_letters + string.digits, k=20) +# ) +# new_user = User( +# username=username, password=hashed_password, api_key=api_key +# ) +# db.add(new_user) +# db.commit() +# db.close() +# # Log in the newly created user +# flask_user = FlaskUser() +# flask_user.id = username +# login_user(flask_user) + +# return {"status": "success"} +# return render_template("signup.html") + + +@app.get("/dashboard") +async def dashboard( + response: Annotated[ + User, RedirectResponse, Depends(get_current_user_from_cookie) + ], + request: Request, +): + if isinstance(response, RedirectResponse): + return response + return templates.TemplateResponse( + "dashboard.html", {"request": request, "user": response.username} + ) + + +# @app.get("/{link}") +# async def log_redirect( +# link: Annotated[str, Path(title="Redirect link")], +# request: Request, +# db=Depends(get_db), +# ): +# link = link.upper() +# # If `link` is not exactly 5 characters, return redirect to base url +# if len(link) != 5: +# return RedirectResponse(url="/login") + +# # Make sure the link exists in the database +# link_record: Link = db.query(Link).filter(Link.link == link).first() +# if not link_record: +# db.close() +# return RedirectResponse(url="/login") +# else: +# # Log the visit +# if request.headers.get("X-Real-IP"): +# ip = request.headers.get("X-Real-IP").split(",")[0] +# else: +# ip = request.client.host +# user_agent = request.headers.get("User-Agent") +# log(link, ip, user_agent) +# db.close() +# return RedirectResponse(url=link_record.redirect_link) + + +# Redirect /api -> /api/docs +@app.get("/api") +async def redirect_to_docs(): + return RedirectResponse(url="/docs") + + +# Custom handler for 404 errors +@app.exception_handler(HTTP_404_NOT_FOUND) +async def custom_404_handler(request: Request, exc: HTTPException): + return RedirectResponse(url="/login") |