feat:初始化
This commit is contained in:
		
							
								
								
									
										0
									
								
								util/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										0
									
								
								util/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
								
								
									
										71
									
								
								util/utils.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										71
									
								
								util/utils.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,71 @@ | ||||
| from typing import Optional | ||||
|  | ||||
| from orjson import orjson | ||||
|  | ||||
|  | ||||
| def check_and_get_sql(res: str) -> str: | ||||
|     json_str = extract_nested_json(res) | ||||
|     if json_str is None: | ||||
|         raise Exception(orjson.dumps({'message': 'Cannot parse sql from answer', | ||||
|                                       'traceback': "Cannot parse sql from answer:\n" + res}).decode()) | ||||
|     sql: str | ||||
|     data: dict | ||||
|     try: | ||||
|         data = orjson.loads(json_str) | ||||
|  | ||||
|         if data['success']: | ||||
|             sql = data['sql'] | ||||
|             return sql | ||||
|         else: | ||||
|             message = data['message'] | ||||
|             raise Exception(message) | ||||
|     except Exception as e: | ||||
|         raise e | ||||
|     except Exception: | ||||
|         raise Exception(orjson.dumps({'message': 'Cannot parse sql from answer', | ||||
|                                       'traceback': "Cannot parse sql from answer:\n" + res}).decode()) | ||||
|  | ||||
|  | ||||
| def extract_nested_json(text): | ||||
|     stack = [] | ||||
|     start_index = -1 | ||||
|     results = [] | ||||
|  | ||||
|     for i, char in enumerate(text): | ||||
|         if char in '{[': | ||||
|             if not stack:  # 记录起始位置 | ||||
|                 start_index = i | ||||
|             stack.append(char) | ||||
|         elif char in '}]': | ||||
|             if stack and ((char == '}' and stack[-1] == '{') or (char == ']' and stack[-1] == '[')): | ||||
|                 stack.pop() | ||||
|                 if not stack:  # 栈空时截取完整JSON | ||||
|                     json_str = text[start_index:i + 1] | ||||
|                     try: | ||||
|                         orjson.loads(json_str)  # 验证有效性 | ||||
|                         results.append(json_str) | ||||
|                     except: | ||||
|                         pass | ||||
|             else: | ||||
|                 stack = []  # 括号不匹配则重置 | ||||
|     if len(results) > 0 and results[0]: | ||||
|         return results[0] | ||||
|     return None | ||||
|  | ||||
|  | ||||
| def get_chart_type_from_sql_answer(res: str) -> Optional[str]: | ||||
|     json_str = extract_nested_json(res) | ||||
|     if json_str is None: | ||||
|         return None | ||||
|     chart_type: Optional[str] | ||||
|     data: dict | ||||
|     try: | ||||
|         data = orjson.loads(json_str) | ||||
|  | ||||
|         if data['success']: | ||||
|             chart_type = data['chart-type'] | ||||
|         else: | ||||
|             return None | ||||
|     except Exception: | ||||
|         return None | ||||
|     return chart_type | ||||
		Reference in New Issue
	
	Block a user
	 雷雨
					雷雨