diff options
Diffstat (limited to 'api/util')
-rw-r--r-- | api/util/authentication.py | 88 | ||||
-rw-r--r-- | api/util/check_api_key.py | 21 | ||||
-rw-r--r-- | api/util/validate_login_information.py | 20 |
3 files changed, 88 insertions, 41 deletions
diff --git a/api/util/authentication.py b/api/util/authentication.py new file mode 100644 index 0000000..4dfbc77 --- /dev/null +++ b/api/util/authentication.py @@ -0,0 +1,88 @@ +import random +import bcrypt +from fastapi import Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer +from jwt.exceptions import InvalidTokenError +from datetime import datetime, timedelta +from typing import Annotated +import jwt + +from api.util.db_dependency import get_db +from api.schemas.auth_schemas import * +from models import User as UserDB + +secret_key = random.randbytes(32) +algorithm = "HS256" +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + +""" +Helper functions for authentication +""" + + +def verify_password(plain_password, hashed_password): + return bcrypt.checkpw( + plain_password.encode("utf-8"), hashed_password.encode("utf-8") + ) + + +def get_user(db, username: str): + """ + Get the user object from the database + """ + user = db.query(UserDB).filter(UserDB.username == username).first() + if user: + return UserInDB(**user.__dict__) + + +def authenticate_user(db, username: str, password: str): + """ + Determine if the correct username and password were provided + If so, return the user object + """ + user = get_user(db, username) + print(user) + if not user: + return False + if not verify_password(password, user.hashed_password): + return False + return user + + +def create_access_token(data: dict, expires_delta: timedelta | None = None): + """ + Return an encoded JWT token with the given data + """ + to_encode = data.copy() + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta(minutes=15) + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=algorithm) + return encoded_jwt + + +async def get_current_user( + token: Annotated[str, Depends(oauth2_scheme)], db=Depends(get_db) +): + """ + Return the current user based on the token, or raise a 401 error + """ + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + payload = jwt.decode(token, secret_key, algorithms=[algorithm]) + username: str = payload.get("sub") + if username is None: + raise credentials_exception + token_data = TokenData(username=username) + except InvalidTokenError: + raise credentials_exception + user = get_user(db, username=token_data.username) + if user is None: + raise credentials_exception + return user diff --git a/api/util/check_api_key.py b/api/util/check_api_key.py deleted file mode 100644 index 9c4c22e..0000000 --- a/api/util/check_api_key.py +++ /dev/null @@ -1,21 +0,0 @@ -from fastapi import Security, HTTPException, Depends, status -from fastapi.security import APIKeyHeader - -from models import User -from api.util.db_dependency import get_db - -""" -Make sure the provided API key is valid, then return the user's ID -""" -api_key_header = APIKeyHeader(name="X-API-Key") - - -def check_api_key( - api_key_header: str = Security(api_key_header), db=Depends(get_db) -) -> str: - response = db.query(User).filter(User.api_key == api_key_header).first() - if not response: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key" - ) - return {"value": api_key_header, "owner": response.id} diff --git a/api/util/validate_login_information.py b/api/util/validate_login_information.py deleted file mode 100644 index 55bbb2e..0000000 --- a/api/util/validate_login_information.py +++ /dev/null @@ -1,20 +0,0 @@ -import bcrypt -from fastapi import Depends - -from api.util.db_dependency import get_db -from models import User - -""" -Validate the login information provided by the user -""" - - -def validate_login_information( - username: str, password: str, db=Depends(get_db) -) -> bool: - user = db.query(User).filter(User.username == username).first() - if not user: - return False - if bcrypt.checkpw(password.encode("utf-8"), user.password.encode("utf-8")): - return True - return False |