From 699c1e28066b4832be591297453cd0adf568fe2e Mon Sep 17 00:00:00 2001 From: kaiyou <pierre@jaury.eu> Date: Wed, 13 Nov 2019 17:17:15 +0100 Subject: [PATCH] Store authorization codes in redis --- hiboo/__init__.py | 1 + hiboo/sso/oidc.py | 56 +++++++++++++++++++---------------------------- hiboo/utils.py | 24 ++++++++++++++++++++ requirements.txt | 1 + 4 files changed, 48 insertions(+), 34 deletions(-) diff --git a/hiboo/__init__.py b/hiboo/__init__.py index 8f13aa4c..abe89ee0 100644 --- a/hiboo/__init__.py +++ b/hiboo/__init__.py @@ -16,6 +16,7 @@ def create_app_from_config(config): utils.login.init_app(app) utils.login.user_loader(models.User.get) utils.migrate.init_app(app, models.db) + utils.redis.init_app(app) # Initialize debugging tools if app.config.get("DEBUG"): diff --git a/hiboo/sso/oidc.py b/hiboo/sso/oidc.py index 1a74ef38..fe7f234d 100644 --- a/hiboo/sso/oidc.py +++ b/hiboo/sso/oidc.py @@ -7,6 +7,7 @@ from hiboo.sso import forms, blueprint from hiboo import models, utils, profile import flask +import time class Config(object): @@ -17,7 +18,7 @@ class Config(object): def derive_form(cls, form): """ Add required fields to a form. """ - return type('NewForm', (forms.OIDCForm, form), {}) + return tmodelsype('NewForm', (forms.OIDCForm, form), {}) @classmethod def populate_service(cls, form, service): @@ -54,22 +55,18 @@ class Config(object): ) -class Client(models.db.Model, sqla_oauth2.OAuth2ClientMixin): +class Client(sqla_oauth2.OAuth2ClientMixin): """ OIDC client that only supports authorization code, implicit and hybrid flows. """ scope = "openid" - expire_after = 3600 def __init__(self, service): self.service = service - super(Client, self).__init__( - client_id=service.config["client_id"], - client_secret=service.config["client_secret"], - client_metadata=service.config - ) - # The authorization server is specific to a client + self.client_id = service.config["client_id"] + self.client_secret = service.config["client_secret"] + self.client_metadata = service.config self.authorization = flask_oauth2.AuthorizationServer( query_client=self.query_client, save_token=self.save_token, @@ -85,7 +82,6 @@ class Client(models.db.Model, sqla_oauth2.OAuth2ClientMixin): return self if client_id == self.client_id else None def save_token(self, token, request): - # TODO: atm we do not save any token pass def get_jwt_config(self): @@ -93,7 +89,7 @@ class Client(models.db.Model, sqla_oauth2.OAuth2ClientMixin): 'key': self.service.config["jwt_key"], 'alg': self.service.config["jwt_alg"], 'iss': flask.url_for("sso.oidc_token", service_uuid=self.service.uuid, _external=True), - 'exp': self.expire_after, + 'exp': 3600, } @classmethod @@ -110,43 +106,35 @@ class Client(models.db.Model, sqla_oauth2.OAuth2ClientMixin): @classmethod def exists_nonce(cls, nonce, request): - return bool(AuthorizationCode.query.filter_by( - nonce=nonce, client_id=request.client_id).first() - ) + return bool(utils.redis.get("nonce:{}".format(nonce))) -class AuthorizationCode(models.db.Model, sqla_oauth2.OAuth2AuthorizationCodeMixin): +class AuthorizationCode(utils.SerializableObj, sqla_oauth2.OAuth2AuthorizationCodeMixin): """ Authorization code object for storage """ - __tablename__ = "oidc_authorization_code" - - user_id = models.db.Column(models.db.Text()) @classmethod def create(cls, client, grant_user, request): - code = gen_salt(48) # TODO - authorization_code = AuthorizationCode( - code=code, nonce=request.data.get('nonce'), - client_id=client.client_id, - redirect_uri=request.redirect_uri, - scope=request.scope, - user_id=grant_user.uuid + obj = cls( + code=gen_salt(48), nonce=request.data.get("nonce") or "", + client_id=client.client_id, redirect_uri=request.redirect_uri, + scope=request.scope, user_id=grant_user.uuid, + auth_time=int(time.time()) ) - models.db.session.add(authorization_code) - models.db.session.commit() - return code + utils.redis.hmset("code:{}".format(obj.code), obj.serialize()) + if obj.nonce: + utils.redis.set("nonce:{}".format(obj.nonce), obj.code) + return obj.code @classmethod def get(cls, code, client): - return AuthorizationCode.query.filter_by( - client_id=client.client_id, - code=code - ).first() + obj = cls.unserialize(utils.redis.hgetall("code:{}".format(code))) + if obj and obj.client_id == client.client_id: + return obj @classmethod def delete(cls, authorization_code): - models.db.session.delete(authorization_code) - models.db.session.commit() + utils.redis.delete("code:{}".format(authorization_code)) class AuthorizationCodeGrant(oauth2.grants.AuthorizationCodeGrant): diff --git a/hiboo/utils.py b/hiboo/utils.py index d5d319fd..f24b72b1 100644 --- a/hiboo/utils.py +++ b/hiboo/utils.py @@ -3,6 +3,7 @@ import flask_login import flask_migrate import flask_babel import flask_limiter +import flask_redis from werkzeug.contrib import fixers from werkzeug import routing @@ -75,6 +76,25 @@ def display_help(identifier): return result +class SerializableObj(object): + def __init__(self, **kwargs): + self.__dict__.update(**kwargs) + self.__keys__ = list(kwargs.keys()) + + @classmethod + def unserialize(cls, kwargs): + return cls(**{ + key.decode("utf8"): value.decode("utf8") if type(value) is bytes else value + for key, value in kwargs.items() + }) if kwargs else None + + def serialize(self): + return { + key.encode("utf8"): value.encode("utf8") if type(value) is str else value + for key, value in self.__dict__.items() if key in self.__keys__ + } + + # Request rate limitation limiter = flask_limiter.Limiter(key_func=lambda: current_user.id) @@ -90,3 +110,7 @@ def get_locale(): # Data migrate migrate = flask_migrate.Migrate() + + +# Redis storage +redis = flask_redis.FlaskRedis() diff --git a/requirements.txt b/requirements.txt index c7f6436d..67bec66e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,7 @@ Flask-script Flask-wtf Flask-debugtoolbar Flask-limiter +Flask-redis WTForms-Components passlib gunicorn -- GitLab