Remove API Keys -> Authenticate with JWT
This commit is contained in:
parent
65fef62741
commit
8ae8c5c454
50
api/main.py
50
api/main.py
@ -1,13 +1,20 @@
|
||||
from fastapi import FastAPI, Depends, HTTPException, Security
|
||||
from fastapi import FastAPI, Depends, HTTPException, Security, status
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import RedirectResponse
|
||||
from datetime import timedelta
|
||||
from typing import Annotated
|
||||
from fastapi.security import OAuth2PasswordRequestForm
|
||||
import string
|
||||
import random
|
||||
|
||||
from api.util.authentication import (
|
||||
authenticate_user,
|
||||
create_access_token,
|
||||
get_current_user,
|
||||
)
|
||||
from api.routes.links_route import router as links_router
|
||||
from api.util.db_dependency import get_db
|
||||
from api.util.check_api_key import check_api_key
|
||||
from models import User
|
||||
from api.schemas.auth_schemas import User, Token
|
||||
|
||||
|
||||
metadata_tags = [
|
||||
@ -37,20 +44,33 @@ app.add_middleware(
|
||||
# Import routes
|
||||
app.include_router(links_router)
|
||||
|
||||
# Regenerate the API key for the user
|
||||
@app.post("/regenerate")
|
||||
async def regenerate(api_key: str = Security(check_api_key), db = Depends(get_db)):
|
||||
"""Regenerate the API key for the user. Requires the current API key."""
|
||||
user = db.query(User).filter(User.api_key == api_key['value']).first()
|
||||
|
||||
"""
|
||||
Authentication
|
||||
"""
|
||||
|
||||
|
||||
@app.post("/token")
|
||||
async def login_for_access_token(
|
||||
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||
db=Depends(get_db),
|
||||
) -> Token:
|
||||
"""
|
||||
Return an access token for the user, if the given authentication details are correct
|
||||
"""
|
||||
user = authenticate_user(db, form_data.username, form_data.password)
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect username or password",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
access_token_expires = timedelta(minutes=30)
|
||||
access_token = create_access_token(
|
||||
data={"sub": user.username}, expires_delta=access_token_expires
|
||||
)
|
||||
return Token(access_token=access_token, token_type="bearer")
|
||||
|
||||
# Generate a new API key
|
||||
new_api_key = ''.join(random.choices(string.ascii_letters + string.digits, k=20))
|
||||
user.api_key = new_api_key
|
||||
db.commit()
|
||||
|
||||
return {"status": "success", "new_api_key": new_api_key}
|
||||
|
||||
# Redirect /api -> /api/docs
|
||||
@app.get("/")
|
||||
|
@ -7,9 +7,10 @@ import datetime
|
||||
import validators
|
||||
|
||||
from api.util.db_dependency import get_db
|
||||
from api.util.check_api_key import check_api_key
|
||||
from models import Link, Record
|
||||
from api.schemas.links_schemas import URLSchema
|
||||
from api.schemas.auth_schemas import User
|
||||
from api.util.authentication import get_current_user
|
||||
|
||||
|
||||
router = APIRouter(prefix="/links", tags=["links"])
|
||||
@ -17,10 +18,10 @@ router = APIRouter(prefix="/links", tags=["links"])
|
||||
|
||||
@router.get("/", summary="Get all of the links associated with your account")
|
||||
async def get_links(
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
db=Depends(get_db),
|
||||
api_key: str = Security(check_api_key),
|
||||
):
|
||||
links = db.query(Link).filter(Link.owner == api_key["owner"]).all()
|
||||
links = db.query(Link).filter(Link.owner == current_user.id).all()
|
||||
if not links:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="No links found"
|
||||
@ -31,8 +32,8 @@ async def get_links(
|
||||
@router.post("/", summary="Create a new link")
|
||||
async def create_link(
|
||||
url: URLSchema,
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
db=Depends(get_db),
|
||||
api_key: str = Security(check_api_key),
|
||||
):
|
||||
# Check if the URL is valid
|
||||
if not validators.url(url.url):
|
||||
@ -48,7 +49,7 @@ async def create_link(
|
||||
).upper()
|
||||
new_link = Link(
|
||||
link=link_path,
|
||||
owner=api_key["owner"],
|
||||
owner=current_user.id,
|
||||
redirect_link=url.url,
|
||||
expire_date=datetime.datetime.now()
|
||||
+ datetime.timedelta(days=30),
|
||||
@ -69,8 +70,8 @@ async def create_link(
|
||||
@router.delete("/{link}", summary="Delete a link")
|
||||
async def delete_link(
|
||||
link: Annotated[str, Path(title="Link to delete")],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
db=Depends(get_db),
|
||||
api_key: str = Security(check_api_key),
|
||||
):
|
||||
link = link.upper()
|
||||
# Get the link and check the owner
|
||||
@ -79,7 +80,7 @@ async def delete_link(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Link not found"
|
||||
)
|
||||
if link.owner != api_key["owner"]:
|
||||
if link.owner != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Link not associated with your account",
|
||||
@ -102,8 +103,8 @@ async def delete_link(
|
||||
)
|
||||
async def get_link_records(
|
||||
link: Annotated[str, Path(title="Link to get records for")],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
db=Depends(get_db),
|
||||
api_key: str = Security(check_api_key),
|
||||
):
|
||||
link = link.upper()
|
||||
# Get the link and check the owner
|
||||
@ -112,7 +113,7 @@ async def get_link_records(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Link not found"
|
||||
)
|
||||
if link.owner != api_key["owner"]:
|
||||
if link.owner != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Link not associated with your account",
|
||||
@ -129,8 +130,8 @@ async def get_link_records(
|
||||
)
|
||||
async def delete_link_records(
|
||||
link: Annotated[str, Path(title="Link to delete records for")],
|
||||
current_user: Annotated[User, Depends(get_current_user)],
|
||||
db=Depends(get_db),
|
||||
api_key: str = Security(check_api_key),
|
||||
):
|
||||
link = link.upper()
|
||||
# Get the link and check the owner
|
||||
@ -139,7 +140,7 @@ async def delete_link_records(
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND, detail="Link not found"
|
||||
)
|
||||
if link.owner != api_key["owner"]:
|
||||
if link.owner != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Link not associated with your account",
|
||||
|
19
api/schemas/auth_schemas.py
Normal file
19
api/schemas/auth_schemas.py
Normal file
@ -0,0 +1,19 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class Token(BaseModel):
|
||||
access_token: str
|
||||
token_type: str
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
username: str | None = None
|
||||
|
||||
|
||||
class User(BaseModel):
|
||||
username: str
|
||||
id: int
|
||||
|
||||
|
||||
class UserInDB(User):
|
||||
hashed_password: str
|
88
api/util/authentication.py
Normal file
88
api/util/authentication.py
Normal file
@ -0,0 +1,88 @@
|
||||
import random
|
||||
import bcrypt
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jwt.exceptions import InvalidTokenError
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Annotated
|
||||
import jwt
|
||||
|
||||
from api.util.db_dependency import get_db
|
||||
from api.schemas.auth_schemas import *
|
||||
from models import User as UserDB
|
||||
|
||||
secret_key = random.randbytes(32)
|
||||
algorithm = "HS256"
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
|
||||
"""
|
||||
Helper functions for authentication
|
||||
"""
|
||||
|
||||
|
||||
def verify_password(plain_password, hashed_password):
|
||||
return bcrypt.checkpw(
|
||||
plain_password.encode("utf-8"), hashed_password.encode("utf-8")
|
||||
)
|
||||
|
||||
|
||||
def get_user(db, username: str):
|
||||
"""
|
||||
Get the user object from the database
|
||||
"""
|
||||
user = db.query(UserDB).filter(UserDB.username == username).first()
|
||||
if user:
|
||||
return UserInDB(**user.__dict__)
|
||||
|
||||
|
||||
def authenticate_user(db, username: str, password: str):
|
||||
"""
|
||||
Determine if the correct username and password were provided
|
||||
If so, return the user object
|
||||
"""
|
||||
user = get_user(db, username)
|
||||
print(user)
|
||||
if not user:
|
||||
return False
|
||||
if not verify_password(password, user.hashed_password):
|
||||
return False
|
||||
return user
|
||||
|
||||
|
||||
def create_access_token(data: dict, expires_delta: timedelta | None = None):
|
||||
"""
|
||||
Return an encoded JWT token with the given data
|
||||
"""
|
||||
to_encode = data.copy()
|
||||
if expires_delta:
|
||||
expire = datetime.utcnow() + expires_delta
|
||||
else:
|
||||
expire = datetime.utcnow() + timedelta(minutes=15)
|
||||
to_encode.update({"exp": expire})
|
||||
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=algorithm)
|
||||
return encoded_jwt
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: Annotated[str, Depends(oauth2_scheme)], db=Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Return the current user based on the token, or raise a 401 error
|
||||
"""
|
||||
credentials_exception = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(token, secret_key, algorithms=[algorithm])
|
||||
username: str = payload.get("sub")
|
||||
if username is None:
|
||||
raise credentials_exception
|
||||
token_data = TokenData(username=username)
|
||||
except InvalidTokenError:
|
||||
raise credentials_exception
|
||||
user = get_user(db, username=token_data.username)
|
||||
if user is None:
|
||||
raise credentials_exception
|
||||
return user
|
@ -1,21 +0,0 @@
|
||||
from fastapi import Security, HTTPException, Depends, status
|
||||
from fastapi.security import APIKeyHeader
|
||||
|
||||
from models import User
|
||||
from api.util.db_dependency import get_db
|
||||
|
||||
"""
|
||||
Make sure the provided API key is valid, then return the user's ID
|
||||
"""
|
||||
api_key_header = APIKeyHeader(name="X-API-Key")
|
||||
|
||||
|
||||
def check_api_key(
|
||||
api_key_header: str = Security(api_key_header), db=Depends(get_db)
|
||||
) -> str:
|
||||
response = db.query(User).filter(User.api_key == api_key_header).first()
|
||||
if not response:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
|
||||
)
|
||||
return {"value": api_key_header, "owner": response.id}
|
@ -1,20 +0,0 @@
|
||||
import bcrypt
|
||||
from fastapi import Depends
|
||||
|
||||
from api.util.db_dependency import get_db
|
||||
from models import User
|
||||
|
||||
"""
|
||||
Validate the login information provided by the user
|
||||
"""
|
||||
|
||||
|
||||
def validate_login_information(
|
||||
username: str, password: str, db=Depends(get_db)
|
||||
) -> bool:
|
||||
user = db.query(User).filter(User.username == username).first()
|
||||
if not user:
|
||||
return False
|
||||
if bcrypt.checkpw(password.encode("utf-8"), user.password.encode("utf-8")):
|
||||
return True
|
||||
return False
|
@ -14,7 +14,7 @@ class User(Base):
|
||||
__tablename__ = "users"
|
||||
id = Column(Integer, primary_key=True)
|
||||
username = Column(String, unique=True, nullable=False)
|
||||
password = Column(Text, nullable=False)
|
||||
hashed_password = Column(Text, nullable=False)
|
||||
api_key = Column(String(20), unique=True, nullable=False)
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user