This commit is contained in:
Parker M. 2024-11-05 15:02:21 -06:00
parent d74ae5e116
commit 6f7e810916
7 changed files with 116 additions and 127 deletions

View File

@ -2,9 +2,8 @@ from fastapi import FastAPI, Depends, Request
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from app.routes.auth_routes import router as auth_router
from app.routes.links_routes import router as links_router from app.routes.links_routes import router as links_router
from app.routes.refresh_route import router as refresh_router
from app.routes.token_route import router as token_router
from app.routes.user_routes import router as user_router from app.routes.user_routes import router as user_router
from typing import Annotated from typing import Annotated
from fastapi.exceptions import HTTPException from fastapi.exceptions import HTTPException
@ -35,13 +34,8 @@ app.add_middleware(
templates = Jinja2Templates(directory="app/templates") templates = Jinja2Templates(directory="app/templates")
# Import routes # Import routes
app.include_router(auth_router, prefix="/api")
app.include_router(links_router, prefix="/api") app.include_router(links_router, prefix="/api")
# Must not have a prefix... for some reason you can't change
# the prefix of the Swagger UI OAuth2 redirect to /api/token
# you can only change it to /token, so we have to remove the
# prefix in order to keep logging in via Swagger UI working
app.include_router(token_router)
app.include_router(refresh_router, prefix="/api")
app.include_router(user_router, prefix="/api") app.include_router(user_router, prefix="/api")

View File

@ -1,22 +1,22 @@
from fastapi import APIRouter, status, Depends, HTTPException from fastapi import Depends, APIRouter, status, HTTPException
from fastapi.responses import JSONResponse, Response from fastapi.security import OAuth2PasswordRequestForm
from typing import Annotated from fastapi.responses import Response
from datetime import timedelta from datetime import timedelta
from typing import Annotated from typing import Annotated
from fastapi.security import OAuth2PasswordRequestForm
from app.util.db_dependency import get_db
from app.util.authentication import ( from app.util.authentication import (
authenticate_user,
create_access_token, create_access_token,
authenticate_user,
refresh_get_current_user,
) )
from app.schemas.auth_schemas import Token from app.schemas.auth_schemas import Token, User
from app.util.db_dependency import get_db
router = APIRouter(prefix="/token", tags=["token"]) router = APIRouter(prefix="/auth", tags=["auth"])
@router.post("/") @router.post("/token", summary="Authenticate and get an access token")
async def login_for_access_token( async def login_for_access_token(
form_data: Annotated[OAuth2PasswordRequestForm, Depends()], form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
response: Response, response: Response,
@ -26,6 +26,7 @@ async def login_for_access_token(
Return an access token for the user, if the given authentication details are correct Return an access token for the user, if the given authentication details are correct
""" """
user = authenticate_user(db, form_data.username, form_data.password) user = authenticate_user(db, form_data.username, form_data.password)
print(user)
if not user: if not user:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -34,14 +35,14 @@ async def login_for_access_token(
) )
access_token_expires = timedelta(minutes=15) access_token_expires = timedelta(minutes=15)
access_token = create_access_token( access_token = create_access_token(
data={"sub": user.username, "refresh": False}, data={"sub": user.id, "refresh": False},
expires_delta=access_token_expires, expires_delta=access_token_expires,
) )
# Create a refresh token - just an access token with a longer expiry # Create a refresh token - just an access token with a longer expiry
# and more restrictions ("refresh" is True) # and more restrictions ("refresh" is True)
refresh_token_expires = timedelta(days=1) refresh_token_expires = timedelta(days=1)
refresh_token = create_access_token( refresh_token = create_access_token(
data={"sub": user.username, "refresh": True}, data={"sub": user.id, "refresh": True},
expires_delta=refresh_token_expires, expires_delta=refresh_token_expires,
) )
# response = JSONResponse(content={"success": True}) # response = JSONResponse(content={"success": True})
@ -58,3 +59,23 @@ async def login_for_access_token(
refresh_token=refresh_token, refresh_token=refresh_token,
token_type="bearer", token_type="bearer",
) )
# Full native JWT support is not complete in FastAPI yet :(
# Part of that is token refresh, so we must implement it ourselves
@router.post("/refresh")
async def refresh_access_token(
current_user: Annotated[User, Depends(refresh_get_current_user)],
) -> Token:
"""
Return a new access token if the refresh token is valid
"""
access_token_expires = timedelta(minutes=30)
access_token = create_access_token(
data={"sub": current_user.id, "refresh": False},
expires_delta=access_token_expires,
)
return Token(
access_token=access_token,
token_type="bearer",
)

View File

@ -1,32 +0,0 @@
from fastapi import Depends, APIRouter
from datetime import timedelta
from typing import Annotated
from app.util.authentication import (
create_access_token,
refresh_get_current_user,
)
from app.schemas.auth_schemas import Token, User
router = APIRouter(prefix="/refresh", tags=["refresh"])
# Full native JWT support is not complete in FastAPI yet :(
# Part of that is token refresh, so we must implement it ourselves
@router.post("/")
async def refresh_access_token(
current_user: Annotated[User, Depends(refresh_get_current_user)],
) -> Token:
"""
Return a new access token if the refresh token is valid
"""
access_token_expires = timedelta(minutes=30)
access_token = create_access_token(
data={"sub": current_user.username, "refresh": False},
expires_delta=access_token_expires,
)
return Token(
access_token=access_token,
token_type="bearer",
)

View File

@ -14,12 +14,77 @@ from models import User as UserModel
from app.util.authentication import get_current_user_from_token from app.util.authentication import get_current_user_from_token
router = APIRouter(prefix="/user", tags=["user"]) router = APIRouter(prefix="/users", tags=["users"])
# In order to help protect some anonymity/privacy, user routes
# do not use path parameters, as then people could potentially @router.delete("/{user_id}", summary="Delete your account")
# see if a specific username exists or not. Instead, the user async def delete_user(
# routes will use query parameters to specify the user to act user_id: Annotated[int, Path(title="Link to delete")],
current_user: Annotated[User, Depends(get_current_user_from_token)],
db=Depends(get_db),
):
"""
Delete the user account associated with the current user
"""
if user_id != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You can only delete your own account",
)
user = db.query(UserModel).filter(UserModel.id == current_user.id).first()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found",
)
db.delete(user)
db.commit()
return status.HTTP_204_NO_CONTENT
@router.post("/{user_id}", summary="Update your account password")
async def update_pass(
user_id: Annotated[int, Path(title="Link to update")],
update_data: UpdatePasswordSchema,
current_user: Annotated[User, Depends(get_current_user_from_token)],
db=Depends(get_db),
):
"""
Update the pass of the current user account
"""
if user_id != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You can only update your own account",
)
# Make sure the password meets all of the requirements
# if len(update_data.new_password) < 8:
# raise HTTPException(
# status_code=status.HTTP_400_BAD_REQUEST,
# detail="Password must be at least 8 characters",
# )
# if not any(char.isdigit() for char in update_data.new_password):
# raise HTTPException(
# status_code=status.HTTP_400_BAD_REQUEST,
# detail="Password must contain at least one digit",
# )
# if not any(char.isupper() for char in update_data.new_password):
# raise HTTPException(
# status_code=status.HTTP_400_BAD_REQUEST,
# detail="Password must contain at least one uppercase letter",
# )
# Get the user and update the password
user = db.query(UserModel).filter(UserModel.id == current_user.id).first()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found",
)
user.hashed_password = bcrypt.hashpw(
update_data.new_password.encode("utf-8"), bcrypt.gensalt()
).decode("utf-8")
db.commit()
return status.HTTP_204_NO_CONTENT
@router.post("/register", summary="Register a new user") @router.post("/register", summary="Register a new user")
@ -33,6 +98,8 @@ async def get_links(
""" """
username = login_data.username username = login_data.username
password = login_data.password password = login_data.password
print(username)
print(password)
# Make sure the password meets all of the requirements # Make sure the password meets all of the requirements
# if len(password) < 8: # if len(password) < 8:
# raise HTTPException( # raise HTTPException(
@ -70,61 +137,3 @@ async def get_links(
db.commit() db.commit()
return status.HTTP_201_CREATED return status.HTTP_201_CREATED
@router.get("/delete", summary="Delete a user - provided it's your own")
async def delete_user(
current_user: Annotated[User, Depends(get_current_user_from_token)],
db=Depends(get_db),
):
"""
Delete the user account associated with the current user
"""
user = db.query(UserModel).filter(UserModel.id == current_user.id).first()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found",
)
db.delete(user)
db.commit()
return status.HTTP_204_NO_CONTENT
@router.put("/updatepass", summary="Update your account's password")
async def update_pass(
current_user: Annotated[User, Depends(get_current_user_from_token)],
update_data: UpdatePasswordSchema,
db=Depends(get_db),
):
"""
Update the pass of the current user account
"""
# Make sure the password meets all of the requirements
# if len(update_data.new_password) < 8:
# raise HTTPException(
# status_code=status.HTTP_400_BAD_REQUEST,
# detail="Password must be at least 8 characters",
# )
# if not any(char.isdigit() for char in update_data.new_password):
# raise HTTPException(
# status_code=status.HTTP_400_BAD_REQUEST,
# detail="Password must contain at least one digit",
# )
# if not any(char.isupper() for char in update_data.new_password):
# raise HTTPException(
# status_code=status.HTTP_400_BAD_REQUEST,
# detail="Password must contain at least one uppercase letter",
# )
# Get the user and update the password
user = db.query(UserModel).filter(UserModel.id == current_user.id).first()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found",
)
user.hashed_password = bcrypt.hashpw(
update_data.new_password.encode("utf-8"), bcrypt.gensalt()
).decode("utf-8")
db.commit()
return status.HTTP_204_NO_CONTENT

