feat:初始化
This commit is contained in:
		
							
								
								
									
										22
									
								
								.env
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								.env
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,22 @@ | ||||
| IS_FIRST_LOAD=True | ||||
|  | ||||
| CHAT_MODEL_BASE_URL=https://api.siliconflow.cn | ||||
| CHAT_MODEL_API_KEY=sk-cjiakyfpzamtxgxitcbxvwyvaulnygmyxqpykkgngsvfqhuy | ||||
| CHAT_MODEL_NAME=Qwen/Qwen3-Next-80B-A3B-Instruct | ||||
|  | ||||
| EMBEDDING_MODEL_BASE_URL=https://api.siliconflow.cn | ||||
| EMBEDDING_MODEL_API_KEY=sk-cjiakyfpzamtxgxitcbxvwyvaulnygmyxqpykkgngsvfqhuy | ||||
| EMBEDDING_MODEL_NAME=Qwen/Qwen3-Embedding-8B | ||||
|  | ||||
| #mysql ,sqlite,pg等 | ||||
| DATA_SOURCE_TYPE=sqlite | ||||
|  | ||||
| #sqlite 连接信息 | ||||
| SQLITE_DATABASE_URL=E://db/db_flights.sqlite | ||||
|  | ||||
| #mysql 连接信息 | ||||
| MYSQL_DATABASE_HOST= | ||||
| MYSQL_DATABASE_PORT= | ||||
| MYSQL_DATABASE_PASSWORD= | ||||
| MYSQL_DATABASE_USER= | ||||
| MYSQL_DATABASE_DBNAME= | ||||
							
								
								
									
										115
									
								
								main_service.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										115
									
								
								main_service.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,115 @@ | ||||
| from email.policy import default | ||||
|  | ||||
| from service.cus_vanna_srevice import CustomVanna, QdrantClient | ||||
| from decouple import config | ||||
| import flask | ||||
|  | ||||
| from flask import Flask, Response, jsonify, request, send_from_directory | ||||
|  | ||||
| def connect_database(vn): | ||||
|     db_type = config('DATA_SOURCE_TYPE', default='sqlite') | ||||
|     if db_type == 'sqlite': | ||||
|         vn.connect_to_sqlite(config('SQLITE_DATABASE_URL', default='')) | ||||
|     elif db_type == 'mysql': | ||||
|         vn.connect_to_mysql(host=config('MYSQL_DATABASE_HOST', default=''), | ||||
|                             port=config('MYSQL_DATABASE_PORT', default=3306), | ||||
|                             user=config('MYSQL_DATABASE_USER', default=''), | ||||
|                             password=config('MYSQL_DATABASE_PASSWORD', default=''), | ||||
|                             database=config('MYSQL_DATABASE_DBNAME', default='')) | ||||
|     elif db_type == 'postgresql': | ||||
|         # 待补充 | ||||
|         pass | ||||
|     else: | ||||
|         pass | ||||
|  | ||||
|  | ||||
| def load_train_data_ddl(vn: CustomVanna): | ||||
|     vn.train(ddl=""" | ||||
|                  create table db_user | ||||
|                  ( | ||||
|                      id        integer not null | ||||
|                          constraint db_user_pk | ||||
|                              primary key autoincrement, | ||||
|                      user_name TEXT    not null, | ||||
|                      age       integer not null, | ||||
|                      address   TEXT, | ||||
|                      gender    integer not null, | ||||
|                      email     TEXT | ||||
|                  ) | ||||
|  | ||||
|  | ||||
|                  """) | ||||
|     vn.train(documentation=''' | ||||
|                 gender 字段 0代表女性,1代表男性; | ||||
|                 查询address时,尽量使用like查询,如:select * from db_user where address like '%北京%'; | ||||
|                 语法为sqlite语法; | ||||
|         ''') | ||||
|  | ||||
|  | ||||
| def create_vana(): | ||||
|     print("----------------create---------") | ||||
|     vn = CustomVanna( | ||||
|         vector_store_config={"client": QdrantClient(":memory:")}, | ||||
|         llm_config={ | ||||
|             "api_key": config('CHAT_MODEL_API_KEY', default=''), | ||||
|             "api_base": config('CHAT_MODEL_BASE_URL', default=''), | ||||
|             "model": config('CHAT_MODEL_NAME', default=''), | ||||
|         }, | ||||
|     ) | ||||
|     return vn | ||||
|  | ||||
|  | ||||
| def init_vn(vn): | ||||
|     print("--------------init vn-----connect----") | ||||
|     connect_database(vn) | ||||
|     if config('IS_FIRST_LOAD', default=False, cast=bool): | ||||
|         load_train_data_ddl(vn) | ||||
|     return vn | ||||
|  | ||||
|  | ||||
| from vanna.flask import VannaFlaskApp | ||||
| vn = create_vana() | ||||
| app = VannaFlaskApp(vn,chart=False) | ||||
| init_vn(vn) | ||||
|  | ||||
|  | ||||
| @app.flask_app.route("/api/v0/generate_sql_2", methods=["GET"]) | ||||
| def generate_sql_2(): | ||||
|     """ | ||||
|     Generate SQL from a question | ||||
|     --- | ||||
|     parameters: | ||||
|       - name: user | ||||
|         in: query | ||||
|       - name: question | ||||
|         in: query | ||||
|         type: string | ||||
|         required: true | ||||
|     responses: | ||||
|       200: | ||||
|         schema: | ||||
|           type: object | ||||
|           properties: | ||||
|             type: | ||||
|               type: string | ||||
|               default: sql | ||||
|             id: | ||||
|               type: string | ||||
|             text: | ||||
|               type: string | ||||
|     """ | ||||
|     question = flask.request.args.get("question") | ||||
|  | ||||
|     if question is None: | ||||
|         return jsonify({"type": "error", "error": "No question provided"}) | ||||
|  | ||||
|     #id = self.cache.generate_id(question=question) | ||||
|     data = vn.generate_sql_2(question=question) | ||||
|  | ||||
|  | ||||
|     return jsonify(data) | ||||
|  | ||||
|  | ||||
| if __name__ == '__main__': | ||||
|  | ||||
|     app.run(host='0.0.0.0', port=8084, debug=False) | ||||
							
								
								
									
										4
									
								
								requirement.txt
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										4
									
								
								requirement.txt
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,4 @@ | ||||
| vanna ==0.7.9 | ||||
| vanna[openai] | ||||
| vanna[qdrant] | ||||
| python-decouple==3.8 | ||||
							
								
								
									
										0
									
								
								service/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								service/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										217
									
								
								service/cus_vanna_srevice.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										217
									
								
								service/cus_vanna_srevice.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,217 @@ | ||||
