feat:修改ddql,增加qa问答
This commit is contained in:
		| @@ -17,7 +17,7 @@ from datetime import datetime | ||||
| import logging | ||||
| from util import  train_ddl | ||||
| logger = logging.getLogger(__name__) | ||||
|  | ||||
| import traceback | ||||
| class OpenAICompatibleLLM(VannaBase): | ||||
|     def __init__(self, client=None, config_file=None): | ||||
|         VannaBase.__init__(self, config=config_file) | ||||
| @@ -186,6 +186,9 @@ class OpenAICompatibleLLM(VannaBase): | ||||
|         try: | ||||
|             logger.info("Start to generate_sql_2 in cus_vanna_srevice") | ||||
|             question_sql_list = self.get_similar_question_sql(question, **kwargs) | ||||
|             if question_sql_list and len(question_sql_list)>2: | ||||
|                 question_sql_list=question_sql_list[:1] | ||||
|  | ||||
|             ddl_list = self.get_related_ddl(question, **kwargs) | ||||
|             #doc_list = self.get_related_documentation(question, **kwargs) | ||||
|             template = get_base_template() | ||||
| @@ -194,7 +197,8 @@ class OpenAICompatibleLLM(VannaBase): | ||||
|             # --------基于提示词,生成sql以及图表类型 | ||||
|             sys_temp = sql_temp['system'].format(engine=config("DB_ENGINE", default='mysql'), lang='中文', | ||||
|                                                  schema=ddl_list, documentation=[train_ddl.train_document], | ||||
|                                                  data_training=question_sql_list) | ||||
|                                                  retrieved_examples_data=question_sql_list, | ||||
|                                                  data_training=question_sql_list,) | ||||
|             logger.info(f"sys_temp:{sys_temp}") | ||||
|             user_temp = sql_temp['user'].format(question=question, | ||||
|                                                 current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')) | ||||
| @@ -223,7 +227,8 @@ class OpenAICompatibleLLM(VannaBase): | ||||
|             logger.info("Finish to generate_sql_2 in cus_vanna_srevice") | ||||
|             return result | ||||
|         except Exception as e: | ||||
|             logger.info("cus_vanna_srevice failed-------------------") | ||||
|             logger.info("cus_vanna_srevice failed-------------------: ") | ||||
|             traceback.print_exc() | ||||
|             raise e | ||||
|  | ||||
|     def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 雷雨
					雷雨