Fix auth and organization/standards

This commit is contained in:
Parker M. 2024-11-05 20:36:09 -06:00
parent 6f7e810916
commit e944df3d7d
Signed by: parker
GPG Key ID: 505ED36FC12B5D5E
9 changed files with 104 additions and 76 deletions

View File

@ -1,6 +1,6 @@
from fastapi import FastAPI, Depends, Request from fastapi import FastAPI, Depends, Request
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse from fastapi.responses import RedirectResponse, JSONResponse
from fastapi.templating import Jinja2Templates from fastapi.templating import Jinja2Templates
from app.routes.auth_routes import router as auth_router from app.routes.auth_routes import router as auth_router
from app.routes.links_routes import router as links_router from app.routes.links_routes import router as links_router
@ -151,3 +151,11 @@ async def redirect_to_docs():
@app.exception_handler(HTTP_404_NOT_FOUND) @app.exception_handler(HTTP_404_NOT_FOUND)
async def custom_404_handler(request: Request, exc: HTTPException): async def custom_404_handler(request: Request, exc: HTTPException):
return RedirectResponse(url="/login") return RedirectResponse(url="/login")
@app.exception_handler(HTTPException)
async def http_exception_handler(request, exc):
return JSONResponse(
status_code=exc.status_code,
content={"detail": f"{exc.detail}"},
)

View File

@ -26,7 +26,7 @@ async def login_for_access_token(
Return an access token for the user, if the given authentication details are correct Return an access token for the user, if the given authentication details are correct
""" """
user = authenticate_user(db, form_data.username, form_data.password) user = authenticate_user(db, form_data.username, form_data.password)
print(user)
if not user: if not user:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@ -35,14 +35,14 @@ async def login_for_access_token(
) )
access_token_expires = timedelta(minutes=15) access_token_expires = timedelta(minutes=15)
access_token = create_access_token( access_token = create_access_token(
data={"sub": user.id, "refresh": False}, data={"sub": user.id, "username": user.username, "refresh": False},
expires_delta=access_token_expires, expires_delta=access_token_expires,
) )
# Create a refresh token - just an access token with a longer expiry # Create a refresh token - just an access token with a longer expiry
# and more restrictions ("refresh" is True) # and more restrictions ("refresh" is True)
refresh_token_expires = timedelta(days=1) refresh_token_expires = timedelta(days=1)
refresh_token = create_access_token( refresh_token = create_access_token(
data={"sub": user.id, "refresh": True}, data={"sub": user.id, "username": user.username, "refresh": True},
expires_delta=refresh_token_expires, expires_delta=refresh_token_expires,
) )
# response = JSONResponse(content={"success": True}) # response = JSONResponse(content={"success": True})

View File

@ -7,7 +7,7 @@ import datetime
import validators import validators
from app.util.db_dependency import get_db from app.util.db_dependency import get_db
from models import Link, Record from models import Link, Log
from app.schemas.links_schemas import URLSchema from app.schemas.links_schemas import URLSchema
from app.schemas.auth_schemas import User from app.schemas.auth_schemas import User
from app.util.authentication import get_current_user_from_token from app.util.authentication import get_current_user_from_token
@ -69,6 +69,9 @@ async def delete_link(
current_user: Annotated[User, Depends(get_current_user_from_token)], current_user: Annotated[User, Depends(get_current_user_from_token)],
db=Depends(get_db), db=Depends(get_db),
): ):
"""
Delete a link and all of the logs associated with it
"""
link = link.upper() link = link.upper()
# Get the link and check the owner # Get the link and check the owner
link = db.query(Link).filter(Link.link == link).first() link = db.query(Link).filter(Link.link == link).first()
@ -82,10 +85,10 @@ async def delete_link(
detail="Link not associated with your account", detail="Link not associated with your account",
) )
# Get and delete all records associated with the link # Get and delete all logsk
records = db.query(Record).filter(Record.link == link.link).all() logs = db.query(Log).filter(Log.link == link.link).all()
for record in records: for log in logs:
db.delete(record) db.delete(log)
# Delete the link # Delete the link
db.delete(link) db.delete(link)
db.commit() db.commit()
@ -93,15 +96,15 @@ async def delete_link(
return status.HTTP_204_NO_CONTENT return status.HTTP_204_NO_CONTENT
@router.get( @router.get("/{link}/logs", summary="Get all logs associated with a link")
"/{link}/records", async def get_link_logs(
summary="Get all of the IP log records associated with a link", link: Annotated[str, Path(title="Link to get logs for")],
)
async def get_link_records(
link: Annotated[str, Path(title="Link to get records for")],
current_user: Annotated[User, Depends(get_current_user_from_token)], current_user: Annotated[User, Depends(get_current_user_from_token)],
db=Depends(get_db), db=Depends(get_db),
): ):
"""
Get all of the IP logs associated with a link
"""
link = link.upper() link = link.upper()
# Get the link and check the owner # Get the link and check the owner
link = db.query(Link).filter(Link.link == link).first() link = db.query(Link).filter(Link.link == link).first()
@ -115,20 +118,20 @@ async def get_link_records(
detail="Link not associated with your account", detail="Link not associated with your account",
) )
# Get and return all of the records associated with the link # Get and return all of the logs
records = db.query(Record).filter(Record.link == link.link).all() logs = db.query(Log).filter(Log.link == link.link).all()
return records return logs
@router.delete( @router.delete("/{link}/logs", summary="Delete logs associated with a link")
"/{link}/records", async def delete_link_logs(
summary="Delete all of the IP log records associated with a link", link: Annotated[str, Path(title="Link to delete logs for")],
)
async def delete_link_records(
link: Annotated[str, Path(title="Link to delete records for")],
current_user: Annotated[User, Depends(get_current_user_from_token)], current_user: Annotated[User, Depends(get_current_user_from_token)],
db=Depends(get_db), db=Depends(get_db),
): ):
"""
Delete all of the IP logs associated with a link
"""
link = link.upper() link = link.upper()
# Get the link and check the owner # Get the link and check the owner
link = db.query(Link).filter(Link.link == link).first() link = db.query(Link).filter(Link.link == link).first()
@ -142,10 +145,10 @@ async def delete_link_records(
detail="Link not associated with your account", detail="Link not associated with your account",
) )
# Get all of the records associated with the link and delete them # Get all of the logs
records = db.query(Record).filter(Record.link == link.link).all() logs = db.query(Log).filter(Log.link == link.link).all()
for record in records: for log in logs:
db.delete(record) db.delete(log)
db.commit() db.commit()
return status.HTTP_204_NO_CONTENT return status.HTTP_204_NO_CONTENT

