aboutsummaryrefslogtreecommitdiff
path: root/api/util
diff options
context:
space:
mode:
Diffstat (limited to 'api/util')
-rw-r--r--api/util/authentication.py88
-rw-r--r--api/util/check_api_key.py21
-rw-r--r--api/util/validate_login_information.py20
3 files changed, 88 insertions, 41 deletions
diff --git a/api/util/authentication.py b/api/util/authentication.py
new file mode 100644
index 0000000..4dfbc77
--- /dev/null
+++ b/api/util/authentication.py
@@ -0,0 +1,88 @@
+import random
+import bcrypt
+from fastapi import Depends, HTTPException, status
+from fastapi.security import OAuth2PasswordBearer
+from jwt.exceptions import InvalidTokenError
+from datetime import datetime, timedelta
+from typing import Annotated
+import jwt
+
+from api.util.db_dependency import get_db
+from api.schemas.auth_schemas import *
+from models import User as UserDB
+
+secret_key = random.randbytes(32)
+algorithm = "HS256"
+oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
+
+"""
+Helper functions for authentication
+"""
+
+
+def verify_password(plain_password, hashed_password):
+ return bcrypt.checkpw(
+ plain_password.encode("utf-8"), hashed_password.encode("utf-8")
+ )
+
+
+def get_user(db, username: str):
+ """
+ Get the user object from the database
+ """
+ user = db.query(UserDB).filter(UserDB.username == username).first()
+ if user:
+ return UserInDB(**user.__dict__)
+
+
+def authenticate_user(db, username: str, password: str):
+ """
+ Determine if the correct username and password were provided
+ If so, return the user object
+ """
+ user = get_user(db, username)
+ print(user)
+ if not user:
+ return False
+ if not verify_password(password, user.hashed_password):
+ return False
+ return user
+
+
+def create_access_token(data: dict, expires_delta: timedelta | None = None):
+ """
+ Return an encoded JWT token with the given data
+ """
+ to_encode = data.copy()
+ if expires_delta:
+ expire = datetime.utcnow() + expires_delta
+ else:
+ expire = datetime.utcnow() + timedelta(minutes=15)
+ to_encode.update({"exp": expire})
+ encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=algorithm)
+ return encoded_jwt
+
+
+async def get_current_user(
+ token: Annotated[str, Depends(oauth2_scheme)], db=Depends(get_db)
+):
+ """
+ Return the current user based on the token, or raise a 401 error
+ """
+ credentials_exception = HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Could not validate credentials",
+ headers={"WWW-Authenticate": "Bearer"},
+ )
+ try:
+ payload = jwt.decode(token, secret_key, algorithms=[algorithm])
+ username: str = payload.get("sub")
+ if username is None:
+ raise credentials_exception
+ token_data = TokenData(username=username)
+ except InvalidTokenError:
+ raise credentials_exception
+ user = get_user(db, username=token_data.username)
+ if user is None:
+ raise credentials_exception
+ return user
diff --git a/api/util/check_api_key.py b/api/util/check_api_key.py
deleted file mode 100644
index 9c4c22e..0000000
--- a/api/util/check_api_key.py
+++ /dev/null
@@ -1,21 +0,0 @@
-from fastapi import Security, HTTPException, Depends, status
-from fastapi.security import APIKeyHeader
-
-from models import User
-from api.util.db_dependency import get_db
-
-"""
-Make sure the provided API key is valid, then return the user's ID
-"""
-api_key_header = APIKeyHeader(name="X-API-Key")
-
-
-def check_api_key(
- api_key_header: str = Security(api_key_header), db=Depends(get_db)
-) -> str:
- response = db.query(User).filter(User.api_key == api_key_header).first()
- if not response:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
- )
- return {"value": api_key_header, "owner": response.id}
diff --git a/api/util/validate_login_information.py b/api/util/validate_login_information.py
deleted file mode 100644
index 55bbb2e..0000000
--- a/api/util/validate_login_information.py
+++ /dev/null
@@ -1,20 +0,0 @@
-import bcrypt
-from fastapi import Depends
-
-from api.util.db_dependency import get_db
-from models import User
-
-"""
-Validate the login information provided by the user
-"""
-
-
-def validate_login_information(
- username: str, password: str, db=Depends(get_db)
-) -> bool:
- user = db.query(User).filter(User.username == username).first()
- if not user:
- return False
- if bcrypt.checkpw(password.encode("utf-8"), user.password.encode("utf-8")):
- return True
- return False