feat:初始化

This commit is contained in:
雷雨
2025-09-23 14:49:00 +08:00
parent e70a28be7d
commit 3ace3e5348
10 changed files with 944 additions and 0 deletions

0
util/__init__.py Normal file
View File

71
util/utils.py Normal file
View File

@@ -0,0 +1,71 @@
from typing import Optional
from orjson import orjson
def check_and_get_sql(res: str) -> str:
json_str = extract_nested_json(res)
if json_str is None:
raise Exception(orjson.dumps({'message': 'Cannot parse sql from answer',
'traceback': "Cannot parse sql from answer:\n" + res}).decode())
sql: str
data: dict
try:
data = orjson.loads(json_str)
if data['success']:
sql = data['sql']
return sql
else:
message = data['message']
raise Exception(message)
except Exception as e:
raise e
except Exception:
raise Exception(orjson.dumps({'message': 'Cannot parse sql from answer',
'traceback': "Cannot parse sql from answer:\n" + res}).decode())
def extract_nested_json(text):
stack = []
start_index = -1
results = []
for i, char in enumerate(text):
if char in '{[':
if not stack: # 记录起始位置
start_index = i
stack.append(char)
elif char in '}]':
if stack and ((char == '}' and stack[-1] == '{') or (char == ']' and stack[-1] == '[')):
stack.pop()
if not stack: # 栈空时截取完整JSON
json_str = text[start_index:i + 1]
try:
orjson.loads(json_str) # 验证有效性
results.append(json_str)
except:
pass
else:
stack = [] # 括号不匹配则重置
if len(results) > 0 and results[0]:
return results[0]
return None
def get_chart_type_from_sql_answer(res: str) -> Optional[str]:
json_str = extract_nested_json(res)
if json_str is None:
return None
chart_type: Optional[str]
data: dict
try:
data = orjson.loads(json_str)
if data['success']:
chart_type = data['chart-type']
else:
return None
except Exception:
return None
return chart_type