Files
sqlbot_agent/main_service.py

427 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import copy
import logging
import uuid
import time
from functools import wraps
import util.utils
from logging_config import LOGGING_CONFIG
from service.cus_vanna_srevice import CustomVanna, QdrantClient, TTLCacheWrapper
from service.question_feedback_service import save_save_question_async, query_predefined_question_list, \
update_user_feedBack, query_feedBack_question_list
from service.conversation_service import save_conversation, update_conversation, get_qa_by_id, get_latest_question
from decouple import config
import flask
from util import load_ddl_doc
from flask import Flask, Response, jsonify, request
from graph_chat.gen_sql_chart_agent import SqlAgentState, sql_chart_agent
from graph_chat.gen_data_report_agent import result_report_agent, DateReportAgentState
import traceback
logger = logging.getLogger(__name__)
def generate_timestamp_id():
"""生成基于时间戳的ID"""
# 获取当前时间戳(秒级)
timestamp = int(time.time() * 1000)
return f"Q{timestamp}"
def connect_database(vn):
db_type = config('DATA_SOURCE_TYPE', default='sqlite')
if db_type == 'sqlite':
vn.connect_to_sqlite(config('SQLITE_DATABASE_URL', default=''))
elif db_type == 'mysql':
vn.connect_to_mysql(host=config('MYSQL_DATABASE_HOST', default=''),
port=int(config('MYSQL_DATABASE_PORT', default=3306)),
user=config('MYSQL_DATABASE_USER', default=''),
password=config('MYSQL_DATABASE_PASSWORD', default=''),
database=config('MYSQL_DATABASE_DBNAME', default=''))
elif db_type == 'dameng':
# 待补充
vn.connect_to_dameng(
host=config('DAMENG_DATABASE_HOST', default=''),
port=config('DAMENG_DATABASE_PORT', default=3306),
user=config('DAMENG_DATABASE_USER', default=''),
password=config('DAMENG_DATABASE_PASSWORD', default=''),
)
else:
pass
def load_train_data_ddl(vn: CustomVanna):
vn.train()
def create_vana():
logger.info("----------------create vana ---------")
q_client = QdrantClient(":memory:") if config('QDRANT_TYPE', default='memory') == 'memory' else QdrantClient(
url=config('QDRANT_DB_HOST', default=''), port=config('QDRANT_DB_PORT', default=6333))
vn = CustomVanna(
vector_store_config={"client": q_client},
llm_config={
"api_key": config('CHAT_MODEL_API_KEY', default=''),
"api_base": config('CHAT_MODEL_BASE_URL', default=''),
"model": config('CHAT_MODEL_NAME', default=''),
'temperature': config('CHAT_MODEL_TEMPERATURE', default=0.7, cast=float),
'max_tokens': config('CHAT_MODEL_MAX_TOKEN', default=5000),
},
)
return vn
def init_vn(vn):
logger.info("--------------init vana-----connect to datasouce db----")
connect_database(vn)
if config('IS_FIRST_LOAD', default=False, cast=bool):
load_ddl_doc.add_ddl(vn)
load_ddl_doc.add_documentation(vn)
load_train_data_ddl(vn)
return vn
from vanna.flask import VannaFlaskApp
vn = create_vana()
app = VannaFlaskApp(vn, chart=False)
app.cache = TTLCacheWrapper(app.cache, ttl=config('TTL_CACHE', cast=int, default=60 * 60))
init_vn(vn)
cache = app.cache
@app.flask_app.route("/yj_sqlbot/api/v0/generate_sql_2", methods=["GET"])
def generate_sql_2():
"""
Generate SQL from a question
---
parameters:
- name: user_id
in: query
- name: question
in: query
- name: question_id
responses:
200:
schema:
type: object
properties:
type:
type: string
default: sql
id:
type: string
text:
type: string
"""
logger.info("Start to generate sql in main")
question = flask.request.args.get("question")
if question is None:
return jsonify({"type": "error", "error": "No question provided"})
user_id = request.args.get("user_id")
cvs_id = request.args.get("cvs_id")
need_context = bool(request.args.get("need_context"))
if user_id is None or cvs_id is None:
return jsonify({"type": "error", "error": "No user_id or cvs_id provided"})
id = generate_timestamp_id()
logger.info(f"question_id: {id} user_id: {user_id} cvs_id: {cvs_id} question: {question}")
save_conversation(id, user_id, cvs_id, question)
try:
logger.info(f"Generate sql for {question}")
data = vn.generate_sql_2(user_id, cvs_id, question, id, need_context)
logger.info("Generate sql result is {0}".format(data))
data['id'] = id
sql = data["resp"]["sql"]
logger.info("generate sql is : " + sql)
update_conversation(id, sql)
save_save_question_async(id, user_id, question, sql)
data["type"] = "success"
return jsonify(data)
except Exception as e:
logger.error(f"generate sql failed:{e}")
return jsonify({"type": "error", "error": str(e)})
# def requires_cache_2(required_keys):
# def decorator(f):
# @wraps(f)
# def decorated(*args, **kwargs):
# id = request.args.get("id")
# user_id = request.args.get("user_id")
# if user_id is None:
# user_id = request.json.get("user_id")
# if user_id is None:
# return jsonify({"type": "error", "error": "No user_id provided"})
# if id is None:
# id = request.json.get("id")
# if id is None:
# return jsonify({"type": "error", "error": "No id provided"})
# all_v = cache.items()
# logger.info(f"all values {all_v}")
# logger.info(f"user {user_id} id {id}")
# qa_list = cache.get(id=user_id, field="qa_list")
# if qa_list is None:
# return jsonify({"type": "error", "error": f"No qa_list found"})
# logger.info(f"qa_list {qa_list}")
# q_a = list(filter(lambda x: x["id"] == id, qa_list))
# logger.info(f"q_a {q_a}")
# for key in required_keys:
# if q_a[0][key] is None:
# return jsonify({"type": "error", "error": f"No {key} found for id:{id}"})
# values = {key:q_a[0][key] for key in required_keys}
# values["id"] = id
# logger.info("cache values {0}".format(values))
#
# return f(*args, **values, **kwargs)
#
# return decorated
#
# return decorator
# def session_save(func):
# @wraps(func)
# def wrapper(*args, **kwargs):
# id = request.args.get("id")
# user_id = request.args.get("user_id")
# logger.info(f" id: {id},user_id: {user_id}")
# result = func(*args, **kwargs)
#
# datas = []
# session_len = int(config("SESSION_LENGTH", default=2))
# if cache.exists(id=user_id, field="qa_list"):
# datas = copy.deepcopy(cache.get(id=user_id, field="qa_list"))
# logger.info("datas is {0}".format(datas))
# if len(datas) > session_len and session_len > 0:
# logger.info(f"开始裁剪-------------------------------------")
# datas=datas[-session_len:]
# # 删除id对应的所有缓存值,因为已经run_sql完毕改用user_id保存为上下文
# # cache.delete(id=id, field="question")
# print("datas---------------------{0}".format(datas))
# cache.set(id=user_id, field="qa_list", value=copy.deepcopy(datas))
# logger.info(f" user data {cache.get(user_id, field='qa_list')}")
# return result
#
# return wrapper
@app.flask_app.route("/yj_sqlbot/api/v0/run_sql_2", methods=["GET"])
# @session_save
# @requires_cache_2(required_keys=["sql"])
def run_sql_2():
"""
Run SQL
---
parameters:
- name: user_id
in: query
required: true
- name: id
in: query|body
type: string
required: true
responses:
200:
schema:
type: object
properties:
type:
type: string
default: df
id:
type: string
df:
type: object
should_generate_chart:
type: boolean
"""
logger.info("Start to run sql in main")
try:
id = request.args.get("id")
qa = get_qa_by_id(id)
sql = qa["sql"]
logger.info(f"sql is {sql}")
if not vn.run_sql_is_set:
return jsonify(
{
"type": "error",
"error": "Please connect to a database using vn.connect_to_... in order to run SQL queries.",
}
)
df = vn.run_sql_2(sql=sql)
result = df.to_dict(orient='records')
logger.info("df ---------------{0} {1}".format(result, type(result)))
return jsonify(
{
"type": "success",
"id": id,
"df": result,
}
)
except Exception as e:
logger.error(f"run sql failed:{e}")
return jsonify({"type": "sql_error", "error": str(e)})
@app.flask_app.route("/yj_sqlbot/api/v0/verify", methods=["GET"])
def verify_user():
try:
id = request.args.get("user_id")
users = config('ALLOWED_USERS', default='')
users = users.split(',')
logger.info(f"allowed users {users}")
for user in users:
if user == id:
return jsonify({"type": "success", "verify": True})
else:
return jsonify({"type": "success", "verify": False})
except Exception as e:
logger.error(f"verify user failed:{e}")
return jsonify({"type": "error", "error": str(e)})
@app.flask_app.route("/yj_sqlbot/api/v0/query_present_question", methods=["GET"])
def query_present_question():
try:
data = query_predefined_question_list()
return jsonify({"type": "success", "data": data})
except Exception as e:
logger.error(f"查询预制问题失败 failed:{e}")
return jsonify({"type": "error", "error": f'查询预制问题失败:{str(e)}'})
@app.flask_app.route("/yj_sqlbot/api/v0/query_feedback_question", methods=["POST"])
def query_feedback_question():
id_list = request.json.get("id_list", [])
try:
data = query_feedBack_question_list(id_list)
return jsonify({"type": "success", "data": data})
except Exception as e:
logger.error(f"查询用户反馈问题失败 failed:{e}")
return jsonify({"type": "error", "error": f'查询用户反馈问题失败:{str(e)}'})
@app.flask_app.route("/yj_sqlbot/api/v0/question_feed_back", methods=["PUT"])
def update_question_feed_back():
id = request.json.get("id")
user_feedback = request.json.get("user_feedback")
if not id or not user_feedback:
return jsonify({"type": "error", "error": "id 或者用户反馈为空"})
try:
update_user_feedBack(id, '', user_feedback)
return jsonify({"type": "success"})
except Exception as e:
logger.error(f"查询预制问题失败 failed:{e}")
return jsonify({"type": "error", "error": f'查询预制问题失败:{str(e)}'})
@app.flask_app.route("/yj_sqlbot/api/v0/gen_graph_question", methods=["GET"])
def gen_graph_question():
try:
user_id = request.args.get("user_id")
cvs_id = request.args.get("cvs_id")
config = {"configurable": {"thread_id": str(uuid.uuid4())}}
question = flask.request.args.get("question")
question_context = get_latest_question(cvs_id, user_id, limit_count=2)
history = []
i = 0
for q in question_context:
is_latest=False
if i==0:
is_latest=True
history.append({"role": "user", 'content': q, 'order': i,'is_latest':is_latest})
i +=1
logger.info(f"history is {history}")
initial_state: SqlAgentState = {
"user_id": user_id,
"user_question": question,
"history": history,
"sql_retry_count": 0,
"chart_retry_count": 0
}
result = sql_chart_agent.invoke(initial_state, config=config)
new_question = result.get('rewritten_user_question', question)
id = str(uuid.uuid4())
save_conversation(id, user_id, cvs_id, new_question)
# cache.set(id=id, field="question", value=result.get('rewritten_user_question',question))
data = {
'id': id,
'sql': result.get("gen_sql_result", {}),
'chart': result.get("gen_chart_result", {}),
'gen_sql_error': result.get("gen_sql_error", None),
'gen_chart_error': result.get("gen_chart_error", None),
}
# cache.set(id=id, field="sql", value=data.get('sql', {}).get('sql', ''))
sql = data.get('sql', {}).get('sql', '')
update_conversation(id, sql)
save_save_question_async(id, user_id, new_question, sql)
return jsonify(data)
except Exception as e:
traceback.print_exc()
logger.error(f"查询预制问题失败 failed:{e}")
return jsonify({"type": "error", "error": f'查询预制问题失败:{str(e)}'})
@app.flask_app.route("/yj_sqlbot/api/v0/run_sql_3", methods=["GET"])
# @session_save
# @app.requires_cache(["sql",'question'])
def run_sql_3():
id = request.args.get("id")
qa = get_qa_by_id(id)
sql = qa["sql"]
question = qa["question"]
logger.info(f"in main sql {sql} question {question}")
logger.info("Start to run sql in main")
try:
user_id = request.args.get("user_id")
cvs_id = request.args.get("cvs_id")
if not vn.run_sql_is_set:
return jsonify(
{
"type": "error",
"error": "Please connect to a database using vn.connect_to_... in order to run SQL queries.",
}
)
question_context = get_latest_question(cvs_id, user_id, limit_count=2)
history = []
i = 0
for q in question_context:
is_latest=False
if i==0:
is_latest=True
history.append({"role": "user", 'content': q, 'order': i,'is_latest':is_latest})
i +=1
initial_state: DateReportAgentState = {
"id": id,
"user_id": user_id,
"sql": sql,
"question": question,
"retry_count": 0,
"history": history,
}
config = {"configurable": {"thread_id": str(uuid.uuid4())}}
rr = result_report_agent.invoke(initial_state, config)
logger.info(f"rr.data is {rr.get('data', {})}")
return jsonify(
{
'data': rr.get('data', {}),
'summary': rr.get('summary', ''),
'run_sql_error': rr.get('run_sql_error', '')
}
)
except Exception as e:
logger.error(f"run sql failed:{e}")
return jsonify({"type": "sql_error", "error": str(e)})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=8084, debug=False)