feat:修改ddql,增加qa问答

This commit is contained in:
雷雨
2025-09-28 16:44:58 +08:00
parent dc16bfaa8e
commit 6064710f4f
7 changed files with 396 additions and 641 deletions

View File

@@ -17,7 +17,7 @@ from datetime import datetime
import logging
from util import train_ddl
logger = logging.getLogger(__name__)
import traceback
class OpenAICompatibleLLM(VannaBase):
def __init__(self, client=None, config_file=None):
VannaBase.__init__(self, config=config_file)
@@ -186,6 +186,9 @@ class OpenAICompatibleLLM(VannaBase):
try:
logger.info("Start to generate_sql_2 in cus_vanna_srevice")
question_sql_list = self.get_similar_question_sql(question, **kwargs)
if question_sql_list and len(question_sql_list)>2:
question_sql_list=question_sql_list[:1]
ddl_list = self.get_related_ddl(question, **kwargs)
#doc_list = self.get_related_documentation(question, **kwargs)
template = get_base_template()
@@ -194,7 +197,8 @@ class OpenAICompatibleLLM(VannaBase):
# --------基于提示词生成sql以及图表类型
sys_temp = sql_temp['system'].format(engine=config("DB_ENGINE", default='mysql'), lang='中文',
schema=ddl_list, documentation=[train_ddl.train_document],
data_training=question_sql_list)
retrieved_examples_data=question_sql_list,
data_training=question_sql_list,)
logger.info(f"sys_temp:{sys_temp}")
user_temp = sql_temp['user'].format(question=question,
current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
@@ -223,7 +227,8 @@ class OpenAICompatibleLLM(VannaBase):
logger.info("Finish to generate_sql_2 in cus_vanna_srevice")
return result
except Exception as e:
logger.info("cus_vanna_srevice failed-------------------")
logger.info("cus_vanna_srevice failed-------------------: ")
traceback.print_exc()
raise e
def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str: