缓存上下文,表结构添加
This commit is contained in:
		| @@ -1,3 +1,4 @@ | ||||
| from dataclasses import field | ||||
| from email.policy import default | ||||
| from typing import List, Union, Any, Optional | ||||
| import time | ||||
| @@ -76,10 +77,10 @@ class OpenAICompatibleLLM(VannaBase): | ||||
|  | ||||
|         def run_sql_damengsql(sql: str) -> Union[pd.DataFrame, None]: | ||||
|             logger.info(f"start to run_sql_damengsql") | ||||
|             if not is_connection_alive(conn=self.conn): | ||||
|                 logger.info("connection is not alive, reconnecting..........") | ||||
|                 reconnect() | ||||
|             try: | ||||
|                 if not is_connection_alive(conn=self.conn): | ||||
|                     logger.info("connection is not alive, reconnecting..........") | ||||
|                     reconnect() | ||||
|                 # conn.ping(reconnect=True) | ||||
|                 cs = self.conn.cursor() | ||||
|                 cs.execute(sql) | ||||
| @@ -203,7 +204,7 @@ class OpenAICompatibleLLM(VannaBase): | ||||
|  | ||||
|         return response.choices[0].message.content | ||||
|  | ||||
|     def generate_sql_2(self, question: str, allow_llm_to_see_data=False, **kwargs) -> dict: | ||||
|     def generate_sql_2(self, question: str, cache=None,user_id=None, allow_llm_to_see_data=False, **kwargs) -> dict: | ||||
|         try: | ||||
|             logger.info("Start to generate_sql_2 in cus_vanna_srevice") | ||||
|             question_sql_list = self.get_similar_question_sql(question, **kwargs) | ||||
| @@ -215,15 +216,19 @@ class OpenAICompatibleLLM(VannaBase): | ||||
|             template = get_base_template() | ||||
|             sql_temp = template['template']['sql'] | ||||
|             char_temp = template['template']['chart'] | ||||
|             history = None | ||||
|             if user_id and cache: | ||||
|                 history = cache.get(id=user_id, field="data") | ||||
|             # --------基于提示词,生成sql以及图表类型 | ||||
|             sys_temp = sql_temp['system'].format(engine=config("DB_ENGINE", default='mysql'), lang='中文', | ||||
|                                                  schema=ddl_list, documentation=[train_ddl.train_document], | ||||
|                                                  retrieved_examples_data=question_sql_list, | ||||
|                                                  history=history,retrieved_examples_data=question_sql_list, | ||||
|                                                  data_training=question_sql_list,) | ||||
|             logger.info(f"sys_temp:{sys_temp}") | ||||
|  | ||||
|             user_temp = sql_temp['user'].format(question=question, | ||||
|                                                 current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')) | ||||
|             logger.info(f"user_temp:{user_temp}") | ||||
|             logger.info(f"sys_temp:{sys_temp}") | ||||
|             llm_response = self.submit_prompt( | ||||
|                 [{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs) | ||||
|             logger.info(f"llm_response:{llm_response}") | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 yujj128
					yujj128