diff --git a/main_service.py b/main_service.py
index c6c5e4d..956eaca 100644
--- a/main_service.py
+++ b/main_service.py
@@ -1,6 +1,7 @@
from email.policy import default
import logging
+import util.utils
from logging_config import LOGGING_CONFIG
from service.cus_vanna_srevice import CustomVanna, QdrantClient
from decouple import config
@@ -97,18 +98,20 @@ def generate_sql_2():
return jsonify({"type": "error", "error": "No question provided"})
try:
id = cache.generate_id(question=question)
+ logger.info(f"Generate sql for {question}")
data = vn.generate_sql_2(question=question)
+ logger.info("Generate sql result is {0}".format(data))
data['id'] = id
sql = data["resp"]["sql"]
- print("sql:", sql)
cache.set(id=id, field="question", value=question)
cache.set(id=id, field="sql", value=sql)
- print("data---------------------------", data)
+ data["type"]="success"
return jsonify(data)
except Exception as e:
return jsonify({"type": "error", "error": str(e)})
+
@app.flask_app.route("/api/v0/run_sql_2", methods=["GET"])
@app.requires_cache(["sql"])
def run_sql_2(id: str, sql: str):
@@ -150,13 +153,15 @@ def run_sql_2(id: str, sql: str):
df = vn.run_sql(sql=sql)
logger.info("")
app.cache.set(id=id, field="df", value=df)
- x = df.head(10).to_dict(orient='records')
+ x = df.to_dict(orient='records')
logger.info("df ---------------{0} {1}".format(x,type(x)))
+ result = util.utils.deal_result(data=x)
+
return jsonify(
{
- "type": "df",
+ "type": "success",
"id": id,
- "df": df.head(10).to_dict(orient='records'),
+ "df": result,
}
)
diff --git a/service/cus_vanna_srevice.py b/service/cus_vanna_srevice.py
index 9585194..94c4e92 100644
--- a/service/cus_vanna_srevice.py
+++ b/service/cus_vanna_srevice.py
@@ -184,6 +184,7 @@ class OpenAICompatibleLLM(VannaBase):
def generate_sql_2(self, question: str, allow_llm_to_see_data=False, **kwargs) -> dict:
try:
+ logger.info("Start to generate_sql_2 in cus_vanna_srevice")
question_sql_list = self.get_similar_question_sql(question, **kwargs)
ddl_list = self.get_related_ddl(question, **kwargs)
doc_list = self.get_related_documentation(question, **kwargs)
@@ -202,9 +203,12 @@ class OpenAICompatibleLLM(VannaBase):
[{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs)
logger.info(f"llm_response:{llm_response}")
result = {"resp": orjson.loads(extract_nested_json(llm_response))}
+ logger.info(f"llm_response:{llm_response}")
sql = check_and_get_sql(llm_response)
+ logger.info(f"sql:{sql}")
# ---------------生成图表
char_type = get_chart_type_from_sql_answer(llm_response)
+ logger.info(f"chart type:{char_type}")
if char_type:
sys_char_temp = char_temp['system'].format(engine=config("DB_ENGINE", default='mysql'),
lang='中文', sql=sql, chart_type=char_type)
@@ -215,8 +219,11 @@ class OpenAICompatibleLLM(VannaBase):
print(llm_response2)
result['chart'] = orjson.loads(extract_nested_json(llm_response2))
logger.info(f"chart_response:{result}")
+
+ logger.info("Finish to generate_sql_2 in cus_vanna_srevice")
return result
except Exception as e:
+ logger.info("cus_vanna_srevice failed-------------------")
raise e
def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str:
diff --git a/template.yaml b/template.yaml
index b006f7d..9ec0f2b 100644
--- a/template.yaml
+++ b/template.yaml
@@ -28,11 +28,17 @@ template:
你只能生成查询用的SQL语句,不得生成增删改相关或操作数据库以及操作数据库数据的SQL
- 如果只涉及查询人员信息,但没说具体哪些信息,可以不用查询所有信息,主要查询相关性较强的五个字段即可
+ 如果因为客观原因无法生成sql,请合理分析无法生成的原因并反馈给用户
+
+
+ 涉及查询人员信息时,如果用户没明确指出要查询哪些字段,主要查询相关性较强的10个字段即可,如果指定要查询所有信息,请返回所有字段信息
不要编造内没有提供给你的表结构
+
+ 当需要计算的字段类型为varchar或者text时,请根据需求转换为合理的类型格式进行计算
+
生成的SQL必须符合内提供数据库引擎的规范
@@ -43,7 +49,7 @@ template:
请注意区分'哪些'和'多少'的区别,哪些是指具体信息,多少是指数量,请注意甄别
- 如遇字符串类型的日期要计算时,务必转化为合理的格式进行计算
+ 如遇字符串类型的日期要计算时,务必转化为合理的日期格式进行计算
请使用JSON格式返回你的回答:
@@ -75,6 +81,9 @@ template:
SQL查询的字段若是函数字段,如 COUNT(),CAST() 等,必须加上别名
+
+ SQL查询的如果用了聚合函数,如SUM(),COUNT()等,必须配合GROUP BY使用
+
计算占比,百分比类型字段,保留两位小数,以%结尾
@@ -167,7 +176,15 @@ template:
今天天气如何?
+
+
+
+ 张三的年龄是多大
+
+
diff --git a/util/utils.py b/util/utils.py
index 0cb66ab..b01a0f9 100644
--- a/util/utils.py
+++ b/util/utils.py
@@ -2,6 +2,15 @@ from typing import Optional
from orjson import orjson
+keywords = {
+ "gender":{"1":"男","2":"女"},
+ "person_status":{"1":"草稿","2":"审批中","3":"制卡中","4":"已入库","5":"停用"},
+ "pass_type":{"1":"集团公司员工","2":"借调人员","3":"借用人员","4":"外部监管人员","5":"外协服务人员","6":"工勤人员","7":"来访人员"},
+ "person_type": {"YG":"正式员工","PQ":"劳务派遣人员","QT":"其他柔性引进人员","WHZ":"合作单位","WLS":"临时访客","WQT":"其他外部人员"},
+ "id_card_type":{"1":"身份证","2":"护照","3":"港澳通行证"},
+ "highest_education": {"1":"初中","2":"高中","3":"中专","4":"技校","5":"职高","6":"大专","7":"本科","8":"硕士","9":"博士"},
+ "highest_degree":{"1":"学士学位","2":"硕士学位","3":"博士学位","4":"无"},
+}
def check_and_get_sql(res: str) -> str:
json_str = extract_nested_json(res)
@@ -11,17 +20,21 @@ def check_and_get_sql(res: str) -> str:
sql: str
data: dict
try:
+ print("check_and_get_sql1----------------------------")
data = orjson.loads(json_str)
if data['success']:
sql = data['sql']
return sql
else:
+ print("check_and_get_sql2----------------------------")
message = data['message']
raise Exception(message)
except Exception as e:
+ print("check_and_get_sql3----------------------------")
raise e
except Exception:
+ print("check_and_get_sql4----------------------------")
raise Exception(orjson.dumps({'message': 'Cannot parse sql from answer',
'traceback': "Cannot parse sql from answer:\n" + res}).decode())
@@ -69,3 +82,18 @@ def get_chart_type_from_sql_answer(res: str) -> Optional[str]:
except Exception:
return None
return chart_type
+
+
+def deal_result(data: list) -> list:
+ try:
+ for item in data:
+ for key, map_value in keywords.items():
+ if key in item:
+ new_key = item.get(key)
+ item[key] = map_value[new_key]
+ print("data----------{0}".format(data))
+ return data
+ except Exception as e:
+ raise Exception(f"sql执行结果处理失败:{str(e)}")
+
+