枚举结果处理
This commit is contained in:
@@ -1,6 +1,7 @@
|
||||
from email.policy import default
|
||||
import logging
|
||||
|
||||
import util.utils
|
||||
from logging_config import LOGGING_CONFIG
|
||||
from service.cus_vanna_srevice import CustomVanna, QdrantClient
|
||||
from decouple import config
|
||||
@@ -97,18 +98,20 @@ def generate_sql_2():
|
||||
return jsonify({"type": "error", "error": "No question provided"})
|
||||
try:
|
||||
id = cache.generate_id(question=question)
|
||||
logger.info(f"Generate sql for {question}")
|
||||
data = vn.generate_sql_2(question=question)
|
||||
logger.info("Generate sql result is {0}".format(data))
|
||||
data['id'] = id
|
||||
sql = data["resp"]["sql"]
|
||||
print("sql:", sql)
|
||||
cache.set(id=id, field="question", value=question)
|
||||
cache.set(id=id, field="sql", value=sql)
|
||||
print("data---------------------------", data)
|
||||
data["type"]="success"
|
||||
return jsonify(data)
|
||||
except Exception as e:
|
||||
return jsonify({"type": "error", "error": str(e)})
|
||||
|
||||
|
||||
|
||||
@app.flask_app.route("/api/v0/run_sql_2", methods=["GET"])
|
||||
@app.requires_cache(["sql"])
|
||||
def run_sql_2(id: str, sql: str):
|
||||
@@ -150,13 +153,15 @@ def run_sql_2(id: str, sql: str):
|
||||
df = vn.run_sql(sql=sql)
|
||||
logger.info("")
|
||||
app.cache.set(id=id, field="df", value=df)
|
||||
x = df.head(10).to_dict(orient='records')
|
||||
x = df.to_dict(orient='records')
|
||||
logger.info("df ---------------{0} {1}".format(x,type(x)))
|
||||
result = util.utils.deal_result(data=x)
|
||||
|
||||
return jsonify(
|
||||
{
|
||||
"type": "df",
|
||||
"type": "success",
|
||||
"id": id,
|
||||
"df": df.head(10).to_dict(orient='records'),
|
||||
"df": result,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
@@ -184,6 +184,7 @@ class OpenAICompatibleLLM(VannaBase):
|
||||
|
||||
def generate_sql_2(self, question: str, allow_llm_to_see_data=False, **kwargs) -> dict:
|
||||
try:
|
||||
logger.info("Start to generate_sql_2 in cus_vanna_srevice")
|
||||
question_sql_list = self.get_similar_question_sql(question, **kwargs)
|
||||
ddl_list = self.get_related_ddl(question, **kwargs)
|
||||
doc_list = self.get_related_documentation(question, **kwargs)
|
||||
@@ -202,9 +203,12 @@ class OpenAICompatibleLLM(VannaBase):
|
||||
[{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs)
|
||||
logger.info(f"llm_response:{llm_response}")
|
||||
result = {"resp": orjson.loads(extract_nested_json(llm_response))}
|
||||
logger.info(f"llm_response:{llm_response}")
|
||||
sql = check_and_get_sql(llm_response)
|
||||
logger.info(f"sql:{sql}")
|
||||
# ---------------生成图表
|
||||
char_type = get_chart_type_from_sql_answer(llm_response)
|
||||
logger.info(f"chart type:{char_type}")
|
||||
if char_type:
|
||||
sys_char_temp = char_temp['system'].format(engine=config("DB_ENGINE", default='mysql'),
|
||||
lang='中文', sql=sql, chart_type=char_type)
|
||||
@@ -215,8 +219,11 @@ class OpenAICompatibleLLM(VannaBase):
|
||||
print(llm_response2)
|
||||
result['chart'] = orjson.loads(extract_nested_json(llm_response2))
|
||||
logger.info(f"chart_response:{result}")
|
||||
|
||||
logger.info("Finish to generate_sql_2 in cus_vanna_srevice")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.info("cus_vanna_srevice failed-------------------")
|
||||
raise e
|
||||
|
||||
def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str:
|
||||
|
||||
@@ -28,11 +28,17 @@ template:
|
||||
你只能生成查询用的SQL语句,不得生成增删改相关或操作数据库以及操作数据库数据的SQL
|
||||
</rule>
|
||||
<rule>
|
||||
如果只涉及查询人员信息,但没说具体哪些信息,可以不用查询所有信息,主要查询相关性较强的五个字段即可
|
||||
如果因为客观原因无法生成sql,请合理分析无法生成的原因并反馈给用户
|
||||
</rule>
|
||||
<rule>
|
||||
涉及查询人员信息时,如果用户没明确指出要查询哪些字段,主要查询相关性较强的10个字段即可,如果指定要查询所有信息,请返回所有字段信息
|
||||
</rule>
|
||||
<rule>
|
||||
不要编造<m-schema>内没有提供给你的表结构
|
||||
</rule>
|
||||
<rule>
|
||||
当需要计算的字段类型为varchar或者text时,请根据需求转换为合理的类型格式进行计算
|
||||
</rule>
|
||||
<rule>
|
||||
生成的SQL必须符合<db-engine>内提供数据库引擎的规范
|
||||
</rule>
|
||||
@@ -43,7 +49,7 @@ template:
|
||||
请注意区分'哪些'和'多少'的区别,哪些是指具体信息,多少是指数量,请注意甄别
|
||||
</rule>
|
||||
<rule>
|
||||
如遇字符串类型的日期要计算时,务必转化为合理的格式进行计算
|
||||
如遇字符串类型的日期要计算时,务必转化为合理的日期格式进行计算
|
||||
</rule>
|
||||
<rule>
|
||||
请使用JSON格式返回你的回答:
|
||||
@@ -75,6 +81,9 @@ template:
|
||||
<rule>
|
||||
SQL查询的字段若是函数字段,如 COUNT(),CAST() 等,必须加上别名
|
||||
</rule>
|
||||
<rule>
|
||||
SQL查询的如果用了聚合函数,如SUM(),COUNT()等,必须配合GROUP BY使用
|
||||
</rule>
|
||||
<rule>
|
||||
计算占比,百分比类型字段,保留两位小数,以%结尾
|
||||
</rule>
|
||||
@@ -167,7 +176,15 @@ template:
|
||||
<user-question>今天天气如何?</user-question>
|
||||
</input>
|
||||
<output>
|
||||
{{"success":false,"message":"我是智能问数小助手,我无法回答您的问题。"}}
|
||||
{{"success":false,"message":"我是智能问数小助手,我无法回答您的问题,该问题与当前数据库问数不相关,数据库中无天气等信息。","status":0}}
|
||||
</output>
|
||||
</example>
|
||||
<example>
|
||||
<input>
|
||||
<user-question>张三的年龄是多大</user-question>
|
||||
</input>
|
||||
<output>
|
||||
{{"success":true,"sql":"SELECT name, FLOOR(MONTHS_BETWEEN(SYSDATE, birthday) / 12) AS age FROM YJOA_APPSERVICE_DB.t_pr3rl2oj_yj_person_database","tables":["t_pr3rl2oj_yj_person_database"],"chart-type":"columns"}}
|
||||
</output>
|
||||
</example>
|
||||
<example>
|
||||
|
||||
@@ -2,6 +2,15 @@ from typing import Optional
|
||||
|
||||
from orjson import orjson
|
||||
|
||||
keywords = {
|
||||
"gender":{"1":"男","2":"女"},
|
||||
"person_status":{"1":"草稿","2":"审批中","3":"制卡中","4":"已入库","5":"停用"},
|
||||
"pass_type":{"1":"集团公司员工","2":"借调人员","3":"借用人员","4":"外部监管人员","5":"外协服务人员","6":"工勤人员","7":"来访人员"},
|
||||
"person_type": {"YG":"正式员工","PQ":"劳务派遣人员","QT":"其他柔性引进人员","WHZ":"合作单位","WLS":"临时访客","WQT":"其他外部人员"},
|
||||
"id_card_type":{"1":"身份证","2":"护照","3":"港澳通行证"},
|
||||
"highest_education": {"1":"初中","2":"高中","3":"中专","4":"技校","5":"职高","6":"大专","7":"本科","8":"硕士","9":"博士"},
|
||||
"highest_degree":{"1":"学士学位","2":"硕士学位","3":"博士学位","4":"无"},
|
||||
}
|
||||
|
||||
def check_and_get_sql(res: str) -> str:
|
||||
json_str = extract_nested_json(res)
|
||||
@@ -11,17 +20,21 @@ def check_and_get_sql(res: str) -> str:
|
||||
sql: str
|
||||
data: dict
|
||||
try:
|
||||
print("check_and_get_sql1----------------------------")
|
||||
data = orjson.loads(json_str)
|
||||
|
||||
if data['success']:
|
||||
sql = data['sql']
|
||||
return sql
|
||||
else:
|
||||
print("check_and_get_sql2----------------------------")
|
||||
message = data['message']
|
||||
raise Exception(message)
|
||||
except Exception as e:
|
||||
print("check_and_get_sql3----------------------------")
|
||||
raise e
|
||||
except Exception:
|
||||
print("check_and_get_sql4----------------------------")
|
||||
raise Exception(orjson.dumps({'message': 'Cannot parse sql from answer',
|
||||
'traceback': "Cannot parse sql from answer:\n" + res}).decode())
|
||||
|
||||
@@ -69,3 +82,18 @@ def get_chart_type_from_sql_answer(res: str) -> Optional[str]:
|
||||
except Exception:
|
||||
return None
|
||||
return chart_type
|
||||
|
||||
|
||||
def deal_result(data: list) -> list:
|
||||
try:
|
||||
for item in data:
|
||||
for key, map_value in keywords.items():
|
||||
if key in item:
|
||||
new_key = item.get(key)
|
||||
item[key] = map_value[new_key]
|
||||
print("data----------{0}".format(data))
|
||||
return data
|
||||
except Exception as e:
|
||||
raise Exception(f"sql执行结果处理失败:{str(e)}")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user