| from email.policy import default | ||||
| from typing import List | ||||
|  | ||||
| import orjson | ||||
| from vanna.base import VannaBase | ||||
| from vanna.qdrant import Qdrant_VectorStore | ||||
| from qdrant_client import QdrantClient | ||||
| from openai import OpenAI | ||||
| import requests | ||||
| from decouple import config | ||||
| from util.utils import extract_nested_json, check_and_get_sql, get_chart_type_from_sql_answer | ||||
| import json | ||||
| from template.template import get_base_template | ||||
| from datetime import datetime | ||||
|  | ||||
|  | ||||
| class OpenAICompatibleLLM(VannaBase): | ||||
|     def __init__(self, client=None, config_file=None): | ||||
|         VannaBase.__init__(self, config=config_file) | ||||
|  | ||||
|         # default parameters - can be overrided using config | ||||
|         self.temperature = 0.5 | ||||
|         self.max_tokens = 5000 | ||||
|  | ||||
|         if "temperature" in config_file: | ||||
|             self.temperature = config_file["temperature"] | ||||
|  | ||||
|         if "max_tokens" in config_file: | ||||
|             self.max_tokens = config_file["max_tokens"] | ||||
|  | ||||
|         if "api_type" in config_file: | ||||
|             raise Exception( | ||||
|                 "Passing api_type is now deprecated. Please pass an OpenAI client instead." | ||||
|             ) | ||||
|  | ||||
|         if "api_version" in config_file: | ||||
|             raise Exception( | ||||
|                 "Passing api_version is now deprecated. Please pass an OpenAI client instead." | ||||
|             ) | ||||
|  | ||||
|         if client is not None: | ||||
|             self.client = client | ||||
|             return | ||||
|  | ||||
|         if "api_base" not in config_file: | ||||
|             raise Exception("Please passing api_base") | ||||
|  | ||||
|         if "api_key" not in config_file: | ||||
|             raise Exception("Please passing api_key") | ||||
|  | ||||
|         self.client = OpenAI(api_key=config_file["api_key"], base_url=config_file["api_base"]) | ||||
|  | ||||
|     def system_message(self, message: str) -> any: | ||||
|         return {"role": "system", "content": message} | ||||
|  | ||||
|     def user_message(self, message: str) -> any: | ||||
|         return {"role": "user", "content": message} | ||||
|  | ||||
|     def assistant_message(self, message: str) -> any: | ||||
|         return {"role": "assistant", "content": message} | ||||
|  | ||||
|     def submit_prompt(self, prompt, **kwargs) -> str: | ||||
|         if prompt is None: | ||||
|             raise Exception("Prompt is None") | ||||
|  | ||||
|         if len(prompt) == 0: | ||||
|             raise Exception("Prompt is empty") | ||||
|         print(prompt) | ||||
|  | ||||
|         num_tokens = 0 | ||||
|         for message in prompt: | ||||
|             num_tokens += len(message["content"]) / 4 | ||||
|  | ||||
|         if kwargs.get("model", None) is not None: | ||||
|             model = kwargs.get("model", None) | ||||
|             print( | ||||
|                 f"Using model {model} for {num_tokens} tokens (approx)" | ||||
|             ) | ||||
|             response = self.client.chat.completions.create( | ||||
|                 model=model, | ||||
|                 messages=prompt, | ||||
|                 max_tokens=self.max_tokens, | ||||
|                 stop=None, | ||||
|                 temperature=self.temperature, | ||||
|             ) | ||||
|         elif kwargs.get("engine", None) is not None: | ||||
|             engine = kwargs.get("engine", None) | ||||
|             print( | ||||
|                 f"Using model {engine} for {num_tokens} tokens (approx)" | ||||
|             ) | ||||
|             response = self.client.chat.completions.create( | ||||
|                 engine=engine, | ||||
|                 messages=prompt, | ||||
|                 max_tokens=self.max_tokens, | ||||
|                 stop=None, | ||||
|                 temperature=self.temperature, | ||||
|             ) | ||||
|         elif self.config is not None and "engine" in self.config: | ||||
|             print( | ||||
|                 f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)" | ||||
|             ) | ||||
|             response = self.client.chat.completions.create( | ||||
|                 engine=self.config["engine"], | ||||
|                 messages=prompt, | ||||
|                 max_tokens=self.max_tokens, | ||||
|                 stop=None, | ||||
|                 temperature=self.temperature, | ||||
|             ) | ||||
|         elif self.config is not None and "model" in self.config: | ||||
|             print( | ||||
|                 f"Using model {self.config['model']} for {num_tokens} tokens (approx)" | ||||
|             ) | ||||
|             response = self.client.chat.completions.create( | ||||
|                 model=self.config["model"], | ||||
|                 messages=prompt, | ||||
|                 max_tokens=self.max_tokens, | ||||
|                 stop=None, | ||||
|                 temperature=self.temperature, | ||||
|             ) | ||||
|         else: | ||||
|             if num_tokens > 3500: | ||||
|                 model = "kimi" | ||||
|             else: | ||||
|                 model = "doubao" | ||||
|  | ||||
|             print(f"Using model {model} for {num_tokens} tokens (approx)") | ||||
|  | ||||
|             response = self.client.chat.completions.create( | ||||
|                 model=model, | ||||
|                 messages=prompt, | ||||
|                 max_tokens=self.max_tokens, | ||||
|                 stop=None, | ||||
|                 temperature=self.temperature, | ||||
|             ) | ||||
|  | ||||
|         for choice in response.choices: | ||||
|             if "text" in choice: | ||||
|                 return choice.text | ||||
|  | ||||
|         return response.choices[0].message.content | ||||
|  | ||||
|     def generate_sql_2(self, question: str, allow_llm_to_see_data=False, **kwargs) -> dict: | ||||
|         question_sql_list = self.get_similar_question_sql(question, **kwargs) | ||||
|         ddl_list = self.get_related_ddl(question, **kwargs) | ||||
|         doc_list = self.get_related_documentation(question, **kwargs) | ||||
|         template = get_base_template() | ||||
|         sql_temp = template['template']['sql'] | ||||
|         char_temp = template['template']['chart'] | ||||
|         # --------基于提示词,生成sql以及图标类型 | ||||
|         sys_temp = sql_temp['system'].format(engine='sqlite', lang='中文', schema=ddl_list, documentation=doc_list, | ||||
|                                              data_training=question_sql_list) | ||||
|         user_temp = sql_temp['user'].format(question=question, | ||||
|                                             current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')) | ||||
|         llm_response = self.submit_prompt( | ||||
|             [{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs) | ||||
|         print(llm_response) | ||||
|         result = {"resp": orjson.loads(extract_nested_json(llm_response))} | ||||
|  | ||||
|         sql = check_and_get_sql(llm_response) | ||||
|         # ---------------生成图表 | ||||
|         char_type = get_chart_type_from_sql_answer(llm_response) | ||||
|         if char_type: | ||||
|             sys_char_temp = char_temp['system'].format(engine='sqlite', lang='中文', sql=sql, chart_type=char_type) | ||||
|             user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question) | ||||
|             llm_response2 = self.submit_prompt( | ||||
|                 [{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}], **kwargs) | ||||
|             print(llm_response2) | ||||
|             result['chart'] = orjson.loads(extract_nested_json(llm_response2)) | ||||
|         return result | ||||
|  | ||||
|  | ||||
| class CustomQdrant_VectorStore(Qdrant_VectorStore): | ||||
|     def __init__( | ||||
|             self, | ||||
|             config_file={} | ||||
|     ): | ||||
|         self.embedding_model_name = config('EMBEDDING_MODEL_NAME', default='') | ||||
|         self.embedding_api_base = config('EMBEDDING_MODEL_BASE_URL', default='') | ||||
|         self.embedding_api_key = config('EMBEDDING_MODEL_API_KEY', default='') | ||||
|         super().__init__(config_file) | ||||
|  | ||||
|     def generate_embedding(self, data: str, **kwargs) -> List[float]: | ||||
|         def _get_error_string(response: requests.Response) -> str: | ||||
|             try: | ||||
|                 if response.content: | ||||
|                     return response.json()["detail"] | ||||
|             except Exception: | ||||
|                 pass | ||||
|             try: | ||||
|                 response.raise_for_status() | ||||
|             except requests.HTTPError as e: | ||||
|                 return str(e) | ||||
|             return "Unknown error" | ||||
|  | ||||
|         request_body = { | ||||
|             "model": self.embedding_model_name, | ||||
|             "input": data, | ||||
|         } | ||||
|         request_body.update(kwargs) | ||||
|  | ||||
|         response = requests.post( | ||||
|             url=f"{self.embedding_api_base}/v1/embeddings", | ||||
|             json=request_body, | ||||
|             headers={"Authorization": f"Bearer {self.embedding_api_key}"}, | ||||
|         ) | ||||
|         if response.status_code != 200: | ||||
|             raise RuntimeError( | ||||
|                 f"Failed to create the embeddings, detail: {_get_error_string(response)}" | ||||
|             ) | ||||
|         result = response.json() | ||||
|         embeddings = [d["embedding"] for d in result["data"]] | ||||
|         return embeddings[0] | ||||
|  | ||||
| class CustomVanna(CustomQdrant_VectorStore, OpenAICompatibleLLM): | ||||
|     def __init__(self, llm_config=None, vector_store_config=None): | ||||
|         CustomQdrant_VectorStore.__init__(self, config_file=vector_store_config) | ||||
|         OpenAICompatibleLLM.__init__(self, config_file=llm_config) | ||||
							
								
								
									
										500
									
								
								template.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										500
									
								
								template.yaml
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,500 @@ | ||||
| template: | ||||
|   terminology: | | ||||
|      | ||||
|     {terminologies} | ||||
|   data_training: | | ||||
|      | ||||
|     {data_training} | ||||
|   sql: | ||||
|     system: | | ||||
|       <Instruction> | ||||
|         你是"SQLBOT",智能问数小助手,可以根据用户提问,专业生成SQL与可视化图表。 | ||||
|         你当前的任务是根据给定的表结构和用户问题生成SQL语句、可能适合展示的图表类型以及该SQL中所用到的表名。 | ||||
|         我们会在<Infos>块内提供给你信息,帮助你生成SQL: | ||||
|           <Infos>内有<db-engine><m-schema><terminologies>等信息; | ||||
|           其中,<db-engine>:提供数据库引擎及版本信息; | ||||
|           <m-schema>:以 M-Schema 格式提供数据库表结构信息; | ||||
|           <terminologies>:提供一组术语,块内每一个<terminology>就是术语,其中同一个<words>内的多个<word>代表术语的多种叫法,也就是术语与它的同义词,<description>即该术语对应的描述,其中也可能是能够用来参考的计算公式,或者是一些其他的查询条件 | ||||
|           <sql-examples>:提供一组SQL示例,你可以参考这些示例来生成你的回答,其中<question>内是提问,<suggestion-answer>内是对于该<question>提问的解释或者对应应该回答的SQL示例 | ||||
|         用户的提问在<user-question>内,<error-msg>内则会提供上次执行你提供的SQL时会出现的错误信息,<background-infos>内的<current-time>会告诉你用户当前提问的时间 | ||||
|       </Instruction> | ||||
|        | ||||
|       你必须遵守以下规则: | ||||
|       <Rules> | ||||
|         <rule> | ||||
|           请使用语言:{lang} 回答,若有深度思考过程,则思考过程也需要使用 {lang} 输出 | ||||
|         </rule> | ||||
|         <rule> | ||||
|           你只能生成查询用的SQL语句,不得生成增删改相关或操作数据库以及操作数据库数据的SQL | ||||
|         </rule> | ||||
|         <rule> | ||||
|           不要编造<m-schema>内没有提供给你的表结构 | ||||
|         </rule> | ||||
|         <rule> | ||||
|           生成的SQL必须符合<db-engine>内提供数据库引擎的规范 | ||||
|         </rule> | ||||
|         <rule> | ||||
|           若用户提问中提供了参考SQL,你需要判断该SQL是否是查询语句 | ||||
|         </rule> | ||||
|         <rule> | ||||
|           请使用JSON格式返回你的回答: | ||||
|           若能生成,则返回格式如:{{"success":true,"sql":"你生成的SQL语句","tables":["该SQL用到的表名1","该SQL用到的表名2",...],"chart-type":"table"}} | ||||
|           若不能生成,则返回格式如:{{"success":false,"message":"说明无法生成SQL的原因"}} | ||||
|         </rule> | ||||
|         <rule> | ||||
|           如果问题是图表展示相关,可参考的图表类型为表格(table)、柱状图(column)、条形图(bar)、折线图(line)或饼图(pie), 返回的JSON内chart-type值则为 table/column/bar/line/pie 中的一个 | ||||
|           图表类型选择原则推荐:趋势 over time 用 line,分类对比用 column/bar,占比用 pie,原始数据查看用 table | ||||
|         </rule> | ||||
|         <rule> | ||||
|           如果问题是图表展示相关且与生成SQL查询无关时,请参考上一次回答的SQL来生成SQL | ||||
|         </rule> | ||||
|         <rule> | ||||
|           返回的JSON字段中,tables字段为你回答的SQL中所用到的表名,不要包含schema和database,用数组返回 | ||||
|         </rule> | ||||
|         <rule> | ||||
|           提问中如果有涉及数据源名称或数据源描述的内容,则忽略数据源的信息,直接根据剩余内容生成SQL | ||||
|         </rule> | ||||
|         <rule> | ||||
|           根据表结构生成SQL语句,需给每个表名生成一个别名(不要加AS) | ||||
|         </rule> | ||||
|         <rule> | ||||
|           SQL查询中不能使用星号(*),必须明确指定字段名 | ||||
|         </rule> | ||||
|         <rule> | ||||
|           SQL查询的字段名不要自动翻译,别名必须为英文 | ||||
|         </rule> | ||||
|         <rule> | ||||
|           SQL查询的字段若是函数字段,如 COUNT(),CAST() 等,必须加上别名 | ||||
|         </rule> | ||||
|         <rule> | ||||
|           计算占比,百分比类型字段,保留两位小数,以%结尾 | ||||
|         </rule> | ||||
|         <rule> | ||||
|           生成SQL时,必须避免与数据库关键字冲突 | ||||
|         </rule> | ||||
|         <rule> | ||||
|           如数据库引擎是 PostgreSQL、Oracle、ClickHouse、达梦(DM)、AWS Redshift、Elasticsearch,则在schema、表名、字段名、别名外层加双引号; | ||||
|           如数据库引擎是 MySQL、Doris,则在表名、字段名、别名外层加反引号; | ||||
|           如数据库引擎是 Microsoft SQL Server,则在schema、表名、字段名、别名外层加方括号。 | ||||
|           <example> | ||||
|           以PostgreSQL为例,查询Schema为TEST表TABLE下前1000条id字段,则生成的SQL为: | ||||
|             SELECT "id" FROM "TEST"."TABLE" LIMIT 1000 | ||||
|             - 注意在表名外双引号的位置,千万不要生成为: | ||||
|               SELECT "id" FROM "TEST.TABLE" LIMIT 1000 | ||||
|           以Microsoft SQL Server为例,查询Schema为TEST表TABLE下前1000条id字段,则生成的SQL为: | ||||
|             SELECT TOP 1000 [id] FROM [TEST].[TABLE] | ||||
|             - 注意在表名外方括号的位置,千万不要生成为: | ||||
|               SELECT TOP 1000 [id] FROM [TEST.TABLE] | ||||
|           </example> | ||||
|         </rule> | ||||
|         <rule> | ||||
|           如果生成SQL的字段内有时间格式的字段: | ||||
|           - 若提问中没有指定查询顺序,则默认按时间升序排序 | ||||
|           - 若提问是时间,且没有指定具体格式,则格式化为yyyy-MM-dd HH:mm:ss的格式 | ||||
|           - 若提问是日期,且没有指定具体格式,则格式化为yyyy-MM-dd的格式 | ||||
|           - 若提问是年月,且没有指定具体格式,则格式化为yyyy-MM的格式 | ||||
|           - 若提问是年,且没有指定具体格式,则格式化为yyyy的格式 | ||||
|           - 生成的格式化语法需要适配对应的数据库引擎。 | ||||
|         </rule> | ||||
|         <rule> | ||||
|           生成的SQL查询结果可以用来进行图表展示,需要注意排序字段的排序优先级,例如: | ||||
|             - 柱状图或折线图:适合展示在横轴的字段优先排序,若SQL包含分类字段,则分类字段次一级排序 | ||||
|         </rule> | ||||
|         <rule> | ||||
|           如果用户没有指定数据条数的限制,输出的查询SQL必须加上1000条的数据条数限制 | ||||
|           如果用户指定的限制大于1000,则按1000处理 | ||||
|           <example> | ||||
|           以PostgreSQL为例,查询Schema为TEST表TABLE下id字段,则生成的SQL为: | ||||
|             SELECT "id" FROM "TEST"."TABLE" LIMIT 1000 | ||||
|           以Microsoft SQL Server为例,查询Schema为TEST表TABLE下id字段,则生成的SQL为: | ||||
|             SELECT TOP 1000 [id] FROM [TEST].[TABLE] | ||||
|           </example> | ||||
|         </rule> | ||||
|         <rule> | ||||
|           若需关联多表,优先使用<m-schema>中标记为"Primary key"/"ID"/"主键"的字段作为关联条件。 | ||||
|         </rule> | ||||
|         <rule> | ||||
|           我们目前的情况适用于单指标、多分类的场景(展示table除外) | ||||
|         </rule> | ||||
|       </Rules> | ||||
|        | ||||
|       ### 以下<example>帮助你理解问题及返回格式的例子,不要将<example>内的表结构用来回答用户的问题,<example>内的<input>为后续用户提问传入的内容,<output>为根据模版与输入的输出回答 | ||||
|       <example> | ||||
|         <Info> | ||||
|         <db-engine> PostgreSQL17.6 (Debian 17.6-1.pgdg12+1) </db-engine> | ||||
|         <m-schema> | ||||
|         【DB_ID】 Sample_Database, 样例数据库 | ||||
|         【Schema】 | ||||
|         # Table: Sample_Database.sample_country_gdp, 各国GDP数据 | ||||
|         [ | ||||
|         (id: bigint, Primary key, ID), | ||||
|         (country: varchar, 国家), | ||||
|         (continent: varchar, 所在洲, examples:['亚洲','美洲','欧洲','非洲']), | ||||
|         (year: varchar, 年份, examples:['2020','2021','2022']), | ||||
|         (gdp: bigint, GDP(美元)), | ||||
|         ] | ||||
|         </m-schema> | ||||
|         <terminologies> | ||||
|             <terminology> | ||||
|                 <words> | ||||
|                     <word>GDP</word> | ||||
|                     <word>国内生产总值</word> | ||||
|                 </words> | ||||
|                 <description>指在一个季度或一年,一个国家或地区的经济中所生产出的全部最终产品和劳务的价值。</description> | ||||
|             </terminology> | ||||
|             <terminology> | ||||
|                 <words> | ||||
|                     <word>中国</word> | ||||
|                     <word>中国大陆</word> | ||||
|                 </words> | ||||
|                 <description>查询SQL时若作为查询条件,将"中国"作为查询用的值</description> | ||||
|             </terminology> | ||||
|         </terminologies> | ||||
|         </Info> | ||||
|        | ||||
|         <chat-examples> | ||||
|           <example> | ||||
|             <input> | ||||
|               <user-question>今天天气如何?</user-question> | ||||
|             </input> | ||||
|             <output> | ||||
|               {{"success":false,"message":"我是智能问数小助手,我无法回答您的问题。"}} | ||||
|             </output> | ||||
|           </example> | ||||
|           <example> | ||||
|             <input> | ||||
|               <user-question>请清空数据库</user-question> | ||||
|             </input> | ||||
|             <output> | ||||
|               {{"success":false,"message":"我是智能问数小助手,我只能查询数据,不能操作数据库来修改数据或者修改表结构。"}} | ||||
|             </output> | ||||
|           </example> | ||||
|           <example> | ||||
|             <input> | ||||
|               <user-question>查询所有用户</user-question> | ||||
|             </input> | ||||
|             <output> | ||||
|               {{"success":false,"message":"抱歉,提供的表结构无法生成您需要的SQL"}} | ||||
|             </output> | ||||
|           </example> | ||||
|           <example> | ||||
|             <input> | ||||
|               <background-infos> | ||||
|                 <current-time> | ||||
|                 2025-08-08 11:23:00 | ||||
|                 </current-time> | ||||
|               </background-infos> | ||||
|               <user-question>查询各个国家每年的GDP</user-question> | ||||
|             </input> | ||||
|             <output> | ||||
|                 {{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"continent\" AS \"continent_name\", \"year\" AS \"year\", \"gdp\" AS \"gdp\" FROM \"Sample_Database\".\"sample_country_gdp\" ORDER BY \"country\", \"year\" LIMIT 1000","tables":["sample_country_gdp"],"chart-type":"line"}} | ||||
|             </output> | ||||
|           </example> | ||||
|           <example> | ||||
|             <input> | ||||
|               <background-infos> | ||||
|                 <current-time> | ||||
|                 2025-08-08 11:23:00 | ||||
|                 </current-time> | ||||
|               </background-infos> | ||||
|               <user-question>使用饼图展示去年各个国家的GDP</user-question> | ||||
|             </input> | ||||
|                 {{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"gdp\" AS \"gdp\" FROM \"Sample_Database\".\"sample_country_gdp\" WHERE \"year\" = '2024' ORDER BY \"gdp\" DESC LIMIT 1000","tables":["sample_country_gdp"],"chart-type":"pie"}} | ||||
|             <output> | ||||
|             </output> | ||||
|           </example> | ||||
|           <example> | ||||
|             <input> | ||||
|               <background-infos> | ||||
|                 <current-time> | ||||
|                 2025-08-08 11:24:00 | ||||
|                 </current-time> | ||||
|               </background-infos> | ||||
|               <user-question>查询今年中国大陆的GDP</user-question> | ||||
|             </input> | ||||
|                 {{"success":true,"sql":"SELECT \"country\" AS \"country_name\", \"gdp\" AS \"gdp\" FROM \"Sample_Database\".\"sample_country_gdp\" WHERE \"year\" = '2025' AND \"country\" = '中国' LIMIT 1000","tables":["sample_country_gdp"],"chart-type":"table"}} | ||||
|             <output> | ||||
|             </output> | ||||
|           </example> | ||||
|         </chat-examples> | ||||
|       </example> | ||||
|        | ||||
|       ### 下面是提供的信息 | ||||
|       <Info> | ||||
|       <db-engine> {engine} </db-engine> | ||||
|       <m-schema> | ||||
|       {schema} | ||||
|       </m-schema> | ||||
|       <documentation> | ||||
|       {documentation} | ||||
|       </documentation> | ||||
|       {data_training} | ||||
|       </Info> | ||||
|        | ||||
|       ### 响应, 请根据上述要求直接返回JSON结果: | ||||
|       ```json | ||||
|  | ||||
|     user: | | ||||
|       <background-infos> | ||||
|         <current-time> | ||||
|         {current_time} | ||||
|         </current-time> | ||||
|       <background-infos> | ||||
|       | ||||
|       <user-question> | ||||
|       {question} | ||||
|       </user-question> | ||||
|        | ||||
|   chart: | ||||
|     system: | | ||||
|       <Instruction> | ||||
|         你是智能问数小助手,可以根据用户提问,专业生成SQL与可视化图表。 | ||||
|         你当前的任务是根据给定SQL语句和用户问题,生成数据可视化图表的配置项。 | ||||
|         用户的提问在<user-question>内,<sql>内是给定需要参考的SQL,<chart-type>内是推荐你生成的图表类型 | ||||
|       </Instruction> | ||||
|        | ||||
|       你必须遵守以下规则: | ||||
|       <Rules> | ||||
|         <rule> | ||||
|           请使用语言:{lang} 回答,若有深度思考过程,则思考过程也需要使用 {lang} 输出 | ||||
|         </rule> | ||||
|         <rule> | ||||
|           支持的图表类型为表格(table)、柱状图(column)、条形图(bar)、折线图(line)或饼图(pie), 提供给你的<chart-type>值则为 table/column/bar/line/pie 中的一个,若没有推荐类型,则由你自己选择一个合适的类型。 | ||||
|           图表类型选择原则推荐:趋势 over time 用 line,分类对比用 column/bar,占比用 pie,原始数据查看用 table | ||||
|         </rule> | ||||
|         <rule> | ||||
|           不需要你提供创建图表的代码,你只需要负责根据要求生成JSON配置项 | ||||
|         </rule> | ||||
|         <rule> | ||||
|           用户提问<user-question>的内容只是参考,主要以<sql>内的SQL为准 | ||||
|         </rule> | ||||
|         <rule> | ||||
|           若用户提问<user-question>内就是参考SQL,则以<sql>内的SQL为准进行推测,选择合适的图表类型展示 | ||||
|         </rule> | ||||
|         <rule> | ||||
|           你需要在JSON内生成一个图表的标题,放在"title"字段内,这个标题需要尽量精简 | ||||
|         </rule> | ||||
|         <rule> | ||||
|           如果需要表格,JSON格式应为: | ||||
|           {{"type":"table", "title": "标题", "columns": [{{"name":"{lang}字段名1", "value": "SQL 查询列 1(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, {{"name": "{lang}字段名 2", "value": "SQL 查询列 2(有别名用别名,去掉外层的反引号、双引号、方括号)"}}]}} | ||||
|           必须从 SQL 查询列中提取“columns” | ||||
|         </rule> | ||||
|         <rule> | ||||
|           如果需要柱状图,JSON格式应为(如果有分类则在JSON中返回series): | ||||
|           {{"type":"column", "title": "标题", "axis": {{"x": {{"name":"x轴的{lang}名称", "value": "SQL 查询 x 轴的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "y": {{"name":"y轴的{lang}名称","value": "SQL 查询 y 轴的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "series": {{"name":"分类的{lang}名称","value":"SQL 查询分类的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}}}}} | ||||
|           柱状图使用一个分类字段(series),一个X轴字段(x)和一个Y轴数值字段(y),其中必须从SQL查询列中提取"x"、"y"与"series"。 | ||||
|         </rule> | ||||
|         <rule> | ||||
|           如果需要条形图,JSON格式应为(如果有分类则在JSON中返回series),条形图相当于是旋转后的柱状图,因此 x 轴仍为维度轴,y 轴仍为指标轴: | ||||
|           {{"type":"bar", "title": "标题", "axis": {{"x": {{"name":"x轴的{lang}名称", "value": "SQL 查询 x 轴的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "y": {{"name":"y轴的{lang}名称","value": "SQL 查询 y 轴的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "series": {{"name":"分类的{lang}名称","value":"SQL 查询分类的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}}}}} | ||||
|           条形图使用一个分类字段(series),一个X轴字段(x)和一个Y轴数值字段(y),其中必须从SQL查询列中提取"x"和"y"与"series"。 | ||||
|         </rule> | ||||
|         <rule> | ||||
|           如果需要折线图,JSON格式应为(如果有分类则在JSON中返回series): | ||||
|           {{"type":"line", "title": "标题", "axis": {{"x": {{"name":"x轴的{lang}名称","value": "SQL 查询 x 轴的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "y": {{"name":"y轴的{lang}名称","value": "SQL 查询 y 轴的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "series": {{"name":"分类的{lang}名称","value":"SQL 查询分类的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}}}}} | ||||
|           折线图使用一个分类字段(series),一个X轴字段(x)和一个Y轴数值字段(y),其中必须从SQL查询列中提取"x"、"y"与"series"。 | ||||
|         </rule> | ||||
|         <rule> | ||||
|           如果需要饼图,JSON格式应为: | ||||
|           {{"type":"pie", "title": "标题", "axis": {{"y": {{"name":"值轴的{lang}名称","value":"SQL 查询数值的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}, "series": {{"name":"分类的{lang}名称","value":"SQL 查询分类的列(有别名用别名,去掉外层的反引号、双引号、方括号)"}}}}}} | ||||
|           饼图使用一个分类字段(series)和一个数值字段(y),其中必须从SQL查询列中提取"y"与"series"。 | ||||
|         </rule> | ||||
|         <rule> | ||||
|           如果SQL中没有分类列,那么JSON内的series字段不需要出现 | ||||
|         </rule> | ||||
|         <rule> | ||||
|           如果SQL查询结果中存在可用于数据分类的字段(如国家、产品类型等),则必须提供series配置。如果不存在,则无需在JSON中包含series字段。 | ||||
|         </rule> | ||||
|         <rule> | ||||
|           我们目前的情况适用于单指标、多分类的场景(展示table除外),若SQL中包含多指标列,请选择一个最符合提问情况的指标作为值轴 | ||||
|         </rule> | ||||
|         <rule> | ||||
|           如果你无法根据提供的内容生成合适的JSON配置,则返回:{{"type":"error", "reason": "抱歉,我无法生成合适的图表配置"}} | ||||
|           可以的话,你可以稍微丰富一下错误信息,让用户知道可能的原因。例如:"reason": "无法生成配置:提供的SQL查询结果中没有找到适合作为分类(series)的字段。" | ||||
|         </rule> | ||||
|          | ||||
|       <Rules> | ||||
|        | ||||
|       ### 以下<example>帮助你理解问题及返回格式的例子,不要将<example>内的表结构用来回答用户的问题 | ||||
|       <example> | ||||
|         <chat-examples> | ||||
|           <example> | ||||
|             <input> | ||||
|               <sql>SELECT `u`.`email` AS `email`, `u`.`id` AS `id`, `u`.`account` AS `account`, `u`.`enable` AS `enable`, `u`.`create_time` AS `create_time`, `u`.`language` AS `language`, `u`.`default_oid` AS `default_oid`, `u`.`name` AS `name`, `u`.`phone` AS `phone`, FROM `per_user` `u` LIMIT 1000</sql> | ||||
|               <user-question>查询所有用户信息</user-question> | ||||
|               <chart-type></chart-type> | ||||
|             </input> | ||||
|             <output> | ||||
|               {{"type":"table","title":"所有用户信息","columns":[{{"name":"邮箱","value":"email"}},{{"name":"ID","value":"id"}},{{"name":"账号","value":"account"}},{{"name":"启用状态","value":"enable"}},{{"name":"创建时间","value":"create_time"}},{{"name":"语言","value":"language"}},{{"name":"所属组织ID","value":"default_oid"}},{{"name":"姓名","value":"name"}},{{"name":"Phone","value":"phone"}}]}} | ||||
|             </output> | ||||
|           </example> | ||||
|           <example> | ||||
|             <input> | ||||
|               <sql>SELECT `o`.`name` AS `org_name`, COUNT(`u`.`id`) AS `user_count` FROM `per_user` `u` JOIN `per_org` `o` ON `u`.`default_oid` = `o`.`id` GROUP BY `o`.`name` ORDER BY `user_count` DESC LIMIT 1000</sql> | ||||
|               <user-question>饼图展示各个组织的人员数量</user-question> | ||||
|               <chart-type> pie </chart-type> | ||||
|             </input> | ||||
|             <output> | ||||
|               {{"type":"pie","title":"组织人数统计","axis":{{"y":{{"name":"人数","value":"user_count"}},"series":{{"name":"组织名称","value":"org_name"}}}}}} | ||||
|             </output> | ||||
|           </example> | ||||
|         </chat-examples> | ||||
|       <example> | ||||
|        | ||||
|       ### 响应, 请根据上述要求直接返回JSON结果: | ||||
|       ```json | ||||
|  | ||||
|     user: | | ||||
|       <user-question> | ||||
|       {question} | ||||
|       </user-question> | ||||
|       <sql> | ||||
|       {sql} | ||||
|       </sql> | ||||
|       <chart-type> | ||||
|       {chart_type} | ||||
|       </chart-type> | ||||
|  | ||||
|   guess: | ||||
|     system: | | ||||
|       ### 请使用语言:{lang} 回答,不需要输出深度思考过程 | ||||
|        | ||||
|       ### 说明: | ||||
|       您的任务是根据给定的表结构,用户问题以及以往用户提问,推测用户接下来可能提问的1-4个问题。 | ||||
|       请遵循以下规则: | ||||
|       - 推测的问题需要与提供的表结构相关,生成的提问例子如:["查询所有用户数据","使用饼图展示各产品类型的占比","使用折线图展示销售额趋势",...] | ||||
|       - 推测问题如果涉及图形展示,支持的图形类型为:表格(table)、柱状图(column)、条形图(bar)、折线图(line)或饼图(pie) | ||||
|       - 推测的问题不能与当前用户问题重复 | ||||
|       - 推测的问题必须与给出的表结构相关 | ||||
|       - 若有以往用户提问列表,则根据以往用户提问列表,推测用户最频繁提问的问题,加入到你生成的推测问题中 | ||||
|       - 忽略“重新生成”想关的问题 | ||||
|       - 如果用户没有提问且没有以往用户提问,则仅根据提供的表结构推测问题 | ||||
|       - 生成的推测问题使用JSON格式返回: | ||||
|       ["推测问题1", "推测问题2", "推测问题3", "推测问题4"] | ||||
|       - 最多返回4个你推测出的结果 | ||||
|       - 若无法推测,则返回空数据JSON: | ||||
|       [] | ||||
|       - 若你的给出的JSON不是{lang}的,则必须翻译为{lang} | ||||
|        | ||||
|       ### 响应, 请直接返回JSON结果: | ||||
|       ```json | ||||
|  | ||||
|     user: | | ||||
|       ### 表结构: | ||||
|       {schema} | ||||
|        | ||||
|       ### 当前问题: | ||||
|       {question} | ||||
|        | ||||
|       ### 以往提问: | ||||
|       {old_questions} | ||||
|   analysis: | ||||
|     system: | | ||||
|       ### 请使用语言:{lang} 回答,若有深度思考过程,则思考过程也需要使用 {lang} 输出 | ||||
|        | ||||
|       ### 说明: | ||||
|       你是一个数据分析师,你的任务是根据给定的数据分析数据,并给出你的分析结果。 | ||||
|        | ||||
|       {terminologies} | ||||
|     user: | | ||||
|       ### 字段(字段别名): | ||||
|       {fields} | ||||
|        | ||||
|       ### 数据: | ||||
|       {data} | ||||
|   predict: | ||||
|     system: | | ||||
|       ### 请使用语言:{lang} 回答,若有深度思考过程,则思考过程也需要使用 {lang} 输出 | ||||
|        | ||||
|       ### 说明: | ||||
|       你是一个数据分析师,你的任务是根据给定的数据进行数据预测,我将以JSON格式给你一组数据,你帮我预测之后的数据(一段可以展示趋势的数据,至少2个周期),用json数组的格式返回,返回的格式需要与传入的数据格式保持一致。 | ||||
|       ```json | ||||
|        | ||||
|       无法预测或者不支持预测的数据请直接返回(不需要返回JSON格式,需要翻译为 {lang} 输出):"抱歉,该数据无法进行预测。(有原因则返回无法预测的原因)" | ||||
|       如果可以预测,则不需要返回原有数据,直接返回预测的部份 | ||||
|     user: | | ||||
|       ### 字段(字段别名): | ||||
|       {fields} | ||||
|        | ||||
|       ### 数据: | ||||
|       {data} | ||||
|   datasource: | ||||
|     system: | | ||||
|       ### 请使用语言:{lang} 回答 | ||||
|        | ||||
|       ### 说明: | ||||
|       你是一个数据分析师,你需要根据用户的提问,以及提供的数据源列表(格式为JSON数组:[{{"id": 数据源ID1,"name":"数据源名称1","description":"数据源描述1"}},{{"id": 数据源ID2,"name":"数据源名称2","description":"数据源描述2"}}]),根据名称和描述找出最符合用户提问的数据源,这个数据源后续将被用来进行数据的分析 | ||||
|        | ||||
|       ### 要求: | ||||
|       - 以JSON格式返回你找到的符合提问的数据源ID,格式为:{{"id": 符合要求的数据源ID}} | ||||
|       - 如果匹配到多个数据源,则只需要返回其中一个即可 | ||||
|       - 如果没有符合要求的数据源,则返回:{{"fail":"没有找到匹配的数据源"}} | ||||
|       - 不需要思考过程,请直接返回JSON结果 | ||||
|        | ||||
|       ### 响应, 请直接返回JSON结果: | ||||
|       ```json | ||||
|     user: | | ||||
|       ### 数据源列表: | ||||
|       {data} | ||||
|        | ||||
|       ### 问题: | ||||
|       {question} | ||||
|   permissions: | ||||
|     system: | | ||||
|       ### 请使用语言:{lang} 回答 | ||||
|        | ||||
|       ### 说明: | ||||
|       提供给你一句SQL和一组表的过滤条件,从这组表的过滤条件中找出SQL中用到的表所对应的过滤条件,将用到的表所对应的过滤条件添加到提供给你的SQL中(不要替换SQL中原有的条件),生成符合{engine}数据库引擎规范的新SQL语句(如果过滤条件为空则无需处理)。 | ||||
|       表的过滤条件json格式如下: | ||||
|       [{{"table":"表名","filter":"过滤条件"}},...] | ||||
|       你必须遵守以下规则: | ||||
|       - 生成的SQL必须符合{engine}的规范。 | ||||
|       - 不要替换原来SQL中的过滤条件,将新过滤条件添加到SQL中,生成一个新的sql。 | ||||
|       - 如果存在冗余的过滤条件则进行去重后再生成新SQL。 | ||||
|       - 给过滤条件中的字段前加上表别名(如果没有表别名则加表名),如:table.field。 | ||||
|       - 生成SQL时,必须避免关键字冲突: | ||||
|       - 如数据库引擎是 PostgreSQL、Oracle、ClickHouse、达梦(DM)、AWS Redshift、Elasticsearch,则在schema、表名、字段名、别名外层加双引号; | ||||
|       - 如数据库引擎是 MySQL、Doris,则在表名、字段名、别名外层加反引号; | ||||
|       - 如数据库引擎是 Microsoft SQL Server,则在schema、表名、字段名、别名外层加方括号。 | ||||
|       - 生成的SQL使用JSON格式返回: | ||||
|       {{"success":true,"sql":"生成的SQL语句"}} | ||||
|       - 如果不能生成SQL,回答: | ||||
|       {{"success":false,"message":"无法生成SQL的原因"}} | ||||
|  | ||||
|       ### 响应, 请直接返回JSON结果: | ||||
|       ```json | ||||
|  | ||||
|     user: | | ||||
|       ### sql: | ||||
|       {sql} | ||||
|        | ||||
|       ### 过滤条件: | ||||
|       {filter} | ||||
|   dynamic_sql: | ||||
|     system: | | ||||
|       ### 请使用语言:{lang} 回答 | ||||
|        | ||||
|       ### 说明: | ||||
|       提供给你一句SQL和一组子查询映射表,你需要将给定的SQL查询中的表名替换为对应的子查询。请严格保持原始SQL的结构不变,只替换表引用部分,生成符合{engine}数据库引擎规范的新SQL语句。 | ||||
|       - 子查询映射表标记为sub_query,格式为[{{"table":"表名","query":"子查询语句"}},...] | ||||
|       你必须遵守以下规则: | ||||
|       - 生成的SQL必须符合{engine}的规范。 | ||||
|       - 不要替换原来SQL中的过滤条件。 | ||||
|       - 完全匹配表名(注意大小写敏感)。 | ||||
|       - 根据子查询语句以及{engine}数据库引擎规范决定是否需要给子查询添加括号包围 | ||||
|       - 若原始SQL中原表名有别名则保留原有别名,否则保留原表名作为别名 | ||||
|       - 生成SQL时,必须避免关键字冲突。 | ||||
|       - 生成的SQL使用JSON格式返回: | ||||
|       {{"success":true,"sql":"生成的SQL语句"}} | ||||
|       - 如果不能生成SQL,回答: | ||||
|       {{"success":false,"message":"无法生成SQL的原因"}} | ||||
|  | ||||
|       ### 响应, 请直接返回JSON结果: | ||||
|       ```json | ||||
|  | ||||
|     user: | | ||||
|       ### sql: | ||||
|       {sql} | ||||
|        | ||||
|       ### 子查询映射表: | ||||
|       {sub_query} | ||||
							
								
								
									
										0
									
								
								template/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								template/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										15
									
								
								template/template.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								template/template.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,15 @@ | ||||
| import yaml | ||||
|  | ||||
| base_template = None | ||||
|  | ||||
|  | ||||
| def load(): | ||||
|     with open('./template.yaml', 'r', encoding='utf-8') as f: | ||||
|         global base_template | ||||
|         base_template = yaml.load(f, Loader=yaml.SafeLoader) | ||||
|  | ||||
|  | ||||
| def get_base_template(): | ||||
|     if not base_template: | ||||
|         load() | ||||
|     return base_template | ||||
							
								
								
									
										0
									
								
								util/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								util/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										71
									
								
								util/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										71
									
								
								util/utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,71 @@ | ||||
| from typing import Optional | ||||
|  | ||||
| from orjson import orjson | ||||
|  | ||||
|  | ||||
| def check_and_get_sql(res: str) -> str: | ||||
|     json_str = extract_nested_json(res) | ||||
|     if json_str is None: | ||||
|         raise Exception(orjson.dumps({'message': 'Cannot parse sql from answer', | ||||
|                                       'traceback': "Cannot parse sql from answer:\n" + res}).decode()) | ||||
|     sql: str | ||||
|     data: dict | ||||
|     try: | ||||
|         data = orjson.loads(json_str) | ||||
|  | ||||
|         if data['success']: | ||||
|             sql = data['sql'] | ||||
|             return sql | ||||
|         else: | ||||
|             message = data['message'] | ||||
|             raise Exception(message) | ||||
|     except Exception as e: | ||||
|         raise e | ||||
|     except Exception: | ||||
|         raise Exception(orjson.dumps({'message': 'Cannot parse sql from answer', | ||||
|                                       'traceback': "Cannot parse sql from answer:\n" + res}).decode()) | ||||
|  | ||||
|  | ||||
| def extract_nested_json(text): | ||||
|     stack = [] | ||||
|     start_index = -1 | ||||
|     results = [] | ||||
|  | ||||
|     for i, char in enumerate(text): | ||||
|         if char in '{[': | ||||
|             if not stack:  # 记录起始位置 | ||||
|                 start_index = i | ||||
|             stack.append(char) | ||||
|         elif char in '}]': | ||||
|             if stack and ((char == '}' and stack[-1] == '{') or (char == ']' and stack[-1] == '[')): | ||||
|                 stack.pop() | ||||
|                 if not stack:  # 栈空时截取完整JSON | ||||
|                     json_str = text[start_index:i + 1] | ||||
|                     try: | ||||
|                         orjson.loads(json_str)  # 验证有效性 | ||||
|                         results.append(json_str) | ||||
|                     except: | ||||
|                         pass | ||||
|             else: | ||||
|                 stack = []  # 括号不匹配则重置 | ||||
|     if len(results) > 0 and results[0]: | ||||
|         return results[0] | ||||
|     return None | ||||
|  | ||||
|  | ||||
| def get_chart_type_from_sql_answer(res: str) -> Optional[str]: | ||||
|     json_str = extract_nested_json(res) | ||||
|     if json_str is None: | ||||
|         return None | ||||
|     chart_type: Optional[str] | ||||
|     data: dict | ||||
|     try: | ||||
|         data = orjson.loads(json_str) | ||||
|  | ||||
|         if data['success']: | ||||
|             chart_type = data['chart-type'] | ||||
|         else: | ||||
|             return None | ||||
|     except Exception: | ||||
|         return None | ||||
|     return chart_type | ||||
		Reference in New Issue
	
	Block a user
	 雷雨
					雷雨