diff --git a/cps/web.py b/cps/web.py index 7aa921e4..e88d423d 100644 --- a/cps/web.py +++ b/cps/web.py @@ -116,14 +116,35 @@ web = Blueprint('web', __name__) log = logger.create() # ################################### Login logic and rights management ############################################### +def _fetch_user_by_name(username): + return ub.session.query(ub.User).filter(func.lower(ub.User.nickname) == username.lower()).first() @lm.user_loader def load_user(user_id): return ub.session.query(ub.User).filter(ub.User.id == int(user_id)).first() -@lm.header_loader -def load_user_from_header(header_val): +@lm.request_loader +def load_user_from_request(request): + auth_header = request.headers.get("Authorization") + if auth_header: + user = load_user_from_auth_header(auth_header) + if user: + return user + + if config.config_allow_reverse_proxy_header_login: + rp_header_name = config.config_reverse_proxy_login_header_name + if rp_header_name: + rp_header = request.headers.get(rp_header_name) + if rp_header_username: + user = _fetch_user_by_name(rp_header_username) + if user: + return user + + return + + +def load_user_from_auth_header(header_val): if header_val.startswith('Basic '): header_val = header_val.replace('Basic ', '', 1) basic_username = basic_password = '' @@ -133,7 +154,7 @@ def load_user_from_header(header_val): basic_password = header_val.split(':')[1] except TypeError: pass - user = ub.session.query(ub.User).filter(func.lower(ub.User.nickname) == basic_username.lower()).first() + user = _fetch_user_by_name(basic_username) if user and check_password_hash(str(user.password), basic_password): return user return