diff options
Diffstat (limited to 'app')
-rw-r--r-- | app/main.py | 331 | ||||
-rw-r--r-- | app/routes/links_route.py | 155 | ||||
-rw-r--r-- | app/routes/refresh_route.py | 33 | ||||
-rw-r--r-- | app/routes/token_route.py | 54 | ||||
-rw-r--r-- | app/schemas/auth_schemas.py | 20 | ||||
-rw-r--r-- | app/schemas/links_schemas.py | 5 | ||||
-rw-r--r-- | app/templates/dashboard.html | 2 | ||||
-rw-r--r-- | app/templates/login.html | 10 | ||||
-rw-r--r-- | app/util/authentication.py | 142 | ||||
-rw-r--r-- | app/util/db_dependency.py | 9 |
10 files changed, 569 insertions, 192 deletions
diff --git a/app/main.py b/app/main.py index 78a65c8..c36d64a 100644 --- a/app/main.py +++ b/app/main.py @@ -1,187 +1,150 @@ -from flask_login import ( - current_user, - login_user, - login_required, - logout_user, - LoginManager, - UserMixin, +from fastapi import FastAPI, Path, Depends, Request +from fastapi.middleware.cors import CORSMiddleware +from fastapi.responses import RedirectResponse +from fastapi.templating import Jinja2Templates +from app.routes.links_route import router as links_router +from app.routes.refresh_route import router as refresh_router +from app.routes.token_route import router as token_router +from typing import Annotated +from fastapi.exceptions import HTTPException +from starlette.status import HTTP_404_NOT_FOUND + +from app.util.authentication import get_current_user_from_cookie +from app.schemas.auth_schemas import User + +app = FastAPI( + title="LinkLogger API", + version="1.0", + summary="Public API for a combined link shortener and IP logger", + license_info={ + "name": "The Unlicense", + "identifier": "Unlicense", + "url": "https://unlicense.org", + }, ) -from flask import Flask, redirect, render_template, request, url_for -import bcrypt -import os -import string -import random - -from models import User, Link -from database import * -from app.util.log import log - - -class FlaskUser(UserMixin): - pass - - -app = Flask(__name__) -app.config["SECRET_KEY"] = os.urandom(24) - -login_manager = LoginManager() -login_manager.init_app(app) - - -@login_manager.user_loader -def user_loader(username): - user = FlaskUser() - user.id = username - return user - - -""" -Handle login requests from the web UI -""" - - -@app.route("/login", methods=["GET", "POST"]) -def login(): - if request.method == "POST": - username = request.form["username"] - password = request.form["password"] - - # Get database session - db = SessionLocal() - - user = db.query(User).filter(User.username == username).first() - db.close() - if not user: - return {"status": "Invalid username or password"} - - if bcrypt.checkpw( - password.encode("utf-8"), user.password.encode("utf-8") - ): - flask_user = FlaskUser() - flask_user.id = username - login_user(flask_user) - return {"status": "success"} - - return {"status": "Invalid username or password"} - return render_template("login.html") - - -""" -Handle signup requests from the web UI -""" - - -@app.route("/signup", methods=["GET", "POST"]) -def signup(): - if request.method == "POST": - username = request.form["username"] - password = request.form["password"] - - # Verify the password meets requirements - if len(password) < 8: - return {"status": "Password must be at least 8 characters"} - if not any(char.isdigit() for char in password): - return {"status": "Password must contain at least one digit"} - if not any(char.isupper() for char in password): - return { - "status": "Password must contain at least one uppercase letter" - } - - # Get database session - db = SessionLocal() - - user = db.query(User).filter(User.username == username).first() - if user: - db.close() - return {"status": "Username not available"} - # Add information to the database - hashed_password = bcrypt.hashpw( - password.encode("utf-8"), bcrypt.gensalt() - ).decode("utf-8") - api_key = "".join( - random.choices(string.ascii_letters + string.digits, k=20) - ) - new_user = User( - username=username, password=hashed_password, api_key=api_key - ) - db.add(new_user) - db.commit() - db.close() - # Log in the newly created user - flask_user = FlaskUser() - flask_user.id = username - login_user(flask_user) - - return {"status": "success"} - return render_template("signup.html") +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_methods=["*"], + allow_headers=["*"], + allow_credentials=True, +) -""" -Load the 'dashboard' page for logged in users -""" - - -@app.route("/dashboard", methods=["GET"]) -@login_required -def dashboard(): - # Get database session - db = SessionLocal() - - # Get the API key for the current user - user = db.query(User).filter(User.username == current_user.id).first() - db.close() - api_key = user.api_key - - return render_template("dashboard.html", api_key=api_key) - - -""" -Log users out of their account -""" - - -@app.route("/logout", methods=["GET"]) -@login_required -def logout(): - logout_user() - return redirect(url_for("login")) - - -""" -Log all records for visits to shortened links -""" - - -@app.route("/<link>", methods=["GET"]) -def log_redirect(link: str): - link = link.upper() - # If `link` is not exactly 5 characters, return redirect to base url - if len(link) != 5: - return redirect(url_for("login")) - - # Make sure the link exists in the database - db = SessionLocal() - link_record = db.query(Link).filter(Link.link == link).first() - if not link_record: - db.close() - return redirect(url_for("login")) - else: - # Log the visit - if request.headers.get("X-Real-IP"): - ip = request.headers.get("X-Real-IP").split(",")[0] - else: - ip = request.remote_addr - user_agent = request.headers.get("User-Agent") - log(link, ip, user_agent) - db.close() - return redirect(link_record.redirect_link) - - -@app.errorhandler(401) -def unauthorized(e): - return redirect(url_for("login")) - - -@app.errorhandler(404) -def not_found(e): - return redirect(url_for("login")) +templates = Jinja2Templates(directory="app/templates") + +# Import routes +app.include_router(links_router, prefix="/api") +# Must not have a prefix... for some reason you can't change +# the prefix of the Swagger UI OAuth2 redirect to /api/token +# you can only change it to /token, so we have to remove the +# prefix in order to keep logging in via Swagger UI working +app.include_router(token_router) +app.include_router(refresh_router, prefix="/api") + + +@app.get("/login") +async def login(request: Request): + return templates.TemplateResponse("login.html", {"request": request}) + + +# Handle login requests through Swagger UI + + +# @app.route("/signup", methods=["GET", "POST"]) +# def signup(): +# if request.method == "POST": +# username = request.form["username"] +# password = request.form["password"] + +# # Verify the password meets requirements +# if len(password) < 8: +# return {"status": "Password must be at least 8 characters"} +# if not any(char.isdigit() for char in password): +# return {"status": "Password must contain at least one digit"} +# if not any(char.isupper() for char in password): +# return { +# "status": "Password must contain at least one uppercase letter" +# } + +# # Get database session +# db = SessionLocal() + +# user = db.query(User).filter(User.username == username).first() +# if user: +# db.close() +# return {"status": "Username not available"} +# # Add information to the database +# hashed_password = bcrypt.hashpw( +# password.encode("utf-8"), bcrypt.gensalt() +# ).decode("utf-8") +# api_key = "".join( +# random.choices(string.ascii_letters + string.digits, k=20) +# ) +# new_user = User( +# username=username, password=hashed_password, api_key=api_key +# ) +# db.add(new_user) +# db.commit() +# db.close() +# # Log in the newly created user +# flask_user = FlaskUser() +# flask_user.id = username +# login_user(flask_user) + +# return {"status": "success"} +# return render_template("signup.html") + + +@app.get("/dashboard") +async def dashboard( + response: Annotated[ + User, RedirectResponse, Depends(get_current_user_from_cookie) + ], + request: Request, +): + if isinstance(response, RedirectResponse): + return response + return templates.TemplateResponse( + "dashboard.html", {"request": request, "user": response.username} + ) + + +# @app.get("/{link}") +# async def log_redirect( +# link: Annotated[str, Path(title="Redirect link")], +# request: Request, +# db=Depends(get_db), +# ): +# link = link.upper() +# # If `link` is not exactly 5 characters, return redirect to base url +# if len(link) != 5: +# return RedirectResponse(url="/login") + +# # Make sure the link exists in the database +# link_record: Link = db.query(Link).filter(Link.link == link).first() +# if not link_record: +# db.close() +# return RedirectResponse(url="/login") +# else: +# # Log the visit +# if request.headers.get("X-Real-IP"): +# ip = request.headers.get("X-Real-IP").split(",")[0] +# else: +# ip = request.client.host +# user_agent = request.headers.get("User-Agent") +# log(link, ip, user_agent) +# db.close() +# return RedirectResponse(url=link_record.redirect_link) + + +# Redirect /api -> /api/docs +@app.get("/api") +async def redirect_to_docs(): + return RedirectResponse(url="/docs") + + +# Custom handler for 404 errors +@app.exception_handler(HTTP_404_NOT_FOUND) +async def custom_404_handler(request: Request, exc: HTTPException): + return RedirectResponse(url="/login") diff --git a/app/routes/links_route.py b/app/routes/links_route.py new file mode 100644 index 0000000..054508a --- /dev/null +++ b/app/routes/links_route.py @@ -0,0 +1,155 @@ +from fastapi import APIRouter, status, Path, Depends +from fastapi.exception_handlers import HTTPException +from typing import Annotated +import string +import random +import datetime +import validators + +from app.util.db_dependency import get_db +from models import Link, Record +from app.schemas.links_schemas import URLSchema +from app.schemas.auth_schemas import User +from app.util.authentication import get_current_user_from_token + + +router = APIRouter(prefix="/links", tags=["links"]) + + +@router.get("/", summary="Get all of the links associated with your account") +async def get_links( + current_user: Annotated[User, Depends(get_current_user_from_token)], + db=Depends(get_db), +): + links = db.query(Link).filter(Link.owner == current_user.id).all() + if not links: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="No links found" + ) + return links + + +@router.post("/", summary="Create a new link") +async def create_link( + url: URLSchema, + current_user: Annotated[User, Depends(get_current_user_from_token)], + db=Depends(get_db), +): + # Check if the URL is valid + if not validators.url(url.url): + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail="Invalid URL", + ) + # Create the new link and add it to the database + while True: + try: + link_path = "".join( + random.choices(string.ascii_uppercase + "1234567890", k=5) + ).upper() + new_link = Link( + link=link_path, + owner=current_user.id, + redirect_link=url.url, + expire_date=datetime.datetime.now() + + datetime.timedelta(days=30), + ) + db.add(new_link) + db.commit() + break + except: + continue + + return { + "response": "Link successfully created", + "expire_date": new_link.expire_date, + "link": new_link.link, + } + + +@router.delete("/{link}", summary="Delete a link") +async def delete_link( + link: Annotated[str, Path(title="Link to delete")], + current_user: Annotated[User, Depends(get_current_user_from_token)], + db=Depends(get_db), +): + link = link.upper() + # Get the link and check the owner + link = db.query(Link).filter(Link.link == link).first() + if not link: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Link not found" + ) + if link.owner != current_user.id: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Link not associated with your account", + ) + + # Get and delete all records associated with the link + records = db.query(Record).filter(Record.link == link.link).all() + for record in records: + db.delete(record) + # Delete the link + db.delete(link) + db.commit() + + return {"response": "Link successfully deleted", "link": link.link} + + +@router.get( + "/{link}/records", + summary="Get all of the IP log records associated with a link", +) +async def get_link_records( + link: Annotated[str, Path(title="Link to get records for")], + current_user: Annotated[User, Depends(get_current_user_from_token)], + db=Depends(get_db), +): + link = link.upper() + # Get the link and check the owner + link = db.query(Link).filter(Link.link == link).first() + if not link: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Link not found" + ) + if link.owner != current_user.id: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Link not associated with your account", + ) + + # Get and return all of the records associated with the link + records = db.query(Record).filter(Record.link == link.link).all() + return records + + +@router.delete( + "/{link}/records", + summary="Delete all of the IP log records associated with a link", +) +async def delete_link_records( + link: Annotated[str, Path(title="Link to delete records for")], + current_user: Annotated[User, Depends(get_current_user_from_token)], + db=Depends(get_db), +): + link = link.upper() + # Get the link and check the owner + link = db.query(Link).filter(Link.link == link).first() + if not link: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, detail="Link not found" + ) + if link.owner != current_user.id: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Link not associated with your account", + ) + + # Get all of the records associated with the link and delete them + records = db.query(Record).filter(Record.link == link.link).all() + for record in records: + db.delete(record) + db.commit() + + return {"response": "Records successfully deleted", "link": link.link} diff --git a/app/routes/refresh_route.py b/app/routes/refresh_route.py new file mode 100644 index 0000000..6bc8797 --- /dev/null +++ b/app/routes/refresh_route.py @@ -0,0 +1,33 @@ +from fastapi import Depends, APIRouter +from fastapi.responses import RedirectResponse +from datetime import timedelta +from typing import Annotated + +from app.util.authentication import ( + create_access_token, + refresh_get_current_user, +) +from app.schemas.auth_schemas import Token, User + + +router = APIRouter(prefix="/refresh", tags=["refresh"]) + + +# Full native JWT support is not complete in FastAPI yet :( +# Part of that is token refresh, so we must implement it ourselves +@router.post("/") +async def refresh_access_token( + current_user: Annotated[User, Depends(refresh_get_current_user)], +) -> Token: + """ + Return a new access token if the refresh token is valid + """ + access_token_expires = timedelta(minutes=30) + access_token = create_access_token( + data={"sub": current_user.username, "refresh": False}, + expires_delta=access_token_expires, + ) + return Token( + access_token=access_token, + token_type="bearer", + ) diff --git a/app/routes/token_route.py b/app/routes/token_route.py new file mode 100644 index 0000000..8000616 --- /dev/null +++ b/app/routes/token_route.py @@ -0,0 +1,54 @@ +from fastapi import APIRouter, status, Depends, HTTPException +from fastapi.responses import JSONResponse, Response +from typing import Annotated +from datetime import timedelta +from typing import Annotated +from fastapi.security import OAuth2PasswordRequestForm + +from app.util.db_dependency import get_db +from app.util.authentication import ( + authenticate_user, + create_access_token, +) +from app.schemas.auth_schemas import Token + + +router = APIRouter(prefix="/token", tags=["token"]) + + +@router.post("/") +async def login_for_access_token( + form_data: Annotated[OAuth2PasswordRequestForm, Depends()], + response: Response, + db=Depends(get_db), +) -> Token: + """ + Return an access token for the user, if the given authentication details are correct + """ + user = authenticate_user(db, form_data.username, form_data.password) + if not user: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + headers={"WWW-Authenticate": "Bearer"}, + ) + access_token_expires = timedelta(minutes=15) + access_token = create_access_token( + data={"sub": user.username, "refresh": False}, + expires_delta=access_token_expires, + ) + # Create a refresh token - just an access token with a longer expiry + # and more restrictions ("refresh" is True) + refresh_token_expires = timedelta(days=1) + refresh_token = create_access_token( + data={"sub": user.username, "refresh": True}, + expires_delta=refresh_token_expires, + ) + response = JSONResponse(content={"success": True}) + response.set_cookie( + key="access_token", value=access_token, httponly=True, samesite="lax" + ) + response.set_cookie( + key="refresh_token", value=refresh_token, httponly=True, samesite="lax" + ) + return response diff --git a/app/schemas/auth_schemas.py b/app/schemas/auth_schemas.py new file mode 100644 index 0000000..006a7c8 --- /dev/null +++ b/app/schemas/auth_schemas.py @@ -0,0 +1,20 @@ +from pydantic import BaseModel + + +class Token(BaseModel): + access_token: str + refresh_token: str | None = None + token_type: str + + +class TokenData(BaseModel): + username: str | None = None + + +class User(BaseModel): + username: str + id: int + + +class UserInDB(User): + hashed_password: str diff --git a/app/schemas/links_schemas.py b/app/schemas/links_schemas.py new file mode 100644 index 0000000..e2812fb --- /dev/null +++ b/app/schemas/links_schemas.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel + + +class URLSchema(BaseModel): + url: str diff --git a/app/templates/dashboard.html b/app/templates/dashboard.html index 2118fbf..f1c98e3 100644 --- a/app/templates/dashboard.html +++ b/app/templates/dashboard.html @@ -8,7 +8,7 @@ <body> <div> <!-- Create a small box that will hold the text for the users api key, next to the box should be a regenerate button --> - <p>Your API Key: <span id="api-key">{{ api_key }}</span></p> + <p>Your Username: <span id="api-key">{{ user }}</span></p> <button onclick="window.location.href='logout'">Logout</button> </div> </body> diff --git a/app/templates/login.html b/app/templates/login.html index b41d15c..1061699 100644 --- a/app/templates/login.html +++ b/app/templates/login.html @@ -90,19 +90,15 @@ event.preventDefault(); const formData = new FormData(this); - // Send POST request to /api/token containing form data - const response = await fetch('/api/token', { + // Send POST request to /token containing form data + const response = await fetch('/token', { method: 'POST', body: formData }); - const data = await response.json(); - if (data.response != 200) { + if (response.status != 200) { document.getElementById('error').style.display = 'block'; } else { - // Save the tokens in localStorage - window.localStorage.token = data.token; - window.localStorage.refreshToken = data.refreshToken; window.location.href = '/dashboard'; } }); diff --git a/app/util/authentication.py b/app/util/authentication.py new file mode 100644 index 0000000..b94b1c6 --- /dev/null +++ b/app/util/authentication.py @@ -0,0 +1,142 @@ +import random +import bcrypt +from fastapi import Depends, HTTPException, status, Cookie +from fastapi.security import OAuth2PasswordBearer +from fastapi.responses import RedirectResponse +from jwt.exceptions import InvalidTokenError +from datetime import datetime, timedelta +from typing import Annotated, Optional +import jwt + +from app.util.db_dependency import get_db +from sqlalchemy.orm import sessionmaker +from app.schemas.auth_schemas import * +from models import User as UserDB + +secret_key = random.randbytes(32) +algorithm = "HS256" +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + +""" +Helper functions for authentication +""" + + +def verify_password(plain_password, hashed_password): + return bcrypt.checkpw( + plain_password.encode("utf-8"), hashed_password.encode("utf-8") + ) + + +def get_user(db, username: str): + """ + Get the user object from the database + """ + user = db.query(UserDB).filter(UserDB.username == username).first() + if user: + return UserInDB(**user.__dict__) + + +def authenticate_user(db, username: str, password: str): + """ + Determine if the correct username and password were provided + If so, return the user object + """ + user = get_user(db, username) + if not user: + return False + if not verify_password(password, user.hashed_password): + return False + return user + + +def create_access_token(data: dict, expires_delta: timedelta): + """ + Return an encoded JWT token with the given data + """ + to_encode = data.copy() + expire = datetime.utcnow() + expires_delta + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=algorithm) + return encoded_jwt + + +async def get_current_user_from_cookie( + access_token: str = Cookie(None), db=Depends(get_db) +): + """ + Return the user based on the access token in the cookie + + Used for authentication into UI pages - so if no cookie + exists, redirect to login page rather than returning a 401 + + Also pass is_ui=True to alert get_current_user that we need + to use RedirectResponse rather than raising an HTTPException + """ + if access_token: + return await get_current_user(access_token, is_ui=True, db=db) + return RedirectResponse(url="/login") + + +async def get_current_user_from_token( + token: Annotated[str, Depends(oauth2_scheme)], db=Depends(get_db) +): + return await get_current_user(token, db=db) + + +# Backwards kinda of way to get refresh token support +# `refresh_get_current_user` is only called from /refresh +# and alerts `get_current_user` that it should expect a refresh token +async def refresh_get_current_user( + token: Annotated[str, Depends(oauth2_scheme)], db=Depends(get_db) +): + return await get_current_user(token, is_refresh=True, db=db) + + +async def get_current_user( + token: str, + is_refresh: bool = False, + is_ui: bool = False, + db: Optional[sessionmaker] = None, +): + """ + Return the current user based on the token + + OR on error - + If is_ui=True, the request is from a UI page and we should redirect to login + Otherwise, the request is from an API and we should return a 401 + """ + + def raise_unauthorized(): + if is_ui: + return RedirectResponse(url="/login") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + + try: + payload = jwt.decode(token, secret_key, algorithms=[algorithm]) + username: str = payload.get("sub") + refresh: bool = payload.get("refresh") + if username is None: + return raise_unauthorized() + # For some reason, an access token was passed when a refresh + # token was expected - some likely malicious activity + if not refresh and is_refresh: + return raise_unauthorized() + # If the token passed is a refresh token and the function + # is not expecting a refresh token, raise an error + if refresh and not is_refresh: + return raise_unauthorized() + + token_data = TokenData(username=username) + except InvalidTokenError: + return raise_unauthorized() + + user = get_user(db, username=token_data.username) + if user is None: + return raise_unauthorized() + + return user diff --git a/app/util/db_dependency.py b/app/util/db_dependency.py new file mode 100644 index 0000000..a6734ea --- /dev/null +++ b/app/util/db_dependency.py @@ -0,0 +1,9 @@ +from database import SessionLocal + + +def get_db(): + db = SessionLocal() + try: + yield db + finally: + db.close() |