feat:初始化
This commit is contained in:
22
.env
Normal file
22
.env
Normal file
@@ -0,0 +1,22 @@
|
||||
IS_FIRST_LOAD=True
|
||||
|
||||
CHAT_MODEL_BASE_URL=https://api.siliconflow.cn
|
||||
CHAT_MODEL_API_KEY=sk-cjiakyfpzamtxgxitcbxvwyvaulnygmyxqpykkgngsvfqhuy
|
||||
CHAT_MODEL_NAME=Qwen/Qwen3-Next-80B-A3B-Instruct
|
||||
|
||||
EMBEDDING_MODEL_BASE_URL=https://api.siliconflow.cn
|
||||
EMBEDDING_MODEL_API_KEY=sk-cjiakyfpzamtxgxitcbxvwyvaulnygmyxqpykkgngsvfqhuy
|
||||
EMBEDDING_MODEL_NAME=Qwen/Qwen3-Embedding-8B
|
||||
|
||||
#mysql ,sqlite,pg等
|
||||
DATA_SOURCE_TYPE=sqlite
|
||||
|
||||
#sqlite 连接信息
|
||||
SQLITE_DATABASE_URL=E://db/db_flights.sqlite
|
||||
|
||||
#mysql 连接信息
|
||||
MYSQL_DATABASE_HOST=
|
||||
MYSQL_DATABASE_PORT=
|
||||
MYSQL_DATABASE_PASSWORD=
|
||||
MYSQL_DATABASE_USER=
|
||||
MYSQL_DATABASE_DBNAME=
|
||||
115
main_service.py
Normal file
115
main_service.py
Normal file
@@ -0,0 +1,115 @@
|
||||
from email.policy import default
|
||||
|
||||
from service.cus_vanna_srevice import CustomVanna, QdrantClient
|
||||
from decouple import config
|
||||
import flask
|
||||
|
||||
from flask import Flask, Response, jsonify, request, send_from_directory
|
||||
|
||||
def connect_database(vn):
|
||||
db_type = config('DATA_SOURCE_TYPE', default='sqlite')
|
||||
if db_type == 'sqlite':
|
||||
vn.connect_to_sqlite(config('SQLITE_DATABASE_URL', default=''))
|
||||
elif db_type == 'mysql':
|
||||
vn.connect_to_mysql(host=config('MYSQL_DATABASE_HOST', default=''),
|
||||
port=config('MYSQL_DATABASE_PORT', default=3306),
|
||||
user=config('MYSQL_DATABASE_USER', default=''),
|
||||
password=config('MYSQL_DATABASE_PASSWORD', default=''),
|
||||
database=config('MYSQL_DATABASE_DBNAME', default=''))
|
||||
elif db_type == 'postgresql':
|
||||
# 待补充
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
|
||||
|
||||
def load_train_data_ddl(vn: CustomVanna):
|
||||
vn.train(ddl="""
|
||||
create table db_user
|
||||
(
|
||||
id integer not null
|
||||
constraint db_user_pk
|
||||
primary key autoincrement,
|
||||
user_name TEXT not null,
|
||||
age integer not null,
|
||||
address TEXT,
|
||||
gender integer not null,
|
||||
email TEXT
|
||||
)
|
||||
|
||||
|
||||
""")
|
||||
vn.train(documentation='''
|
||||
gender 字段 0代表女性,1代表男性;
|
||||
查询address时,尽量使用like查询,如:select * from db_user where address like '%北京%';
|
||||
语法为sqlite语法;
|
||||
''')
|
||||
|
||||
|
||||
def create_vana():
|
||||
print("----------------create---------")
|
||||
vn = CustomVanna(
|
||||
vector_store_config={"client": QdrantClient(":memory:")},
|
||||
llm_config={
|
||||
"api_key": config('CHAT_MODEL_API_KEY', default=''),
|
||||
"api_base": config('CHAT_MODEL_BASE_URL', default=''),
|
||||
"model": config('CHAT_MODEL_NAME', default=''),
|
||||
},
|
||||
)
|
||||
return vn
|
||||
|
||||
|
||||
def init_vn(vn):
|
||||
print("--------------init vn-----connect----")
|
||||
connect_database(vn)
|
||||
if config('IS_FIRST_LOAD', default=False, cast=bool):
|
||||
load_train_data_ddl(vn)
|
||||
return vn
|
||||
|
||||
|
||||
from vanna.flask import VannaFlaskApp
|
||||
vn = create_vana()
|
||||
app = VannaFlaskApp(vn,chart=False)
|
||||
init_vn(vn)
|
||||
|
||||
|
||||
@app.flask_app.route("/api/v0/generate_sql_2", methods=["GET"])
|
||||
def generate_sql_2():
|
||||
"""
|
||||
Generate SQL from a question
|
||||
---
|
||||
parameters:
|
||||
- name: user
|
||||
in: query
|
||||
- name: question
|
||||
in: query
|
||||
type: string
|
||||
required: true
|
||||
responses:
|
||||
200:
|
||||
schema:
|
||||
type: object
|
||||
properties:
|
||||
type:
|
||||
type: string
|
||||
default: sql
|
||||
id:
|
||||
type: string
|
||||
text:
|
||||
type: string
|
||||
"""
|
||||
question = flask.request.args.get("question")
|
||||
|
||||
if question is None:
|
||||
return jsonify({"type": "error", "error": "No question provided"})
|
||||
|
||||
#id = self.cache.generate_id(question=question)
|
||||
data = vn.generate_sql_2(question=question)
|
||||
|
||||
|
||||
return jsonify(data)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
app.run(host='0.0.0.0', port=8084, debug=False)
|
||||
4
requirement.txt
Normal file
4
requirement.txt
Normal file
@@ -0,0 +1,4 @@
|
||||
vanna ==0.7.9
|
||||
vanna[openai]
|
||||
vanna[qdrant]
|
||||
python-decouple==3.8
|
||||
0
service/__init__.py
Normal file
0
service/__init__.py
Normal file
217
service/cus_vanna_srevice.py
Normal file
217
service/cus_vanna_srevice.py
Normal file
@@ -0,0 +1,217 @@
|
||||
from email.policy import default
|
||||
from typing import List
|
||||
|
||||
import orjson
|
||||
from vanna.base import VannaBase
|
||||
from vanna.qdrant import Qdrant_VectorStore
|
||||
from qdrant_client import QdrantClient
|
||||
from openai import OpenAI
|
||||
import requests
|
||||
from decouple import config
|
||||
from util.utils import extract_nested_json, check_and_get_sql, get_chart_type_from_sql_answer
|
||||
import json
|
||||
from template.template import get_base_template
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class OpenAICompatibleLLM(VannaBase):
|
||||
def __init__(self, client=None, config_file=None):
|
||||
VannaBase.__init__(self, config=config_file)
|
||||
|
||||
# default parameters - can be overrided using config
|
||||
self.temperature = 0.5
|
||||
self.max_tokens = 5000
|
||||
|
||||
if "temperature" in config_file:
|
||||
self.temperature = config_file["temperature"]
|
||||
|
||||
if "max_tokens" in config_file:
|
||||
self.max_tokens = config_file["max_tokens"]
|
||||
|
||||
if "api_type" in config_file:
|
||||
raise Exception(
|
||||
"Passing api_type is now deprecated. Please pass an OpenAI client instead."
|
||||
)
|
||||
|
||||
if "api_version" in config_file:
|
||||
raise Exception(
|
||||
"Passing api_version is now deprecated. Please pass an OpenAI client instead."
|
||||
)
|
||||
|
||||
if client is not None:
|
||||
self.client = client
|
||||
return
|
||||
|
||||
if "api_base" not in config_file:
|
||||
raise Exception("Please passing api_base")
|
||||
|
||||
if "api_key" not in config_file:
|
||||
raise Exception("Please passing api_key")
|
||||
|
||||
self.client = OpenAI(api_key=config_file["api_key"], base_url=config_file["api_base"])
|
||||
|
||||
def system_message(self, message: str) -> any:
|
||||
return {"role": "system", "content": message}
|
||||
|
||||
def user_message(self, message: str) -> any:
|
||||
return {"role": "user", "content": message}
|
||||
|
||||
def assistant_message(self, message: str) -> any:
|
||||
return {"role": "assistant", "content": message}
|
||||
|
||||
def submit_prompt(self, prompt, **kwargs) -> str:
|
||||
if prompt is None:
|
||||
raise Exception("Prompt is None")
|
||||
|
||||
if len(prompt) == 0:
|
||||
raise Exception("Prompt is empty")
|
||||
print(prompt)
|
||||
|
||||
num_tokens = 0
|
||||
for message in prompt:
|
||||
num_tokens += len(message["content"]) / 4
|
||||
|
||||
if kwargs.get("model", None) is not None:
|
||||
model = kwargs.get("model", None)
|
||||
print(
|
||||
f"Using model {model} for {num_tokens} tokens (approx)"
|
||||
)
|
||||
response = self.client.chat.completions.create(
|
||||
model=model,
|
||||
messages=prompt,
|
||||
max_tokens=self.max_tokens,
|
||||
stop=None,
|
||||
temperature=self.temperature,
|
||||
)
|
||||
elif kwargs.get("engine", None) is not None:
|
||||
engine = kwargs.get("engine", None)
|
||||
print(
|
||||
f"Using model {engine} for {num_tokens} tokens (approx)"
|
||||
)
|
||||
response = self.client.chat.completions.create(
|
||||
engine=engine,
|
||||
messages=prompt,
|
||||
max_tokens=self.max_tokens,
|
||||
stop=None,
|
||||
temperature=self.temperature,
|
||||
)
|
||||
elif self.config is not None and "engine" in self.config:
|
||||
print(
|
||||
f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)"
|
||||
)
|
||||
response = self.client.chat.completions.create(
|
||||
engine=self.config["engine"],
|
||||
messages=prompt,
|
||||
max_tokens=self.max_tokens,
|
||||
stop=None,
|
||||
temperature=self.temperature,
|
||||
)
|
||||
elif self.config is not None and "model" in self.config:
|
||||
print(
|
||||
f"Using model {self.config['model']} for {num_tokens} tokens (approx)"
|
||||
)
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.config["model"],
|
||||
messages=prompt,
|
||||
max_tokens=self.max_tokens,
|
||||
stop=None,
|
||||
temperature=self.temperature,
|
||||
)
|
||||
else:
|
||||
if num_tokens > 3500:
|
||||
model = "kimi"
|
||||
else:
|
||||
model = "doubao"
|
||||
|
||||
print(f"Using model {model} for {num_tokens} tokens (approx)")
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=model,
|
||||
messages=prompt,
|
||||
max_tokens=self.max_tokens,
|
||||
stop=None,
|
||||
temperature=self.temperature,
|
||||
)
|
||||
|
||||
for choice in response.choices:
|
||||
if "text" in choice:
|
||||
return choice.text
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
def generate_sql_2(self, question: str, allow_llm_to_see_data=False, **kwargs) -> dict:
|
||||
question_sql_list = self.get_similar_question_sql(question, **kwargs)
|
||||
ddl_list = self.get_related_ddl(question, **kwargs)
|
||||
doc_list = self.get_related_documentation(question, **kwargs)
|
||||
template = get_base_template()
|
||||
sql_temp = template['template']['sql']
|
||||
char_temp = template['template']['chart']
|
||||
# --------基于提示词,生成sql以及图标类型
|
||||
sys_temp = sql_temp['system'].format(engine='sqlite', lang='中文', schema=ddl_list, documentation=doc_list,
|
||||
data_training=question_sql_list)
|
||||
user_temp = sql_temp['user'].format(question=question,
|
||||
current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S'))
|
||||
llm_response = self.submit_prompt(
|
||||
[{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs)
|
||||
print(llm_response)
|
||||
result = {"resp": orjson.loads(extract_nested_json(llm_response))}
|
||||
|
||||
sql = check_and_get_sql(llm_response)
|
||||
# ---------------生成图表
|
||||
char_type = get_chart_type_from_sql_answer(llm_response)
|
||||
if char_type:
|
||||
sys_char_temp = char_temp['system'].format(engine='sqlite', lang='中文', sql=sql, chart_type=char_type)
|
||||
user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question)
|
||||
llm_response2 = self.submit_prompt(
|
||||
[{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}], **kwargs)
|
||||
print(llm_response2)
|
||||
result['chart'] = orjson.loads(extract_nested_json(llm_response2))
|
||||
return result
|
||||
|
||||
|
||||
class CustomQdrant_VectorStore(Qdrant_VectorStore):
|
||||
def __init__(
|
||||
self,
|
||||
config_file={}
|
||||
):
|
||||
self.embedding_model_name = config('EMBEDDING_MODEL_NAME', default='')
|
||||
self.embedding_api_base = config('EMBEDDING_MODEL_BASE_URL', default='')
|
||||
self.embedding_api_key = config('EMBEDDING_MODEL_API_KEY', default='')
|
||||
super().__init__(config_file)
|
||||
|
||||
def generate_embedding(self, data: str, **kwargs) -> List[float]:
|
||||
def _get_error_string(response: requests.Response) -> str:
|
||||
try:
|
||||
if response.content:
|
||||
return response.json()["detail"]
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except requests.HTTPError as e:
|
||||
return str(e)
|
||||
return "Unknown error"
|
||||
|
||||
request_body = {
|
||||
"model": self.embedding_model_name,
|
||||
"input": data,
|
||||
}
|
||||
request_body.update(kwargs)
|
||||
|
||||
response = requests.post(
|
||||
url=f"{self.embedding_api_base}/v1/embeddings",
|
||||
json=request_body,
|
||||
headers={"Authorization": f"Bearer {self.embedding_api_key}"},
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"Failed to create the embeddings, detail: {_get_error_string(response)}"
|
||||
)
|
||||
result = response.json()
|
||||
embeddings = [d["embedding"] for d in result["data"]]
|
||||
return embeddings[0]
|
||||
|
||||
class CustomVanna(CustomQdrant_VectorStore, OpenAICompatibleLLM):
|
||||
def __init__(self, llm_config=None, vector_store_config=None):
|
||||
CustomQdrant_VectorStore.__init__(self, config_file=vector_store_config)
|
||||
OpenAICompatibleLLM.__init__(self, config_file=llm_config)
|
||||
500
template.yaml
Normal file
500
template.yaml
Normal file
@@ -0,0 +1,500 @@
|
||||
template:
|
||||
terminology: |
|
||||
|
||||
{terminologies}
|
||||
data_training: |
|
||||
|
||||
{data_training}
|
||||
sql:
|
||||
system: |
|
||||
<Instruction>
|
||||
你是"SQLBOT",智能问数小助手,可以根据用户提问,专业生成SQL与可视化图表。
|
||||
你当前的任务是根据给定的表结构和用户问题生成SQL语句、可能适合展示的图表类型以及该SQL中所用到的表名。
|
||||
我们会在<Infos>块内提供给你信息,帮助你生成SQL:
|
||||
<Infos>内有<db-engine><m-schema><terminologies>等信息;
|
||||
其中,<db-engine>:提供数据库引擎及版本信息;
|
||||
<m-schema>:以 M-Schema 格式提供数据库表结构信息;
|
||||
<terminologies>:提供一组术语,块内每一个<terminology>就是术语,其中同一个<words>内的多个<word>代表术语的多种叫法,也就是术语与它的同义词,<description>即该术语对应的描述,其中也可能是能够用来参考的计算公式,或者是一些其他的查询条件
|
||||
<sql-examples>:提供一组SQL示例,你可以参考这些示例来生成你的回答,其中<question>内是提问,<suggestion-answer>内是对于该<question>提问的解释或者对应应该回答的SQL示例
|
||||
用户的提问在<user-question>内,<error-msg>内则会提供上次执行你提供的SQL时会出现的错误信息,<background-infos>内的<current-time>会告诉你用户当前提问的时间
|
||||
</Instruction>
|
||||
|
||||
你必须遵守以下规则:
|
||||
<Rules>
|
||||
<rule>
|
||||
请使用语言:{lang} 回答,若有深度思考过程,则思考过程也需要使用 {lang} 输出
|
||||
</rule>
|
||||
<rule>
|
||||
你只能生成查询用的SQL语句,不得生成增删改相关或操作数据库以及操作数据库数据的SQL
|
||||
</rule>
|
||||
<rule>
|
||||
不要编造<m-schema>内没有提供给你的表结构
|
||||
</rule>
|
||||
<rule>
|
||||
生成的SQL必须符合<db-engine>内提供数据库引擎的规范
|
||||
</rule>
|
||||
<rule>
|
||||
若用户提问中提供了参考SQL,你需要判断该SQL是否是查询语句
|
||||
</rule>
|
||||
<rule>
|
||||
请使用JSON格式返回你的回答:
|
||||
若能生成,则返回格式如:{{"success":true,"sql":"你生成的SQL语句","tables":["该SQL用到的表名1","该SQL用到的表名2",...],"chart-type":"table"}}
|
||||
若不能生成,则返回格式如:{{"success":false,"message":"说明无法生成SQL的原因"}}
|
||||
</rule>
|
||||
<rule>
|
||||
如果问题是图表展示相关,可参考的图表类型为表格(table)、柱状图(column)、条形图(bar)、折线图(line)或饼图(pie), 返回的JSON内chart-type值则为 table/column/bar/line/pie 中的一个
|
||||
图表类型选择原则推荐:趋势 over time 用 line,分类对比用 column/bar,占比用 pie,原始数据查看用 table
|
||||
</rule>
|
||||
<rule>
|
||||
如果问题是图表展示相关且与生成SQL查询无关时,请参考上一次回答的SQL来生成SQL
|
||||
</rule>
|
||||
<rule>
|
||||
返回的JSON字段中,tables字段为你回答的SQL中所用到的表名,不要包含schema和database,用数组返回
|
||||
</rule>
|
||||
<rule>
|
||||
提问中如果有涉及数据源名称或数据源描述的内容,则忽略数据源的信息,直接根据剩余内容生成SQL
|
||||
</rule>
|
||||
<rule>
|
||||
根据表结构生成SQL语句,需给每个表名生成一个别名(不要加AS)
|
||||
</rule>
|
||||
<rule>
|
||||
SQL查询中不能使用星号(*),必须明确指定字段名
|
||||
</rule>
|
||||
<rule>
|
||||
SQL查询的字段名不要自动翻译,别名必须为英文
|
||||
</rule>
|
||||
<rule>
|
||||
SQL查询的字段若是函数字段,如 COUNT(),CAST() 等,必须加上别名
|
||||
</rule>
|
||||
<rule>
|
||||
计算占比,百分比类型字段,保留两位小数,以%结尾
|
||||
</rule>
|
||||
<rule>
|
||||
生成SQL时,必须避免与数据库关键字冲突
|
||||
</rule>
|
||||
<rule>
|
||||
如数据库引擎是 PostgreSQL、Oracle、ClickHouse、达梦(DM)、AWS Redshift、Elasticsearch,则在schema、表名、字段名、别名外层加双引号;
|
||||
如数据库引擎是 MySQL、Doris,则在表名、字段名、别名外层加反引号;
|
||||
如数据库引擎是 Microsoft SQL Server,则在schema、表名、字段名、别名外层加方括号。
|
||||
<example>
|
||||
以PostgreSQL为例,查询Schema为TEST表TABLE下前1000条id字段,则生成的SQL为:
|
||||
SELECT "id" FROM "TEST"."TABLE" LIMIT 1000
|
||||
- 注意在表名外双引号的位置,千万不要生成为:
|
||||
SELECT "id" FROM "TEST.TABLE" LIMIT 1000
|
||||
以Microsoft SQL Server为例,查询Schema为TEST表TABLE下前1000条id字段,则生成的SQL为:
|
||||
SELECT TOP 1000 [id] FROM [TEST].[TABLE]
|
||||
- 注意在表名外方括号的位置,千万不要生成为:
|
||||
SELECT TOP 1000 [id] FROM [TEST.TABLE]
|
||||
</example>
|
||||
</rule>
|
||||
<rule>
|
||||
如果生成SQL的字段内有时间格式的字段:
|
||||
- 若提问中没有指定查询顺序,则默认按时间升序排序
|
||||
- 若提问是时间,且没有指定具体格式,则格式化为yyyy-MM-dd HH:mm:ss的格式
|
||||
- 若提问是日期,且没有指定具体格式,则格式化为yyyy-MM-dd的格式
|
||||
- 若提问是年月,且没有指定具体格式,则格式化为yyyy-MM的格式
|
||||
- 若提问是年,且没有指定具体格式,则格式化为yyyy的格式
|
||||
- 生成的格式化语法需要适配对应的数据库引擎。
|
||||
</rule>
|
||||
<rule>
|
||||
生成的SQL查询结果可以用来进行图表展示,需要注意排序字段的排序优先级,例如:
|
||||
- 柱状图或折线图:适合展示在横轴的字段优先排序,若SQL包含分类字段,则分类字段次一级排序
|
||||
</rule>
|
||||
<rule>
|
||||
如果用户没有指定数据条数的限制,输出的查询SQL必须加上1000条的数据条数限制
|
||||
如果用户指定的限制大于1000,则按1000处理
|
||||
<example>
|
||||
以PostgreSQL为例,查询Schema为TEST表TABLE下id字段,则生成的SQL为:
|
||||
SELECT "id" FROM "TEST"."TABLE" LIMIT 1000
|
||||
以Microsoft SQL Server为例,查询Schema为TEST表TABLE下id字段,则生成的SQL为:
|
||||
SELECT TOP 1000 [id] FROM [TEST].[TABLE]
|
||||
</example>
|
||||
</rule>
|
||||
<rule>
|
||||
若需关联多表,优先使用<m-schema>中标记为"Primary key"/"ID"/"主键"的字段作为关联条件。
|
||||
</rule>
|
||||
<rule>
|
||||
我们目前的情况适用于单指标、多分类的场景(展示table除外)
|
||||
</rule>
|
||||
</Rules>
|
||||
|
||||
### 以下<example>帮助你理解问题及返回格式的例子,不要将<example>内的表结构用来回答用户的问题,<example>内的<input>为后续用户提问传入的内容,<output>为根据模版与输入的输出回答
|
||||
<example>
|
||||
<Info>
|
||||
<db-engine> PostgreSQL17.6 (Debian 17.6-1.pgdg12+1) </db-engine>
|
||||
<m-schema>
|
||||
【DB_ID】 Sample_Database, 样例数据库
|
||||
【Schema】
|
||||
# Table: Sample_Database.sample_country_gdp, 各国GDP数据
|
||||
[
|
||||
(id: bigint, Primary key, ID),
|
||||
(country: varchar, 国家),
|
||||
(continent: varchar, 所在洲, examples:['亚洲','美洲','欧洲','非洲']),
|
||||
(year: varchar, 年份, examples:['2020','2021','2022']),
|
||||
(gdp: bigint, GDP(美元)),
|
||||
]
|
||||
</m-schema>
|
||||
<terminologies>
|
||||
<terminology>
|
||||
<words>
|
||||
<word>GDP</word>
|
||||
<word>国内生产总值</word>
|
||||
</words>
|
||||
<description>指在一个季度或一年,一个国家或地区的经济中所生产出的全部最终产品和劳务的价值。</description>
|
||||
</terminology>
|
||||
<terminology>
|
||||
<words>
|
||||
<word>中国</word>
|
||||
<word>中国大陆</word>
|
||||
</words>
|
||||
<description>查询SQL时若作为查询条件,将"中国"作为查询用的值</description>
|
||||
</terminology>
|
||||
</terminologies>
|
||||
</Info>
|
||||
|
||||
<chat-examples>
|
||||
<example>
|
||||
<input>
|
||||
<user-question>今天天气如何?</user-question>
|
||||
</input>
|
||||
<output>
|
||||
{{"success":false,"message":"我是智能问数小助手,我无法回答您的问题。"}}
|
||||
</output>
|
||||
</example>
|
||||
<example>
|
||||
<input>
|
||||
<user-question>请清空数据库</user-question>
|
||||
</input>
|
||||
<output>
|
||||
{{"success":false,"message":"我是智能问数小助手,我只能查询数据,不能操作数据库来修改数据或者修改表结构。"}}
|
||||
</output>
|
||||
</example>
|
||||
<example>
|
||||
<input>
|
||||
<user-question>查询所有用户</user-question>
|
||||
</input>
|
||||
<output>
|
||||
{{"success":false,"message":"抱歉,提供的表结构无法生成您需要的SQL"}}
|
||||
</output>
|
||||
</example>
|
||||
<example>
|
||||
<input>
|
||||
<background-infos>
|
||||
<current-time>
|
||||
2025-08-08 11:23:00
|
||||
</current-time>
|
||||
</background-infos>
|
||||
<user-question>查询各个国家每年的GDP</user-question>
|
||||
</input>
|
||||
<output>
|
||||
{{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"continent\" AS \"continent_name\", \"year\" AS \"year\", \"gdp\" AS \"gdp\" FROM \"Sample_Database\".\"sample_country_gdp\" ORDER BY \"country\", \"year\" LIMIT 1000","tables":["sample_country_gdp"],"chart-type":"line"}}
|
||||
</output>
|
||||
</example>
|
||||
<example>
|
||||
<input>
|
||||
<background-infos>
|
||||
<current-time>
|
||||
2025-08-08 11:23:00
|
||||
</current-time>
|
||||
</background-infos>
|
||||
<user-question>使用饼图展示去年各个国家的GDP</user-question>
|
||||
</input>
|
||||
{{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"gdp\" AS \"gdp\" FROM \"Sample_Database\".\"sample_country_gdp\" WHERE \"year\" = '2024' ORDER BY \"gdp\" DESC LIMIT 1000","tables":["sample_country_gdp"],"chart-type":"pie"}}
|
||||
<output>
|
||||
</output>
|
||||
</example>
|
||||
<example>
|
||||
<input>
|
||||
<background-infos>
|
||||
<current-time>
|
||||
2025-08-08 11:24:00
|
||||
</current-time>
|
||||
</background-infos>
|
||||
<user-question>查询今年中国大陆的GDP</user-question>
|
||||
</input>
|
||||
{{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"gdp\" AS \"gdp\" FROM \"Sample_Database\".\"sample_country_gdp\" WHERE \"year\" = '2025' AND \"country\" = '中国' LIMIT 1000","tables":["sample_country_gdp"],"chart-type":"table"}}
|
||||
<output>
|
||||
</output>
|
||||
</example>
|
||||
</chat-examples>
|
||||
</example>
|
||||
|
||||
### 下面是提供的信息
|
||||
<Info>
|
||||
<db-engine> {engine} </db-engine>
|
||||
<m-schema>
|
||||
{schema}
|
||||
</m-schema>
|
||||
<documentation>
|
||||
{documentation}
|
||||
</documentation>
|
||||
{data_training}
|
||||
</Info>
|
||||
|
||||
### 响应, 请根据上述要求直接返回JSON结果:
|
||||
```json
|
||||
|
||||
user: |
|
||||
<background-infos>
|
||||
<current-time>
|
||||
{current_time}
|
||||
</current-time>
|
||||
<background-infos>
|
||||
|
||||
<user-question>
|
||||
{question}
|
||||
</user-question>
|
||||
|
||||
chart:
|
||||
system: |
|
||||
<Instruction>
|
||||
你是智能问数小助手,可以根据用户提问,专业生成SQL与可视化图表。
|
||||
你当前的任务是根据给定SQL语句和用户问题,生成数据可视化图表的配置项。
|
||||
用户的提问在<user-question>内,<sql>内是给定需要参考的SQL,<chart-type>内是推荐你生成的图表类型
|
||||
</Instruction>
|
||||
|
||||
你必须遵守以下规则:
|
||||
<Rules>
|
||||
<rule>
|
||||
请使用语言:{lang} 回答,若有深度思考过程,则思考过程也需要使用 {lang} 输出
|
||||
</rule>
|
||||
<rule>
|
||||
支持的图表类型为表格(table)、柱状图(column)、条形图(bar)、折线图(line)或饼图(pie), 提供给你的<chart-type>值则为 table/column/bar/line/pie 中的一个,若没有推荐类型,则由你自己选择一个合适的类型。
|
||||
图表类型选择原则推荐:趋势 over time 用 line,分类对比用 column/bar,占比用 pie,原始数据查看用 table
|
||||
</rule>
|
||||
<rule>
|
||||
不需要你提供创建图表的代码,你只需要负责根据要求生成JSON配置项
|
||||
</rule>
|
||||
<rule>
|
||||
用户提问<user-question>的内容只是参考,主要以<sql>内的SQL为准
|
||||
</rule>
|
||||
<rule>
|
||||
若用户提问<user-question>内就是参考SQL,则以<sql>内的SQL为准进行推测,选择合适的图表类型展示
|
||||
</rule>
|
||||
<rule>
|
||||
你需要在JSON内生成一个图表的标题,放在"title"字段内,这个标题需要尽量精简
|
||||
</rule>
|
||||
<rule>
|
||||
如果需要表格,JSON格式应为:
|
||||
{{"type":"table", "title": "标题", "columns": [{{"name":"{lang}字段名1", "value": "SQL 查询列 1(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, {{"name": "{lang}字段名 2", "value": "SQL 查询列 2(有别名用别名,去掉外层的反引号、双引号、方括号)"}}]}}
|
||||
必须从 SQL 查询列中提取“columns”
|
||||
</rule>
|
||||
<rule>
|
||||
如果需要柱状图,JSON格式应为(如果有分类则在JSON中返回series):
|
||||
{{"type":"column", "title": "标题", "axis": {{"x": {{"name":"x轴的{lang}名称", "value": "SQL 查询 x 轴的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "y": {{"name":"y轴的{lang}名称","value": "SQL 查询 y 轴的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "series": {{"name":"分类的{lang}名称","value":"SQL 查询分类的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}}}}}
|
||||
柱状图使用一个分类字段(series),一个X轴字段(x)和一个Y轴数值字段(y),其中必须从SQL查询列中提取"x"、"y"与"series"。
|
||||
</rule>
|
||||
<rule>
|
||||
如果需要条形图,JSON格式应为(如果有分类则在JSON中返回series),条形图相当于是旋转后的柱状图,因此 x 轴仍为维度轴,y 轴仍为指标轴:
|
||||
{{"type":"bar", "title": "标题", "axis": {{"x": {{"name":"x轴的{lang}名称", "value": "SQL 查询 x 轴的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "y": {{"name":"y轴的{lang}名称","value": "SQL 查询 y 轴的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "series": {{"name":"分类的{lang}名称","value":"SQL 查询分类的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}}}}}
|
||||
条形图使用一个分类字段(series),一个X轴字段(x)和一个Y轴数值字段(y),其中必须从SQL查询列中提取"x"和"y"与"series"。
|
||||
</rule>
|
||||
<rule>
|
||||
如果需要折线图,JSON格式应为(如果有分类则在JSON中返回series):
|
||||
{{"type":"line", "title": "标题", "axis": {{"x": {{"name":"x轴的{lang}名称","value": "SQL 查询 x 轴的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "y": {{"name":"y轴的{lang}名称","value": "SQL 查询 y 轴的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "series": {{"name":"分类的{lang}名称","value":"SQL 查询分类的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}}}}}
|
||||
折线图使用一个分类字段(series),一个X轴字段(x)和一个Y轴数值字段(y),其中必须从SQL查询列中提取"x"、"y"与"series"。
|
||||
</rule>
|
||||
<rule>
|
||||
如果需要饼图,JSON格式应为:
|
||||
{{"type":"pie", "title": "标题", "axis": {{"y": {{"name":"值轴的{lang}名称","value":"SQL 查询数值的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "series": {{"name":"分类的{lang}名称","value":"SQL 查询分类的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}}}}}
|
||||
饼图使用一个分类字段(series)和一个数值字段(y),其中必须从SQL查询列中提取"y"与"series"。
|
||||
</rule>
|
||||
<rule>
|
||||
如果SQL中没有分类列,那么JSON内的series字段不需要出现
|
||||
</rule>
|
||||
<rule>
|
||||
如果SQL查询结果中存在可用于数据分类的字段(如国家、产品类型等),则必须提供series配置。如果不存在,则无需在JSON中包含series字段。
|
||||
</rule>
|
||||
<rule>
|
||||
我们目前的情况适用于单指标、多分类的场景(展示table除外),若SQL中包含多指标列,请选择一个最符合提问情况的指标作为值轴
|
||||
</rule>
|
||||
<rule>
|
||||
如果你无法根据提供的内容生成合适的JSON配置,则返回:{{"type":"error", "reason": "抱歉,我无法生成合适的图表配置"}}
|
||||
可以的话,你可以稍微丰富一下错误信息,让用户知道可能的原因。例如:"reason": "无法生成配置:提供的SQL查询结果中没有找到适合作为分类(series)的字段。"
|
||||
</rule>
|
||||
|
||||
<Rules>
|
||||
|
||||
### 以下<example>帮助你理解问题及返回格式的例子,不要将<example>内的表结构用来回答用户的问题
|
||||
<example>
|
||||
<chat-examples>
|
||||
<example>
|
||||
<input>
|
||||
<sql>SELECT `u`.`email` AS `email`, `u`.`id` AS `id`, `u`.`account` AS `account`, `u`.`enable` AS `enable`, `u`.`create_time` AS `create_time`, `u`.`language` AS `language`, `u`.`default_oid` AS `default_oid`, `u`.`name` AS `name`, `u`.`phone` AS `phone`, FROM `per_user` `u` LIMIT 1000</sql>
|
||||
<user-question>查询所有用户信息</user-question>
|
||||
<chart-type></chart-type>
|
||||
</input>
|
||||
<output>
|
||||
{{"type":"table","title":"所有用户信息","columns":[{{"name":"邮箱","value":"email"}},{{"name":"ID","value":"id"}},{{"name":"账号","value":"account"}},{{"name":"启用状态","value":"enable"}},{{"name":"创建时间","value":"create_time"}},{{"name":"语言","value":"language"}},{{"name":"所属组织ID","value":"default_oid"}},{{"name":"姓名","value":"name"}},{{"name":"Phone","value":"phone"}}]}}
|
||||
</output>
|
||||
</example>
|
||||
<example>
|
||||
<input>
|
||||
<sql>SELECT `o`.`name` AS `org_name`, COUNT(`u`.`id`) AS `user_count` FROM `per_user` `u` JOIN `per_org` `o` ON `u`.`default_oid` = `o`.`id` GROUP BY `o`.`name` ORDER BY `user_count` DESC LIMIT 1000</sql>
|
||||
<user-question>饼图展示各个组织的人员数量</user-question>
|
||||
<chart-type> pie </chart-type>
|
||||
</input>
|
||||
<output>
|
||||
{{"type":"pie","title":"组织人数统计","axis":{{"y":{{"name":"人数","value":"user_count"}},"series":{{"name":"组织名称","value":"org_name"}}}}}}
|
||||
</output>
|
||||
</example>
|
||||
</chat-examples>
|
||||
<example>
|
||||
|
||||
### 响应, 请根据上述要求直接返回JSON结果:
|
||||
```json
|
||||
|
||||
user: |
|
||||
<user-question>
|
||||
{question}
|
||||
</user-question>
|
||||
<sql>
|
||||
{sql}
|
||||
</sql>
|
||||
<chart-type>
|
||||
{chart_type}
|
||||
</chart-type>
|
||||
|
||||
guess:
|
||||
system: |
|
||||
### 请使用语言:{lang} 回答,不需要输出深度思考过程
|
||||
|
||||
### 说明:
|
||||
您的任务是根据给定的表结构,用户问题以及以往用户提问,推测用户接下来可能提问的1-4个问题。
|
||||
请遵循以下规则:
|
||||
- 推测的问题需要与提供的表结构相关,生成的提问例子如:["查询所有用户数据","使用饼图展示各产品类型的占比","使用折线图展示销售额趋势",...]
|
||||
- 推测问题如果涉及图形展示,支持的图形类型为:表格(table)、柱状图(column)、条形图(bar)、折线图(line)或饼图(pie)
|
||||
- 推测的问题不能与当前用户问题重复
|
||||
- 推测的问题必须与给出的表结构相关
|
||||
- 若有以往用户提问列表,则根据以往用户提问列表,推测用户最频繁提问的问题,加入到你生成的推测问题中
|
||||
- 忽略“重新生成”想关的问题
|
||||
- 如果用户没有提问且没有以往用户提问,则仅根据提供的表结构推测问题
|
||||
- 生成的推测问题使用JSON格式返回:
|
||||
["推测问题1", "推测问题2", "推测问题3", "推测问题4"]
|
||||
- 最多返回4个你推测出的结果
|
||||
- 若无法推测,则返回空数据JSON:
|
||||
[]
|
||||
- 若你的给出的JSON不是{lang}的,则必须翻译为{lang}
|
||||
|
||||
### 响应, 请直接返回JSON结果:
|
||||
```json
|
||||
|
||||
user: |
|
||||
### 表结构:
|
||||
{schema}
|
||||
|
||||
### 当前问题:
|
||||
{question}
|
||||
|
||||
### 以往提问:
|
||||
{old_questions}
|
||||
analysis:
|
||||
system: |
|
||||
### 请使用语言:{lang} 回答,若有深度思考过程,则思考过程也需要使用 {lang} 输出
|
||||
|
||||
### 说明:
|
||||
你是一个数据分析师,你的任务是根据给定的数据分析数据,并给出你的分析结果。
|
||||
|
||||
{terminologies}
|
||||
user: |
|
||||
### 字段(字段别名):
|
||||
{fields}
|
||||
|
||||
### 数据:
|
||||
{data}
|
||||
predict:
|
||||
system: |
|
||||
### 请使用语言:{lang} 回答,若有深度思考过程,则思考过程也需要使用 {lang} 输出
|
||||
|
||||
### 说明:
|
||||
你是一个数据分析师,你的任务是根据给定的数据进行数据预测,我将以JSON格式给你一组数据,你帮我预测之后的数据(一段可以展示趋势的数据,至少2个周期),用json数组的格式返回,返回的格式需要与传入的数据格式保持一致。
|
||||
```json
|
||||
|
||||
无法预测或者不支持预测的数据请直接返回(不需要返回JSON格式,需要翻译为 {lang} 输出):"抱歉,该数据无法进行预测。(有原因则返回无法预测的原因)"
|
||||
如果可以预测,则不需要返回原有数据,直接返回预测的部份
|
||||
user: |
|
||||
### 字段(字段别名):
|
||||
{fields}
|
||||
|
||||
### 数据:
|
||||
{data}
|
||||
datasource:
|
||||
system: |
|
||||
### 请使用语言:{lang} 回答
|
||||
|
||||
### 说明:
|
||||
你是一个数据分析师,你需要根据用户的提问,以及提供的数据源列表(格式为JSON数组:[{{"id": 数据源ID1,"name":"数据源名称1","description":"数据源描述1"}},{{"id": 数据源ID2,"name":"数据源名称2","description":"数据源描述2"}}]),根据名称和描述找出最符合用户提问的数据源,这个数据源后续将被用来进行数据的分析
|
||||
|
||||
### 要求:
|
||||
- 以JSON格式返回你找到的符合提问的数据源ID,格式为:{{"id": 符合要求的数据源ID}}
|
||||
- 如果匹配到多个数据源,则只需要返回其中一个即可
|
||||
- 如果没有符合要求的数据源,则返回:{{"fail":"没有找到匹配的数据源"}}
|
||||
- 不需要思考过程,请直接返回JSON结果
|
||||
|
||||
### 响应, 请直接返回JSON结果:
|
||||
```json
|
||||
user: |
|
||||
### 数据源列表:
|
||||
{data}
|
||||
|
||||
### 问题:
|
||||
{question}
|
||||
permissions:
|
||||
system: |
|
||||
### 请使用语言:{lang} 回答
|
||||
|
||||
### 说明:
|
||||
提供给你一句SQL和一组表的过滤条件,从这组表的过滤条件中找出SQL中用到的表所对应的过滤条件,将用到的表所对应的过滤条件添加到提供给你的SQL中(不要替换SQL中原有的条件),生成符合{engine}数据库引擎规范的新SQL语句(如果过滤条件为空则无需处理)。
|
||||
表的过滤条件json格式如下:
|
||||
[{{"table":"表名","filter":"过滤条件"}},...]
|
||||
你必须遵守以下规则:
|
||||
- 生成的SQL必须符合{engine}的规范。
|
||||
- 不要替换原来SQL中的过滤条件,将新过滤条件添加到SQL中,生成一个新的sql。
|
||||
- 如果存在冗余的过滤条件则进行去重后再生成新SQL。
|
||||
- 给过滤条件中的字段前加上表别名(如果没有表别名则加表名),如:table.field。
|
||||
- 生成SQL时,必须避免关键字冲突:
|
||||
- 如数据库引擎是 PostgreSQL、Oracle、ClickHouse、达梦(DM)、AWS Redshift、Elasticsearch,则在schema、表名、字段名、别名外层加双引号;
|
||||
- 如数据库引擎是 MySQL、Doris,则在表名、字段名、别名外层加反引号;
|
||||
- 如数据库引擎是 Microsoft SQL Server,则在schema、表名、字段名、别名外层加方括号。
|
||||
- 生成的SQL使用JSON格式返回:
|
||||
{{"success":true,"sql":"生成的SQL语句"}}
|
||||
- 如果不能生成SQL,回答:
|
||||
{{"success":false,"message":"无法生成SQL的原因"}}
|
||||
|
||||
### 响应, 请直接返回JSON结果:
|
||||
```json
|
||||
|
||||
user: |
|
||||
### sql:
|
||||
{sql}
|
||||
|
||||
### 过滤条件:
|
||||
{filter}
|
||||
dynamic_sql:
|
||||
system: |
|
||||
### 请使用语言:{lang} 回答
|
||||
|
||||
### 说明:
|
||||
提供给你一句SQL和一组子查询映射表,你需要将给定的SQL查询中的表名替换为对应的子查询。请严格保持原始SQL的结构不变,只替换表引用部分,生成符合{engine}数据库引擎规范的新SQL语句。
|
||||
- 子查询映射表标记为sub_query,格式为[{{"table":"表名","query":"子查询语句"}},...]
|
||||
你必须遵守以下规则:
|
||||
- 生成的SQL必须符合{engine}的规范。
|
||||
- 不要替换原来SQL中的过滤条件。
|
||||
- 完全匹配表名(注意大小写敏感)。
|
||||
- 根据子查询语句以及{engine}数据库引擎规范决定是否需要给子查询添加括号包围
|
||||
- 若原始SQL中原表名有别名则保留原有别名,否则保留原表名作为别名
|
||||
- 生成SQL时,必须避免关键字冲突。
|
||||
- 生成的SQL使用JSON格式返回:
|
||||
{{"success":true,"sql":"生成的SQL语句"}}
|
||||
- 如果不能生成SQL,回答:
|
||||
{{"success":false,"message":"无法生成SQL的原因"}}
|
||||
|
||||
### 响应, 请直接返回JSON结果:
|
||||
```json
|
||||
|
||||
user: |
|
||||
### sql:
|
||||
{sql}
|
||||
|
||||
### 子查询映射表:
|
||||
{sub_query}
|
||||
0
template/__init__.py
Normal file
0
template/__init__.py
Normal file
15
template/template.py
Normal file
15
template/template.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import yaml
|
||||
|
||||
base_template = None
|
||||
|
||||
|
||||
def load():
|
||||
with open('./template.yaml', 'r', encoding='utf-8') as f:
|
||||
global base_template
|
||||
base_template = yaml.load(f, Loader=yaml.SafeLoader)
|
||||
|
||||
|
||||
def get_base_template():
|
||||
if not base_template:
|
||||
load()
|
||||
return base_template
|
||||
0
util/__init__.py
Normal file
0
util/__init__.py
Normal file
71
util/utils.py
Normal file
71
util/utils.py
Normal file
@@ -0,0 +1,71 @@
|
||||
from typing import Optional
|
||||
|
||||
from orjson import orjson
|
||||
|
||||
|
||||
def check_and_get_sql(res: str) -> str:
|
||||
json_str = extract_nested_json(res)
|
||||
if json_str is None:
|
||||
raise Exception(orjson.dumps({'message': 'Cannot parse sql from answer',
|
||||
'traceback': "Cannot parse sql from answer:\n" + res}).decode())
|
||||
sql: str
|
||||
data: dict
|
||||
try:
|
||||
data = orjson.loads(json_str)
|
||||
|
||||
if data['success']:
|
||||
sql = data['sql']
|
||||
return sql
|
||||
else:
|
||||
message = data['message']
|
||||
raise Exception(message)
|
||||
except Exception as e:
|
||||
raise e
|
||||
except Exception:
|
||||
raise Exception(orjson.dumps({'message': 'Cannot parse sql from answer',
|
||||
'traceback': "Cannot parse sql from answer:\n" + res}).decode())
|
||||
|
||||
|
||||
def extract_nested_json(text):
|
||||
stack = []
|
||||
start_index = -1
|
||||
results = []
|
||||
|
||||
for i, char in enumerate(text):
|
||||
if char in '{[':
|
||||
if not stack: # 记录起始位置
|
||||
start_index = i
|
||||
stack.append(char)
|
||||
elif char in '}]':
|
||||
if stack and ((char == '}' and stack[-1] == '{') or (char == ']' and stack[-1] == '[')):
|
||||
stack.pop()
|
||||
if not stack: # 栈空时截取完整JSON
|
||||
json_str = text[start_index:i + 1]
|
||||
try:
|
||||
orjson.loads(json_str) # 验证有效性
|
||||
results.append(json_str)
|
||||
except:
|
||||
pass
|
||||
else:
|
||||
stack = [] # 括号不匹配则重置
|
||||
if len(results) > 0 and results[0]:
|
||||
return results[0]
|
||||
return None
|
||||
|
||||
|
||||
def get_chart_type_from_sql_answer(res: str) -> Optional[str]:
|
||||
json_str = extract_nested_json(res)
|
||||
if json_str is None:
|
||||
return None
|
||||
chart_type: Optional[str]
|
||||
data: dict
|
||||
try:
|
||||
data = orjson.loads(json_str)
|
||||
|
||||
if data['success']:
|
||||
chart_type = data['chart-type']
|
||||
else:
|
||||
return None
|
||||
except Exception:
|
||||
return None
|
||||
return chart_type
|
||||
Reference in New Issue
Block a user