枚举结果处理
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
from email.policy import default
|
from email.policy import default
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
import util.utils
|
||||||
from logging_config import LOGGING_CONFIG
|
from logging_config import LOGGING_CONFIG
|
||||||
from service.cus_vanna_srevice import CustomVanna, QdrantClient
|
from service.cus_vanna_srevice import CustomVanna, QdrantClient
|
||||||
from decouple import config
|
from decouple import config
|
||||||
@@ -97,18 +98,20 @@ def generate_sql_2():
|
|||||||
return jsonify({"type": "error", "error": "No question provided"})
|
return jsonify({"type": "error", "error": "No question provided"})
|
||||||
try:
|
try:
|
||||||
id = cache.generate_id(question=question)
|
id = cache.generate_id(question=question)
|
||||||
|
logger.info(f"Generate sql for {question}")
|
||||||
data = vn.generate_sql_2(question=question)
|
data = vn.generate_sql_2(question=question)
|
||||||
|
logger.info("Generate sql result is {0}".format(data))
|
||||||
data['id'] = id
|
data['id'] = id
|
||||||
sql = data["resp"]["sql"]
|
sql = data["resp"]["sql"]
|
||||||
print("sql:", sql)
|
|
||||||
cache.set(id=id, field="question", value=question)
|
cache.set(id=id, field="question", value=question)
|
||||||
cache.set(id=id, field="sql", value=sql)
|
cache.set(id=id, field="sql", value=sql)
|
||||||
print("data---------------------------", data)
|
data["type"]="success"
|
||||||
return jsonify(data)
|
return jsonify(data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return jsonify({"type": "error", "error": str(e)})
|
return jsonify({"type": "error", "error": str(e)})
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@app.flask_app.route("/api/v0/run_sql_2", methods=["GET"])
|
@app.flask_app.route("/api/v0/run_sql_2", methods=["GET"])
|
||||||
@app.requires_cache(["sql"])
|
@app.requires_cache(["sql"])
|
||||||
def run_sql_2(id: str, sql: str):
|
def run_sql_2(id: str, sql: str):
|
||||||
@@ -150,13 +153,15 @@ def run_sql_2(id: str, sql: str):
|
|||||||
df = vn.run_sql(sql=sql)
|
df = vn.run_sql(sql=sql)
|
||||||
logger.info("")
|
logger.info("")
|
||||||
app.cache.set(id=id, field="df", value=df)
|
app.cache.set(id=id, field="df", value=df)
|
||||||
x = df.head(10).to_dict(orient='records')
|
x = df.to_dict(orient='records')
|
||||||
logger.info("df ---------------{0} {1}".format(x,type(x)))
|
logger.info("df ---------------{0} {1}".format(x,type(x)))
|
||||||
|
result = util.utils.deal_result(data=x)
|
||||||
|
|
||||||
return jsonify(
|
return jsonify(
|
||||||
{
|
{
|
||||||
"type": "df",
|
"type": "success",
|
||||||
"id": id,
|
"id": id,
|
||||||
"df": df.head(10).to_dict(orient='records'),
|
"df": result,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -184,6 +184,7 @@ class OpenAICompatibleLLM(VannaBase):
|
|||||||
|
|
||||||
def generate_sql_2(self, question: str, allow_llm_to_see_data=False, **kwargs) -> dict:
|
def generate_sql_2(self, question: str, allow_llm_to_see_data=False, **kwargs) -> dict:
|
||||||
try:
|
try:
|
||||||
|
logger.info("Start to generate_sql_2 in cus_vanna_srevice")
|
||||||
question_sql_list = self.get_similar_question_sql(question, **kwargs)
|
question_sql_list = self.get_similar_question_sql(question, **kwargs)
|
||||||
ddl_list = self.get_related_ddl(question, **kwargs)
|
ddl_list = self.get_related_ddl(question, **kwargs)
|
||||||
doc_list = self.get_related_documentation(question, **kwargs)
|
doc_list = self.get_related_documentation(question, **kwargs)
|
||||||
@@ -202,9 +203,12 @@ class OpenAICompatibleLLM(VannaBase):
|
|||||||
[{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs)
|
[{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs)
|
||||||
logger.info(f"llm_response:{llm_response}")
|
logger.info(f"llm_response:{llm_response}")
|
||||||
result = {"resp": orjson.loads(extract_nested_json(llm_response))}
|
result = {"resp": orjson.loads(extract_nested_json(llm_response))}
|
||||||
|
logger.info(f"llm_response:{llm_response}")
|
||||||
sql = check_and_get_sql(llm_response)
|
sql = check_and_get_sql(llm_response)
|
||||||
|
logger.info(f"sql:{sql}")
|
||||||
# ---------------生成图表
|
# ---------------生成图表
|
||||||
char_type = get_chart_type_from_sql_answer(llm_response)
|
char_type = get_chart_type_from_sql_answer(llm_response)
|
||||||
|
logger.info(f"chart type:{char_type}")
|
||||||
if char_type:
|
if char_type:
|
||||||
sys_char_temp = char_temp['system'].format(engine=config("DB_ENGINE", default='mysql'),
|
sys_char_temp = char_temp['system'].format(engine=config("DB_ENGINE", default='mysql'),
|
||||||
lang='中文', sql=sql, chart_type=char_type)
|
lang='中文', sql=sql, chart_type=char_type)
|
||||||
@@ -215,8 +219,11 @@ class OpenAICompatibleLLM(VannaBase):
|
|||||||
print(llm_response2)
|
print(llm_response2)
|
||||||
result['chart'] = orjson.loads(extract_nested_json(llm_response2))
|
result['chart'] = orjson.loads(extract_nested_json(llm_response2))
|
||||||
logger.info(f"chart_response:{result}")
|
logger.info(f"chart_response:{result}")
|
||||||
|
|
||||||
|
logger.info("Finish to generate_sql_2 in cus_vanna_srevice")
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
logger.info("cus_vanna_srevice failed-------------------")
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str:
|
def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str:
|
||||||
|
|||||||
@@ -28,11 +28,17 @@ template:
|
|||||||
你只能生成查询用的SQL语句,不得生成增删改相关或操作数据库以及操作数据库数据的SQL
|
你只能生成查询用的SQL语句,不得生成增删改相关或操作数据库以及操作数据库数据的SQL
|
||||||
</rule>
|
</rule>
|
||||||
<rule>
|
<rule>
|
||||||
如果只涉及查询人员信息,但没说具体哪些信息,可以不用查询所有信息,主要查询相关性较强的五个字段即可
|
如果因为客观原因无法生成sql,请合理分析无法生成的原因并反馈给用户
|
||||||
|
</rule>
|
||||||
|
<rule>
|
||||||
|
涉及查询人员信息时,如果用户没明确指出要查询哪些字段,主要查询相关性较强的10个字段即可,如果指定要查询所有信息,请返回所有字段信息
|
||||||
</rule>
|
</rule>
|
||||||
<rule>
|
<rule>
|
||||||
不要编造<m-schema>内没有提供给你的表结构
|
不要编造<m-schema>内没有提供给你的表结构
|
||||||
</rule>
|
</rule>
|
||||||
|
<rule>
|
||||||
|
当需要计算的字段类型为varchar或者text时,请根据需求转换为合理的类型格式进行计算
|
||||||
|
</rule>
|
||||||
<rule>
|
<rule>
|
||||||
生成的SQL必须符合<db-engine>内提供数据库引擎的规范
|
生成的SQL必须符合<db-engine>内提供数据库引擎的规范
|
||||||
</rule>
|
</rule>
|
||||||
@@ -43,7 +49,7 @@ template:
|
|||||||
请注意区分'哪些'和'多少'的区别,哪些是指具体信息,多少是指数量,请注意甄别
|
请注意区分'哪些'和'多少'的区别,哪些是指具体信息,多少是指数量,请注意甄别
|
||||||
</rule>
|
</rule>
|
||||||
<rule>
|
<rule>
|
||||||
如遇字符串类型的日期要计算时,务必转化为合理的格式进行计算
|
如遇字符串类型的日期要计算时,务必转化为合理的日期格式进行计算
|
||||||
</rule>
|
</rule>
|
||||||
<rule>
|
<rule>
|
||||||
请使用JSON格式返回你的回答:
|
请使用JSON格式返回你的回答:
|
||||||
@@ -75,6 +81,9 @@ template:
|
|||||||
<rule>
|
<rule>
|
||||||
SQL查询的字段若是函数字段,如 COUNT(),CAST() 等,必须加上别名
|
SQL查询的字段若是函数字段,如 COUNT(),CAST() 等,必须加上别名
|
||||||
</rule>
|
</rule>
|
||||||
|
<rule>
|
||||||
|
SQL查询的如果用了聚合函数,如SUM(),COUNT()等,必须配合GROUP BY使用
|
||||||
|
</rule>
|
||||||
<rule>
|
<rule>
|
||||||
计算占比,百分比类型字段,保留两位小数,以%结尾
|
计算占比,百分比类型字段,保留两位小数,以%结尾
|
||||||
</rule>
|
</rule>
|
||||||
@@ -167,7 +176,15 @@ template:
|
|||||||
<user-question>今天天气如何?</user-question>
|
<user-question>今天天气如何?</user-question>
|
||||||
</input>
|
</input>
|
||||||
<output>
|
<output>
|
||||||
{{"success":false,"message":"我是智能问数小助手,我无法回答您的问题。"}}
|
{{"success":false,"message":"我是智能问数小助手,我无法回答您的问题,该问题与当前数据库问数不相关,数据库中无天气等信息。","status":0}}
|
||||||
|
</output>
|
||||||
|
</example>
|
||||||
|
<example>
|
||||||
|
<input>
|
||||||
|
<user-question>张三的年龄是多大</user-question>
|
||||||
|
</input>
|
||||||
|
<output>
|
||||||
|
{{"success":true,"sql":"SELECT name, FLOOR(MONTHS_BETWEEN(SYSDATE, birthday) / 12) AS age FROM YJOA_APPSERVICE_DB.t_pr3rl2oj_yj_person_database","tables":["t_pr3rl2oj_yj_person_database"],"chart-type":"columns"}}
|
||||||
</output>
|
</output>
|
||||||
</example>
|
</example>
|
||||||
<example>
|
<example>
|
||||||
|
|||||||
@@ -2,6 +2,15 @@ from typing import Optional
|
|||||||
|
|
||||||
from orjson import orjson
|
from orjson import orjson
|
||||||
|
|
||||||
|
keywords = {
|
||||||
|
"gender":{"1":"男","2":"女"},
|
||||||
|
"person_status":{"1":"草稿","2":"审批中","3":"制卡中","4":"已入库","5":"停用"},
|
||||||
|
"pass_type":{"1":"集团公司员工","2":"借调人员","3":"借用人员","4":"外部监管人员","5":"外协服务人员","6":"工勤人员","7":"来访人员"},
|
||||||
|
"person_type": {"YG":"正式员工","PQ":"劳务派遣人员","QT":"其他柔性引进人员","WHZ":"合作单位","WLS":"临时访客","WQT":"其他外部人员"},
|
||||||
|
"id_card_type":{"1":"身份证","2":"护照","3":"港澳通行证"},
|
||||||
|
"highest_education": {"1":"初中","2":"高中","3":"中专","4":"技校","5":"职高","6":"大专","7":"本科","8":"硕士","9":"博士"},
|
||||||
|
"highest_degree":{"1":"学士学位","2":"硕士学位","3":"博士学位","4":"无"},
|
||||||
|
}
|
||||||
|
|
||||||
def check_and_get_sql(res: str) -> str:
|
def check_and_get_sql(res: str) -> str:
|
||||||
json_str = extract_nested_json(res)
|
json_str = extract_nested_json(res)
|
||||||
@@ -11,17 +20,21 @@ def check_and_get_sql(res: str) -> str:
|
|||||||
sql: str
|
sql: str
|
||||||
data: dict
|
data: dict
|
||||||
try:
|
try:
|
||||||
|
print("check_and_get_sql1----------------------------")
|
||||||
data = orjson.loads(json_str)
|
data = orjson.loads(json_str)
|
||||||
|
|
||||||
if data['success']:
|
if data['success']:
|
||||||
sql = data['sql']
|
sql = data['sql']
|
||||||
return sql
|
return sql
|
||||||
else:
|
else:
|
||||||
|
print("check_and_get_sql2----------------------------")
|
||||||
message = data['message']
|
message = data['message']
|
||||||
raise Exception(message)
|
raise Exception(message)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
print("check_and_get_sql3----------------------------")
|
||||||
raise e
|
raise e
|
||||||
except Exception:
|
except Exception:
|
||||||
|
print("check_and_get_sql4----------------------------")
|
||||||
raise Exception(orjson.dumps({'message': 'Cannot parse sql from answer',
|
raise Exception(orjson.dumps({'message': 'Cannot parse sql from answer',
|
||||||
'traceback': "Cannot parse sql from answer:\n" + res}).decode())
|
'traceback': "Cannot parse sql from answer:\n" + res}).decode())
|
||||||
|
|
||||||
@@ -69,3 +82,18 @@ def get_chart_type_from_sql_answer(res: str) -> Optional[str]:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return None
|
return None
|
||||||
return chart_type
|
return chart_type
|
||||||
|
|
||||||
|
|
||||||
|
def deal_result(data: list) -> list:
|
||||||
|
try:
|
||||||
|
for item in data:
|
||||||
|
for key, map_value in keywords.items():
|
||||||
|
if key in item:
|
||||||
|
new_key = item.get(key)
|
||||||
|
item[key] = map_value[new_key]
|
||||||
|
print("data----------{0}".format(data))
|
||||||
|
return data
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"sql执行结果处理失败:{str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user