diff --git a/app/main.py b/app/main.py index 0ac6d09..4280275 100644 --- a/app/main.py +++ b/app/main.py @@ -5,11 +5,11 @@ from fastapi.templating import Jinja2Templates from app.routes.auth_routes import router as auth_router from app.routes.links_routes import router as links_router from app.routes.user_routes import router as user_router -from typing import Annotated +from typing import Annotated, Union from fastapi.exceptions import HTTPException from starlette.status import HTTP_404_NOT_FOUND -from app.util.authentication import get_current_user_from_cookie +from app.util.authentication import get_current_user from app.util.db_dependency import get_db from app.util.log import log from app.schemas.auth_schemas import User @@ -55,10 +55,8 @@ async def signup(request: Request): @app.get("/dashboard") async def dashboard( - response: Annotated[ - User, RedirectResponse, Depends(get_current_user_from_cookie) - ], request: Request, + response: Union[User, RedirectResponse] = Depends(get_current_user), ): if isinstance(response, RedirectResponse): return response diff --git a/app/routes/auth_routes.py b/app/routes/auth_routes.py index 4d1c25e..ac75228 100644 --- a/app/routes/auth_routes.py +++ b/app/routes/auth_routes.py @@ -1,6 +1,6 @@ from fastapi import Depends, APIRouter, status, HTTPException from fastapi.security import OAuth2PasswordRequestForm -from fastapi.responses import Response +from fastapi.responses import Response, JSONResponse from datetime import timedelta from typing import Annotated @@ -21,7 +21,7 @@ async def login_for_access_token( form_data: Annotated[OAuth2PasswordRequestForm, Depends()], response: Response, db=Depends(get_db), -) -> Token: +): """ Return an access token for the user, if the given authentication details are correct """ @@ -45,20 +45,19 @@ async def login_for_access_token( data={"sub": user.id, "username": user.username, "refresh": True}, expires_delta=refresh_token_expires, ) - # response = JSONResponse(content={"success": True}) - # response.set_cookie( - # key="access_token", value=access_token, httponly=True, samesite="lax" - # ) - # response.set_cookie( - # key="refresh_token", value=refresh_token, httponly=True, samesite="lax" - # ) + response = JSONResponse(content={"success": True}) + response.set_cookie(key="access_token", value=access_token, httponly=True) + response.set_cookie( + key="refresh_token", value=refresh_token, httponly=True + ) + return response # For Swagger UI to work, must return the token - return Token( - access_token=access_token, - refresh_token=refresh_token, - token_type="bearer", - ) + # return Token( + # access_token=access_token, + # refresh_token=refresh_token, + # token_type="bearer", + # ) # Full native JWT support is not complete in FastAPI yet :( diff --git a/app/routes/links_routes.py b/app/routes/links_routes.py index 77811c8..90ca1bd 100644 --- a/app/routes/links_routes.py +++ b/app/routes/links_routes.py @@ -10,7 +10,7 @@ from app.util.db_dependency import get_db from models import Link, Log from app.schemas.links_schemas import URLSchema from app.schemas.auth_schemas import User -from app.util.authentication import get_current_user_from_token +from app.util.authentication import get_current_user router = APIRouter(prefix="/links", tags=["links"]) @@ -18,7 +18,7 @@ router = APIRouter(prefix="/links", tags=["links"]) @router.get("/", summary="Get all of the links associated with your account") async def get_links( - current_user: Annotated[User, Depends(get_current_user_from_token)], + current_user: Annotated[User, Depends(get_current_user)], db=Depends(get_db), ): links = db.query(Link).filter(Link.owner == current_user.id).all() @@ -32,7 +32,7 @@ async def get_links( @router.post("/", summary="Create a new link") async def create_link( url: URLSchema, - current_user: Annotated[User, Depends(get_current_user_from_token)], + current_user: Annotated[User, Depends(get_current_user)], db=Depends(get_db), ): # Check if the URL is valid @@ -51,8 +51,6 @@ async def create_link( link=link_path, owner=current_user.id, redirect_link=url.url, - expire_date=datetime.datetime.now() - + datetime.timedelta(days=30), ) db.add(new_link) db.commit() @@ -60,13 +58,13 @@ async def create_link( except: continue - return new_link + return {"link": link_path, "expire_date": new_link.expire_date} @router.delete("/{link}", summary="Delete a link") async def delete_link( link: Annotated[str, Path(title="Link to delete")], - current_user: Annotated[User, Depends(get_current_user_from_token)], + current_user: Annotated[User, Depends(get_current_user)], db=Depends(get_db), ): """ @@ -99,7 +97,7 @@ async def delete_link( @router.get("/{link}/logs", summary="Get all logs associated with a link") async def get_link_logs( link: Annotated[str, Path(title="Link to get logs for")], - current_user: Annotated[User, Depends(get_current_user_from_token)], + current_user: Annotated[User, Depends(get_current_user)], db=Depends(get_db), ): """ @@ -118,15 +116,20 @@ async def get_link_logs( detail="Link not associated with your account", ) - # Get and return all of the logs - logs = db.query(Log).filter(Log.link == link.link).all() + # Get and return all of the logs - ordered by timestamp + logs = ( + db.query(Log) + .filter(Log.link == link.link) + .order_by(Log.timestamp.desc()) + .all() + ) return logs @router.delete("/{link}/logs", summary="Delete logs associated with a link") async def delete_link_logs( link: Annotated[str, Path(title="Link to delete logs for")], - current_user: Annotated[User, Depends(get_current_user_from_token)], + current_user: Annotated[User, Depends(get_current_user)], db=Depends(get_db), ): """ diff --git a/app/routes/user_routes.py b/app/routes/user_routes.py index 7fcc768..12b2828 100644 --- a/app/routes/user_routes.py +++ b/app/routes/user_routes.py @@ -13,7 +13,7 @@ from app.schemas.user_schemas import * from models import User as UserModel from app.util.authentication import ( verify_password, - get_current_user_from_token, + get_current_user, ) @@ -23,7 +23,7 @@ router = APIRouter(prefix="/users", tags=["users"]) @router.delete("/{user_id}", summary="Delete your account") async def delete_user( user_id: Annotated[int, Path(title="Link to delete")], - current_user: Annotated[User, Depends(get_current_user_from_token)], + current_user: Annotated[User, Depends(get_current_user)], db=Depends(get_db), ): """ @@ -53,7 +53,7 @@ async def delete_user( async def update_pass( user_id: Annotated[int, Path(title="Link to update")], update_data: UpdatePasswordSchema, - current_user: Annotated[User, Depends(get_current_user_from_token)], + current_user: Annotated[User, Depends(get_current_user)], db=Depends(get_db), ): """ diff --git a/app/templates/dashboard.html b/app/templates/dashboard.html index f1c98e3..c2c9ebc 100644 --- a/app/templates/dashboard.html +++ b/app/templates/dashboard.html @@ -7,9 +7,15 @@
- -

