Files
sqlbot_agent/util/utils.py
2025-09-26 17:35:23 +08:00

102 lines
3.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import Optional
from orjson import orjson
keywords = {
# "gender":{"1":"男","2":"女"},
"person_status":{"1":"草稿","2":"审批中","3":"制卡中","4":"已入库","5":"停用"},
"pass_type":{"1":"集团公司员工","2":"借调人员","3":"借用人员","4":"外部监管人员","5":"外协服务人员","6":"工勤人员","7":"来访人员"},
"person_type": {"YG":"正式员工","PQ":"劳务派遣人员","QT":"其他柔性引进人员","WHZ":"合作单位","WLS":"临时访客","WQT":"其他外部人员"},
"id_card_type":{"1":"身份证","2":"护照","3":"港澳通行证"},
"highest_education": {"1":"初中","2":"高中","3":"中专","4":"技校","5":"职高","6":"大专","7":"本科","8":"硕士","9":"博士"},
"highest_degree":{"1":"学士学位","2":"硕士学位","3":"博士学位","4":""},
}
def check_and_get_sql(res: str) -> str:
json_str = extract_nested_json(res)
if json_str is None:
raise Exception(orjson.dumps({'message': 'Cannot parse sql from answer',
'traceback': "Cannot parse sql from answer:\n" + res}).decode())
sql: str
data: dict
try:
print("check_and_get_sql1----------------------------")
data = orjson.loads(json_str)
if data['success']:
sql = data['sql']
return sql
else:
print("check_and_get_sql2----------------------------")
message = data['message']
raise Exception(message)
except Exception as e:
print("check_and_get_sql3----------------------------")
raise e
except Exception:
print("check_and_get_sql4----------------------------")
raise Exception(orjson.dumps({'message': 'Cannot parse sql from answer',
'traceback': "Cannot parse sql from answer:\n" + res}).decode())
def extract_nested_json(text):
stack = []
start_index = -1
results = []
for i, char in enumerate(text):
if char in '{[':
if not stack: # 记录起始位置
start_index = i
stack.append(char)
elif char in '}]':
if stack and ((char == '}' and stack[-1] == '{') or (char == ']' and stack[-1] == '[')):
stack.pop()
if not stack: # 栈空时截取完整JSON
json_str = text[start_index:i + 1]
try:
orjson.loads(json_str) # 验证有效性
results.append(json_str)
except:
pass
else:
stack = [] # 括号不匹配则重置
if len(results) > 0 and results[0]:
return results[0]
return None
def get_chart_type_from_sql_answer(res: str) -> Optional[str]:
json_str = extract_nested_json(res)
if json_str is None:
return None
chart_type: Optional[str]
data: dict
try:
data = orjson.loads(json_str)
if data['success']:
chart_type = data['chart-type']
else:
return None
except Exception:
return None
return chart_type
def deal_result(data: list) -> list:
try:
for item in data:
for key, map_value in keywords.items():
if key in item:
new_key = item.get(key)
if new_key in map_value:
item[key] = map_value[new_key]
print("data----------{0}".format(data))
return data
except Exception as e:
print("1111111111111111111111111111111111111111",e)
raise Exception(f"sql执行结果处理失败{str(e)}")