Files
sqlbot_agent/main_service.py
2025-11-28 16:40:39 +08:00

508 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import copy
import json
import logging
import uuid
import time
from functools import wraps
import util.utils
from logging_config import LOGGING_CONFIG
from service.cus_vanna_srevice import CustomVanna, QdrantClient, TTLCacheWrapper
from service.question_feedback_service import save_save_question_async, query_predefined_question_list, \
update_user_feedBack, query_feedBack_question_list
from service.conversation_service import (save_conversation, update_conversation,
get_qa_by_id, get_latest_question,get_all_conversations_by_user)
from decouple import config
import flask
from util import load_ddl_doc
from flask import Flask, Response, jsonify, request
from graph_chat.gen_sql_chart_agent import SqlAgentState, sql_chart_agent
from graph_chat.gen_data_report_agent import result_report_agent, DateReportAgentState
import traceback
logger = logging.getLogger(__name__)
logging.getLogger('sqlalchemy.engine').setLevel(logging.DEBUG)
def generate_timestamp_id():
"""生成基于时间戳的ID"""
# 获取当前时间戳(秒级)
timestamp = int(time.time() * 1000)
return f"Q{timestamp}"
def connect_database(vn):
db_type = config('DATA_SOURCE_TYPE', default='sqlite')
if db_type == 'sqlite':
vn.connect_to_sqlite(config('SQLITE_DATABASE_URL', default=''))
elif db_type == 'mysql':
vn.connect_to_mysql(host=config('MYSQL_DATABASE_HOST', default=''),
port=int(config('MYSQL_DATABASE_PORT', default=3306)),
user=config('MYSQL_DATABASE_USER', default=''),
password=config('MYSQL_DATABASE_PASSWORD', default=''),
database=config('MYSQL_DATABASE_DBNAME', default=''))
elif db_type == 'dameng':
# 待补充
vn.connect_to_dameng(
host=config('DAMENG_DATABASE_HOST', default=''),
port=config('DAMENG_DATABASE_PORT', default=3306),
user=config('DAMENG_DATABASE_USER', default=''),
password=config('DAMENG_DATABASE_PASSWORD', default=''),
)
else:
pass
def load_train_data_ddl(vn: CustomVanna):
vn.train()
def create_vana():
logger.info("----------------create vana ---------")
q_client = QdrantClient(":memory:") if config('QDRANT_TYPE', default='memory') == 'memory' else QdrantClient(
url=config('QDRANT_DB_HOST', default=''), port=config('QDRANT_DB_PORT', default=6333))
vn = CustomVanna(
vector_store_config={"client": q_client},
llm_config={
"api_key": config('CHAT_MODEL_API_KEY', default=''),
"api_base": config('CHAT_MODEL_BASE_URL', default=''),
"model": config('CHAT_MODEL_NAME', default=''),
'temperature': config('CHAT_MODEL_TEMPERATURE', default=0.7, cast=float),
'max_tokens': config('CHAT_MODEL_MAX_TOKEN', default=5000),
},
)
return vn
def init_vn(vn):
logger.info("--------------init vana-----connect to datasouce db----")
connect_database(vn)
if config('IS_FIRST_LOAD', default=False, cast=bool):
load_ddl_doc.add_ddl(vn)
load_ddl_doc.add_documentation(vn)
load_train_data_ddl(vn)
return vn
from vanna.flask import VannaFlaskApp
vn = create_vana()
app = VannaFlaskApp(vn, chart=False)
app.cache = TTLCacheWrapper(app.cache, ttl=config('TTL_CACHE', cast=int, default=60 * 60))
init_vn(vn)
cache = app.cache
@app.flask_app.route("/yj_sqlbot/api/v0/generate_sql_2", methods=["GET"])
def generate_sql_2():
"""
Generate SQL from a question
---
parameters:
- name: user_id
in: query
- name: question
in: query
- name: question_id
responses:
200:
schema:
type: object
properties:
type:
type: string
default: sql
id:
type: string
text:
type: string
"""
logger.info("Start to generate sql in main")
question = flask.request.args.get("question")
if question is None:
return jsonify({"type": "error", "error": "No question provided"})
user_id = request.args.get("user_id")
cvs_id = request.args.get("cvs_id")
need_context = bool(request.args.get("need_context"))
if user_id is None or cvs_id is None:
return jsonify({"type": "error", "error": "No user_id or cvs_id provided"})
id = generate_timestamp_id()
logger.info(f"question_id: {id} user_id: {user_id} cvs_id: {cvs_id} question: {question}")
save_conversation(id, user_id, cvs_id, question)
try:
logger.info(f"Generate sql for {question}")
data = vn.generate_sql_2(user_id, cvs_id, question, id, need_context)
logger.info("Generate sql result is {0}".format(data))
data['id'] = id
sql = data["resp"]["sql"]
logger.info("generate sql is : " + sql)
update_conversation(id, sql)
save_save_question_async(id, user_id, question, sql)
data["type"] = "success"
return jsonify(data)
except Exception as e:
logger.error(f"generate sql failed:{e}")
return jsonify({"type": "error", "error": str(e)})
# def requires_cache_2(required_keys):
# def decorator(f):
# @wraps(f)
# def decorated(*args, **kwargs):
# id = request.args.get("id")
# user_id = request.args.get("user_id")
# if user_id is None:
# user_id = request.json.get("user_id")
# if user_id is None:
# return jsonify({"type": "error", "error": "No user_id provided"})
# if id is None:
# id = request.json.get("id")
# if id is None:
# return jsonify({"type": "error", "error": "No id provided"})
# all_v = cache.items()
# logger.info(f"all values {all_v}")
# logger.info(f"user {user_id} id {id}")
# qa_list = cache.get(id=user_id, field="qa_list")
# if qa_list is None:
# return jsonify({"type": "error", "error": f"No qa_list found"})
# logger.info(f"qa_list {qa_list}")
# q_a = list(filter(lambda x: x["id"] == id, qa_list))
# logger.info(f"q_a {q_a}")
# for key in required_keys:
# if q_a[0][key] is None:
# return jsonify({"type": "error", "error": f"No {key} found for id:{id}"})
# values = {key:q_a[0][key] for key in required_keys}
# values["id"] = id
# logger.info("cache values {0}".format(values))
#
# return f(*args, **values, **kwargs)
#
# return decorated
#
# return decorator
# def session_save(func):
# @wraps(func)
# def wrapper(*args, **kwargs):
# id = request.args.get("id")
# user_id = request.args.get("user_id")
# logger.info(f" id: {id},user_id: {user_id}")
# result = func(*args, **kwargs)
#
# datas = []
# session_len = int(config("SESSION_LENGTH", default=2))
# if cache.exists(id=user_id, field="qa_list"):
# datas = copy.deepcopy(cache.get(id=user_id, field="qa_list"))
# logger.info("datas is {0}".format(datas))
# if len(datas) > session_len and session_len > 0:
# logger.info(f"开始裁剪-------------------------------------")
# datas=datas[-session_len:]
# # 删除id对应的所有缓存值,因为已经run_sql完毕改用user_id保存为上下文
# # cache.delete(id=id, field="question")
# print("datas---------------------{0}".format(datas))
# cache.set(id=user_id, field="qa_list", value=copy.deepcopy(datas))
# logger.info(f" user data {cache.get(user_id, field='qa_list')}")
# return result
#
# return wrapper
@app.flask_app.route("/yj_sqlbot/api/v0/run_sql_2", methods=["GET"])
# @session_save
# @requires_cache_2(required_keys=["sql"])
def run_sql_2():
"""
Run SQL
---
parameters:
- name: user_id
in: query
required: true
- name: id
in: query|body
type: string
required: true
responses:
200:
schema:
type: object
properties:
type:
type: string
default: df
id:
type: string
df:
type: object
should_generate_chart:
type: boolean
"""
logger.info("Start to run sql in main")
try:
id = request.args.get("id")
qa = get_qa_by_id(id)
sql = qa["sql"]
logger.info(f"sql is {sql}")
if not vn.run_sql_is_set:
return jsonify(
{
"type": "error",
"error": "Please connect to a database using vn.connect_to_... in order to run SQL queries.",
}
)
df = vn.run_sql_2(sql=sql)
result = df.to_dict(orient='records')
logger.info("df ---------------{0} {1}".format(result, type(result)))
return jsonify(
{
"type": "success",
"id": id,
"df": result,
}
)
except Exception as e:
logger.error(f"run sql failed:{e}")
return jsonify({"type": "sql_error", "error": str(e)})
@app.flask_app.route("/yj_sqlbot/api/v0/verify", methods=["GET"])
def verify_user():
try:
id = request.args.get("user_id")
users = config('ALLOWED_USERS', default='')
users = users.split(',')
logger.info(f"allowed users {users}")
for user in users:
if user == id:
return jsonify({"type": "success", "verify": True})
else:
return jsonify({"type": "success", "verify": False})
except Exception as e:
logger.error(f"verify user failed:{e}")
return jsonify({"type": "error", "error": str(e)})
@app.flask_app.route("/yj_sqlbot/api/v0/query_present_question", methods=["GET"])
def query_present_question():
try:
data = query_predefined_question_list()
return jsonify({"type": "success", "data": data})
except Exception as e:
logger.error(f"查询预制问题失败 failed:{e}")
return jsonify({"type": "error", "error": f'查询预制问题失败:{str(e)}'})
@app.flask_app.route("/yj_sqlbot/api/v0/query_feedback_question", methods=["POST"])
def query_feedback_question():
id_list = request.json.get("id_list", [])
try:
data = query_feedBack_question_list(id_list)
return jsonify({"type": "success", "data": data})
except Exception as e:
logger.error(f"查询用户反馈问题失败 failed:{e}")
return jsonify({"type": "error", "error": f'查询用户反馈问题失败:{str(e)}'})
@app.flask_app.route("/yj_sqlbot/api/v0/question_feed_back", methods=["PUT"])
def update_question_feed_back():
id = request.json.get("id")
user_feedback = request.json.get("user_feedback")
if not id or not user_feedback:
return jsonify({"type": "error", "error": "id 或者用户反馈为空"})
try:
update_user_feedBack(id, '', user_feedback)
return jsonify({"type": "success"})
except Exception as e:
logger.error(f"查询预制问题失败 failed:{e}")
return jsonify({"type": "error", "error": f'查询预制问题失败:{str(e)}'})
@app.flask_app.route("/yj_sqlbot/api/v0/gen_graph_question", methods=["GET"])
def gen_graph_question():
try:
user_id = request.args.get("user_id")
logger.info(f"Into gen_graph_question => user {user_id}")
cvs_id = request.args.get("cvs_id")
config = {"configurable": {"thread_id": str(uuid.uuid4())}}
question = flask.request.args.get("question")
logger.info(f"Start to get question context")
question_context = get_latest_question(cvs_id, user_id, limit_count=2)
history = []
i = 0
if not question_context:
question_context = []
for q in question_context:
is_latest=False
if i==0:
is_latest=True
history.append({"role": "user", 'content': q, 'order': i,'is_latest':is_latest})
i +=1
logger.info(f"question context is {history}")
initial_state: SqlAgentState = {
"user_id": user_id,
"user_question": question,
"history": history,
"sql_retry_count": 0,
"chart_retry_count": 0
}
id = str(uuid.uuid4())
logger.info(f"Start to save conversation info (id,user_id,cvs_id,question)")
save_conversation(id, user_id, cvs_id, question)
logger.info(f"Enter the graph node=>gen_sql, gen_chart")
result = sql_chart_agent.invoke(initial_state, config=config)
logger.info(f"End====>gen_sql, gen_chart")
new_question = result.get('rewritten_user_question', question)
# save_conversation(id, user_id, cvs_id, new_question)
if new_question:
logger.info(f"new_question is {new_question}")
update_conversation(id=id, meta={'new_question': new_question})
logger.info("gen_sql_result => {0}".format(result.get("gen_sql_result", {})))
data = {
'id': id,
'sql': result.get("gen_sql_result", {}),
'chart': result.get("gen_chart_result", {}),
'gen_sql_error': result.get("gen_sql_error", None),
'gen_chart_error': result.get("gen_chart_error", None),
}
sql = data.get('sql', {}).get('sql', '')
state = data.get('sql', {}).get('success', False)
if not state:
logger.info("SQL generation failed.save error info to table")
error_msg = data.get('sql', {}).get('message', '')
update_conversation(id=id, answer={'type_error':'error','error':error_msg})
chart_cfg = data.get('chart', {})
update_conversation(id=id, sql=sql, chart_cfg=chart_cfg)
save_save_question_async(id, user_id, new_question, sql)
return jsonify(data)
except Exception as e:
traceback.print_exc()
logger.error(f"查询预制问题失败 failed:{e}")
return jsonify({"type": "error", "error": f'查询预制问题失败:{str(e)}'})
@app.flask_app.route("/yj_sqlbot/api/v0/run_sql_3", methods=["GET"])
def run_sql_3():
id = request.args.get("id")
qa = get_qa_by_id(id)
if not qa:
return jsonify({"type": "error", "error": f'获取对话失败:{id}'})
logger.info(f"Start to run_sql_3 => {qa}")
sql = qa["sql"]
if not sql:
error_info = qa["answer"]
if error_info:
return jsonify({"type": "error", "error": json.loads(error_info).get("error","")})
return jsonify({"type": "error", "error": f'sql 生成失败,请联系管理员'})
question = qa["question"]
logger.info(f"in main sql {sql} question {question}")
logger.info("Start to run sql in main")
try:
user_id = request.args.get("user_id")
cvs_id = request.args.get("cvs_id")
if not vn.run_sql_is_set:
return jsonify(
{
"type": "error",
"error": "Please connect to a database using "
"vn.connect_to_... in order to run SQL queries.",
}
)
question_context = get_latest_question(cvs_id, user_id, limit_count=2)
history = []
i = 0
for q in question_context:
is_latest=False
if i==0:
is_latest=True
history.append({"role": "user", 'content': q,
'order': i,'is_latest':is_latest})
i +=1
initial_state: DateReportAgentState = {
"id": id,
"user_id": user_id,
"sql": sql,
"question": question,
"retry_count": 0,
"history": history,
}
config = {"configurable": {"thread_id": str(uuid.uuid4())}}
rr = result_report_agent.invoke(initial_state, config)
data = rr.get('data', {})
summary = rr.get('summary', '')
run_sql_error = rr.get('run_sql_error', '')
logger.debug(f"data type is {type(data)} data is {data} summary is {summary}")
if data and not run_sql_error:
update_conversation(id=id, answer={'data': data,'summary':summary})
elif run_sql_error:
update_conversation(id=id, answer={'type_error':'error','error':"sql执行失败"})
logger.info(f"run_Sql finish run_sql_error => {run_sql_error}")
return jsonify(
{
'data': data,
'summary': summary,
'run_sql_error': "sql执行失败" if run_sql_error else None
}
)
except Exception as e:
logger.error(f"run sql failed:{e}")
return jsonify({"type": "sql_error", "error": str(e)})
@app.flask_app.route("/yj_sqlbot/api/v0/get_history", methods=["GET"])
def get_history():
logger.info(f"in main get history {request.args}")
try:
user_id = request.args.get("user_id")
page_num = int(request.args.get("page_num",1))
page_size = int(request.args.get("page_size",10))
if page_num < 1:
page_num = 1
if page_size < 1 or page_size > 100:
page_size = 10
offset = (page_num-1)*page_size
result = get_all_conversations_by_user(user_id, offset, page_size)
data = result.get('data', {})
total_count = result.get('total_count', 0)
total_pages = (total_count + page_size - 1) // page_size
logger.info(f"offset {offset} pagesize {page_size} total_count {total_count} total_pages {total_pages}")
info = {
"total_pages": total_pages,
"total_count": total_count,
}
history = []
ids = [d.id for d in data]
feedback_questions = query_feedBack_question_list(ids)
map_list = {q["id"]:q["user_praise"] for q in feedback_questions}
logger.info(f"map_list {map_list}")
for item in data:
cvs = {}
cvs['id'] = item.id
cvs['question'] = item.question
# answer_info = item.answer
# info_type = answer_info.get('type','')
# if info_type and info_type=="error":
# cvs[info_type] = answer_info.get("error","信息获取失败")
cvs['answer'] = json.loads(item.answer) if item.answer else None
cvs['chart_cfg'] = json.loads(item.chart_cfg) if item.chart_cfg else None
cvs['user_praise'] = map_list.get(item.id,0)
history.append(cvs)
info["history"] = history
return jsonify(info)
except Exception as e:
logger.error(f"get history failed:{e}")
return jsonify({"type": "get history error", "error": str(e)})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=8084, debug=False)