Compare commits

3 Commits

Author SHA1 Message Date
yujj128
30bb1a6839 修复 2025-11-06 16:32:32 +08:00
yujj128
a7043e4ec3 dev2->dev graph 2025-11-06 16:23:51 +08:00
雷雨
e34a8b839f feat:图式节点-支持问数 2025-11-06 12:23:07 +08:00
10 changed files with 758 additions and 98 deletions

21
.env
View File

@@ -1,4 +1,4 @@
IS_FIRST_LOAD=True
IS_FIRST_LOAD=False
#CHAT_MODEL_BASE_URL=https://api.siliconflow.cn
#CHAT_MODEL_API_KEY=sk-iyhiltycmrfnhrnbljsgqjrinhbztwdplyvuhfihcdlepole
@@ -9,19 +9,26 @@ IS_FIRST_LOAD=True
#CHAT_MODEL_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1
#CHAT_MODEL_API_KEY=sk-72575159d3ec43a68c6e222a15719bed
#CHAT_MODEL_NAME=qwen-plus
CHAT_MODEL_BASE_URL=https://api.siliconflow.cn
CHAT_MODEL_API_KEY=sk-tnmkzvzbipohjfbqxhictewzdgrrxoghbmicrfjgxbgdkjfq
CHAT_MODEL_NAME=Qwen/Qwen3-32B
CHAT_MODEL_BASE_URL=https://api.siliconflow.cn/v1
CHAT_MODEL_API_KEY=sk-jixyuwdltfawojgwywogkkdoqwpxblabprxybltlacbpqnip
CHAT_MODEL_NAME=zai-org/GLM-4.6
SQL_MODEL_BASE_URL=https://api.siliconflow.cn/v1
SQL_MODEL_API_KEY=sk-jixyuwdltfawojgwywogkkdoqwpxblabprxybltlacbpqnip
SQL_MODEL_NAME=Qwen/Qwen3-Coder-30B-A3B-Instruct
SQL_MODEL_TEMPERATURE=0.5
#使用ai中台的模型
EMBEDDING_MODEL_BASE_URL=http://10.225.128.2:13206/member1/small-model/bge/encode
EMBEDDING_MODEL_BASE_URL=https://api.siliconflow.cn
EMBEDDING_MODEL_API_KEY=sk-iyhiltycmrfnhrnbljsgqjrinhbztwdplyvuhfihcdlepole
EMBEDDING_MODEL_NAME=BAAI/bge-m3
#向量数据库
#type:memory/remote,如果设置为remote将IS_FIRST_LOAD 设置成false
QDRANT_TYPE=memory
QDRANT_TYPE=remote
QDRANT_DB_HOST=106.13.42.156
QDRANT_DB_PORT=16000
QDRANT_DB_PORT=33088
#mysql ,sqlite,pg等
DATA_SOURCE_TYPE=dameng

View File

@@ -29,6 +29,17 @@ class QuestionFeedBack(Base):
is_process = Column(Boolean, nullable=False, default=False)
class Conversation(Base):
__tablename__ = 'db_conversation'
id = Column(String(255), primary_key=True)
create_time = Column(DateTime, nullable=False, )
question = Column(String(500), nullable=False)
sql = Column(String(500), nullable=True)
user_id = Column(String(100), nullable=False)
cvs_id = Column(String(100), nullable=False)
meta = Column(Text, nullable=True)
class PredefinedQuestion(Base):
# 定义表名,预制问题表
__tablename__ = 'db_predefined_question'

0
graph_chat/__init__.py Normal file
View File

View File

