提示词优化,迟到,请假
This commit is contained in:
		@@ -26,7 +26,7 @@ class OpenAICompatibleLLM(VannaBase):
 | 
			
		||||
        VannaBase.__init__(self, config=config_file)
 | 
			
		||||
        # default parameters - can be overrided using config
 | 
			
		||||
        self.temperature = 0.6
 | 
			
		||||
        self.max_tokens = 5000
 | 
			
		||||
        self.max_tokens = 10000
 | 
			
		||||
 | 
			
		||||
        if "temperature" in config_file:
 | 
			
		||||
            self.temperature = config_file["temperature"]
 | 
			
		||||
@@ -125,17 +125,21 @@ class OpenAICompatibleLLM(VannaBase):
 | 
			
		||||
 | 
			
		||||
    def submit_prompt(self, prompt, **kwargs) -> str:
 | 
			
		||||
        if prompt is None:
 | 
			
		||||
            print("test1")
 | 
			
		||||
            raise Exception("Prompt is None")
 | 
			
		||||
 | 
			
		||||
        if len(prompt) == 0:
 | 
			
		||||
            print("test2")
 | 
			
		||||
            raise Exception("Prompt is empty")
 | 
			
		||||
        print(prompt)
 | 
			
		||||
 | 
			
		||||
        num_tokens = 0
 | 
			
		||||
        for message in prompt:
 | 
			
		||||
            num_tokens += len(message["content"]) / 4
 | 
			
		||||
        print("test3 {0}".format(num_tokens))
 | 
			
		||||
 | 
			
		||||
        if kwargs.get("model", None) is not None:
 | 
			
		||||
            print("test4")
 | 
			
		||||
            model = kwargs.get("model", None)
 | 
			
		||||
            print(
 | 
			
		||||
                f"Using model {model} for {num_tokens} tokens (approx)"
 | 
			
		||||
@@ -148,6 +152,7 @@ class OpenAICompatibleLLM(VannaBase):
 | 
			
		||||
                temperature=self.temperature,
 | 
			
		||||
            )
 | 
			
		||||
        elif kwargs.get("engine", None) is not None:
 | 
			
		||||
            print("test5")
 | 
			
		||||
            engine = kwargs.get("engine", None)
 | 
			
		||||
            print(
 | 
			
		||||
                f"Using model {engine} for {num_tokens} tokens (approx)"
 | 
			
		||||
@@ -160,6 +165,7 @@ class OpenAICompatibleLLM(VannaBase):
 | 
			
		||||
                temperature=self.temperature,
 | 
			
		||||
            )
 | 
			
		||||
        elif self.config is not None and "engine" in self.config:
 | 
			
		||||
            print("test6")
 | 
			
		||||
            print(
 | 
			
		||||
                f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)"
 | 
			
		||||
            )
 | 
			
		||||
@@ -171,10 +177,11 @@ class OpenAICompatibleLLM(VannaBase):
 | 
			
		||||
                temperature=self.temperature,
 | 
			
		||||
            )
 | 
			
		||||
        elif self.config is not None and "model" in self.config:
 | 
			
		||||
            print("test7")
 | 
			
		||||
            print(
 | 
			
		||||
                f"Using model {self.config['model']} for {num_tokens} tokens (approx)"
 | 
			
		||||
            )
 | 
			
		||||
            print(self.config)
 | 
			
		||||
            print("config is ",self.config)
 | 
			
		||||
            response = self.client.chat.completions.create(
 | 
			
		||||
                model=self.config["model"],
 | 
			
		||||
                messages=prompt,
 | 
			
		||||
@@ -194,6 +201,7 @@ class OpenAICompatibleLLM(VannaBase):
 | 
			
		||||
            #     json=data
 | 
			
		||||
            # )
 | 
			
		||||
        else:
 | 
			
		||||
            print("test8")
 | 
			
		||||
            if num_tokens > 3500:
 | 
			
		||||
                model = "kimi"
 | 
			
		||||
            else:
 | 
			
		||||
@@ -271,15 +279,26 @@ class OpenAICompatibleLLM(VannaBase):
 | 
			
		||||
            traceback.print_exc()
 | 
			
		||||
            raise e
 | 
			
		||||
 | 
			
		||||
    def run_sql_2(self,sql):
 | 
			
		||||
        try:
 | 
			
		||||
            return self.run_sql(sql)
 | 
			
		||||
        except Exception as e:
 | 
			
		||||
            logger.error("run_sql failed {0}".format(sql))
 | 
			
		||||
            raise e
 | 
			
		||||
 | 
			
		||||
    def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str:
 | 
			
		||||
        logger.info(f"generate_rewritten_question---------------{new_question}")
 | 
			
		||||
        if last_question is None:
 | 
			
		||||
            return new_question
 | 
			
		||||
 | 
			
		||||
        print("last question {0}".format(last_question))
 | 
			
		||||
        print("new question {0}".format(new_question))
 | 
			
		||||
        prompt = [
 | 
			
		||||
            self.system_message(
 | 
			
		||||
                "Your goal is to combine a sequence of questions into a singular question if they are related. If the second question does not relate to the first question and is fully self-contained, return the second question. Return just the new combined question with no additional explanations. The question should theoretically be answerable with a single SQL statement."),
 | 
			
		||||
            self.user_message(new_question),
 | 
			
		||||
                "你的目标是将一系列相关问题合并成一个单一的问题。"
 | 
			
		||||
                "合并准则一、如果第二个问题与第一个问题无关且本身是完整独立的,则直接返回第二个问题。"
 | 
			
		||||
                "合并准则二、如果第二个问题域第一个问题相关,且要基于第一个问题的前提,请合并两个问题为一个问题,只需返回合并后的新问题,不要添加任何额外解释。"
 | 
			
		||||
                "合并准则三、理论上,合并后的问题应该能够通过单个SQL语句来回答"),
 | 
			
		||||
            self.user_message("First question: " + last_question + "\nSecond question: " + new_question),
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
        return self.submit_prompt(prompt=prompt, **kwargs)
 | 
			
		||||
@@ -312,6 +331,11 @@ class CustomQdrant_VectorStore(Qdrant_VectorStore):
 | 
			
		||||
            "model": self.embedding_model_name,
 | 
			
		||||
            "sentences": [data],
 | 
			
		||||
        }
 | 
			
		||||
        # request_body = {
 | 
			
		||||
        #     "model": self.embedding_model_name,
 | 
			
		||||
        #     "encoding_format": "float",
 | 
			
		||||
        #     "input": [data],
 | 
			
		||||
        # }
 | 
			
		||||
        request_body.update(kwargs)
 | 
			
		||||
 | 
			
		||||
        response = requests.post(
 | 
			
		||||
@@ -325,7 +349,9 @@ class CustomQdrant_VectorStore(Qdrant_VectorStore):
 | 
			
		||||
            )
 | 
			
		||||
        result = response.json()
 | 
			
		||||
        embeddings = result['embeddings']
 | 
			
		||||
        # embeddings = result['data'][0]['embedding']
 | 
			
		||||
        return embeddings[0]
 | 
			
		||||
        # return embeddings
 | 
			
		||||
 | 
			
		||||
