缓存上下文,表结构添加

This commit is contained in:
yujj128
2025-10-13 18:18:58 +08:00
parent 73cbc55d74
commit be0bc661e2
4 changed files with 371 additions and 24 deletions

View File

@@ -1,3 +1,4 @@
from dataclasses import field
from email.policy import default
from typing import List, Union, Any, Optional
import time
@@ -76,10 +77,10 @@ class OpenAICompatibleLLM(VannaBase):
def run_sql_damengsql(sql: str) -> Union[pd.DataFrame, None]:
logger.info(f"start to run_sql_damengsql")
if not is_connection_alive(conn=self.conn):
logger.info("connection is not alive, reconnecting..........")
reconnect()
try:
if not is_connection_alive(conn=self.conn):
logger.info("connection is not alive, reconnecting..........")
reconnect()
# conn.ping(reconnect=True)
cs = self.conn.cursor()
cs.execute(sql)
@@ -203,7 +204,7 @@ class OpenAICompatibleLLM(VannaBase):
return response.choices[0].message.content
def generate_sql_2(self, question: str, allow_llm_to_see_data=False, **kwargs) -> dict:
def generate_sql_2(self, question: str, cache=None,user_id=None, allow_llm_to_see_data=False, **kwargs) -> dict:
try:
logger.info("Start to generate_sql_2 in cus_vanna_srevice")
question_sql_list = self.get_similar_question_sql(question, **kwargs)
@@ -215,15 +216,19 @@ class OpenAICompatibleLLM(VannaBase):
template = get_base_template()
sql_temp = template['template']['sql']
char_temp = template['template']['chart']
history = None
if user_id and cache:
history = cache.get(id=user_id, field="data")
# --------基于提示词生成sql以及图表类型
sys_temp = sql_temp['system'].format(engine=config("DB_ENGINE", default='mysql'), lang='中文',
schema=ddl_list, documentation=[train_ddl.train_document],
retrieved_examples_data=question_sql_list,
history=history,retrieved_examples_data=question_sql_list,
data_training=question_sql_list,)
logger.info(f"sys_temp:{sys_temp}")
user_temp = sql_temp['user'].format(question=question,
current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
logger.info(f"user_temp:{user_temp}")
logger.info(f"sys_temp:{sys_temp}")
llm_response = self.submit_prompt(
[{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs)
logger.info(f"llm_response:{llm_response}")