72 lines
		
	
	
		
			2.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			72 lines
		
	
	
		
			2.1 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | from typing import Optional | ||
|  | 
 | ||
|  | from orjson import orjson | ||
|  | 
 | ||
|  | 
 | ||
|  | def check_and_get_sql(res: str) -> str: | ||
|  |     json_str = extract_nested_json(res) | ||
|  |     if json_str is None: | ||
|  |         raise Exception(orjson.dumps({'message': 'Cannot parse sql from answer', | ||
|  |                                       'traceback': "Cannot parse sql from answer:\n" + res}).decode()) | ||
|  |     sql: str | ||
|  |     data: dict | ||
|  |     try: | ||
|  |         data = orjson.loads(json_str) | ||
|  | 
 | ||
|  |         if data['success']: | ||
|  |             sql = data['sql'] | ||
|  |             return sql | ||
|  |         else: | ||
|  |             message = data['message'] | ||
|  |             raise Exception(message) | ||
|  |     except Exception as e: | ||
|  |         raise e | ||
|  |     except Exception: | ||
|  |         raise Exception(orjson.dumps({'message': 'Cannot parse sql from answer', | ||
|  |                                       'traceback': "Cannot parse sql from answer:\n" + res}).decode()) | ||
|  | 
 | ||
|  | 
 | ||
|  | def extract_nested_json(text): | ||
|  |     stack = [] | ||
|  |     start_index = -1 | ||
|  |     results = [] | ||
|  | 
 | ||
|  |     for i, char in enumerate(text): | ||
|  |         if char in '{[': | ||
|  |             if not stack:  # 记录起始位置 | ||
|  |                 start_index = i | ||
|  |             stack.append(char) | ||
|  |         elif char in '}]': | ||
|  |             if stack and ((char == '}' and stack[-1] == '{') or (char == ']' and stack[-1] == '[')): | ||
|  |                 stack.pop() | ||
|  |                 if not stack:  # 栈空时截取完整JSON | ||
|  |                     json_str = text[start_index:i + 1] | ||
|  |                     try: | ||
|  |                         orjson.loads(json_str)  # 验证有效性 | ||
|  |                         results.append(json_str) | ||
|  |                     except: | ||
|  |                         pass | ||
|  |             else: | ||
|  |                 stack = []  # 括号不匹配则重置 | ||
|  |     if len(results) > 0 and results[0]: | ||
|  |         return results[0] | ||
|  |     return None | ||
|  | 
 | ||
|  | 
 | ||
|  | def get_chart_type_from_sql_answer(res: str) -> Optional[str]: | ||
|  |     json_str = extract_nested_json(res) | ||
|  |     if json_str is None: | ||
|  |         return None | ||
|  |     chart_type: Optional[str] | ||
|  |     data: dict | ||
|  |     try: | ||
|  |         data = orjson.loads(json_str) | ||
|  | 
 | ||
|  |         if data['success']: | ||
|  |             chart_type = data['chart-type'] | ||
|  |         else: | ||
|  |             return None | ||
|  |     except Exception: | ||
|  |         return None | ||
|  |     return chart_type |