View File

@ -4,14 +4,16 @@ from typing import Annotated
import string import string
import bcrypt import bcrypt
import random import random
import datetime
import validators
from app.util.db_dependency import get_db from app.util.db_dependency import get_db
from app.util.check_password_reqs import check_password_reqs
from app.schemas.auth_schemas import User from app.schemas.auth_schemas import User
from app.schemas.user_schemas import * from app.schemas.user_schemas import *
from models import User as UserModel from models import User as UserModel
from app.util.authentication import get_current_user_from_token from app.util.authentication import (
verify_password,
get_current_user_from_token,
)
router = APIRouter(prefix="/users", tags=["users"]) router = APIRouter(prefix="/users", tags=["users"])
@ -26,23 +28,27 @@ async def delete_user(
""" """
Delete the user account associated with the current user Delete the user account associated with the current user
""" """
# No editing others accounts
if user_id != current_user.id: if user_id != current_user.id:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="You can only delete your own account", detail="You can only delete your own account",
) )
# Get the user and delete them
user = db.query(UserModel).filter(UserModel.id == current_user.id).first() user = db.query(UserModel).filter(UserModel.id == current_user.id).first()
if not user: if not user:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_404_NOT_FOUND,
detail="User not found", detail="User not found",
) )
db.delete(user) db.delete(user)
db.commit() db.commit()
return status.HTTP_204_NO_CONTENT return status.HTTP_204_NO_CONTENT
@router.post("/{user_id}", summary="Update your account password") @router.post("/{user_id}/password", summary="Update your account password")
async def update_pass( async def update_pass(
user_id: Annotated[int, Path(title="Link to update")], user_id: Annotated[int, Path(title="Link to update")],
update_data: UpdatePasswordSchema, update_data: UpdatePasswordSchema,
@ -57,22 +63,19 @@ async def update_pass(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="You can only update your own account", detail="You can only update your own account",
) )
# Make sure that they entered the correct current password
if not verify_password(
update_data.current_password, current_user.hashed_password
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect current password",
)
# Make sure the password meets all of the requirements # Make sure the password meets all of the requirements
# if len(update_data.new_password) < 8: check_password_reqs(update_data.new_password)
# raise HTTPException(
# status_code=status.HTTP_400_BAD_REQUEST,
# detail="Password must be at least 8 characters",
# )
# if not any(char.isdigit() for char in update_data.new_password):
# raise HTTPException(
# status_code=status.HTTP_400_BAD_REQUEST,
# detail="Password must contain at least one digit",
# )
# if not any(char.isupper() for char in update_data.new_password):
# raise HTTPException(
# status_code=status.HTTP_400_BAD_REQUEST,
# detail="Password must contain at least one uppercase letter",
# )
# Get the user and update the password # Get the user and update the password
user = db.query(UserModel).filter(UserModel.id == current_user.id).first() user = db.query(UserModel).filter(UserModel.id == current_user.id).first()
if not user: if not user:
@ -98,24 +101,10 @@ async def get_links(
""" """
username = login_data.username username = login_data.username
password = login_data.password password = login_data.password
print(username)
print(password)
# Make sure the password meets all of the requirements # Make sure the password meets all of the requirements
# if len(password) < 8: check_password_reqs(password)
# raise HTTPException(
# status_code=status.HTTP_400_BAD_REQUEST,
# detail="Password must be at least 8 characters",
# )
# if not any(char.isdigit() for char in password):
# raise HTTPException(
# status_code=status.HTTP_400_BAD_REQUEST,
# detail="Password must contain at least one digit",
# )
# if not any(char.isupper() for char in password):
# raise HTTPException(
# status_code=status.HTTP_400_BAD_REQUEST,
# detail="Password must contain at least one uppercase letter",
# )
# Make sure the username isn't taken # Make sure the username isn't taken
user = db.query(UserModel).filter(UserModel.username == username).first() user = db.query(UserModel).filter(UserModel.username == username).first()
if user: if user:

View File

@ -7,5 +7,5 @@ class LoginDataSchema(BaseModel):
class UpdatePasswordSchema(BaseModel): class UpdatePasswordSchema(BaseModel):
password: str current_password: str
new_password: str new_password: str

View File

@ -28,11 +28,11 @@ def verify_password(plain_password, hashed_password):
) )
def get_user(db, id: int): def get_user(db, username: str):
""" """
Get the user object from the database Get the user object from the database
""" """
user = db.query(UserModel).filter(UserModel.id == id).first() user = db.query(UserModel).filter(UserModel.username == username).first()
if user: if user:
return UserInDB(**user.__dict__) return UserInDB(**user.__dict__)
@ -46,6 +46,7 @@ def authenticate_user(db, username: str, password: str):
if not user: if not user:
return False return False
if not verify_password(password, user.hashed_password): if not verify_password(password, user.hashed_password):
print("WHY")
return False return False
return user return user
@ -121,8 +122,9 @@ async def get_current_user(
try: try:
payload = jwt.decode(token, secret_key, algorithms=[algorithm]) payload = jwt.decode(token, secret_key, algorithms=[algorithm])
id: int = payload.get("sub") id: int = payload.get("sub")
username: str = payload.get("username")
refresh: bool = payload.get("refresh") refresh: bool = payload.get("refresh")
if not id: if not id or not username:
return raise_unauthorized() return raise_unauthorized()
# For some reason, an access token was passed when a refresh # For some reason, an access token was passed when a refresh
# token was expected - some likely malicious activity # token was expected - some likely malicious activity
@ -136,7 +138,7 @@ async def get_current_user(
except InvalidTokenError: except InvalidTokenError:
return raise_unauthorized() return raise_unauthorized()
user = get_user(db, id) user = get_user(db, username)
if user is None: if user is None:
return raise_unauthorized() return raise_unauthorized()

View File

@ -0,0 +1,26 @@
from fastapi import HTTPException, status
def check_password_reqs(password: str):
"""
Make sure the entered password meets the security requirements:
1. At least 8 characters
2. At least one digit
3. At least one uppercase letter
"""
if len(password) < 8:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Password must be at least 8 characters",
)
if not any(char.isdigit() for char in password):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Password must contain at least one digit",
)
if not any(char.isupper() for char in password):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Password must contain at least one uppercase letter",
)
return

