stuff
This commit is contained in:
parent
d74ae5e116
commit
6f7e810916
10
app/main.py
10
app/main.py
@ -2,9 +2,8 @@ from fastapi import FastAPI, Depends, Request
|
|||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
from fastapi.templating import Jinja2Templates
|
from fastapi.templating import Jinja2Templates
|
||||||
|
from app.routes.auth_routes import router as auth_router
|
||||||
from app.routes.links_routes import router as links_router
|
from app.routes.links_routes import router as links_router
|
||||||
from app.routes.refresh_route import router as refresh_router
|
|
||||||
from app.routes.token_route import router as token_router
|
|
||||||
from app.routes.user_routes import router as user_router
|
from app.routes.user_routes import router as user_router
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from fastapi.exceptions import HTTPException
|
from fastapi.exceptions import HTTPException
|
||||||
@ -35,13 +34,8 @@ app.add_middleware(
|
|||||||
templates = Jinja2Templates(directory="app/templates")
|
templates = Jinja2Templates(directory="app/templates")
|
||||||
|
|
||||||
# Import routes
|
# Import routes
|
||||||
|
app.include_router(auth_router, prefix="/api")
|
||||||
app.include_router(links_router, prefix="/api")
|
app.include_router(links_router, prefix="/api")
|
||||||
# Must not have a prefix... for some reason you can't change
|
|
||||||
# the prefix of the Swagger UI OAuth2 redirect to /api/token
|
|
||||||
# you can only change it to /token, so we have to remove the
|
|
||||||
# prefix in order to keep logging in via Swagger UI working
|
|
||||||
app.include_router(token_router)
|
|
||||||
app.include_router(refresh_router, prefix="/api")
|
|
||||||
app.include_router(user_router, prefix="/api")
|
app.include_router(user_router, prefix="/api")
|
||||||
|
|
||||||
|
|
||||||
|
@ -1,22 +1,22 @@
|
|||||||
from fastapi import APIRouter, status, Depends, HTTPException
|
from fastapi import Depends, APIRouter, status, HTTPException
|
||||||
from fastapi.responses import JSONResponse, Response
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
from typing import Annotated
|
from fastapi.responses import Response
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
from fastapi.security import OAuth2PasswordRequestForm
|
|
||||||
|
|
||||||
from app.util.db_dependency import get_db
|
|
||||||
from app.util.authentication import (
|
from app.util.authentication import (
|
||||||
authenticate_user,
|
|
||||||
create_access_token,
|
create_access_token,
|
||||||
|
authenticate_user,
|
||||||
|
refresh_get_current_user,
|
||||||
)
|
)
|
||||||
from app.schemas.auth_schemas import Token
|
from app.schemas.auth_schemas import Token, User
|
||||||
|
from app.util.db_dependency import get_db
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/token", tags=["token"])
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
|
|
||||||
@router.post("/")
|
@router.post("/token", summary="Authenticate and get an access token")
|
||||||
async def login_for_access_token(
|
async def login_for_access_token(
|
||||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||||
response: Response,
|
response: Response,
|
||||||
@ -26,6 +26,7 @@ async def login_for_access_token(
|
|||||||
Return an access token for the user, if the given authentication details are correct
|
Return an access token for the user, if the given authentication details are correct
|
||||||
"""
|
"""
|
||||||
user = authenticate_user(db, form_data.username, form_data.password)
|
user = authenticate_user(db, form_data.username, form_data.password)
|
||||||
|
print(user)
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
@ -34,14 +35,14 @@ async def login_for_access_token(
|
|||||||
)
|
)
|
||||||
access_token_expires = timedelta(minutes=15)
|
access_token_expires = timedelta(minutes=15)
|
||||||
access_token = create_access_token(
|
access_token = create_access_token(
|
||||||
data={"sub": user.username, "refresh": False},
|
data={"sub": user.id, "refresh": False},
|
||||||
expires_delta=access_token_expires,
|
expires_delta=access_token_expires,
|
||||||
)
|
)
|
||||||
# Create a refresh token - just an access token with a longer expiry
|
# Create a refresh token - just an access token with a longer expiry
|
||||||
# and more restrictions ("refresh" is True)
|
# and more restrictions ("refresh" is True)
|
||||||
refresh_token_expires = timedelta(days=1)
|
refresh_token_expires = timedelta(days=1)
|
||||||
refresh_token = create_access_token(
|
refresh_token = create_access_token(
|
||||||
data={"sub": user.username, "refresh": True},
|
data={"sub": user.id, "refresh": True},
|
||||||
expires_delta=refresh_token_expires,
|
expires_delta=refresh_token_expires,
|
||||||
)
|
)
|
||||||
# response = JSONResponse(content={"success": True})
|
# response = JSONResponse(content={"success": True})
|
||||||
@ -58,3 +59,23 @@ async def login_for_access_token(
|
|||||||
refresh_token=refresh_token,
|
refresh_token=refresh_token,
|
||||||
token_type="bearer",
|
token_type="bearer",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Full native JWT support is not complete in FastAPI yet :(
|
||||||
|
# Part of that is token refresh, so we must implement it ourselves
|
||||||
|
@router.post("/refresh")
|
||||||
|
async def refresh_access_token(
|
||||||
|
current_user: Annotated[User, Depends(refresh_get_current_user)],
|
||||||
|
) -> Token:
|
||||||
|
"""
|
||||||
|
Return a new access token if the refresh token is valid
|
||||||
|
"""
|
||||||
|
access_token_expires = timedelta(minutes=30)
|
||||||
|
access_token = create_access_token(
|
||||||
|
data={"sub": current_user.id, "refresh": False},
|
||||||
|
expires_delta=access_token_expires,
|
||||||
|
)
|
||||||
|
return Token(
|
||||||
|
access_token=access_token,
|
||||||
|
token_type="bearer",
|
||||||
|
)
|
@ -1,32 +0,0 @@
|
|||||||
from fastapi import Depends, APIRouter
|
|
||||||
from datetime import timedelta
|
|
||||||
from typing import Annotated
|
|
||||||
|
|
||||||
from app.util.authentication import (
|
|
||||||
create_access_token,
|
|
||||||
refresh_get_current_user,
|
|
||||||
)
|
|
||||||
from app.schemas.auth_schemas import Token, User
|
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/refresh", tags=["refresh"])
|
|
||||||
|
|
||||||
|
|
||||||
# Full native JWT support is not complete in FastAPI yet :(
|
|
||||||
# Part of that is token refresh, so we must implement it ourselves
|
|
||||||
@router.post("/")
|
|
||||||
async def refresh_access_token(
|
|
||||||
current_user: Annotated[User, Depends(refresh_get_current_user)],
|
|
||||||
) -> Token:
|
|
||||||
"""
|
|
||||||
Return a new access token if the refresh token is valid
|
|
||||||
"""
|
|
||||||
access_token_expires = timedelta(minutes=30)
|
|
||||||
access_token = create_access_token(
|
|
||||||
data={"sub": current_user.username, "refresh": False},
|
|
||||||
expires_delta=access_token_expires,
|
|
||||||
)
|
|
||||||
return Token(
|
|
||||||
access_token=access_token,
|
|
||||||
token_type="bearer",
|
|
||||||
)
|
|
@ -14,12 +14,77 @@ from models import User as UserModel
|
|||||||
from app.util.authentication import get_current_user_from_token
|
from app.util.authentication import get_current_user_from_token
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/user", tags=["user"])
|
router = APIRouter(prefix="/users", tags=["users"])
|
||||||
|
|
||||||
# In order to help protect some anonymity/privacy, user routes
|
|
||||||
# do not use path parameters, as then people could potentially
|
@router.delete("/{user_id}", summary="Delete your account")
|
||||||
# see if a specific username exists or not. Instead, the user
|
async def delete_user(
|
||||||
# routes will use query parameters to specify the user to act
|
user_id: Annotated[int, Path(title="Link to delete")],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user_from_token)],
|
||||||
|
db=Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Delete the user account associated with the current user
|
||||||
|
"""
|
||||||
|
if user_id != current_user.id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="You can only delete your own account",
|
||||||
|
)
|
||||||
|
user = db.query(UserModel).filter(UserModel.id == current_user.id).first()
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="User not found",
|
||||||
|
)
|
||||||
|
db.delete(user)
|
||||||
|
db.commit()
|
||||||
|
return status.HTTP_204_NO_CONTENT
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{user_id}", summary="Update your account password")
|
||||||
|
async def update_pass(
|
||||||
|
user_id: Annotated[int, Path(title="Link to update")],
|
||||||
|
update_data: UpdatePasswordSchema,
|
||||||
|
current_user: Annotated[User, Depends(get_current_user_from_token)],
|
||||||
|
db=Depends(get_db),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Update the pass of the current user account
|
||||||
|
"""
|
||||||
|
if user_id != current_user.id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="You can only update your own account",
|
||||||
|
)
|
||||||
|
# Make sure the password meets all of the requirements
|
||||||
|
# if len(update_data.new_password) < 8:
|
||||||
|
# raise HTTPException(
|
||||||
|
# status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
# detail="Password must be at least 8 characters",
|
||||||
|
# )
|
||||||
|
# if not any(char.isdigit() for char in update_data.new_password):
|
||||||
|
# raise HTTPException(
|
||||||
|
# status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
# detail="Password must contain at least one digit",
|
||||||
|
# )
|
||||||
|
# if not any(char.isupper() for char in update_data.new_password):
|
||||||
|
# raise HTTPException(
|
||||||
|
# status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
# detail="Password must contain at least one uppercase letter",
|
||||||
|
# )
|
||||||
|
# Get the user and update the password
|
||||||
|
user = db.query(UserModel).filter(UserModel.id == current_user.id).first()
|
||||||
|
if not user:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="User not found",
|
||||||
|
)
|
||||||
|
user.hashed_password = bcrypt.hashpw(
|
||||||
|
update_data.new_password.encode("utf-8"), bcrypt.gensalt()
|
||||||
|
).decode("utf-8")
|
||||||
|
db.commit()
|
||||||
|
return status.HTTP_204_NO_CONTENT
|
||||||
|
|
||||||
|
|
||||||
@router.post("/register", summary="Register a new user")
|
@router.post("/register", summary="Register a new user")
|
||||||
@ -33,6 +98,8 @@ async def get_links(
|
|||||||
"""
|
"""
|
||||||
username = login_data.username
|
username = login_data.username
|
||||||
password = login_data.password
|
password = login_data.password
|
||||||
|
print(username)
|
||||||
|
print(password)
|
||||||
# Make sure the password meets all of the requirements
|
# Make sure the password meets all of the requirements
|
||||||
# if len(password) < 8:
|
# if len(password) < 8:
|
||||||
# raise HTTPException(
|
# raise HTTPException(
|
||||||
@ -70,61 +137,3 @@ async def get_links(
|
|||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
return status.HTTP_201_CREATED
|
return status.HTTP_201_CREATED
|
||||||
|
|
||||||
|
|
||||||
@router.get("/delete", summary="Delete a user - provided it's your own")
|
|
||||||
async def delete_user(
|
|
||||||
current_user: Annotated[User, Depends(get_current_user_from_token)],
|
|
||||||
db=Depends(get_db),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Delete the user account associated with the current user
|
|
||||||
"""
|
|
||||||
user = db.query(UserModel).filter(UserModel.id == current_user.id).first()
|
|
||||||
if not user:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="User not found",
|
|
||||||
)
|
|
||||||
db.delete(user)
|
|
||||||
db.commit()
|
|
||||||
return status.HTTP_204_NO_CONTENT
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/updatepass", summary="Update your account's password")
|
|
||||||
async def update_pass(
|
|
||||||
current_user: Annotated[User, Depends(get_current_user_from_token)],
|
|
||||||
update_data: UpdatePasswordSchema,
|
|
||||||
db=Depends(get_db),
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Update the pass of the current user account
|
|
||||||
"""
|
|
||||||
# Make sure the password meets all of the requirements
|
|
||||||
# if len(update_data.new_password) < 8:
|
|
||||||
# raise HTTPException(
|
|
||||||
# status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
# detail="Password must be at least 8 characters",
|
|
||||||
# )
|
|
||||||
# if not any(char.isdigit() for char in update_data.new_password):
|
|
||||||
# raise HTTPException(
|
|
||||||
# status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
# detail="Password must contain at least one digit",
|
|
||||||
# )
|
|
||||||
# if not any(char.isupper() for char in update_data.new_password):
|
|
||||||
# raise HTTPException(
|
|
||||||
# status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
# detail="Password must contain at least one uppercase letter",
|
|
||||||
# )
|
|
||||||
# Get the user and update the password
|
|
||||||
user = db.query(UserModel).filter(UserModel.id == current_user.id).first()
|
|
||||||
if not user:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="User not found",
|
|
||||||
)
|
|
||||||
user.hashed_password = bcrypt.hashpw(
|
|
||||||
update_data.new_password.encode("utf-8"), bcrypt.gensalt()
|
|
||||||
).decode("utf-8")
|
|
||||||
db.commit()
|
|
||||||
return status.HTTP_204_NO_CONTENT
|
|
||||||
|
@ -7,10 +7,6 @@ class Token(BaseModel):
|
|||||||
token_type: str
|
token_type: str
|
||||||
|
|
||||||
|
|
||||||
class TokenData(BaseModel):
|
|
||||||
username: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class User(BaseModel):
|
class User(BaseModel):
|
||||||
username: str
|
username: str
|
||||||
id: int
|
id: int
|
||||||
|
@ -91,11 +91,13 @@
|
|||||||
|
|
||||||
const formData = new FormData(this);
|
const formData = new FormData(this);
|
||||||
// Send POST request to /token containing form data
|
// Send POST request to /token containing form data
|
||||||
const response = await fetch('/token', {
|
const response = await fetch('/api/auth/token', {
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
body: formData
|
body: formData
|
||||||
});
|
});
|
||||||
|
|
||||||
|
console.log(await response.json());
|
||||||
|
|
||||||
if (response.status != 200) {
|
if (response.status != 200) {
|
||||||
document.getElementById('error').style.display = 'block';
|
document.getElementById('error').style.display = 'block';
|
||||||
} else {
|
} else {
|
||||||
|
@ -5,7 +5,7 @@ from fastapi.security import OAuth2PasswordBearer
|
|||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
from jwt.exceptions import InvalidTokenError
|
from jwt.exceptions import InvalidTokenError
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Annotated, Optional
|
from typing import Annotated
|
||||||
import jwt
|
import jwt
|
||||||
|
|
||||||
from app.util.db_dependency import get_db
|
from app.util.db_dependency import get_db
|
||||||
@ -15,7 +15,7 @@ from models import User as UserModel
|
|||||||
|
|
||||||
secret_key = random.randbytes(32)
|
secret_key = random.randbytes(32)
|
||||||
algorithm = "HS256"
|
algorithm = "HS256"
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="api/auth/token")
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Helper functions for authentication
|
Helper functions for authentication
|
||||||
@ -28,11 +28,11 @@ def verify_password(plain_password, hashed_password):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_user(db, username: str):
|
def get_user(db, id: int):
|
||||||
"""
|
"""
|
||||||
Get the user object from the database
|
Get the user object from the database
|
||||||
"""
|
"""
|
||||||
user = db.query(UserModel).filter(UserModel.username == username).first()
|
user = db.query(UserModel).filter(UserModel.id == id).first()
|
||||||
if user:
|
if user:
|
||||||
return UserInDB(**user.__dict__)
|
return UserInDB(**user.__dict__)
|
||||||
|
|
||||||
@ -120,9 +120,9 @@ async def get_current_user(
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(token, secret_key, algorithms=[algorithm])
|
payload = jwt.decode(token, secret_key, algorithms=[algorithm])
|
||||||
username: str = payload.get("sub")
|
id: int = payload.get("sub")
|
||||||
refresh: bool = payload.get("refresh")
|
refresh: bool = payload.get("refresh")
|
||||||
if username is None:
|
if not id:
|
||||||
return raise_unauthorized()
|
return raise_unauthorized()
|
||||||
# For some reason, an access token was passed when a refresh
|
# For some reason, an access token was passed when a refresh
|
||||||
# token was expected - some likely malicious activity
|
# token was expected - some likely malicious activity
|
||||||
@ -133,11 +133,10 @@ async def get_current_user(
|
|||||||
if refresh and not is_refresh:
|
if refresh and not is_refresh:
|
||||||
return raise_unauthorized()
|
return raise_unauthorized()
|
||||||
|
|
||||||
token_data = TokenData(username=username)
|
|
||||||
except InvalidTokenError:
|
except InvalidTokenError:
|
||||||
return raise_unauthorized()
|
return raise_unauthorized()
|
||||||
|
|
||||||
user = get_user(db, username=token_data.username)
|
user = get_user(db, id)
|
||||||
if user is None:
|
if user is None:
|
||||||
return raise_unauthorized()
|
return raise_unauthorized()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user