Skip to content

Latest commit

 

History

History
78 lines (64 loc) · 3.58 KB

sqlalchemy.md

File metadata and controls

78 lines (64 loc) · 3.58 KB

Flask SQLAlchemy的使用

from json import loads as json_loads
from flask import Flask
from flask_sqlalchemy import SQLAlchemy
from datetime import datetime
from model_flow.settings import DBConfig, JobStatus, DoStatus

app = Flask(__name__)
app.config['SQLALCHEMY_DATABASE_URI'] = DBConfig.URI
app.config['SQLALCHEMY_TRACK_MODIFICATIONS'] = False

db = SQLAlchemy(app)
tb_prefix = DBConfig.TB_PREFIX


class JobInfo(db.Model):
    """作业基本信息表"""
    __tablename__ = tb_prefix + 'job_info'

    job_id = db.Column(db.CHAR(16), primary_key=True, comment='唯一ID')
    user_id = db.Column(db.CHAR(16), nullable=False, index=True, comment='用户ID')
    task_id = db.Column(db.CHAR(16), nullable=False, index=True, comment='任务ID')
    node_id = db.Column(db.CHAR(16), nullable=False, index=True, comment='节点ID')
    algo_type = db.Column(db.String(20), index=True, comment='算法名称')
    algo_name = db.Column(db.String(20), index=True, comment='算法名称')
    features = db.Column(db.String(1023), nullable=False, comment='特征字段名列表,英文逗号分隔')
    target = db.Column(db.String(50), index=True, comment='目标字段名列表,英文逗号分隔')
    status = db.Column(db.Enum(JobStatus), default=JobStatus.INIT, comment='作业状态')
    train_status = db.Column(db.Enum(DoStatus), default=DoStatus.INIT, comment='训练进度状态')
    test_status = db.Column(db.Enum(DoStatus), default=DoStatus.INIT, comment='测试进度状态')
    dataset_conf = db.Column(db.String(1023), comment='数据配置, json格式')
    train_begin = db.Column(db.DateTime, comment='训练开始时间')
    train_end = db.Column(db.DateTime, comment='训练结束时间')
    created_at = db.Column(db.DateTime, default=datetime.now, comment='创建时间')
    updated_at = db.Column(db.DateTime, index=True, default=datetime.now, onupdate=datetime.now, comment='最后更新时间')

    def to_dict(self):
        row = {
            'job_id': self.job_id,
            'algo_name': self.algo_name,
            'features': self.features,
            'target': self.target,
            'status': self.status.name,
            'train_status': self.train_status.name,
            'test_status': self.test_status.name,
            'dataset_conf': json_loads(self.dataset_conf),
            'created_at': self.created_at.strftime(DBConfig.DATETIME_FORMAT),
            'updated_at': self.updated_at.strftime(DBConfig.DATETIME_FORMAT),
        }
        if self.train_begin is not None:
            row['train_begin'] = self.train_begin.strftime(DBConfig.DATETIME_FORMAT)
        if self.train_end is not None:
            row['train_end'] = self.train_end.strftime(DBConfig.DATETIME_FORMAT)

        return row


class Metric(db.Model):
    """模型评估指标记录表"""
    __tablename__ = tb_prefix + 'metric'
    __table_args__ = (db.PrimaryKeyConstraint('job_id', 'key'), )

    job_id = db.Column(db.CHAR(16), db.ForeignKey(JobInfo.job_id), comment='job唯一ID')
    key = db.Column(db.String(20), nullable=False, comment='指标key')
    user_id = db.Column(db.CHAR(16), nullable=False, index=True, comment='用户ID')
    task_id = db.Column(db.CHAR(16), nullable=False, comment='任务ID')
    node_id = db.Column(db.CHAR(16), nullable=False, comment='节点ID')
    value = db.Column(db.Float, nullable=False, comment='指标value')
    created_at = db.Column(db.DateTime, default=datetime.now, comment='创建时间')

# 增加索引
db.Index('ix_metric_task_key_val', Metric.task_id, Metric.key, Metric.value)
db.Index('ix_metric_node_key_val', Metric.node_id, Metric.key, Metric.value)