Fix auth and organization/standards

This commit is contained in:
Parker M. 2024-11-05 20:36:09 -06:00
parent 6f7e810916
commit e944df3d7d
Signed by: parker
GPG Key ID: 505ED36FC12B5D5E
9 changed files with 104 additions and 76 deletions

View File

@ -1,6 +1,6 @@
from fastapi import FastAPI, Depends, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import RedirectResponse
from fastapi.responses import RedirectResponse, JSONResponse
from fastapi.templating import Jinja2Templates
from app.routes.auth_routes import router as auth_router
from app.routes.links_routes import router as links_router
@ -151,3 +151,11 @@ async def redirect_to_docs():
@app.exception_handler(HTTP_404_NOT_FOUND)
async def custom_404_handler(request: Request, exc: HTTPException):
return RedirectResponse(url="/login")
@app.exception_handler(HTTPException)
async def http_exception_handler(request, exc):
return JSONResponse(
status_code=exc.status_code,
content={"detail": f"{exc.detail}"},
)

View File

@ -26,7 +26,7 @@ async def login_for_access_token(
Return an access token for the user, if the given authentication details are correct
"""
user = authenticate_user(db, form_data.username, form_data.password)
print(user)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
@ -35,14 +35,14 @@ async def login_for_access_token(
)
access_token_expires = timedelta(minutes=15)
access_token = create_access_token(
data={"sub": user.id, "refresh": False},
data={"sub": user.id, "username": user.username, "refresh": False},
expires_delta=access_token_expires,
)
# Create a refresh token - just an access token with a longer expiry
# and more restrictions ("refresh" is True)
refresh_token_expires = timedelta(days=1)
refresh_token = create_access_token(
data={"sub": user.id, "refresh": True},
data={"sub": user.id, "username": user.username, "refresh": True},
expires_delta=refresh_token_expires,
)
# response = JSONResponse(content={"success": True})

View File

@ -7,7 +7,7 @@ import datetime
import validators
from app.util.db_dependency import get_db
from models import Link, Record
from models import Link, Log
from app.schemas.links_schemas import URLSchema
from app.schemas.auth_schemas import User
from app.util.authentication import get_current_user_from_token
@ -69,6 +69,9 @@ async def delete_link(
current_user: Annotated[User, Depends(get_current_user_from_token)],
db=Depends(get_db),
):
"""
Delete a link and all of the logs associated with it
"""
link = link.upper()
# Get the link and check the owner
link = db.query(Link).filter(Link.link == link).first()
@ -82,10 +85,10 @@ async def delete_link(
detail="Link not associated with your account",
)
# Get and delete all records associated with the link
records = db.query(Record).filter(Record.link == link.link).all()
for record in records:
db.delete(record)
# Get and delete all logsk
logs = db.query(Log).filter(Log.link == link.link).all()
for log in logs:
db.delete(log)
# Delete the link
db.delete(link)
db.commit()
@ -93,15 +96,15 @@ async def delete_link(
return status.HTTP_204_NO_CONTENT
@router.get(
"/{link}/records",
summary="Get all of the IP log records associated with a link",
)
async def get_link_records(
link: Annotated[str, Path(title="Link to get records for")],
@router.get("/{link}/logs", summary="Get all logs associated with a link")
async def get_link_logs(
link: Annotated[str, Path(title="Link to get logs for")],
current_user: Annotated[User, Depends(get_current_user_from_token)],
db=Depends(get_db),
):
"""
Get all of the IP logs associated with a link
"""
link = link.upper()
# Get the link and check the owner
link = db.query(Link).filter(Link.link == link).first()
@ -115,20 +118,20 @@ async def get_link_records(
detail="Link not associated with your account",
)
# Get and return all of the records associated with the link
records = db.query(Record).filter(Record.link == link.link).all()
return records
# Get and return all of the logs
logs = db.query(Log).filter(Log.link == link.link).all()
return logs
@router.delete(
"/{link}/records",
summary="Delete all of the IP log records associated with a link",
)
async def delete_link_records(
link: Annotated[str, Path(title="Link to delete records for")],
@router.delete("/{link}/logs", summary="Delete logs associated with a link")
async def delete_link_logs(
link: Annotated[str, Path(title="Link to delete logs for")],
current_user: Annotated[User, Depends(get_current_user_from_token)],
db=Depends(get_db),
):
"""
Delete all of the IP logs associated with a link
"""
link = link.upper()
# Get the link and check the owner
link = db.query(Link).filter(Link.link == link).first()
@ -142,10 +145,10 @@ async def delete_link_records(
detail="Link not associated with your account",
)
# Get all of the records associated with the link and delete them
records = db.query(Record).filter(Record.link == link.link).all()
for record in records:
db.delete(record)
# Get all of the logs
logs = db.query(Log).filter(Log.link == link.link).all()
for log in logs:
db.delete(log)
db.commit()
return status.HTTP_204_NO_CONTENT

View File

@ -4,14 +4,16 @@ from typing import Annotated
import string
import bcrypt
import random
import datetime
import validators
from app.util.db_dependency import get_db
from app.util.check_password_reqs import check_password_reqs
from app.schemas.auth_schemas import User
from app.schemas.user_schemas import *
from models import User as UserModel
from app.util.authentication import get_current_user_from_token
from app.util.authentication import (
verify_password,
get_current_user_from_token,
)
router = APIRouter(prefix="/users", tags=["users"])
@ -26,23 +28,27 @@ async def delete_user(
"""
Delete the user account associated with the current user
"""
# No editing others accounts
if user_id != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You can only delete your own account",
)
# Get the user and delete them
user = db.query(UserModel).filter(UserModel.id == current_user.id).first()
if not user:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="User not found",
)
db.delete(user)
db.commit()
return status.HTTP_204_NO_CONTENT
@router.post("/{user_id}", summary="Update your account password")
@router.post("/{user_id}/password", summary="Update your account password")
async def update_pass(
user_id: Annotated[int, Path(title="Link to update")],
update_data: UpdatePasswordSchema,
@ -57,22 +63,19 @@ async def update_pass(
status_code=status.HTTP_403_FORBIDDEN,
detail="You can only update your own account",
)
# Make sure that they entered the correct current password
if not verify_password(
update_data.current_password, current_user.hashed_password
):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Incorrect current password",
)
# Make sure the password meets all of the requirements
# if len(update_data.new_password) < 8:
# raise HTTPException(
# status_code=status.HTTP_400_BAD_REQUEST,
# detail="Password must be at least 8 characters",
# )
# if not any(char.isdigit() for char in update_data.new_password):
# raise HTTPException(
# status_code=status.HTTP_400_BAD_REQUEST,
# detail="Password must contain at least one digit",
# )
# if not any(char.isupper() for char in update_data.new_password):
# raise HTTPException(
# status_code=status.HTTP_400_BAD_REQUEST,
# detail="Password must contain at least one uppercase letter",
# )
check_password_reqs(update_data.new_password)
# Get the user and update the password
user = db.query(UserModel).filter(UserModel.id == current_user.id).first()
if not user:
@ -98,24 +101,10 @@ async def get_links(
"""
username = login_data.username
password = login_data.password
print(username)
print(password)
# Make sure the password meets all of the requirements
# if len(password) < 8:
# raise HTTPException(
# status_code=status.HTTP_400_BAD_REQUEST,
# detail="Password must be at least 8 characters",
# )
# if not any(char.isdigit() for char in password):
# raise HTTPException(
# status_code=status.HTTP_400_BAD_REQUEST,
# detail="Password must contain at least one digit",
# )
# if not any(char.isupper() for char in password):
# raise HTTPException(
# status_code=status.HTTP_400_BAD_REQUEST,
# detail="Password must contain at least one uppercase letter",
# )
check_password_reqs(password)
# Make sure the username isn't taken
user = db.query(UserModel).filter(UserModel.username == username).first()
if user:

View File

@ -7,5 +7,5 @@ class LoginDataSchema(BaseModel):
class UpdatePasswordSchema(BaseModel):
password: str
current_password: str
new_password: str

View File

@ -28,11 +28,11 @@ def verify_password(plain_password, hashed_password):
)
def get_user(db, id: int):
def get_user(db, username: str):
"""
Get the user object from the database
"""
user = db.query(UserModel).filter(UserModel.id == id).first()
user = db.query(UserModel).filter(UserModel.username == username).first()
if user:
return UserInDB(**user.__dict__)
@ -46,6 +46,7 @@ def authenticate_user(db, username: str, password: str):
if not user:
return False
if not verify_password(password, user.hashed_password):
print("WHY")
return False
return user
@ -121,8 +122,9 @@ async def get_current_user(
try:
payload = jwt.decode(token, secret_key, algorithms=[algorithm])
id: int = payload.get("sub")
username: str = payload.get("username")
refresh: bool = payload.get("refresh")
if not id:
if not id or not username:
return raise_unauthorized()
# For some reason, an access token was passed when a refresh
# token was expected - some likely malicious activity
@ -136,7 +138,7 @@ async def get_current_user(
except InvalidTokenError:
return raise_unauthorized()
user = get_user(db, id)
user = get_user(db, username)
if user is None:
return raise_unauthorized()

View File

@ -0,0 +1,26 @@
from fastapi import HTTPException, status
def check_password_reqs(password: str):
"""
Make sure the entered password meets the security requirements:
1. At least 8 characters
2. At least one digit
3. At least one uppercase letter
"""
if len(password) < 8:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Password must be at least 8 characters",
)
if not any(char.isdigit() for char in password):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Password must contain at least one digit",
)
if not any(char.isupper() for char in password):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Password must contain at least one uppercase letter",
)
return

View File

@ -4,10 +4,10 @@ from ua_parser import user_agent_parser
from database import SessionLocal
import config
from models import Link, Record
from models import Link, Log
"""
Create a new log record whenever a link is visited
Create a new log whenever a link is visited
"""
@ -65,8 +65,8 @@ def log(link, ip, user_agent):
browser = ua_string["user_agent"]["family"]
os = f'{ua_string["os"]["family"]} {ua_string["os"]["major"]}'
# Create the log record and commit it to the database
link_record = Record(
# Create the log and commit it to the database
new_log = Log(
owner=owner,
link=link,
timestamp=timestamp,
@ -77,7 +77,7 @@ def log(link, ip, user_agent):
user_agent=user_agent,
isp=isp,
)
db.add(link_record)
db.add(new_log)
db.commit()
db.close()

View File

@ -26,8 +26,8 @@ class Link(Base):
expire_date = Column(DateTime, nullable=False)
class Record(Base):
__tablename__ = "records"
class Log(Base):
__tablename__ = "logs"
id = Column(Integer, primary_key=True)
owner = Column(Integer, ForeignKey("users.id"), nullable=False)
link = Column(String, ForeignKey("links.link"), nullable=False)