Refactor database session handling in model methods for consistency and flexibility

1 parent 34da7c42
...@@ -95,12 +95,16 @@ class DefaultModel(CommonModel): ...@@ -95,12 +95,16 @@ class DefaultModel(CommonModel):
return row return row
@classmethod @classmethod
def count(cls): def count(cls, db_session=None):
return cls.db_session.query(func.count('id')).scalar() if not db_session:
db_session = cls.db_session
return db_session.query(func.count('id')).scalar()
@classmethod @classmethod
def query(cls, filters=None): def query(cls, db_session=None, filters=None):
query = cls.db_session.query(cls) if not db_session:
db_session = cls.db_session
query = db_session.query(cls)
if filters: if filters:
filter_expressions = [] filter_expressions = []
for d in filters: for d in filters:
...@@ -113,24 +117,32 @@ class DefaultModel(CommonModel): ...@@ -113,24 +117,32 @@ class DefaultModel(CommonModel):
return query return query
@classmethod @classmethod
def query_from(cls, columns=[], filters=None): def query_from(cls, db_session=None, columns=[], filters=None):
query = cls.db_session.query().select_from(cls) if not db_session:
db_session = cls.db_session
query = db_session.query().select_from(cls)
for c in columns: for c in columns:
query = query.add_columns(c) query = query.add_columns(c)
return query return query
@classmethod @classmethod
def query_id(cls, row_id): def query_id(cls, row_id, db_session=None):
return cls.query().filter_by(id=row_id) if not db_session:
db_session = cls.db_session
return cls.query(db_session).filter_by(id=row_id)
@classmethod @classmethod
def delete(cls, row_id): def delete(cls, row_id, db_session=None):
cls.query_id(row_id).delete() if not db_session:
db_session = cls.db_session
cls.query_id(row_id, db_session).delete()
@classmethod @classmethod
def flush(cls, row): def flush(cls, row, db_session=None):
cls.db_session.add(row) if not db_session:
cls.db_session.flush() db_session = cls.db_session
db_session.add(row)
db_session.flush()
class StandarModel(DefaultModel): class StandarModel(DefaultModel):
...@@ -142,85 +154,124 @@ class StandarModel(DefaultModel): ...@@ -142,85 +154,124 @@ class StandarModel(DefaultModel):
# New Method # New Method
@classmethod @classmethod
def query_status(cls, status=0): def query_status(cls, status=0, db_session=None):
return cls.query().filter_by(status=status) if not db_session:
db_session = cls.db_session
return cls.query(db_session).filter_by(status=status)
@classmethod @classmethod
def disabled(cls): def disabled(cls, db_session=None):
return cls.query_status(status=0) if not db_session:
db_session = cls.db_session
return cls.query_status(status=0, db_session=db_session)
@classmethod @classmethod
def active(cls): def active(cls, db_session=None):
return cls.query_status(status=1) if not db_session:
db_session = cls.db_session
return cls.query_status(status=1, db_session=db_session)
@classmethod @classmethod
def draft(cls): def draft(cls, db_session=None):
return cls.disabled() if not db_session:
db_session = cls.db_session
return cls.disabled(db_session=db_session)
@classmethod @classmethod
def processed(cls): def processed(cls, db_session=None):
return cls.query_status(status=1) if not db_session:
db_session = cls.db_session
return cls.query_status(status=1, db_session=db_session)
@classmethod @classmethod
def canceled(cls): def canceled(cls, db_session=None):
return cls.query_status(status=9) if not db_session:
db_session = cls.db_session
return cls.query_status(status=9, db_session=db_session)
@classmethod @classmethod
def get_active(cls): def get_active(cls, db_session=None):
return cls.query_status(status=1).all() if not db_session:
db_session = cls.db_session
return cls.query_status(status=1, db_session=db_session).all()
@classmethod @classmethod
def get_disabled(cls): def get_disabled(cls, db_session=None):
return cls.query_status(status=0).all() if not db_session:
db_session = cls.db_session
return cls.query_status(status=0, db_session=db_session).all()
@classmethod @classmethod
def get_archived(cls): def get_archived(cls, db_session=None):
return cls.query_status(status=0).all() if not db_session:
db_session = cls.db_session
return cls.query_status(status=0, db_session=db_session).all()
class KodeModel(StandarModel): class KodeModel(StandarModel):
kode = Column(String(32), nullable=False) kode = Column(String(32), nullable=False)
@classmethod @classmethod
def query_kode(cls, kode): def query_kode(cls, kode, db_session=None):
return cls.query().filter_by(kode=kode) if not db_session:
db_session = cls.db_session
return cls.query(db_session).filter_by(kode=kode)
@classmethod @classmethod
def get_by_kode(cls, kode): def get_by_kode(cls, kode, db_session=None):
return cls.query_kode(kode).first() if not db_session:
db_session = cls.db_session
return cls.query_kode(kode, db_session=db_session).first()
class UraianModel(StandarModel): class UraianModel(StandarModel):
nama = Column(String(128)) nama = Column(String(128))
@classmethod @classmethod
def query_nama(cls, nama): def query_nama(cls, nama, db_session=None):
return cls.query().filter_by(nama=nama) if not db_session:
db_session = cls.db_session
return cls.query(db_session).filter_by(nama=nama)
@classmethod @classmethod
def get_by_nama(cls, nama): def get_by_nama(cls, nama, db_session=None):
return cls.query_nama(nama).first() if not db_session:
db_session = cls.db_session
return cls.query_nama(nama, db_session=db_session).first()
@classmethod @classmethod
def get_list(cls): def get_list(cls, db_session=None):
return cls.db_session.query(cls.id, cls.nama).order_by(cls.nama).all() if not db_session:
db_session = cls.db_session
return db_session.query(cls.id, cls.nama).order_by(cls.nama).all()
class NamaModel(KodeModel): class NamaModel(KodeModel):
nama = Column(String(128), nullable=False) nama = Column(String(128), nullable=False)
@classmethod @classmethod
def query_nama(cls, nama): def query_nama(cls, nama, db_session=None):
return cls.query().filter_by(nama=nama) if not db_session:
db_session = cls.db_session
return cls.query(db_session).filter_by(nama=nama)
@classmethod @classmethod
def get_by_nama(cls, nama): def get_by_nama(cls, nama, db_session=None):
return cls.query_nama(nama).first() if not db_session:
db_session = cls.db_session
return cls.query_nama(nama, db_session=db_session).first()
@classmethod @classmethod
def query_list(cls): def query_list(cls, db_session=None):
return cls.db_session.query(cls.id, cls.nama).order_by(cls.nama) if not db_session:
db_session = cls.db_session
return db_session.query(cls.id, cls.nama).order_by(cls.nama)
@classmethod @classmethod
def get_list(cls):
return cls.query_list().all()
\ No newline at end of file \ No newline at end of file
def get_list(cls, db_session=None):
if not db_session:
db_session = cls.db_session
return cls.query_list(db_session=db_session).all()
\ No newline at end of file \ No newline at end of file
Markdown is supported
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!