Fix auth and organization/standards
This commit is contained in:
parent
6f7e810916
commit
e944df3d7d
10
app/main.py
10
app/main.py
@ -1,6 +1,6 @@
|
||||
from fastapi import FastAPI, Depends, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.responses import RedirectResponse, JSONResponse
|
||||
from fastapi.templating import Jinja2Templates
|
||||
from app.routes.auth_routes import router as auth_router
|
||||
from app.routes.links_routes import router as links_router
|
||||
@ -151,3 +151,11 @@ async def redirect_to_docs():
|
||||
@app.exception_handler(HTTP_404_NOT_FOUND)
|
||||
async def custom_404_handler(request: Request, exc: HTTPException):
|
||||
return RedirectResponse(url="/login")
|
||||
|
||||
|
||||
@app.exception_handler(HTTPException)
|
||||
async def http_exception_handler(request, exc):
|
||||
return JSONResponse(
|
||||
status_code=exc.status_code,
|
||||
content={"detail": f"{exc.detail}"},
|
||||
)
|
||||
|
@ -26,7 +26,7 @@ async def login_for_access_token(
|
||||
Return an access token for the user, if the given authentication details are correct
|
||||
"""
|
||||
user = authenticate_user(db, form_data.username, form_data.password)
|
||||
print(user)
|
||||
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
@ -35,14 +35,14 @@ async def login_for_access_token(
|
||||
)
|
||||
access_token_expires = timedelta(minutes=15)
|
||||
access_token = create_access_token(
|
||||
data={"sub": user.id, "refresh": False},
|
||||
data={"sub": user.id, "username": user.username, "refresh": False},
|
||||
expires_delta=access_token_expires,
|
||||
)
|
||||
# Create a refresh token - just an access token with a longer expiry
|
||||
# and more restrictions ("refresh" is True)
|
||||
refresh_token_expires = timedelta(days=1)
|
||||
refresh_token = create_access_token(
|
||||
data={"sub": user.id, "refresh": True},
|
||||
data={"sub": user.id, "username": user.username, "refresh": True},
|
||||
expires_delta=refresh_token_expires,
|
||||
)
|
||||
# response = JSONResponse(content={"success": True})
|
||||
|
@ -7,7 +7,7 @@ import datetime
|
||||
import validators
|
||||
|
||||
from app.util.db_dependency import get_db
|
||||
from models import Link, Record
|
||||
from models import Link, Log
|
||||
from app.schemas.links_schemas import URLSchema
|
||||
from app.schemas.auth_schemas import User
|
||||
from app.util.authentication import get_current_user_from_token
|
||||
@ -69,6 +69,9 @@ async def delete_link(
|
||||
current_user: Annotated[User, Depends(get_current_user_from_token)],
|
||||
db=Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Delete a link and all of the logs associated with it
|
||||
"""
|
||||
link = link.upper()
|
||||
# Get the link and check the owner
|
||||
link = db.query(Link).filter(Link.link == link).first()
|
||||
@ -82,10 +85,10 @@ async def delete_link(
|
||||
detail="Link not associated with your account",
|
||||
)
|
||||
|
||||
# Get and delete all records associated with the link
|
||||
records = db.query(Record).filter(Record.link == link.link).all()
|
||||
for record in records:
|
||||
db.delete(record)
|
||||
# Get and delete all logsk
|
||||
logs = db.query(Log).filter(Log.link == link.link).all()
|
||||
for log in logs:
|
||||
db.delete(log)
|
||||
# Delete the link
|
||||
db.delete(link)
|
||||
db.commit()
|
||||
@ -93,15 +96,15 @@ async def delete_link(
|
||||
return status.HTTP_204_NO_CONTENT
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{link}/records",
|
||||
summary="Get all of the IP log records associated with a link",
|
||||
)
|
||||
async def get_link_records(
|
||||
link: Annotated[str, Path(title="Link to get records for")],
|
||||
@router.get("/{link}/logs", summary="Get all logs associated with a link")
|
||||
async def get_link_logs(
|
||||
link: Annotated[str, Path(title="Link to get logs for")],
|
||||
current_user: Annotated[User, Depends(get_current_user_from_token)],
|
||||
db=Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get all of the IP logs associated with a link
|
||||
"""
|
||||
link = link.upper()
|
||||
# Get the link and check the owner
|
||||
link = db.query(Link).filter(Link.link == link).first()
|
||||
@ -115,20 +118,20 @@ async def get_link_records(
|
||||
detail="Link not associated with your account",
|
||||
)
|
||||
|
||||
# Get and return all of the records associated with the link
|
||||
records = db.query(Record).filter(Record.link == link.link).all()
|
||||
return records
|
||||
# Get and return all of the logs
|
||||
logs = db.query(Log).filter(Log.link == link.link).all()
|
||||
return logs
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{link}/records",
|
||||
summary="Delete all of the IP log records associated with a link",
|
||||
)
|
||||
async def delete_link_records(
|
||||
link: Annotated[str, Path(title="Link to delete records for")],
|
||||
@router.delete("/{link}/logs", summary="Delete logs associated with a link")
|
||||
async def delete_link_logs(
|
||||
link: Annotated[str, Path(title="Link to delete logs for")],
|
||||
current_user: Annotated[User, Depends(get_current_user_from_token)],
|
||||
db=Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Delete all of the IP logs associated with a link
|
||||
"""
|
||||
link = link.upper()
|
||||
# Get the link and check the owner
|
||||
link = db.query(Link).filter(Link.link == link).first()
|
||||
@ -142,10 +145,10 @@ async def delete_link_records(
|
||||
detail="Link not associated with your account",
|
||||
)
|
||||
|
||||
# Get all of the records associated with the link and delete them
|
||||
records = db.query(Record).filter(Record.link == link.link).all()
|
||||
for record in records:
|
||||
db.delete(record)
|
||||
# Get all of the logs
|
||||
logs = db.query(Log).filter(Log.link == link.link).all()
|
||||
for log in logs:
|
||||
db.delete(log)
|
||||
db.commit()
|
||||
|
||||
return status.HTTP_204_NO_CONTENT
|
||||
|
@ -4,14 +4,16 @@ from typing import Annotated
|
||||
import string
|
||||
import bcrypt
|
||||
import random
|
||||
import datetime
|
||||
import validators
|
||||
|
||||
from app.util.db_dependency import get_db
|
||||
from app.util.check_password_reqs import check_password_reqs
|
||||
from app.schemas.auth_schemas import User
|
||||
from app.schemas.user_schemas import *
|
||||
from models import User as UserModel
|
||||
from app.util.authentication import get_current_user_from_token
|
||||
from app.util.authentication import (
|
||||
verify_password,
|
||||
get_current_user_from_token,
|
||||
)
|
||||
|
||||
|
||||
router = APIRouter(prefix="/users", tags=["users"])
|
||||
@ -26,23 +28,27 @@ async def delete_user(
|
||||
"""
|
||||
Delete the user account associated with the current user
|
||||
"""
|
||||
# No editing others accounts
|
||||
if user_id != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You can only delete your own account",
|
||||
)
|
||||
|
||||
# Get the user and delete them
|
||||
user = db.query(UserModel).filter(UserModel.id == current_user.id).first()
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="User not found",
|
||||
)
|
||||
|
||||
db.delete(user)
|
||||
db.commit()
|
||||
return status.HTTP_204_NO_CONTENT
|
||||
|
||||
|
||||
@router.post("/{user_id}", summary="Update your account password")
|
||||
@router.post("/{user_id}/password", summary="Update your account password")
|
||||
async def update_pass(
|
||||
user_id: Annotated[int, Path(title="Link to update")],
|
||||
update_data: UpdatePasswordSchema,
|
||||
@ -57,22 +63,19 @@ async def update_pass(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You can only update your own account",
|
||||
)
|
||||
|
||||
# Make sure that they entered the correct current password
|
||||
if not verify_password(
|
||||
update_data.current_password, current_user.hashed_password
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Incorrect current password",
|
||||
)
|
||||
|
||||
# Make sure the password meets all of the requirements
|
||||
# if len(update_data.new_password) < 8:
|
||||
# raise HTTPException(
|
||||
# status_code=status.HTTP_400_BAD_REQUEST,
|
||||
# detail="Password must be at least 8 characters",
|
||||
# )
|
||||
# if not any(char.isdigit() for char in update_data.new_password):
|
||||
# raise HTTPException(
|
||||
# status_code=status.HTTP_400_BAD_REQUEST,
|
||||
# detail="Password must contain at least one digit",
|
||||
# )
|
||||
# if not any(char.isupper() for char in update_data.new_password):
|
||||
# raise HTTPException(
|
||||
# status_code=status.HTTP_400_BAD_REQUEST,
|
||||
# detail="Password must contain at least one uppercase letter",
|
||||
# )
|
||||
check_password_reqs(update_data.new_password)
|
||||
|
||||
# Get the user and update the password
|
||||
user = db.query(UserModel).filter(UserModel.id == current_user.id).first()
|
||||
if not user:
|
||||
@ -98,24 +101,10 @@ async def get_links(
|
||||
"""
|
||||
username = login_data.username
|
||||
password = login_data.password
|
||||
print(username)
|
||||
print(password)
|
||||
|
||||
# Make sure the password meets all of the requirements
|
||||
# if len(password) < 8:
|
||||
# raise HTTPException(
|
||||
# status_code=status.HTTP_400_BAD_REQUEST,
|
||||
# detail="Password must be at least 8 characters",
|
||||
# )
|
||||
# if not any(char.isdigit() for char in password):
|
||||
# raise HTTPException(
|
||||
# status_code=status.HTTP_400_BAD_REQUEST,
|
||||
# detail="Password must contain at least one digit",
|
||||
# )
|
||||
# if not any(char.isupper() for char in password):
|
||||
# raise HTTPException(
|
||||
# status_code=status.HTTP_400_BAD_REQUEST,
|
||||
# detail="Password must contain at least one uppercase letter",
|
||||
# )
|
||||
check_password_reqs(password)
|
||||
|
||||
# Make sure the username isn't taken
|
||||
user = db.query(UserModel).filter(UserModel.username == username).first()
|
||||
if user:
|
||||
|
@ -7,5 +7,5 @@ class LoginDataSchema(BaseModel):
|
||||
|
||||
|
||||
class UpdatePasswordSchema(BaseModel):
|
||||
password: str
|
||||
current_password: str
|
||||
new_password: str
|
||||
|
@ -28,11 +28,11 @@ def verify_password(plain_password, hashed_password):
|
||||
)
|
||||
|
||||
|
||||
def get_user(db, id: int):
|
||||
def get_user(db, username: str):
|
||||
"""
|
||||
Get the user object from the database
|
||||
"""
|
||||
user = db.query(UserModel).filter(UserModel.id == id).first()
|
||||
user = db.query(UserModel).filter(UserModel.username == username).first()
|
||||
if user:
|
||||
return UserInDB(**user.__dict__)
|
||||
|
||||
@ -46,6 +46,7 @@ def authenticate_user(db, username: str, password: str):
|
||||
if not user:
|
||||
return False
|
||||
if not verify_password(password, user.hashed_password):
|
||||
print("WHY")
|
||||
return False
|
||||
return user
|
||||
|
||||
@ -121,8 +122,9 @@ async def get_current_user(
|
||||
try:
|
||||
payload = jwt.decode(token, secret_key, algorithms=[algorithm])
|
||||
id: int = payload.get("sub")
|
||||
username: str = payload.get("username")
|
||||
refresh: bool = payload.get("refresh")
|
||||
if not id:
|
||||
if not id or not username:
|
||||
return raise_unauthorized()
|
||||
# For some reason, an access token was passed when a refresh
|
||||
# token was expected - some likely malicious activity
|
||||
@ -136,7 +138,7 @@ async def get_current_user(
|
||||
except InvalidTokenError:
|
||||
return raise_unauthorized()
|
||||
|
||||
user = get_user(db, id)
|
||||
user = get_user(db, username)
|
||||
if user is None:
|
||||
return raise_unauthorized()
|
||||
|
||||
|
26
app/util/check_password_reqs.py
Normal file
26
app/util/check_password_reqs.py
Normal file
@ -0,0 +1,26 @@
|
||||
from fastapi import HTTPException, status
|
||||
|
||||
|
||||
def check_password_reqs(password: str):
|
||||
"""
|
||||
Make sure the entered password meets the security requirements:
|
||||
1. At least 8 characters
|
||||
2. At least one digit
|
||||
3. At least one uppercase letter
|
||||
"""
|
||||
if len(password) < 8:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Password must be at least 8 characters",
|
||||
)
|
||||
if not any(char.isdigit() for char in password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Password must contain at least one digit",
|
||||
)
|
||||
if not any(char.isupper() for char in password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Password must contain at least one uppercase letter",
|
||||
)
|
||||
return
|
@ -4,10 +4,10 @@ from ua_parser import user_agent_parser
|
||||
|
||||
from database import SessionLocal
|
||||
import config
|
||||
from models import Link, Record
|
||||
from models import Link, Log
|
||||
|
||||
"""
|
||||
Create a new log record whenever a link is visited
|
||||
Create a new log whenever a link is visited
|
||||
"""
|
||||
|
||||
|
||||
@ -65,8 +65,8 @@ def log(link, ip, user_agent):
|
||||
browser = ua_string["user_agent"]["family"]
|
||||
os = f'{ua_string["os"]["family"]} {ua_string["os"]["major"]}'
|
||||
|
||||
# Create the log record and commit it to the database
|
||||
link_record = Record(
|
||||
# Create the log and commit it to the database
|
||||
new_log = Log(
|
||||
owner=owner,
|
||||
link=link,
|
||||
timestamp=timestamp,
|
||||
@ -77,7 +77,7 @@ def log(link, ip, user_agent):
|
||||
user_agent=user_agent,
|
||||
isp=isp,
|
||||
)
|
||||
db.add(link_record)
|
||||
db.add(new_log)
|
||||
db.commit()
|
||||
db.close()
|
||||
|
||||
|
@ -26,8 +26,8 @@ class Link(Base):
|
||||
expire_date = Column(DateTime, nullable=False)
|
||||
|
||||
|
||||
class Record(Base):
|
||||
__tablename__ = "records"
|
||||
class Log(Base):
|
||||
__tablename__ = "logs"
|
||||
id = Column(Integer, primary_key=True)
|
||||
owner = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||
link = Column(String, ForeignKey("links.link"), nullable=False)
|
||||
|
Loading…
x
Reference in New Issue
Block a user