提示词优化,迟到,请假

This commit is contained in:
yujj128
2025-11-01 10:16:06 +08:00
parent e72b24e7f7
commit 78fb9f4d54
7 changed files with 498 additions and 111 deletions

View File

@@ -26,7 +26,7 @@ class OpenAICompatibleLLM(VannaBase):
VannaBase.__init__(self, config=config_file)
# default parameters - can be overrided using config
self.temperature = 0.6
self.max_tokens = 5000
self.max_tokens = 10000
if "temperature" in config_file:
self.temperature = config_file["temperature"]
@@ -125,17 +125,21 @@ class OpenAICompatibleLLM(VannaBase):
def submit_prompt(self, prompt, **kwargs) -> str:
if prompt is None:
print("test1")
raise Exception("Prompt is None")
if len(prompt) == 0:
print("test2")
raise Exception("Prompt is empty")
print(prompt)
num_tokens = 0
for message in prompt:
num_tokens += len(message["content"]) / 4
print("test3 {0}".format(num_tokens))
if kwargs.get("model", None) is not None:
print("test4")
model = kwargs.get("model", None)
print(
f"Using model {model} for {num_tokens} tokens (approx)"
@@ -148,6 +152,7 @@ class OpenAICompatibleLLM(VannaBase):
temperature=self.temperature,
)
elif kwargs.get("engine", None) is not None:
print("test5")
engine = kwargs.get("engine", None)
print(
f"Using model {engine} for {num_tokens} tokens (approx)"
@@ -160,6 +165,7 @@ class OpenAICompatibleLLM(VannaBase):
temperature=self.temperature,
)
elif self.config is not None and "engine" in self.config:
print("test6")
print(
f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)"
)
@@ -171,10 +177,11 @@ class OpenAICompatibleLLM(VannaBase):
temperature=self.temperature,
)
elif self.config is not None and "model" in self.config:
print("test7")
print(
f"Using model {self.config['model']} for {num_tokens} tokens (approx)"
)
print(self.config)
print("config is ",self.config)
response = self.client.chat.completions.create(
model=self.config["model"],
messages=prompt,
@@ -194,6 +201,7 @@ class OpenAICompatibleLLM(VannaBase):
# json=data
# )
else:
print("test8")
if num_tokens > 3500:
model = "kimi"
else:
@@ -271,15 +279,26 @@ class OpenAICompatibleLLM(VannaBase):
traceback.print_exc()
raise e
def run_sql_2(self,sql):
try:
return self.run_sql(sql)
except Exception as e:
logger.error("run_sql failed {0}".format(sql))
raise e
def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str:
logger.info(f"generate_rewritten_question---------------{new_question}")
if last_question is None:
return new_question
print("last question {0}".format(last_question))
print("new question {0}".format(new_question))
prompt = [
self.system_message(
"Your goal is to combine a sequence of questions into a singular question if they are related. If the second question does not relate to the first question and is fully self-contained, return the second question. Return just the new combined question with no additional explanations. The question should theoretically be answerable with a single SQL statement."),
self.user_message(new_question),
"你的目标是将一系列相关问题合并成一个单一的问题。"
"合并准则一、如果第二个问题与第一个问题无关且本身是完整独立的,则直接返回第二个问题。"
"合并准则二、如果第二个问题域第一个问题相关,且要基于第一个问题的前提,请合并两个问题为一个问题,只需返回合并后的新问题,不要添加任何额外解释。"
"合并准则三、理论上合并后的问题应该能够通过单个SQL语句来回答"),
self.user_message("First question: " + last_question + "\nSecond question: " + new_question),
]
return self.submit_prompt(prompt=prompt, **kwargs)
@@ -312,6 +331,11 @@ class CustomQdrant_VectorStore(Qdrant_VectorStore):
"model": self.embedding_model_name,
"sentences": [data],
}
# request_body = {
# "model": self.embedding_model_name,
# "encoding_format": "float",
# "input": [data],
# }
request_body.update(kwargs)
response = requests.post(
@@ -325,7 +349,9 @@ class CustomQdrant_VectorStore(Qdrant_VectorStore):
)
result = response.json()
embeddings = result['embeddings']
# embeddings = result['data'][0]['embedding']
return embeddings[0]
# return embeddings
class CustomVanna(CustomQdrant_VectorStore, OpenAICompatibleLLM):
def __init__(self, llm_config=None, vector_store_config=None):
@@ -412,6 +438,64 @@ class TTLCacheWrapper:
return False
return self.get(id=id, field=field) is not None
def items(self):
"""遍历所有未过期的缓存键值对"""
current_time = time.time()
items = []
for (id, field), expiry in self._expiry_times.items():
# 检查是否过期
if current_time <= expiry:
value = self.get(id=id, field=field)
if value is not None:
items.append({
'id': id,
'field': field,
'value': value,
'expires_at': expiry,
'time_left': expiry - current_time
})
return items
def get_latest_by_id(self, id: str, limit: int = 1, field_filter: str = None):
"""
获取指定ID下时间最近的缓存项
Args:
id: 要查询的ID
limit: 返回最近几条记录默认1条
field_filter: 可选的字段过滤,如只获取特定字段
Returns:
按时间倒序排列的缓存项列表
"""
current_time = time.time()
matched_items = []
# 找出该ID下所有未过期的缓存项
for (cache_id, field), expiry in self._expiry_times.items():
if cache_id == id and current_time <= expiry:
# 字段过滤
if field_filter and field != field_filter:
continue
value = self.get(id=id, field=field)
if value is not None:
matched_items.append({
'id': id,
'field': field,
'value': value,
'expires_at': expiry,
'created_time': expiry - self.ttl, # 估算创建时间
'time_left': expiry - current_time
})
# 按过期时间倒序排列(最近创建的排在前面)
matched_items.sort(key=lambda x: x['expires_at'], reverse=True)
return matched_items[:limit]
# 代理其他方法到原始cache
def __getattr__(self, name):
return getattr(self.cache, name)