501 lines
		
	
	
		
			19 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			501 lines
		
	
	
		
			19 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
 | 
						||
from typing import List, Union, Any, Optional
 | 
						||
import time
 | 
						||
import threading
 | 
						||
from vanna.flask import Cache, MemoryCache
 | 
						||
import  dmPython
 | 
						||
import orjson
 | 
						||
import pandas as pd
 | 
						||
from vanna.base import VannaBase
 | 
						||
from vanna.flask import MemoryCache
 | 
						||
from vanna.qdrant import Qdrant_VectorStore
 | 
						||
from qdrant_client import QdrantClient
 | 
						||
from openai import OpenAI
 | 
						||
import requests
 | 
						||
from decouple import config
 | 
						||
from util.utils import extract_nested_json, check_and_get_sql, get_chart_type_from_sql_answer
 | 
						||
import json
 | 
						||
from template.template import get_base_template
 | 
						||
from datetime import datetime
 | 
						||
import logging
 | 
						||
from util import  train_ddl
 | 
						||
logger = logging.getLogger(__name__)
 | 
						||
import traceback
 | 
						||
class OpenAICompatibleLLM(VannaBase):
 | 
						||
    def __init__(self, client=None, config_file=None):
 | 
						||
        VannaBase.__init__(self, config=config_file)
 | 
						||
        # default parameters - can be overrided using config
 | 
						||
        self.temperature = 0.6
 | 
						||
        self.max_tokens = 10000
 | 
						||
 | 
						||
        if "temperature" in config_file:
 | 
						||
            self.temperature = config_file["temperature"]
 | 
						||
 | 
						||
        if "max_tokens" in config_file:
 | 
						||
            self.max_tokens = config_file["max_tokens"]
 | 
						||
 | 
						||
        if "api_type" in config_file:
 | 
						||
            raise Exception(
 | 
						||
                "Passing api_type is now deprecated. Please pass an OpenAI client instead."
 | 
						||
            )
 | 
						||
 | 
						||
        if "api_version" in config_file:
 | 
						||
            raise Exception(
 | 
						||
                "Passing api_version is now deprecated. Please pass an OpenAI client instead."
 | 
						||
            )
 | 
						||
 | 
						||
        if client is not None:
 | 
						||
            self.client = client
 | 
						||
            return
 | 
						||
 | 
						||
        if "api_base" not in config_file:
 | 
						||
            raise Exception("Please passing api_base")
 | 
						||
 | 
						||
        if "api_key" not in config_file:
 | 
						||
            raise Exception("Please passing api_key")
 | 
						||
 | 
						||
        self.client = OpenAI(api_key=config_file["api_key"], base_url=config_file["api_base"])
 | 
						||
 | 
						||
    def system_message(self, message: str) -> any:
 | 
						||
        return {"role": "system", "content": message}
 | 
						||
 | 
						||
    def connect_to_dameng(
 | 
						||
            self,
 | 
						||
            host: str = None,
 | 
						||
            dbname: str = None,
 | 
						||
            user: str = None,
 | 
						||
            password: str = None,
 | 
						||
            port: int = None,
 | 
						||
            **kwargs
 | 
						||
    ):
 | 
						||
        self.conn = None
 | 
						||
        try:
 | 
						||
            self.conn = dmPython.connect(user=user, password=password, server=host, port=port)
 | 
						||
        except Exception as e:
 | 
						||
            raise Exception(f"Failed to connect to dameng database: {e}")
 | 
						||
 | 
						||
        def run_sql_damengsql(sql: str) -> Union[pd.DataFrame, None]:
 | 
						||
            logger.info(f"start to run_sql_damengsql")
 | 
						||
            try:
 | 
						||
                if not is_connection_alive(conn=self.conn):
 | 
						||
                    logger.info("connection is not alive, reconnecting..........")
 | 
						||
                    reconnect()
 | 
						||
                # conn.ping(reconnect=True)
 | 
						||
                cs = self.conn.cursor()
 | 
						||
                cs.execute(sql)
 | 
						||
                results = cs.fetchall()
 | 
						||
 | 
						||
                # Create a pandas dataframe from the results
 | 
						||
                df = pd.DataFrame(
 | 
						||
                    results, columns=[desc[0] for desc in cs.description]
 | 
						||
                )
 | 
						||
 | 
						||
                return df
 | 
						||
            except Exception as e:
 | 
						||
                self.conn.rollback()
 | 
						||
                logger.error(f"Failed to execute sql query: {e}")
 | 
						||
                raise e
 | 
						||
            return None
 | 
						||
 | 
						||
        def reconnect():
 | 
						||
            try:
 | 
						||
                self.conn = dmPython.connect(user=user, password=password, server=host, port=port)
 | 
						||
            except Exception as e:
 | 
						||
                raise Exception(f"reconnect failed: {e}")
 | 
						||
        def is_connection_alive(conn) -> bool:
 | 
						||
            if conn is None:
 | 
						||
                return False
 | 
						||
            try:
 | 
						||
                cursor = conn.cursor()
 | 
						||
                cursor.execute("SELECT 1 FROM DUAL")
 | 
						||
                cursor.close()
 | 
						||
                return True
 | 
						||
            except Exception as e:
 | 
						||
                return False
 | 
						||
 | 
						||
        self.run_sql_is_set = True
 | 
						||
        self.run_sql = run_sql_damengsql
 | 
						||
 | 
						||
 | 
						||
    def user_message(self, message: str) -> any:
 | 
						||
        return {"role": "user", "content": message}
 | 
						||
 | 
						||
    def assistant_message(self, message: str) -> any:
 | 
						||
        return {"role": "assistant", "content": message}
 | 
						||
 | 
						||
    def submit_prompt(self, prompt, **kwargs) -> str:
 | 
						||
        if prompt is None:
 | 
						||
            print("test1")
 | 
						||
            raise Exception("Prompt is None")
 | 
						||
 | 
						||
        if len(prompt) == 0:
 | 
						||
            print("test2")
 | 
						||
            raise Exception("Prompt is empty")
 | 
						||
        print(prompt)
 | 
						||
 | 
						||
        num_tokens = 0
 | 
						||
        for message in prompt:
 | 
						||
            num_tokens += len(message["content"]) / 4
 | 
						||
        print("test3 {0}".format(num_tokens))
 | 
						||
 | 
						||
        if kwargs.get("model", None) is not None:
 | 
						||
            print("test4")
 | 
						||
            model = kwargs.get("model", None)
 | 
						||
            print(
 | 
						||
                f"Using model {model} for {num_tokens} tokens (approx)"
 | 
						||
            )
 | 
						||
            response = self.client.chat.completions.create(
 | 
						||
                model=model,
 | 
						||
                messages=prompt,
 | 
						||
                max_tokens=self.max_tokens,
 | 
						||
                stop=None,
 | 
						||
                temperature=self.temperature,
 | 
						||
            )
 | 
						||
        elif kwargs.get("engine", None) is not None:
 | 
						||
            print("test5")
 | 
						||
            engine = kwargs.get("engine", None)
 | 
						||
            print(
 | 
						||
                f"Using model {engine} for {num_tokens} tokens (approx)"
 | 
						||
            )
 | 
						||
            response = self.client.chat.completions.create(
 | 
						||
                engine=engine,
 | 
						||
                messages=prompt,
 | 
						||
                max_tokens=self.max_tokens,
 | 
						||
                stop=None,
 | 
						||
                temperature=self.temperature,
 | 
						||
            )
 | 
						||
        elif self.config is not None and "engine" in self.config:
 | 
						||
            print("test6")
 | 
						||
            print(
 | 
						||
                f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)"
 | 
						||
            )
 | 
						||
            response = self.client.chat.completions.create(
 | 
						||
                engine=self.config["engine"],
 | 
						||
                messages=prompt,
 | 
						||
                max_tokens=self.max_tokens,
 | 
						||
                stop=None,
 | 
						||
                temperature=self.temperature,
 | 
						||
            )
 | 
						||
        elif self.config is not None and "model" in self.config:
 | 
						||
            print("test7")
 | 
						||
            print(
 | 
						||
                f"Using model {self.config['model']} for {num_tokens} tokens (approx)"
 | 
						||
            )
 | 
						||
            print("config is ",self.config)
 | 
						||
            response = self.client.chat.completions.create(
 | 
						||
                model=self.config["model"],
 | 
						||
                messages=prompt,
 | 
						||
                max_tokens=self.max_tokens,
 | 
						||
                stop=None,
 | 
						||
                temperature=self.temperature,
 | 
						||
            )
 | 
						||
            # data = {
 | 
						||
            #     "model": self.model,
 | 
						||
            #     "prompt": prompt,
 | 
						||
            #     "max_tokens": self.max_tokens,
 | 
						||
            #     "temperature": self.temperature,
 | 
						||
            # }
 | 
						||
            # response = requests.post(
 | 
						||
            #     url=f"{self.api_base}/completions",
 | 
						||
            #     headers=self.headers,
 | 
						||
            #     json=data
 | 
						||
            # )
 | 
						||
        else:
 | 
						||
            print("test8")
 | 
						||
            if num_tokens > 3500:
 | 
						||
                model = "kimi"
 | 
						||
            else:
 | 
						||
                model = "doubao"
 | 
						||
 | 
						||
            print(f"5.Using model {model} for {num_tokens} tokens (approx)")
 | 
						||
 | 
						||
            response = self.client.chat.completions.create(
 | 
						||
                model=model,
 | 
						||
                messages=prompt,
 | 
						||
                max_tokens=self.max_tokens,
 | 
						||
                stop=None,
 | 
						||
                temperature=self.temperature,
 | 
						||
            )
 | 
						||
        for choice in response.choices:
 | 
						||
            if "text" in choice:
 | 
						||
                return choice.text
 | 
						||
 | 
						||
        return response.choices[0].message.content
 | 
						||
 | 
						||
    def generate_sql_2(self, question: str, cache=None,user_id=None, allow_llm_to_see_data=False, **kwargs) -> dict:
 | 
						||
        try:
 | 
						||
            logger.info("Start to generate_sql_2 in cus_vanna_srevice")
 | 
						||
            question_sql_list = self.get_similar_question_sql(question, **kwargs)
 | 
						||
            if question_sql_list and len(question_sql_list)>2:
 | 
						||
                question_sql_list=question_sql_list[:2]
 | 
						||
 | 
						||
            ddl_list = self.get_related_ddl(question, **kwargs)
 | 
						||
            #doc_list = self.get_related_documentation(question, **kwargs)
 | 
						||
            template = get_base_template()
 | 
						||
            sql_temp = template['template']['sql']
 | 
						||
            char_temp = template['template']['chart']
 | 
						||
            history = None
 | 
						||
            if user_id and cache:
 | 
						||
                history = cache.get(id=user_id, field="data")
 | 
						||
            # --------基于提示词,生成sql以及图表类型
 | 
						||
            sys_temp = sql_temp['system'].format(engine=config("DB_ENGINE", default='mysql'), lang='中文',
 | 
						||
                                                 schema=ddl_list, documentation=[train_ddl.train_document],
 | 
						||
                                                 history=history,retrieved_examples_data=question_sql_list,
 | 
						||
                                                 data_training=question_sql_list,)
 | 
						||
 | 
						||
            user_temp = sql_temp['user'].format(question=question,
 | 
						||
                                                current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
 | 
						||
            logger.info(f"user_temp:{user_temp}")
 | 
						||
            logger.info(f"sys_temp:{sys_temp}")
 | 
						||
            llm_response = self.submit_prompt(
 | 
						||
                [{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs)
 | 
						||
            llm_response = str(llm_response.strip())
 | 
						||
            logger.info(f"llm_response:{llm_response}")
 | 
						||
            #优化中.......
 | 
						||
            result = extract_nested_json(llm_response)
 | 
						||
            logger.info(f"result:{result}")
 | 
						||
            result = {"resp": orjson.loads(result)}
 | 
						||
            logger.info(f"llm_response:{llm_response}")
 | 
						||
            sql = check_and_get_sql(llm_response)
 | 
						||
            logger.info(f"sql:{sql}")
 | 
						||
            # ---------------生成图表
 | 
						||
            char_type = get_chart_type_from_sql_answer(llm_response)
 | 
						||
            logger.info(f"chart type:{char_type}")
 | 
						||
            if char_type:
 | 
						||
                sys_char_temp = char_temp['system'].format(engine=config("DB_ENGINE", default='mysql'),
 | 
						||
                                                           lang='中文', sql=sql, chart_type=char_type)
 | 
						||
                user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question)
 | 
						||
                llm_response2 = self.submit_prompt(
 | 
						||
                    [{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}],
 | 
						||
                    **kwargs)
 | 
						||
                print(llm_response2)
 | 
						||
                result['chart'] = orjson.loads(extract_nested_json(llm_response2))
 | 
						||
                logger.info(f"chart_response:{result}")
 | 
						||
 | 
						||
            logger.info("Finish to generate_sql_2 in cus_vanna_srevice")
 | 
						||
            return result
 | 
						||
        except Exception as e:
 | 
						||
            logger.info("cus_vanna_srevice failed-------------------: ")
 | 
						||
            traceback.print_exc()
 | 
						||
            raise e
 | 
						||
 | 
						||
    def run_sql_2(self,sql):
 | 
						||
        try:
 | 
						||
            return self.run_sql(sql)
 | 
						||
        except Exception as e:
 | 
						||
            logger.error("run_sql failed {0}".format(sql))
 | 
						||
            raise e
 | 
						||
 | 
						||
    def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str:
 | 
						||
        logger.info(f"generate_rewritten_question---------------{new_question}")
 | 
						||
        if last_question is None:
 | 
						||
            return new_question
 | 
						||
        print("last question {0}".format(last_question))
 | 
						||
        print("new question {0}".format(new_question))
 | 
						||
        prompt = [
 | 
						||
            self.system_message(
 | 
						||
                "你的目标是将一系列相关问题合并成一个单一的问题。"
 | 
						||
                "合并准则一、如果第二个问题与第一个问题无关且本身是完整独立的,则直接返回第二个问题。"
 | 
						||
                "合并准则二、如果第二个问题域第一个问题相关,且要基于第一个问题的前提,请合并两个问题为一个问题,只需返回合并后的新问题,不要添加任何额外解释。"
 | 
						||
                "合并准则三、理论上,合并后的问题应该能够通过单个SQL语句来回答"),
 | 
						||
            self.user_message("First question: " + last_question + "\nSecond question: " + new_question),
 | 
						||
        ]
 | 
						||
 | 
						||
        return self.submit_prompt(prompt=prompt, **kwargs)
 | 
						||
 | 
						||
 | 
						||
class CustomQdrant_VectorStore(Qdrant_VectorStore):
 | 
						||
    def __init__(
 | 
						||
            self,
 | 
						||
            config_file={}
 | 
						||
    ):
 | 
						||
        self.embedding_model_name = config('EMBEDDING_MODEL_NAME', default='')
 | 
						||
        self.embedding_api_base = config('EMBEDDING_MODEL_BASE_URL', default='')
 | 
						||
        self.embedding_api_key = config('EMBEDDING_MODEL_API_KEY', default='')
 | 
						||
        super().__init__(config_file)
 | 
						||
 | 
						||
    def generate_embedding(self, data: str, **kwargs) -> List[float]:
 | 
						||
        def _get_error_string(response: requests.Response) -> str:
 | 
						||
            try:
 | 
						||
                if response.content:
 | 
						||
                    return response.json()["detail"]
 | 
						||
            except Exception:
 | 
						||
                pass
 | 
						||
            try:
 | 
						||
                response.raise_for_status()
 | 
						||
            except requests.HTTPError as e:
 | 
						||
                return str(e)
 | 
						||
            return "Unknown error"
 | 
						||
 | 
						||
        request_body = {
 | 
						||
            "model": self.embedding_model_name,
 | 
						||
            "sentences": [data],
 | 
						||
        }
 | 
						||
        # request_body = {
 | 
						||
        #     "model": self.embedding_model_name,
 | 
						||
        #     "encoding_format": "float",
 | 
						||
        #     "input": [data],
 | 
						||
        # }
 | 
						||
        request_body.update(kwargs)
 | 
						||
 | 
						||
        response = requests.post(
 | 
						||
            url=f"{self.embedding_api_base}",
 | 
						||
            json=request_body,
 | 
						||
            headers={"Authorization": f"Bearer {self.embedding_api_key}", 'Content-Type': 'application/json'},
 | 
						||
        )
 | 
						||
        if response.status_code != 200:
 | 
						||
            raise RuntimeError(
 | 
						||
                f"Failed to create the embeddings, detail: {_get_error_string(response)}"
 | 
						||
            )
 | 
						||
        result = response.json()
 | 
						||
        embeddings = result['embeddings']
 | 
						||
        # embeddings = result['data'][0]['embedding']
 | 
						||
        return embeddings[0]
 | 
						||
        # return embeddings
 | 
						||
 | 
						||
class CustomVanna(CustomQdrant_VectorStore, OpenAICompatibleLLM):
 | 
						||
    def __init__(self, llm_config=None, vector_store_config=None):
 | 
						||
        CustomQdrant_VectorStore.__init__(self, config_file=vector_store_config)
 | 
						||
        OpenAICompatibleLLM.__init__(self, config_file=llm_config)
 | 
						||
 | 
						||
class TTLCacheWrapper:
 | 
						||
    "为MemoryCache()添加带ttl的包装器,防治内存泄漏"
 | 
						||
    def __init__(self, cache: Optional[Cache] = None, ttl: int = 3600):
 | 
						||
        self.cache = cache or MemoryCache()
 | 
						||
        self.ttl = ttl
 | 
						||
        self._expiry_times = {}
 | 
						||
        self._cleanup_thread = None
 | 
						||
        self._start_cleanup()
 | 
						||
 | 
						||
    def _start_cleanup(self):
 | 
						||
        """启动后台清理线程"""
 | 
						||
 | 
						||
        def cleanup():
 | 
						||
            while True:
 | 
						||
                current_time = time.time()
 | 
						||
                expired_keys = []
 | 
						||
 | 
						||
                # 找出所有过期的key
 | 
						||
                for key_info, expiry in self._expiry_times.items():
 | 
						||
                    if expiry <= current_time:
 | 
						||
                        expired_keys.append(key_info)
 | 
						||
 | 
						||
                # 清理过期数据
 | 
						||
                for key_info in expired_keys:
 | 
						||
                    id, field = key_info
 | 
						||
                    if hasattr(self.cache, 'delete'):
 | 
						||
                        self.cache.delete(id=id)
 | 
						||
                    del self._expiry_times[key_info]
 | 
						||
 | 
						||
                time.sleep(180)  # 每3分钟清理一次
 | 
						||
 | 
						||
        self._cleanup_thread = threading.Thread(target=cleanup, daemon=True)
 | 
						||
        self._cleanup_thread.start()
 | 
						||
 | 
						||
    def set(self, id: str, field: str, value: Any, ttl: Optional[int] = None):
 | 
						||
        """设置缓存值,支持TTL"""
 | 
						||
        # 使用提供的TTL或默认TTL
 | 
						||
        actual_ttl = ttl if ttl is not None else self.ttl
 | 
						||
 | 
						||
        # 调用原始cache的set方法
 | 
						||
        self.cache.set(id=id, field=field, value=value)
 | 
						||
 | 
						||
        # 记录过期时间
 | 
						||
        key_info = (id, field)
 | 
						||
        self._expiry_times[key_info] = time.time() + actual_ttl
 | 
						||
 | 
						||
    def get(self, id: str, field: str) -> Any:
 | 
						||
        """获取缓存值,自动处理过期"""
 | 
						||
        key_info = (id, field)
 | 
						||
 | 
						||
        # 检查是否过期
 | 
						||
        if key_info in self._expiry_times:
 | 
						||
            if time.time() > self._expiry_times[key_info]:
 | 
						||
                # 已过期,删除并返回None
 | 
						||
                if hasattr(self.cache, 'delete'):
 | 
						||
                    self.cache.delete(id=id)
 | 
						||
                del self._expiry_times[key_info]
 | 
						||
                return None
 | 
						||
 | 
						||
        # 返回缓存值
 | 
						||
        return self.cache.get(id=id, field=field)
 | 
						||
 | 
						||
    def delete(self, id: str, field: str):
 | 
						||
        """删除缓存值"""
 | 
						||
        key_info = (id, field)
 | 
						||
        if hasattr(self.cache, 'delete'):
 | 
						||
            self.cache.delete(id=id)
 | 
						||
        if key_info in self._expiry_times:
 | 
						||
            del self._expiry_times[key_info]
 | 
						||
 | 
						||
    def exists(self, id: str, field: str) -> bool:
 | 
						||
        """检查缓存是否存在且未过期"""
 | 
						||
        key_info = (id, field)
 | 
						||
        if key_info in self._expiry_times:
 | 
						||
            if time.time() > self._expiry_times[key_info]:
 | 
						||
                # 已过期,清理并返回False
 | 
						||
                self.delete(id=id, field=field)
 | 
						||
                return False
 | 
						||
        return self.get(id=id, field=field) is not None
 | 
						||
 | 
						||
    def items(self):
 | 
						||
        """遍历所有未过期的缓存键值对"""
 | 
						||
        current_time = time.time()
 | 
						||
        items = []
 | 
						||
 | 
						||
        for (id, field), expiry in self._expiry_times.items():
 | 
						||
            # 检查是否过期
 | 
						||
            if current_time <= expiry:
 | 
						||
                value = self.get(id=id, field=field)
 | 
						||
                if value is not None:
 | 
						||
                    items.append({
 | 
						||
                        'id': id,
 | 
						||
                        'field': field,
 | 
						||
                        'value': value,
 | 
						||
                        'expires_at': expiry,
 | 
						||
                        'time_left': expiry - current_time
 | 
						||
                    })
 | 
						||
 | 
						||
        return items
 | 
						||
 | 
						||
    def get_latest_by_id(self, id: str, limit: int = 1, field_filter: str = None):
 | 
						||
        """
 | 
						||
        获取指定ID下时间最近的缓存项
 | 
						||
 | 
						||
        Args:
 | 
						||
            id: 要查询的ID
 | 
						||
            limit: 返回最近几条记录,默认1条
 | 
						||
            field_filter: 可选的字段过滤,如只获取特定字段
 | 
						||
 | 
						||
        Returns:
 | 
						||
            按时间倒序排列的缓存项列表
 | 
						||
        """
 | 
						||
        current_time = time.time()
 | 
						||
        matched_items = []
 | 
						||
 | 
						||
        # 找出该ID下所有未过期的缓存项
 | 
						||
        for (cache_id, field), expiry in self._expiry_times.items():
 | 
						||
            if cache_id == id and current_time <= expiry:
 | 
						||
                # 字段过滤
 | 
						||
                if field_filter and field != field_filter:
 | 
						||
                    continue
 | 
						||
 | 
						||
                value = self.get(id=id, field=field)
 | 
						||
                if value is not None:
 | 
						||
                    matched_items.append({
 | 
						||
                        'id': id,
 | 
						||
                        'field': field,
 | 
						||
                        'value': value,
 | 
						||
                        'expires_at': expiry,
 | 
						||
                        'created_time': expiry - self.ttl,  # 估算创建时间
 | 
						||
                        'time_left': expiry - current_time
 | 
						||
                    })
 | 
						||
 | 
						||
        # 按过期时间倒序排列(最近创建的排在前面)
 | 
						||
        matched_items.sort(key=lambda x: x['expires_at'], reverse=True)
 | 
						||
 | 
						||
        return matched_items[:limit]
 | 
						||
 | 
						||
    # 代理其他方法到原始cache
 | 
						||
    def __getattr__(self, name):
 | 
						||
        return getattr(self.cache, name) |