From 699c1e28066b4832be591297453cd0adf568fe2e Mon Sep 17 00:00:00 2001
From: kaiyou <pierre@jaury.eu>
Date: Wed, 13 Nov 2019 17:17:15 +0100
Subject: [PATCH] Store authorization codes in redis

---
 hiboo/__init__.py |  1 +
 hiboo/sso/oidc.py | 56 +++++++++++++++++++----------------------------
 hiboo/utils.py    | 24 ++++++++++++++++++++
 requirements.txt  |  1 +
 4 files changed, 48 insertions(+), 34 deletions(-)

diff --git a/hiboo/__init__.py b/hiboo/__init__.py
index 8f13aa4c..abe89ee0 100644
--- a/hiboo/__init__.py
+++ b/hiboo/__init__.py
@@ -16,6 +16,7 @@ def create_app_from_config(config):
     utils.login.init_app(app)
     utils.login.user_loader(models.User.get)
     utils.migrate.init_app(app, models.db)
+    utils.redis.init_app(app)
 
     # Initialize debugging tools
     if app.config.get("DEBUG"):
diff --git a/hiboo/sso/oidc.py b/hiboo/sso/oidc.py
index 1a74ef38..fe7f234d 100644
--- a/hiboo/sso/oidc.py
+++ b/hiboo/sso/oidc.py
@@ -7,6 +7,7 @@ from hiboo.sso import forms, blueprint
 from hiboo import models, utils, profile
 
 import flask
+import time
 
 
 class Config(object):
@@ -17,7 +18,7 @@ class Config(object):
     def derive_form(cls, form):
         """ Add required fields to a form.
         """
-        return type('NewForm', (forms.OIDCForm, form), {})
+        return tmodelsype('NewForm', (forms.OIDCForm, form), {})
 
     @classmethod
     def populate_service(cls, form, service):
@@ -54,22 +55,18 @@ class Config(object):
             )
 
 
-class Client(models.db.Model, sqla_oauth2.OAuth2ClientMixin):
+class Client(sqla_oauth2.OAuth2ClientMixin):
     """ OIDC client that only supports authorization code, implicit and
     hybrid flows.
     """
 
     scope = "openid"
-    expire_after = 3600
 
     def __init__(self, service):
         self.service = service
-        super(Client, self).__init__(
-            client_id=service.config["client_id"],
-            client_secret=service.config["client_secret"],
-            client_metadata=service.config
-        )
-        # The authorization server is specific to a client
+        self.client_id = service.config["client_id"]
+        self.client_secret = service.config["client_secret"]
+        self.client_metadata = service.config
         self.authorization = flask_oauth2.AuthorizationServer(
             query_client=self.query_client,
             save_token=self.save_token,
@@ -85,7 +82,6 @@ class Client(models.db.Model, sqla_oauth2.OAuth2ClientMixin):
         return self if client_id == self.client_id else None
 
     def save_token(self, token, request):
-        # TODO: atm we do not save any token
         pass
 
     def get_jwt_config(self):
@@ -93,7 +89,7 @@ class Client(models.db.Model, sqla_oauth2.OAuth2ClientMixin):
             'key': self.service.config["jwt_key"],
             'alg': self.service.config["jwt_alg"],
             'iss': flask.url_for("sso.oidc_token", service_uuid=self.service.uuid, _external=True),
-            'exp': self.expire_after,
+            'exp': 3600,
         }
 
     @classmethod
@@ -110,43 +106,35 @@ class Client(models.db.Model, sqla_oauth2.OAuth2ClientMixin):
 
     @classmethod
     def exists_nonce(cls, nonce, request):
-        return bool(AuthorizationCode.query.filter_by(
-            nonce=nonce, client_id=request.client_id).first()
-        )
+        return bool(utils.redis.get("nonce:{}".format(nonce)))
 
 
-class AuthorizationCode(models.db.Model, sqla_oauth2.OAuth2AuthorizationCodeMixin):
+class AuthorizationCode(utils.SerializableObj, sqla_oauth2.OAuth2AuthorizationCodeMixin):
     """ Authorization code object for storage
     """
-    __tablename__ = "oidc_authorization_code"
-
-    user_id = models.db.Column(models.db.Text())
 
     @classmethod
     def create(cls, client, grant_user, request):
-        code = gen_salt(48)  # TODO
-        authorization_code = AuthorizationCode(
-            code=code, nonce=request.data.get('nonce'),
-            client_id=client.client_id,
-            redirect_uri=request.redirect_uri,
-            scope=request.scope,
-            user_id=grant_user.uuid
+        obj = cls(
+            code=gen_salt(48), nonce=request.data.get("nonce") or "",
+            client_id=client.client_id, redirect_uri=request.redirect_uri,
+            scope=request.scope, user_id=grant_user.uuid,
+            auth_time=int(time.time())
         )
-        models.db.session.add(authorization_code)
-        models.db.session.commit()
-        return code
+        utils.redis.hmset("code:{}".format(obj.code), obj.serialize())
+        if obj.nonce:
+            utils.redis.set("nonce:{}".format(obj.nonce), obj.code)
+        return obj.code
 
     @classmethod
     def get(cls, code, client):
-        return AuthorizationCode.query.filter_by(
-            client_id=client.client_id,
-            code=code
-        ).first()
+        obj = cls.unserialize(utils.redis.hgetall("code:{}".format(code)))
+        if obj and obj.client_id == client.client_id:
+            return obj
 
     @classmethod
     def delete(cls, authorization_code):
-        models.db.session.delete(authorization_code)
-        models.db.session.commit()
+        utils.redis.delete("code:{}".format(authorization_code))
 
 
 class AuthorizationCodeGrant(oauth2.grants.AuthorizationCodeGrant):
diff --git a/hiboo/utils.py b/hiboo/utils.py
index d5d319fd..f24b72b1 100644
--- a/hiboo/utils.py
+++ b/hiboo/utils.py
@@ -3,6 +3,7 @@ import flask_login
 import flask_migrate
 import flask_babel
 import flask_limiter
+import flask_redis
 
 from werkzeug.contrib import fixers
 from werkzeug import routing
@@ -75,6 +76,25 @@ def display_help(identifier):
     return result
 
 
+class SerializableObj(object):
+    def __init__(self, **kwargs):
+        self.__dict__.update(**kwargs)
+        self.__keys__ = list(kwargs.keys())
+
+    @classmethod
+    def unserialize(cls, kwargs):
+        return cls(**{
+            key.decode("utf8"): value.decode("utf8") if type(value) is bytes else value
+            for key, value in kwargs.items()
+        }) if kwargs else None
+
+    def serialize(self):
+        return {
+            key.encode("utf8"): value.encode("utf8") if type(value) is str else value
+            for key, value in self.__dict__.items() if key in self.__keys__
+        }
+
+
 # Request rate limitation
 limiter = flask_limiter.Limiter(key_func=lambda: current_user.id)
 
@@ -90,3 +110,7 @@ def get_locale():
 
 # Data migrate
 migrate = flask_migrate.Migrate()
+
+
+# Redis storage
+redis = flask_redis.FlaskRedis()
diff --git a/requirements.txt b/requirements.txt
index c7f6436d..67bec66e 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -7,6 +7,7 @@ Flask-script
 Flask-wtf
 Flask-debugtoolbar
 Flask-limiter
+Flask-redis
 WTForms-Components
 passlib
 gunicorn
-- 
GitLab