320 lines
11 KiB
Python
320 lines
11 KiB
Python
import copy
|
||
import logging
|
||
import time
|
||
from functools import wraps
|
||
import util.utils
|
||
from logging_config import LOGGING_CONFIG
|
||
from service.cus_vanna_srevice import CustomVanna, QdrantClient, TTLCacheWrapper
|
||
from service.question_feedback_service import save_save_question_async,query_predefined_question_list
|
||
from service.conversation_service import save_conversation,update_conversation,get_sql_by_id
|
||
from decouple import config
|
||
import flask
|
||
from util import load_ddl_doc
|
||
from flask import Flask, Response, jsonify, request
|
||
from graph_chat.gen_sql_chart_agent import SqlAgentState, sql_chart_agent
|
||
import traceback
|
||
logger = logging.getLogger(__name__)
|
||
|
||
def generate_timestamp_id():
|
||
"""生成基于时间戳的ID"""
|
||
# 获取当前时间戳(秒级)
|
||
timestamp = int(time.time() * 1000)
|
||
return f"Q{timestamp}"
|
||
|
||
|
||
def connect_database(vn):
|
||
db_type = config('DATA_SOURCE_TYPE', default='sqlite')
|
||
if db_type == 'sqlite':
|
||
vn.connect_to_sqlite(config('SQLITE_DATABASE_URL', default=''))
|
||
elif db_type == 'mysql':
|
||
vn.connect_to_mysql(host=config('MYSQL_DATABASE_HOST', default=''),
|
||
port=int(config('MYSQL_DATABASE_PORT', default=3306)),
|
||
user=config('MYSQL_DATABASE_USER', default=''),
|
||
password=config('MYSQL_DATABASE_PASSWORD', default=''),
|
||
database=config('MYSQL_DATABASE_DBNAME', default=''))
|
||
elif db_type == 'dameng':
|
||
# 待补充
|
||
vn.connect_to_dameng(
|
||
host=config('DAMENG_DATABASE_HOST', default=''),
|
||
port=config('DAMENG_DATABASE_PORT', default=3306),
|
||
user=config('DAMENG_DATABASE_USER', default=''),
|
||
password=config('DAMENG_DATABASE_PASSWORD', default=''),
|
||
)
|
||
else:
|
||
pass
|
||
|
||
|
||
def load_train_data_ddl(vn: CustomVanna):
|
||
vn.train()
|
||
|
||
|
||
def create_vana():
|
||
logger.info("----------------create vana ---------")
|
||
q_client = QdrantClient(":memory:") if config('QDRANT_TYPE', default='memory') == 'memory' else QdrantClient(
|
||
url=config('QDRANT_DB_HOST', default=''), port=config('QDRANT_DB_PORT', default=6333))
|
||
vn = CustomVanna(
|
||
vector_store_config={"client": q_client},
|
||
llm_config={
|
||
"api_key": config('CHAT_MODEL_API_KEY', default=''),
|
||
"api_base": config('CHAT_MODEL_BASE_URL', default=''),
|
||
"model": config('CHAT_MODEL_NAME', default=''),
|
||
'temperature': config('CHAT_MODEL_TEMPERATURE', default=0.7, cast=float),
|
||
'max_tokens': config('CHAT_MODEL_MAX_TOKEN', default=5000),
|
||
},
|
||
)
|
||
|
||
return vn
|
||
|
||
|
||
def init_vn(vn):
|
||
logger.info("--------------init vana-----connect to datasouce db----")
|
||
#connect_database(vn)
|
||
if config('IS_FIRST_LOAD', default=False, cast=bool):
|
||
load_ddl_doc.add_ddl(vn)
|
||
load_ddl_doc.add_documentation(vn)
|
||
load_train_data_ddl(vn)
|
||
return vn
|
||
|
||
|
||
from vanna.flask import VannaFlaskApp
|
||
|
||
vn = create_vana()
|
||
app = VannaFlaskApp(vn, chart=False)
|
||
app.cache = TTLCacheWrapper(app.cache, ttl=config('TTL_CACHE', cast=int, default=60 * 60))
|
||
init_vn(vn)
|
||
cache = app.cache
|
||
|
||
|
||
@app.flask_app.route("/yj_sqlbot/api/v0/generate_sql_2", methods=["GET"])
|
||
def generate_sql_2():
|
||
"""
|
||
Generate SQL from a question
|
||
---
|
||
parameters:
|
||
- name: user_id
|
||
in: query
|
||
- name: question
|
||
in: query
|
||
- name: question_id
|
||
responses:
|
||
200:
|
||
schema:
|
||
type: object
|
||
properties:
|
||
type:
|
||
type: string
|
||
default: sql
|
||
id:
|
||
type: string
|
||
text:
|
||
type: string
|
||
"""
|
||
logger.info("Start to generate sql in main")
|
||
question = flask.request.args.get("question")
|
||
if question is None:
|
||
return jsonify({"type": "error", "error": "No question provided"})
|
||
|
||
user_id = request.args.get("user_id")
|
||
cvs_id = request.args.get("cvs_id")
|
||
need_context = bool(request.args.get("need_context"))
|
||
if user_id is None or cvs_id is None:
|
||
return jsonify({"type": "error", "error": "No user_id or cvs_id provided"})
|
||
id = generate_timestamp_id()
|
||
logger.info(f"question_id: {id} user_id: {user_id} cvs_id: {cvs_id} question: {question}")
|
||
save_conversation(id,user_id,cvs_id,question)
|
||
try:
|
||
logger.info(f"Generate sql for {question}")
|
||
data = vn.generate_sql_2(user_id,cvs_id,question,id,need_context)
|
||
logger.info("Generate sql result is {0}".format(data))
|
||
data['id'] = id
|
||
sql = data["resp"]["sql"]
|
||
logger.info("generate sql is : "+ sql)
|
||
update_conversation(cvs_id, id, sql)
|
||
save_save_question_async(id, user_id, question, sql)
|
||
data["type"]="success"
|
||
return jsonify(data)
|
||
except Exception as e:
|
||
logger.error(f"generate sql failed:{e}")
|
||
return jsonify({"type": "error", "error": str(e)})
|
||
|
||
|
||
# def requires_cache_2(required_keys):
|
||
# def decorator(f):
|
||
# @wraps(f)
|
||
# def decorated(*args, **kwargs):
|
||
# id = request.args.get("id")
|
||
# user_id = request.args.get("user_id")
|
||
# if user_id is None:
|
||
# user_id = request.json.get("user_id")
|
||
# if user_id is None:
|
||
# return jsonify({"type": "error", "error": "No user_id provided"})
|
||
# if id is None:
|
||
# id = request.json.get("id")
|
||
# if id is None:
|
||
# return jsonify({"type": "error", "error": "No id provided"})
|
||
# all_v = cache.items()
|
||
# logger.info(f"all values {all_v}")
|
||
# logger.info(f"user {user_id} id {id}")
|
||
# qa_list = cache.get(id=user_id, field="qa_list")
|
||
# if qa_list is None:
|
||
# return jsonify({"type": "error", "error": f"No qa_list found"})
|
||
# logger.info(f"qa_list {qa_list}")
|
||
# q_a = list(filter(lambda x: x["id"] == id, qa_list))
|
||
# logger.info(f"q_a {q_a}")
|
||
# for key in required_keys:
|
||
# if q_a[0][key] is None:
|
||
# return jsonify({"type": "error", "error": f"No {key} found for id:{id}"})
|
||
# values = {key:q_a[0][key] for key in required_keys}
|
||
# values["id"] = id
|
||
# logger.info("cache values {0}".format(values))
|
||
#
|
||
# return f(*args, **values, **kwargs)
|
||
#
|
||
# return decorated
|
||
#
|
||
# return decorator
|
||
|
||
|
||
|
||
|
||
# def session_save(func):
|
||
# @wraps(func)
|
||
# def wrapper(*args, **kwargs):
|
||
# id = request.args.get("id")
|
||
# user_id = request.args.get("user_id")
|
||
# logger.info(f" id: {id},user_id: {user_id}")
|
||
# result = func(*args, **kwargs)
|
||
#
|
||
# datas = []
|
||
# session_len = int(config("SESSION_LENGTH", default=2))
|
||
# if cache.exists(id=user_id, field="qa_list"):
|
||
# datas = copy.deepcopy(cache.get(id=user_id, field="qa_list"))
|
||
# logger.info("datas is {0}".format(datas))
|
||
# if len(datas) > session_len and session_len > 0:
|
||
# logger.info(f"开始裁剪-------------------------------------")
|
||
# datas=datas[-session_len:]
|
||
# # 删除id对应的所有缓存值,因为已经run_sql完毕,改用user_id保存为上下文
|
||
# # cache.delete(id=id, field="question")
|
||
# print("datas---------------------{0}".format(datas))
|
||
# cache.set(id=user_id, field="qa_list", value=copy.deepcopy(datas))
|
||
# logger.info(f" user data {cache.get(user_id, field='qa_list')}")
|
||
# return result
|
||
#
|
||
# return wrapper
|
||
|
||
|
||
|
||
@app.flask_app.route("/yj_sqlbot/api/v0/run_sql_2", methods=["GET"])
|
||
# @session_save
|
||
# @requires_cache_2(required_keys=["sql"])
|
||
def run_sql_2():
|
||
"""
|
||
Run SQL
|
||
---
|
||
parameters:
|
||
- name: user_id
|
||
in: query
|
||
required: true
|
||
- name: id
|
||
in: query|body
|
||
type: string
|
||
required: true
|
||
responses:
|
||
200:
|
||
schema:
|
||
type: object
|
||
properties:
|
||
type:
|
||
type: string
|
||
default: df
|
||
id:
|
||
type: string
|
||
df:
|
||
type: object
|
||
should_generate_chart:
|
||
type: boolean
|
||
"""
|
||
logger.info("Start to run sql in main")
|
||
try:
|
||
id = request.args.get("id")
|
||
sql = get_sql_by_id(id)
|
||
logger.info(f"sql is {sql}")
|
||
if not vn.run_sql_is_set:
|
||
return jsonify(
|
||
{
|
||
"type": "error",
|
||
"error": "Please connect to a database using vn.connect_to_... in order to run SQL queries.",
|
||
}
|
||
)
|
||
|
||
df = vn.run_sql_2(sql=sql)
|
||
result = df.to_dict(orient='records')
|
||
logger.info("df ---------------{0} {1}".format(result, type(result)))
|
||
return jsonify(
|
||
{
|
||
"type": "success",
|
||
"id": id,
|
||
"df": result,
|
||
}
|
||
)
|
||
|
||
except Exception as e:
|
||
logger.error(f"run sql failed:{e}")
|
||
return jsonify({"type": "sql_error", "error": str(e)})
|
||
|
||
|
||
@app.flask_app.route("/yj_sqlbot/api/v0/verify", methods=["GET"])
|
||
def verify_user():
|
||
try:
|
||
id = request.args.get("user_id")
|
||
users = config('ALLOWED_USERS', default='')
|
||
users = users.split(',')
|
||
logger.info(f"allowed users {users}")
|
||
for user in users:
|
||
if user == id:
|
||
return jsonify({"type": "success", "verify": True})
|
||
else:
|
||
return jsonify({"type": "success", "verify": False})
|
||
except Exception as e:
|
||
logger.error(f"verify user failed:{e}")
|
||
return jsonify({"type": "error", "error": str(e)})
|
||
|
||
|
||
@app.flask_app.route("/yj_sqlbot/api/v0/query_present_question", methods=["GET"])
|
||
def query_present_question():
|
||
try:
|
||
data = query_predefined_question_list()
|
||
return jsonify({"type": "success", "data": data})
|
||
except Exception as e:
|
||
logger.error(f"查询预制问题失败 failed:{e}")
|
||
return jsonify({"type": "error", "error": f'查询预制问题失败:{str(e)}'})
|
||
|
||
|
||
@app.flask_app.route("/yj_sqlbot/api/v0/gen_graph_question", methods=["GET"])
|
||
def gen_graph_question():
|
||
try:
|
||
config = {"configurable": {"thread_id": '1233'}}
|
||
question = flask.request.args.get("question")
|
||
initial_state: SqlAgentState = {
|
||
"user_question": question,
|
||
"history": [{"role":"user","content":"宋亚澜9月在林芝工作多少天"},{"role":"user","content":"余佳佳9月在林芝工作多少天"}],
|
||
"sql_retry_count": 0,
|
||
"chart_retry_count": 0
|
||
}
|
||
result=sql_chart_agent.invoke(initial_state, config=config)
|
||
data={
|
||
'sql':result.get("gen_sql_result",{}),
|
||
'chart':result.get("gen_chart_result",{}),
|
||
'gen_sql_error':result.get("gen_sql_error",None),
|
||
'gen_chart_error':result.get("gen_chart_error",None),
|
||
}
|
||
return jsonify(data)
|
||
except Exception as e:
|
||
traceback.print_exc()
|
||
logger.error(f"查询预制问题失败 failed:{e}")
|
||
return jsonify({"type": "error", "error": f'查询预制问题失败:{str(e)}'})
|
||
|
||
|
||
if __name__ == '__main__':
|
||
app.run(host='0.0.0.0', port=8084, debug=False)
|