diff options
author | Parker <contact@pkrm.dev> | 2024-11-04 23:01:13 -0600 |
---|---|---|
committer | Parker <contact@pkrm.dev> | 2024-11-04 23:01:13 -0600 |
commit | 3f8e39cc86ca22c3e94f52d693c90553ef1dfd57 (patch) | |
tree | 0bf2ef55e3250d059f1bdaf8546f2c1f2773ad52 | |
parent | 5a0777033f6733c33fbd6119ade812e0c749be44 (diff) |
Major consolidation and upgrades
-rw-r--r-- | api/main.py | 112 | ||||
-rw-r--r-- | app/main.py | 331 | ||||
-rw-r--r-- | app/routes/links_route.py (renamed from api/routes/links_route.py) | 20 | ||||
-rw-r--r-- | app/routes/refresh_route.py | 33 | ||||
-rw-r--r-- | app/routes/token_route.py | 54 | ||||
-rw-r--r-- | app/schemas/auth_schemas.py (renamed from api/schemas/auth_schemas.py) | 0 | ||||
-rw-r--r-- | app/schemas/links_schemas.py (renamed from api/schemas/links_schemas.py) | 0 | ||||
-rw-r--r-- | app/templates/dashboard.html | 2 | ||||
-rw-r--r-- | app/templates/login.html | 10 | ||||
-rw-r--r-- | app/util/authentication.py (renamed from api/util/authentication.py) | 81 | ||||
-rw-r--r-- | app/util/db_dependency.py (renamed from api/util/db_dependency.py) | 0 | ||||
-rw-r--r-- | linklogger.py | 18 |
12 files changed, 307 insertions, 354 deletions
diff --git a/api/main.py b/api/main.py deleted file mode 100644 index fbe8805..0000000 --- a/api/main.py +++ /dev/null @@ -1,112 +0,0 @@ -import random -from fastapi import FastAPI, Depends, HTTPException, status -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import RedirectResponse -from datetime import timedelta -from typing import Annotated -from fastapi.security import OAuth2PasswordRequestForm, OAuth2PasswordBearer - -from api.util.authentication import ( - authenticate_user, - create_access_token, - refresh_get_current_user, -) -from api.routes.links_route import router as links_router -from api.util.db_dependency import get_db -from api.schemas.auth_schemas import Token, User - - -metadata_tags = [ - {"name": "links", "description": "Operations for managing links"}, -] - -app = FastAPI( - title="LinkLogger API", - version="1.0", - summary="Public API for a combined link shortener and IP logger", - license_info={ - "name": "The Unlicense", - "identifier": "Unlicense", - "url": "https://unlicense.org", - }, - openapi_tags=metadata_tags, -) - -app.add_middleware( - CORSMiddleware, - allow_origins=["*"], - allow_methods=["*"], - allow_headers=["*"], - allow_credentials=True, -) - -secret_key = random.randbytes(32) -algorithm = "HS256" -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") - -# Import routes -app.include_router(links_router) - - -""" -Authentication -""" - - -@app.post("/token") -async def login_for_access_token( - form_data: Annotated[OAuth2PasswordRequestForm, Depends()], - db=Depends(get_db), -) -> Token: - """ - Return an access token for the user, if the given authentication details are correct - """ - user = authenticate_user(db, form_data.username, form_data.password) - if not user: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Incorrect username or password", - headers={"WWW-Authenticate": "Bearer"}, - ) - access_token_expires = timedelta(minutes=15) - access_token = create_access_token( - data={"sub": user.username, "refresh": False}, - expires_delta=access_token_expires, - ) - # Create a refresh token - just an access token with a longer expiry - # and more restrictions ("refresh" is True) - refresh_token_expires = timedelta(days=1) - refresh_token = create_access_token( - data={"sub": user.username, "refresh": True}, - expires_delta=refresh_token_expires, - ) - return Token( - access_token=access_token, - refresh_token=refresh_token, - token_type="bearer", - ) - - -# Full native JWT support is not complete in FastAPI yet :( -# Part of that is token refresh, so we must implement it ourselves -@app.post("/refresh") -async def refresh_access_token( - current_user: Annotated[User, Depends(refresh_get_current_user)], -) -> Token: - """ - Return a new access token if the refresh token is valid - """ - access_token_expires = timedelta(minutes=30) - access_token = create_access_token( - data={"sub": current_user.username}, expires_delta=access_token_expires - ) - return Token( - access_token=access_token, - token_type="bearer", - ) - - -# Redirect /api -> /api/docs -@app.get("/") -async def redirect_to_docs(): - return RedirectResponse(url="/api/docs") diff --git a/app/main.py b/app/main.py index 78a65c8..c36d64a 100644 --- a/app/main.py +++ b/app/main.py @@ -1,187 +1,150 @@ -from flask_login import ( - current_user, - login_user, - login_required, - logout_user, - LoginManager, - UserMixin, +from fastapi import FastAPI, Path, Depends, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import RedirectResponse +from fastapi.templating import Jinja2Templates +from app.routes.links_route import router as links_router +from app.routes.refresh_route import router as refresh_router +from app.routes.token_route import router as token_router +from typing import Annotated +from fastapi.exceptions import HTTPException +from starlette.status import HTTP_404_NOT_FOUND + +from app.util.authentication import get_current_user_from_cookie +from app.schemas.auth_schemas import User + +app = FastAPI( + title="LinkLogger API", + version="1.0", + summary="Public API for a combined link shortener and IP logger", + license_info={ + "name": "The Unlicense", + "identifier": "Unlicense", + "url": "https://unlicense.org", + }, ) -from flask import Flask, redirect, render_template, request, url_for -import bcrypt -import os -import string -import random - -from models import User, Link -from database import * -from app.util.log import log - - -class FlaskUser(UserMixin): - pass - - -app = Flask(__name__) -app.config["SECRET_KEY"] = os.urandom(24) - -login_manager = LoginManager() -login_manager.init_app(app) - - -@login_manager.user_loader -def user_loader(username): - user = FlaskUser() - user.id = username - return user - - -""" -Handle login requests from the web UI -""" - - -@app.route("/login", methods=["GET", "POST"]) -def login(): - if request.method == "POST": - username = request.form["username"] - password = request.form["password"] - - # Get database session - db = SessionLocal() - - user = db.query(User).filter(User.username == username).first() - db.close() - if not user: - return {"status": "Invalid username or password"} - - if bcrypt.checkpw( - password.encode("utf-8"), user.password.encode("utf-8") - ): - flask_user = FlaskUser() - flask_user.id = username - login_user(flask_user) - return {"status": "success"} - - return {"status": "Invalid username or password"} - return render_template("login.html") - - -""" -Handle signup requests from the web UI -""" - - -@app.route("/signup", methods=["GET", "POST"]) -def signup(): - if request.method == "POST": - username = request.form["username"] - password = request.form["password"] - - # Verify the password meets requirements - if len(password) < 8: - return {"status": "Password must be at least 8 characters"} - if not any(char.isdigit() for char in password): - return {"status": "Password must contain at least one digit"} - if not any(char.isupper() for char in password): - return { - "status": "Password must contain at least one uppercase letter" - } - - # Get database session - db = SessionLocal() - - user = db.query(User).filter(User.username == username).first() - if user: - db.close() - return {"status": "Username not available"} - # Add information to the database - hashed_password = bcrypt.hashpw( - password.encode("utf-8"), bcrypt.gensalt() - ).decode("utf-8") - api_key = "".join( - random.choices(string.ascii_letters + string.digits, k=20) - ) - new_user = User( - username=username, password=hashed_password, api_key=api_key - ) - db.add(new_user) - db.commit() - db.close() - # Log in the newly created user - flask_user = FlaskUser() - flask_user.id = username - login_user(flask_user) - - return {"status": "success"} - return render_template("signup.html") +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_methods=["*"], + allow_headers=["*"], + allow_credentials=True, +) -""" -Load the 'dashboard' page for logged in users -""" - - -@app.route("/dashboard", methods=["GET"]) -@login_required -def dashboard(): - # Get database session - db = SessionLocal() - - # Get the API key for the current user - user = db.query(User).filter(User.username == current_user.id).first() - db.close() - api_key = user.api_key - - return render_template("dashboard.html", api_key=api_key) - - -""" -Log users out of their account -""" - - -@app.route("/logout", methods=["GET"]) -@login_required -def logout(): - logout_user() - return redirect(url_for("login")) - - -""" -Log all records for visits to shortened links -""" - - -@app.route("/<link>", methods=["GET"]) -def log_redirect(link: str): - link = link.upper() - # If `link` is not exactly 5 characters, return redirect to base url - if len(link) != 5: - return redirect(url_for("login")) - - # Make sure the link exists in the database - db = SessionLocal() - link_record = db.query(Link).filter(Link.link == link).first() - if not link_record: - db.close() - return redirect(url_for("login")) - else: - # Log the visit - if request.headers.get("X-Real-IP"): - ip = request.headers.get("X-Real-IP").split(",")[0] - else: - ip = request.remote_addr - user_agent = request.headers.get("User-Agent") - log(link, ip, user_agent) - db.close() - return redirect(link_record.redirect_link) - - -@app.errorhandler(401) -def unauthorized(e): - return redirect(url_for("login")) - - -@app.errorhandler(404) -def not_found(e): - return redirect(url_for("login")) +templates = Jinja2Templates(directory="app/templates") + +# Import routes +app.include_router(links_router, prefix="/api") +# Must not have a prefix... for some reason you can't change +# the prefix of the Swagger UI OAuth2 redirect to /api/token +# you can only change it to /token, so we have to remove the +# prefix in order to keep logging in via Swagger UI working +app.include_router(token_router) +app.include_router(refresh_router, prefix="/api") + + +@app.get("/login") +async def login(request: Request): + return templates.TemplateResponse("login.html", {"request": request}) + + +# Handle login requests through Swagger UI + + +# @app.route("/signup", methods=["GET", "POST"]) +# def signup(): +# if request.method == "POST": +# username = request.form["username"] +# password = request.form["password"] + +# # Verify the password meets requirements +# if len(password) < 8: +# return {"status": "Password must be at least 8 characters"} +# if not any(char.isdigit() for char in password): +# return {"status": "Password must contain at least one digit"} +# if not any(char.isupper() for char in password): +# return { +# "status": "Password must contain at least one uppercase letter" +# } + +# # Get database session +# db = SessionLocal() + +# user = db.query(User).filter(User.username == username).first() +# if user: +# db.close() +# return {"status": "Username not available"} +# # Add information to the database +# hashed_password = bcrypt.hashpw( +# password.encode("utf-8"), bcrypt.gensalt() +# ).decode("utf-8") +# api_key = "".join( +# random.choices(string.ascii_letters + string.digits, k=20) +# ) +# new_user = User( +# username=username, password=hashed_password, api_key=api_key +# ) +# db.add(new_user) +# db.commit() +# db.close() +# # Log in the newly created user +# flask_user = FlaskUser() +# flask_user.id = username +# login_user(flask_user) + +# return {"status": "success"} +# return render_template("signup.html") + + +@app.get("/dashboard") +async def dashboard( + response: Annotated[ + User, RedirectResponse, Depends(get_current_user_from_cookie) + ], + request: Request, +): + if isinstance(response, RedirectResponse): + return response + return templates.TemplateResponse( + "dashboard.html", {"request": request, "user": response.username} + ) + + +# @app.get("/{link}") +# async def log_redirect( +# link: Annotated[str, Path(title="Redirect link")], +# request: Request, +# db=Depends(get_db), +# ): +# link = link.upper() +# # If `link` is not exactly 5 characters, return redirect to base url +# if len(link) != 5: +# return RedirectResponse(url="/login") + +# # Make sure the link exists in the database +# link_record: Link = db.query(Link).filter(Link.link == link).first() +# if not link_record: +# db.close() +# return RedirectResponse(url="/login") +# else: +# # Log the visit +# if request.headers.get("X-Real-IP"): +# ip = request.headers.get("X-Real-IP").split(",")[0] +# else: +# ip = request.client.host +# user_agent = request.headers.get("User-Agent") +# log(link, ip, user_agent) +# db.close() +# return RedirectResponse(url=link_record.redirect_link) + + +# Redirect /api -> /api/docs +@app.get("/api") +async def redirect_to_docs(): + return RedirectResponse(url="/docs") + + +# Custom handler for 404 errors +@app.exception_handler(HTTP_404_NOT_FOUND) +async def custom_404_handler(request: Request, exc: HTTPException): + return RedirectResponse(url="/login") diff --git a/api/routes/links_route.py b/app/routes/links_route.py index 08e7690..054508a 100644 --- a/api/routes/links_route.py +++ b/app/routes/links_route.py @@ -1,4 +1,4 @@ -from fastapi import APIRouter, status, Path, Depends, Security, Request +from fastapi import APIRouter, status, Path, Depends from fastapi.exception_handlers import HTTPException from typing import Annotated import string @@ -6,11 +6,11 @@ import random import datetime import validators -from api.util.db_dependency import get_db +from app.util.db_dependency import get_db from models import Link, Record -from api.schemas.links_schemas import URLSchema -from api.schemas.auth_schemas import User -from api.util.authentication import get_current_user +from app.schemas.links_schemas import URLSchema +from app.schemas.auth_schemas import User +from app.util.authentication import get_current_user_from_token router = APIRouter(prefix="/links", tags=["links"]) @@ -18,7 +18,7 @@ router = APIRouter(prefix="/links", tags=["links"]) @router.get("/", summary="Get all of the links associated with your account") async def get_links( - current_user: Annotated[User, Depends(get_current_user)], + current_user: Annotated[User, Depends(get_current_user_from_token)], db=Depends(get_db), ): links = db.query(Link).filter(Link.owner == current_user.id).all() @@ -32,7 +32,7 @@ async def get_links( @router.post("/", summary="Create a new link") async def create_link( url: URLSchema, - current_user: Annotated[User, Depends(get_current_user)], + current_user: Annotated[User, Depends(get_current_user_from_token)], db=Depends(get_db), ): # Check if the URL is valid @@ -70,7 +70,7 @@ async def create_link( @router.delete("/{link}", summary="Delete a link") async def delete_link( link: Annotated[str, Path(title="Link to delete")], - current_user: Annotated[User, Depends(get_current_user)], + current_user: Annotated[User, Depends(get_current_user_from_token)], db=Depends(get_db), ): link = link.upper() @@ -103,7 +103,7 @@ async def delete_link( ) async def get_link_records( link: Annotated[str, Path(title="Link to get records for")], - current_user: Annotated[User, Depends(get_current_user)], + current_user: Annotated[User, Depends(get_current_user_from_token)], db=Depends(get_db), ): link = link.upper() @@ -130,7 +130,7 @@ async def get_link_records( ) async def delete_link_records( link: Annotated[str, Path(title="Link to delete records for")], - current_user: Annotated[User, Depends(get_current_user)], + current_user: Annotated[User, Depends(get_current_user_from_token)], db=Depends(get_db), ): link = link.upper() diff --git a/app/routes/refresh_route.py b/app/routes/refresh_route.py new file mode 100644 index 0000000..6bc8797 --- /dev/null +++ b/app/routes/refresh_route.py @@ -0,0 +1,33 @@ +from fastapi import Depends, APIRouter +from fastapi.responses import RedirectResponse +from datetime import timedelta +from typing import Annotated + +from app.util.authentication import ( + create_access_token, + refresh_get_current_user, +) +from app.schemas.auth_schemas import Token, User + + +router = APIRouter(prefix="/refresh", tags=["refresh"]) + + +# Full native JWT support is not complete in FastAPI yet :( +# Part of that is token refresh, so we must implement it ourselves +@router.post("/") +async def refresh_access_token( + current_user: Annotated[User, Depends(refresh_get_current_user)], +) -> Token: + """ + Return a new access token if the refresh token is valid + """ + access_token_expires = timedelta(minutes=30) + access_token = create_access_token( + data={"sub": current_user.username, "refresh": False}, + expires_delta=access_token_expires, + ) + return Token( + access_token=access_token, + token_type="bearer", + ) diff --git a/app/routes/token_route.py b/app/routes/token_route.py new file mode 100644 index 0000000..8000616 --- /dev/null +++ b/app/routes/token_route.py @@ -0,0 +1,54 @@ +from fastapi import APIRouter, status, Depends, HTTPException +from fastapi.responses import JSONResponse, Response +from typing import Annotated +from datetime import timedelta +from typing import Annotated +from fastapi.security import OAuth2PasswordRequestForm + +from app.util.db_dependency import get_db +from app.util.authentication import ( + authenticate_user, + create_access_token, +) +from app.schemas.auth_schemas import Token + + +router = APIRouter(prefix="/token", tags=["token"]) + + +@router.post("/") +async def login_for_access_token( + form_data: Annotated[OAuth2PasswordRequestForm, Depends()], + response: Response, + db=Depends(get_db), +) -> Token: + """ + Return an access token for the user, if the given authentication details are correct + """ + user = authenticate_user(db, form_data.username, form_data.password) + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + headers={"WWW-Authenticate": "Bearer"}, + ) + access_token_expires = timedelta(minutes=15) + access_token = create_access_token( + data={"sub": user.username, "refresh": False}, + expires_delta=access_token_expires, + ) + # Create a refresh token - just an access token with a longer expiry + # and more restrictions ("refresh" is True) + refresh_token_expires = timedelta(days=1) + refresh_token = create_access_token( + data={"sub": user.username, "refresh": True}, + expires_delta=refresh_token_expires, + ) + response = JSONResponse(content={"success": True}) + response.set_cookie( + key="access_token", value=access_token, httponly=True, samesite="lax" + ) + response.set_cookie( + key="refresh_token", value=refresh_token, httponly=True, samesite="lax" + ) + return response diff --git a/api/schemas/auth_schemas.py b/app/schemas/auth_schemas.py index 006a7c8..006a7c8 100644 --- a/api/schemas/auth_schemas.py +++ b/app/schemas/auth_schemas.py diff --git a/api/schemas/links_schemas.py b/app/schemas/links_schemas.py index e2812fb..e2812fb 100644 --- a/api/schemas/links_schemas.py +++ b/app/schemas/links_schemas.py diff --git a/app/templates/dashboard.html b/app/templates/dashboard.html index 2118fbf..f1c98e3 100644 --- a/app/templates/dashboard.html +++ b/app/templates/dashboard.html @@ -8,7 +8,7 @@ <body> <div> <!-- Create a small box that will hold the text for the users api key, next to the box should be a regenerate button --> - <p>Your API Key: <span id="api-key">{{ api_key }}</span></p> + <p>Your Username: <span id="api-key">{{ user }}</span></p> <button onclick="window.location.href='logout'">Logout</button> </div> </body> diff --git a/app/templates/login.html b/app/templates/login.html index b41d15c..1061699 100644 --- a/app/templates/login.html +++ b/app/templates/login.html @@ -90,19 +90,15 @@ event.preventDefault(); const formData = new FormData(this); - // Send POST request to /api/token containing form data - const response = await fetch('/api/token', { + // Send POST request to /token containing form data + const response = await fetch('/token', { method: 'POST', body: formData }); - const data = await response.json(); - if (data.response != 200) { + if (response.status != 200) { document.getElementById('error').style.display = 'block'; } else { - // Save the tokens in localStorage - window.localStorage.token = data.token; - window.localStorage.refreshToken = data.refreshToken; window.location.href = '/dashboard'; } }); diff --git a/api/util/authentication.py b/app/util/authentication.py index b8ac6a6..b94b1c6 100644 --- a/api/util/authentication.py +++ b/app/util/authentication.py @@ -1,14 +1,16 @@ import random import bcrypt -from fastapi import Depends, HTTPException, status +from fastapi import Depends, HTTPException, status, Cookie from fastapi.security import OAuth2PasswordBearer +from fastapi.responses import RedirectResponse from jwt.exceptions import InvalidTokenError from datetime import datetime, timedelta from typing import Annotated, Optional import jwt -from api.util.db_dependency import get_db -from api.schemas.auth_schemas import * +from app.util.db_dependency import get_db +from sqlalchemy.orm import sessionmaker +from app.schemas.auth_schemas import * from models import User as UserDB secret_key = random.randbytes(32) @@ -59,53 +61,82 @@ def create_access_token(data: dict, expires_delta: timedelta): return encoded_jwt -# Backwards kinda of way to get refresh token support -# 'refresh_get_current_user' is only called from /refresh -# and alerts 'current_user' that it should expect a refresh token -async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]): - user = await current_user(token) - return user +async def get_current_user_from_cookie( + access_token: str = Cookie(None), db=Depends(get_db) +): + """ + Return the user based on the access token in the cookie + + Used for authentication into UI pages - so if no cookie + exists, redirect to login page rather than returning a 401 + + Also pass is_ui=True to alert get_current_user that we need + to use RedirectResponse rather than raising an HTTPException + """ + if access_token: + return await get_current_user(access_token, is_ui=True, db=db) + return RedirectResponse(url="/login") + + +async def get_current_user_from_token( + token: Annotated[str, Depends(oauth2_scheme)], db=Depends(get_db) +): + return await get_current_user(token, db=db) +# Backwards kinda of way to get refresh token support +# `refresh_get_current_user` is only called from /refresh +# and alerts `get_current_user` that it should expect a refresh token async def refresh_get_current_user( - token: Annotated[str, Depends(oauth2_scheme)], + token: Annotated[str, Depends(oauth2_scheme)], db=Depends(get_db) ): - user = await current_user(token, is_refresh=True) - return user + return await get_current_user(token, is_refresh=True, db=db) -async def current_user( +async def get_current_user( token: str, is_refresh: bool = False, - db=Depends(get_db), + is_ui: bool = False, + db: Optional[sessionmaker] = None, ): """ - Return the current user based on the token, or raise a 401 error + Return the current user based on the token + + OR on error - + If is_ui=True, the request is from a UI page and we should redirect to login + Otherwise, the request is from an API and we should return a 401 """ - credentials_exception = HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, - ) + + def raise_unauthorized(): + if is_ui: + return RedirectResponse(url="/login") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + try: payload = jwt.decode(token, secret_key, algorithms=[algorithm]) username: str = payload.get("sub") refresh: bool = payload.get("refresh") if username is None: - raise credentials_exception + return raise_unauthorized() # For some reason, an access token was passed when a refresh # token was expected - some likely malicious activity if not refresh and is_refresh: - raise credentials_exception + return raise_unauthorized() # If the token passed is a refresh token and the function # is not expecting a refresh token, raise an error if refresh and not is_refresh: - raise credentials_exception + return raise_unauthorized() token_data = TokenData(username=username) except InvalidTokenError: - raise credentials_exception + return raise_unauthorized() + user = get_user(db, username=token_data.username) if user is None: - raise credentials_exception + return raise_unauthorized() + return user diff --git a/api/util/db_dependency.py b/app/util/db_dependency.py index a6734ea..a6734ea 100644 --- a/api/util/db_dependency.py +++ b/app/util/db_dependency.py diff --git a/linklogger.py b/linklogger.py index 90348e3..0e3f011 100644 --- a/linklogger.py +++ b/linklogger.py @@ -1,21 +1,9 @@ -from werkzeug.middleware.dispatcher import DispatcherMiddleware -from a2wsgi import ASGIMiddleware +import uvicorn import config -from app.main import app as flask_app -from api.main import app as fastapi_app -from database import Base, engine +from app.main import app -Base.metadata.create_all(bind=engine) - -flask_app.wsgi_app = DispatcherMiddleware( - flask_app.wsgi_app, - { - "/": flask_app, - "/api": ASGIMiddleware(fastapi_app), - }, -) if __name__ == "__main__": config.load_config() - flask_app.run(port=5252) + uvicorn.run(app, port=5252) |