2025-09-23 14:49:00 +08:00
from email . policy import default
from typing import List
import orjson
from vanna . base import VannaBase
from vanna . qdrant import Qdrant_VectorStore
from qdrant_client import QdrantClient
from openai import OpenAI
import requests
from decouple import config
from util . utils import extract_nested_json , check_and_get_sql , get_chart_type_from_sql_answer
import json
from template . template import get_base_template
from datetime import datetime
class OpenAICompatibleLLM ( VannaBase ) :
def __init__ ( self , client = None , config_file = None ) :
VannaBase . __init__ ( self , config = config_file )
# default parameters - can be overrided using config
self . temperature = 0.5
self . max_tokens = 5000
if " temperature " in config_file :
self . temperature = config_file [ " temperature " ]
if " max_tokens " in config_file :
self . max_tokens = config_file [ " max_tokens " ]
if " api_type " in config_file :
raise Exception (
" Passing api_type is now deprecated. Please pass an OpenAI client instead. "
)
if " api_version " in config_file :
raise Exception (
" Passing api_version is now deprecated. Please pass an OpenAI client instead. "
)
if client is not None :
self . client = client
return
if " api_base " not in config_file :
raise Exception ( " Please passing api_base " )
if " api_key " not in config_file :
raise Exception ( " Please passing api_key " )
self . client = OpenAI ( api_key = config_file [ " api_key " ] , base_url = config_file [ " api_base " ] )
def system_message ( self , message : str ) - > any :
return { " role " : " system " , " content " : message }
def user_message ( self , message : str ) - > any :
return { " role " : " user " , " content " : message }
def assistant_message ( self , message : str ) - > any :
return { " role " : " assistant " , " content " : message }
def submit_prompt ( self , prompt , * * kwargs ) - > str :
if prompt is None :
raise Exception ( " Prompt is None " )
if len ( prompt ) == 0 :
raise Exception ( " Prompt is empty " )
print ( prompt )
num_tokens = 0
for message in prompt :
num_tokens + = len ( message [ " content " ] ) / 4
if kwargs . get ( " model " , None ) is not None :
model = kwargs . get ( " model " , None )
print (
f " Using model { model } for { num_tokens } tokens (approx) "
)
response = self . client . chat . completions . create (
model = model ,
messages = prompt ,
max_tokens = self . max_tokens ,
stop = None ,
temperature = self . temperature ,
)
elif kwargs . get ( " engine " , None ) is not None :
engine = kwargs . get ( " engine " , None )
print (
f " Using model { engine } for { num_tokens } tokens (approx) "
)
response = self . client . chat . completions . create (
engine = engine ,
messages = prompt ,
max_tokens = self . max_tokens ,
stop = None ,
temperature = self . temperature ,
)
elif self . config is not None and " engine " in self . config :
print (
f " Using engine { self . config [ ' engine ' ] } for { num_tokens } tokens (approx) "
)
response = self . client . chat . completions . create (
engine = self . config [ " engine " ] ,
messages = prompt ,
max_tokens = self . max_tokens ,
stop = None ,
temperature = self . temperature ,
)
elif self . config is not None and " model " in self . config :
print (
f " Using model { self . config [ ' model ' ] } for { num_tokens } tokens (approx) "
)
response = self . client . chat . completions . create (
model = self . config [ " model " ] ,
messages = prompt ,
max_tokens = self . max_tokens ,
stop = None ,
temperature = self . temperature ,
)
else :
if num_tokens > 3500 :
model = " kimi "
else :
model = " doubao "
print ( f " Using model { model } for { num_tokens } tokens (approx) " )
response = self . client . chat . completions . create (
model = model ,
messages = prompt ,
max_tokens = self . max_tokens ,
stop = None ,
temperature = self . temperature ,
)
for choice in response . choices :
if " text " in choice :
return choice . text
return response . choices [ 0 ] . message . content
def generate_sql_2 ( self , question : str , allow_llm_to_see_data = False , * * kwargs ) - > dict :
question_sql_list = self . get_similar_question_sql ( question , * * kwargs )
ddl_list = self . get_related_ddl ( question , * * kwargs )
doc_list = self . get_related_documentation ( question , * * kwargs )
template = get_base_template ( )
sql_temp = template [ ' template ' ] [ ' sql ' ]
char_temp = template [ ' template ' ] [ ' chart ' ]
2025-09-23 16:29:56 +08:00
# --------基于提示词, 生成sql以及图表类型
sys_temp = sql_temp [ ' system ' ] . format ( engine = config ( " DATA_SOURCE_TYPE " , default = ' mysql ' ) , lang = ' 中文 ' , schema = ddl_list , documentation = doc_list ,
2025-09-23 14:49:00 +08:00
data_training = question_sql_list )
2025-09-23 16:29:56 +08:00
print ( " sys_temp " , sys_temp )
2025-09-23 14:49:00 +08:00
user_temp = sql_temp [ ' user ' ] . format ( question = question ,
current_time = datetime . now ( ) . strftime ( ' % Y- % m- %d % H: % M: % S ' ) )
2025-09-23 16:29:56 +08:00
print ( " user_temp " , user_temp )
2025-09-23 14:49:00 +08:00
llm_response = self . submit_prompt (
[ { ' role ' : ' system ' , ' content ' : sys_temp } , { ' role ' : ' user ' , ' content ' : user_temp } ] , * * kwargs )
print ( llm_response )
result = { " resp " : orjson . loads ( extract_nested_json ( llm_response ) ) }
2025-09-23 16:29:56 +08:00
print ( " result " , result )
2025-09-23 14:49:00 +08:00
sql = check_and_get_sql ( llm_response )
# ---------------生成图表
char_type = get_chart_type_from_sql_answer ( llm_response )
if char_type :
2025-09-23 16:29:56 +08:00
sys_char_temp = char_temp [ ' system ' ] . format ( engine = config ( " DATA_SOURCE_TYPE " , default = ' mysql ' ) , lang = ' 中文 ' , sql = sql , chart_type = char_type )
2025-09-23 14:49:00 +08:00
user_char_temp = char_temp [ ' user ' ] . format ( sql = sql , chart_type = char_type , question = question )
llm_response2 = self . submit_prompt (
[ { ' role ' : ' system ' , ' content ' : sys_char_temp } , { ' role ' : ' user ' , ' content ' : user_char_temp } ] , * * kwargs )
print ( llm_response2 )
result [ ' chart ' ] = orjson . loads ( extract_nested_json ( llm_response2 ) )
return result
2025-09-23 16:29:56 +08:00
def generate_rewritten_question ( self , last_question : str , new_question : str , * * kwargs ) - > str :
print ( " new_question--------------- " , new_question )
if last_question is None :
return new_question
prompt = [
self . system_message (
" Your goal is to combine a sequence of questions into a singular question if they are related. If the second question does not relate to the first question and is fully self-contained, return the second question. Return just the new combined question with no additional explanations. The question should theoretically be answerable with a single SQL statement. " ) ,
self . user_message ( new_question ) ,
]
return self . submit_prompt ( prompt = prompt , * * kwargs )
2025-09-23 14:49:00 +08:00
class CustomQdrant_VectorStore ( Qdrant_VectorStore ) :
def __init__ (
self ,
config_file = { }
) :
self . embedding_model_name = config ( ' EMBEDDING_MODEL_NAME ' , default = ' ' )
self . embedding_api_base = config ( ' EMBEDDING_MODEL_BASE_URL ' , default = ' ' )
self . embedding_api_key = config ( ' EMBEDDING_MODEL_API_KEY ' , default = ' ' )
super ( ) . __init__ ( config_file )
def generate_embedding ( self , data : str , * * kwargs ) - > List [ float ] :
def _get_error_string ( response : requests . Response ) - > str :
try :
if response . content :
return response . json ( ) [ " detail " ]
except Exception :
pass
try :
response . raise_for_status ( )
except requests . HTTPError as e :
return str ( e )
return " Unknown error "
request_body = {
" model " : self . embedding_model_name ,
" input " : data ,
}
request_body . update ( kwargs )
response = requests . post (
url = f " { self . embedding_api_base } /v1/embeddings " ,
json = request_body ,
headers = { " Authorization " : f " Bearer { self . embedding_api_key } " } ,
)
if response . status_code != 200 :
raise RuntimeError (
f " Failed to create the embeddings, detail: { _get_error_string ( response ) } "
)
result = response . json ( )
embeddings = [ d [ " embedding " ] for d in result [ " data " ] ]
return embeddings [ 0 ]
class CustomVanna ( CustomQdrant_VectorStore , OpenAICompatibleLLM ) :
def __init__ ( self , llm_config = None , vector_store_config = None ) :
CustomQdrant_VectorStore . __init__ ( self , config_file = vector_store_config )
OpenAICompatibleLLM . __init__ ( self , config_file = llm_config )