feat:添加达梦数据库支持以及对接pre环境数据
This commit is contained in:
		| @@ -1,6 +1,6 @@ | ||||
| from email.policy import default | ||||
| from typing import List | ||||
|  | ||||
| from typing import List, Union | ||||
| import  dmPython | ||||
| import orjson | ||||
| import pandas as pd | ||||
| from vanna.base import VannaBase | ||||
| @@ -54,6 +54,46 @@ class OpenAICompatibleLLM(VannaBase): | ||||
|     def system_message(self, message: str) -> any: | ||||
|         return {"role": "system", "content": message} | ||||
|  | ||||
|     def connect_to_dameng( | ||||
|             self, | ||||
|             host: str = None, | ||||
|             dbname: str = None, | ||||
|             user: str = None, | ||||
|             password: str = None, | ||||
|             port: int = None, | ||||
|             **kwargs | ||||
|     ): | ||||
|         conn = None | ||||
|         try: | ||||
|             conn = dmPython.connect(user=user, password=password, server=host, port=port) | ||||
|         except Exception as e: | ||||
|             raise Exception(f"Failed to connect to dameng database: {e}") | ||||
|  | ||||
|         def run_sql_damengsql(sql: str) -> Union[pd.DataFrame, None]: | ||||
|             if conn: | ||||
|                 try: | ||||
|                     # conn.ping(reconnect=True) | ||||
|                     cs = conn.cursor() | ||||
|                     cs.execute(sql) | ||||
|                     results = cs.fetchall() | ||||
|  | ||||
|                     # Create a pandas dataframe from the results | ||||
|                     df = pd.DataFrame( | ||||
|                         results, columns=[desc[0] for desc in cs.description] | ||||
|                     ) | ||||
|  | ||||
|                     return df | ||||
|  | ||||
|  | ||||
|  | ||||
|                 except Exception as e: | ||||
|                     conn.rollback() | ||||
|                     raise e | ||||
|             return None | ||||
|  | ||||
|         self.run_sql_is_set = True | ||||
|         self.run_sql = run_sql_damengsql | ||||
|  | ||||
|     def user_message(self, message: str) -> any: | ||||
|         return {"role": "user", "content": message} | ||||
|  | ||||
| @@ -148,7 +188,7 @@ class OpenAICompatibleLLM(VannaBase): | ||||
|         sql_temp = template['template']['sql'] | ||||
|         char_temp = template['template']['chart'] | ||||
|         # --------基于提示词,生成sql以及图表类型 | ||||
|         sys_temp = sql_temp['system'].format(engine=config("DATA_SOURCE_TYPE",default='mysql'), lang='中文', schema=ddl_list, documentation=doc_list, | ||||
|         sys_temp = sql_temp['system'].format(engine=config("DB_ENGINE",default='mysql'), lang='中文', schema=ddl_list, documentation=doc_list, | ||||
|                                              data_training=question_sql_list) | ||||
|         print("sys_temp", sys_temp) | ||||
|         user_temp = sql_temp['user'].format(question=question, | ||||
| @@ -163,7 +203,7 @@ class OpenAICompatibleLLM(VannaBase): | ||||
|         # ---------------生成图表 | ||||
|         char_type = get_chart_type_from_sql_answer(llm_response) | ||||
|         if char_type: | ||||
|             sys_char_temp = char_temp['system'].format(engine=config("DATA_SOURCE_TYPE",default='mysql'), lang='中文', sql=sql, chart_type=char_type) | ||||
|             sys_char_temp = char_temp['system'].format(engine=config("DB_ENGINE",default='mysql'), lang='中文', sql=sql, chart_type=char_type) | ||||
|             user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question) | ||||
|             llm_response2 = self.submit_prompt( | ||||
|                 [{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}], **kwargs) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 雷雨
					雷雨