@@ -0,0 +1,182 @@
from langchain_openai import ChatOpenAI
from typing import TypedDict, Annotated, List, Optional, Union
import pandas as pd
from langgraph.types import Command, interrupt
from langgraph.graph import StateGraph, END, START
from langgraph.checkpoint.memory import MemorySaver
import orjson, logging
logger = logging.getLogger(__name__)
from decouple import config
from template.template import get_base_template
from util.utils import extract_nested_json, check_and_get_sql, get_chart_type_from_sql_answer
from datetime import datetime
from util import train_ddl
gen_history_llm = ChatOpenAI(
model=config('CHAT_MODEL_NAME', default=''),
base_url=config('CHAT_MODEL_BASE_URL', default=''),
temperature=config('SQL_MODEL_TEMPERATURE', default=0.5, cast=float),
max_tokens=8000,
max_retries=2,
api_key=config('CHAT_MODEL_API_KEY', default='empty'),
)
gen_sql_llm = ChatOpenAI(
model=config('SQL_MODEL_NAME', default=''),
base_url=config('SQL_MODEL_BASE_URL', default=''),
temperature=config('SQL_MODEL_TEMPERATURE', default=0.5, cast=float),
max_tokens=8000,
max_retries=2,
api_key=config('SQL_MODEL_API_KEY', default='empty'),
)
'''
用于生成sql,生成图表的agent上下文
'''
class SqlAgentState(TypedDict):
user_question: str
rewritten_user_question: Optional[str]
session_id: str
user_id: str
history: Optional[list]
gen_sql_result: Optional[dict]
gen_chart_result: Optional[dict]
gen_sql_error: Optional[str]
gen_chart_error: Optional[str]
final_error_msg: Optional[str]
sql_retry_count: int
chart_retry_count: int
'''
基于上下文,历史消息,重写用户问题
'''
def _rewrite_user_question(state: SqlAgentState) -> dict:
logger.info(f"user:{state.get('user_id', '1')} 进入 _rewrite_user_question 节点")
user_question = state['user_question']
template = get_base_template()
rewrite_temp = template['template']['rewrite_question']
history = state.get('history', [])
sys_promot = rewrite_temp['system'].format(current_question=user_question, history=history)
new_question = gen_history_llm.invoke(sys_promot).text
logger.info("new_question:{0}".format(new_question))
result = extract_nested_json(new_question)
logger.info(f"result:{result}")
result = orjson.loads(result)
return {"rewritten_user_question": result.get('question', user_question)}
def _gen_sql(state: SqlAgentState) -> dict:
from main_service import vn
logger.info(f"user:{state.get('user_id', '1')} 进入 _gen_sql 节点")
question = state.get('rewritten_user_question', state['user_question'])
service = vn
question_sql_list = service.get_similar_question_sql(question)
if question_sql_list and len(question_sql_list) > 3:
question_sql_list = question_sql_list[:3]
ddl_list = service.get_related_ddl(question)
template = get_base_template()
sql_temp = template['template']['sql']
history = []
try:
sys_temp = sql_temp['system'].format(engine=config("DB_ENGINE", default='mysql'), lang='中文',
schema=ddl_list, documentation=[train_ddl.train_document],
history=history, retrieved_examples_data=question_sql_list,
data_training=question_sql_list, )
user_temp = sql_temp['user'].format(question=question,
current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
rr = gen_sql_llm.invoke(
[{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}]).text
result = extract_nested_json(rr)
logger.info(f"gensql result: {result}")
result = orjson.loads(result)
retry = state.get("sql_retry_count", 0) + 1
return {'gen_sql_result': result,"sql_retry_count": retry}
except Exception as e:
import traceback
traceback.print_exc()
retry = state.get("sql_retry_count", 0) + 1
logger.info("cus_vanna_srevice failed-------------------: ")
return {'gen_sql_error': str(e), "sql_retry_count": retry}
def _gen_chart(state: SqlAgentState) -> dict:
logger.info(f"user:{state.get('user_id', '1')} 进入 _gen_chart 节点")
template = get_base_template()
gen_sql_result=state.get('gen_sql_result', {})
sql = gen_sql_result.get('sql', '')
char_type = gen_sql_result.get('char_type', '')
question = state.get('rewritten_user_question', state['user_question'])
char_temp = template['template']['chart']
try:
sys_char_temp = char_temp['system'].format(engine=config("DB_ENGINE", default='mysql'),
lang='中文', sql=sql, chart_type=char_type)
user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question)
rr=gen_history_llm.invoke([{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}]).text
retry = state.get("chart_retry_count", 0) + 1
return {"chart_retry_count": retry,'gen_chart_result': orjson.loads(extract_nested_json(rr))}
except Exception as e:
import traceback
traceback.print_exc()
retry = state.get("chart_retry_count", 0) + 1
return {'gen_chart_error': str(e), "chart_retry_count": retry}
# 如果生成sql失败则重试2次如果仍然失败则返回错误信息
def gen_sql_handler(state: SqlAgentState) -> str:
sql_error = state.get('gen_sql_error', '')
sql_result = state.get('gen_sql_result', {})
sql_retry_count = state.get('sql_retry_count', 0)
if sql_error and len(sql_error) > 0:
if sql_retry_count < 2:
return '_gen_sql'
else:
return END
sql = sql_result.get('sql', '')
if len(sql) > 0:
return '_gen_chart'
else:
if sql_retry_count < 2:
return '_gen_sql'
else:
return END
def gen_chart_handler(state: SqlAgentState) -> str:
chart_error = state.get('gen_chart_error', '')
chart_result = state.get('gen_chart_result', {})
sql_retry_count = state.get('chart_retry_count', 0)
if chart_error and len(chart_error) > 0:
if sql_retry_count < 2:
return '_gen_chart'
else:
return END
title = chart_result.get('title', '')
if len(title) > 0:
return END
else:
if sql_retry_count < 2:
return '_gen_chart'
else:
return END
workflow = StateGraph(SqlAgentState)
workflow.add_node("_rewrite_user_question", _rewrite_user_question)
workflow.add_node("_gen_sql", _gen_sql)
workflow.add_node("_gen_chart",_gen_chart)
workflow.add_edge(START, "_rewrite_user_question")
workflow.add_edge("_rewrite_user_question", "_gen_sql")
workflow.add_conditional_edges('_gen_sql', gen_sql_handler, ['_gen_sql', '_gen_chart',END])
workflow.add_conditional_edges('_gen_chart', gen_chart_handler, [ '_gen_chart',END])
memory = MemorySaver()
sql_chart_agent = workflow.compile(checkpointer=memory)
png_data=sql_chart_agent .get_graph().draw_mermaid_png()
# with open("D://graph.png", "wb") as f:
# f.write(png_data)

View File