Your Username: {{ user }}

- + + + + + + + + + @@ -31,17 +37,55 @@ font-size: 25px; color: #ccc; } + - button { - display: block; - margin: 10px auto; - width: 200px; - border-radius: 5px; - padding: 15px; - color: #ccc; - background-color: #415eac; - border: none; - font-size: 17px; - cursor: pointer; + \ No newline at end of file diff --git a/app/templates/signup.html b/app/templates/signup.html index 446aaeb..32962b7 100644 --- a/app/templates/signup.html +++ b/app/templates/signup.html @@ -91,21 +91,22 @@ // Prevent default form submission event.preventDefault(); + // Get form data const formData = new FormData(this); - // Send POST request to /signup containing form data + + // Send POST request const response = await fetch('/api/users/register', { method: 'POST', body: formData }); if (response.status != 200) { - const data = await response.json(); - + const data = await response.json() + document.getElementById('error').style.display = 'block'; document.getElementById('error').innerText = data.detail; - } - else { - window.location.href = '/login'; + } else { + window.location.href = '/dashboard'; } }); \ No newline at end of file diff --git a/app/util/authentication.py b/app/util/authentication.py index b270c6d..0bc7e09 100644 --- a/app/util/authentication.py +++ b/app/util/authentication.py @@ -1,15 +1,15 @@ import random import bcrypt -from fastapi import Depends, HTTPException, status, Cookie +from fastapi import Depends, HTTPException, status, Request, Cookie from fastapi.security import OAuth2PasswordBearer from fastapi.responses import RedirectResponse from jwt.exceptions import InvalidTokenError from datetime import datetime, timedelta -from typing import Annotated +from typing import Annotated, Optional import jwt from app.util.db_dependency import get_db -from sqlalchemy.orm import sessionmaker +from sqlalchemy.orm import Session from app.schemas.auth_schemas import * from models import User as UserModel @@ -62,30 +62,6 @@ def create_access_token(data: dict, expires_delta: timedelta): return encoded_jwt -async def get_current_user_from_cookie( - access_token: str = Cookie(None), db=Depends(get_db) -): - """ - Return the user based on the access token in the cookie - - Used for authentication into UI pages - so if no cookie - exists, redirect to login page rather than returning a 401 - - Also pass is_ui=True to alert get_current_user that we need - to use RedirectResponse rather than raising an HTTPException - """ - if access_token: - return await get_current_user(access_token, is_ui=True, db=db) - return RedirectResponse(url="/login") - - -async def get_current_user_from_token( - token: Annotated[str, Depends(oauth2_scheme)], - db=Depends(get_db), -): - return await get_current_user(token, db=db) - - # Backwards kind of way to get refresh token support # `refresh_get_current_user` is only called from /refresh # and alerts `get_current_user` that it should expect a refresh token @@ -97,10 +73,8 @@ async def refresh_get_current_user( async def get_current_user( - token: str, - is_refresh: bool = False, - is_ui: bool = False, - db: sessionmaker = None, + request: Request, + db=Depends(get_db), ): """ Return the current user based on the token @@ -110,9 +84,16 @@ async def get_current_user( Otherwise, the request is from an API and we should return a 401 """ + # If the request is from /api/auth/refresh, it is a request to get + # a new access token using a refresh token + if request.url.path == "/api/auth/refresh": + token = request.cookies.get("refresh_token") + is_refresh = True + else: + token = request.cookies.get("access_token") + is_refresh = False + def raise_unauthorized(): - if is_ui: - return RedirectResponse(url="/login") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", @@ -126,12 +107,8 @@ async def get_current_user( refresh: bool = payload.get("refresh") if not id or not username: return raise_unauthorized() - # For some reason, an access token was passed when a refresh - # token was expected - some likely malicious activity - if not refresh and is_refresh: - return raise_unauthorized() - # If the token passed is a refresh token and the function - # is not expecting a refresh token, raise an error + + # Make sure that a refresh token was not passed to any other endpoint if refresh and not is_refresh: return raise_unauthorized() diff --git a/app/util/log.py b/app/util/log.py index b84c8a0..1d21445 100644 --- a/app/util/log.py +++ b/app/util/log.py @@ -60,7 +60,6 @@ def log(link, ip, user_agent): # Get the location and ISP of the user location, isp = ip_to_location(ip) - timestamp = datetime.datetime.now() ua_string = user_agent_parser.Parse(user_agent) browser = ua_string["user_agent"]["family"] os = f'{ua_string["os"]["family"]} {ua_string["os"]["major"]}' @@ -69,7 +68,6 @@ def log(link, ip, user_agent): new_log = Log( owner=owner, link=link, - timestamp=timestamp, ip=ip, location=location, browser=browser, diff --git a/models.py b/models.py index 6061661..27f1436 100644 --- a/models.py +++ b/models.py @@ -6,6 +6,7 @@ from sqlalchemy import ( Text, DateTime, ) +import datetime from database import Base @@ -23,7 +24,10 @@ class Link(Base): link = Column(String, primary_key=True) owner = Column(Integer, ForeignKey("users.id"), nullable=False) redirect_link = Column(String, nullable=False) - expire_date = Column(DateTime, nullable=False) + expire_date = Column( + DateTime, + default=datetime.datetime.utcnow() + datetime.timedelta(days=30), + ) class Log(Base): @@ -31,7 +35,7 @@ class Log(Base): id = Column(Integer, primary_key=True) owner = Column(Integer, ForeignKey("users.id"), nullable=False) link = Column(String, ForeignKey("links.link"), nullable=False) - timestamp = Column(DateTime, nullable=False) + timestamp = Column(DateTime, default=datetime.datetime.utcnow()) ip = Column(String, nullable=False) location = Column(String, nullable=False) browser = Column(String, nullable=False)
IDTimestampIPLocationISP