diff --git a/app/main.py b/app/main.py index 0d48123..2556aeb 100644 --- a/app/main.py +++ b/app/main.py @@ -12,10 +12,10 @@ import os import string import random -from models import User +from models import User, Link from database import * from app.util.log import log -from var import BASE_URL +from var import BASE_URL class FlaskUser(UserMixin): @@ -139,18 +139,23 @@ Log all records for visits to shortened links @app.route("/", methods=["GET"]) def log_redirect(link): - # If the `link` is more than 5 characters, ignore - if len(link) > 5: + # If `link` is not exactly 5 characters, return redirect to base url + if len(link) != 5: return redirect(BASE_URL) - # If the `link` is one of the registered routes, ignore - if link in ["login", "signup", "dashboard", "logout", "api"]: - return - - ip = request.remote_addr - user_agent = request.headers.get("user-agent") - redirect_link = log(link, ip, user_agent) - return redirect(redirect_link) + # Make sure the link exists in the database + db = SessionLocal() + link_record = db.query(Link).filter(Link.link == link).first() + if not link_record: + db.close() + return redirect(BASE_URL) + else: + # Log the visit + ip = request.remote_addr + user_agent = request.headers.get("User-Agent") + log(link, ip, user_agent) + db.close() + return redirect(link_record.redirect_link) @app.errorhandler(401)