@@ -1,17 +1,26 @@
import copy
import logging
import time
from functools import wraps
import util.utils
from logging_config import LOGGING_CONFIG
from service.cus_vanna_srevice import CustomVanna, QdrantClient, TTLCacheWrapper
from service.question_feedback_service import save_save_question_async,query_predefined_question_list
from service.conversation_service import save_conversation,update_conversation,get_sql_by_id
from decouple import config
import flask
from util import load_ddl_doc
from flask import Flask, Response, jsonify, request
from graph_chat.gen_sql_chart_agent import SqlAgentState, sql_chart_agent
import traceback
logger = logging.getLogger(__name__)
def generate_timestamp_id():
"""生成基于时间戳的ID"""
# 获取当前时间戳(秒级)
timestamp = int(time.time() * 1000)
return f"Q{timestamp}"
def connect_database(vn):
db_type = config('DATA_SOURCE_TYPE', default='sqlite')
@@ -59,7 +68,7 @@ def create_vana():
def init_vn(vn):
logger.info("--------------init vana-----connect to datasouce db----")
connect_database(vn)
#connect_database(vn)
if config('IS_FIRST_LOAD', default=False, cast=bool):
load_ddl_doc.add_ddl(vn)
load_ddl_doc.add_documentation(vn)
@@ -82,12 +91,11 @@ def generate_sql_2():
Generate SQL from a question
---
parameters:
- name: user
- name: user_id
in: query
- name: question
in: query
type: string
required: true
- name: question_id
responses:
200:
schema:
@@ -105,59 +113,101 @@ def generate_sql_2():
question = flask.request.args.get("question")
if question is None:
return jsonify({"type": "error", "error": "No question provided"})
user_id = request.args.get("user_id")
cvs_id = request.args.get("cvs_id")
need_context = bool(request.args.get("need_context"))
if user_id is None or cvs_id is None:
return jsonify({"type": "error", "error": "No user_id or cvs_id provided"})
id = generate_timestamp_id()
logger.info(f"question_id: {id} user_id: {user_id} cvs_id: {cvs_id} question: {question}")
save_conversation(id,user_id,cvs_id,question)
try:
id = cache.generate_id(question=question)
user_id = request.args.get("user_id")
logger.info(f"Generate sql for {question}")
data = vn.generate_sql_2(question=question, cache=cache, user_id=user_id)
data = vn.generate_sql_2(user_id,cvs_id,question,id,need_context)
logger.info("Generate sql result is {0}".format(data))
data['id'] = id
sql = data["resp"]["sql"]
logger.info("generate sql is : " + sql)
cache.set(id=id, field="question", value=question)
cache.set(id=id, field="sql", value=sql)
data["type"] = "success"
logger.info("generate sql is : "+ sql)
update_conversation(cvs_id, id, sql)
save_save_question_async(id, user_id, question, sql)
data["type"]="success"
return jsonify(data)
except Exception as e:
logger.error(f"generate sql failed:{e}")
return jsonify({"type": "error", "error": str(e)})
def session_save(func):
@wraps(func)
def wrapper(*args, **kwargs):
id = request.args.get("id")
user_id = request.args.get("user_id")
logger.info(f" id: {id},user_id: {user_id}")
result = func(*args, **kwargs)
# def requires_cache_2(required_keys):
# def decorator(f):
# @wraps(f)
# def decorated(*args, **kwargs):
# id = request.args.get("id")
# user_id = request.args.get("user_id")
# if user_id is None:
# user_id = request.json.get("user_id")
# if user_id is None:
# return jsonify({"type": "error", "error": "No user_id provided"})
# if id is None:
# id = request.json.get("id")
# if id is None:
# return jsonify({"type": "error", "error": "No id provided"})
# all_v = cache.items()
# logger.info(f"all values {all_v}")
# logger.info(f"user {user_id} id {id}")
# qa_list = cache.get(id=user_id, field="qa_list")
# if qa_list is None:
# return jsonify({"type": "error", "error": f"No qa_list found"})
# logger.info(f"qa_list {qa_list}")
# q_a = list(filter(lambda x: x["id"] == id, qa_list))
# logger.info(f"q_a {q_a}")
# for key in required_keys:
# if q_a[0][key] is None:
# return jsonify({"type": "error", "error": f"No {key} found for id:{id}"})
# values = {key:q_a[0][key] for key in required_keys}
# values["id"] = id
# logger.info("cache values {0}".format(values))
#
# return f(*args, **values, **kwargs)
#
# return decorated
#
# return decorator
datas = []
session_len = int(config("SESSION_LENGTH", default=2))
if cache.exists(id=user_id, field="data"):
datas = copy.deepcopy(cache.get(id=user_id, field="data"))
data = {
"id": id,
"question": cache.get(id=id, field="question"),
"sql": cache.get(id=id, field="sql")
}
datas.append(data)
logger.info("datas is {0}".format(datas))
if len(datas) > session_len and session_len > 0:
datas = datas[-session_len:]
# 删除id对应的所有缓存值,因为已经run_sql完毕改用user_id保存为上下文
cache.delete(id=id, field="question")
cache.set(id=user_id, field="data", value=copy.deepcopy(datas))
logger.info(f" user data {cache.get(user_id, field='data')}")
return result
return wrapper
# def session_save(func):
# @wraps(func)
# def wrapper(*args, **kwargs):
# id = request.args.get("id")
# user_id = request.args.get("user_id")
# logger.info(f" id: {id},user_id: {user_id}")
# result = func(*args, **kwargs)
#
# datas = []
# session_len = int(config("SESSION_LENGTH", default=2))
# if cache.exists(id=user_id, field="qa_list"):
# datas = copy.deepcopy(cache.get(id=user_id, field="qa_list"))
# logger.info("datas is {0}".format(datas))
# if len(datas) > session_len and session_len > 0:
# logger.info(f"开始裁剪-------------------------------------")
# datas=datas[-session_len:]
# # 删除id对应的所有缓存值,因为已经run_sql完毕改用user_id保存为上下文
# # cache.delete(id=id, field="question")
# print("datas---------------------{0}".format(datas))
# cache.set(id=user_id, field="qa_list", value=copy.deepcopy(datas))
# logger.info(f" user data {cache.get(user_id, field='qa_list')}")
# return result
#
# return wrapper
@app.flask_app.route("/yj_sqlbot/api/v0/run_sql_2", methods=["GET"])
@session_save
@app.requires_cache(["sql"])
def run_sql_2(id: str, sql: str):
# @session_save
# @requires_cache_2(required_keys=["sql"])
def run_sql_2():
"""
Run SQL
---
@@ -169,10 +219,6 @@ def run_sql_2(id: str, sql: str):
in: query|body
type: string
required: true
- name: page_size
in: query
-name: page_num
in: query
responses:
200:
schema:
@@ -190,7 +236,9 @@ def run_sql_2(id: str, sql: str):
"""
logger.info("Start to run sql in main")
try:
user_id = request.args.get("user_id")
id = request.args.get("id")
sql = get_sql_by_id(id)
logger.info(f"sql is {sql}")
if not vn.run_sql_is_set:
return jsonify(
{
@@ -199,11 +247,6 @@ def run_sql_2(id: str, sql: str):
}
)
# count_sql = f"SELECT COUNT(*) AS total_count FROM ({sql}) AS subquery"
# df_count = vn.run_sql(count_sql)
# print(df_count,"is type",type(df_count))
# total_count = df_count.to_dict(orient="records")[0]["total_count"]
# logger.info("Total count is {0}".format(total_count))
df = vn.run_sql_2(sql=sql)
result = df.to_dict(orient='records')
logger.info("df ---------------{0} {1}".format(result, type(result)))
@@ -240,11 +283,36 @@ def verify_user():
@app.flask_app.route("/yj_sqlbot/api/v0/query_present_question", methods=["GET"])
def query_present_question():
try:
data=query_predefined_question_list()
return jsonify({"type": "success", "data": data})
data = query_predefined_question_list()
return jsonify({"type": "success", "data": data})
except Exception as e:
logger.error(f"查询预制问题失败 failed:{e}")
return jsonify({"type": "error", "error":f'查询预制问题失败:{str(e)}'})
return jsonify({"type": "error", "error": f'查询预制问题失败:{str(e)}'})
@app.flask_app.route("/yj_sqlbot/api/v0/gen_graph_question", methods=["GET"])
def gen_graph_question():
try:
config = {"configurable": {"thread_id": '1233'}}
question = flask.request.args.get("question")
initial_state: SqlAgentState = {
"user_question": question,
"history": [{"role":"user","content":"宋亚澜9月在林芝工作多少天"},{"role":"user","content":"余佳佳9月在林芝工作多少天"}],
"sql_retry_count": 0,
"chart_retry_count": 0
}
result=sql_chart_agent.invoke(initial_state, config=config)
data={
'sql':result.get("gen_sql_result",{}),
'chart':result.get("gen_chart_result",{}),
'gen_sql_error':result.get("gen_sql_error",None),
'gen_chart_error':result.get("gen_chart_error",None),
}
return jsonify(data)
except Exception as e:
traceback.print_exc()
logger.error(f"查询预制问题失败 failed:{e}")
return jsonify({"type": "error", "error": f'查询预制问题失败:{str(e)}'})
if __name__ == '__main__':

