2025-10-13 18:18:58 +08:00
from dataclasses import field
2025-09-23 14:49:00 +08:00
from email . policy import default
2025-10-10 16:39:59 +08:00
from typing import List , Union , Any , Optional
import time
import threading
from vanna . flask import Cache , MemoryCache
2025-09-24 14:17:07 +08:00
import dmPython
2025-09-23 14:49:00 +08:00
import orjson
2025-09-23 20:06:32 +08:00
import pandas as pd
2025-09-23 14:49:00 +08:00
from vanna . base import VannaBase
2025-09-23 17:07:47 +08:00
from vanna . flask import MemoryCache
2025-09-23 14:49:00 +08:00
from vanna . qdrant import Qdrant_VectorStore
from qdrant_client import QdrantClient
from openai import OpenAI
import requests
from decouple import config
from util . utils import extract_nested_json , check_and_get_sql , get_chart_type_from_sql_answer
import json
from template . template import get_base_template
from datetime import datetime
2025-09-25 08:35:49 +08:00
import logging
2025-09-26 16:05:07 +08:00
from util import train_ddl
2025-09-25 08:35:49 +08:00
logger = logging . getLogger ( __name__ )
2025-09-28 16:44:58 +08:00
import traceback
2025-09-23 14:49:00 +08:00
class OpenAICompatibleLLM ( VannaBase ) :
def __init__ ( self , client = None , config_file = None ) :
VannaBase . __init__ ( self , config = config_file )
# default parameters - can be overrided using config
self . temperature = 0.5
self . max_tokens = 5000
if " temperature " in config_file :
self . temperature = config_file [ " temperature " ]
if " max_tokens " in config_file :
self . max_tokens = config_file [ " max_tokens " ]
if " api_type " in config_file :
raise Exception (
" Passing api_type is now deprecated. Please pass an OpenAI client instead. "
)
if " api_version " in config_file :
raise Exception (
" Passing api_version is now deprecated. Please pass an OpenAI client instead. "
)
if client is not None :
self . client = client
return
if " api_base " not in config_file :
raise Exception ( " Please passing api_base " )
if " api_key " not in config_file :
raise Exception ( " Please passing api_key " )
self . client = OpenAI ( api_key = config_file [ " api_key " ] , base_url = config_file [ " api_base " ] )
def system_message ( self , message : str ) - > any :
return { " role " : " system " , " content " : message }
2025-09-24 14:17:07 +08:00
def connect_to_dameng (
self ,
host : str = None ,
dbname : str = None ,
user : str = None ,
password : str = None ,
port : int = None ,
* * kwargs
) :
2025-10-10 16:39:59 +08:00
self . conn = None
2025-09-24 14:17:07 +08:00
try :
2025-10-10 16:39:59 +08:00
self . conn = dmPython . connect ( user = user , password = password , server = host , port = port )
2025-09-24 14:17:07 +08:00
except Exception as e :
raise Exception ( f " Failed to connect to dameng database: { e } " )
def run_sql_damengsql ( sql : str ) - > Union [ pd . DataFrame , None ] :
2025-10-10 16:39:59 +08:00
logger . info ( f " start to run_sql_damengsql " )
try :
2025-10-13 18:18:58 +08:00
if not is_connection_alive ( conn = self . conn ) :
logger . info ( " connection is not alive, reconnecting.......... " )
reconnect ( )
2025-10-10 16:39:59 +08:00
# conn.ping(reconnect=True)
cs = self . conn . cursor ( )
cs . execute ( sql )
results = cs . fetchall ( )
# Create a pandas dataframe from the results
df = pd . DataFrame (
results , columns = [ desc [ 0 ] for desc in cs . description ]
)
return df
except Exception as e :
self . conn . rollback ( )
logger . error ( f " Failed to execute sql query: { e } " )
raise e
2025-09-24 14:17:07 +08:00
return None
2025-10-10 16:39:59 +08:00
def reconnect ( ) :
try :
self . conn = dmPython . connect ( user = user , password = password , server = host , port = port )
except Exception as e :
raise Exception ( f " reconnect failed: { e } " )
def is_connection_alive ( conn ) - > bool :
if conn is None :
return False
try :
cursor = conn . cursor ( )
cursor . execute ( " SELECT 1 FROM DUAL " )
cursor . close ( )
return True
except Exception as e :
return False
2025-09-24 14:17:07 +08:00
self . run_sql_is_set = True
self . run_sql = run_sql_damengsql
2025-10-10 16:39:59 +08:00
2025-09-23 14:49:00 +08:00
def user_message ( self , message : str ) - > any :
return { " role " : " user " , " content " : message }
def assistant_message ( self , message : str ) - > any :
return { " role " : " assistant " , " content " : message }
def submit_prompt ( self , prompt , * * kwargs ) - > str :
if prompt is None :
raise Exception ( " Prompt is None " )
if len ( prompt ) == 0 :
raise Exception ( " Prompt is empty " )
print ( prompt )
num_tokens = 0
for message in prompt :
num_tokens + = len ( message [ " content " ] ) / 4
if kwargs . get ( " model " , None ) is not None :
model = kwargs . get ( " model " , None )
print (
f " Using model { model } for { num_tokens } tokens (approx) "
)
response = self . client . chat . completions . create (
model = model ,
messages = prompt ,
max_tokens = self . max_tokens ,
stop = None ,
temperature = self . temperature ,
)
elif kwargs . get ( " engine " , None ) is not None :
engine = kwargs . get ( " engine " , None )
print (
f " Using model { engine } for { num_tokens } tokens (approx) "
)
response = self . client . chat . completions . create (
engine = engine ,
messages = prompt ,
max_tokens = self . max_tokens ,
stop = None ,
temperature = self . temperature ,
)
elif self . config is not None and " engine " in self . config :
print (
f " Using engine { self . config [ ' engine ' ] } for { num_tokens } tokens (approx) "
)
response = self . client . chat . completions . create (
engine = self . config [ " engine " ] ,
messages = prompt ,
max_tokens = self . max_tokens ,
stop = None ,
temperature = self . temperature ,
)
elif self . config is not None and " model " in self . config :
print (
f " Using model { self . config [ ' model ' ] } for { num_tokens } tokens (approx) "
)
response = self . client . chat . completions . create (
model = self . config [ " model " ] ,
messages = prompt ,
max_tokens = self . max_tokens ,
stop = None ,
temperature = self . temperature ,
)
2025-10-15 16:52:35 +08:00
# data = {
# "model": self.model,
# "prompt": prompt,
# "max_tokens": self.max_tokens,
# "temperature": self.temperature,
# }
# response = requests.post(
# url=f"{self.api_base}/completions",
# headers=self.headers,
# json=data
# )
2025-09-23 14:49:00 +08:00
else :
if num_tokens > 3500 :
model = " kimi "
else :
model = " doubao "
2025-10-15 16:52:35 +08:00
print ( f " 5.Using model { model } for { num_tokens } tokens (approx) " )
2025-09-23 14:49:00 +08:00
response = self . client . chat . completions . create (
model = model ,
messages = prompt ,
max_tokens = self . max_tokens ,
stop = None ,
temperature = self . temperature ,
)
for choice in response . choices :
if " text " in choice :
return choice . text
return response . choices [ 0 ] . message . content
2025-10-13 18:18:58 +08:00
def generate_sql_2 ( self , question : str , cache = None , user_id = None , allow_llm_to_see_data = False , * * kwargs ) - > dict :
2025-09-24 14:39:42 +08:00
try :
2025-09-25 16:49:25 +08:00
logger . info ( " Start to generate_sql_2 in cus_vanna_srevice " )
2025-09-24 14:39:42 +08:00
question_sql_list = self . get_similar_question_sql ( question , * * kwargs )
2025-09-28 16:44:58 +08:00
if question_sql_list and len ( question_sql_list ) > 2 :
2025-10-14 11:02:44 +08:00
question_sql_list = question_sql_list [ : 2 ]
2025-09-28 16:44:58 +08:00
2025-09-24 14:39:42 +08:00
ddl_list = self . get_related_ddl ( question , * * kwargs )
2025-09-26 16:05:07 +08:00
#doc_list = self.get_related_documentation(question, **kwargs)
2025-09-24 14:39:42 +08:00
template = get_base_template ( )
sql_temp = template [ ' template ' ] [ ' sql ' ]
char_temp = template [ ' template ' ] [ ' chart ' ]
2025-10-13 18:18:58 +08:00
history = None
if user_id and cache :
history = cache . get ( id = user_id , field = " data " )
2025-09-24 14:39:42 +08:00
# --------基于提示词, 生成sql以及图表类型
2025-09-24 14:49:37 +08:00
sys_temp = sql_temp [ ' system ' ] . format ( engine = config ( " DB_ENGINE " , default = ' mysql ' ) , lang = ' 中文 ' ,
2025-09-26 16:05:07 +08:00
schema = ddl_list , documentation = [ train_ddl . train_document ] ,
2025-10-13 18:18:58 +08:00
history = history , retrieved_examples_data = question_sql_list ,
2025-09-28 16:44:58 +08:00
data_training = question_sql_list , )
2025-10-13 18:18:58 +08:00
2025-09-24 14:39:42 +08:00
user_temp = sql_temp [ ' user ' ] . format ( question = question ,
current_time = datetime . now ( ) . strftime ( ' % Y- % m- %d % H: % M: % S ' ) )
2025-09-25 08:35:49 +08:00
logger . info ( f " user_temp: { user_temp } " )
2025-10-13 18:18:58 +08:00
logger . info ( f " sys_temp: { sys_temp } " )
2025-09-24 14:39:42 +08:00
llm_response = self . submit_prompt (
[ { ' role ' : ' system ' , ' content ' : sys_temp } , { ' role ' : ' user ' , ' content ' : user_temp } ] , * * kwargs )
2025-10-15 16:52:35 +08:00
llm_response = str ( llm_response . strip ( ) )
2025-09-25 08:35:49 +08:00
logger . info ( f " llm_response: { llm_response } " )
2025-09-24 14:39:42 +08:00
result = { " resp " : orjson . loads ( extract_nested_json ( llm_response ) ) }
2025-09-25 16:49:25 +08:00
logger . info ( f " llm_response: { llm_response } " )
2025-09-24 14:39:42 +08:00
sql = check_and_get_sql ( llm_response )
2025-09-25 16:49:25 +08:00
logger . info ( f " sql: { sql } " )
2025-09-24 14:39:42 +08:00
# ---------------生成图表
char_type = get_chart_type_from_sql_answer ( llm_response )
2025-09-25 16:49:25 +08:00
logger . info ( f " chart type: { char_type } " )
2025-09-24 14:39:42 +08:00
if char_type :
2025-09-24 14:49:37 +08:00
sys_char_temp = char_temp [ ' system ' ] . format ( engine = config ( " DB_ENGINE " , default = ' mysql ' ) ,
2025-09-24 14:39:42 +08:00
lang = ' 中文 ' , sql = sql , chart_type = char_type )
user_char_temp = char_temp [ ' user ' ] . format ( sql = sql , chart_type = char_type , question = question )
llm_response2 = self . submit_prompt (
[ { ' role ' : ' system ' , ' content ' : sys_char_temp } , { ' role ' : ' user ' , ' content ' : user_char_temp } ] ,
* * kwargs )
print ( llm_response2 )
result [ ' chart ' ] = orjson . loads ( extract_nested_json ( llm_response2 ) )
2025-09-25 08:35:49 +08:00
logger . info ( f " chart_response: { result } " )
2025-09-25 16:49:25 +08:00
logger . info ( " Finish to generate_sql_2 in cus_vanna_srevice " )
2025-09-24 14:39:42 +08:00
return result
except Exception as e :
2025-09-28 16:44:58 +08:00
logger . info ( " cus_vanna_srevice failed-------------------: " )
traceback . print_exc ( )
2025-09-24 14:39:42 +08:00
raise e
2025-09-23 14:49:00 +08:00
2025-09-23 16:29:56 +08:00
def generate_rewritten_question ( self , last_question : str , new_question : str , * * kwargs ) - > str :
2025-09-25 08:35:49 +08:00
logger . info ( f " generate_rewritten_question--------------- { new_question } " )
2025-09-23 16:29:56 +08:00
if last_question is None :
return new_question
prompt = [
self . system_message (
" Your goal is to combine a sequence of questions into a singular question if they are related. If the second question does not relate to the first question and is fully self-contained, return the second question. Return just the new combined question with no additional explanations. The question should theoretically be answerable with a single SQL statement. " ) ,
self . user_message ( new_question ) ,
]
return self . submit_prompt ( prompt = prompt , * * kwargs )
2025-09-23 14:49:00 +08:00
class CustomQdrant_VectorStore ( Qdrant_VectorStore ) :
def __init__ (
self ,
config_file = { }
) :
self . embedding_model_name = config ( ' EMBEDDING_MODEL_NAME ' , default = ' ' )
self . embedding_api_base = config ( ' EMBEDDING_MODEL_BASE_URL ' , default = ' ' )
self . embedding_api_key = config ( ' EMBEDDING_MODEL_API_KEY ' , default = ' ' )
super ( ) . __init__ ( config_file )
def generate_embedding ( self , data : str , * * kwargs ) - > List [ float ] :
def _get_error_string ( response : requests . Response ) - > str :
try :
if response . content :
return response . json ( ) [ " detail " ]
except Exception :
pass
try :
response . raise_for_status ( )
except requests . HTTPError as e :
return str ( e )
return " Unknown error "
request_body = {
" model " : self . embedding_model_name ,
2025-09-29 19:29:26 +08:00
" sentences " : [ data ] ,
2025-09-23 14:49:00 +08:00
}
request_body . update ( kwargs )
response = requests . post (
2025-09-29 19:29:26 +08:00
url = f " { self . embedding_api_base } " ,
2025-09-23 14:49:00 +08:00
json = request_body ,
2025-09-29 19:29:26 +08:00
headers = { " Authorization " : f " Bearer { self . embedding_api_key } " , ' Content-Type ' : ' application/json ' } ,
2025-09-23 14:49:00 +08:00
)
if response . status_code != 200 :
raise RuntimeError (
f " Failed to create the embeddings, detail: { _get_error_string ( response ) } "
)
result = response . json ( )
2025-09-29 19:29:26 +08:00
embeddings = result [ ' embeddings ' ]
2025-09-23 14:49:00 +08:00
return embeddings [ 0 ]
class CustomVanna ( CustomQdrant_VectorStore , OpenAICompatibleLLM ) :
def __init__ ( self , llm_config = None , vector_store_config = None ) :
CustomQdrant_VectorStore . __init__ ( self , config_file = vector_store_config )
2025-10-10 16:39:59 +08:00
OpenAICompatibleLLM . __init__ ( self , config_file = llm_config )
class TTLCacheWrapper :
" 为MemoryCache()添加带ttl的包装器, 防治内存泄漏 "
def __init__ ( self , cache : Optional [ Cache ] = None , ttl : int = 3600 ) :
self . cache = cache or MemoryCache ( )
self . ttl = ttl
self . _expiry_times = { }
self . _cleanup_thread = None
self . _start_cleanup ( )
def _start_cleanup ( self ) :
""" 启动后台清理线程 """
def cleanup ( ) :
while True :
current_time = time . time ( )
expired_keys = [ ]
# 找出所有过期的key
for key_info , expiry in self . _expiry_times . items ( ) :
if expiry < = current_time :
expired_keys . append ( key_info )
# 清理过期数据
for key_info in expired_keys :
id , field = key_info
if hasattr ( self . cache , ' delete ' ) :
self . cache . delete ( id = id )
del self . _expiry_times [ key_info ]
time . sleep ( 180 ) # 每3分钟清理一次
self . _cleanup_thread = threading . Thread ( target = cleanup , daemon = True )
self . _cleanup_thread . start ( )
def set ( self , id : str , field : str , value : Any , ttl : Optional [ int ] = None ) :
""" 设置缓存值, 支持TTL """
# 使用提供的TTL或默认TTL
actual_ttl = ttl if ttl is not None else self . ttl
# 调用原始cache的set方法
self . cache . set ( id = id , field = field , value = value )
# 记录过期时间
key_info = ( id , field )
self . _expiry_times [ key_info ] = time . time ( ) + actual_ttl
def get ( self , id : str , field : str ) - > Any :
""" 获取缓存值,自动处理过期 """
key_info = ( id , field )
# 检查是否过期
if key_info in self . _expiry_times :
if time . time ( ) > self . _expiry_times [ key_info ] :
# 已过期, 删除并返回None
if hasattr ( self . cache , ' delete ' ) :
self . cache . delete ( id = id )
del self . _expiry_times [ key_info ]
return None
# 返回缓存值
return self . cache . get ( id = id , field = field )
def delete ( self , id : str , field : str ) :
""" 删除缓存值 """
key_info = ( id , field )
if hasattr ( self . cache , ' delete ' ) :
self . cache . delete ( id = id )
if key_info in self . _expiry_times :
del self . _expiry_times [ key_info ]
def exists ( self , id : str , field : str ) - > bool :
""" 检查缓存是否存在且未过期 """
key_info = ( id , field )
if key_info in self . _expiry_times :
if time . time ( ) > self . _expiry_times [ key_info ] :
# 已过期, 清理并返回False
self . delete ( id = id , field = field )
return False
return self . get ( id = id , field = field ) is not None
# 代理其他方法到原始cache
def __getattr__ ( self , name ) :
return getattr ( self . cache , name )