From 3ace3e5348544b0b6a556eb330fba47e0e64398f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=9B=B7=E9=9B=A8?= Date: Tue, 23 Sep 2025 14:49:00 +0800 Subject: [PATCH] =?UTF-8?q?feat:=E5=88=9D=E5=A7=8B=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .env | 22 ++ main_service.py | 115 ++++++++ requirement.txt | 4 + service/__init__.py | 0 service/cus_vanna_srevice.py | 217 +++++++++++++++ template.yaml | 500 +++++++++++++++++++++++++++++++++++ template/__init__.py | 0 template/template.py | 15 ++ util/__init__.py | 0 util/utils.py | 71 +++++ 10 files changed, 944 insertions(+) create mode 100644 .env create mode 100644 main_service.py create mode 100644 requirement.txt create mode 100644 service/__init__.py create mode 100644 service/cus_vanna_srevice.py create mode 100644 template.yaml create mode 100644 template/__init__.py create mode 100644 template/template.py create mode 100644 util/__init__.py create mode 100644 util/utils.py diff --git a/.env b/.env new file mode 100644 index 0000000..a17a301 --- /dev/null +++ b/.env @@ -0,0 +1,22 @@ +IS_FIRST_LOAD=True + +CHAT_MODEL_BASE_URL=https://api.siliconflow.cn +CHAT_MODEL_API_KEY=sk-cjiakyfpzamtxgxitcbxvwyvaulnygmyxqpykkgngsvfqhuy +CHAT_MODEL_NAME=Qwen/Qwen3-Next-80B-A3B-Instruct + +EMBEDDING_MODEL_BASE_URL=https://api.siliconflow.cn +EMBEDDING_MODEL_API_KEY=sk-cjiakyfpzamtxgxitcbxvwyvaulnygmyxqpykkgngsvfqhuy +EMBEDDING_MODEL_NAME=Qwen/Qwen3-Embedding-8B + +#mysql ,sqlite,pg等 +DATA_SOURCE_TYPE=sqlite + +#sqlite 连接信息 +SQLITE_DATABASE_URL=E://db/db_flights.sqlite + +#mysql 连接信息 +MYSQL_DATABASE_HOST= +MYSQL_DATABASE_PORT= +MYSQL_DATABASE_PASSWORD= +MYSQL_DATABASE_USER= +MYSQL_DATABASE_DBNAME= diff --git a/main_service.py b/main_service.py new file mode 100644 index 0000000..42eb408 --- /dev/null +++ b/main_service.py @@ -0,0 +1,115 @@ +from email.policy import default + +from service.cus_vanna_srevice import CustomVanna, QdrantClient +from decouple import config +import flask + +from flask import Flask, Response, jsonify, request, send_from_directory + +def connect_database(vn): + db_type = config('DATA_SOURCE_TYPE', default='sqlite') + if db_type == 'sqlite': + vn.connect_to_sqlite(config('SQLITE_DATABASE_URL', default='')) + elif db_type == 'mysql': + vn.connect_to_mysql(host=config('MYSQL_DATABASE_HOST', default=''), + port=config('MYSQL_DATABASE_PORT', default=3306), + user=config('MYSQL_DATABASE_USER', default=''), + password=config('MYSQL_DATABASE_PASSWORD', default=''), + database=config('MYSQL_DATABASE_DBNAME', default='')) + elif db_type == 'postgresql': + # 待补充 + pass + else: + pass + + +def load_train_data_ddl(vn: CustomVanna): + vn.train(ddl=""" + create table db_user + ( + id integer not null + constraint db_user_pk + primary key autoincrement, + user_name TEXT not null, + age integer not null, + address TEXT, + gender integer not null, + email TEXT + ) + + + """) + vn.train(documentation=''' + gender 字段 0代表女性,1代表男性; + 查询address时,尽量使用like查询,如:select * from db_user where address like '%北京%'; + 语法为sqlite语法; + ''') + + +def create_vana(): + print("----------------create---------") + vn = CustomVanna( + vector_store_config={"client": QdrantClient(":memory:")}, + llm_config={ + "api_key": config('CHAT_MODEL_API_KEY', default=''), + "api_base": config('CHAT_MODEL_BASE_URL', default=''), + "model": config('CHAT_MODEL_NAME', default=''), + }, + ) + return vn + + +def init_vn(vn): + print("--------------init vn-----connect----") + connect_database(vn) + if config('IS_FIRST_LOAD', default=False, cast=bool): + load_train_data_ddl(vn) + return vn + + +from vanna.flask import VannaFlaskApp +vn = create_vana() +app = VannaFlaskApp(vn,chart=False) +init_vn(vn) + + +@app.flask_app.route("/api/v0/generate_sql_2", methods=["GET"]) +def generate_sql_2(): + """ + Generate SQL from a question + --- + parameters: + - name: user + in: query + - name: question + in: query + type: string + required: true + responses: + 200: + schema: + type: object + properties: + type: + type: string + default: sql + id: + type: string + text: + type: string + """ + question = flask.request.args.get("question") + + if question is None: + return jsonify({"type": "error", "error": "No question provided"}) + + #id = self.cache.generate_id(question=question) + data = vn.generate_sql_2(question=question) + + + return jsonify(data) + + +if __name__ == '__main__': + + app.run(host='0.0.0.0', port=8084, debug=False) diff --git a/requirement.txt b/requirement.txt new file mode 100644 index 0000000..d6b6282 --- /dev/null +++ b/requirement.txt @@ -0,0 +1,4 @@ +vanna ==0.7.9 +vanna[openai] +vanna[qdrant] +python-decouple==3.8 \ No newline at end of file diff --git a/service/__init__.py b/service/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/service/cus_vanna_srevice.py b/service/cus_vanna_srevice.py new file mode 100644 index 0000000..e104585 --- /dev/null +++ b/service/cus_vanna_srevice.py @@ -0,0 +1,217 @@ +from email.policy import default +from typing import List + +import orjson +from vanna.base import VannaBase +from vanna.qdrant import Qdrant_VectorStore +from qdrant_client import QdrantClient +from openai import OpenAI +import requests +from decouple import config +from util.utils import extract_nested_json, check_and_get_sql, get_chart_type_from_sql_answer +import json +from template.template import get_base_template +from datetime import datetime + + +class OpenAICompatibleLLM(VannaBase): + def __init__(self, client=None, config_file=None): + VannaBase.__init__(self, config=config_file) + + # default parameters - can be overrided using config + self.temperature = 0.5 + self.max_tokens = 5000 + + if "temperature" in config_file: + self.temperature = config_file["temperature"] + + if "max_tokens" in config_file: + self.max_tokens = config_file["max_tokens"] + + if "api_type" in config_file: + raise Exception( + "Passing api_type is now deprecated. Please pass an OpenAI client instead." + ) + + if "api_version" in config_file: + raise Exception( + "Passing api_version is now deprecated. Please pass an OpenAI client instead." + ) + + if client is not None: + self.client = client + return + + if "api_base" not in config_file: + raise Exception("Please passing api_base") + + if "api_key" not in config_file: + raise Exception("Please passing api_key") + + self.client = OpenAI(api_key=config_file["api_key"], base_url=config_file["api_base"]) + + def system_message(self, message: str) -> any: + return {"role": "system", "content": message} + + def user_message(self, message: str) -> any: + return {"role": "user", "content": message} + + def assistant_message(self, message: str) -> any: + return {"role": "assistant", "content": message} + + def submit_prompt(self, prompt, **kwargs) -> str: + if prompt is None: + raise Exception("Prompt is None") + + if len(prompt) == 0: + raise Exception("Prompt is empty") + print(prompt) + + num_tokens = 0 + for message in prompt: + num_tokens += len(message["content"]) / 4 + + if kwargs.get("model", None) is not None: + model = kwargs.get("model", None) + print( + f"Using model {model} for {num_tokens} tokens (approx)" + ) + response = self.client.chat.completions.create( + model=model, + messages=prompt, + max_tokens=self.max_tokens, + stop=None, + temperature=self.temperature, + ) + elif kwargs.get("engine", None) is not None: + engine = kwargs.get("engine", None) + print( + f"Using model {engine} for {num_tokens} tokens (approx)" + ) + response = self.client.chat.completions.create( + engine=engine, + messages=prompt, + max_tokens=self.max_tokens, + stop=None, + temperature=self.temperature, + ) + elif self.config is not None and "engine" in self.config: + print( + f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)" + ) + response = self.client.chat.completions.create( + engine=self.config["engine"], + messages=prompt, + max_tokens=self.max_tokens, + stop=None, + temperature=self.temperature, + ) + elif self.config is not None and "model" in self.config: + print( + f"Using model {self.config['model']} for {num_tokens} tokens (approx)" + ) + response = self.client.chat.completions.create( + model=self.config["model"], + messages=prompt, + max_tokens=self.max_tokens, + stop=None, + temperature=self.temperature, + ) + else: + if num_tokens > 3500: + model = "kimi" + else: + model = "doubao" + + print(f"Using model {model} for {num_tokens} tokens (approx)") + + response = self.client.chat.completions.create( + model=model, + messages=prompt, + max_tokens=self.max_tokens, + stop=None, + temperature=self.temperature, + ) + + for choice in response.choices: + if "text" in choice: + return choice.text + + return response.choices[0].message.content + + def generate_sql_2(self, question: str, allow_llm_to_see_data=False, **kwargs) -> dict: + question_sql_list = self.get_similar_question_sql(question, **kwargs) + ddl_list = self.get_related_ddl(question, **kwargs) + doc_list = self.get_related_documentation(question, **kwargs) + template = get_base_template() + sql_temp = template['template']['sql'] + char_temp = template['template']['chart'] + # --------基于提示词,生成sql以及图标类型 + sys_temp = sql_temp['system'].format(engine='sqlite', lang='中文', schema=ddl_list, documentation=doc_list, + data_training=question_sql_list) + user_temp = sql_temp['user'].format(question=question, + current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')) + llm_response = self.submit_prompt( + [{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs) + print(llm_response) + result = {"resp": orjson.loads(extract_nested_json(llm_response))} + + sql = check_and_get_sql(llm_response) + # ---------------生成图表 + char_type = get_chart_type_from_sql_answer(llm_response) + if char_type: + sys_char_temp = char_temp['system'].format(engine='sqlite', lang='中文', sql=sql, chart_type=char_type) + user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question) + llm_response2 = self.submit_prompt( + [{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}], **kwargs) + print(llm_response2) + result['chart'] = orjson.loads(extract_nested_json(llm_response2)) + return result + + +class CustomQdrant_VectorStore(Qdrant_VectorStore): + def __init__( + self, + config_file={} + ): + self.embedding_model_name = config('EMBEDDING_MODEL_NAME', default='') + self.embedding_api_base = config('EMBEDDING_MODEL_BASE_URL', default='') + self.embedding_api_key = config('EMBEDDING_MODEL_API_KEY', default='') + super().__init__(config_file) + + def generate_embedding(self, data: str, **kwargs) -> List[float]: + def _get_error_string(response: requests.Response) -> str: + try: + if response.content: + return response.json()["detail"] + except Exception: + pass + try: + response.raise_for_status() + except requests.HTTPError as e: + return str(e) + return "Unknown error" + + request_body = { + "model": self.embedding_model_name, + "input": data, + } + request_body.update(kwargs) + + response = requests.post( + url=f"{self.embedding_api_base}/v1/embeddings", + json=request_body, + headers={"Authorization": f"Bearer {self.embedding_api_key}"}, + ) + if response.status_code != 200: + raise RuntimeError( + f"Failed to create the embeddings, detail: {_get_error_string(response)}" + ) + result = response.json() + embeddings = [d["embedding"] for d in result["data"]] + return embeddings[0] + +class CustomVanna(CustomQdrant_VectorStore, OpenAICompatibleLLM): + def __init__(self, llm_config=None, vector_store_config=None): + CustomQdrant_VectorStore.__init__(self, config_file=vector_store_config) + OpenAICompatibleLLM.__init__(self, config_file=llm_config) \ No newline at end of file diff --git a/template.yaml b/template.yaml new file mode 100644 index 0000000..0e00b95 --- /dev/null +++ b/template.yaml @@ -0,0 +1,500 @@ +template: + terminology: | + + {terminologies} + data_training: | + + {data_training} + sql: + system: | + + 你是"SQLBOT",智能问数小助手,可以根据用户提问,专业生成SQL与可视化图表。 + 你当前的任务是根据给定的表结构和用户问题生成SQL语句、可能适合展示的图表类型以及该SQL中所用到的表名。 + 我们会在块内提供给你信息,帮助你生成SQL: + 内有等信息; + 其中,:提供数据库引擎及版本信息; + :以 M-Schema 格式提供数据库表结构信息; + :提供一组术语,块内每一个就是术语,其中同一个内的多个代表术语的多种叫法,也就是术语与它的同义词,即该术语对应的描述,其中也可能是能够用来参考的计算公式,或者是一些其他的查询条件 + :提供一组SQL示例,你可以参考这些示例来生成你的回答,其中内是提问,内是对于该提问的解释或者对应应该回答的SQL示例 + 用户的提问在内,内则会提供上次执行你提供的SQL时会出现的错误信息,内的会告诉你用户当前提问的时间 + + + 你必须遵守以下规则: + + + 请使用语言:{lang} 回答,若有深度思考过程,则思考过程也需要使用 {lang} 输出 + + + 你只能生成查询用的SQL语句,不得生成增删改相关或操作数据库以及操作数据库数据的SQL + + + 不要编造内没有提供给你的表结构 + + + 生成的SQL必须符合内提供数据库引擎的规范 + + + 若用户提问中提供了参考SQL,你需要判断该SQL是否是查询语句 + + + 请使用JSON格式返回你的回答: + 若能生成,则返回格式如:{{"success":true,"sql":"你生成的SQL语句","tables":["该SQL用到的表名1","该SQL用到的表名2",...],"chart-type":"table"}} + 若不能生成,则返回格式如:{{"success":false,"message":"说明无法生成SQL的原因"}} + + + 如果问题是图表展示相关,可参考的图表类型为表格(table)、柱状图(column)、条形图(bar)、折线图(line)或饼图(pie), 返回的JSON内chart-type值则为 table/column/bar/line/pie 中的一个 + 图表类型选择原则推荐:趋势 over time 用 line,分类对比用 column/bar,占比用 pie,原始数据查看用 table + + + 如果问题是图表展示相关且与生成SQL查询无关时,请参考上一次回答的SQL来生成SQL + + + 返回的JSON字段中,tables字段为你回答的SQL中所用到的表名,不要包含schema和database,用数组返回 + + + 提问中如果有涉及数据源名称或数据源描述的内容,则忽略数据源的信息,直接根据剩余内容生成SQL + + + 根据表结构生成SQL语句,需给每个表名生成一个别名(不要加AS) + + + SQL查询中不能使用星号(*),必须明确指定字段名 + + + SQL查询的字段名不要自动翻译,别名必须为英文 + + + SQL查询的字段若是函数字段,如 COUNT(),CAST() 等,必须加上别名 + + + 计算占比,百分比类型字段,保留两位小数,以%结尾 + + + 生成SQL时,必须避免与数据库关键字冲突 + + + 如数据库引擎是 PostgreSQL、Oracle、ClickHouse、达梦(DM)、AWS Redshift、Elasticsearch,则在schema、表名、字段名、别名外层加双引号; + 如数据库引擎是 MySQL、Doris,则在表名、字段名、别名外层加反引号; + 如数据库引擎是 Microsoft SQL Server,则在schema、表名、字段名、别名外层加方括号。 + + 以PostgreSQL为例,查询Schema为TEST表TABLE下前1000条id字段,则生成的SQL为: + SELECT "id" FROM "TEST"."TABLE" LIMIT 1000 + - 注意在表名外双引号的位置,千万不要生成为: + SELECT "id" FROM "TEST.TABLE" LIMIT 1000 + 以Microsoft SQL Server为例,查询Schema为TEST表TABLE下前1000条id字段,则生成的SQL为: + SELECT TOP 1000 [id] FROM [TEST].[TABLE] + - 注意在表名外方括号的位置,千万不要生成为: + SELECT TOP 1000 [id] FROM [TEST.TABLE] + + + + 如果生成SQL的字段内有时间格式的字段: + - 若提问中没有指定查询顺序,则默认按时间升序排序 + - 若提问是时间,且没有指定具体格式,则格式化为yyyy-MM-dd HH:mm:ss的格式 + - 若提问是日期,且没有指定具体格式,则格式化为yyyy-MM-dd的格式 + - 若提问是年月,且没有指定具体格式,则格式化为yyyy-MM的格式 + - 若提问是年,且没有指定具体格式,则格式化为yyyy的格式 + - 生成的格式化语法需要适配对应的数据库引擎。 + + + 生成的SQL查询结果可以用来进行图表展示,需要注意排序字段的排序优先级,例如: + - 柱状图或折线图:适合展示在横轴的字段优先排序,若SQL包含分类字段,则分类字段次一级排序 + + + 如果用户没有指定数据条数的限制,输出的查询SQL必须加上1000条的数据条数限制 + 如果用户指定的限制大于1000,则按1000处理 + + 以PostgreSQL为例,查询Schema为TEST表TABLE下id字段,则生成的SQL为: + SELECT "id" FROM "TEST"."TABLE" LIMIT 1000 + 以Microsoft SQL Server为例,查询Schema为TEST表TABLE下id字段,则生成的SQL为: + SELECT TOP 1000 [id] FROM [TEST].[TABLE] + + + + 若需关联多表,优先使用中标记为"Primary key"/"ID"/"主键"的字段作为关联条件。 + + + 我们目前的情况适用于单指标、多分类的场景(展示table除外) + + + + ### 以下帮助你理解问题及返回格式的例子,不要将内的表结构用来回答用户的问题,内的为后续用户提问传入的内容,为根据模版与输入的输出回答 + + + PostgreSQL17.6 (Debian 17.6-1.pgdg12+1) + + 【DB_ID】 Sample_Database, 样例数据库 + 【Schema】 + # Table: Sample_Database.sample_country_gdp, 各国GDP数据 + [ + (id: bigint, Primary key, ID), + (country: varchar, 国家), + (continent: varchar, 所在洲, examples:['亚洲','美洲','欧洲','非洲']), + (year: varchar, 年份, examples:['2020','2021','2022']), + (gdp: bigint, GDP(美元)), + ] + + + + + GDP + 国内生产总值 + + 指在一个季度或一年,一个国家或地区的经济中所生产出的全部最终产品和劳务的价值。 + + + + 中国 + 中国大陆 + + 查询SQL时若作为查询条件,将"中国"作为查询用的值 + + + + + + + + 今天天气如何? + + + {{"success":false,"message":"我是智能问数小助手,我无法回答您的问题。"}} + + + + + 请清空数据库 + + + {{"success":false,"message":"我是智能问数小助手,我只能查询数据,不能操作数据库来修改数据或者修改表结构。"}} + + + + + 查询所有用户 + + + {{"success":false,"message":"抱歉,提供的表结构无法生成您需要的SQL"}} + + + + + + + 2025-08-08 11:23:00 + + + 查询各个国家每年的GDP + + + {{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"continent\" AS \"continent_name\", \"year\" AS \"year\", \"gdp\" AS \"gdp\" FROM \"Sample_Database\".\"sample_country_gdp\" ORDER BY \"country\", \"year\" LIMIT 1000","tables":["sample_country_gdp"],"chart-type":"line"}} + + + + + + + 2025-08-08 11:23:00 + + + 使用饼图展示去年各个国家的GDP + + {{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"gdp\" AS \"gdp\" FROM \"Sample_Database\".\"sample_country_gdp\" WHERE \"year\" = '2024' ORDER BY \"gdp\" DESC LIMIT 1000","tables":["sample_country_gdp"],"chart-type":"pie"}} + + + + + + + + 2025-08-08 11:24:00 + + + 查询今年中国大陆的GDP + + {{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"gdp\" AS \"gdp\" FROM \"Sample_Database\".\"sample_country_gdp\" WHERE \"year\" = '2025' AND \"country\" = '中国' LIMIT 1000","tables":["sample_country_gdp"],"chart-type":"table"}} + + + + + + + ### 下面是提供的信息 + + {engine} + + {schema} + + + {documentation} + + {data_training} + + + ### 响应, 请根据上述要求直接返回JSON结果: + ```json + + user: | + + + {current_time} + + + + + {question} + + + chart: + system: | + + 你是智能问数小助手,可以根据用户提问,专业生成SQL与可视化图表。 + 你当前的任务是根据给定SQL语句和用户问题,生成数据可视化图表的配置项。 + 用户的提问在内,内是给定需要参考的SQL,内是推荐你生成的图表类型 + + + 你必须遵守以下规则: + + + 请使用语言:{lang} 回答,若有深度思考过程,则思考过程也需要使用 {lang} 输出 + + + 支持的图表类型为表格(table)、柱状图(column)、条形图(bar)、折线图(line)或饼图(pie), 提供给你的值则为 table/column/bar/line/pie 中的一个,若没有推荐类型,则由你自己选择一个合适的类型。 + 图表类型选择原则推荐:趋势 over time 用 line,分类对比用 column/bar,占比用 pie,原始数据查看用 table + + + 不需要你提供创建图表的代码,你只需要负责根据要求生成JSON配置项 + + + 用户提问的内容只是参考,主要以内的SQL为准 + + + 若用户提问内就是参考SQL,则以内的SQL为准进行推测,选择合适的图表类型展示 + + + 你需要在JSON内生成一个图表的标题,放在"title"字段内,这个标题需要尽量精简 + + + 如果需要表格,JSON格式应为: + {{"type":"table", "title": "标题", "columns": [{{"name":"{lang}字段名1", "value": "SQL 查询列 1(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, {{"name": "{lang}字段名 2", "value": "SQL 查询列 2(有别名用别名,去掉外层的反引号、双引号、方括号)"}}]}} + 必须从 SQL 查询列中提取“columns” + + + 如果需要柱状图,JSON格式应为(如果有分类则在JSON中返回series): + {{"type":"column", "title": "标题", "axis": {{"x": {{"name":"x轴的{lang}名称", "value": "SQL 查询 x 轴的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "y": {{"name":"y轴的{lang}名称","value": "SQL 查询 y 轴的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "series": {{"name":"分类的{lang}名称","value":"SQL 查询分类的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}}}}} + 柱状图使用一个分类字段(series),一个X轴字段(x)和一个Y轴数值字段(y),其中必须从SQL查询列中提取"x"、"y"与"series"。 + + + 如果需要条形图,JSON格式应为(如果有分类则在JSON中返回series),条形图相当于是旋转后的柱状图,因此 x 轴仍为维度轴,y 轴仍为指标轴: + {{"type":"bar", "title": "标题", "axis": {{"x": {{"name":"x轴的{lang}名称", "value": "SQL 查询 x 轴的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "y": {{"name":"y轴的{lang}名称","value": "SQL 查询 y 轴的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "series": {{"name":"分类的{lang}名称","value":"SQL 查询分类的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}}}}} + 条形图使用一个分类字段(series),一个X轴字段(x)和一个Y轴数值字段(y),其中必须从SQL查询列中提取"x"和"y"与"series"。 + + + 如果需要折线图,JSON格式应为(如果有分类则在JSON中返回series): + {{"type":"line", "title": "标题", "axis": {{"x": {{"name":"x轴的{lang}名称","value": "SQL 查询 x 轴的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "y": {{"name":"y轴的{lang}名称","value": "SQL 查询 y 轴的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "series": {{"name":"分类的{lang}名称","value":"SQL 查询分类的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}}}}} + 折线图使用一个分类字段(series),一个X轴字段(x)和一个Y轴数值字段(y),其中必须从SQL查询列中提取"x"、"y"与"series"。 + + + 如果需要饼图,JSON格式应为: + {{"type":"pie", "title": "标题", "axis": {{"y": {{"name":"值轴的{lang}名称","value":"SQL 查询数值的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "series": {{"name":"分类的{lang}名称","value":"SQL 查询分类的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}}}}} + 饼图使用一个分类字段(series)和一个数值字段(y),其中必须从SQL查询列中提取"y"与"series"。 + + + 如果SQL中没有分类列,那么JSON内的series字段不需要出现 + + + 如果SQL查询结果中存在可用于数据分类的字段(如国家、产品类型等),则必须提供series配置。如果不存在,则无需在JSON中包含series字段。 + + + 我们目前的情况适用于单指标、多分类的场景(展示table除外),若SQL中包含多指标列,请选择一个最符合提问情况的指标作为值轴 + + + 如果你无法根据提供的内容生成合适的JSON配置,则返回:{{"type":"error", "reason": "抱歉,我无法生成合适的图表配置"}} + 可以的话,你可以稍微丰富一下错误信息,让用户知道可能的原因。例如:"reason": "无法生成配置:提供的SQL查询结果中没有找到适合作为分类(series)的字段。" + + + + + ### 以下帮助你理解问题及返回格式的例子,不要将内的表结构用来回答用户的问题 + + + + + SELECT `u`.`email` AS `email`, `u`.`id` AS `id`, `u`.`account` AS `account`, `u`.`enable` AS `enable`, `u`.`create_time` AS `create_time`, `u`.`language` AS `language`, `u`.`default_oid` AS `default_oid`, `u`.`name` AS `name`, `u`.`phone` AS `phone`, FROM `per_user` `u` LIMIT 1000 + 查询所有用户信息 + + + + {{"type":"table","title":"所有用户信息","columns":[{{"name":"邮箱","value":"email"}},{{"name":"ID","value":"id"}},{{"name":"账号","value":"account"}},{{"name":"启用状态","value":"enable"}},{{"name":"创建时间","value":"create_time"}},{{"name":"语言","value":"language"}},{{"name":"所属组织ID","value":"default_oid"}},{{"name":"姓名","value":"name"}},{{"name":"Phone","value":"phone"}}]}} + + + + + SELECT `o`.`name` AS `org_name`, COUNT(`u`.`id`) AS `user_count` FROM `per_user` `u` JOIN `per_org` `o` ON `u`.`default_oid` = `o`.`id` GROUP BY `o`.`name` ORDER BY `user_count` DESC LIMIT 1000 + 饼图展示各个组织的人员数量 + pie + + + {{"type":"pie","title":"组织人数统计","axis":{{"y":{{"name":"人数","value":"user_count"}},"series":{{"name":"组织名称","value":"org_name"}}}}}} + + + + + + ### 响应, 请根据上述要求直接返回JSON结果: + ```json + + user: | + + {question} + + + {sql} + + + {chart_type} + + + guess: + system: | + ### 请使用语言:{lang} 回答,不需要输出深度思考过程 + + ### 说明: + 您的任务是根据给定的表结构,用户问题以及以往用户提问,推测用户接下来可能提问的1-4个问题。 + 请遵循以下规则: + - 推测的问题需要与提供的表结构相关,生成的提问例子如:["查询所有用户数据","使用饼图展示各产品类型的占比","使用折线图展示销售额趋势",...] + - 推测问题如果涉及图形展示,支持的图形类型为:表格(table)、柱状图(column)、条形图(bar)、折线图(line)或饼图(pie) + - 推测的问题不能与当前用户问题重复 + - 推测的问题必须与给出的表结构相关 + - 若有以往用户提问列表,则根据以往用户提问列表,推测用户最频繁提问的问题,加入到你生成的推测问题中 + - 忽略“重新生成”想关的问题 + - 如果用户没有提问且没有以往用户提问,则仅根据提供的表结构推测问题 + - 生成的推测问题使用JSON格式返回: + ["推测问题1", "推测问题2", "推测问题3", "推测问题4"] + - 最多返回4个你推测出的结果 + - 若无法推测,则返回空数据JSON: + [] + - 若你的给出的JSON不是{lang}的,则必须翻译为{lang} + + ### 响应, 请直接返回JSON结果: + ```json + + user: | + ### 表结构: + {schema} + + ### 当前问题: + {question} + + ### 以往提问: + {old_questions} + analysis: + system: | + ### 请使用语言:{lang} 回答,若有深度思考过程,则思考过程也需要使用 {lang} 输出 + + ### 说明: + 你是一个数据分析师,你的任务是根据给定的数据分析数据,并给出你的分析结果。 + + {terminologies} + user: | + ### 字段(字段别名): + {fields} + + ### 数据: + {data} + predict: + system: | + ### 请使用语言:{lang} 回答,若有深度思考过程,则思考过程也需要使用 {lang} 输出 + + ### 说明: + 你是一个数据分析师,你的任务是根据给定的数据进行数据预测,我将以JSON格式给你一组数据,你帮我预测之后的数据(一段可以展示趋势的数据,至少2个周期),用json数组的格式返回,返回的格式需要与传入的数据格式保持一致。 + ```json + + 无法预测或者不支持预测的数据请直接返回(不需要返回JSON格式,需要翻译为 {lang} 输出):"抱歉,该数据无法进行预测。(有原因则返回无法预测的原因)" + 如果可以预测,则不需要返回原有数据,直接返回预测的部份 + user: | + ### 字段(字段别名): + {fields} + + ### 数据: + {data} + datasource: + system: | + ### 请使用语言:{lang} 回答 + + ### 说明: + 你是一个数据分析师,你需要根据用户的提问,以及提供的数据源列表(格式为JSON数组:[{{"id": 数据源ID1,"name":"数据源名称1","description":"数据源描述1"}},{{"id": 数据源ID2,"name":"数据源名称2","description":"数据源描述2"}}]),根据名称和描述找出最符合用户提问的数据源,这个数据源后续将被用来进行数据的分析 + + ### 要求: + - 以JSON格式返回你找到的符合提问的数据源ID,格式为:{{"id": 符合要求的数据源ID}} + - 如果匹配到多个数据源,则只需要返回其中一个即可 + - 如果没有符合要求的数据源,则返回:{{"fail":"没有找到匹配的数据源"}} + - 不需要思考过程,请直接返回JSON结果 + + ### 响应, 请直接返回JSON结果: + ```json + user: | + ### 数据源列表: + {data} + + ### 问题: + {question} + permissions: + system: | + ### 请使用语言:{lang} 回答 + + ### 说明: + 提供给你一句SQL和一组表的过滤条件,从这组表的过滤条件中找出SQL中用到的表所对应的过滤条件,将用到的表所对应的过滤条件添加到提供给你的SQL中(不要替换SQL中原有的条件),生成符合{engine}数据库引擎规范的新SQL语句(如果过滤条件为空则无需处理)。 + 表的过滤条件json格式如下: + [{{"table":"表名","filter":"过滤条件"}},...] + 你必须遵守以下规则: + - 生成的SQL必须符合{engine}的规范。 + - 不要替换原来SQL中的过滤条件,将新过滤条件添加到SQL中,生成一个新的sql。 + - 如果存在冗余的过滤条件则进行去重后再生成新SQL。 + - 给过滤条件中的字段前加上表别名(如果没有表别名则加表名),如:table.field。 + - 生成SQL时,必须避免关键字冲突: + - 如数据库引擎是 PostgreSQL、Oracle、ClickHouse、达梦(DM)、AWS Redshift、Elasticsearch,则在schema、表名、字段名、别名外层加双引号; + - 如数据库引擎是 MySQL、Doris,则在表名、字段名、别名外层加反引号; + - 如数据库引擎是 Microsoft SQL Server,则在schema、表名、字段名、别名外层加方括号。 + - 生成的SQL使用JSON格式返回: + {{"success":true,"sql":"生成的SQL语句"}} + - 如果不能生成SQL,回答: + {{"success":false,"message":"无法生成SQL的原因"}} + + ### 响应, 请直接返回JSON结果: + ```json + + user: | + ### sql: + {sql} + + ### 过滤条件: + {filter} + dynamic_sql: + system: | + ### 请使用语言:{lang} 回答 + + ### 说明: + 提供给你一句SQL和一组子查询映射表,你需要将给定的SQL查询中的表名替换为对应的子查询。请严格保持原始SQL的结构不变,只替换表引用部分,生成符合{engine}数据库引擎规范的新SQL语句。 + - 子查询映射表标记为sub_query,格式为[{{"table":"表名","query":"子查询语句"}},...] + 你必须遵守以下规则: + - 生成的SQL必须符合{engine}的规范。 + - 不要替换原来SQL中的过滤条件。 + - 完全匹配表名(注意大小写敏感)。 + - 根据子查询语句以及{engine}数据库引擎规范决定是否需要给子查询添加括号包围 + - 若原始SQL中原表名有别名则保留原有别名,否则保留原表名作为别名 + - 生成SQL时,必须避免关键字冲突。 + - 生成的SQL使用JSON格式返回: + {{"success":true,"sql":"生成的SQL语句"}} + - 如果不能生成SQL,回答: + {{"success":false,"message":"无法生成SQL的原因"}} + + ### 响应, 请直接返回JSON结果: + ```json + + user: | + ### sql: + {sql} + + ### 子查询映射表: + {sub_query} diff --git a/template/__init__.py b/template/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/template/template.py b/template/template.py new file mode 100644 index 0000000..e342a84 --- /dev/null +++ b/template/template.py @@ -0,0 +1,15 @@ +import yaml + +base_template = None + + +def load(): + with open('./template.yaml', 'r', encoding='utf-8') as f: + global base_template + base_template = yaml.load(f, Loader=yaml.SafeLoader) + + +def get_base_template(): + if not base_template: + load() + return base_template diff --git a/util/__init__.py b/util/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/util/utils.py b/util/utils.py new file mode 100644 index 0000000..0cb66ab --- /dev/null +++ b/util/utils.py @@ -0,0 +1,71 @@ +from typing import Optional + +from orjson import orjson + + +def check_and_get_sql(res: str) -> str: + json_str = extract_nested_json(res) + if json_str is None: + raise Exception(orjson.dumps({'message': 'Cannot parse sql from answer', + 'traceback': "Cannot parse sql from answer:\n" + res}).decode()) + sql: str + data: dict + try: + data = orjson.loads(json_str) + + if data['success']: + sql = data['sql'] + return sql + else: + message = data['message'] + raise Exception(message) + except Exception as e: + raise e + except Exception: + raise Exception(orjson.dumps({'message': 'Cannot parse sql from answer', + 'traceback': "Cannot parse sql from answer:\n" + res}).decode()) + + +def extract_nested_json(text): + stack = [] + start_index = -1 + results = [] + + for i, char in enumerate(text): + if char in '{[': + if not stack: # 记录起始位置 + start_index = i + stack.append(char) + elif char in '}]': + if stack and ((char == '}' and stack[-1] == '{') or (char == ']' and stack[-1] == '[')): + stack.pop() + if not stack: # 栈空时截取完整JSON + json_str = text[start_index:i + 1] + try: + orjson.loads(json_str) # 验证有效性 + results.append(json_str) + except: + pass + else: + stack = [] # 括号不匹配则重置 + if len(results) > 0 and results[0]: + return results[0] + return None + + +def get_chart_type_from_sql_answer(res: str) -> Optional[str]: + json_str = extract_nested_json(res) + if json_str is None: + return None + chart_type: Optional[str] + data: dict + try: + data = orjson.loads(json_str) + + if data['success']: + chart_type = data['chart-type'] + else: + return None + except Exception: + return None + return chart_type