aboutsummaryrefslogtreecommitdiff
path: root/api
diff options
context:
space:
mode:
Diffstat (limited to 'api')
-rw-r--r--api/main.py48
-rw-r--r--api/schemas/auth_schemas.py1
-rw-r--r--api/util/authentication.py22
3 files changed, 57 insertions, 14 deletions
diff --git a/api/main.py b/api/main.py
index ac7b927..54d9f5e 100644
--- a/api/main.py
+++ b/api/main.py
@@ -1,11 +1,10 @@
-from fastapi import FastAPI, Depends, HTTPException, Security, status
+import random
+from fastapi import FastAPI, Depends, HTTPException, status
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
from datetime import timedelta
from typing import Annotated
-from fastapi.security import OAuth2PasswordRequestForm
-import string
-import random
+from fastapi.security import OAuth2PasswordRequestForm, OAuth2PasswordBearer
from api.util.authentication import (
authenticate_user,
@@ -14,7 +13,7 @@ from api.util.authentication import (
)
from api.routes.links_route import router as links_router
from api.util.db_dependency import get_db
-from api.schemas.auth_schemas import User, Token
+from api.schemas.auth_schemas import Token, User
metadata_tags = [
@@ -41,6 +40,10 @@ app.add_middleware(
allow_credentials=True,
)
+secret_key = random.randbytes(32)
+algorithm = "HS256"
+oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
+
# Import routes
app.include_router(links_router)
@@ -65,11 +68,42 @@ async def login_for_access_token(
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
+ access_token_expires = timedelta(minutes=15)
+ access_token = create_access_token(
+ data={"sub": user.username, "refresh": False},
+ expires_delta=access_token_expires,
+ )
+ # Create a refresh token - just an access token with a longer expiry
+ # and more restrictions ("refresh" is True)
+ refresh_token_expire = timedelta(days=1)
+ refresh_token = create_access_token(
+ data={"sub": user.username, "refresh": True},
+ expire_delta=refresh_token_expire,
+ )
+ return Token(
+ access_token=access_token,
+ refresh_token=refresh_token,
+ token_type="bearer",
+ )
+
+
+# Full native JWT support is not complete in FastAPI yet :(
+# Part of that is token refresh, so we must implement it ourselves
+@app.post("/refresh")
+async def refresh_access_token(
+ current_user: Annotated[User, Depends(get_current_user, refresh=True)],
+):
+ """
+ Return a new access token if the refresh token is valid
+ """
access_token_expires = timedelta(minutes=30)
access_token = create_access_token(
- data={"sub": user.username}, expires_delta=access_token_expires
+ data={"sub": current_user.username}, expires_delta=access_token_expires
+ )
+ return Token(
+ access_token=access_token,
+ token_type="bearer",
)
- return Token(access_token=access_token, token_type="bearer")
# Redirect /api -> /api/docs
diff --git a/api/schemas/auth_schemas.py b/api/schemas/auth_schemas.py
index d5b9a88..006a7c8 100644
--- a/api/schemas/auth_schemas.py
+++ b/api/schemas/auth_schemas.py
@@ -3,6 +3,7 @@ from pydantic import BaseModel
class Token(BaseModel):
access_token: str
+ refresh_token: str | None = None
token_type: str
diff --git a/api/util/authentication.py b/api/util/authentication.py
index 4dfbc77..507b806 100644
--- a/api/util/authentication.py
+++ b/api/util/authentication.py
@@ -41,7 +41,6 @@ def authenticate_user(db, username: str, password: str):
If so, return the user object
"""
user = get_user(db, username)
- print(user)
if not user:
return False
if not verify_password(password, user.hashed_password):
@@ -49,22 +48,21 @@ def authenticate_user(db, username: str, password: str):
return user
-def create_access_token(data: dict, expires_delta: timedelta | None = None):
+def create_access_token(data: dict, expires_delta: timedelta):
"""
Return an encoded JWT token with the given data
"""
to_encode = data.copy()
- if expires_delta:
- expire = datetime.utcnow() + expires_delta
- else:
- expire = datetime.utcnow() + timedelta(minutes=15)
+ expire = datetime.utcnow() + expires_delta
to_encode.update({"exp": expire})
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=algorithm)
return encoded_jwt
async def get_current_user(
- token: Annotated[str, Depends(oauth2_scheme)], db=Depends(get_db)
+ token: Annotated[str, Depends(oauth2_scheme)],
+ is_refresh: bool = False,
+ db=Depends(get_db),
):
"""
Return the current user based on the token, or raise a 401 error
@@ -77,8 +75,18 @@ async def get_current_user(
try:
payload = jwt.decode(token, secret_key, algorithms=[algorithm])
username: str = payload.get("sub")
+ refresh: bool = payload.get("refresh")
if username is None:
raise credentials_exception
+ # For some reason, an access token was passed when a refresh
+ # token was expected - some likely malicious activity
+ if not refresh and is_refresh:
+ raise credentials_exception
+ # If the token passed is a refresh token and the function
+ # is not expecting a refresh token, raise an error
+ if refresh and not is_refresh:
+ raise credentials_exception
+
token_data = TokenData(username=username)
except InvalidTokenError:
raise credentials_exception