initialize_db.py 4.5 KB
import os
import sys
import csv
import subprocess
from getpass import getpass
from argparse import ArgumentParser
import transaction
from ziggurat_foundations.models.services.user import UserService
from pyramid.paster import (
    get_appsettings,
    setup_logging,
    )
from pyramid.i18n import (
    Localizer,
    TranslationStringFactory,
    Translations,
    )
from ..models import (
    get_engine,
    get_session_factory,
    get_tm_session,
    )
from ..models.ziggurat import (
    Group,
    GroupPermission,
    UserGroup,
    User,
    )
from ..models.meta import (
    metadata,
    Base,
    )
from ..models.conf import Conf


domain = 'initialize_db'
_ = TranslationStringFactory(domain)

my_registry = dict()


class MyLocalizer:
    def __init__(self):
        settings = my_registry['settings']
        locale_name = settings['pyramid.default_locale_name']
        here = os.path.abspath(os.path.dirname(__file__))
        locale_dir = os.path.join(here, '..', 'locale')
        translations = Translations.load(locale_dir, [locale_name], domain)
        self.localizer = Localizer(locale_name, translations)

    def translate(self, ts):
        return self.localizer.translate(ts)


def read_file(filename):
    f = open(filename)
    s = f.read()
    f.close()
    return s


def alembic_run(ini_file):
    bin_path = os.path.split(sys.executable)[0]
    alembic_bin = os.path.join(bin_path, 'alembic')
    command = (alembic_bin, '-c', ini_file, 'upgrade', 'head')
    if subprocess.call(command) != 0:
        sys.exit()


def get_file(filename):
    base_dir = os.path.split(__file__)[0]
    fullpath = os.path.join(base_dir, 'data', filename)
    return open(fullpath)


def ask_password(name):
    localizer = MyLocalizer()
    data = dict(name=name)
    t_msg1 = _(
            'ask-password-1', default='Enter new password for ${name}: ',
            mapping=data)
    t_msg2 = _(
            'ask-password-2', default='Retype new password for ${name}: ',
            mapping=data)
    msg1 = localizer.translate(t_msg1)
    msg2 = localizer.translate(t_msg2)
    while True:
        pass1 = getpass(msg1)
        if not pass1:
            continue
        pass2 = getpass(msg2)
        if pass1 == pass2:
            return pass1
        ts = _('Sorry, passwords do not match')
        print(localizer.translate(ts))


def restore_csv(table, filename):
    DBSession = my_registry['dbsession']
    q = DBSession.query(table)
    if q.first():
        return
    with get_file(filename) as f:
        reader = csv.DictReader(f)
        for cf in reader:
            row = table()
            for fieldname in cf:
                val = cf[fieldname]
                if not val:
                    continue
                setattr(row, fieldname, val)
            DBSession.add(row)
    return True


def append_csv(table, filename, keys):
    DBSession = my_registry['dbsession']
    with get_file(filename) as f:
        reader = csv.DictReader(f)
        filter_ = dict()
        for cf in reader:
            for key in keys:
                filter_[key] = cf[key]
            q = DBSession.query(table).filter_by(**filter_)
            found = q.first()
            if found:
                continue
            row = table()
            for fieldname in cf:
                val = cf[fieldname]
                if not val:
                    continue
                setattr(row, fieldname, val)
            DBSession.add(row)


def setup_models(dbsession):
    metadata.create_all(dbsession.bind)
    append_csv(Conf, 'conf.csv', ['nama'])
    if restore_csv(User, 'users.csv'):
        dbsession.flush()
        q = dbsession.query(User).filter_by(id=1)
        user = q.first()
        password = ask_password(user.user_name)
        UserService.set_password(user, password)
    append_csv(Group, 'groups.csv', ['group_name'])
    restore_csv(UserGroup, 'users_groups.csv')


def parse_args(argv):
    parser = ArgumentParser()
    parser.add_argument(
        'config_uri',
        help='Configuration file, e.g., development.ini',
    )
    return parser.parse_args(argv[1:])


def main(argv=sys.argv):
    args = parse_args(argv)
    setup_logging(args.config_uri)
    my_registry['settings'] = settings = get_appsettings(args.config_uri)
    engine = get_engine(settings)
    Base.metadata.create_all(engine)
    alembic_run(args.config_uri)
    session_factory = get_session_factory(engine)
    with transaction.manager:
        my_registry['dbsession'] = dbsession = get_tm_session(
            session_factory, transaction.manager)
        setup_models(dbsession)