Continue JWT implementation - add refresh token

This commit is contained in:
Parker M. 2024-11-04 21:00:42 -06:00
parent 8ae8c5c454
commit d4280d1fda
Signed by: parker
GPG Key ID: 505ED36FC12B5D5E
5 changed files with 86 additions and 23 deletions

View File

@ -1,11 +1,10 @@
from fastapi import FastAPI, Depends, HTTPException, Security, status import random
from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse
from datetime import timedelta from datetime import timedelta
from typing import Annotated from typing import Annotated
from fastapi.security import OAuth2PasswordRequestForm from fastapi.security import OAuth2PasswordRequestForm, OAuth2PasswordBearer
import string
import random
from api.util.authentication import ( from api.util.authentication import (
authenticate_user, authenticate_user,
@ -14,7 +13,7 @@ from api.util.authentication import (
) )
from api.routes.links_route import router as links_router from api.routes.links_route import router as links_router
from api.util.db_dependency import get_db from api.util.db_dependency import get_db
from api.schemas.auth_schemas import User, Token from api.schemas.auth_schemas import Token, User
metadata_tags = [ metadata_tags = [
@ -41,6 +40,10 @@ app.add_middleware(
allow_credentials=True, allow_credentials=True,
) )
secret_key = random.randbytes(32)
algorithm = "HS256"
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# Import routes # Import routes
app.include_router(links_router) app.include_router(links_router)
@ -65,11 +68,42 @@ async def login_for_access_token(
detail="Incorrect username or password", detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"}, headers={"WWW-Authenticate": "Bearer"},
) )
access_token_expires = timedelta(minutes=15)
access_token = create_access_token(
data={"sub": user.username, "refresh": False},
expires_delta=access_token_expires,
)
# Create a refresh token - just an access token with a longer expiry
# and more restrictions ("refresh" is True)
refresh_token_expire = timedelta(days=1)
refresh_token = create_access_token(
data={"sub": user.username, "refresh": True},
expire_delta=refresh_token_expire,
)
return Token(
access_token=access_token,
refresh_token=refresh_token,
token_type="bearer",
)
# Full native JWT support is not complete in FastAPI yet :(
# Part of that is token refresh, so we must implement it ourselves
@app.post("/refresh")
async def refresh_access_token(
current_user: Annotated[User, Depends(get_current_user, refresh=True)],
):
"""
Return a new access token if the refresh token is valid
"""
access_token_expires = timedelta(minutes=30) access_token_expires = timedelta(minutes=30)
access_token = create_access_token( access_token = create_access_token(
data={"sub": user.username}, expires_delta=access_token_expires data={"sub": current_user.username}, expires_delta=access_token_expires
)
return Token(
access_token=access_token,
token_type="bearer",
) )
return Token(access_token=access_token, token_type="bearer")
# Redirect /api -> /api/docs # Redirect /api -> /api/docs

View File

@ -3,6 +3,7 @@ from pydantic import BaseModel
class Token(BaseModel): class Token(BaseModel):
access_token: str access_token: str
refresh_token: str | None = None
token_type: str token_type: str

View File

@ -41,7 +41,6 @@ def authenticate_user(db, username: str, password: str):
If so, return the user object If so, return the user object
""" """
user = get_user(db, username) user = get_user(db, username)
print(user)
if not user: if not user:
return False return False
if not verify_password(password, user.hashed_password): if not verify_password(password, user.hashed_password):
@ -49,22 +48,21 @@ def authenticate_user(db, username: str, password: str):
return user return user
def create_access_token(data: dict, expires_delta: timedelta | None = None): def create_access_token(data: dict, expires_delta: timedelta):
""" """
Return an encoded JWT token with the given data Return an encoded JWT token with the given data
""" """
to_encode = data.copy() to_encode = data.copy()
if expires_delta: expire = datetime.utcnow() + expires_delta
expire = datetime.utcnow() + expires_delta
else:
expire = datetime.utcnow() + timedelta(minutes=15)
to_encode.update({"exp": expire}) to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=algorithm) encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=algorithm)
return encoded_jwt return encoded_jwt
async def get_current_user( async def get_current_user(
token: Annotated[str, Depends(oauth2_scheme)], db=Depends(get_db) token: Annotated[str, Depends(oauth2_scheme)],
is_refresh: bool = False,
db=Depends(get_db),
): ):
""" """
Return the current user based on the token, or raise a 401 error Return the current user based on the token, or raise a 401 error
@ -77,8 +75,18 @@ async def get_current_user(
try: try:
payload = jwt.decode(token, secret_key, algorithms=[algorithm]) payload = jwt.decode(token, secret_key, algorithms=[algorithm])
username: str = payload.get("sub") username: str = payload.get("sub")
refresh: bool = payload.get("refresh")
if username is None: if username is None:
raise credentials_exception raise credentials_exception
# For some reason, an access token was passed when a refresh
# token was expected - some likely malicious activity
if not refresh and is_refresh:
raise credentials_exception
# If the token passed is a refresh token and the function
# is not expecting a refresh token, raise an error
if refresh and not is_refresh:
raise credentials_exception
token_data = TokenData(username=username) token_data = TokenData(username=username)
except InvalidTokenError: except InvalidTokenError:
raise credentials_exception raise credentials_exception

22
app/js/jwt.js Normal file
View File

@ -0,0 +1,22 @@
function parseJwt (token) {
var base64Url = token.split('.')[1];
var base64 = base64Url.replace(/-/g, '+').replace(/_/g, '/');
var jsonPayload = decodeURIComponent(window.atob(base64).split('').map(function(c) {
return '%' + ('00' + c.charCodeAt(0).toString(16)).slice(-2);
}).join(''));
return JSON.parse(jsonPayload);
}
function isJwtExpired (token) {
var jwt = parseJwt(token);
return jwt.exp < Date.now() / 1000;
}
async function refreshAccessToken (refreshToken) {
const data = await fetch('/api/refresh', {
method: 'POST',
headers: {'Authorization': 'Bearer ' + refreshToken}
});
return data.access_token;
}

View File

@ -89,22 +89,20 @@
// Prevent default form submission // Prevent default form submission
event.preventDefault(); event.preventDefault();
// Get form data
const formData = new FormData(this); const formData = new FormData(this);
// Send POST request to /api/token containing form data
console.log(formData) const response = await fetch('/api/token', {
// Send POST request to /login containing form data
const response = await fetch('/login', {
method: 'POST', method: 'POST',
body: formData body: formData
}); });
const data = await response.json();
data = await response.json() if (data.response != 200) {
if (data.status != "success") {
document.getElementById('error').style.display = 'block'; document.getElementById('error').style.display = 'block';
} else { } else {
// Save the tokens in localStorage
window.localStorage.token = data.token;
window.localStorage.refreshToken = data.refreshToken;
window.location.href = '/dashboard'; window.location.href = '/dashboard';
} }
}); });