提示词优化,日志添加
This commit is contained in:
		| @@ -1,6 +1,6 @@ | ||||
| from email.policy import default | ||||
| from typing import List | ||||
|  | ||||
| import dmPython | ||||
| import orjson | ||||
| import pandas as pd | ||||
| from vanna.base import VannaBase | ||||
| @@ -22,7 +22,7 @@ class OpenAICompatibleLLM(VannaBase): | ||||
|         # default parameters - can be overrided using config | ||||
|         self.temperature = 0.5 | ||||
|         self.max_tokens = 5000 | ||||
|  | ||||
|         self.conn = None | ||||
|         if "temperature" in config_file: | ||||
|             self.temperature = config_file["temperature"] | ||||
|  | ||||
| @@ -140,36 +140,57 @@ class OpenAICompatibleLLM(VannaBase): | ||||
|  | ||||
|         return response.choices[0].message.content | ||||
|  | ||||
|     # def connect_to_dameng(self, host, port, username, password, database): | ||||
|     #     try: | ||||
|     #         self.conn = dmPython.connect( | ||||
|     #             user=username, | ||||
|     #             password=password, | ||||
|     #             server=host, | ||||
|     #             port=port,  # 达梦默认端口5236 | ||||
|     #             autoCommit=True | ||||
|     #         ) | ||||
|     #         print("达梦数据库连接成功") | ||||
|     #         return True | ||||
|     #     except Exception as e: | ||||
|     #         print(f"连接失败: {e}") | ||||
|     #         return False | ||||
|  | ||||
|     def generate_sql_2(self, question: str, allow_llm_to_see_data=False, **kwargs) -> dict: | ||||
|         question_sql_list = self.get_similar_question_sql(question, **kwargs) | ||||
|         ddl_list = self.get_related_ddl(question, **kwargs) | ||||
|         doc_list = self.get_related_documentation(question, **kwargs) | ||||
|         template = get_base_template() | ||||
|         sql_temp = template['template']['sql'] | ||||
|         char_temp = template['template']['chart'] | ||||
|         # --------基于提示词,生成sql以及图表类型 | ||||
|         sys_temp = sql_temp['system'].format(engine=config("DATA_SOURCE_TYPE",default='mysql'), lang='中文', schema=ddl_list, documentation=doc_list, | ||||
|                                              data_training=question_sql_list) | ||||
|         print("sys_temp", sys_temp) | ||||
|         user_temp = sql_temp['user'].format(question=question, | ||||
|                                             current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')) | ||||
|         print("user_temp", user_temp) | ||||
|         llm_response = self.submit_prompt( | ||||
|             [{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs) | ||||
|         print(llm_response) | ||||
|         result = {"resp": orjson.loads(extract_nested_json(llm_response))} | ||||
|         print("result", result) | ||||
|         sql = check_and_get_sql(llm_response) | ||||
|         # ---------------生成图表 | ||||
|         char_type = get_chart_type_from_sql_answer(llm_response) | ||||
|         if char_type: | ||||
|             sys_char_temp = char_temp['system'].format(engine=config("DATA_SOURCE_TYPE",default='mysql'), lang='中文', sql=sql, chart_type=char_type) | ||||
|             user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question) | ||||
|             llm_response2 = self.submit_prompt( | ||||
|                 [{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}], **kwargs) | ||||
|             print(llm_response2) | ||||
|             result['chart'] = orjson.loads(extract_nested_json(llm_response2)) | ||||
|         return result | ||||
|         try: | ||||
|             question_sql_list = self.get_similar_question_sql(question, **kwargs) | ||||
|             ddl_list = self.get_related_ddl(question, **kwargs) | ||||
|             doc_list = self.get_related_documentation(question, **kwargs) | ||||
|             template = get_base_template() | ||||
|             sql_temp = template['template']['sql'] | ||||
|             char_temp = template['template']['chart'] | ||||
|             # --------基于提示词,生成sql以及图表类型 | ||||
|             sys_temp = sql_temp['system'].format(engine=config("DATA_SOURCE_TYPE", default='mysql'), lang='中文', | ||||
|                                                  schema=ddl_list, documentation=doc_list, | ||||
|                                                  data_training=question_sql_list) | ||||
|             print("sys_temp", sys_temp) | ||||
|             user_temp = sql_temp['user'].format(question=question, | ||||
|                                                 current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')) | ||||
|             print("user_temp", user_temp) | ||||
|             llm_response = self.submit_prompt( | ||||
|                 [{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs) | ||||
|             print(llm_response) | ||||
|             result = {"resp": orjson.loads(extract_nested_json(llm_response))} | ||||
|             print("result", result) | ||||
|             sql = check_and_get_sql(llm_response) | ||||
|             # ---------------生成图表 | ||||
|             char_type = get_chart_type_from_sql_answer(llm_response) | ||||
|             if char_type: | ||||
|                 sys_char_temp = char_temp['system'].format(engine=config("DATA_SOURCE_TYPE", default='mysql'), | ||||
|                                                            lang='中文', sql=sql, chart_type=char_type) | ||||
|                 user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question) | ||||
|                 llm_response2 = self.submit_prompt( | ||||
|                     [{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}], | ||||
|                     **kwargs) | ||||
|                 print(llm_response2) | ||||
|                 result['chart'] = orjson.loads(extract_nested_json(llm_response2)) | ||||
|             return result | ||||
|         except Exception as e: | ||||
|             raise e | ||||
|  | ||||
|     def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str: | ||||
|         print("new_question---------------", new_question) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 yujj128
					yujj128