aboutsummaryrefslogtreecommitdiff
path: root/app
diff options
context:
space:
mode:
Diffstat (limited to 'app')
-rw-r--r--app/main.py331
-rw-r--r--app/routes/links_route.py155
-rw-r--r--app/routes/refresh_route.py33
-rw-r--r--app/routes/token_route.py54
-rw-r--r--app/schemas/auth_schemas.py20
-rw-r--r--app/schemas/links_schemas.py5
-rw-r--r--app/templates/dashboard.html2
-rw-r--r--app/templates/login.html10
-rw-r--r--app/util/authentication.py142
-rw-r--r--app/util/db_dependency.py9
10 files changed, 569 insertions, 192 deletions
diff --git a/app/main.py b/app/main.py
index 78a65c8..c36d64a 100644
--- a/app/main.py
+++ b/app/main.py
@@ -1,187 +1,150 @@
-from flask_login import (
- current_user,
- login_user,
- login_required,
- logout_user,
- LoginManager,
- UserMixin,
+from fastapi import FastAPI, Path, Depends, Request
+from fastapi.middleware.cors import CORSMiddleware
+from fastapi.responses import RedirectResponse
+from fastapi.templating import Jinja2Templates
+from app.routes.links_route import router as links_router
+from app.routes.refresh_route import router as refresh_router
+from app.routes.token_route import router as token_router
+from typing import Annotated
+from fastapi.exceptions import HTTPException
+from starlette.status import HTTP_404_NOT_FOUND
+
+from app.util.authentication import get_current_user_from_cookie
+from app.schemas.auth_schemas import User
+
+app = FastAPI(
+ title="LinkLogger API",
+ version="1.0",
+ summary="Public API for a combined link shortener and IP logger",
+ license_info={
+ "name": "The Unlicense",
+ "identifier": "Unlicense",
+ "url": "https://unlicense.org",
+ },
)
-from flask import Flask, redirect, render_template, request, url_for
-import bcrypt
-import os
-import string
-import random
-
-from models import User, Link
-from database import *
-from app.util.log import log
-
-
-class FlaskUser(UserMixin):
- pass
-
-
-app = Flask(__name__)
-app.config["SECRET_KEY"] = os.urandom(24)
-
-login_manager = LoginManager()
-login_manager.init_app(app)
-
-
-@login_manager.user_loader
-def user_loader(username):
- user = FlaskUser()
- user.id = username
- return user
-
-
-"""
-Handle login requests from the web UI
-"""
-
-
-@app.route("/login", methods=["GET", "POST"])
-def login():
- if request.method == "POST":
- username = request.form["username"]
- password = request.form["password"]
-
- # Get database session
- db = SessionLocal()
-
- user = db.query(User).filter(User.username == username).first()
- db.close()
- if not user:
- return {"status": "Invalid username or password"}
-
- if bcrypt.checkpw(
- password.encode("utf-8"), user.password.encode("utf-8")
- ):
- flask_user = FlaskUser()
- flask_user.id = username
- login_user(flask_user)
- return {"status": "success"}
-
- return {"status": "Invalid username or password"}
- return render_template("login.html")
-
-
-"""
-Handle signup requests from the web UI
-"""
-
-
-@app.route("/signup", methods=["GET", "POST"])
-def signup():
- if request.method == "POST":
- username = request.form["username"]
- password = request.form["password"]
-
- # Verify the password meets requirements
- if len(password) < 8:
- return {"status": "Password must be at least 8 characters"}
- if not any(char.isdigit() for char in password):
- return {"status": "Password must contain at least one digit"}
- if not any(char.isupper() for char in password):
- return {
- "status": "Password must contain at least one uppercase letter"
- }
-
- # Get database session
- db = SessionLocal()
-
- user = db.query(User).filter(User.username == username).first()
- if user:
- db.close()
- return {"status": "Username not available"}
- # Add information to the database
- hashed_password = bcrypt.hashpw(
- password.encode("utf-8"), bcrypt.gensalt()
- ).decode("utf-8")
- api_key = "".join(
- random.choices(string.ascii_letters + string.digits, k=20)
- )
- new_user = User(
- username=username, password=hashed_password, api_key=api_key
- )
- db.add(new_user)
- db.commit()
- db.close()
- # Log in the newly created user
- flask_user = FlaskUser()
- flask_user.id = username
- login_user(flask_user)
-
- return {"status": "success"}
- return render_template("signup.html")
+app.add_middleware(
+ CORSMiddleware,
+ allow_origins=["*"],
+ allow_methods=["*"],
+ allow_headers=["*"],
+ allow_credentials=True,
+)
-"""
-Load the 'dashboard' page for logged in users
-"""
-
-
-@app.route("/dashboard", methods=["GET"])
-@login_required
-def dashboard():
- # Get database session
- db = SessionLocal()
-
- # Get the API key for the current user
- user = db.query(User).filter(User.username == current_user.id).first()
- db.close()
- api_key = user.api_key
-
- return render_template("dashboard.html", api_key=api_key)
-
-
-"""
-Log users out of their account
-"""
-
-
-@app.route("/logout", methods=["GET"])
-@login_required
-def logout():
- logout_user()
- return redirect(url_for("login"))
-
-
-"""
-Log all records for visits to shortened links
-"""
-
-
-@app.route("/<link>", methods=["GET"])
-def log_redirect(link: str):
- link = link.upper()
- # If `link` is not exactly 5 characters, return redirect to base url
- if len(link) != 5:
- return redirect(url_for("login"))
-
- # Make sure the link exists in the database
- db = SessionLocal()
- link_record = db.query(Link).filter(Link.link == link).first()
- if not link_record:
- db.close()
- return redirect(url_for("login"))
- else:
- # Log the visit
- if request.headers.get("X-Real-IP"):
- ip = request.headers.get("X-Real-IP").split(",")[0]
- else:
- ip = request.remote_addr
- user_agent = request.headers.get("User-Agent")
- log(link, ip, user_agent)
- db.close()
- return redirect(link_record.redirect_link)
-
-
-@app.errorhandler(401)
-def unauthorized(e):
- return redirect(url_for("login"))
-
-
-@app.errorhandler(404)
-def not_found(e):
- return redirect(url_for("login"))
+templates = Jinja2Templates(directory="app/templates")
+
+# Import routes
+app.include_router(links_router, prefix="/api")
+# Must not have a prefix... for some reason you can't change
+# the prefix of the Swagger UI OAuth2 redirect to /api/token
+# you can only change it to /token, so we have to remove the
+# prefix in order to keep logging in via Swagger UI working
+app.include_router(token_router)
+app.include_router(refresh_router, prefix="/api")
+
+
+@app.get("/login")
+async def login(request: Request):
+ return templates.TemplateResponse("login.html", {"request": request})
+
+
+# Handle login requests through Swagger UI
+
+
+# @app.route("/signup", methods=["GET", "POST"])
+# def signup():
+# if request.method == "POST":
+# username = request.form["username"]
+# password = request.form["password"]
+
+# # Verify the password meets requirements
+# if len(password) < 8:
+# return {"status": "Password must be at least 8 characters"}
+# if not any(char.isdigit() for char in password):
+# return {"status": "Password must contain at least one digit"}
+# if not any(char.isupper() for char in password):
+# return {
+# "status": "Password must contain at least one uppercase letter"
+# }
+
+# # Get database session
+# db = SessionLocal()
+
+# user = db.query(User).filter(User.username == username).first()
+# if user:
+# db.close()
+# return {"status": "Username not available"}
+# # Add information to the database
+# hashed_password = bcrypt.hashpw(
+# password.encode("utf-8"), bcrypt.gensalt()
+# ).decode("utf-8")
+# api_key = "".join(
+# random.choices(string.ascii_letters + string.digits, k=20)
+# )
+# new_user = User(
+# username=username, password=hashed_password, api_key=api_key
+# )
+# db.add(new_user)
+# db.commit()
+# db.close()
+# # Log in the newly created user
+# flask_user = FlaskUser()
+# flask_user.id = username
+# login_user(flask_user)
+
+# return {"status": "success"}
+# return render_template("signup.html")
+
+
+@app.get("/dashboard")
+async def dashboard(
+ response: Annotated[
+ User, RedirectResponse, Depends(get_current_user_from_cookie)
+ ],
+ request: Request,
+):
+ if isinstance(response, RedirectResponse):
+ return response
+ return templates.TemplateResponse(
+ "dashboard.html", {"request": request, "user": response.username}
+ )
+
+
+# @app.get("/{link}")
+# async def log_redirect(
+# link: Annotated[str, Path(title="Redirect link")],
+# request: Request,
+# db=Depends(get_db),
+# ):
+# link = link.upper()
+# # If `link` is not exactly 5 characters, return redirect to base url
+# if len(link) != 5:
+# return RedirectResponse(url="/login")
+
+# # Make sure the link exists in the database
+# link_record: Link = db.query(Link).filter(Link.link == link).first()
+# if not link_record:
+# db.close()
+# return RedirectResponse(url="/login")
+# else:
+# # Log the visit
+# if request.headers.get("X-Real-IP"):
+# ip = request.headers.get("X-Real-IP").split(",")[0]
+# else:
+# ip = request.client.host
+# user_agent = request.headers.get("User-Agent")
+# log(link, ip, user_agent)
+# db.close()
+# return RedirectResponse(url=link_record.redirect_link)
+
+
+# Redirect /api -> /api/docs
+@app.get("/api")
+async def redirect_to_docs():
+ return RedirectResponse(url="/docs")
+
+
+# Custom handler for 404 errors
+@app.exception_handler(HTTP_404_NOT_FOUND)
+async def custom_404_handler(request: Request, exc: HTTPException):
+ return RedirectResponse(url="/login")
diff --git a/app/routes/links_route.py b/app/routes/links_route.py
new file mode 100644
index 0000000..054508a
--- /dev/null
+++ b/app/routes/links_route.py
@@ -0,0 +1,155 @@
+from fastapi import APIRouter, status, Path, Depends
+from fastapi.exception_handlers import HTTPException
+from typing import Annotated
+import string
+import random
+import datetime
+import validators
+
+from app.util.db_dependency import get_db
+from models import Link, Record
+from app.schemas.links_schemas import URLSchema
+from app.schemas.auth_schemas import User
+from app.util.authentication import get_current_user_from_token
+
+
+router = APIRouter(prefix="/links", tags=["links"])
+
+
+@router.get("/", summary="Get all of the links associated with your account")
+async def get_links(
+ current_user: Annotated[User, Depends(get_current_user_from_token)],
+ db=Depends(get_db),
+):
+ links = db.query(Link).filter(Link.owner == current_user.id).all()
+ if not links:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="No links found"
+ )
+ return links
+
+
+@router.post("/", summary="Create a new link")
+async def create_link(
+ url: URLSchema,
+ current_user: Annotated[User, Depends(get_current_user_from_token)],
+ db=Depends(get_db),
+):
+ # Check if the URL is valid
+ if not validators.url(url.url):
+ raise HTTPException(
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
+ detail="Invalid URL",
+ )
+ # Create the new link and add it to the database
+ while True:
+ try:
+ link_path = "".join(
+ random.choices(string.ascii_uppercase + "1234567890", k=5)
+ ).upper()
+ new_link = Link(
+ link=link_path,
+ owner=current_user.id,
+ redirect_link=url.url,
+ expire_date=datetime.datetime.now()
+ + datetime.timedelta(days=30),
+ )
+ db.add(new_link)
+ db.commit()
+ break
+ except:
+ continue
+
+ return {
+ "response": "Link successfully created",
+ "expire_date": new_link.expire_date,
+ "link": new_link.link,
+ }
+
+
+@router.delete("/{link}", summary="Delete a link")
+async def delete_link(
+ link: Annotated[str, Path(title="Link to delete")],
+ current_user: Annotated[User, Depends(get_current_user_from_token)],
+ db=Depends(get_db),
+):
+ link = link.upper()
+ # Get the link and check the owner
+ link = db.query(Link).filter(Link.link == link).first()
+ if not link:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Link not found"
+ )
+ if link.owner != current_user.id:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Link not associated with your account",
+ )
+
+ # Get and delete all records associated with the link
+ records = db.query(Record).filter(Record.link == link.link).all()
+ for record in records:
+ db.delete(record)
+ # Delete the link
+ db.delete(link)
+ db.commit()
+
+ return {"response": "Link successfully deleted", "link": link.link}
+
+
+@router.get(
+ "/{link}/records",
+ summary="Get all of the IP log records associated with a link",
+)
+async def get_link_records(
+ link: Annotated[str, Path(title="Link to get records for")],
+ current_user: Annotated[User, Depends(get_current_user_from_token)],
+ db=Depends(get_db),
+):
+ link = link.upper()
+ # Get the link and check the owner
+ link = db.query(Link).filter(Link.link == link).first()
+ if not link:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Link not found"
+ )
+ if link.owner != current_user.id:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Link not associated with your account",
+ )
+
+ # Get and return all of the records associated with the link
+ records = db.query(Record).filter(Record.link == link.link).all()
+ return records
+
+
+@router.delete(
+ "/{link}/records",
+ summary="Delete all of the IP log records associated with a link",
+)
+async def delete_link_records(
+ link: Annotated[str, Path(title="Link to delete records for")],
+ current_user: Annotated[User, Depends(get_current_user_from_token)],
+ db=Depends(get_db),
+):
+ link = link.upper()
+ # Get the link and check the owner
+ link = db.query(Link).filter(Link.link == link).first()
+ if not link:
+ raise HTTPException(
+ status_code=status.HTTP_404_NOT_FOUND, detail="Link not found"
+ )
+ if link.owner != current_user.id:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Link not associated with your account",
+ )
+
+ # Get all of the records associated with the link and delete them
+ records = db.query(Record).filter(Record.link == link.link).all()
+ for record in records:
+ db.delete(record)
+ db.commit()
+
+ return {"response": "Records successfully deleted", "link": link.link}
diff --git a/app/routes/refresh_route.py b/app/routes/refresh_route.py
new file mode 100644
index 0000000..6bc8797
--- /dev/null
+++ b/app/routes/refresh_route.py
@@ -0,0 +1,33 @@
+from fastapi import Depends, APIRouter
+from fastapi.responses import RedirectResponse
+from datetime import timedelta
+from typing import Annotated
+
+from app.util.authentication import (
+ create_access_token,
+ refresh_get_current_user,
+)
+from app.schemas.auth_schemas import Token, User
+
+
+router = APIRouter(prefix="/refresh", tags=["refresh"])
+
+
+# Full native JWT support is not complete in FastAPI yet :(
+# Part of that is token refresh, so we must implement it ourselves
+@router.post("/")
+async def refresh_access_token(
+ current_user: Annotated[User, Depends(refresh_get_current_user)],
+) -> Token:
+ """
+ Return a new access token if the refresh token is valid
+ """
+ access_token_expires = timedelta(minutes=30)
+ access_token = create_access_token(
+ data={"sub": current_user.username, "refresh": False},
+ expires_delta=access_token_expires,
+ )
+ return Token(
+ access_token=access_token,
+ token_type="bearer",
+ )
diff --git a/app/routes/token_route.py b/app/routes/token_route.py
new file mode 100644
index 0000000..8000616
--- /dev/null
+++ b/app/routes/token_route.py
@@ -0,0 +1,54 @@
+from fastapi import APIRouter, status, Depends, HTTPException
+from fastapi.responses import JSONResponse, Response
+from typing import Annotated
+from datetime import timedelta
+from typing import Annotated
+from fastapi.security import OAuth2PasswordRequestForm
+
+from app.util.db_dependency import get_db
+from app.util.authentication import (
+ authenticate_user,
+ create_access_token,
+)
+from app.schemas.auth_schemas import Token
+
+
+router = APIRouter(prefix="/token", tags=["token"])
+
+
+@router.post("/")
+async def login_for_access_token(
+ form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
+ response: Response,
+ db=Depends(get_db),
+) -> Token:
+ """
+ Return an access token for the user, if the given authentication details are correct
+ """
+ user = authenticate_user(db, form_data.username, form_data.password)
+ if not user:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Incorrect username or password",
+ headers={"WWW-Authenticate": "Bearer"},
+ )
+ access_token_expires = timedelta(minutes=15)
+ access_token = create_access_token(
+ data={"sub": user.username, "refresh": False},
+ expires_delta=access_token_expires,
+ )
+ # Create a refresh token - just an access token with a longer expiry
+ # and more restrictions ("refresh" is True)
+ refresh_token_expires = timedelta(days=1)
+ refresh_token = create_access_token(
+ data={"sub": user.username, "refresh": True},
+ expires_delta=refresh_token_expires,
+ )
+ response = JSONResponse(content={"success": True})
+ response.set_cookie(
+ key="access_token", value=access_token, httponly=True, samesite="lax"
+ )
+ response.set_cookie(
+ key="refresh_token", value=refresh_token, httponly=True, samesite="lax"
+ )
+ return response
diff --git a/app/schemas/auth_schemas.py b/app/schemas/auth_schemas.py
new file mode 100644
index 0000000..006a7c8
--- /dev/null
+++ b/app/schemas/auth_schemas.py
@@ -0,0 +1,20 @@
+from pydantic import BaseModel
+
+
+class Token(BaseModel):
+ access_token: str
+ refresh_token: str | None = None
+ token_type: str
+
+
+class TokenData(BaseModel):
+ username: str | None = None
+
+
+class User(BaseModel):
+ username: str
+ id: int
+
+
+class UserInDB(User):
+ hashed_password: str
diff --git a/app/schemas/links_schemas.py b/app/schemas/links_schemas.py
new file mode 100644
index 0000000..e2812fb
--- /dev/null
+++ b/app/schemas/links_schemas.py
@@ -0,0 +1,5 @@
+from pydantic import BaseModel
+
+
+class URLSchema(BaseModel):
+ url: str
diff --git a/app/templates/dashboard.html b/app/templates/dashboard.html
index 2118fbf..f1c98e3 100644
--- a/app/templates/dashboard.html
+++ b/app/templates/dashboard.html
@@ -8,7 +8,7 @@
<body>
<div>
<!-- Create a small box that will hold the text for the users api key, next to the box should be a regenerate button -->
- <p>Your API Key: <span id="api-key">{{ api_key }}</span></p>
+ <p>Your Username: <span id="api-key">{{ user }}</span></p>
<button onclick="window.location.href='logout'">Logout</button>
</div>
</body>
diff --git a/app/templates/login.html b/app/templates/login.html
index b41d15c..1061699 100644
--- a/app/templates/login.html
+++ b/app/templates/login.html
@@ -90,19 +90,15 @@
event.preventDefault();
const formData = new FormData(this);
- // Send POST request to /api/token containing form data
- const response = await fetch('/api/token', {
+ // Send POST request to /token containing form data
+ const response = await fetch('/token', {
method: 'POST',
body: formData
});
- const data = await response.json();
- if (data.response != 200) {
+ if (response.status != 200) {
document.getElementById('error').style.display = 'block';
} else {
- // Save the tokens in localStorage
- window.localStorage.token = data.token;
- window.localStorage.refreshToken = data.refreshToken;
window.location.href = '/dashboard';
}
});
diff --git a/app/util/authentication.py b/app/util/authentication.py
new file mode 100644
index 0000000..b94b1c6
--- /dev/null
+++ b/app/util/authentication.py
@@ -0,0 +1,142 @@
+import random
+import bcrypt
+from fastapi import Depends, HTTPException, status, Cookie
+from fastapi.security import OAuth2PasswordBearer
+from fastapi.responses import RedirectResponse
+from jwt.exceptions import InvalidTokenError
+from datetime import datetime, timedelta
+from typing import Annotated, Optional
+import jwt
+
+from app.util.db_dependency import get_db
+from sqlalchemy.orm import sessionmaker
+from app.schemas.auth_schemas import *
+from models import User as UserDB
+
+secret_key = random.randbytes(32)
+algorithm = "HS256"
+oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
+
+"""
+Helper functions for authentication
+"""
+
+
+def verify_password(plain_password, hashed_password):
+ return bcrypt.checkpw(
+ plain_password.encode("utf-8"), hashed_password.encode("utf-8")
+ )
+
+
+def get_user(db, username: str):
+ """
+ Get the user object from the database
+ """
+ user = db.query(UserDB).filter(UserDB.username == username).first()
+ if user:
+ return UserInDB(**user.__dict__)
+
+
+def authenticate_user(db, username: str, password: str):
+ """
+ Determine if the correct username and password were provided
+ If so, return the user object
+ """
+ user = get_user(db, username)
+ if not user:
+ return False
+ if not verify_password(password, user.hashed_password):
+ return False
+ return user
+
+
+def create_access_token(data: dict, expires_delta: timedelta):
+ """
+ Return an encoded JWT token with the given data
+ """
+ to_encode = data.copy()
+ expire = datetime.utcnow() + expires_delta
+ to_encode.update({"exp": expire})
+ encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=algorithm)
+ return encoded_jwt
+
+
+async def get_current_user_from_cookie(
+ access_token: str = Cookie(None), db=Depends(get_db)
+):
+ """
+ Return the user based on the access token in the cookie
+
+ Used for authentication into UI pages - so if no cookie
+ exists, redirect to login page rather than returning a 401
+
+ Also pass is_ui=True to alert get_current_user that we need
+ to use RedirectResponse rather than raising an HTTPException
+ """
+ if access_token:
+ return await get_current_user(access_token, is_ui=True, db=db)
+ return RedirectResponse(url="/login")
+
+
+async def get_current_user_from_token(
+ token: Annotated[str, Depends(oauth2_scheme)], db=Depends(get_db)
+):
+ return await get_current_user(token, db=db)
+
+
+# Backwards kinda of way to get refresh token support
+# `refresh_get_current_user` is only called from /refresh
+# and alerts `get_current_user` that it should expect a refresh token
+async def refresh_get_current_user(
+ token: Annotated[str, Depends(oauth2_scheme)], db=Depends(get_db)
+):
+ return await get_current_user(token, is_refresh=True, db=db)
+
+
+async def get_current_user(
+ token: str,
+ is_refresh: bool = False,
+ is_ui: bool = False,
+ db: Optional[sessionmaker] = None,
+):
+ """
+ Return the current user based on the token
+
+ OR on error -
+ If is_ui=True, the request is from a UI page and we should redirect to login
+ Otherwise, the request is from an API and we should return a 401
+ """
+
+ def raise_unauthorized():
+ if is_ui:
+ return RedirectResponse(url="/login")
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Could not validate credentials",
+ headers={"WWW-Authenticate": "Bearer"},
+ )
+
+ try:
+ payload = jwt.decode(token, secret_key, algorithms=[algorithm])
+ username: str = payload.get("sub")
+ refresh: bool = payload.get("refresh")
+ if username is None:
+ return raise_unauthorized()
+ # For some reason, an access token was passed when a refresh
+ # token was expected - some likely malicious activity
+ if not refresh and is_refresh:
+ return raise_unauthorized()
+ # If the token passed is a refresh token and the function
+ # is not expecting a refresh token, raise an error
+ if refresh and not is_refresh:
+ return raise_unauthorized()
+
+ token_data = TokenData(username=username)
+ except InvalidTokenError:
+ return raise_unauthorized()
+
+ user = get_user(db, username=token_data.username)
+ if user is None:
+ return raise_unauthorized()
+
+ return user
diff --git a/app/util/db_dependency.py b/app/util/db_dependency.py
new file mode 100644
index 0000000..a6734ea
--- /dev/null
+++ b/app/util/db_dependency.py
@@ -0,0 +1,9 @@
+from database import SessionLocal
+
+
+def get_db():
+ db = SessionLocal()
+ try:
+ yield db
+ finally:
+ db.close()