From d4280d1fda2f4809274793e7bd49f484f57a883e Mon Sep 17 00:00:00 2001 From: Parker Date: Mon, 4 Nov 2024 21:00:42 -0600 Subject: [PATCH] Continue JWT implementation - add refresh token --- api/main.py | 48 +++++++++++++++++++++++++++++++------ api/schemas/auth_schemas.py | 1 + api/util/authentication.py | 22 +++++++++++------ app/js/jwt.js | 22 +++++++++++++++++ app/templates/login.html | 16 ++++++------- 5 files changed, 86 insertions(+), 23 deletions(-) create mode 100644 app/js/jwt.js diff --git a/api/main.py b/api/main.py index ac7b927..54d9f5e 100644 --- a/api/main.py +++ b/api/main.py @@ -1,11 +1,10 @@ -from fastapi import FastAPI, Depends, HTTPException, Security, status +import random +from fastapi import FastAPI, Depends, HTTPException, status from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import RedirectResponse from datetime import timedelta from typing import Annotated -from fastapi.security import OAuth2PasswordRequestForm -import string -import random +from fastapi.security import OAuth2PasswordRequestForm, OAuth2PasswordBearer from api.util.authentication import ( authenticate_user, @@ -14,7 +13,7 @@ from api.util.authentication import ( ) from api.routes.links_route import router as links_router from api.util.db_dependency import get_db -from api.schemas.auth_schemas import User, Token +from api.schemas.auth_schemas import Token, User metadata_tags = [ @@ -41,6 +40,10 @@ app.add_middleware( allow_credentials=True, ) +secret_key = random.randbytes(32) +algorithm = "HS256" +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") + # Import routes app.include_router(links_router) @@ -65,11 +68,42 @@ async def login_for_access_token( detail="Incorrect username or password", headers={"WWW-Authenticate": "Bearer"}, ) + access_token_expires = timedelta(minutes=15) + access_token = create_access_token( + data={"sub": user.username, "refresh": False}, + expires_delta=access_token_expires, + ) + # Create a refresh token - just an access token with a longer expiry + # and more restrictions ("refresh" is True) + refresh_token_expire = timedelta(days=1) + refresh_token = create_access_token( + data={"sub": user.username, "refresh": True}, + expire_delta=refresh_token_expire, + ) + return Token( + access_token=access_token, + refresh_token=refresh_token, + token_type="bearer", + ) + + +# Full native JWT support is not complete in FastAPI yet :( +# Part of that is token refresh, so we must implement it ourselves +@app.post("/refresh") +async def refresh_access_token( + current_user: Annotated[User, Depends(get_current_user, refresh=True)], +): + """ + Return a new access token if the refresh token is valid + """ access_token_expires = timedelta(minutes=30) access_token = create_access_token( - data={"sub": user.username}, expires_delta=access_token_expires + data={"sub": current_user.username}, expires_delta=access_token_expires + ) + return Token( + access_token=access_token, + token_type="bearer", ) - return Token(access_token=access_token, token_type="bearer") # Redirect /api -> /api/docs diff --git a/api/schemas/auth_schemas.py b/api/schemas/auth_schemas.py index d5b9a88..006a7c8 100644 --- a/api/schemas/auth_schemas.py +++ b/api/schemas/auth_schemas.py @@ -3,6 +3,7 @@ from pydantic import BaseModel class Token(BaseModel): access_token: str + refresh_token: str | None = None token_type: str diff --git a/api/util/authentication.py b/api/util/authentication.py index 4dfbc77..507b806 100644 --- a/api/util/authentication.py +++ b/api/util/authentication.py @@ -41,7 +41,6 @@ def authenticate_user(db, username: str, password: str): If so, return the user object """ user = get_user(db, username) - print(user) if not user: return False if not verify_password(password, user.hashed_password): @@ -49,22 +48,21 @@ def authenticate_user(db, username: str, password: str): return user -def create_access_token(data: dict, expires_delta: timedelta | None = None): +def create_access_token(data: dict, expires_delta: timedelta): """ Return an encoded JWT token with the given data """ to_encode = data.copy() - if expires_delta: - expire = datetime.utcnow() + expires_delta - else: - expire = datetime.utcnow() + timedelta(minutes=15) + expire = datetime.utcnow() + expires_delta to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=algorithm) return encoded_jwt async def get_current_user( - token: Annotated[str, Depends(oauth2_scheme)], db=Depends(get_db) + token: Annotated[str, Depends(oauth2_scheme)], + is_refresh: bool = False, + db=Depends(get_db), ): """ Return the current user based on the token, or raise a 401 error @@ -77,8 +75,18 @@ async def get_current_user( try: payload = jwt.decode(token, secret_key, algorithms=[algorithm]) username: str = payload.get("sub") + refresh: bool = payload.get("refresh") if username is None: raise credentials_exception + # For some reason, an access token was passed when a refresh + # token was expected - some likely malicious activity + if not refresh and is_refresh: + raise credentials_exception + # If the token passed is a refresh token and the function + # is not expecting a refresh token, raise an error + if refresh and not is_refresh: + raise credentials_exception + token_data = TokenData(username=username) except InvalidTokenError: raise credentials_exception diff --git a/app/js/jwt.js b/app/js/jwt.js new file mode 100644 index 0000000..43c2e6c --- /dev/null +++ b/app/js/jwt.js @@ -0,0 +1,22 @@ +function parseJwt (token) { + var base64Url = token.split('.')[1]; + var base64 = base64Url.replace(/-/g, '+').replace(/_/g, '/'); + var jsonPayload = decodeURIComponent(window.atob(base64).split('').map(function(c) { + return '%' + ('00' + c.charCodeAt(0).toString(16)).slice(-2); + }).join('')); + + return JSON.parse(jsonPayload); +} + +function isJwtExpired (token) { + var jwt = parseJwt(token); + return jwt.exp < Date.now() / 1000; +} + +async function refreshAccessToken (refreshToken) { + const data = await fetch('/api/refresh', { + method: 'POST', + headers: {'Authorization': 'Bearer ' + refreshToken} + }); + return data.access_token; +} \ No newline at end of file diff --git a/app/templates/login.html b/app/templates/login.html index 25ce3b6..b41d15c 100644 --- a/app/templates/login.html +++ b/app/templates/login.html @@ -89,22 +89,20 @@ // Prevent default form submission event.preventDefault(); - // Get form data const formData = new FormData(this); - - console.log(formData) - - // Send POST request to /login containing form data - const response = await fetch('/login', { + // Send POST request to /api/token containing form data + const response = await fetch('/api/token', { method: 'POST', body: formData }); + const data = await response.json(); - data = await response.json() - - if (data.status != "success") { + if (data.response != 200) { document.getElementById('error').style.display = 'block'; } else { + // Save the tokens in localStorage + window.localStorage.token = data.token; + window.localStorage.refreshToken = data.refreshToken; window.location.href = '/dashboard'; } });