From 2472e03a69a89ced388cff977e8a057cd6b18ac4 Mon Sep 17 00:00:00 2001 From: Ozzieisaacs Date: Mon, 5 Sep 2022 18:45:24 +0200 Subject: [PATCH] Bugfix ratelimiter kobo --- cps/kobo_auth.py | 9 +++++++-- cps/main.py | 2 ++ cps/opds.py | 2 +- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/cps/kobo_auth.py b/cps/kobo_auth.py index ea9b71b1..43e8c4fe 100644 --- a/cps/kobo_auth.py +++ b/cps/kobo_auth.py @@ -64,11 +64,12 @@ from datetime import datetime from os import urandom from functools import wraps -from flask import g, Blueprint, url_for, abort, request +from flask import g, Blueprint, abort, request from flask_login import login_user, current_user, login_required from flask_babel import gettext as _ +from flask_limiter import RateLimitExceeded -from . import logger, config, calibre_db, db, helper, ub, lm +from . import logger, config, calibre_db, db, helper, ub, lm, limiter from .render_template import render_title_template log = logger.create() @@ -151,6 +152,10 @@ def requires_kobo_auth(f): def inner(*args, **kwargs): auth_token = get_auth_token() if auth_token is not None: + try: + limiter.check() + except RateLimitExceeded: + return abort(429) user = ( ub.session.query(ub.User) .join(ub.RemoteAuthToken) diff --git a/cps/main.py b/cps/main.py index c3d8dd01..99f7391c 100644 --- a/cps/main.py +++ b/cps/main.py @@ -44,6 +44,7 @@ def main(): try: from .kobo import kobo, get_kobo_activated from .kobo_auth import kobo_auth + from flask_limiter.util import get_remote_address kobo_available = get_kobo_activated() except (ImportError, AttributeError): # Catch also error for not installed flask-WTF (missing csrf decorator) kobo_available = False @@ -73,6 +74,7 @@ def main(): if kobo_available: app.register_blueprint(kobo) app.register_blueprint(kobo_auth) + limiter.limit("10/minute", key_func=get_remote_address)(kobo) if oauth_available: app.register_blueprint(oauth) success = web_server.start() diff --git a/cps/opds.py b/cps/opds.py index bf3d77b8..f4d01209 100644 --- a/cps/opds.py +++ b/cps/opds.py @@ -483,7 +483,7 @@ def check_auth(username, password): try: limiter.check() except RateLimitExceeded: - return False + return abort(429) # False try: username = username.encode('windows-1252') except UnicodeEncodeError: