Skip to content
Snippets Groups Projects
port_to_maria.py 11.2 KiB
Newer Older
# -*- coding: utf-8 -*-
# Copyright 2015 OpenMarket Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from twisted.internet import defer, reactor
from twisted.enterprise import adbapi

from synapse.storage._base import LoggingTransaction, SQLBaseStore
from synapse.storage.engines import create_engine

import argparse
import itertools
import logging
import yaml


logger = logging.getLogger("port_to_maria")


BINARY_COLUMNS = {
    "event_content_hashes": ["hash"],
    "event_reference_hashes": ["hash"],
    "event_signatures": ["signature"],
    "event_edge_hashes": ["hash"],
    "events": ["content", "unrecognized_keys"],
    "event_json": ["internal_metadata", "json"],
    "application_services_txns": ["event_ids"],
    "received_transactions": ["response_json"],
    "sent_transactions": ["response_json"],
    "server_tls_certificates": ["tls_certificate"],
    "server_signature_keys": ["verify_key"],
    "pushers": ["pushkey", "data"],
    "user_filters": ["filter_json"],
}

UNICODE_COLUMNS = {
    "events": ["content", "unrecognized_keys"],
    "event_json": ["internal_metadata", "json"],
    "users": ["password_hash"],
}


BOOLEAN_COLUMNS = {
    "events": ["processed", "outlier"],
    "rooms": ["is_public"],
    "event_edges": ["is_state"],
    "presence_list": ["accepted"],
}


APPEND_ONLY_TABLES = [
    "event_content_hashes",
    "event_reference_hashes",
    "event_signatures",
    "event_edge_hashes",
    "events",
    "event_json",
    "state_events",
    "room_memberships",
    "feedback",
    "topics",
    "room_names",
    "rooms",
    "local_media_repository",
    "local_media_repository_thumbnails",
    "remote_media_cache",
    "remote_media_cache_thumbnails",
    "redactions",
    "event_edges",
    "event_auth",
    "received_transactions",
    "sent_transactions",
    "transaction_id_to_pdu",
    "users",
    "state_groups",
    "state_groups_state",
    "event_to_state_groups",
    "rejections",
]


