From e944df3d7d431b5bd88c2c235501a355ea1ba6ab Mon Sep 17 00:00:00 2001 From: Parker Date: Tue, 5 Nov 2024 20:36:09 -0600 Subject: [PATCH] Fix auth and organization/standards --- app/main.py | 10 +++++- app/routes/auth_routes.py | 6 ++-- app/routes/links_routes.py | 51 ++++++++++++++------------- app/routes/user_routes.py | 61 ++++++++++++++------------------- app/schemas/user_schemas.py | 2 +- app/util/authentication.py | 10 +++--- app/util/check_password_reqs.py | 26 ++++++++++++++ app/util/log.py | 10 +++--- models.py | 4 +-- 9 files changed, 104 insertions(+), 76 deletions(-) create mode 100644 app/util/check_password_reqs.py diff --git a/app/main.py b/app/main.py index 3ef89f2..90b3104 100644 --- a/app/main.py +++ b/app/main.py @@ -1,6 +1,6 @@ from fastapi import FastAPI, Depends, Request from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import RedirectResponse +from fastapi.responses import RedirectResponse, JSONResponse from fastapi.templating import Jinja2Templates from app.routes.auth_routes import router as auth_router from app.routes.links_routes import router as links_router @@ -151,3 +151,11 @@ async def redirect_to_docs(): @app.exception_handler(HTTP_404_NOT_FOUND) async def custom_404_handler(request: Request, exc: HTTPException): return RedirectResponse(url="/login") + + +@app.exception_handler(HTTPException) +async def http_exception_handler(request, exc): + return JSONResponse( + status_code=exc.status_code, + content={"detail": f"{exc.detail}"}, + ) diff --git a/app/routes/auth_routes.py b/app/routes/auth_routes.py index a28ec63..4d1c25e 100644 --- a/app/routes/auth_routes.py +++ b/app/routes/auth_routes.py @@ -26,7 +26,7 @@ async def login_for_access_token( Return an access token for the user, if the given authentication details are correct """ user = authenticate_user(db, form_data.username, form_data.password) - print(user) + if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -35,14 +35,14 @@ async def login_for_access_token( ) access_token_expires = timedelta(minutes=15) access_token = create_access_token( - data={"sub": user.id, "refresh": False}, + data={"sub": user.id, "username": user.username, "refresh": False}, expires_delta=access_token_expires, ) # Create a refresh token - just an access token with a longer expiry # and more restrictions ("refresh" is True) refresh_token_expires = timedelta(days=1) refresh_token = create_access_token( - data={"sub": user.id, "refresh": True}, + data={"sub": user.id, "username": user.username, "refresh": True}, expires_delta=refresh_token_expires, ) # response = JSONResponse(content={"success": True}) diff --git a/app/routes/links_routes.py b/app/routes/links_routes.py index 848c677..77811c8 100644 --- a/app/routes/links_routes.py +++ b/app/routes/links_routes.py @@ -7,7 +7,7 @@ import datetime import validators from app.util.db_dependency import get_db -from models import Link, Record +from models import Link, Log from app.schemas.links_schemas import URLSchema from app.schemas.auth_schemas import User from app.util.authentication import get_current_user_from_token @@ -69,6 +69,9 @@ async def delete_link( current_user: Annotated[User, Depends(get_current_user_from_token)], db=Depends(get_db), ): + """ + Delete a link and all of the logs associated with it + """ link = link.upper() # Get the link and check the owner link = db.query(Link).filter(Link.link == link).first() @@ -82,10 +85,10 @@ async def delete_link( detail="Link not associated with your account", ) - # Get and delete all records associated with the link - records = db.query(Record).filter(Record.link == link.link).all() - for record in records: - db.delete(record) + # Get and delete all logsk + logs = db.query(Log).filter(Log.link == link.link).all() + for log in logs: + db.delete(log) # Delete the link db.delete(link) db.commit() @@ -93,15 +96,15 @@ async def delete_link( return status.HTTP_204_NO_CONTENT -@router.get( - "/{link}/records", - summary="Get all of the IP log records associated with a link", -) -async def get_link_records( - link: Annotated[str, Path(title="Link to get records for")], +@router.get("/{link}/logs", summary="Get all logs associated with a link") +async def get_link_logs( + link: Annotated[str, Path(title="Link to get logs for")], current_user: Annotated[User, Depends(get_current_user_from_token)], db=Depends(get_db), ): + """ + Get all of the IP logs associated with a link + """ link = link.upper() # Get the link and check the owner link = db.query(Link).filter(Link.link == link).first() @@ -115,20 +118,20 @@ async def get_link_records( detail="Link not associated with your account", ) - # Get and return all of the records associated with the link - records = db.query(Record).filter(Record.link == link.link).all() - return records + # Get and return all of the logs + logs = db.query(Log).filter(Log.link == link.link).all() + return logs -@router.delete( - "/{link}/records", - summary="Delete all of the IP log records associated with a link", -) -async def delete_link_records( - link: Annotated[str, Path(title="Link to delete records for")], +@router.delete("/{link}/logs", summary="Delete logs associated with a link") +async def delete_link_logs( + link: Annotated[str, Path(title="Link to delete logs for")], current_user: Annotated[User, Depends(get_current_user_from_token)], db=Depends(get_db), ): + """ + Delete all of the IP logs associated with a link + """ link = link.upper() # Get the link and check the owner link = db.query(Link).filter(Link.link == link).first() @@ -142,10 +145,10 @@ async def delete_link_records( detail="Link not associated with your account", ) - # Get all of the records associated with the link and delete them - records = db.query(Record).filter(Record.link == link.link).all() - for record in records: - db.delete(record) + # Get all of the logs + logs = db.query(Log).filter(Log.link == link.link).all() + for log in logs: + db.delete(log) db.commit() return status.HTTP_204_NO_CONTENT diff --git a/app/routes/user_routes.py b/app/routes/user_routes.py index 1c6c61e..c356104 100644 --- a/app/routes/user_routes.py +++ b/app/routes/user_routes.py @@ -4,14 +4,16 @@ from typing import Annotated import string import bcrypt import random -import datetime -import validators from app.util.db_dependency import get_db +from app.util.check_password_reqs import check_password_reqs from app.schemas.auth_schemas import User from app.schemas.user_schemas import * from models import User as UserModel -from app.util.authentication import get_current_user_from_token +from app.util.authentication import ( + verify_password, + get_current_user_from_token, +) router = APIRouter(prefix="/users", tags=["users"]) @@ -26,23 +28,27 @@ async def delete_user( """ Delete the user account associated with the current user """ + # No editing others accounts if user_id != current_user.id: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="You can only delete your own account", ) + + # Get the user and delete them user = db.query(UserModel).filter(UserModel.id == current_user.id).first() if not user: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail="User not found", ) + db.delete(user) db.commit() return status.HTTP_204_NO_CONTENT -@router.post("/{user_id}", summary="Update your account password") +@router.post("/{user_id}/password", summary="Update your account password") async def update_pass( user_id: Annotated[int, Path(title="Link to update")], update_data: UpdatePasswordSchema, @@ -57,22 +63,19 @@ async def update_pass( status_code=status.HTTP_403_FORBIDDEN, detail="You can only update your own account", ) + + # Make sure that they entered the correct current password + if not verify_password( + update_data.current_password, current_user.hashed_password + ): + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Incorrect current password", + ) + # Make sure the password meets all of the requirements - # if len(update_data.new_password) < 8: - # raise HTTPException( - # status_code=status.HTTP_400_BAD_REQUEST, - # detail="Password must be at least 8 characters", - # ) - # if not any(char.isdigit() for char in update_data.new_password): - # raise HTTPException( - # status_code=status.HTTP_400_BAD_REQUEST, - # detail="Password must contain at least one digit", - # ) - # if not any(char.isupper() for char in update_data.new_password): - # raise HTTPException( - # status_code=status.HTTP_400_BAD_REQUEST, - # detail="Password must contain at least one uppercase letter", - # ) + check_password_reqs(update_data.new_password) + # Get the user and update the password user = db.query(UserModel).filter(UserModel.id == current_user.id).first() if not user: @@ -98,24 +101,10 @@ async def get_links( """ username = login_data.username password = login_data.password - print(username) - print(password) + # Make sure the password meets all of the requirements - # if len(password) < 8: - # raise HTTPException( - # status_code=status.HTTP_400_BAD_REQUEST, - # detail="Password must be at least 8 characters", - # ) - # if not any(char.isdigit() for char in password): - # raise HTTPException( - # status_code=status.HTTP_400_BAD_REQUEST, - # detail="Password must contain at least one digit", - # ) - # if not any(char.isupper() for char in password): - # raise HTTPException( - # status_code=status.HTTP_400_BAD_REQUEST, - # detail="Password must contain at least one uppercase letter", - # ) + check_password_reqs(password) + # Make sure the username isn't taken user = db.query(UserModel).filter(UserModel.username == username).first() if user: diff --git a/app/schemas/user_schemas.py b/app/schemas/user_schemas.py index 70613ac..949b9a5 100644 --- a/app/schemas/user_schemas.py +++ b/app/schemas/user_schemas.py @@ -7,5 +7,5 @@ class LoginDataSchema(BaseModel): class UpdatePasswordSchema(BaseModel): - password: str + current_password: str new_password: str diff --git a/app/util/authentication.py b/app/util/authentication.py index 99f8b47..1127451 100644 --- a/app/util/authentication.py +++ b/app/util/authentication.py @@ -28,11 +28,11 @@ def verify_password(plain_password, hashed_password): ) -def get_user(db, id: int): +def get_user(db, username: str): """ Get the user object from the database """ - user = db.query(UserModel).filter(UserModel.id == id).first() + user = db.query(UserModel).filter(UserModel.username == username).first() if user: return UserInDB(**user.__dict__) @@ -46,6 +46,7 @@ def authenticate_user(db, username: str, password: str): if not user: return False if not verify_password(password, user.hashed_password): + print("WHY") return False return user @@ -121,8 +122,9 @@ async def get_current_user( try: payload = jwt.decode(token, secret_key, algorithms=[algorithm]) id: int = payload.get("sub") + username: str = payload.get("username") refresh: bool = payload.get("refresh") - if not id: + if not id or not username: return raise_unauthorized() # For some reason, an access token was passed when a refresh # token was expected - some likely malicious activity @@ -136,7 +138,7 @@ async def get_current_user( except InvalidTokenError: return raise_unauthorized() - user = get_user(db, id) + user = get_user(db, username) if user is None: return raise_unauthorized() diff --git a/app/util/check_password_reqs.py b/app/util/check_password_reqs.py new file mode 100644 index 0000000..dcb9bf8 --- /dev/null +++ b/app/util/check_password_reqs.py @@ -0,0 +1,26 @@ +from fastapi import HTTPException, status + + +def check_password_reqs(password: str): + """ + Make sure the entered password meets the security requirements: + 1. At least 8 characters + 2. At least one digit + 3. At least one uppercase letter + """ + if len(password) < 8: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Password must be at least 8 characters", + ) + if not any(char.isdigit() for char in password): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Password must contain at least one digit", + ) + if not any(char.isupper() for char in password): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Password must contain at least one uppercase letter", + ) + return diff --git a/app/util/log.py b/app/util/log.py index de60f23..b84c8a0 100644 --- a/app/util/log.py +++ b/app/util/log.py @@ -4,10 +4,10 @@ from ua_parser import user_agent_parser from database import SessionLocal import config -from models import Link, Record +from models import Link, Log """ -Create a new log record whenever a link is visited +Create a new log whenever a link is visited """ @@ -65,8 +65,8 @@ def log(link, ip, user_agent): browser = ua_string["user_agent"]["family"] os = f'{ua_string["os"]["family"]} {ua_string["os"]["major"]}' - # Create the log record and commit it to the database - link_record = Record( + # Create the log and commit it to the database + new_log = Log( owner=owner, link=link, timestamp=timestamp, @@ -77,7 +77,7 @@ def log(link, ip, user_agent): user_agent=user_agent, isp=isp, ) - db.add(link_record) + db.add(new_log) db.commit() db.close() diff --git a/models.py b/models.py index 605b668..6061661 100644 --- a/models.py +++ b/models.py @@ -26,8 +26,8 @@ class Link(Base): expire_date = Column(DateTime, nullable=False) -class Record(Base): - __tablename__ = "records" +class Log(Base): + __tablename__ = "logs" id = Column(Integer, primary_key=True) owner = Column(Integer, ForeignKey("users.id"), nullable=False) link = Column(String, ForeignKey("links.link"), nullable=False)