提示词优化,日志添加

This commit is contained in:
yujj128
2025-09-24 14:39:42 +08:00
parent 23916fc328
commit 2533126e36
5 changed files with 200 additions and 44 deletions

View File

@@ -1,6 +1,6 @@
from email.policy import default
from typing import List
import dmPython
import orjson
import pandas as pd
from vanna.base import VannaBase
@@ -22,7 +22,7 @@ class OpenAICompatibleLLM(VannaBase):
# default parameters - can be overrided using config
self.temperature = 0.5
self.max_tokens = 5000
self.conn = None
if "temperature" in config_file:
self.temperature = config_file["temperature"]
@@ -140,36 +140,57 @@ class OpenAICompatibleLLM(VannaBase):
return response.choices[0].message.content
# def connect_to_dameng(self, host, port, username, password, database):
# try:
# self.conn = dmPython.connect(
# user=username,
# password=password,
# server=host,
# port=port, # 达梦默认端口5236
# autoCommit=True
# )
# print("达梦数据库连接成功")
# return True
# except Exception as e:
# print(f"连接失败: {e}")
# return False
def generate_sql_2(self, question: str, allow_llm_to_see_data=False, **kwargs) -> dict:
question_sql_list = self.get_similar_question_sql(question, **kwargs)
ddl_list = self.get_related_ddl(question, **kwargs)
doc_list = self.get_related_documentation(question, **kwargs)
template = get_base_template()
sql_temp = template['template']['sql']
char_temp = template['template']['chart']
# --------基于提示词生成sql以及图表类型
sys_temp = sql_temp['system'].format(engine=config("DATA_SOURCE_TYPE",default='mysql'), lang='中文', schema=ddl_list, documentation=doc_list,
data_training=question_sql_list)
print("sys_temp", sys_temp)
user_temp = sql_temp['user'].format(question=question,
current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
print("user_temp", user_temp)
llm_response = self.submit_prompt(
[{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs)
print(llm_response)
result = {"resp": orjson.loads(extract_nested_json(llm_response))}
print("result", result)
sql = check_and_get_sql(llm_response)
# ---------------生成图表
char_type = get_chart_type_from_sql_answer(llm_response)
if char_type:
sys_char_temp = char_temp['system'].format(engine=config("DATA_SOURCE_TYPE",default='mysql'), lang='中文', sql=sql, chart_type=char_type)
user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question)
llm_response2 = self.submit_prompt(
[{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}], **kwargs)
print(llm_response2)
result['chart'] = orjson.loads(extract_nested_json(llm_response2))
return result
try:
question_sql_list = self.get_similar_question_sql(question, **kwargs)
ddl_list = self.get_related_ddl(question, **kwargs)
doc_list = self.get_related_documentation(question, **kwargs)
template = get_base_template()
sql_temp = template['template']['sql']
char_temp = template['template']['chart']
# --------基于提示词生成sql以及图表类型
sys_temp = sql_temp['system'].format(engine=config("DATA_SOURCE_TYPE", default='mysql'), lang='中文',
schema=ddl_list, documentation=doc_list,
data_training=question_sql_list)
print("sys_temp", sys_temp)
user_temp = sql_temp['user'].format(question=question,
current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
print("user_temp", user_temp)
llm_response = self.submit_prompt(
[{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs)
print(llm_response)
result = {"resp": orjson.loads(extract_nested_json(llm_response))}
print("result", result)
sql = check_and_get_sql(llm_response)
# ---------------生成图表
char_type = get_chart_type_from_sql_answer(llm_response)
if char_type:
sys_char_temp = char_temp['system'].format(engine=config("DATA_SOURCE_TYPE", default='mysql'),
lang='中文', sql=sql, chart_type=char_type)
user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question)
llm_response2 = self.submit_prompt(
[{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}],
**kwargs)
print(llm_response2)
result['chart'] = orjson.loads(extract_nested_json(llm_response2))
return result
except Exception as e:
raise e
def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str:
print("new_question---------------", new_question)