#!/usr/bin/python3
# Copyright (c) 2017-2023, Mudita Sp. z.o.o. All rights reserved.
# For licensing, see https://github.com/mudita/MuditaOS/LICENSE.md
import os
import uuid
import sqlite3
from argparse import ArgumentParser
from pathlib import Path
import sys
import datetime
import json
import shutil
import traceback
import itertools
# Constants
up_script = "up.sql"
down_script = "down.sql"
devel_script = "devel.sql"
meta_file = ".meta"
databases_set = "databases.json"
env_file = "dbm_env.ini"
license_header = f"-- Copyright (c) 2017-{datetime.date.today().year}, Mudita Sp. z.o.o. All rights reserved.\n" \
"-- For licensing, see https://github.com/mudita/MuditaOS/LICENSE.md\n\n"
cli = ArgumentParser()
subparsers = cli.add_subparsers(dest="subcommand")
def subcommand(args=[], parent=subparsers):
def decorator(func):
parser = parent.add_parser(func.__name__, description=func.__doc__)
for arg in args:
parser.add_argument(*arg[0], **arg[1])
parser.set_defaults(func=func)
return decorator
def argument(*name_or_flags, **kwargs):
return [*name_or_flags], kwargs
class RevisionMetadata:
_key_id = "id"
_key_date = "date"
_key_message = "message"
_key_parent = "parent"
file_name = ".meta"
def __init__(self, id, date, message, parent):
self.set = {RevisionMetadata._key_id: str(id), RevisionMetadata._key_date: date,
RevisionMetadata._key_message: message, RevisionMetadata._key_parent: parent}
def id(self):
return self.set[RevisionMetadata._key_id]
def parent(self):
return self.set[RevisionMetadata._key_parent]
def message(self):
return self.set[RevisionMetadata._key_message]
@classmethod
def from_file(cls, path: Path):
with open(path, "r") as f:
raw = json.load(f)
return cls(raw[cls._key_id], raw[cls._key_date], raw[cls._key_message], raw[cls._key_parent])
def dump_to_file(self, path: Path):
with path.open('a') as file:
file.write(json.dumps(self.set, indent=1))
class ConstRevisionEntry:
def __init__(self, dir: Path):
self.dir = dir
self.metadata = RevisionMetadata.from_file(dir / RevisionMetadata.file_name)
def read_sql(self):
lines_to_skip = license_header.count('\n')
with open(self.dir / up_script) as f:
up = ''.join(f.readlines()[lines_to_skip:])
with open(self.dir / down_script) as f:
down = ''.join(f.readlines()[lines_to_skip:])
try:
with open(self.dir / devel_script) as f:
devel = ''.join(f.readlines()[lines_to_skip:])
except OSError:
devel = None
return up, down, devel
class RevisionEntry:
def __init__(self, base_dir: Path, message: str):
self.id = uuid.uuid4()
self.base_dir = base_dir
self.dir = base_dir / "{id}_{message}".format(id=str(self.id)[:8], message=message.replace(' ', '_'))
self.date = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self.message = message
latest_revision = get_latest_revision(self.base_dir)
self.metadata = RevisionMetadata(self.id, self.date, self.message,
latest_revision.metadata.id() if latest_revision else 0)
def spawn(self):
Path.mkdir(self.dir, exist_ok=True, parents=True)
self.metadata.dump_to_file(self.dir / RevisionMetadata.file_name)
self._build_sql_template()
def _build_sql_template(self):
with (self.dir / up_script).open('w') as file:
file.write(license_header + self._sql_header())
with (self.dir / down_script).open('w') as file:
file.write(license_header + self._sql_header())
with (self.dir / devel_script).open('w') as file:
file.write(license_header + self._sql_header())
def _sql_header(self):
return f'-- Message: {self.message}\n' \
f'-- Revision: {self.id}\n' \
f'-- Create Date: {self.date}\n\n' \
f'-- Insert SQL here\n'
class DatabaseSet:
def __init__(self, path: Path):
self.key_db_version = "version"
self.key_db_name = "name"
self.key_db_array = "databases"
with open(path, "r") as f:
self.set = json.load(f)
self.path = path
self.product = list(self.set.keys())[0]
def get_database_version(self, db_name: str):
v = next(
d[self.key_db_version] for d in self.set[self.product][self.key_db_array] if
d[self.key_db_name] == db_name)
return int(v)
def db_array(self):
return self.set[self.product][self.key_db_array]
def list_databases_by_name(self):
return set([database["name"] for database in self.db_array()])
def modify_database_version(self, db_name: str, version: int):
entry = next(d for d in self.set[self.product][self.key_db_array] if d[self.key_db_name] == db_name)
entry[self.key_db_version] = str(version)
with open(self.path, 'w') as file:
file.write(json.dumps(self.set, indent=1))
class Migration:
env_var = "DB_MIGRATION_ENV"
_rev_base_dir = "current"
def _get_env(self, path: Path):
"""Tries to fetch environment settings from the given file"""
with open(path / env_file) as f:
data = json.load(f)
data["output_dir"] = Path(data["output_dir"])
data["dirs"][:] = [Path(e) for e in data["dirs"]]
data["db_set_dir"] = Path(data["db_set_dir"])
return data
def _get_db_set(self):
return DatabaseSet(self._env["db_set_dir"])
def _get_database_path(self, db_name):
return next(d / db_name for d in self._env["dirs"] if (Path(d) / db_name).exists())
def _invoke_sql(self, db_name, script_name):
base_dir = self._get_database_path(db_name) / Migration._rev_base_dir
if not base_dir.exists():
print("Nothing to invoke")
return
rev = get_latest_revision(base_dir)
execute_db_script(self._env["output_dir"] / f"{db_name}.db", rev.dir / script_name)
def __init__(self, env_path: Path):
self._env = self._get_env(env_path) if env_path else self._get_env(Path(os.environ.get(Migration.env_var)))
self.db_names = [os.listdir(d) for d in self._env["dirs"]]
self.db_names = list(itertools.chain(*self.db_names))
def upgrade(self, db_name, rev, devel):
print(f"Upgrading '{db_name}', devel features: {devel}")
db_path = self._get_database_path(db_name)
Path.mkdir(self._env["output_dir"], exist_ok=True, parents=True)
# Remove old database, if exists
Path.unlink(self._env["output_dir"] / f"{db_name}.db", missing_ok=True)
# First, migrate using already committed db version from database set file
version = self._get_db_set().get_database_version(db_name)
print(f"-> Upgrading to committed version: {version}")
migrate_database_up(db_name, db_path, self._env["output_dir"], version, devel)
# Check if 'current' directory exists and apply current revision list
current_path = db_path / Migration._rev_base_dir
if not current_path.exists():
return
revisions = build_revision_entries(current_path)
if rev is None:
print(f"-> Upgrading to the newest available revision: {revisions[-1].metadata.id()}")
revisions_range = revisions[:]
else:
# Upgrade up to the specified revision
revisions_range = build_revision_entries_up_to(revisions, rev)
if not revisions_range:
print(f"-> revision: {rev} does not exist")
return
print(f"-> Upgrading to the revision: {rev}")
for revision in revisions_range:
meta = revision.metadata
print(f" -> Running upgrade from {meta.parent()} to {meta.id()}")
execute_db_script(self._env["output_dir"] / f"{db_name}.db", revision.dir / up_script)
if devel and os.path.exists(revision.dir / devel_script):
execute_db_script(self._env["output_dir"] / f"{db_name}.db", revision.dir / devel_script)
def install(self, devel):
shutil.rmtree(self._env["output_dir"], ignore_errors=True)
Path.mkdir(self._env["output_dir"], exist_ok=True, parents=True)
databases_to_migrate = self._get_db_set().list_databases_by_name().intersection(self.db_names)
print(f"Database set to be upgraded and installed: {databases_to_migrate}")
for db_name in databases_to_migrate:
self.upgrade(db_name, None, devel)
# Populate output dir with migration scripts, skip 'devel.sql' scripts
for d in self._env["dirs"]:
shutil.copytree(d, self._env["output_dir"] / "migration", dirs_exist_ok=True,
ignore=shutil.ignore_patterns(devel_script))
def commit(self, db_name):
db_path = self._get_database_path(db_name)
current_path = db_path / Migration._rev_base_dir
upgrade_version = self._get_db_set().get_database_version(db_name) + 1
print(f"Committing database '{db_name}':")
if not current_path.exists():
print("->Nothing to commit")
return
# Prepare new version directory structure
version_path = db_path / str(upgrade_version)
Path.mkdir(db_path / version_path, exist_ok=True, parents=True)
merge_sql_from_dir(current_path, db_path / version_path)
self._get_db_set().modify_database_version(db_name, upgrade_version)
shutil.rmtree(current_path)
print(f"->New version generated from commit: {upgrade_version}")
def commit_all(self):
for db_name in self._get_db_set().list_databases_by_name():
self.commit(db_name)
def revision(self, db_name, message):
base_dir = self._get_database_path(db_name) / Migration._rev_base_dir
Path.mkdir(base_dir, exist_ok=True, parents=True)
entry = RevisionEntry(base_dir, message)
entry.spawn()
print(f"Added new revision: {entry.metadata.id()}")
def revert(self, db_name):
self._invoke_sql(db_name, down_script)
def redo(self, db_name):
self._invoke_sql(db_name, down_script)
self._invoke_sql(db_name, up_script)
def build_revision_entries(base: Path):
""" Builds the list of ConstRevisionEntry entries where each child is placed after its parent.
Revision_1(id=1,parent=0) -> Revision_2(id=2,parent=1) -> Revision_n(id=n,parent=2)
"""
metas = []
for entry in base.iterdir():
metas.append(ConstRevisionEntry(entry))
chain = []
parent_index = 0
for _ in metas:
try:
entry = next(d for d in metas if d.metadata.parent() == parent_index)
parent_index = entry.metadata.id()
chain.append(entry)
except StopIteration:
break
return chain
def build_revision_entries_up_to(revisions, rev):
""" Try to build the list of ConstRevisionEntry entries from the already existing list of revisions up to the
specified revision. For instance, Revision_1(id=1,parent=0) -> Revision_2(id=2,parent=1) -> Revision_n(id=rev,
parent=2)
"""
if next((r for r in revisions if r.metadata.id() == rev), [None]):
revisions_range = []
for r in revisions:
revisions_range.append(r)
if r.metadata.id() == rev:
return revisions_range
else:
return None
def get_latest_revision(base: Path):
"""Obtains the latest ConstRevisionEntry """
chain = build_revision_entries(base)
return None if len(chain) == 0 else chain[-1]
def merge_sql_from_dir(directory: Path, out: Path):
revisions = build_revision_entries(directory)
# Merge up/down.sql
with open(out / up_script, 'w') as up_file, open(out / down_script, 'w') as down_file:
up_file.write(license_header)
down_file.write(license_header)
for rev in revisions:
print(f"->Merging revision: {rev.metadata.id()}")
sql_up, _, sql_devel = rev.read_sql()
up_file.write(sql_up + '\n')
if sql_devel:
if not (out / devel_script).exists():
with open(out / devel_script, 'w') as devel_file:
devel_file.write(license_header)
with open(out / devel_script, 'a+') as devel_file:
devel_file.write(sql_devel + '\n')
# Down scripts need to be merged in reversed order
for rev in reversed(revisions):
_, sql_down, _ = rev.read_sql()
down_file.write(sql_down + '\n')
def execute_db_script(db_path: Path, script: Path, version: int = None):
connection = sqlite3.connect(db_path)
with open(script) as ms:
connection.executescript(ms.read())
connection.commit()
if version:
connection.execute(f"PRAGMA user_version = {version};")
connection.commit()
connection.close()
def migrate_database_up(database: str, migration_path: os.path, dst_directory: os.path, dst_version: int, devel: bool):
db_name_full = f"{database}.db"
dst_db_path = dst_directory / db_name_full
Path(dst_db_path).unlink(missing_ok=True)
for i in range(dst_version + 1):
migration_script = migration_path / str(i) / up_script
devel_script_path = migration_path / str(i) / devel_script
execute_db_script(dst_db_path, migration_script, i)
if devel and os.path.exists(devel_script_path):
execute_db_script(dst_db_path, devel_script_path, i)
@subcommand([argument("-e", "--env", help="where to store environment configuration", required=True, type=Path),
argument("--dbset", help="location of the file describing database set", required=True, type=Path),
argument("-o", "--out", help="where to store generated databases", required=True, type=Path),
argument("--dirs",
help="list of migration base directories. It's important to pass product-specific directory as "
"a first element on the list",
action='append',
nargs='*',
required=True,
type=Path)])
def init(args):
"""Initializes migration environment"""
env = {"db_set_dir": args.dbset.as_posix(), "output_dir": args.out.as_posix(),
"dirs": [a[0].as_posix() for a in args.dirs]}
with open(args.env / env_file, 'w') as f:
f.write(json.dumps(env, indent=1))
@subcommand([argument("-e", "--env", help="environment location", type=Path),
argument("--db", help="database name", required=True, type=str),
argument("-m", "--message", help="revision message", required=True, type=str)])
def revision(args):
"""Creates a new database migration revision"""
Migration(args.env).revision(args.db, args.message)
@subcommand([argument("-e", "--env", help="environment location", type=Path),
argument("--db", help="database name", type=str)])
def commit(args):
"""Commits current set of SQL statements and updates database version number"""
if args.db:
Migration(args.env).commit(args.db)
else:
Migration(args.env).commit_all()
@subcommand(
[argument("-e", "--env", help="environment location", type=Path),
argument("-d", "--devel", help="with development schema", default=False)])
def install(args):
""" Generates database set and then installs it in the specific output directory. It also populates output
directory with corresponding migration scripts"""
Migration(args.env).install(args.devel)
@subcommand(
[argument("-e", "--env", help="environment location", type=Path),
argument("--db", help="database name", type=str, required=True),
argument("-r", "--revision", help="target revision", type=str),
argument("-d", "--devel", help="with development schema", default=False)])
def upgrade(args):
""" Upgrades database to the specific revision(or the newest one if --revision parameter omitted)"""
Migration(args.env).upgrade(args.db, args.revision, args.devel)
@subcommand([argument("-e", "--env", help="environment location", type=Path),
argument("--db", help="database name", type=str, required=True)])
def revert(args):
""" Runs the (down.sql) for the specified database for the most recent migration"""
Migration(args.env).revert(args.db)
@subcommand([argument("-e", "--env", help="environment location", type=Path),
argument("--db", help="database name", type=str, required=True)])
def redo(args):
""" Runs the (down.sql) and then the (up.sql) for the most recent migration"""
Migration(args.env).redo(args.db)
def main() -> int:
args = cli.parse_args()
if args.subcommand is None:
cli.print_help()
return 1
else:
try:
args.func(args)
except:
print(traceback.format_exc())
return 1
if __name__ == "__main__":
sys.exit(main())