diff --git a/.env b/.env
index 018084d..8a8988e 100644
--- a/.env
+++ b/.env
@@ -9,8 +9,9 @@ EMBEDDING_MODEL_API_KEY=sk-cjiakyfpzamtxgxitcbxvwyvaulnygmyxqpykkgngsvfqhuy
EMBEDDING_MODEL_NAME=Qwen/Qwen3-Embedding-8B
#mysql ,sqlite,pg等
-DATA_SOURCE_TYPE=mysql
-
+DATA_SOURCE_TYPE=dameng
+#数据库类型
+DB_ENGINE=达梦数据库
#sqlite 连接信息
SQLITE_DATABASE_URL=E://db/db_flights.sqlite
@@ -20,3 +21,10 @@ MYSQL_DATABASE_PORT=3306
MYSQL_DATABASE_PASSWORD=Admin1234!
MYSQL_DATABASE_USER=yu
MYSQL_DATABASE_DBNAME=test
+
+
+#达梦数据库
+DAMENG_DATABASE_HOST=10.254.192.191
+DAMENG_DATABASE_PORT=5236
+DAMENG_DATABASE_PASSWORD=SYSDBA
+DAMENG_DATABASE_USER=SYSDBA
diff --git a/main_service.py b/main_service.py
index 5c18e13..ce36856 100644
--- a/main_service.py
+++ b/main_service.py
@@ -13,17 +13,23 @@ def connect_database(vn):
vn.connect_to_sqlite(config('SQLITE_DATABASE_URL', default=''))
elif db_type == 'mysql':
vn.connect_to_mysql(host=config('MYSQL_DATABASE_HOST', default=''),
- port=int(config('MYSQL_DATABASE_PORT', default=3306)),
+ port=config('MYSQL_DATABASE_PORT', default=3306),
user=config('MYSQL_DATABASE_USER', default=''),
password=config('MYSQL_DATABASE_PASSWORD', default=''),
- dbname=config('MYSQL_DATABASE_DBNAME', default=''))
- elif db_type == 'postgresql':
+ database=config('MYSQL_DATABASE_DBNAME', default=''))
+ elif db_type == 'dameng':
# 待补充
- pass
+ vn.connect_to_dameng(
+ host=config('DAMENG_DATABASE_HOST', default=''),
+ port=config('DAMENG_DATABASE_PORT', default=3306),
+ user=config('DAMENG_DATABASE_USER', default=''),
+ password=config('DAMENG_DATABASE_PASSWORD', default=''),
+ )
else:
pass
+
def load_train_data_ddl(vn: CustomVanna):
vn.train()
diff --git a/service/cus_vanna_srevice.py b/service/cus_vanna_srevice.py
index 22375cd..19c7eea 100644
--- a/service/cus_vanna_srevice.py
+++ b/service/cus_vanna_srevice.py
@@ -1,6 +1,6 @@
from email.policy import default
-from typing import List
-
+from typing import List, Union
+import dmPython
import orjson
import pandas as pd
from vanna.base import VannaBase
@@ -54,6 +54,46 @@ class OpenAICompatibleLLM(VannaBase):
def system_message(self, message: str) -> any:
return {"role": "system", "content": message}
+ def connect_to_dameng(
+ self,
+ host: str = None,
+ dbname: str = None,
+ user: str = None,
+ password: str = None,
+ port: int = None,
+ **kwargs
+ ):
+ conn = None
+ try:
+ conn = dmPython.connect(user=user, password=password, server=host, port=port)
+ except Exception as e:
+ raise Exception(f"Failed to connect to dameng database: {e}")
+
+ def run_sql_damengsql(sql: str) -> Union[pd.DataFrame, None]:
+ if conn:
+ try:
+ # conn.ping(reconnect=True)
+ cs = conn.cursor()
+ cs.execute(sql)
+ results = cs.fetchall()
+
+ # Create a pandas dataframe from the results
+ df = pd.DataFrame(
+ results, columns=[desc[0] for desc in cs.description]
+ )
+
+ return df
+
+
+
+ except Exception as e:
+ conn.rollback()
+ raise e
+ return None
+
+ self.run_sql_is_set = True
+ self.run_sql = run_sql_damengsql
+
def user_message(self, message: str) -> any:
return {"role": "user", "content": message}
@@ -148,7 +188,7 @@ class OpenAICompatibleLLM(VannaBase):
sql_temp = template['template']['sql']
char_temp = template['template']['chart']
# --------基于提示词,生成sql以及图表类型
- sys_temp = sql_temp['system'].format(engine=config("DATA_SOURCE_TYPE",default='mysql'), lang='中文', schema=ddl_list, documentation=doc_list,
+ sys_temp = sql_temp['system'].format(engine=config("DB_ENGINE",default='mysql'), lang='中文', schema=ddl_list, documentation=doc_list,
data_training=question_sql_list)
print("sys_temp", sys_temp)
user_temp = sql_temp['user'].format(question=question,
@@ -163,7 +203,7 @@ class OpenAICompatibleLLM(VannaBase):
# ---------------生成图表
char_type = get_chart_type_from_sql_answer(llm_response)
if char_type:
- sys_char_temp = char_temp['system'].format(engine=config("DATA_SOURCE_TYPE",default='mysql'), lang='中文', sql=sql, chart_type=char_type)
+ sys_char_temp = char_temp['system'].format(engine=config("DB_ENGINE",default='mysql'), lang='中文', sql=sql, chart_type=char_type)
user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question)
llm_response2 = self.submit_prompt(
[{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}], **kwargs)
diff --git a/util/load_ddl_doc.py b/util/load_ddl_doc.py
index a1e2d5a..1ac1fed 100644
--- a/util/load_ddl_doc.py
+++ b/util/load_ddl_doc.py
@@ -1,131 +1,19 @@
from service.cus_vanna_srevice import CustomVanna
-# table_ddls = [
-# """
-# create table db_user
-# (
-# id integer not null
-# constraint db_user_pk
-# primary key autoincrement,
-# user_name TEXT not null,
-# age integer not null,
-# address TEXT,
-# gender integer not null,
-# email TEXT
-# )
-# """,
-# ]
-# list_documentions = [
-# """
-# gender 字段 0代表女性,1代表男性;
-# 查询address时,尽量使用like查询,如:select * from db_user where address like '%北京%';
-# 语法为sqlite语法;
-# """,
-# ]
+from util import train_ddl
+
table_ddls = [
- """
- CREATE TABLE 人员库表 (
- id VARCHAR(22) PRIMARY KEY COMMENT '主键',
- name VARCHAR(600) DEFAULT NULL COMMENT '姓名',
- gender VARCHAR(108) DEFAULT NULL COMMENT '性别',
- id_card_type VARCHAR(108) DEFAULT NULL COMMENT '身份证件类型',
- id_card VARCHAR(600) DEFAULT NULL COMMENT '身份证号码',
- birthday VARCHAR(30) DEFAULT NULL COMMENT '出生日期',
- native_place TEXT DEFAULT NULL COMMENT '籍贯',
- nation TEXT DEFAULT NULL COMMENT '民族',
- country TEXT DEFAULT NULL COMMENT '国籍',
- residence_address TEXT DEFAULT NULL COMMENT '户籍地址',
- highest_education VARCHAR(108) DEFAULT NULL COMMENT '最高学历',
- highest_degree TEXT DEFAULT NULL COMMENT '最高学位',
- graduate_school TEXT DEFAULT NULL COMMENT '毕业院校',
- political_status TEXT DEFAULT NULL COMMENT '政治面貌',
- phone_number TEXT DEFAULT NULL COMMENT '手机号',
- email VARCHAR(600) DEFAULT NULL COMMENT '电子邮箱',
- worker_id VARCHAR(200) DEFAULT NULL COMMENT '工号',
- post TEXT DEFAULT NULL COMMENT '职务',
- engage_post TEXT DEFAULT NULL COMMENT '现从事岗位',
- work_unit TEXT DEFAULT NULL COMMENT '工作单位全称',
- work_content TEXT DEFAULT NULL COMMENT '工作内容',
- engage_contract_no VARCHAR(600) DEFAULT NULL COMMENT '从事项目合同编号',
- engage_contract_name VARCHAR(600) DEFAULT NULL COMMENT '从事项目合同名称',
- is_subcontractor VARCHAR(108) DEFAULT NULL COMMENT '是否分包商',
- general_contractor_unit VARCHAR(600) DEFAULT NULL COMMENT '总包单位全称',
- office_city TEXT DEFAULT NULL COMMENT '办公城市',
- office_address TEXT DEFAULT NULL COMMENT '办公地点',
- person_type TEXT DEFAULT NULL COMMENT '人员类型',
- person_status VARCHAR(108) DEFAULT NULL COMMENT '人员状态',
- is_internal VARCHAR(108) DEFAULT NULL COMMENT '是否内部员工',
- internal_unit VARCHAR(108) DEFAULT NULL COMMENT '内部单位',
- internal_dept VARCHAR(108) DEFAULT NULL COMMENT '内部部门',
- external_unit VARCHAR(600) DEFAULT NULL COMMENT '外部单位',
- external_dept VARCHAR(600) DEFAULT NULL COMMENT '外部部门',
- to_dept VARCHAR(600) DEFAULT NULL COMMENT '所属处室',
- pass_type VARCHAR(108) DEFAULT NULL COMMENT '通行证类型',
- entry_date VARCHAR(30) DEFAULT NULL COMMENT '入场日期',
- expected_departure_date VARCHAR(30) DEFAULT NULL COMMENT '预计离场日期',
- expire_time DATETIME DEFAULT NULL COMMENT '失效时间',
- verifystate INT DEFAULT NULL COMMENT '单据状态',
- auditor VARCHAR(180) DEFAULT NULL COMMENT '终审审批人',
- auditor1 VARCHAR(36) DEFAULT NULL COMMENT '处室负责人',
- auditnote VARCHAR(200) DEFAULT NULL COMMENT '当前审批人',
- procinst_id VARCHAR(36) DEFAULT NULL COMMENT '流程实例ID',
- bizflow_id VARCHAR(36) DEFAULT NULL COMMENT '业务流id',
- bizflowname VARCHAR(200) DEFAULT NULL COMMENT '流程名称',
- bizflow_makebillcode VARCHAR(200) DEFAULT NULL COMMENT '单据转换规则编码',
- bizflowinstance_id VARCHAR(36) DEFAULT NULL COMMENT '业务流实例id',
- sourcegrand_id VARCHAR(108) DEFAULT NULL COMMENT '来源孙表id',
- first_id VARCHAR(108) DEFAULT NULL COMMENT '来源单据主表id',
- firstchild_id VARCHAR(108) DEFAULT NULL COMMENT '来源单据子表id',
- firstbusiobj VARCHAR(108) DEFAULT NULL COMMENT '来源业务对象',
- firstcode TEXT DEFAULT NULL COMMENT '来源单据号',
- source_id VARCHAR(36) DEFAULT NULL COMMENT '上游单据主表id',
- sourcechild_id VARCHAR(36) DEFAULT NULL COMMENT '上游单据子表id',
- sourcebusiobj VARCHAR(36) DEFAULT NULL COMMENT '上游业务对象',
- sourcecode VARCHAR(200) DEFAULT NULL COMMENT '上游单据号',
- code TEXT DEFAULT NULL COMMENT '编码',
- ytenant_id VARCHAR(64) DEFAULT NULL COMMENT '租户id',
- photo TEXT DEFAULT NULL COMMENT '照片',
- input_time DATETIME DEFAULT NULL COMMENT '录入时间',
- create_time DATETIME DEFAULT NULL COMMENT '创建时间',
- modify_time DATETIME DEFAULT NULL COMMENT '修改时间',
- audit_time DATETIME DEFAULT NULL COMMENT '审批日期',
- input_user VARCHAR(108) DEFAULT NULL COMMENT '录入人',
- input_dept VARCHAR(108) DEFAULT NULL COMMENT '录入部门',
- creator VARCHAR(60) DEFAULT NULL COMMENT '创建人',
- modifier VARCHAR(60) DEFAULT NULL COMMENT '修改人',
- sort INT DEFAULT NULL COMMENT '排序',
- dr INT DEFAULT 0 COMMENT '逻辑删除:0-未删除,1-已删除',
- DHDATASTA INT DEFAULT NULL COMMENT '推送状态',
- pubts DATETIME DEFAULT NULL COMMENT '发布时间戳(或其他时间戳)'
- ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci COMMENT='人员信息表';
- """,
+ train_ddl.ddl_sql,
]
list_documentions = [
- """
- <人员库表注意事项>
-