Remove API Keys -> Authenticate with JWT
This commit is contained in:
parent
65fef62741
commit
8ae8c5c454
50
api/main.py
50
api/main.py
@ -1,13 +1,20 @@
|
|||||||
from fastapi import FastAPI, Depends, HTTPException, Security
|
from fastapi import FastAPI, Depends, HTTPException, Security, status
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse
|
||||||
|
from datetime import timedelta
|
||||||
|
from typing import Annotated
|
||||||
|
from fastapi.security import OAuth2PasswordRequestForm
|
||||||
import string
|
import string
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
from api.util.authentication import (
|
||||||
|
authenticate_user,
|
||||||
|
create_access_token,
|
||||||
|
get_current_user,
|
||||||
|
)
|
||||||
from api.routes.links_route import router as links_router
|
from api.routes.links_route import router as links_router
|
||||||
from api.util.db_dependency import get_db
|
from api.util.db_dependency import get_db
|
||||||
from api.util.check_api_key import check_api_key
|
from api.schemas.auth_schemas import User, Token
|
||||||
from models import User
|
|
||||||
|
|
||||||
|
|
||||||
metadata_tags = [
|
metadata_tags = [
|
||||||
@ -37,20 +44,33 @@ app.add_middleware(
|
|||||||
# Import routes
|
# Import routes
|
||||||
app.include_router(links_router)
|
app.include_router(links_router)
|
||||||
|
|
||||||
# Regenerate the API key for the user
|
|
||||||
@app.post("/regenerate")
|
"""
|
||||||
async def regenerate(api_key: str = Security(check_api_key), db = Depends(get_db)):
|
Authentication
|
||||||
"""Regenerate the API key for the user. Requires the current API key."""
|
"""
|
||||||
user = db.query(User).filter(User.api_key == api_key['value']).first()
|
|
||||||
|
|
||||||
|
@app.post("/token")
|
||||||
|
async def login_for_access_token(
|
||||||
|
form_data: Annotated[OAuth2PasswordRequestForm, Depends()],
|
||||||
|
db=Depends(get_db),
|
||||||
|
) -> Token:
|
||||||
|
"""
|
||||||
|
Return an access token for the user, if the given authentication details are correct
|
||||||
|
"""
|
||||||
|
user = authenticate_user(db, form_data.username, form_data.password)
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(status_code=401, detail="Invalid API key")
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Incorrect username or password",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
access_token_expires = timedelta(minutes=30)
|
||||||
|
access_token = create_access_token(
|
||||||
|
data={"sub": user.username}, expires_delta=access_token_expires
|
||||||
|
)
|
||||||
|
return Token(access_token=access_token, token_type="bearer")
|
||||||
|
|
||||||
# Generate a new API key
|
|
||||||
new_api_key = ''.join(random.choices(string.ascii_letters + string.digits, k=20))
|
|
||||||
user.api_key = new_api_key
|
|
||||||
db.commit()
|
|
||||||
|
|
||||||
return {"status": "success", "new_api_key": new_api_key}
|
|
||||||
|
|
||||||
# Redirect /api -> /api/docs
|
# Redirect /api -> /api/docs
|
||||||
@app.get("/")
|
@app.get("/")
|
||||||
|
@ -7,9 +7,10 @@ import datetime
|
|||||||
import validators
|
import validators
|
||||||
|
|
||||||
from api.util.db_dependency import get_db
|
from api.util.db_dependency import get_db
|
||||||
from api.util.check_api_key import check_api_key
|
|
||||||
from models import Link, Record
|
from models import Link, Record
|
||||||
from api.schemas.links_schemas import URLSchema
|
from api.schemas.links_schemas import URLSchema
|
||||||
|
from api.schemas.auth_schemas import User
|
||||||
|
from api.util.authentication import get_current_user
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/links", tags=["links"])
|
router = APIRouter(prefix="/links", tags=["links"])
|
||||||
@ -17,10 +18,10 @@ router = APIRouter(prefix="/links", tags=["links"])
|
|||||||
|
|
||||||
@router.get("/", summary="Get all of the links associated with your account")
|
@router.get("/", summary="Get all of the links associated with your account")
|
||||||
async def get_links(
|
async def get_links(
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
db=Depends(get_db),
|
db=Depends(get_db),
|
||||||
api_key: str = Security(check_api_key),
|
|
||||||
):
|
):
|
||||||
links = db.query(Link).filter(Link.owner == api_key["owner"]).all()
|
links = db.query(Link).filter(Link.owner == current_user.id).all()
|
||||||
if not links:
|
if not links:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail="No links found"
|
status_code=status.HTTP_404_NOT_FOUND, detail="No links found"
|
||||||
@ -31,8 +32,8 @@ async def get_links(
|
|||||||
@router.post("/", summary="Create a new link")
|
@router.post("/", summary="Create a new link")
|
||||||
async def create_link(
|
async def create_link(
|
||||||
url: URLSchema,
|
url: URLSchema,
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
db=Depends(get_db),
|
db=Depends(get_db),
|
||||||
api_key: str = Security(check_api_key),
|
|
||||||
):
|
):
|
||||||
# Check if the URL is valid
|
# Check if the URL is valid
|
||||||
if not validators.url(url.url):
|
if not validators.url(url.url):
|
||||||
@ -48,7 +49,7 @@ async def create_link(
|
|||||||
).upper()
|
).upper()
|
||||||
new_link = Link(
|
new_link = Link(
|
||||||
link=link_path,
|
link=link_path,
|
||||||
owner=api_key["owner"],
|
owner=current_user.id,
|
||||||
redirect_link=url.url,
|
redirect_link=url.url,
|
||||||
expire_date=datetime.datetime.now()
|
expire_date=datetime.datetime.now()
|
||||||
+ datetime.timedelta(days=30),
|
+ datetime.timedelta(days=30),
|
||||||
@ -69,8 +70,8 @@ async def create_link(
|
|||||||
@router.delete("/{link}", summary="Delete a link")
|
@router.delete("/{link}", summary="Delete a link")
|
||||||
async def delete_link(
|
async def delete_link(
|
||||||
link: Annotated[str, Path(title="Link to delete")],
|
link: Annotated[str, Path(title="Link to delete")],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
db=Depends(get_db),
|
db=Depends(get_db),
|
||||||
api_key: str = Security(check_api_key),
|
|
||||||
):
|
):
|
||||||
link = link.upper()
|
link = link.upper()
|
||||||
# Get the link and check the owner
|
# Get the link and check the owner
|
||||||
@ -79,7 +80,7 @@ async def delete_link(
|
|||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail="Link not found"
|
status_code=status.HTTP_404_NOT_FOUND, detail="Link not found"
|
||||||
)
|
)
|
||||||
if link.owner != api_key["owner"]:
|
if link.owner != current_user.id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Link not associated with your account",
|
detail="Link not associated with your account",
|
||||||
@ -102,8 +103,8 @@ async def delete_link(
|
|||||||
)
|
)
|
||||||
async def get_link_records(
|
async def get_link_records(
|
||||||
link: Annotated[str, Path(title="Link to get records for")],
|
link: Annotated[str, Path(title="Link to get records for")],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
db=Depends(get_db),
|
db=Depends(get_db),
|
||||||
api_key: str = Security(check_api_key),
|
|
||||||
):
|
):
|
||||||
link = link.upper()
|
link = link.upper()
|
||||||
# Get the link and check the owner
|
# Get the link and check the owner
|
||||||
@ -112,7 +113,7 @@ async def get_link_records(
|
|||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail="Link not found"
|
status_code=status.HTTP_404_NOT_FOUND, detail="Link not found"
|
||||||
)
|
)
|
||||||
if link.owner != api_key["owner"]:
|
if link.owner != current_user.id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Link not associated with your account",
|
detail="Link not associated with your account",
|
||||||
@ -129,8 +130,8 @@ async def get_link_records(
|
|||||||
)
|
)
|
||||||
async def delete_link_records(
|
async def delete_link_records(
|
||||||
link: Annotated[str, Path(title="Link to delete records for")],
|
link: Annotated[str, Path(title="Link to delete records for")],
|
||||||
|
current_user: Annotated[User, Depends(get_current_user)],
|
||||||
db=Depends(get_db),
|
db=Depends(get_db),
|
||||||
api_key: str = Security(check_api_key),
|
|
||||||
):
|
):
|
||||||
link = link.upper()
|
link = link.upper()
|
||||||
# Get the link and check the owner
|
# Get the link and check the owner
|
||||||
@ -139,7 +140,7 @@ async def delete_link_records(
|
|||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND, detail="Link not found"
|
status_code=status.HTTP_404_NOT_FOUND, detail="Link not found"
|
||||||
)
|
)
|
||||||
if link.owner != api_key["owner"]:
|
if link.owner != current_user.id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail="Link not associated with your account",
|
detail="Link not associated with your account",
|
||||||
|
19
api/schemas/auth_schemas.py
Normal file
19
api/schemas/auth_schemas.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class Token(BaseModel):
|
||||||
|
access_token: str
|
||||||
|
token_type: str
|
||||||
|
|
||||||
|
|
||||||
|
class TokenData(BaseModel):
|
||||||
|
username: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class User(BaseModel):
|
||||||
|
username: str
|
||||||
|
id: int
|
||||||
|
|
||||||
|
|
||||||
|
class UserInDB(User):
|
||||||
|
hashed_password: str
|
88
api/util/authentication.py
Normal file
88
api/util/authentication.py
Normal file
@ -0,0 +1,88 @@
|
|||||||
|
import random
|
||||||
|
import bcrypt
|
||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from jwt.exceptions import InvalidTokenError
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Annotated
|
||||||
|
import jwt
|
||||||
|
|
||||||
|
from api.util.db_dependency import get_db
|
||||||
|
from api.schemas.auth_schemas import *
|
||||||
|
from models import User as UserDB
|
||||||
|
|
||||||
|
secret_key = random.randbytes(32)
|
||||||
|
algorithm = "HS256"
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||||
|
|
||||||
|
"""
|
||||||
|
Helper functions for authentication
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def verify_password(plain_password, hashed_password):
|
||||||
|
return bcrypt.checkpw(
|
||||||
|
plain_password.encode("utf-8"), hashed_password.encode("utf-8")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_user(db, username: str):
|
||||||
|
"""
|
||||||
|
Get the user object from the database
|
||||||
|
"""
|
||||||
|
user = db.query(UserDB).filter(UserDB.username == username).first()
|
||||||
|
if user:
|
||||||
|
return UserInDB(**user.__dict__)
|
||||||
|
|
||||||
|
|
||||||
|
def authenticate_user(db, username: str, password: str):
|
||||||
|
"""
|
||||||
|
Determine if the correct username and password were provided
|
||||||
|
If so, return the user object
|
||||||
|
"""
|
||||||
|
user = get_user(db, username)
|
||||||
|
print(user)
|
||||||
|
if not user:
|
||||||
|
return False
|
||||||
|
if not verify_password(password, user.hashed_password):
|
||||||
|
return False
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def create_access_token(data: dict, expires_delta: timedelta | None = None):
|
||||||
|
"""
|
||||||
|
Return an encoded JWT token with the given data
|
||||||
|
"""
|
||||||
|
to_encode = data.copy()
|
||||||
|
if expires_delta:
|
||||||
|
expire = datetime.utcnow() + expires_delta
|
||||||
|
else:
|
||||||
|
expire = datetime.utcnow() + timedelta(minutes=15)
|
||||||
|
to_encode.update({"exp": expire})
|
||||||
|
encoded_jwt = jwt.encode(to_encode, secret_key, algorithm=algorithm)
|
||||||
|
return encoded_jwt
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_user(
|
||||||
|
token: Annotated[str, Depends(oauth2_scheme)], db=Depends(get_db)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Return the current user based on the token, or raise a 401 error
|
||||||
|
"""
|
||||||
|
credentials_exception = HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Could not validate credentials",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, secret_key, algorithms=[algorithm])
|
||||||
|
username: str = payload.get("sub")
|
||||||
|
if username is None:
|
||||||
|
raise credentials_exception
|
||||||
|
token_data = TokenData(username=username)
|
||||||
|
except InvalidTokenError:
|
||||||
|
raise credentials_exception
|
||||||
|
user = get_user(db, username=token_data.username)
|
||||||
|
if user is None:
|
||||||
|
raise credentials_exception
|
||||||
|
return user
|
@ -1,21 +0,0 @@
|
|||||||
from fastapi import Security, HTTPException, Depends, status
|
|
||||||
from fastapi.security import APIKeyHeader
|
|
||||||
|
|
||||||
from models import User
|
|
||||||
from api.util.db_dependency import get_db
|
|
||||||
|
|
||||||
"""
|
|
||||||
Make sure the provided API key is valid, then return the user's ID
|
|
||||||
"""
|
|
||||||
api_key_header = APIKeyHeader(name="X-API-Key")
|
|
||||||
|
|
||||||
|
|
||||||
def check_api_key(
|
|
||||||
api_key_header: str = Security(api_key_header), db=Depends(get_db)
|
|
||||||
) -> str:
|
|
||||||
response = db.query(User).filter(User.api_key == api_key_header).first()
|
|
||||||
if not response:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key"
|
|
||||||
)
|
|
||||||
return {"value": api_key_header, "owner": response.id}
|
|
@ -1,20 +0,0 @@
|
|||||||
import bcrypt
|
|
||||||
from fastapi import Depends
|
|
||||||
|
|
||||||
from api.util.db_dependency import get_db
|
|
||||||
from models import User
|
|
||||||
|
|
||||||
"""
|
|
||||||
Validate the login information provided by the user
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def validate_login_information(
|
|
||||||
username: str, password: str, db=Depends(get_db)
|
|
||||||
) -> bool:
|
|
||||||
user = db.query(User).filter(User.username == username).first()
|
|
||||||
if not user:
|
|
||||||
return False
|
|
||||||
if bcrypt.checkpw(password.encode("utf-8"), user.password.encode("utf-8")):
|
|
||||||
return True
|
|
||||||
return False
|
|
@ -14,7 +14,7 @@ class User(Base):
|
|||||||
__tablename__ = "users"
|
__tablename__ = "users"
|
||||||
id = Column(Integer, primary_key=True)
|
id = Column(Integer, primary_key=True)
|
||||||
username = Column(String, unique=True, nullable=False)
|
username = Column(String, unique=True, nullable=False)
|
||||||
password = Column(Text, nullable=False)
|
hashed_password = Column(Text, nullable=False)
|
||||||
api_key = Column(String(20), unique=True, nullable=False)
|
api_key = Column(String(20), unique=True, nullable=False)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user