View File

@ -7,10 +7,6 @@ class Token(BaseModel):
token_type: str token_type: str
class TokenData(BaseModel):
username: str | None = None
class User(BaseModel): class User(BaseModel):
username: str username: str
id: int id: int

View File

@ -91,11 +91,13 @@
const formData = new FormData(this); const formData = new FormData(this);
// Send POST request to /token containing form data // Send POST request to /token containing form data
const response = await fetch('/token', { const response = await fetch('/api/auth/token', {
method: 'POST', method: 'POST',
body: formData body: formData
}); });
console.log(await response.json());
if (response.status != 200) { if (response.status != 200) {
document.getElementById('error').style.display = 'block'; document.getElementById('error').style.display = 'block';
} else { } else {

View File

@ -5,7 +5,7 @@ from fastapi.security import OAuth2PasswordBearer
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from jwt.exceptions import InvalidTokenError from jwt.exceptions import InvalidTokenError
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Annotated, Optional from typing import Annotated
import jwt import jwt
from app.util.db_dependency import get_db from app.util.db_dependency import get_db
@ -15,7 +15,7 @@ from models import User as UserModel
secret_key = random.randbytes(32) secret_key = random.randbytes(32)
algorithm = "HS256" algorithm = "HS256"
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/auth/token")
""" """
Helper functions for authentication Helper functions for authentication
@ -28,11 +28,11 @@ def verify_password(plain_password, hashed_password):
) )
def get_user(db, username: str): def get_user(db, id: int):
""" """
Get the user object from the database Get the user object from the database
""" """
user = db.query(UserModel).filter(UserModel.username == username).first() user = db.query(UserModel).filter(UserModel.id == id).first()
if user: if user:
return UserInDB(**user.__dict__) return UserInDB(**user.__dict__)
@ -120,9 +120,9 @@ async def get_current_user(
try: try:
payload = jwt.decode(token, secret_key, algorithms=[algorithm]) payload = jwt.decode(token, secret_key, algorithms=[algorithm])
username: str = payload.get("sub") id: int = payload.get("sub")
refresh: bool = payload.get("refresh") refresh: bool = payload.get("refresh")
if username is None: if not id:
return raise_unauthorized() return raise_unauthorized()
# For some reason, an access token was passed when a refresh # For some reason, an access token was passed when a refresh
# token was expected - some likely malicious activity # token was expected - some likely malicious activity
@ -133,11 +133,10 @@ async def get_current_user(
if refresh and not is_refresh: if refresh and not is_refresh:
return raise_unauthorized() return raise_unauthorized()
token_data = TokenData(username=username)
except InvalidTokenError: except InvalidTokenError:
return raise_unauthorized() return raise_unauthorized()
user = get_user(db, username=token_data.username) user = get_user(db, id)
if user is None: if user is None:
return raise_unauthorized() return raise_unauthorized()