Files
sqlbot_agent/main_service.py
2025-10-15 14:42:48 +08:00

239 lines
7.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import copy
import logging
from functools import wraps
import util.utils
from logging_config import LOGGING_CONFIG
from service.cus_vanna_srevice import CustomVanna, QdrantClient, TTLCacheWrapper
from decouple import config
import flask
from util import load_ddl_doc
from flask import Flask, Response, jsonify, request, send_from_directory
logger = logging.getLogger(__name__)
def connect_database(vn):
db_type = config('DATA_SOURCE_TYPE', default='sqlite')
if db_type == 'sqlite':
vn.connect_to_sqlite(config('SQLITE_DATABASE_URL', default=''))
elif db_type == 'mysql':
vn.connect_to_mysql(host=config('MYSQL_DATABASE_HOST', default=''),
port=int(config('MYSQL_DATABASE_PORT', default=3306)),
user=config('MYSQL_DATABASE_USER', default=''),
password=config('MYSQL_DATABASE_PASSWORD', default=''),
database=config('MYSQL_DATABASE_DBNAME', default=''))
elif db_type == 'dameng':
# 待补充
vn.connect_to_dameng(
host=config('DAMENG_DATABASE_HOST', default=''),
port=config('DAMENG_DATABASE_PORT', default=3306),
user=config('DAMENG_DATABASE_USER', default=''),
password=config('DAMENG_DATABASE_PASSWORD', default=''),
)
else:
pass
def load_train_data_ddl(vn: CustomVanna):
vn.train()
def create_vana():
logger.info("----------------create vana ---------")
q_client = QdrantClient(":memory:") if config('QDRANT_TYPE', default='memory') == 'memory' else QdrantClient(
url=config('QDRANT_DB_HOST', default=''), port=config('QDRANT_DB_PORT', default=6333))
vn = CustomVanna(
vector_store_config={"client": q_client},
llm_config={
"api_key": config('CHAT_MODEL_API_KEY', default=''),
"api_base": config('CHAT_MODEL_BASE_URL', default=''),
"model": config('CHAT_MODEL_NAME', default=''),
'temperature':config('CHAT_MODEL_TEMPERATURE', default=0.7, cast=float),
'max_tokens':config('CHAT_MODEL_MAX_TOKEN', default=20000),
},
)
return vn
def init_vn(vn):
logger.info("--------------init vana-----connect to datasouce db----")
connect_database(vn)
load_ddl_doc.add_ddl(vn)
load_ddl_doc.add_documentation(vn)
if config('IS_FIRST_LOAD', default=False, cast=bool):
load_train_data_ddl(vn)
return vn
from vanna.flask import VannaFlaskApp
vn = create_vana()
app = VannaFlaskApp(vn,chart=False)
app.cache = TTLCacheWrapper(app.cache, ttl = config('TTL_CACHE', cast=int,default=60*60))
init_vn(vn)
cache = app.cache
@app.flask_app.route("/yj_sqlbot/api/v0/generate_sql_2", methods=["GET"])
def generate_sql_2():
"""
Generate SQL from a question
---
parameters:
- name: user
in: query
- name: question
in: query
type: string
required: true
responses:
200:
schema:
type: object
properties:
type:
type: string
default: sql
id:
type: string
text:
type: string
"""
logger.info("Start to generate sql in main")
question = flask.request.args.get("question")
if question is None:
return jsonify({"type": "error", "error": "No question provided"})
try:
id = cache.generate_id(question=question)
user_id = request.args.get("user_id")
logger.info(f"Generate sql for {question}")
data = vn.generate_sql_2(question=question, cache=cache, user_id=user_id)
logger.info("Generate sql result is {0}".format(data))
data['id'] = id
sql = data["resp"]["sql"]
logger.info("generate sql is : "+ sql)
cache.set(id=id, field="question", value=question)
cache.set(id=id, field="sql", value=sql)
data["type"]="success"
return jsonify(data)
except Exception as e:
logger.error(f"generate sql failed:{e}")
return jsonify({"type": "error", "error": str(e)})
def session_save(func):
@wraps(func)
def wrapper(*args, **kwargs):
id=request.args.get("id")
user_id = request.args.get("user_id")
logger.info(f" id: {id},user_id: {user_id}")
result = func(*args, **kwargs)
datas=[]
session_len = int(config("SESSION_LENGTH", default=2))
if cache.exists(id=user_id, field="data"):
datas = copy.deepcopy(cache.get(id=user_id, field="data"))
data = {
"id": id,
"question":cache.get(id=id, field="question"),
"sql":cache.get(id=id, field="sql")
}
datas.append(data)
logger.info("datas is {0}".format(datas))
if len(datas) > session_len and session_len > 0:
datas=datas[-session_len:]
# 删除id对应的所有缓存值,因为已经run_sql完毕改用user_id保存为上下文
cache.delete(id=id, field="question")
cache.set(id=user_id, field="data", value=copy.deepcopy(datas))
logger.info(f" user data {cache.get(user_id, field='data')}")
return result
return wrapper
@app.flask_app.route("/yj_sqlbot/api/v0/run_sql_2", methods=["GET"])
@session_save
@app.requires_cache(["sql"])
def run_sql_2(id: str, sql: str):
"""
Run SQL
---
parameters:
- name: user_id
in: query
required: true
- name: id
in: query|body
type: string
required: true
- name: page_size
in: query
-name: page_num
in: query
responses:
200:
schema:
type: object
properties:
type:
type: string
default: df
id:
type: string
df:
type: object
should_generate_chart:
type: boolean
"""
logger.info("Start to run sql in main")
try:
if not vn.run_sql_is_set:
return jsonify(
{
"type": "error",
"error": "Please connect to a database using vn.connect_to_... in order to run SQL queries.",
}
)
# count_sql = f"SELECT COUNT(*) AS total_count FROM ({sql}) AS subquery"
# df_count = vn.run_sql(count_sql)
# print(df_count,"is type",type(df_count))
# total_count = df_count.to_dict(orient="records")[0]["total_count"]
# logger.info("Total count is {0}".format(total_count))
df = vn.run_sql(sql=sql)
result = df.to_dict(orient='records')
logger.info("df ---------------{0} {1}".format(result,type(result)))
return jsonify(
{
"type": "success",
"id": id,
"df": result,
}
)
except Exception as e:
logger.error(f"run sql failed:{e}")
return jsonify({"type": "sql_error", "error": str(e)})
@app.flask_app.route("/yj_sqlbot/api/v0/verify", methods=["GET"])
def verify_user():
try:
id = request.args.get("user_id")
users = config('ALLOWED_USERS', default='')
users = users.split(',')
logger.info(f"allowed users {users}")
for user in users:
if user == id:
return jsonify({"type": "success", "verify": True})
else:
return jsonify({"type": "success", "verify": False})
except Exception as e:
logger.error(f"verify user failed:{e}")
return jsonify({"type": "error", "error": str(e)})
if __name__ == '__main__':
app.run(host='0.0.0.0', port=8084, debug=False)