Files
sqlbot_agent/main_service.py
2025-09-23 14:49:00 +08:00

116 lines
3.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from email.policy import default
from service.cus_vanna_srevice import CustomVanna, QdrantClient
from decouple import config
import flask
from flask import Flask, Response, jsonify, request, send_from_directory
def connect_database(vn):
db_type = config('DATA_SOURCE_TYPE', default='sqlite')
if db_type == 'sqlite':
vn.connect_to_sqlite(config('SQLITE_DATABASE_URL', default=''))
elif db_type == 'mysql':
vn.connect_to_mysql(host=config('MYSQL_DATABASE_HOST', default=''),
port=config('MYSQL_DATABASE_PORT', default=3306),
user=config('MYSQL_DATABASE_USER', default=''),
password=config('MYSQL_DATABASE_PASSWORD', default=''),
database=config('MYSQL_DATABASE_DBNAME', default=''))
elif db_type == 'postgresql':
# 待补充
pass
else:
pass
def load_train_data_ddl(vn: CustomVanna):
vn.train(ddl="""
create table db_user
(
id integer not null
constraint db_user_pk
primary key autoincrement,
user_name TEXT not null,
age integer not null,
address TEXT,
gender integer not null,
email TEXT
)
""")
vn.train(documentation='''
gender 字段 0代表女性1代表男性;
查询address时,尽量使用like查询如:select * from db_user where address like '%北京%';
语法为sqlite语法;
''')
def create_vana():
print("----------------create---------")
vn = CustomVanna(
vector_store_config={"client": QdrantClient(":memory:")},
llm_config={
"api_key": config('CHAT_MODEL_API_KEY', default=''),
"api_base": config('CHAT_MODEL_BASE_URL', default=''),
"model": config('CHAT_MODEL_NAME', default=''),
},
)
return vn
def init_vn(vn):
print("--------------init vn-----connect----")
connect_database(vn)
if config('IS_FIRST_LOAD', default=False, cast=bool):
load_train_data_ddl(vn)
return vn
from vanna.flask import VannaFlaskApp
vn = create_vana()
app = VannaFlaskApp(vn,chart=False)
init_vn(vn)
@app.flask_app.route("/api/v0/generate_sql_2", methods=["GET"])
def generate_sql_2():
"""
Generate SQL from a question
---
parameters:
- name: user
in: query
- name: question
in: query
type: string
required: true
responses:
200:
schema:
type: object
properties:
type:
type: string
default: sql
id:
type: string
text:
type: string
"""
question = flask.request.args.get("question")
if question is None:
return jsonify({"type": "error", "error": "No question provided"})
#id = self.cache.generate_id(question=question)
data = vn.generate_sql_2(question=question)
return jsonify(data)
if __name__ == '__main__':
app.run(host='0.0.0.0', port=8084, debug=False)