From 5a0777033f6733c33fbd6119ade812e0c749be44 Mon Sep 17 00:00:00 2001 From: Parker Date: Mon, 4 Nov 2024 21:14:18 -0600 Subject: [PATCH] Work on refresh tokens --- api/main.py | 10 +++++----- api/util/authentication.py | 19 +++++++++++++++++-- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/api/main.py b/api/main.py index 54d9f5e..fbe8805 100644 --- a/api/main.py +++ b/api/main.py @@ -9,7 +9,7 @@ from fastapi.security import OAuth2PasswordRequestForm, OAuth2PasswordBearer from api.util.authentication import ( authenticate_user, create_access_token, - get_current_user, + refresh_get_current_user, ) from api.routes.links_route import router as links_router from api.util.db_dependency import get_db @@ -75,10 +75,10 @@ async def login_for_access_token( ) # Create a refresh token - just an access token with a longer expiry # and more restrictions ("refresh" is True) - refresh_token_expire = timedelta(days=1) + refresh_token_expires = timedelta(days=1) refresh_token = create_access_token( data={"sub": user.username, "refresh": True}, - expire_delta=refresh_token_expire, + expires_delta=refresh_token_expires, ) return Token( access_token=access_token, @@ -91,8 +91,8 @@ async def login_for_access_token( # Part of that is token refresh, so we must implement it ourselves @app.post("/refresh") async def refresh_access_token( - current_user: Annotated[User, Depends(get_current_user, refresh=True)], -): + current_user: Annotated[User, Depends(refresh_get_current_user)], +) -> Token: """ Return a new access token if the refresh token is valid """ diff --git a/api/util/authentication.py b/api/util/authentication.py index 507b806..b8ac6a6 100644 --- a/api/util/authentication.py +++ b/api/util/authentication.py @@ -4,7 +4,7 @@ from fastapi import Depends, HTTPException, status from fastapi.security import OAuth2PasswordBearer from jwt.exceptions import InvalidTokenError from datetime import datetime, timedelta -from typing import Annotated +from typing import Annotated, Optional import jwt from api.util.db_dependency import get_db @@ -59,8 +59,23 @@ def create_access_token(data: dict, expires_delta: timedelta): return encoded_jwt -async def get_current_user( +# Backwards kinda of way to get refresh token support +# 'refresh_get_current_user' is only called from /refresh +# and alerts 'current_user' that it should expect a refresh token +async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]): + user = await current_user(token) + return user + + +async def refresh_get_current_user( token: Annotated[str, Depends(oauth2_scheme)], +): + user = await current_user(token, is_refresh=True) + return user + + +async def current_user( + token: str, is_refresh: bool = False, db=Depends(get_db), ):