diff options
Diffstat (limited to 'app/util')
-rw-r--r-- | app/util/authentication.py | 127 | ||||
-rw-r--r-- | app/util/check_password_reqs.py | 26 | ||||
-rw-r--r-- | app/util/db_dependency.py | 9 | ||||
-rw-r--r-- | app/util/log.py | 84 |
4 files changed, 0 insertions, 246 deletions
diff --git a/app/util/authentication.py b/app/util/authentication.py deleted file mode 100644 index a8f7aff..0000000 --- a/app/util/authentication.py +++ /dev/null @@ -1,127 +0,0 @@ -import random -import bcrypt -from fastapi import Depends, HTTPException, status, Request, Cookie -from fastapi.security import OAuth2PasswordBearer -from fastapi.responses import RedirectResponse -from jwt.exceptions import InvalidTokenError -from datetime import datetime, timedelta -from typing import Annotated, Optional -import jwt - -from app.util.db_dependency import get_db -from sqlalchemy.orm import Session -from app.schemas.auth_schemas import * -from models import User as UserModel - -secret_key = random.randbytes(32) -algorithm = "HS256" -oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/auth/token") - -""" -Helper functions for authentication -""" - - -def verify_password(plain_password, hashed_password): - return bcrypt.checkpw( - plain_password.encode("utf-8"), hashed_password.encode("utf-8") - ) - - -def get_user(db, username: str): - """ - Get the user object from the database - """ - user = db.query(UserModel).filter(UserModel.username == username).first() - if user: - return UserInDB(**user.__dict__) - - -def authenticate_user(db, username: str, password: str): - """ - Determine if the correct username and password were provided - If so, return the user object - """ - user = get_user(db, username) - if not user: - return False - if not verify_password(password, user.hashed_password): - print("WHY") - return False - return user - - -def create_access_token(data: dict, expires_delta: timedelta): - """ - Return an encoded JWT token with the given data - """ - to_encode = data.copy() - expire = datetime.utcnow() + expires_delta - to_encode.update({"exp": expire}) - encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=algorithm) - return encoded_jwt - - -# Backwards kind of way to get refresh token support -# `refresh_get_current_user` is only called from /refresh -# and alerts `get_current_user` that it should expect a refresh token -async def refresh_get_current_user( - token: Annotated[str, Depends(oauth2_scheme)], - db=Depends(get_db), -): - return await get_current_user(token, is_refresh=True, db=db) - - -def process_refresh_token(token: str, db: Session): - return False - - -async def get_current_user( - request: Request, - db=Depends(get_db), -): - """ - Return the current user based on the token - - OR on error - - If is_ui=True, the request is from a UI page and we should redirect to login - Otherwise, the request is from an API and we should return a 401 - """ - - def raise_unauthorized(): - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not validate credentials", - headers={"WWW-Authenticate": "Bearer"}, - ) - - # If the request is from /api/auth/refresh, it is a request to get - # a new access token using a refresh token - if request.url.path == "/api/auth/refresh": - token = request.cookies.get("refresh_token") - user = process_refresh_token(token, db) - if user is None: - raise_unauthorized() - else: - token = request.cookies.get("access_token") - - try: - payload = jwt.decode(token, secret_key, algorithms=[algorithm]) - id: int = payload.get("sub") - username: str = payload.get("username") - refresh: bool = payload.get("refresh") - if not id or not username: - return raise_unauthorized() - - # Make sure that a refresh token was not passed to any other endpoint - if refresh and not is_refresh: - return raise_unauthorized() - - except InvalidTokenError: - return raise_unauthorized() - - user = get_user(db, username) - if user is None: - return raise_unauthorized() - - return user diff --git a/app/util/check_password_reqs.py b/app/util/check_password_reqs.py deleted file mode 100644 index dcb9bf8..0000000 --- a/app/util/check_password_reqs.py +++ /dev/null @@ -1,26 +0,0 @@ -from fastapi import HTTPException, status - - -def check_password_reqs(password: str): - """ - Make sure the entered password meets the security requirements: - 1. At least 8 characters - 2. At least one digit - 3. At least one uppercase letter - """ - if len(password) < 8: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Password must be at least 8 characters", - ) - if not any(char.isdigit() for char in password): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Password must contain at least one digit", - ) - if not any(char.isupper() for char in password): - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Password must contain at least one uppercase letter", - ) - return diff --git a/app/util/db_dependency.py b/app/util/db_dependency.py deleted file mode 100644 index a6734ea..0000000 --- a/app/util/db_dependency.py +++ /dev/null @@ -1,9 +0,0 @@ -from database import SessionLocal - - -def get_db(): - db = SessionLocal() - try: - yield db - finally: - db.close() diff --git a/app/util/log.py b/app/util/log.py deleted file mode 100644 index 58a56f9..0000000 --- a/app/util/log.py +++ /dev/null @@ -1,84 +0,0 @@ -import requests -import datetime -from ua_parser import user_agent_parser - -from database import SessionLocal -import config -from models import Link, Log - -""" -Create a new log whenever a link is visited -""" - - -def ip_to_location(ip): - if not config.IP_TO_LOCATION: - return "-, -", "-" - - url = f"https://api.ip2location.io/?key={config.API_KEY}&ip={ip}" - response = requests.get(url) - data = response.json() - - if response.status_code != 200: - config.LOG.error( - "Error with IP2Location API. Perhaps the API is down." - ) - return "-, -", "-" - - if "error" in data: - config.LOG.error( - "Error with IP2Location API. Likely wrong API key or insufficient" - " funds." - ) - return "-, -", "-" - - location = "" - # Sometimes a certain name may not be present, so always check - if "city_name" in data: - location += data["city_name"] - - if "region_name" in data: - location += f', {data["region_name"]}' - - if "country_name" in data: - location += f', {data["country_name"]}' - - isp = data["as"] - return location, isp - - -def log(link, ip, user_agent): - db = SessionLocal() - - # Get the redirect link and owner of the link - redirect_link, owner = ( - db.query(Link.redirect_link, Link.owner) - .filter(Link.link == link) - .first() - ) - - # Get the location and ISP of the user - location, isp = ip_to_location(ip) - - ua_string = user_agent_parser.Parse(user_agent) - browser = ua_string["user_agent"]["family"] - os = f'{ua_string["os"]["family"]} {ua_string["os"]["major"]}' - - # Create the log and commit it to the database - new_log = Log( - owner=owner, - link=link, - timestamp=datetime.datetime.utcnow(), - ip=ip, - location=location, - browser=browser, - os=os, - user_agent=user_agent, - isp=isp, - ) - db.add(new_log) - db.commit() - db.close() - - # Return the redirect link in order to properly redirect the user - return redirect_link |