class CustomVanna(CustomQdrant_VectorStore, OpenAICompatibleLLM):
 | 
			
		||||
    def __init__(self, llm_config=None, vector_store_config=None):
 | 
			
		||||
@@ -412,6 +438,64 @@ class TTLCacheWrapper:
 | 
			
		||||
                return False
 | 
			
		||||
        return self.get(id=id, field=field) is not None
 | 
			
		||||
 | 
			
		||||
    def items(self):
 | 
			
		||||
        """遍历所有未过期的缓存键值对"""
 | 
			
		||||
        current_time = time.time()
 | 
			
		||||
        items = []
 | 
			
		||||
 | 
			
		||||
        for (id, field), expiry in self._expiry_times.items():
 | 
			
		||||
            # 检查是否过期
 | 
			
		||||
            if current_time <= expiry:
 | 
			
		||||
                value = self.get(id=id, field=field)
 | 
			
		||||
                if value is not None:
 | 
			
		||||
                    items.append({
 | 
			
		||||
                        'id': id,
 | 
			
		||||
                        'field': field,
 | 
			
		||||
                        'value': value,
 | 
			
		||||
                        'expires_at': expiry,
 | 
			
		||||
                        'time_left': expiry - current_time
 | 
			
		||||
                    })
 | 
			
		||||
 | 
			
		||||
        return items
 | 
			
		||||
 | 
			
		||||
    def get_latest_by_id(self, id: str, limit: int = 1, field_filter: str = None):
 | 
			
		||||
        """
 | 
			
		||||
        获取指定ID下时间最近的缓存项
 | 
			
		||||
 | 
			
		||||
        Args:
 | 
			
		||||
            id: 要查询的ID
 | 
			
		||||
            limit: 返回最近几条记录,默认1条
 | 
			
		||||
            field_filter: 可选的字段过滤,如只获取特定字段
 | 
			
		||||
 | 
			
		||||
        Returns:
 | 
			
		||||
            按时间倒序排列的缓存项列表
 | 
			
		||||
        """
 | 
			
		||||
        current_time = time.time()
 | 
			
		||||
        matched_items = []
 | 
			
		||||
 | 
			
		||||
        # 找出该ID下所有未过期的缓存项
 | 
			
		||||
        for (cache_id, field), expiry in self._expiry_times.items():
 | 
			
		||||
            if cache_id == id and current_time <= expiry:
 | 
			
		||||
                # 字段过滤
 | 
			
		||||
                if field_filter and field != field_filter:
 | 
			
		||||
                    continue
 | 
			
		||||
 | 
			
		||||
                value = self.get(id=id, field=field)
 | 
			
		||||
                if value is not None:
 | 
			
		||||
                    matched_items.append({
 | 
			
		||||
                        'id': id,
 | 
			
		||||
                        'field': field,
 | 
			
		||||
                        'value': value,
 | 
			
		||||
                        'expires_at': expiry,
 | 
			
		||||
                        'created_time': expiry - self.ttl,  # 估算创建时间
 | 
			
		||||
                        'time_left': expiry - current_time
 | 
			
		||||
                    })
 | 
			
		||||
 | 
			
		||||
        # 按过期时间倒序排列(最近创建的排在前面)
 | 
			
		||||
        matched_items.sort(key=lambda x: x['expires_at'], reverse=True)
 | 
			
		||||
 | 
			
		||||
        return matched_items[:limit]
 | 
			
		||||
 | 
			
		||||
    # 代理其他方法到原始cache
 | 
			
		||||
    def __getattr__(self, name):
 | 
			
		||||
        return getattr(self.cache, name)
 | 
			
		||||
		Reference in New Issue
	
	Block a user