Merge branch 'dev' of gitlab-devt.yced.com.cn:lei_y601/sqlbot_agent into dev

# Conflicts:
#	main_service.py
#	service/cus_vanna_srevice.py
#	util/load_ddl_doc.py
This commit is contained in:
yujj128
2025-09-24 14:49:37 +08:00
6 changed files with 195 additions and 151 deletions

View File

@@ -1,6 +1,6 @@
from email.policy import default
from typing import List
import dmPython
from typing import List, Union
import dmPython
import orjson
import pandas as pd
from vanna.base import VannaBase
@@ -22,7 +22,7 @@ class OpenAICompatibleLLM(VannaBase):
# default parameters - can be overrided using config
self.temperature = 0.5
self.max_tokens = 5000
self.conn = None
if "temperature" in config_file:
self.temperature = config_file["temperature"]
@@ -54,6 +54,46 @@ class OpenAICompatibleLLM(VannaBase):
def system_message(self, message: str) -> any:
return {"role": "system", "content": message}
def connect_to_dameng(
self,
host: str = None,
dbname: str = None,
user: str = None,
password: str = None,
port: int = None,
**kwargs
):
conn = None
try:
conn = dmPython.connect(user=user, password=password, server=host, port=port)
except Exception as e:
raise Exception(f"Failed to connect to dameng database: {e}")
def run_sql_damengsql(sql: str) -> Union[pd.DataFrame, None]:
if conn:
try:
# conn.ping(reconnect=True)
cs = conn.cursor()
cs.execute(sql)
results = cs.fetchall()
# Create a pandas dataframe from the results
df = pd.DataFrame(
results, columns=[desc[0] for desc in cs.description]
)
return df
except Exception as e:
conn.rollback()
raise e
return None
self.run_sql_is_set = True
self.run_sql = run_sql_damengsql
def user_message(self, message: str) -> any:
return {"role": "user", "content": message}
@@ -140,21 +180,6 @@ class OpenAICompatibleLLM(VannaBase):
return response.choices[0].message.content
# def connect_to_dameng(self, host, port, username, password, database):
# try:
# self.conn = dmPython.connect(
# user=username,
# password=password,
# server=host,
# port=port, # 达梦默认端口5236
# autoCommit=True
# )
# print("达梦数据库连接成功")
# return True
# except Exception as e:
# print(f"连接失败: {e}")
# return False
def generate_sql_2(self, question: str, allow_llm_to_see_data=False, **kwargs) -> dict:
try:
question_sql_list = self.get_similar_question_sql(question, **kwargs)
@@ -164,7 +189,7 @@ class OpenAICompatibleLLM(VannaBase):
sql_temp = template['template']['sql']
char_temp = template['template']['chart']
# --------基于提示词生成sql以及图表类型
sys_temp = sql_temp['system'].format(engine=config("DATA_SOURCE_TYPE", default='mysql'), lang='中文',
sys_temp = sql_temp['system'].format(engine=config("DB_ENGINE", default='mysql'), lang='中文',
schema=ddl_list, documentation=doc_list,
data_training=question_sql_list)
print("sys_temp", sys_temp)
@@ -180,7 +205,7 @@ class OpenAICompatibleLLM(VannaBase):
# ---------------生成图表
char_type = get_chart_type_from_sql_answer(llm_response)
if char_type:
sys_char_temp = char_temp['system'].format(engine=config("DATA_SOURCE_TYPE", default='mysql'),
sys_char_temp = char_temp['system'].format(engine=config("DB_ENGINE", default='mysql'),
lang='中文', sql=sql, chart_type=char_type)
user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question)
llm_response2 = self.submit_prompt(