View File

@@ -2,4 +2,6 @@ vanna ==0.7.9
vanna[openai]
vanna[qdrant]
python-decouple==3.8
dmPython==2.5.22
dmPython==2.5.22
langgraph
langchain-openai

View File

@@ -0,0 +1,85 @@
from datetime import datetime
from db_util.db_main import Conversation, SqliteSqlalchemy
import logging
logger = logging.getLogger(__name__)
def save_conversation(id, user_id, cvs_id, question):
cvs = Conversation(id=id, user_id=user_id, cvs_id=cvs_id, question=question, create_time = datetime.now())
session = SqliteSqlalchemy().session
try:
session.add(cvs)
session.commit()
except:
session.rollback()
finally:
session.close()
def get_conversation(cvs_id: str):
session = SqliteSqlalchemy().session
try:
results = session.query(Conversation).filter(Conversation.id == cvs_id)
logger.info(f"conversation {cvs_id} results is {results}")
return results.all()
except Exception as e:
logger.info(f"get conversation with id {cvs_id} error {e}")
finally:
session.close()
# def get_all_conversations_by_user(user_id):
# session = SqliteSqlalchemy().session
# user_cvs = []
# try:
# results = session.query(Conversation).filter(Conversation.user_id == user_id).all()
# logger.info(f"conversation {user_id} results is {results}")
# cvs = {}
# for rs in results:
# cvs[rs.cvs_id] = {
#
# }
def update_conversation(cvs_id: str, id: str, sql=None, meta=None):
"""更新sql到对应question"""
session = SqliteSqlalchemy().session
try:
if sql:
session.query(Conversation).filter(Conversation.cvs_id == cvs_id, Conversation.id == id).update({Conversation.sql: sql})
if meta:
session.query(Conversation).filter(Conversation.cvs_id == cvs_id, Conversation.id == id).update({Conversation.meta: meta})
session.commit()
except Exception as e:
session.rollback()
finally:
session.close()
def get_latest_question(cvs_id, user_id, limit_count):
"""获取指定会话的最新问题"""
session = SqliteSqlalchemy().session
try:
latest_conversation = session.query(Conversation).filter_by(
cvs_id=cvs_id,
user_id=user_id
).order_by(Conversation.create_time.desc()).limit(limit_count).all()
last_question = [cs.question for cs in latest_conversation]
return last_question
except Exception as e:
logger.error(f"get_latest_question error {e}")
finally:
session.close()
def get_sql_by_id(id: str):
session = SqliteSqlalchemy().session
try:
result = session.query(Conversation).filter_by(id=id).first()
if result:
return result.sql
return None
except Exception as e:
logger.error(f"get_sql_by_id error {e}")
finally:
session.close()

View File

