aboutsummaryrefslogtreecommitdiff
path: root/api
diff options
context:
space:
mode:
authorParker <contact@pkrm.dev>2024-11-04 00:12:36 -0600
committerParker <contact@pkrm.dev>2024-11-04 00:12:36 -0600
commit8ae8c5c454ba42e8f56f415d33bbaaac7d1a37ec (patch)
treed56704d87f63b79681530ab729d9f54d24f73c80 /api
parent65fef6274166678f59d6d81c9da68465a7c374bc (diff)
Remove API Keys -> Authenticate with JWT
Diffstat (limited to 'api')
-rw-r--r--api/main.py52
-rw-r--r--api/routes/links_route.py23
-rw-r--r--api/schemas/auth_schemas.py19
-rw-r--r--api/util/authentication.py88
-rw-r--r--api/util/check_api_key.py21
-rw-r--r--api/util/validate_login_information.py20
6 files changed, 155 insertions, 68 deletions
diff --git a/api/main.py b/api/main.py
index 6ede8ba..ac7b927 100644
--- a/api/main.py
+++ b/api/main.py
@@ -1,13 +1,20 @@
-from fastapi import FastAPI, Depends, HTTPException, Security
+from fastapi import FastAPI, Depends, HTTPException, Security, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
+from datetime import timedelta
+from typing import Annotated
+from fastapi.security import OAuth2PasswordRequestForm
import string
import random
+from api.util.authentication import (
+ authenticate_user,
+ create_access_token,
+ get_current_user,
+)
from api.routes.links_route import router as links_router
from api.util.db_dependency import get_db
-from api.util.check_api_key import check_api_key
-from models import User
+from api.schemas.auth_schemas import User, Token
metadata_tags = [
@@ -37,22 +44,35 @@ app.add_middleware(
# Import routes
app.include_router(links_router)
-# Regenerate the API key for the user
-@app.post("/regenerate")
-async def regenerate(api_key: str = Security(check_api_key), db = Depends(get_db)):
- """Regenerate the API key for the user. Requires the current API key."""
- user = db.query(User).filter(User.api_key == api_key['value']).first()
- if not user:
- raise HTTPException(status_code=401, detail="Invalid API key")
- # Generate a new API key
- new_api_key = ''.join(random.choices(string.ascii_letters + string.digits, k=20))
- user.api_key = new_api_key
- db.commit()
+"""
+Authentication
+"""
+
+
+@app.post("/token")
+async def login_for_access_token(
+ form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
+ db=Depends(get_db),
+) -> Token:
+ """
+ Return an access token for the user, if the given authentication details are correct
+ """
+ user = authenticate_user(db, form_data.username, form_data.password)
+ if not user:
+ raise HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Incorrect username or password",
+ headers={"WWW-Authenticate": "Bearer"},
+ )
+ access_token_expires = timedelta(minutes=30)
+ access_token = create_access_token(
+ data={"sub": user.username}, expires_delta=access_token_expires
+ )
+ return Token(access_token=access_token, token_type="bearer")
- return {"status": "success", "new_api_key": new_api_key}
# Redirect /api -> /api/docs
@app.get("/")
async def redirect_to_docs():
- return RedirectResponse(url="/api/docs") \ No newline at end of file
+ return RedirectResponse(url="/api/docs")
diff --git a/api/routes/links_route.py b/api/routes/links_route.py
index 4385712..08e7690 100644
--- a/api/routes/links_route.py
+++ b/api/routes/links_route.py
@@ -7,9 +7,10 @@ import datetime
import validators
from api.util.db_dependency import get_db
-from api.util.check_api_key import check_api_key
from models import Link, Record
from api.schemas.links_schemas import URLSchema
+from api.schemas.auth_schemas import User
+from api.util.authentication import get_current_user
router = APIRouter(prefix="/links", tags=["links"])
@@ -17,10 +18,10 @@ router = APIRouter(prefix="/links", tags=["links"])
@router.get("/", summary="Get all of the links associated with your account")
async def get_links(
+ current_user: Annotated[User, Depends(get_current_user)],
db=Depends(get_db),
- api_key: str = Security(check_api_key),
):
- links = db.query(Link).filter(Link.owner == api_key["owner"]).all()
+ links = db.query(Link).filter(Link.owner == current_user.id).all()
if not links:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="No links found"
@@ -31,8 +32,8 @@ async def get_links(
@router.post("/", summary="Create a new link")
async def create_link(
url: URLSchema,
+ current_user: Annotated[User, Depends(get_current_user)],
db=Depends(get_db),
- api_key: str = Security(check_api_key),
):
# Check if the URL is valid
if not validators.url(url.url):
@@ -48,7 +49,7 @@ async def create_link(
).upper()
new_link = Link(
link=link_path,
- owner=api_key["owner"],
+ owner=current_user.id,
redirect_link=url.url,
expire_date=datetime.datetime.now()
+ datetime.timedelta(days=30),
@@ -69,8 +70,8 @@ async def create_link(
@router.delete("/{link}", summary="Delete a link")
async def delete_link(
link: Annotated[str, Path(title="Link to delete")],
+ current_user: Annotated[User, Depends(get_current_user)],
db=Depends(get_db),
- api_key: str = Security(check_api_key),
):
link = link.upper()
# Get the link and check the owner
@@ -79,7 +80,7 @@ async def delete_link(
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Link not found"
)
- if link.owner != api_key["owner"]:
+ if link.owner != current_user.id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Link not associated with your account",
@@ -102,8 +103,8 @@ async def delete_link(
)
async def get_link_records(
link: Annotated[str, Path(title="Link to get records for")],
+ current_user: Annotated[User, Depends(get_current_user)],
db=Depends(get_db),
- api_key: str = Security(check_api_key),
):
link = link.upper()
# Get the link and check the owner
@@ -112,7 +113,7 @@ async def get_link_records(
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Link not found"
)
- if link.owner != api_key["owner"]:
+ if link.owner != current_user.id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Link not associated with your account",
@@ -129,8 +130,8 @@ async def get_link_records(
)
async def delete_link_records(
link: Annotated[str, Path(title="Link to delete records for")],
+ current_user: Annotated[User, Depends(get_current_user)],
db=Depends(get_db),
- api_key: str = Security(check_api_key),
):
link = link.upper()
# Get the link and check the owner
@@ -139,7 +140,7 @@ async def delete_link_records(
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, detail="Link not found"
)
- if link.owner != api_key["owner"]:
+ if link.owner != current_user.id:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Link not associated with your account",
diff --git a/api/schemas/auth_schemas.py b/api/schemas/auth_schemas.py
new file mode 100644
index 0000000..d5b9a88
--- /dev/null
+++ b/api/schemas/auth_schemas.py
@@ -0,0 +1,19 @@
+from pydantic import BaseModel
+
+
+class Token(BaseModel):
+ access_token: str
+ token_type: str
+
+
+class TokenData(BaseModel):
+ username: str | None = None
+
+
+class User(BaseModel):
+ username: str
+ id: int
+
+
+class UserInDB(User):
+ hashed_password: str
diff --git a/api/util/authentication.py b/api/util/authentication.py
new file mode 100644
index 0000000..4dfbc77
--- /dev/null
+++ b/api/util/authentication.py
@@ -0,0 +1,88 @@
+import random
+import bcrypt
+from fastapi import Depends, HTTPException, status
+from fastapi.security import OAuth2PasswordBearer
+from jwt.exceptions import InvalidTokenError
+from datetime import datetime, timedelta
+from typing import Annotated
+import jwt
+
+from api.util.db_dependency import get_db
+from api.schemas.auth_schemas import *
+from models import User as UserDB
+
+secret_key = random.randbytes(32)
+algorithm = "HS256"
+oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
+
+"""
+Helper functions for authentication
+"""
+
+
+def verify_password(plain_password, hashed_password):
+ return bcrypt.checkpw(
+ plain_password.encode("utf-8"), hashed_password.encode("utf-8")
+ )
+
+
+def get_user(db, username: str):
+ """
+ Get the user object from the database
+ """
+ user = db.query(UserDB).filter(UserDB.username == username).first()
+ if user:
+ return UserInDB(**user.__dict__)
+
+
+def authenticate_user(db, username: str, password: str):
+ """
+ Determine if the correct username and password were provided
+ If so, return the user object
+ """
+ user = get_user(db, username)
+ print(user)
+ if not user:
+ return False
+ if not verify_password(password, user.hashed_password):
+ return False
+ return user
+
+
+def create_access_token(data: dict, expires_delta: timedelta | None = None):
+ """
+ Return an encoded JWT token with the given data
+ """
+ to_encode = data.copy()
+ if expires_delta:
+ expire = datetime.utcnow() + expires_delta
+ else:
+ expire = datetime.utcnow() + timedelta(minutes=15)
+ to_encode.update({"exp": expire})
+ encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=algorithm)
+ return encoded_jwt
+
+
+async def get_current_user(
+ token: Annotated[str, Depends(oauth2_scheme)], db=Depends(get_db)
+):
+ """
+ Return the current user based on the token, or raise a 401 error
+ """
+ credentials_exception = HTTPException(
+ status_code=status.HTTP_401_UNAUTHORIZED,
+ detail="Could not validate credentials",
+ headers={"WWW-Authenticate": "Bearer"},
+ )
+ try:
+ payload = jwt.decode(token, secret_key, algorithms=[algorithm])
+ username: str = payload.get("sub")
+ if username is None:
+ raise credentials_exception
+ token_data = TokenData(username=username)
+ except InvalidTokenError:
+ raise credentials_exception
+ user = get_user(db, username=token_data.username)
+ if user is None:
+ raise credentials_exception
+ return user
diff --git a/api/util/check_api_key.py b/api/util/check_api_key.py
deleted file mode 100644
index 9c4c22e..0000000
--- a/api/util/check_api_key.py
+++ /dev/null
@@ -1,21 +0,0 @@
-from fastapi import Security, HTTPException, Depends, status
-from fastapi.security import APIKeyHeader
-
-from models import User
-from api.util.db_dependency import get_db
-
-"""
-Make sure the provided API key is valid, then return the user's ID
-"""
-api_key_header = APIKeyHeader(name="X-API-Key")
-
-
-def check_api_key(
- api_key_header: str = Security(api_key_header), db=Depends(get_db)
-) -> str:
- response = db.query(User).filter(User.api_key == api_key_header).first()
- if not response:
- raise HTTPException(
- status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
- )
- return {"value": api_key_header, "owner": response.id}
diff --git a/api/util/validate_login_information.py b/api/util/validate_login_information.py
deleted file mode 100644
index 55bbb2e..0000000
--- a/api/util/validate_login_information.py
+++ /dev/null
@@ -1,20 +0,0 @@
-import bcrypt
-from fastapi import Depends
-
-from api.util.db_dependency import get_db
-from models import User
-
-"""
-Validate the login information provided by the user
-"""
-
-
-def validate_login_information(
- username: str, password: str, db=Depends(get_db)
-) -> bool:
- user = db.query(User).filter(User.username == username).first()
- if not user:
- return False
- if bcrypt.checkpw(password.encode("utf-8"), user.password.encode("utf-8")):
- return True
- return False