批量添加ddl到向量数据库,添加documention,重写generate_rewritten_question
This commit is contained in:
		| @@ -146,21 +146,23 @@ class OpenAICompatibleLLM(VannaBase): | ||||
|         template = get_base_template() | ||||
|         sql_temp = template['template']['sql'] | ||||
|         char_temp = template['template']['chart'] | ||||
|         # --------基于提示词,生成sql以及图标类型 | ||||
|         sys_temp = sql_temp['system'].format(engine='sqlite', lang='中文', schema=ddl_list, documentation=doc_list, | ||||
|         # --------基于提示词,生成sql以及图表类型 | ||||
|         sys_temp = sql_temp['system'].format(engine=config("DATA_SOURCE_TYPE",default='mysql'), lang='中文', schema=ddl_list, documentation=doc_list, | ||||
|                                              data_training=question_sql_list) | ||||
|         print("sys_temp", sys_temp) | ||||
|         user_temp = sql_temp['user'].format(question=question, | ||||
|                                             current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')) | ||||
|         print("user_temp", user_temp) | ||||
|         llm_response = self.submit_prompt( | ||||
|             [{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs) | ||||
|         print(llm_response) | ||||
|         result = {"resp": orjson.loads(extract_nested_json(llm_response))} | ||||
|  | ||||
|         print("result", result) | ||||
|         sql = check_and_get_sql(llm_response) | ||||
|         # ---------------生成图表 | ||||
|         char_type = get_chart_type_from_sql_answer(llm_response) | ||||
|         if char_type: | ||||
|             sys_char_temp = char_temp['system'].format(engine='sqlite', lang='中文', sql=sql, chart_type=char_type) | ||||
|             sys_char_temp = char_temp['system'].format(engine=config("DATA_SOURCE_TYPE",default='mysql'), lang='中文', sql=sql, chart_type=char_type) | ||||
|             user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question) | ||||
|             llm_response2 = self.submit_prompt( | ||||
|                 [{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}], **kwargs) | ||||
| @@ -168,6 +170,19 @@ class OpenAICompatibleLLM(VannaBase): | ||||
|             result['chart'] = orjson.loads(extract_nested_json(llm_response2)) | ||||
|         return result | ||||
|  | ||||
|     def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str: | ||||
|         print("new_question---------------", new_question) | ||||
|         if last_question is None: | ||||
|             return new_question | ||||
|  | ||||
|         prompt = [ | ||||
|             self.system_message( | ||||
|                 "Your goal is to combine a sequence of questions into a singular question if they are related. If the second question does not relate to the first question and is fully self-contained, return the second question. Return just the new combined question with no additional explanations. The question should theoretically be answerable with a single SQL statement."), | ||||
|             self.user_message(new_question), | ||||
|         ] | ||||
|  | ||||
|         return self.submit_prompt(prompt=prompt, **kwargs) | ||||
|  | ||||
|  | ||||
| class CustomQdrant_VectorStore(Qdrant_VectorStore): | ||||
|     def __init__( | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 yujj128
					yujj128