diff options
author | Parker <contact@pkrm.dev> | 2024-11-04 00:12:36 -0600 |
---|---|---|
committer | Parker <contact@pkrm.dev> | 2024-11-04 00:12:36 -0600 |
commit | 8ae8c5c454ba42e8f56f415d33bbaaac7d1a37ec (patch) | |
tree | d56704d87f63b79681530ab729d9f54d24f73c80 /api/util/authentication.py | |
parent | 65fef6274166678f59d6d81c9da68465a7c374bc (diff) |
Remove API Keys -> Authenticate with JWT
Diffstat (limited to 'api/util/authentication.py')
-rw-r--r-- | api/util/authentication.py | 88 |
1 files changed, 88 insertions, 0 deletions
diff --git a/api/util/authentication.py b/api/util/authentication.py new file mode 100644 index 0000000..4dfbc77 --- /dev/null +++ b/api/util/authentication.py @@ -0,0 +1,88 @@ +import random +import bcrypt +from fastapi import Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer +from jwt.exceptions import InvalidTokenError +from datetime import datetime, timedelta +from typing import Annotated +import jwt + +from api.util.db_dependency import get_db +from api.schemas.auth_schemas import * +from models import User as UserDB + +secret_key = random.randbytes(32) +algorithm = "HS256" +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + +""" +Helper functions for authentication +""" + + +def verify_password(plain_password, hashed_password): + return bcrypt.checkpw( + plain_password.encode("utf-8"), hashed_password.encode("utf-8") + ) + + +def get_user(db, username: str): + """ + Get the user object from the database + """ + user = db.query(UserDB).filter(UserDB.username == username).first() + if user: + return UserInDB(**user.__dict__) + + +def authenticate_user(db, username: str, password: str): + """ + Determine if the correct username and password were provided + If so, return the user object + """ + user = get_user(db, username) + print(user) + if not user: + return False + if not verify_password(password, user.hashed_password): + return False + return user + + +def create_access_token(data: dict, expires_delta: timedelta | None = None): + """ + Return an encoded JWT token with the given data + """ + to_encode = data.copy() + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta(minutes=15) + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=algorithm) + return encoded_jwt + + +async def get_current_user( + token: Annotated[str, Depends(oauth2_scheme)], db=Depends(get_db) +): + """ + Return the current user based on the token, or raise a 401 error + """ + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + payload = jwt.decode(token, secret_key, algorithms=[algorithm]) + username: str = payload.get("sub") + if username is None: + raise credentials_exception + token_data = TokenData(username=username) + except InvalidTokenError: + raise credentials_exception + user = get_user(db, username=token_data.username) + if user is None: + raise credentials_exception + return user |