72 lines
		
	
	
		
			2.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			72 lines
		
	
	
		
			2.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from typing import Optional
 | |
| 
 | |
| from orjson import orjson
 | |
| 
 | |
| 
 | |
| def check_and_get_sql(res: str) -> str:
 | |
|     json_str = extract_nested_json(res)
 | |
|     if json_str is None:
 | |
|         raise Exception(orjson.dumps({'message': 'Cannot parse sql from answer',
 | |
|                                       'traceback': "Cannot parse sql from answer:\n" + res}).decode())
 | |
|     sql: str
 | |
|     data: dict
 | |
|     try:
 | |
|         data = orjson.loads(json_str)
 | |
| 
 | |
|         if data['success']:
 | |
|             sql = data['sql']
 | |
|             return sql
 | |
|         else:
 | |
|             message = data['message']
 | |
|             raise Exception(message)
 | |
|     except Exception as e:
 | |
|         raise e
 | |
|     except Exception:
 | |
|         raise Exception(orjson.dumps({'message': 'Cannot parse sql from answer',
 | |
|                                       'traceback': "Cannot parse sql from answer:\n" + res}).decode())
 | |
| 
 | |
| 
 | |
| def extract_nested_json(text):
 | |
|     stack = []
 | |
|     start_index = -1
 | |
|     results = []
 | |
| 
 | |
|     for i, char in enumerate(text):
 | |
|         if char in '{[':
 | |
|             if not stack:  # 记录起始位置
 | |
|                 start_index = i
 | |
|             stack.append(char)
 | |
|         elif char in '}]':
 | |
|             if stack and ((char == '}' and stack[-1] == '{') or (char == ']' and stack[-1] == '[')):
 | |
|                 stack.pop()
 | |
|                 if not stack:  # 栈空时截取完整JSON
 | |
|                     json_str = text[start_index:i + 1]
 | |
|                     try:
 | |
|                         orjson.loads(json_str)  # 验证有效性
 | |
|                         results.append(json_str)
 | |
|                     except:
 | |
|                         pass
 | |
|             else:
 | |
|                 stack = []  # 括号不匹配则重置
 | |
|     if len(results) > 0 and results[0]:
 | |
|         return results[0]
 | |
|     return None
 | |
| 
 | |
| 
 | |
| def get_chart_type_from_sql_answer(res: str) -> Optional[str]:
 | |
|     json_str = extract_nested_json(res)
 | |
|     if json_str is None:
 | |
|         return None
 | |
|     chart_type: Optional[str]
 | |
|     data: dict
 | |
|     try:
 | |
|         data = orjson.loads(json_str)
 | |
| 
 | |
|         if data['success']:
 | |
|             chart_type = data['chart-type']
 | |
|         else:
 | |
|             return None
 | |
|     except Exception:
 | |
|         return None
 | |
|     return chart_type
 | 
