Fix auth and organization/standards
This commit is contained in:
parent
6f7e810916
commit
e944df3d7d
10
app/main.py
10
app/main.py
@ -1,6 +1,6 @@
|
|||||||
from fastapi import FastAPI, Depends, Request
|
from fastapi import FastAPI, Depends, Request
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import RedirectResponse
|
from fastapi.responses import RedirectResponse, JSONResponse
|
||||||
from fastapi.templating import Jinja2Templates
|
from fastapi.templating import Jinja2Templates
|
||||||
from app.routes.auth_routes import router as auth_router
|
from app.routes.auth_routes import router as auth_router
|
||||||
from app.routes.links_routes import router as links_router
|
from app.routes.links_routes import router as links_router
|
||||||
@ -151,3 +151,11 @@ async def redirect_to_docs():
|
|||||||
@app.exception_handler(HTTP_404_NOT_FOUND)
|
@app.exception_handler(HTTP_404_NOT_FOUND)
|
||||||
async def custom_404_handler(request: Request, exc: HTTPException):
|
async def custom_404_handler(request: Request, exc: HTTPException):
|
||||||
return RedirectResponse(url="/login")
|
return RedirectResponse(url="/login")
|
||||||
|
|
||||||
|
|
||||||
|
@app.exception_handler(HTTPException)
|
||||||
|
async def http_exception_handler(request, exc):
|
||||||
|
return JSONResponse(
|
||||||
|
status_code=exc.status_code,
|
||||||
|
content={"detail": f"{exc.detail}"},
|
||||||
|
)
|
||||||
|
@ -26,7 +26,7 @@ async def login_for_access_token(
|
|||||||
Return an access token for the user, if the given authentication details are correct
|
Return an access token for the user, if the given authentication details are correct
|
||||||
"""
|
"""
|
||||||
user = authenticate_user(db, form_data.username, form_data.password)
|
user = authenticate_user(db, form_data.username, form_data.password)
|
||||||
print(user)
|
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
@ -35,14 +35,14 @@ async def login_for_access_token(
|
|||||||
)
|
)
|
||||||
access_token_expires = timedelta(minutes=15)
|
access_token_expires = timedelta(minutes=15)
|
||||||
access_token = create_access_token(
|
access_token = create_access_token(
|
||||||
data={"sub": user.id, "refresh": False},
|
data={"sub": user.id, "username": user.username, "refresh": False},
|
||||||
expires_delta=access_token_expires,
|
expires_delta=access_token_expires,
|
||||||
)
|
)
|
||||||
# Create a refresh token - just an access token with a longer expiry
|
# Create a refresh token - just an access token with a longer expiry
|
||||||
# and more restrictions ("refresh" is True)
|
# and more restrictions ("refresh" is True)
|
||||||
refresh_token_expires = timedelta(days=1)
|
refresh_token_expires = timedelta(days=1)
|
||||||
refresh_token = create_access_token(
|
refresh_token = create_access_token(
|
||||||
data={"sub": user.id, "refresh": True},
|
data={"sub": user.id, "username": user.username, "refresh": True},
|
||||||
expires_delta=refresh_token_expires,
|
expires_delta=refresh_token_expires,
|
||||||
)
|
)
|
||||||
# response = JSONResponse(content={"success": True})
|
# response = JSONResponse(content={"success": True})
|
||||||
|
@ -7,7 +7,7 @@ import datetime
|
|||||||
import validators
|
import validators
|
||||||
|
|
||||||
from app.util.db_dependency import get_db
|
from app.util.db_dependency import get_db
|
||||||
from models import Link, Record
|
from models import Link, Log
|
||||||
from app.schemas.links_schemas import URLSchema
|
from app.schemas.links_schemas import URLSchema
|
||||||
from app.schemas.auth_schemas import User
|
from app.schemas.auth_schemas import User
|
||||||
from app.util.authentication import get_current_user_from_token
|
from app.util.authentication import get_current_user_from_token
|
||||||
@ -69,6 +69,9 @@ async def delete_link(
|
|||||||
current_user: Annotated[User, Depends(get_current_user_from_token)],
|
current_user: Annotated[User, Depends(get_current_user_from_token)],
|
||||||
db=Depends(get_db),
|
db=Depends(get_db),
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Delete a link and all of the logs associated with it
|
||||||
|
"""
|
||||||
link = link.upper()
|
link = link.upper()
|
||||||
# Get the link and check the owner
|
# Get the link and check the owner
|
||||||
link = db.query(Link).filter(Link.link == link).first()
|
link = db.query(Link).filter(Link.link == link).first()
|
||||||
@ -82,10 +85,10 @@ async def delete_link(
|
|||||||
detail="Link not associated with your account",
|
detail="Link not associated with your account",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get and delete all records associated with the link
|
# Get and delete all logsk
|
||||||
records = db.query(Record).filter(Record.link == link.link).all()
|
logs = db.query(Log).filter(Log.link == link.link).all()
|
||||||
for record in records:
|
for log in logs:
|
||||||
db.delete(record)
|
db.delete(log)
|
||||||
# Delete the link
|
# Delete the link
|
||||||
db.delete(link)
|
db.delete(link)
|
||||||
db.commit()
|
db.commit()
|
||||||
@ -93,15 +96,15 @@ async def delete_link(
|
|||||||
return status.HTTP_204_NO_CONTENT
|
return status.HTTP_204_NO_CONTENT
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.get("/{link}/logs", summary="Get all logs associated with a link")
|
||||||
"/{link}/records",
|
async def get_link_logs(
|
||||||
summary="Get all of the IP log records associated with a link",
|
link: Annotated[str, Path(title="Link to get logs for")],
|
||||||
)
|
|
||||||
async def get_link_records(
|
|
||||||
link: Annotated[str, Path(title="Link to get records for")],
|
|
||||||
current_user: Annotated[User, Depends(get_current_user_from_token)],
|
current_user: Annotated[User, Depends(get_current_user_from_token)],
|
||||||
db=Depends(get_db),
|
db=Depends(get_db),
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Get all of the IP logs associated with a link
|
||||||
|
"""
|
||||||
link = link.upper()
|
link = link.upper()
|
||||||
# Get the link and check the owner
|
# Get the link and check the owner
|
||||||
link = db.query(Link).filter(Link.link == link).first()
|
link = db.query(Link).filter(Link.link == link).first()
|
||||||
@ -115,20 +118,20 @@ async def get_link_records(
|
|||||||
detail="Link not associated with your account",
|
detail="Link not associated with your account",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get and return all of the records associated with the link
|
# Get and return all of the logs
|
||||||
records = db.query(Record).filter(Record.link == link.link).all()
|
logs = db.query(Log).filter(Log.link == link.link).all()
|
||||||
return records
|
return logs
|
||||||
|
|
||||||
|
|
||||||
@router.delete(
|
@router.delete("/{link}/logs", summary="Delete logs associated with a link")
|
||||||
"/{link}/records",
|
async def delete_link_logs(
|
||||||
summary="Delete all of the IP log records associated with a link",
|
link: Annotated[str, Path(title="Link to delete logs for")],
|
||||||
)
|
|
||||||
async def delete_link_records(
|
|
||||||
link: Annotated[str, Path(title="Link to delete records for")],
|
|
||||||
current_user: Annotated[User, Depends(get_current_user_from_token)],
|
current_user: Annotated[User, Depends(get_current_user_from_token)],
|
||||||
db=Depends(get_db),
|
db=Depends(get_db),
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Delete all of the IP logs associated with a link
|
||||||
|
"""
|
||||||
link = link.upper()
|
link = link.upper()
|
||||||
# Get the link and check the owner
|
# Get the link and check the owner
|
||||||
link = db.query(Link).filter(Link.link == link).first()
|
link = db.query(Link).filter(Link.link == link).first()
|
||||||
@ -142,10 +145,10 @@ async def delete_link_records(
|
|||||||
detail="Link not associated with your account",
|
detail="Link not associated with your account",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get all of the records associated with the link and delete them
|
# Get all of the logs
|
||||||
records = db.query(Record).filter(Record.link == link.link).all()
|
logs = db.query(Log).filter(Log.link == link.link).all()
|
||||||
for record in records:
|
for log in logs:
|
||||||
db.delete(record)
|
db.delete(log)
|
||||||
db.commit()
|
db.commit()
|
||||||
|
|
||||||
return status.HTTP_204_NO_CONTENT
|
return status.HTTP_204_NO_CONTENT
|
||||||
|
@ -4,14 +4,16 @@ from typing import Annotated
|
|||||||
import string
|
import string
|
||||||
import bcrypt
|
import bcrypt
|
||||||
import random
|
import random
|
||||||
import datetime
|
|
||||||
import validators
|
|
||||||
|
|
||||||
from app.util.db_dependency import get_db
|
from app.util.db_dependency import get_db
|
||||||
|
from app.util.check_password_reqs import check_password_reqs
|
||||||
from app.schemas.auth_schemas import User
|
from app.schemas.auth_schemas import User
|
||||||
from app.schemas.user_schemas import *
|
from app.schemas.user_schemas import *
|
||||||
from models import User as UserModel
|
from models import User as UserModel
|
||||||
from app.util.authentication import get_current_user_from_token
|
from app.util.authentication import (
|
||||||
|
verify_password,
|
||||||
|
get_current_user_from_token,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/users", tags=["users"])
|
router = APIRouter(prefix="/users", tags=["users"])
|
||||||
@ -26,23 +28,27 @@ async def delete_user(
|
|||||||
"""
|
"""
|
||||||
Delete the user account associated with the current user
|
Delete the user account associated with the current user
|
||||||
"""
|
"""
|
||||||
|
# No editing others accounts
|
||||||
if user_id != current_user.id:
|
if user_id != current_user.id:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail="You can only delete your own account",
|
detail="You can only delete your own account",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Get the user and delete them
|
||||||
user = db.query(UserModel).filter(UserModel.id == current_user.id).first()
|
user = db.query(UserModel).filter(UserModel.id == current_user.id).first()
|
||||||
if not user:
|
if not user:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
detail="User not found",
|
detail="User not found",
|
||||||
)
|
)
|
||||||
|
|
||||||
db.delete(user)
|
db.delete(user)
|
||||||
db.commit()
|
db.commit()
|
||||||
return status.HTTP_204_NO_CONTENT
|
return status.HTTP_204_NO_CONTENT
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{user_id}", summary="Update your account password")
|
@router.post("/{user_id}/password", summary="Update your account password")
|
||||||
async def update_pass(
|
async def update_pass(
|
||||||
user_id: Annotated[int, Path(title="Link to update")],
|
user_id: Annotated[int, Path(title="Link to update")],
|
||||||
update_data: UpdatePasswordSchema,
|
update_data: UpdatePasswordSchema,
|
||||||
@ -57,22 +63,19 @@ async def update_pass(
|
|||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail="You can only update your own account",
|
detail="You can only update your own account",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Make sure that they entered the correct current password
|
||||||
|
if not verify_password(
|
||||||
|
update_data.current_password, current_user.hashed_password
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Incorrect current password",
|
||||||
|
)
|
||||||
|
|
||||||
# Make sure the password meets all of the requirements
|
# Make sure the password meets all of the requirements
|
||||||
# if len(update_data.new_password) < 8:
|
check_password_reqs(update_data.new_password)
|
||||||
# raise HTTPException(
|
|
||||||
# status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
# detail="Password must be at least 8 characters",
|
|
||||||
# )
|
|
||||||
# if not any(char.isdigit() for char in update_data.new_password):
|
|
||||||
# raise HTTPException(
|
|
||||||
# status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
# detail="Password must contain at least one digit",
|
|
||||||
# )
|
|
||||||
# if not any(char.isupper() for char in update_data.new_password):
|
|
||||||
# raise HTTPException(
|
|
||||||
# status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
# detail="Password must contain at least one uppercase letter",
|
|
||||||
# )
|
|
||||||
# Get the user and update the password
|
# Get the user and update the password
|
||||||
user = db.query(UserModel).filter(UserModel.id == current_user.id).first()
|
user = db.query(UserModel).filter(UserModel.id == current_user.id).first()
|
||||||
if not user:
|
if not user:
|
||||||
@ -98,24 +101,10 @@ async def get_links(
|
|||||||
"""
|
"""
|
||||||
username = login_data.username
|
username = login_data.username
|
||||||
password = login_data.password
|
password = login_data.password
|
||||||
print(username)
|
|
||||||
print(password)
|
|
||||||
# Make sure the password meets all of the requirements
|
# Make sure the password meets all of the requirements
|
||||||
# if len(password) < 8:
|
check_password_reqs(password)
|
||||||
# raise HTTPException(
|
|
||||||
# status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
# detail="Password must be at least 8 characters",
|
|
||||||
# )
|
|
||||||
# if not any(char.isdigit() for char in password):
|
|
||||||
# raise HTTPException(
|
|
||||||
# status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
# detail="Password must contain at least one digit",
|
|
||||||
# )
|
|
||||||
# if not any(char.isupper() for char in password):
|
|
||||||
# raise HTTPException(
|
|
||||||
# status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
# detail="Password must contain at least one uppercase letter",
|
|
||||||
# )
|
|
||||||
# Make sure the username isn't taken
|
# Make sure the username isn't taken
|
||||||
user = db.query(UserModel).filter(UserModel.username == username).first()
|
user = db.query(UserModel).filter(UserModel.username == username).first()
|
||||||
if user:
|
if user:
|
||||||
|
@ -7,5 +7,5 @@ class LoginDataSchema(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class UpdatePasswordSchema(BaseModel):
|
class UpdatePasswordSchema(BaseModel):
|
||||||
password: str
|
current_password: str
|
||||||
new_password: str
|
new_password: str
|
||||||
|
@ -28,11 +28,11 @@ def verify_password(plain_password, hashed_password):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_user(db, id: int):
|
def get_user(db, username: str):
|
||||||
"""
|
"""
|
||||||
Get the user object from the database
|
Get the user object from the database
|
||||||
"""
|
"""
|
||||||
user = db.query(UserModel).filter(UserModel.id == id).first()
|
user = db.query(UserModel).filter(UserModel.username == username).first()
|
||||||
if user:
|
if user:
|
||||||
return UserInDB(**user.__dict__)
|
return UserInDB(**user.__dict__)
|
||||||
|
|
||||||
@ -46,6 +46,7 @@ def authenticate_user(db, username: str, password: str):
|
|||||||
if not user:
|
if not user:
|
||||||
return False
|
return False
|
||||||
if not verify_password(password, user.hashed_password):
|
if not verify_password(password, user.hashed_password):
|
||||||
|
print("WHY")
|
||||||
return False
|
return False
|
||||||
return user
|
return user
|
||||||
|
|
||||||
@ -121,8 +122,9 @@ async def get_current_user(
|
|||||||
try:
|
try:
|
||||||
payload = jwt.decode(token, secret_key, algorithms=[algorithm])
|
payload = jwt.decode(token, secret_key, algorithms=[algorithm])
|
||||||
id: int = payload.get("sub")
|
id: int = payload.get("sub")
|
||||||
|
username: str = payload.get("username")
|
||||||
refresh: bool = payload.get("refresh")
|
refresh: bool = payload.get("refresh")
|
||||||
if not id:
|
if not id or not username:
|
||||||
return raise_unauthorized()
|
return raise_unauthorized()
|
||||||
# For some reason, an access token was passed when a refresh
|
# For some reason, an access token was passed when a refresh
|
||||||
# token was expected - some likely malicious activity
|
# token was expected - some likely malicious activity
|
||||||
@ -136,7 +138,7 @@ async def get_current_user(
|
|||||||
except InvalidTokenError:
|
except InvalidTokenError:
|
||||||
return raise_unauthorized()
|
return raise_unauthorized()
|
||||||
|
|
||||||
user = get_user(db, id)
|
user = get_user(db, username)
|
||||||
if user is None:
|
if user is None:
|
||||||
return raise_unauthorized()
|
return raise_unauthorized()
|
||||||
|
|
||||||
|
26
app/util/check_password_reqs.py
Normal file
26
app/util/check_password_reqs.py
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
from fastapi import HTTPException, status
|
||||||
|
|
||||||
|
|
||||||
|
def check_password_reqs(password: str):
|
||||||
|
"""
|
||||||
|
Make sure the entered password meets the security requirements:
|
||||||
|
1. At least 8 characters
|
||||||
|
2. At least one digit
|
||||||
|
3. At least one uppercase letter
|
||||||
|
"""
|
||||||
|
if len(password) < 8:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Password must be at least 8 characters",
|
||||||
|
)
|
||||||
|
if not any(char.isdigit() for char in password):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Password must contain at least one digit",
|
||||||
|
)
|
||||||
|
if not any(char.isupper() for char in password):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Password must contain at least one uppercase letter",
|
||||||
|
)
|
||||||
|
return
|
@ -4,10 +4,10 @@ from ua_parser import user_agent_parser
|
|||||||
|
|
||||||
from database import SessionLocal
|
from database import SessionLocal
|
||||||
import config
|
import config
|
||||||
from models import Link, Record
|
from models import Link, Log
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Create a new log record whenever a link is visited
|
Create a new log whenever a link is visited
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
@ -65,8 +65,8 @@ def log(link, ip, user_agent):
|
|||||||
browser = ua_string["user_agent"]["family"]
|
browser = ua_string["user_agent"]["family"]
|
||||||
os = f'{ua_string["os"]["family"]} {ua_string["os"]["major"]}'
|
os = f'{ua_string["os"]["family"]} {ua_string["os"]["major"]}'
|
||||||
|
|
||||||
# Create the log record and commit it to the database
|
# Create the log and commit it to the database
|
||||||
link_record = Record(
|
new_log = Log(
|
||||||
owner=owner,
|
owner=owner,
|
||||||
link=link,
|
link=link,
|
||||||
timestamp=timestamp,
|
timestamp=timestamp,
|
||||||
@ -77,7 +77,7 @@ def log(link, ip, user_agent):
|
|||||||
user_agent=user_agent,
|
user_agent=user_agent,
|
||||||
isp=isp,
|
isp=isp,
|
||||||
)
|
)
|
||||||
db.add(link_record)
|
db.add(new_log)
|
||||||
db.commit()
|
db.commit()
|
||||||
db.close()
|
db.close()
|
||||||
|
|
||||||
|
@ -26,8 +26,8 @@ class Link(Base):
|
|||||||
expire_date = Column(DateTime, nullable=False)
|
expire_date = Column(DateTime, nullable=False)
|
||||||
|
|
||||||
|
|
||||||
class Record(Base):
|
class Log(Base):
|
||||||
__tablename__ = "records"
|
__tablename__ = "logs"
|
||||||
id = Column(Integer, primary_key=True)
|
id = Column(Integer, primary_key=True)
|
||||||
owner = Column(Integer, ForeignKey("users.id"), nullable=False)
|
owner = Column(Integer, ForeignKey("users.id"), nullable=False)
|
||||||
link = Column(String, ForeignKey("links.link"), nullable=False)
|
link = Column(String, ForeignKey("links.link"), nullable=False)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user