From 8ae8c5c454ba42e8f56f415d33bbaaac7d1a37ec Mon Sep 17 00:00:00 2001 From: Parker Date: Mon, 4 Nov 2024 00:12:36 -0600 Subject: [PATCH] Remove API Keys -> Authenticate with JWT --- api/main.py | 52 ++++++++++----- api/routes/links_route.py | 23 +++---- api/schemas/auth_schemas.py | 19 ++++++ api/util/authentication.py | 88 ++++++++++++++++++++++++++ api/util/check_api_key.py | 21 ------ api/util/validate_login_information.py | 20 ------ models.py | 2 +- 7 files changed, 156 insertions(+), 69 deletions(-) create mode 100644 api/schemas/auth_schemas.py create mode 100644 api/util/authentication.py delete mode 100644 api/util/check_api_key.py delete mode 100644 api/util/validate_login_information.py diff --git a/api/main.py b/api/main.py index 6ede8ba..ac7b927 100644 --- a/api/main.py +++ b/api/main.py @@ -1,13 +1,20 @@ -from fastapi import FastAPI, Depends, HTTPException, Security +from fastapi import FastAPI, Depends, HTTPException, Security, status from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import RedirectResponse +from datetime import timedelta +from typing import Annotated +from fastapi.security import OAuth2PasswordRequestForm import string import random +from api.util.authentication import ( + authenticate_user, + create_access_token, + get_current_user, +) from api.routes.links_route import router as links_router from api.util.db_dependency import get_db -from api.util.check_api_key import check_api_key -from models import User +from api.schemas.auth_schemas import User, Token metadata_tags = [ @@ -37,22 +44,35 @@ app.add_middleware( # Import routes app.include_router(links_router) -# Regenerate the API key for the user -@app.post("/regenerate") -async def regenerate(api_key: str = Security(check_api_key), db = Depends(get_db)): - """Regenerate the API key for the user. Requires the current API key.""" - user = db.query(User).filter(User.api_key == api_key['value']).first() + +""" +Authentication +""" + + +@app.post("/token") +async def login_for_access_token( + form_data: Annotated[OAuth2PasswordRequestForm, Depends()], + db=Depends(get_db), +) -> Token: + """ + Return an access token for the user, if the given authentication details are correct + """ + user = authenticate_user(db, form_data.username, form_data.password) if not user: - raise HTTPException(status_code=401, detail="Invalid API key") + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect username or password", + headers={"WWW-Authenticate": "Bearer"}, + ) + access_token_expires = timedelta(minutes=30) + access_token = create_access_token( + data={"sub": user.username}, expires_delta=access_token_expires + ) + return Token(access_token=access_token, token_type="bearer") - # Generate a new API key - new_api_key = ''.join(random.choices(string.ascii_letters + string.digits, k=20)) - user.api_key = new_api_key - db.commit() - - return {"status": "success", "new_api_key": new_api_key} # Redirect /api -> /api/docs @app.get("/") async def redirect_to_docs(): - return RedirectResponse(url="/api/docs") \ No newline at end of file + return RedirectResponse(url="/api/docs") diff --git a/api/routes/links_route.py b/api/routes/links_route.py index 4385712..08e7690 100644 --- a/api/routes/links_route.py +++ b/api/routes/links_route.py @@ -7,9 +7,10 @@ import datetime import validators from api.util.db_dependency import get_db -from api.util.check_api_key import check_api_key from models import Link, Record from api.schemas.links_schemas import URLSchema +from api.schemas.auth_schemas import User +from api.util.authentication import get_current_user router = APIRouter(prefix="/links", tags=["links"]) @@ -17,10 +18,10 @@ router = APIRouter(prefix="/links", tags=["links"]) @router.get("/", summary="Get all of the links associated with your account") async def get_links( + current_user: Annotated[User, Depends(get_current_user)], db=Depends(get_db), - api_key: str = Security(check_api_key), ): - links = db.query(Link).filter(Link.owner == api_key["owner"]).all() + links = db.query(Link).filter(Link.owner == current_user.id).all() if not links: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="No links found" @@ -31,8 +32,8 @@ async def get_links( @router.post("/", summary="Create a new link") async def create_link( url: URLSchema, + current_user: Annotated[User, Depends(get_current_user)], db=Depends(get_db), - api_key: str = Security(check_api_key), ): # Check if the URL is valid if not validators.url(url.url): @@ -48,7 +49,7 @@ async def create_link( ).upper() new_link = Link( link=link_path, - owner=api_key["owner"], + owner=current_user.id, redirect_link=url.url, expire_date=datetime.datetime.now() + datetime.timedelta(days=30), @@ -69,8 +70,8 @@ async def create_link( @router.delete("/{link}", summary="Delete a link") async def delete_link( link: Annotated[str, Path(title="Link to delete")], + current_user: Annotated[User, Depends(get_current_user)], db=Depends(get_db), - api_key: str = Security(check_api_key), ): link = link.upper() # Get the link and check the owner @@ -79,7 +80,7 @@ async def delete_link( raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Link not found" ) - if link.owner != api_key["owner"]: + if link.owner != current_user.id: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Link not associated with your account", @@ -102,8 +103,8 @@ async def delete_link( ) async def get_link_records( link: Annotated[str, Path(title="Link to get records for")], + current_user: Annotated[User, Depends(get_current_user)], db=Depends(get_db), - api_key: str = Security(check_api_key), ): link = link.upper() # Get the link and check the owner @@ -112,7 +113,7 @@ async def get_link_records( raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Link not found" ) - if link.owner != api_key["owner"]: + if link.owner != current_user.id: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Link not associated with your account", @@ -129,8 +130,8 @@ async def get_link_records( ) async def delete_link_records( link: Annotated[str, Path(title="Link to delete records for")], + current_user: Annotated[User, Depends(get_current_user)], db=Depends(get_db), - api_key: str = Security(check_api_key), ): link = link.upper() # Get the link and check the owner @@ -139,7 +140,7 @@ async def delete_link_records( raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="Link not found" ) - if link.owner != api_key["owner"]: + if link.owner != current_user.id: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Link not associated with your account", diff --git a/api/schemas/auth_schemas.py b/api/schemas/auth_schemas.py new file mode 100644 index 0000000..d5b9a88 --- /dev/null +++ b/api/schemas/auth_schemas.py @@ -0,0 +1,19 @@ +from pydantic import BaseModel + + +class Token(BaseModel): + access_token: str + token_type: str + + +class TokenData(BaseModel): + username: str | None = None + + +class User(BaseModel): + username: str + id: int + + +class UserInDB(User): + hashed_password: str diff --git a/api/util/authentication.py b/api/util/authentication.py new file mode 100644 index 0000000..4dfbc77 --- /dev/null +++ b/api/util/authentication.py @@ -0,0 +1,88 @@ +import random +import bcrypt +from fastapi import Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer +from jwt.exceptions import InvalidTokenError +from datetime import datetime, timedelta +from typing import Annotated +import jwt + +from api.util.db_dependency import get_db +from api.schemas.auth_schemas import * +from models import User as UserDB + +secret_key = random.randbytes(32) +algorithm = "HS256" +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + +""" +Helper functions for authentication +""" + + +def verify_password(plain_password, hashed_password): + return bcrypt.checkpw( + plain_password.encode("utf-8"), hashed_password.encode("utf-8") + ) + + +def get_user(db, username: str): + """ + Get the user object from the database + """ + user = db.query(UserDB).filter(UserDB.username == username).first() + if user: + return UserInDB(**user.__dict__) + + +def authenticate_user(db, username: str, password: str): + """ + Determine if the correct username and password were provided + If so, return the user object + """ + user = get_user(db, username) + print(user) + if not user: + return False + if not verify_password(password, user.hashed_password): + return False + return user + + +def create_access_token(data: dict, expires_delta: timedelta | None = None): + """ + Return an encoded JWT token with the given data + """ + to_encode = data.copy() + if expires_delta: + expire = datetime.utcnow() + expires_delta + else: + expire = datetime.utcnow() + timedelta(minutes=15) + to_encode.update({"exp": expire}) + encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=algorithm) + return encoded_jwt + + +async def get_current_user( + token: Annotated[str, Depends(oauth2_scheme)], db=Depends(get_db) +): + """ + Return the current user based on the token, or raise a 401 error + """ + credentials_exception = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + payload = jwt.decode(token, secret_key, algorithms=[algorithm]) + username: str = payload.get("sub") + if username is None: + raise credentials_exception + token_data = TokenData(username=username) + except InvalidTokenError: + raise credentials_exception + user = get_user(db, username=token_data.username) + if user is None: + raise credentials_exception + return user diff --git a/api/util/check_api_key.py b/api/util/check_api_key.py deleted file mode 100644 index 9c4c22e..0000000 --- a/api/util/check_api_key.py +++ /dev/null @@ -1,21 +0,0 @@ -from fastapi import Security, HTTPException, Depends, status -from fastapi.security import APIKeyHeader - -from models import User -from api.util.db_dependency import get_db - -""" -Make sure the provided API key is valid, then return the user's ID -""" -api_key_header = APIKeyHeader(name="X-API-Key") - - -def check_api_key( - api_key_header: str = Security(api_key_header), db=Depends(get_db) -) -> str: - response = db.query(User).filter(User.api_key == api_key_header).first() - if not response: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key" - ) - return {"value": api_key_header, "owner": response.id} diff --git a/api/util/validate_login_information.py b/api/util/validate_login_information.py deleted file mode 100644 index 55bbb2e..0000000 --- a/api/util/validate_login_information.py +++ /dev/null @@ -1,20 +0,0 @@ -import bcrypt -from fastapi import Depends - -from api.util.db_dependency import get_db -from models import User - -""" -Validate the login information provided by the user -""" - - -def validate_login_information( - username: str, password: str, db=Depends(get_db) -) -> bool: - user = db.query(User).filter(User.username == username).first() - if not user: - return False - if bcrypt.checkpw(password.encode("utf-8"), user.password.encode("utf-8")): - return True - return False diff --git a/models.py b/models.py index dad81a0..605b668 100644 --- a/models.py +++ b/models.py @@ -14,7 +14,7 @@ class User(Base): __tablename__ = "users" id = Column(Integer, primary_key=True) username = Column(String, unique=True, nullable=False) - password = Column(Text, nullable=False) + hashed_password = Column(Text, nullable=False) api_key = Column(String(20), unique=True, nullable=False)