Merge branch 'dev' of gitlab-devt.yced.com.cn:lei_y601/sqlbot_agent into dev
# Conflicts: # main_service.py # service/cus_vanna_srevice.py # util/load_ddl_doc.py
This commit is contained in:
		| @@ -1,6 +1,6 @@ | ||||
| from email.policy import default | ||||
| from typing import List | ||||
| import dmPython | ||||
| from typing import List, Union | ||||
| import  dmPython | ||||
| import orjson | ||||
| import pandas as pd | ||||
| from vanna.base import VannaBase | ||||
| @@ -22,7 +22,7 @@ class OpenAICompatibleLLM(VannaBase): | ||||
|         # default parameters - can be overrided using config | ||||
|         self.temperature = 0.5 | ||||
|         self.max_tokens = 5000 | ||||
|         self.conn = None | ||||
|  | ||||
|         if "temperature" in config_file: | ||||
|             self.temperature = config_file["temperature"] | ||||
|  | ||||
| @@ -54,6 +54,46 @@ class OpenAICompatibleLLM(VannaBase): | ||||
|     def system_message(self, message: str) -> any: | ||||
|         return {"role": "system", "content": message} | ||||
|  | ||||
|     def connect_to_dameng( | ||||
|             self, | ||||
|             host: str = None, | ||||
|             dbname: str = None, | ||||
|             user: str = None, | ||||
|             password: str = None, | ||||
|             port: int = None, | ||||
|             **kwargs | ||||
|     ): | ||||
|         conn = None | ||||
|         try: | ||||
|             conn = dmPython.connect(user=user, password=password, server=host, port=port) | ||||
|         except Exception as e: | ||||
|             raise Exception(f"Failed to connect to dameng database: {e}") | ||||
|  | ||||
|         def run_sql_damengsql(sql: str) -> Union[pd.DataFrame, None]: | ||||
|             if conn: | ||||
|                 try: | ||||
|                     # conn.ping(reconnect=True) | ||||
|                     cs = conn.cursor() | ||||
|                     cs.execute(sql) | ||||
|                     results = cs.fetchall() | ||||
|  | ||||
|                     # Create a pandas dataframe from the results | ||||
|                     df = pd.DataFrame( | ||||
|                         results, columns=[desc[0] for desc in cs.description] | ||||
|                     ) | ||||
|  | ||||
|                     return df | ||||
|  | ||||
|  | ||||
|  | ||||
|                 except Exception as e: | ||||
|                     conn.rollback() | ||||
|                     raise e | ||||
|             return None | ||||
|  | ||||
|         self.run_sql_is_set = True | ||||
|         self.run_sql = run_sql_damengsql | ||||
|  | ||||
|     def user_message(self, message: str) -> any: | ||||
|         return {"role": "user", "content": message} | ||||
|  | ||||
| @@ -140,21 +180,6 @@ class OpenAICompatibleLLM(VannaBase): | ||||
|  | ||||
|         return response.choices[0].message.content | ||||
|  | ||||
|     # def connect_to_dameng(self, host, port, username, password, database): | ||||
|     #     try: | ||||
|     #         self.conn = dmPython.connect( | ||||
|     #             user=username, | ||||
|     #             password=password, | ||||
|     #             server=host, | ||||
|     #             port=port,  # 达梦默认端口5236 | ||||
|     #             autoCommit=True | ||||
|     #         ) | ||||
|     #         print("达梦数据库连接成功") | ||||
|     #         return True | ||||
|     #     except Exception as e: | ||||
|     #         print(f"连接失败: {e}") | ||||
|     #         return False | ||||
|  | ||||
|     def generate_sql_2(self, question: str, allow_llm_to_see_data=False, **kwargs) -> dict: | ||||
|         try: | ||||
|             question_sql_list = self.get_similar_question_sql(question, **kwargs) | ||||
| @@ -164,7 +189,7 @@ class OpenAICompatibleLLM(VannaBase): | ||||
|             sql_temp = template['template']['sql'] | ||||
|             char_temp = template['template']['chart'] | ||||
|             # --------基于提示词,生成sql以及图表类型 | ||||
|             sys_temp = sql_temp['system'].format(engine=config("DATA_SOURCE_TYPE", default='mysql'), lang='中文', | ||||
|             sys_temp = sql_temp['system'].format(engine=config("DB_ENGINE", default='mysql'), lang='中文', | ||||
|                                                  schema=ddl_list, documentation=doc_list, | ||||
|                                                  data_training=question_sql_list) | ||||
|             print("sys_temp", sys_temp) | ||||
| @@ -180,7 +205,7 @@ class OpenAICompatibleLLM(VannaBase): | ||||
|             # ---------------生成图表 | ||||
|             char_type = get_chart_type_from_sql_answer(llm_response) | ||||
|             if char_type: | ||||
|                 sys_char_temp = char_temp['system'].format(engine=config("DATA_SOURCE_TYPE", default='mysql'), | ||||
|                 sys_char_temp = char_temp['system'].format(engine=config("DB_ENGINE", default='mysql'), | ||||
|                                                            lang='中文', sql=sql, chart_type=char_type) | ||||
|                 user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question) | ||||
|                 llm_response2 = self.submit_prompt( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 yujj128
					yujj128