class Store(object):
    def __init__(self, db_pool, engine):
        self.db_pool = db_pool
        self.database_engine = engine

    _simple_insert_txn = SQLBaseStore.__dict__["_simple_insert_txn"]
    _simple_insert = SQLBaseStore.__dict__["_simple_insert"]

    _simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"]
    _simple_select_onecol = SQLBaseStore.__dict__["_simple_select_onecol"]
    _simple_select_one_onecol = SQLBaseStore.__dict__["_simple_select_one_onecol"]
    _simple_select_one_onecol_txn = SQLBaseStore.__dict__["_simple_select_one_onecol_txn"]

    _simple_update_one = SQLBaseStore.__dict__["_simple_update_one"]
    _simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"]

    _execute_and_decode = SQLBaseStore.__dict__["_execute_and_decode"]

    def runInteraction(self, desc, func, *args, **kwargs):
        def r(conn):
            try:
                i = 0
                N = 5
                while True:
                    try:
                        txn = conn.cursor()
                        return func(
                            LoggingTransaction(txn, desc, self.database_engine),
                    except self.database_engine.module.DatabaseError as e:
                        if self.database_engine.is_deadlock(e):
                            logger.warn("[TXN DEADLOCK] {%s} %d/%d", desc, i, N)
                            if i < N:
                                i += 1
                                conn.rollback()
                                continue
                        raise
            except Exception as e:
                logger.debug("[TXN FAIL] {%s}", desc, e)
                raise

        return self.db_pool.runWithConnection(r)

    def insert_many_txn(self, txn, table, headers, rows):
        sql = "INSERT INTO %s (%s) VALUES (%s)" % (
            table,
            ", ".join(k for k in headers),
            ", ".join("%s" for _ in headers)
        )

        try:
            txn.executemany(sql, rows)
        except:
            logger.exception(
                "Failed to insert: %s",
                table,
            )
            raise



def chunks(n):
    for i in itertools.count(0, n):
        yield range(i, i+n)


@defer.inlineCallbacks
def handle_table(table, sqlite_store, mysql_store):
    if table in APPEND_ONLY_TABLES:
        # It's safe to just carry on inserting.
        next_chunk = yield mysql_store._simple_select_one_onecol(
            table="port_from_sqlite3",
            keyvalues={"table_name": table},
            retcol="rowid",
            allow_none=True,
        )

        if next_chunk is None:
            yield mysql_store._simple_insert(
                table="port_from_sqlite3",
                values={"table_name": table, "rowid": 0}
            )

            next_chunk = 0
    else:
        def delete_all(txn):
            txn.execute(
                "DELETE FROM port_from_sqlite3 WHERE table_name = %s",
                (table,)
            )
            txn.execute("TRUNCATE %s CASCADE" % (table,))
            mysql_store._simple_insert_txn(
                txn,
                table="port_from_sqlite3",
                values={"table_name": table, "rowid": 0}
            )

        yield mysql_store.runInteraction(
            "delete_non_append_only", delete_all
        )

        next_chunk = 0

    logger.info("next_chunk for %s: %d", table, next_chunk)


    select = "SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?" % (table,)

    uni_col_names = UNICODE_COLUMNS.get(table, [])
    bool_col_names = BOOLEAN_COLUMNS.get(table, [])
    bin_col_names = BINARY_COLUMNS.get(table, [])

    while True:
        def r(txn):
            txn.execute(select, (next_chunk, N,))
            rows = txn.fetchall()
            headers = [column[0] for column in txn.description]

            return headers, rows

        headers, rows = yield sqlite_store.runInteraction("select", r)

        logger.info("Got %d rows for %s", len(rows), table)

        if rows:
            uni_cols = [i for i, h in enumerate(headers) if h in uni_col_names]
            bool_cols = [i for i, h in enumerate(headers) if h in bool_col_names]
            bin_cols = [i for i, h in enumerate(headers) if h in bin_col_names]
            next_chunk = rows[-1][0] + 1

            def conv(j, col):
                if j in uni_cols:
                    col = sqlite_store.database_engine.load_unicode(col)
                if j in bool_cols:
                    return bool(col)

                if j in bin_cols:
                    if isinstance(col, types.UnicodeType):
                        col = buffer(col.encode("utf8"))

                return col

            for i, row in enumerate(rows):
                rows[i] = tuple(
                    mysql_store.database_engine.encode_parameter(
                    )
                    for j, col in enumerate(row)
                    if j > 0
                )

            def ins(txn):
                mysql_store.insert_many_txn(txn, table, headers[1:], rows)
                mysql_store._simple_update_one_txn(
                    txn,
                    table="port_from_sqlite3",
                    keyvalues={"table_name": table},
                    updatevalues={"rowid": next_chunk},
                )

            yield mysql_store.runInteraction("insert_many", ins)
        else:
            return


def setup_db(db_config, database_engine):
    db_conn = database_engine.module.connect(
        **{
            k: v for k, v in db_config.get("args", {}).items()
            if not k.startswith("cp_")
        }
    )

    database_engine.prepare_database(db_conn)

    db_conn.commit()


@defer.inlineCallbacks
def main(sqlite_config, mysql_config):
    try:
        sqlite_db_pool = adbapi.ConnectionPool(
            sqlite_config["name"],
            **sqlite_config["args"]
        )

        mysql_db_pool = adbapi.ConnectionPool(
            mysql_config["name"],
            **mysql_config["args"]
        )

        sqlite_engine = create_engine("sqlite3")
        mysql_engine = create_engine("psycopg2")

        sqlite_store = Store(sqlite_db_pool, sqlite_engine)
        mysql_store = Store(mysql_db_pool, mysql_engine)

        # Step 1. Set up mysql database.
        logger.info("Preparing sqlite database...")
        setup_db(sqlite_config, sqlite_engine)

        logger.info("Preparing mysql database...")
        setup_db(mysql_config, mysql_engine)

        # Step 2. Get tables.
        logger.info("Fetching tables...")
        tables = yield sqlite_store._simple_select_onecol(
            table="sqlite_master",
            keyvalues={
                "type": "table",
            },
            retcol="name",
        )

        logger.info("Found %d tables", len(tables))

        def create_port_table(txn):
            txn.execute(
                "CREATE TABLE port_from_sqlite3 ("
                " table_name varchar(100) NOT NULL UNIQUE,"
                " rowid bigint NOT NULL"
                ")"
            )
        try:
            yield mysql_store.runInteraction(
                "create_port_table", create_port_table
            )
        except Exception as e:
            logger.info("Failed to create port table: %s", e)
        # Process tables.
        yield defer.gatherResults(
            [
                handle_table(table, sqlite_store, mysql_store)
                for table in tables
                if table not in ["schema_version", "applied_schema_deltas"]
                and not table.startswith("sqlite_")
            ],
            consumeErrors=True,
        )

        # for table in ["current_state_events"]:  # tables:
        #     if table not in ["schema_version", "applied_schema_deltas"]:
        #         if not table.startswith("sqlite_"):
        #             yield handle_table(table, sqlite_store, mysql_store)
    except:
        logger.exception("")
    finally:
        reactor.stop()


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--sqlite-database")
    parser.add_argument(
        "--mysql-config", type=argparse.FileType('r'),
    )

    args = parser.parse_args()
    logging.basicConfig(level=logging.INFO)

    sqlite_config = {
        "name": "sqlite3",
        "args": {
            "database": args.sqlite_database,
            "cp_min": 1,
            "cp_max": 1,
            "check_same_thread": False,
        },
    }

    mysql_config = yaml.safe_load(args.mysql_config)
    # mysql_config["args"].update({
    #     "sql_mode": "TRADITIONAL",
    #     "charset": "utf8mb4",
    #     "use_unicode": True,
    #     "collation": "utf8mb4_bin",
    # })

    reactor.callWhenRunning(
        main,
        sqlite_config=sqlite_config,
        mysql_config=mysql_config,
    )

    reactor.run()