View File

@ -4,10 +4,10 @@ from ua_parser import user_agent_parser
from database import SessionLocal from database import SessionLocal
import config import config
from models import Link, Record from models import Link, Log
""" """
Create a new log record whenever a link is visited Create a new log whenever a link is visited
""" """
@ -65,8 +65,8 @@ def log(link, ip, user_agent):
browser = ua_string["user_agent"]["family"] browser = ua_string["user_agent"]["family"]
os = f'{ua_string["os"]["family"]} {ua_string["os"]["major"]}' os = f'{ua_string["os"]["family"]} {ua_string["os"]["major"]}'
# Create the log record and commit it to the database # Create the log and commit it to the database
link_record = Record( new_log = Log(
owner=owner, owner=owner,
link=link, link=link,
timestamp=timestamp, timestamp=timestamp,
@ -77,7 +77,7 @@ def log(link, ip, user_agent):
user_agent=user_agent, user_agent=user_agent,
isp=isp, isp=isp,
) )
db.add(link_record) db.add(new_log)
db.commit() db.commit()
db.close() db.close()

View File

@ -26,8 +26,8 @@ class Link(Base):
expire_date = Column(DateTime, nullable=False) expire_date = Column(DateTime, nullable=False)
class Record(Base): class Log(Base):
__tablename__ = "records" __tablename__ = "logs"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
owner = Column(Integer, ForeignKey("users.id"), nullable=False) owner = Column(Integer, ForeignKey("users.id"), nullable=False)
link = Column(String, ForeignKey("links.link"), nullable=False) link = Column(String, ForeignKey("links.link"), nullable=False)