diff --git a/logging_config.py b/logging_config.py new file mode 100644 index 0000000..5555099 --- /dev/null +++ b/logging_config.py @@ -0,0 +1,62 @@ +# logging_config.py +import logging +import logging.config +from pathlib import Path + +# 确保 logs 目录存在 +log_dir = Path("logs") +log_dir.mkdir(exist_ok=True) + +LOGGING_CONFIG = { + "version": 1, + "disable_existing_loggers": False, + "formatters": { + "default": { + "format": "%(asctime)s - %(name)s - %(levelname)s - %(filename)s:%(lineno)d - %(message)s", + }, + "detailed": { + "format": "%(asctime)s - %(name)s - %(levelname)s - %(funcName)s - %(message)s", + } + }, + "handlers": { + "console": { + "class": "logging.StreamHandler", + "level": "INFO", + "formatter": "default", + "stream": "ext://sys.stdout" + }, + "file": { + "class": "logging.handlers.RotatingFileHandler", # 自动轮转 + "level": "INFO", + "formatter": "detailed", + "filename": "logs/sqlbot.log", + "maxBytes": 10485760, # 10MB + "backupCount": 5, # 保留5个备份 + "encoding": "utf8" + }, + }, + "root": { + "level": "INFO", + "handlers": ["console", "file"] + }, + "loggers": { + "uvicorn": { + "level": "INFO", + "handlers": ["console", "file"], + "propagate": False + }, + "uvicorn.error": { + "level": "INFO", + "handlers": ["console", "file"], + "propagate": False + }, + "uvicorn.access": { + "level": "WARNING", # 只记录警告以上,避免刷屏 + "handlers": ["file"], # 只写入文件 + "propagate": False + } + } +} + +# 应用配置 +logging.config.dictConfig(LOGGING_CONFIG) \ No newline at end of file diff --git a/main_service.py b/main_service.py index 5c18e13..26a5253 100644 --- a/main_service.py +++ b/main_service.py @@ -1,5 +1,8 @@ from email.policy import default +import dmPython +import logging +from logging_config import LOGGING_CONFIG from service.cus_vanna_srevice import CustomVanna, QdrantClient from decouple import config import flask @@ -7,6 +10,8 @@ from util import load_ddl_doc from flask import Flask, Response, jsonify, request, send_from_directory + +logger = logging.getLogger(__name__) def connect_database(vn): db_type = config('DATA_SOURCE_TYPE', default='sqlite') if db_type == 'sqlite': @@ -17,6 +22,8 @@ def connect_database(vn): user=config('MYSQL_DATABASE_USER', default=''), password=config('MYSQL_DATABASE_PASSWORD', default=''), dbname=config('MYSQL_DATABASE_DBNAME', default='')) + elif db_type == 'dameng': + vn.connect_to_dameng( ) elif db_type == 'postgresql': # 待补充 pass @@ -81,22 +88,78 @@ def generate_sql_2(): text: type: string """ - + logger.info("Start to generate sql in main") question = flask.request.args.get("question") if question is None: return jsonify({"type": "error", "error": "No question provided"}) - id = cache.generate_id(question=question) - data = vn.generate_sql_2(question=question) - data['id'] =id - sql = data["resp"]["sql"] - print("sql:",sql) - cache.set(id=id, field="question", value=question) - cache.set(id=id, field="sql", value=sql) - print("data---------------------------",data) + try: + id = cache.generate_id(question=question) + data = vn.generate_sql_2(question=question) + data['id'] = id + sql = data["resp"]["sql"] + print("sql:", sql) + cache.set(id=id, field="question", value=question) + cache.set(id=id, field="sql", value=sql) + print("data---------------------------", data) + return jsonify(data) + except Exception as e: + return jsonify({"type": "error", "error": str(e)}) - return jsonify(data) +@app.flask_app.route("/api/v0/run_sql_2", methods=["GET"]) +@app.requires_cache(["sql"]) +def run_sql_2(id: str, sql: str): + """ + Run SQL + --- + parameters: + - name: user + in: query + - name: id + in: query|body + type: string + required: true + responses: + 200: + schema: + type: object + properties: + type: + type: string + default: df + id: + type: string + df: + type: object + should_generate_chart: + type: boolean + """ + logger.info("Start to run sql in main") + try: + if not vn.run_sql_is_set: + return jsonify( + { + "type": "error", + "error": "Please connect to a database using vn.connect_to_... in order to run SQL queries.", + } + ) + + df = vn.run_sql(sql=sql) + logger.info("") + app.cache.set(id=id, field="df", value=df) + x = df.head(10).to_dict(orient='records') + logger.info("df ---------------{0} {1}".format(x,type(x))) + return jsonify( + { + "type": "df", + "id": id, + "df": df.head(10).to_dict(orient='records'), + } + ) + + except Exception as e: + return jsonify({"type": "sql_error", "error": str(e)}) if __name__ == '__main__': diff --git a/service/cus_vanna_srevice.py b/service/cus_vanna_srevice.py index 22375cd..569c868 100644 --- a/service/cus_vanna_srevice.py +++ b/service/cus_vanna_srevice.py @@ -1,6 +1,6 @@ from email.policy import default from typing import List - +import dmPython import orjson import pandas as pd from vanna.base import VannaBase @@ -22,7 +22,7 @@ class OpenAICompatibleLLM(VannaBase): # default parameters - can be overrided using config self.temperature = 0.5 self.max_tokens = 5000 - + self.conn = None if "temperature" in config_file: self.temperature = config_file["temperature"] @@ -140,36 +140,57 @@ class OpenAICompatibleLLM(VannaBase): return response.choices[0].message.content + # def connect_to_dameng(self, host, port, username, password, database): + # try: + # self.conn = dmPython.connect( + # user=username, + # password=password, + # server=host, + # port=port, # 达梦默认端口5236 + # autoCommit=True + # ) + # print("达梦数据库连接成功") + # return True + # except Exception as e: + # print(f"连接失败: {e}") + # return False + def generate_sql_2(self, question: str, allow_llm_to_see_data=False, **kwargs) -> dict: - question_sql_list = self.get_similar_question_sql(question, **kwargs) - ddl_list = self.get_related_ddl(question, **kwargs) - doc_list = self.get_related_documentation(question, **kwargs) - template = get_base_template() - sql_temp = template['template']['sql'] - char_temp = template['template']['chart'] - # --------基于提示词,生成sql以及图表类型 - sys_temp = sql_temp['system'].format(engine=config("DATA_SOURCE_TYPE",default='mysql'), lang='中文', schema=ddl_list, documentation=doc_list, - data_training=question_sql_list) - print("sys_temp", sys_temp) - user_temp = sql_temp['user'].format(question=question, - current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')) - print("user_temp", user_temp) - llm_response = self.submit_prompt( - [{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs) - print(llm_response) - result = {"resp": orjson.loads(extract_nested_json(llm_response))} - print("result", result) - sql = check_and_get_sql(llm_response) - # ---------------生成图表 - char_type = get_chart_type_from_sql_answer(llm_response) - if char_type: - sys_char_temp = char_temp['system'].format(engine=config("DATA_SOURCE_TYPE",default='mysql'), lang='中文', sql=sql, chart_type=char_type) - user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question) - llm_response2 = self.submit_prompt( - [{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}], **kwargs) - print(llm_response2) - result['chart'] = orjson.loads(extract_nested_json(llm_response2)) - return result + try: + question_sql_list = self.get_similar_question_sql(question, **kwargs) + ddl_list = self.get_related_ddl(question, **kwargs) + doc_list = self.get_related_documentation(question, **kwargs) + template = get_base_template() + sql_temp = template['template']['sql'] + char_temp = template['template']['chart'] + # --------基于提示词,生成sql以及图表类型 + sys_temp = sql_temp['system'].format(engine=config("DATA_SOURCE_TYPE", default='mysql'), lang='中文', + schema=ddl_list, documentation=doc_list, + data_training=question_sql_list) + print("sys_temp", sys_temp) + user_temp = sql_temp['user'].format(question=question, + current_time=datetime.now().strftime('%Y-%m-%d %H:%M:%S')) + print("user_temp", user_temp) + llm_response = self.submit_prompt( + [{'role': 'system', 'content': sys_temp}, {'role': 'user', 'content': user_temp}], **kwargs) + print(llm_response) + result = {"resp": orjson.loads(extract_nested_json(llm_response))} + print("result", result) + sql = check_and_get_sql(llm_response) + # ---------------生成图表 + char_type = get_chart_type_from_sql_answer(llm_response) + if char_type: + sys_char_temp = char_temp['system'].format(engine=config("DATA_SOURCE_TYPE", default='mysql'), + lang='中文', sql=sql, chart_type=char_type) + user_char_temp = char_temp['user'].format(sql=sql, chart_type=char_type, question=question) + llm_response2 = self.submit_prompt( + [{'role': 'system', 'content': sys_char_temp}, {'role': 'user', 'content': user_char_temp}], + **kwargs) + print(llm_response2) + result['chart'] = orjson.loads(extract_nested_json(llm_response2)) + return result + except Exception as e: + raise e def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str: print("new_question---------------", new_question) diff --git a/template.yaml b/template.yaml index f6f59ac..b006f7d 100644 --- a/template.yaml +++ b/template.yaml @@ -27,6 +27,9 @@ template: 你只能生成查询用的SQL语句,不得生成增删改相关或操作数据库以及操作数据库数据的SQL + + 如果只涉及查询人员信息,但没说具体哪些信息,可以不用查询所有信息,主要查询相关性较强的五个字段即可 + 不要编造内没有提供给你的表结构 @@ -39,6 +42,9 @@ template: 请注意区分'哪些'和'多少'的区别,哪些是指具体信息,多少是指数量,请注意甄别 + + 如遇字符串类型的日期要计算时,务必转化为合理的格式进行计算 + 请使用JSON格式返回你的回答: 若能生成,则返回格式如:{{"success":true,"sql":"你生成的SQL语句","tables":["该SQL用到的表名1","该SQL用到的表名2",...],"chart-type":"table"}} diff --git a/util/load_ddl_doc.py b/util/load_ddl_doc.py index a1e2d5a..fe3727a 100644 --- a/util/load_ddl_doc.py +++ b/util/load_ddl_doc.py @@ -103,6 +103,12 @@ list_documentions = [ """ <人员库表注意事项> + 查询address时,尽量使用like查询,如:select * from 人员库 where address like '%张三%'; + 语法为mysql语法; + 如果涉及下面中的字段需要展示给用户看时请替换成相关代表 + birthday 字段涉及计算时,请转化为合理格式计算 + + person_status 字段 1代表草稿,2代表审批中,3代表制卡中,4代表已入库,5代表停用; gender 字段 1代表男,2代表女 is_internal 字段 0代表否,1代表是 @@ -115,9 +121,7 @@ list_documentions = [ is_subcontractor 字段 0代表否,1代表是 is_sign_confidentiality_agreement 字段 0代表否,1代表是 DHDATASTA 字段 0代表新增 1代表更新 - 查询address时,尽量使用like查询,如:select * from 人员库 where address like '%张三%'; - 语法为mysql语法; - + """, ]