@@ -21,6 +21,10 @@ import logging
from util import train_ddl
logger = logging.getLogger(__name__)
import traceback
from service.conversation_service import get_latest_question,update_conversation
limit_count = 3
class OpenAICompatibleLLM(VannaBase):
def __init__(self, client=None, config_file=None):
VannaBase.__init__(self, config=config_file)
@@ -124,24 +128,25 @@ class OpenAICompatibleLLM(VannaBase):
return {"role": "assistant", "content": message}
def submit_prompt(self, prompt, **kwargs) -> str:
logger.info(f"submit prompt: {prompt}")
if prompt is None:
print("test1")
logger.info("test1")
raise Exception("Prompt is None")
if len(prompt) == 0:
print("test2")
logger.info("test2")
raise Exception("Prompt is empty")
print(prompt)
num_tokens = 0
for message in prompt:
num_tokens += len(message["content"]) / 4
print("test3 {0}".format(num_tokens))
logger.info("test3 {0}".format(num_tokens))
if kwargs.get("model", None) is not None:
print("test4")
logger.info("test4")
model = kwargs.get("model", None)
print(
logger.info(
f"Using model {model} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
@@ -152,9 +157,9 @@ class OpenAICompatibleLLM(VannaBase):
temperature=self.temperature,
)
elif kwargs.get("engine", None) is not None:
print("test5")
logger.info("test5")
engine = kwargs.get("engine", None)
print(
logger.info(
f"Using model {engine} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
@@ -165,8 +170,8 @@ class OpenAICompatibleLLM(VannaBase):
temperature=self.temperature,
)
elif self.config is not None and "engine" in self.config:
print("test6")
print(
logger.info("test6")
logger.info(
f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)"
)
response = self.client.chat.completions.create(
@@ -177,11 +182,11 @@ class OpenAICompatibleLLM(VannaBase):
temperature=self.temperature,
)
elif self.config is not None and "model" in self.config:
print("test7")
print(
logger.info("test7")
logger.info(
f"Using model {self.config['model']} for {num_tokens} tokens (approx)"
)
print("config is ",self.config)
logger.info("config is ",self.config)
response = self.client.chat.completions.create(
model=self.config["model"],
messages=prompt,
@@ -201,13 +206,13 @@ class OpenAICompatibleLLM(VannaBase):
# json=data
# )
else:
print("test8")
logger.info("test8")
if num_tokens > 3500:
model = "kimi"
else:
model = "doubao"
print(f"5.Using model {model} for {num_tokens} tokens (approx)")
logger.info(f"5.Using model {model} for {num_tokens} tokens (approx)")
response = self.client.chat.completions.create(
model=model,
@@ -222,21 +227,30 @@ class OpenAICompatibleLLM(VannaBase):
return response.choices[0].message.content
def generate_sql_2(self, question: str, cache=None,user_id=None, allow_llm_to_see_data=False, **kwargs) -> dict:
def generate_sql_2(self, user_id: str, cvs_id: str, question: str,id: str, need_context: bool, allow_llm_to_see_data=False, **kwargs) -> dict:
try:
logger.info("Start to generate_sql_2 in cus_vanna_srevice")
question_sql_list = self.get_similar_question_sql(question, **kwargs)
if question_sql_list and len(question_sql_list)>2:
question_sql_list=question_sql_list[:2]
ddl_list = self.get_related_ddl(question, **kwargs)
#doc_list = self.get_related_documentation(question, **kwargs)
template = get_base_template()
sql_temp = template['template']['sql']
char_temp = template['template']['chart']
history = None
if user_id and cache:
history = cache.get(id=user_id, field="data")
if need_context:
questions = get_latest_question(cvs_id, user_id,limit_count)
logger.info(f"latest_questions is {questions}")
if questions[0] != question:
raise Exception(f"上下文不匹配 {question} {questions[0]}")
new_question = self.generate_rewritten_question(questions,**kwargs)
logger.info(f"new_question is {new_question}")
question = new_question if new_question else question
update_conversation(cvs_id, id, meta=question)
# if user_id and cache:
# history = cache.get(id=user_id, field="data")
# --------基于提示词生成sql以及图表类型
sys_temp = sql_temp['system'].format(engine=config("DB_ENGINE", default='mysql'), lang='中文',
schema=ddl_list, documentation=[train_ddl.train_document],
@@ -275,6 +289,7 @@ class OpenAICompatibleLLM(VannaBase):
logger.info("Finish to generate_sql_2 in cus_vanna_srevice")
return result
except Exception as e:
logger.info("cus_vanna_srevice failed-------------------: ")
traceback.print_exc()
raise e
@@ -286,19 +301,166 @@ class OpenAICompatibleLLM(VannaBase):
logger.error("run_sql failed {0}".format(sql))
raise e
def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str:
logger.info(f"generate_rewritten_question---------------{new_question}")
if last_question is None:
def generate_rewritten_question(self, questions: str, **kwargs) -> str:
logger.info(f"generate_rewritten_question---------------{questions}")
new_question = questions[0]
context_question = questions[1:]
if not context_question:
return new_question
print("last question {0}".format(last_question))
print("last question {0}".format(context_question))
print("new question {0}".format(new_question))
# sys_info = '''
# 你是一个问题补全助手先判断问题1是否存在信息不完整的情况如果不完整则根据上下文问题2问题3来补全问题1
# 按时间顺序从新到旧问题1、问题2、问题3问题1是用户当前提出的问题
#
# 【准则一】独立性优先
# 如果问题1本身含义完整不依赖其他问题的上下文也能被理解则直接返回问题1,禁止强行合并。
# 【准则二】最新问题优先
# 问题1始终作为核心只判断它是否需要利用前序问题补充自身信息当它含义完整时不再考虑合并。
# 合并时只能用较旧的问题问题2、问题3的信息来补全较新的问题问题1不能反向操作。
# 要以问题1的中心思想为准禁止合并后该表问题1的中心思想
# 【准则三】单向合并限制
# 只有问题1能与其他问题合并问题2和问题3之间不能单独合并。
# 【准则四】顺序依赖判断
# 只有当问题1明确依赖问题2或问题3的结果或上下文时才考虑合并。依赖特征包括
# - 问题1中包含"其中"、"这些"、"这个"、"那些"、"他"、"他们"等指代性词语
# - 问题1是在问题2或问题3基础上的细节追问
# - 问题1具有明确的顺序逻辑关系
# - 问题1缺少必要的主语或时间范围或具体查询信息等上下文信息
# 【准则五】合并范围选择
# - 如果问题1只依赖问题2则合并问题2+问题1
# - 如果问题1只依赖问题3则合并问题3+问题1
# - 如果问题1依赖问题2可得出结果依赖问题3也可得出结果则就近原则合并问题2+问题1
# - 如果问题1同时依赖问题2和问题3才能得出结果则合并问题3+问题2+问题1
# - 如果问题1不依赖任何其他问题则直接返回问题1
#
# 【准则六】合并执行原则
# 将选择的问题自然衔接成一个完整问题,不要添加任何解释性文字。
#
# 【准则七】SQL可行性验证
# 合并后的问题应该能够通过一条SQL查询语句来回答。
#
# 【准则八】兜底措施
# 当你无法判断问题一是否完整,也无法判断问题一是否依赖其他问题才能补全信息时,请直接向用户询问细节
#
#
# 示例:
# 输入问题1="早退多少天"问题2="其中迟到多少天"问题3="张三九月工作了多少天"
# 输出:"张三九月早退多少天"
#
# 输入问题1="这些天里张三是否有迟到问题2="李四考勤""问题3="张三九月的考勤"
# 输出:"张三九月的迟到情况"
#
# 输入问题1="最近一个月是否有迟到"问题2="李四考勤"问题3="张三九月的考勤"
# 输出:"张三最近一个月是否有迟到"
#
# 输入问题1="迟到了多少天"问题2="他哪几天迟到了"问题3="张三九月在林芝是否早退"
# 输出:"张三九月在林芝迟到了多少天"
#
# 输入问题1="张三九月休息了多少天" 问题2="张三九月迟到了多少天"问题3="张三其中迟到多少天"
# 输出:"张三九月休息了多少天"
#
# 输入问题1="张三九月考勤情况" 问题2="张三九月迟到了多少天"问题3="李四迟到多少天"
# 输出:"张三九月考勤情况"
# '''
# sys_info2 = '''
# 你是一个问题补全助手任务是判断用户当前提出的问题问题1是否信息完整。
#
# 若问题1语义完整且可独立理解则直接返回原问题1
#
# 若问题1信息缺失或存在指代依赖则根据其前序上下文问题2 和 问题3按时间倒序排列问题1 最新问题2 次新问题3 最早)进行最小必要补全,生成一个语义完整、忠实于原意、且能通过一条 SQL 查询回答的问题。
#
# 请严格遵循以下准则:
# <>
# <rule_title>独立性优先</rule_title>
# <rule>若问题1 本身语义完整即不依赖问题2 或问题3 也能被准确理解则直接返回问题1禁止强行合并或改写。也不再受下面的规则约束</rule>
#
# <rule_title>以最新问题为核心</rule_title>
#
# <rule>问题1 始终是查询意图的唯一来源</rule>
# <rule>补全时只能借用问题2 或问题3 中的信息如主语、时间、地点等补全问题1的缺失不得改变问题1 的核心要素。</rule>
# <rule>合并后的问题必须完全保留问题1 的原始意图、时间范围、查询对象和动作。</rule>
# <rule>如果问题1已有明确的时间、地点、人物等信息禁止用前序问题的不同信息进行覆盖替换。</rule>
#
# <rule_title>单向合并限制</rule_title>
# <rule>仅允许将问题1 与问题2 或问题3 合并。</rule>
# <rule>禁止问题2 与问题3 直接合并也禁止忽略问题1 进行其他组合。</rule>
#
# <rule_title>依赖判断标准</rule_title>
# <rule>仅当问题1 明确依赖前序问题的上下文时,才触发合并。</rule>
# <rule>包含指代词:如“这个”“这些”“其中”“他”“他们”等;</rule>
# <rule>是对前一个问题的细节追问(如追问数量、时间、条件等);</rule>
# <rule>存在顺序逻辑(如“然后呢?”“接下来怎么样?”);<rule>
# <rule>缺失关键要素:如主语、时间范围、地点、对象等,需从前序问题中补全。</rule>
#
# <rule_title>合并范围选择规则</rule_title>
# 根据依赖关系,按以下优先级确定合并方式:
# <rule>仅依赖问题2 → 合并为问题1 + 问题2</rule>
# <rule>仅依赖问题3 → 合并为问题1 + 问题3</rule>
# <rule>问题2 和问题3 均可独立支撑问题1 → 采用就近原则合并问题1 + 问题2</rule>
# <rule>必须同时依赖问题2 和问题3 才能完整理解问题1 → 合并为问题1 + 问题2 + 问题3</rule>
# <rule>不依赖任何前序问题 → 直接返回问题1</rule>
# <rule_title>自然衔接,无额外内容</rule_title>
# <rule>合并后的问题必须是一个语法通顺、语义连贯的完整问句,不得添加解释、连接词或说明性文字(如“根据前面的问题”“结合上下文”等)。</rule>
#
# <rule_title>SQL 可执行性</rule_title>
# <rule>合并后的问题必须能通过一条 SQL 查询直接回答。</rule>
# <rule>若合并后的问题模糊、多义、或无法映射到具体数据库字段,则不应合并</rule>
#
# <rule_title>兜底策略</rule_title>
# <rule>若无法明确判断问题1 是否完整,或无法确定其是否依赖前序问题,请不要猜测,而是主动向用户请求澄清或补充细节。</rule>
#
# <example>
# 输入问题1="早退多少天"问题2="其中迟到多少天"问题3="张三九月工作了多少天"
# 输出:"张三九月早退多少天"
#
# 输入问题1="这些天里张三是否有迟到问题2="李四考勤""问题3="张三九月的考勤"
# 正确输出:"张三九月的迟到情况"
# 错误输出:
#
# 输入问题1="最近一个月是否有迟到"问题2="李四考勤"问题3="张三九月的考勤"
# 输出:"张三最近一个月是否有迟到"
#
# 输入问题1="迟到了多少天"问题2="他哪几天迟到了"问题3="张三九月在林芝是否早退"
# 输出:"张三九月在林芝迟到了多少天"
#
# 输入问题1="张三九月休息了多少天" 问题2="张三九月迟到了多少天"问题3="张三其中迟到多少天"
# 输出:"张三九月休息了多少天"
#
# 输入问题1="张三九月考勤情况" 问题2="张三九月迟到了多少天"问题3="李四迟到多少天"
# 输出:"张三九月考勤情况"
#
# 输入问题1="张三9月在林芝上班多少天" 问题2="9月29的考勤"问题3="9月29的考勤"
# 输出:"张三9月在林芝上班多少天"
# </example>
# '''
# sys_info3 = '''
# '''
sys_info = '''
你是一个问题补全助手任务是判断用户当前提出的问题问题1是否信息完整。
处理流程:
先判断问题1是否语义完整
如果问题1 自身含义清晰、包含必要要素如主语、时间范围、具体查询目标等指代明确不依赖任何上下文也能被准确理解则直接返回问题1原文禁止任何形式的改写或合并。
仅当问题1信息缺失时才使用上下文补全
上下文包括前两个历史问题问题2较近、问题3较远
补全时遵循就近优先原则优先使用问题2 的信息仅当问题2 无法提供所需信息且问题3 可补全时才使用问题3。
若需同时依赖问题2 和问题3 才能补全,则按 问题3 + 问题2 + 问题1 的顺序融合。
若问题1 中包含“他”“她”“它”等代词,无需无需区分性别或语义类别,(他不一定代表男,她也不一定代表女),采用就近原则从上下文中找出最近的具有人名或明确实体的主语进行替换。
主语选择必须严格遵循时间顺序仅当问题2 中无有效主语时才考虑问题3。只要问题2 包含明确人名就必须使用问题2 的主语
补全要求:
合并后的问题必须是一个语法通顺、语义完整的自然问句。
不得添加任何解释性、连接性或说明性文字(如“根据前面的问题”“结合上下文”等)。
补全后的问题必须能通过一条 SQL 查询语句直接回答(即具备明确的查询对象、条件和指标)。
兜底策略:
如果你无法确定问题1 是否完整,或无法判断是否依赖上下文,或即使合并上下文仍无法形成完整、可执行的问题,请不要猜测或强行输出,而是直接向用户请求补充细节。
'''
prompt = [
self.system_message(
"你的目标是将一系列相关问题合并成一个单一的问题。"
"合并准则一、如果第二个问题与第一个问题无关且本身是完整独立的,则直接返回第二个问题。"
"合并准则二、如果第二个问题域第一个问题相关,且要基于第一个问题的前提,请合并两个问题为一个问题,只需返回合并后的新问题,不要添加任何额外解释。"
"合并准则三、理论上合并后的问题应该能够通过单个SQL语句来回答"),
self.user_message("First question: " + last_question + "\nSecond question: " + new_question),
self.system_message(sys_info),
# self.user_message("First question: " + last_question + "\nSecond question: " + new_question),
self.user_message("问题1: " + new_question + "\n上下文: " +str(context_question))
]
return self.submit_prompt(prompt=prompt, **kwargs)
@@ -329,17 +491,18 @@ class CustomQdrant_VectorStore(Qdrant_VectorStore):
request_body = {
"model": self.embedding_model_name,
"sentences": [data],
"encoding_format": "float",
"input": [data],
}
# request_body = {
# "model": self.embedding_model_name,
# "encoding_format": "float",
# "input": [data],
# "input": [{"type":"text","text":data}],
# }
request_body.update(kwargs)
response = requests.post(
url=f"{self.embedding_api_base}",
url=f"{self.embedding_api_base}/v1/embeddings",
json=request_body,
headers={"Authorization": f"Bearer {self.embedding_api_key}", 'Content-Type': 'application/json'},
)
@@ -348,10 +511,10 @@ class CustomQdrant_VectorStore(Qdrant_VectorStore):
f"Failed to create the embeddings, detail: {_get_error_string(response)}"
)
result = response.json()
embeddings = result['embeddings']
# embeddings = result['data'][0]['embedding']
return embeddings[0]
# return embeddings
# embeddings = result['embeddings']
embeddings = result['data'][0]['embedding']
# return embeddings[0]
return embeddings
class CustomVanna(CustomQdrant_VectorStore, OpenAICompatibleLLM):
def __init__(self, llm_config=None, vector_store_config=None):

View File

@@ -538,4 +538,47 @@ template:
{sub_query}
rewrite_question:
system: |
# 角色
你是一个专业的智能数据分析助手的对话上下文理解模块。
# 任务
你的核心任务是判断用户的当前问题是否依赖于之前的对话历史,并据此生成一个独立的、可供数据分析系统直接执行的问题。
# 核心规则
1. **识别关联性**
* 分析当前问题是否包含指代词(如“它”、“这个”、“那个”)、简称(如“去年同期”、“此产品”)、或省略的主语/宾语。
* 如果存在上述特征,且这些信息能在历史对话中找到明确的对应项,则判定为**关联**。
* 如果当前问题是一个全新的、独立的查询,与历史话题无关,则判定为**不关联**。
2. **生成最终问题**
* **如果判定为【关联】**:你必须将历史对话中的相关上下文信息**融合**到当前问题中,形成一个**完整清晰、无任何指代或歧义的新问题**。
* **如果判定为【不关联】**:你只需**原样输出**当前问题。
# 输出格式要求
- **你的唯一输出必须是一个JSON对象**。
- **严禁**在JSON前后添加任何解释性文字、代码块标记如 ```json或任何其他内容。
- JSON的键固定为 `question`,值为处理后的最终问题字符串。
**最终输出格式示例**
{{"question": "A产品上个月的成本是多少"}}
<examples>
## 示例1
**对话历史:**
Q: B产品上月的销量是 A: 200件
**当前用户问题:**
那利润率呢?
**输出:**
{{"question": "B产品上月的利润率是多少"}}
## 示例2
**对话历史:**
Q: 列出所有华东区的销售员 A: 张三、李四、王五
**当前用户问题:**
帮我预定一张明天去上海的机票
**输出:**
{{"question": "帮我预定一张明天去上海的机票"}}
</examples>
####
# 上下文信息
**对话历史:**
{history}
**当前用户问题:**
{current_question}
Resources:

View File

@@ -701,7 +701,7 @@ question_and_answer = [
"category": "外部单位统计"
},
{
"question": "XX中心员工在林芝工作的天数",
"question": "XX中心员工在林芝工作的天数排行",
"answer": '''
SELECT p."code" AS "工号",
p."name" AS "姓名",
@@ -727,6 +727,105 @@ question_and_answer = [
"tags": ["员工", "部门", "考勤", "工作地", "区域", "工作天数统计"],
"category": "工作地考勤统计分析"
},
{
"question": "XX中心张XX十月在林芝工作了多长时间",
"answer": '''
SELECT
p."code" AS "工号",
p."name" AS "姓名",
SUM(ps."work_time") AS "在林芝工作时间"
FROM "YJOA_APPSERVICE_DB"."t_pr3rl2oj_yj_person_database" p
INNER JOIN "YJOA_APPSERVICE_DB"."t_yj_person_status" ps
ON p."code" = ps."person_id"
WHERE p."dr" = 0
AND ps."dr" = 0
AND p."name" = '张XX'
AND ps."date_value" BETWEEN '2025-10-01' AND '2025-10-31'
AND (p."code", ps."date_value") IN (
SELECT
a."person_id",
TO_CHAR(a."attendance_time", 'yyyy-MM-dd')
FROM "YJOA_APPSERVICE_DB"."t_yj_person_attendance" a
INNER JOIN "YJOA_APPSERVICE_DB"."t_yj_person_ac_area" ac
ON a."access_control_point" = ac."ac_point"
WHERE a."dr" = 0
AND ac."region" = 5
)
AND p."internal_dept" IN (
SELECT "id"
FROM "IUAP_APDOC_BASEDOC"."org_orgs"
START WITH ("name" LIKE '%XX中心%' OR "shortname" LIKE '%XX中心%')
AND "dr" = 0 AND "enable" = 1 AND "code" LIKE '%CYJ%'
CONNECT BY PRIOR "id" = "parentid"
)
GROUP BY p."code", p."name"
ORDER BY "在林芝工作时间" DESC
LIMIT 1000;
''',
"tags": ["员工", "部门", "考勤", "工作地", "区域", "工作时长统计"],
"category": "工作地考勤统计分析"
},
{
"question": "XX中心张三在林芝工作了多少天迟到了多少天",
"answer": '''
SELECT p."name" AS "姓名",
COUNT(DISTINCT TO_CHAR(a."attendance_time", 'yyyy-MM-dd')) AS "在林芝工作天数",
COUNT(DISTINCT CASE WHEN ps."status" IN ('1006','1009','6002','6004') THEN ps."date_value" END) AS "在林芝迟到的天数"
FROM "YJOA_APPSERVICE_DB"."t_pr3rl2oj_yj_person_database" p
LEFT JOIN "YJOA_APPSERVICE_DB"."t_yj_person_attendance" a ON p."code" = a."person_id"
LEFT JOIN "YJOA_APPSERVICE_DB"."t_yj_person_ac_area" ac ON a."access_control_point" = ac."ac_point"
LEFT JOIN "YJOA_APPSERVICE_DB"."t_yj_person_status" ps ON p."code" = ps."person_id"
AND TO_CHAR(a."attendance_time", 'yyyy-MM-dd') = ps."date_value"
WHERE p."dr" = 0
AND a."dr" = 0
AND ps."dr" = 0
AND ac."region" = 5
AND p."name" = '张三'
AND a."attendance_time" LIKE '2025-09%'
AND ps."date_value" LIKE '2025-09%'
AND p."internal_dept" IN (
SELECT "id"
FROM "IUAP_APDOC_BASEDOC"."org_orgs"
START WITH ("name" LIKE '%XX中心%' OR "shortname" LIKE '%XX中心%')
AND "dr" = 0 AND "enable" = 1 AND "code" LIKE '%CYJ%'
CONNECT BY PRIOR "id" = "parentid"
)
GROUP BY p."name"
''',
"tags": ["员工", "部门", "考勤", "工作地", "区域", "工作天数统计", "迟到天数统计"],
"category": "工作地考勤统计分析"
},
{
"question": "张三9月在林芝上班期间有多少天早退了",
"answer": '''
SELECT p."name" AS "姓名",
COUNT(DISTINCT ps."date_value") AS "早退天数"
FROM "YJOA_APPSERVICE_DB"."t_yj_person_status" ps
INNER JOIN "YJOA_APPSERVICE_DB"."t_pr3rl2oj_yj_person_database" p
ON p."code" = ps."person_id"
INNER JOIN "YJOA_APPSERVICE_DB"."t_yj_person_attendance" pa
ON pa."person_id" = ps."person_id"
AND pa."attendance_time" LIKE '2025-09%'
INNER JOIN "YJOA_APPSERVICE_DB"."t_yj_person_ac_position" acp
ON acp."ac_point" = pa."access_control_point"
INNER JOIN "YJOA_APPSERVICE_DB"."t_yj_person_ac_area" aca
ON aca."ac_point" = acp."ac_point"
WHERE p."name" = '张三'
AND ps."date_value" LIKE '2025-09%'
AND ps."dr" = 0
AND p."dr" = 0
AND aca."region" = '5'
AND ps."status" IN ('1006','6001','4006')
GROUP BY p."name"
LIMIT 1000
''',
"tags": ["员工", "部门", "考勤", "工作地", "区域","早退天数统计"],
"category": "工作地考勤统计分析"
},
{
"question": "XX中心员工在成都工作的天数",