连续对话,上下文管理
This commit is contained in:
@@ -29,6 +29,17 @@ class QuestionFeedBack(Base):
|
||||
is_process = Column(Boolean, nullable=False, default=False)
|
||||
|
||||
|
||||
class Conversation(Base):
|
||||
__tablename__ = 'db_conversation'
|
||||
id = Column(String(255), primary_key=True)
|
||||
create_time = Column(DateTime, nullable=False, )
|
||||
question = Column(String(500), nullable=False)
|
||||
sql = Column(String(500), nullable=True)
|
||||
user_id = Column(String(100), nullable=False)
|
||||
cvs_id = Column(String(100), nullable=False)
|
||||
meta = Column(Text, nullable=True)
|
||||
|
||||
|
||||
class PredefinedQuestion(Base):
|
||||
# 定义表名,预制问题表
|
||||
__tablename__ = 'db_predefined_question'
|
||||
|
||||
144
main_service.py
144
main_service.py
@@ -1,10 +1,12 @@
|
||||
import copy
|
||||
import logging
|
||||
import time
|
||||
from functools import wraps
|
||||
import util.utils
|
||||
from logging_config import LOGGING_CONFIG
|
||||
from service.cus_vanna_srevice import CustomVanna, QdrantClient, TTLCacheWrapper
|
||||
from service.question_feedback_service import save_save_question_async,query_predefined_question_list
|
||||
from service.conversation_service import save_conversation,update_conversation,get_sql_by_id
|
||||
from decouple import config
|
||||
import flask
|
||||
from util import load_ddl_doc
|
||||
@@ -12,6 +14,12 @@ from flask import Flask, Response, jsonify, request
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def generate_timestamp_id():
|
||||
"""生成基于时间戳的ID"""
|
||||
# 获取当前时间戳(秒级)
|
||||
timestamp = int(time.time() * 1000)
|
||||
return f"Q{timestamp}"
|
||||
|
||||
|
||||
def connect_database(vn):
|
||||
db_type = config('DATA_SOURCE_TYPE', default='sqlite')
|
||||
@@ -82,12 +90,11 @@ def generate_sql_2():
|
||||
Generate SQL from a question
|
||||
---
|
||||
parameters:
|
||||
- name: user
|
||||
- name: user_id
|
||||
in: query
|
||||
- name: question
|
||||
in: query
|
||||
type: string
|
||||
required: true
|
||||
- name: question_id
|
||||
responses:
|
||||
200:
|
||||
schema:
|
||||
@@ -105,59 +112,101 @@ def generate_sql_2():
|
||||
question = flask.request.args.get("question")
|
||||
if question is None:
|
||||
return jsonify({"type": "error", "error": "No question provided"})
|
||||
|
||||
user_id = request.args.get("user_id")
|
||||
cvs_id = request.args.get("cvs_id")
|
||||
need_context = bool(request.args.get("need_context"))
|
||||
if user_id is None or cvs_id is None:
|
||||
return jsonify({"type": "error", "error": "No user_id or cvs_id provided"})
|
||||
id = generate_timestamp_id()
|
||||
logger.info(f"question_id: {id} user_id: {user_id} cvs_id: {cvs_id} question: {question}")
|
||||
save_conversation(id,user_id,cvs_id,question)
|
||||
try:
|
||||
id = cache.generate_id(question=question)
|
||||
user_id = request.args.get("user_id")
|
||||
logger.info(f"Generate sql for {question}")
|
||||
data = vn.generate_sql_2(question=question, cache=cache, user_id=user_id)
|
||||
data = vn.generate_sql_2(user_id,cvs_id,question,id,need_context)
|
||||
logger.info("Generate sql result is {0}".format(data))
|
||||
data['id'] = id
|
||||
sql = data["resp"]["sql"]
|
||||
logger.info("generate sql is : " + sql)
|
||||
cache.set(id=id, field="question", value=question)
|
||||
cache.set(id=id, field="sql", value=sql)
|
||||
data["type"] = "success"
|
||||
logger.info("generate sql is : "+ sql)
|
||||
update_conversation(cvs_id, id, sql)
|
||||
save_save_question_async(id, user_id, question, sql)
|
||||
data["type"]="success"
|
||||
return jsonify(data)
|
||||
except Exception as e:
|
||||
logger.error(f"generate sql failed:{e}")
|
||||
return jsonify({"type": "error", "error": str(e)})
|
||||
|
||||
|
||||
def session_save(func):
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
id = request.args.get("id")
|
||||
user_id = request.args.get("user_id")
|
||||
logger.info(f" id: {id},user_id: {user_id}")
|
||||
result = func(*args, **kwargs)
|
||||
# def requires_cache_2(required_keys):
|
||||
# def decorator(f):
|
||||
# @wraps(f)
|
||||
# def decorated(*args, **kwargs):
|
||||
# id = request.args.get("id")
|
||||
# user_id = request.args.get("user_id")
|
||||
# if user_id is None:
|
||||
# user_id = request.json.get("user_id")
|
||||
# if user_id is None:
|
||||
# return jsonify({"type": "error", "error": "No user_id provided"})
|
||||
# if id is None:
|
||||
# id = request.json.get("id")
|
||||
# if id is None:
|
||||
# return jsonify({"type": "error", "error": "No id provided"})
|
||||
# all_v = cache.items()
|
||||
# logger.info(f"all values {all_v}")
|
||||
# logger.info(f"user {user_id} id {id}")
|
||||
# qa_list = cache.get(id=user_id, field="qa_list")
|
||||
# if qa_list is None:
|
||||
# return jsonify({"type": "error", "error": f"No qa_list found"})
|
||||
# logger.info(f"qa_list {qa_list}")
|
||||
# q_a = list(filter(lambda x: x["id"] == id, qa_list))
|
||||
# logger.info(f"q_a {q_a}")
|
||||
# for key in required_keys:
|
||||
# if q_a[0][key] is None:
|
||||
# return jsonify({"type": "error", "error": f"No {key} found for id:{id}"})
|
||||
# values = {key:q_a[0][key] for key in required_keys}
|
||||
# values["id"] = id
|
||||
# logger.info("cache values {0}".format(values))
|
||||
#
|
||||
# return f(*args, **values, **kwargs)
|
||||
#
|
||||
# return decorated
|
||||
#
|
||||
# return decorator
|
||||
|
||||
datas = []
|
||||
session_len = int(config("SESSION_LENGTH", default=2))
|
||||
if cache.exists(id=user_id, field="data"):
|
||||
datas = copy.deepcopy(cache.get(id=user_id, field="data"))
|
||||
data = {
|
||||
"id": id,
|
||||
"question": cache.get(id=id, field="question"),
|
||||
"sql": cache.get(id=id, field="sql")
|
||||
}
|
||||
datas.append(data)
|
||||
logger.info("datas is {0}".format(datas))
|
||||
if len(datas) > session_len and session_len > 0:
|
||||
datas = datas[-session_len:]
|
||||
# 删除id对应的所有缓存值,因为已经run_sql完毕,改用user_id保存为上下文
|
||||
cache.delete(id=id, field="question")
|
||||
cache.set(id=user_id, field="data", value=copy.deepcopy(datas))
|
||||
logger.info(f" user data {cache.get(user_id, field='data')}")
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
# def session_save(func):
|
||||
# @wraps(func)
|
||||
# def wrapper(*args, **kwargs):
|
||||
# id = request.args.get("id")
|
||||
# user_id = request.args.get("user_id")
|
||||
# logger.info(f" id: {id},user_id: {user_id}")
|
||||
# result = func(*args, **kwargs)
|
||||
#
|
||||
# datas = []
|
||||
# session_len = int(config("SESSION_LENGTH", default=2))
|
||||
# if cache.exists(id=user_id, field="qa_list"):
|
||||
# datas = copy.deepcopy(cache.get(id=user_id, field="qa_list"))
|
||||
# logger.info("datas is {0}".format(datas))
|
||||
# if len(datas) > session_len and session_len > 0:
|
||||
# logger.info(f"开始裁剪-------------------------------------")
|
||||
# datas=datas[-session_len:]
|
||||
# # 删除id对应的所有缓存值,因为已经run_sql完毕,改用user_id保存为上下文
|
||||
# # cache.delete(id=id, field="question")
|
||||
# print("datas---------------------{0}".format(datas))
|
||||
# cache.set(id=user_id, field="qa_list", value=copy.deepcopy(datas))
|
||||
# logger.info(f" user data {cache.get(user_id, field='qa_list')}")
|
||||
# return result
|
||||
#
|
||||
# return wrapper
|
||||
|
||||
|
||||
|
||||
@app.flask_app.route("/yj_sqlbot/api/v0/run_sql_2", methods=["GET"])
|
||||
@session_save
|
||||
@app.requires_cache(["sql"])
|
||||
def run_sql_2(id: str, sql: str):
|
||||
# @session_save
|
||||
# @requires_cache_2(required_keys=["sql"])
|
||||
def run_sql_2():
|
||||
"""
|
||||
Run SQL
|
||||
---
|
||||
@@ -169,10 +218,6 @@ def run_sql_2(id: str, sql: str):
|
||||
in: query|body
|
||||
type: string
|
||||
required: true
|
||||
- name: page_size
|
||||
in: query
|
||||
-name: page_num
|
||||
in: query
|
||||
responses:
|
||||
200:
|
||||
schema:
|
||||
@@ -190,7 +235,9 @@ def run_sql_2(id: str, sql: str):
|
||||
"""
|
||||
logger.info("Start to run sql in main")
|
||||
try:
|
||||
user_id = request.args.get("user_id")
|
||||
id = request.args.get("id")
|
||||
sql = get_sql_by_id(id)
|
||||
logger.info(f"sql is {sql}")
|
||||
if not vn.run_sql_is_set:
|
||||
return jsonify(
|
||||
{
|
||||
@@ -199,11 +246,6 @@ def run_sql_2(id: str, sql: str):
|
||||
}
|
||||
)
|
||||
|
||||
# count_sql = f"SELECT COUNT(*) AS total_count FROM ({sql}) AS subquery"
|
||||
# df_count = vn.run_sql(count_sql)
|
||||
# print(df_count,"is type",type(df_count))
|
||||
# total_count = df_count.to_dict(orient="records")[0]["total_count"]
|
||||
# logger.info("Total count is {0}".format(total_count))
|
||||
df = vn.run_sql_2(sql=sql)
|
||||
result = df.to_dict(orient='records')
|
||||
logger.info("df ---------------{0} {1}".format(result, type(result)))
|
||||
@@ -219,6 +261,10 @@ def run_sql_2(id: str, sql: str):
|
||||
logger.error(f"run sql failed:{e}")
|
||||
return jsonify({"type": "sql_error", "error": str(e)})
|
||||
|
||||
@app.flask_app.route("/yj_sqlbot/api/v0/conversations", methods=["GET"])
|
||||
def conversations():
|
||||
user_id = request.args.get("user_id")
|
||||
|
||||
|
||||
@app.flask_app.route("/yj_sqlbot/api/v0/verify", methods=["GET"])
|
||||
def verify_user():
|
||||
|
||||
85
service/conversation_service.py
Normal file
85
service/conversation_service.py
Normal file
@@ -0,0 +1,85 @@
|
||||
from datetime import datetime
|
||||
|
||||
from db_util.db_main import Conversation, SqliteSqlalchemy
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def save_conversation(id, user_id, cvs_id, question):
|
||||
cvs = Conversation(id=id, user_id=user_id, cvs_id=cvs_id, question=question, create_time = datetime.now())
|
||||
session = SqliteSqlalchemy().session
|
||||
try:
|
||||
session.add(cvs)
|
||||
session.commit()
|
||||
except:
|
||||
session.rollback()
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
def get_conversation(cvs_id: str):
|
||||
session = SqliteSqlalchemy().session
|
||||
try:
|
||||
results = session.query(Conversation).filter(Conversation.id == cvs_id)
|
||||
logger.info(f"conversation {cvs_id} results is {results}")
|
||||
return results.all()
|
||||
except Exception as e:
|
||||
logger.info(f"get conversation with id {cvs_id} error {e}")
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
# def get_all_conversations_by_user(user_id):
|
||||
# session = SqliteSqlalchemy().session
|
||||
# user_cvs = []
|
||||
# try:
|
||||
# results = session.query(Conversation).filter(Conversation.user_id == user_id).all()
|
||||
# logger.info(f"conversation {user_id} results is {results}")
|
||||
# cvs = {}
|
||||
# for rs in results:
|
||||
# cvs[rs.cvs_id] = {
|
||||
#
|
||||
# }
|
||||
|
||||
|
||||
def update_conversation(cvs_id: str, id: str, sql=None, meta=None):
|
||||
"""更新sql到对应question"""
|
||||
session = SqliteSqlalchemy().session
|
||||
try:
|
||||
if sql:
|
||||
session.query(Conversation).filter(Conversation.cvs_id == cvs_id, Conversation.id == id).update({Conversation.sql: sql})
|
||||
if meta:
|
||||
session.query(Conversation).filter(Conversation.cvs_id == cvs_id, Conversation.id == id).update({Conversation.meta: meta})
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
def get_latest_question(cvs_id, user_id, limit_count):
|
||||
"""获取指定会话的最新问题"""
|
||||
session = SqliteSqlalchemy().session
|
||||
try:
|
||||
latest_conversation = session.query(Conversation).filter_by(
|
||||
cvs_id=cvs_id,
|
||||
user_id=user_id
|
||||
).order_by(Conversation.create_time.desc()).limit(limit_count).all()
|
||||
last_question = [cs.question for cs in latest_conversation]
|
||||
return last_question
|
||||
except Exception as e:
|
||||
logger.error(f"get_latest_question error {e}")
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
def get_sql_by_id(id: str):
|
||||
session = SqliteSqlalchemy().session
|
||||
try:
|
||||
result = session.query(Conversation).filter_by(id=id).first()
|
||||
if result:
|
||||
return result.sql
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"get_sql_by_id error {e}")
|
||||
finally:
|
||||
session.close()
|
||||
@@ -21,6 +21,10 @@ import logging
|
||||
from util import train_ddl
|
||||
logger = logging.getLogger(__name__)
|
||||
import traceback
|
||||
from service.conversation_service import get_latest_question,update_conversation
|
||||
|
||||
limit_count = 3
|
||||
|
||||
class OpenAICompatibleLLM(VannaBase):
|
||||
def __init__(self, client=None, config_file=None):
|
||||
VannaBase.__init__(self, config=config_file)
|
||||
@@ -124,24 +128,25 @@ class OpenAICompatibleLLM(VannaBase):
|
||||
return {"role": "assistant", "content": message}
|
||||
|
||||
def submit_prompt(self, prompt, **kwargs) -> str:
|
||||
logger.info(f"submit prompt: {prompt}")
|
||||
if prompt is None:
|
||||
print("test1")
|
||||
logger.info("test1")
|
||||
raise Exception("Prompt is None")
|
||||
|
||||
if len(prompt) == 0:
|
||||
print("test2")
|
||||
logger.info("test2")
|
||||
raise Exception("Prompt is empty")
|
||||
print(prompt)
|
||||
|
||||
num_tokens = 0
|
||||
for message in prompt:
|
||||
num_tokens += len(message["content"]) / 4
|
||||
print("test3 {0}".format(num_tokens))
|
||||
logger.info("test3 {0}".format(num_tokens))
|
||||
|
||||
if kwargs.get("model", None) is not None:
|
||||
print("test4")
|
||||
logger.info("test4")
|
||||
model = kwargs.get("model", None)
|
||||
print(
|
||||
logger.info(
|
||||
f"Using model {model} for {num_tokens} tokens (approx)"
|
||||
)
|
||||
response = self.client.chat.completions.create(
|
||||
@@ -152,9 +157,9 @@ class OpenAICompatibleLLM(VannaBase):
|
||||
temperature=self.temperature,
|
||||
)
|
||||
elif kwargs.get("engine", None) is not None:
|
||||
print("test5")
|
||||
logger.info("test5")
|
||||
engine = kwargs.get("engine", None)
|
||||
print(
|
||||
logger.info(
|
||||
f"Using model {engine} for {num_tokens} tokens (approx)"
|
||||
)
|
||||
response = self.client.chat.completions.create(
|
||||
@@ -165,8 +170,8 @@ class OpenAICompatibleLLM(VannaBase):
|
||||
temperature=self.temperature,
|
||||
)
|
||||
elif self.config is not None and "engine" in self.config:
|
||||
print("test6")
|
||||
print(
|
||||
logger.info("test6")
|
||||
logger.info(
|
||||
f"Using engine {self.config['engine']} for {num_tokens} tokens (approx)"
|
||||
)
|
||||
response = self.client.chat.completions.create(
|
||||
@@ -177,11 +182,11 @@ class OpenAICompatibleLLM(VannaBase):
|
||||
temperature=self.temperature,
|
||||
)
|
||||
elif self.config is not None and "model" in self.config:
|
||||
print("test7")
|
||||
print(
|
||||
logger.info("test7")
|
||||
logger.info(
|
||||
f"Using model {self.config['model']} for {num_tokens} tokens (approx)"
|
||||
)
|
||||
print("config is ",self.config)
|
||||
logger.info("config is ",self.config)
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.config["model"],
|
||||
messages=prompt,
|
||||
@@ -201,13 +206,13 @@ class OpenAICompatibleLLM(VannaBase):
|
||||
# json=data
|
||||
# )
|
||||
else:
|
||||
print("test8")
|
||||
logger.info("test8")
|
||||
if num_tokens > 3500:
|
||||
model = "kimi"
|
||||
else:
|
||||
model = "doubao"
|
||||
|
||||
print(f"5.Using model {model} for {num_tokens} tokens (approx)")
|
||||
logger.info(f"5.Using model {model} for {num_tokens} tokens (approx)")
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=model,
|
||||
@@ -222,21 +227,30 @@ class OpenAICompatibleLLM(VannaBase):
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
def generate_sql_2(self, question: str, cache=None,user_id=None, allow_llm_to_see_data=False, **kwargs) -> dict:
|
||||
def generate_sql_2(self, user_id: str, cvs_id: str, question: str,id: str, need_context: bool, allow_llm_to_see_data=False, **kwargs) -> dict:
|
||||
try:
|
||||
logger.info("Start to generate_sql_2 in cus_vanna_srevice")
|
||||
question_sql_list = self.get_similar_question_sql(question, **kwargs)
|
||||
if question_sql_list and len(question_sql_list)>2:
|
||||
question_sql_list=question_sql_list[:2]
|
||||
|
||||
ddl_list = self.get_related_ddl(question, **kwargs)
|
||||
#doc_list = self.get_related_documentation(question, **kwargs)
|
||||
template = get_base_template()
|
||||
sql_temp = template['template']['sql']
|
||||
char_temp = template['template']['chart']
|
||||
history = None
|
||||
if user_id and cache:
|
||||
history = cache.get(id=user_id, field="data")
|
||||
if need_context:
|
||||
questions = get_latest_question(cvs_id, user_id,limit_count)
|
||||
logger.info(f"latest_questions is {questions}")
|
||||
if questions[0] != question:
|
||||
raise Exception(f"上下文不匹配 {question} {questions[0]}")
|
||||
new_question = self.generate_rewritten_question(questions,**kwargs)
|
||||
logger.info(f"new_question is {new_question}")
|
||||
question = new_question if new_question else question
|
||||
update_conversation(cvs_id, id, meta=question)
|
||||
|
||||
# if user_id and cache:
|
||||
# history = cache.get(id=user_id, field="data")
|
||||
# --------基于提示词,生成sql以及图表类型
|
||||
sys_temp = sql_temp['system'].format(engine=config("DB_ENGINE", default='mysql'), lang='中文',
|
||||
schema=ddl_list, documentation=[train_ddl.train_document],
|
||||
@@ -275,6 +289,7 @@ class OpenAICompatibleLLM(VannaBase):
|
||||
logger.info("Finish to generate_sql_2 in cus_vanna_srevice")
|
||||
return result
|
||||
except Exception as e:
|
||||
|
||||
logger.info("cus_vanna_srevice failed-------------------: ")
|
||||
traceback.print_exc()
|
||||
raise e
|
||||
@@ -286,19 +301,166 @@ class OpenAICompatibleLLM(VannaBase):
|
||||
logger.error("run_sql failed {0}".format(sql))
|
||||
raise e
|
||||
|
||||
def generate_rewritten_question(self, last_question: str, new_question: str, **kwargs) -> str:
|
||||
logger.info(f"generate_rewritten_question---------------{new_question}")
|
||||
if last_question is None:
|
||||
def generate_rewritten_question(self, questions: str, **kwargs) -> str:
|
||||
logger.info(f"generate_rewritten_question---------------{questions}")
|
||||
new_question = questions[0]
|
||||
context_question = questions[1:]
|
||||
|
||||
if not context_question:
|
||||
return new_question
|
||||
print("last question {0}".format(last_question))
|
||||
print("last question {0}".format(context_question))
|
||||
print("new question {0}".format(new_question))
|
||||
# sys_info = '''
|
||||
# 你是一个问题补全助手,先判断问题1是否存在信息不完整的情况,如果不完整则根据上下文(问题2,问题3)来补全问题1
|
||||
# (按时间顺序从新到旧:问题1、问题2、问题3),问题1是用户当前提出的问题
|
||||
#
|
||||
# 【准则一】独立性优先
|
||||
# 如果问题1本身含义完整,不依赖其他问题的上下文也能被理解,则直接返回问题1,禁止强行合并。
|
||||
# 【准则二】最新问题优先
|
||||
# 问题1始终作为核心,只判断它是否需要利用前序问题补充自身信息,当它含义完整时,不再考虑合并。
|
||||
# 合并时只能用较旧的问题(问题2、问题3)的信息来补全较新的问题(问题1),不能反向操作。
|
||||
# 要以问题1的中心思想为准,禁止合并后该表问题1的中心思想
|
||||
# 【准则三】单向合并限制
|
||||
# 只有问题1能与其他问题合并,问题2和问题3之间不能单独合并。
|
||||
# 【准则四】顺序依赖判断
|
||||
# 只有当问题1明确依赖问题2或问题3的结果或上下文时,才考虑合并。依赖特征包括:
|
||||
# - 问题1中包含"其中"、"这些"、"这个"、"那些"、"他"、"他们"等指代性词语
|
||||
# - 问题1是在问题2或问题3基础上的细节追问
|
||||
# - 问题1具有明确的顺序逻辑关系
|
||||
# - 问题1缺少必要的主语或时间范围或具体查询信息等上下文信息
|
||||
# 【准则五】合并范围选择
|
||||
# - 如果问题1只依赖问题2,则合并问题2+问题1
|
||||
# - 如果问题1只依赖问题3,则合并问题3+问题1
|
||||
# - 如果问题1依赖问题2,可得出结果,依赖问题3也可得出结果,则就近原则,合并问题2+问题1
|
||||
# - 如果问题1同时依赖问题2和问题3才能得出结果,则合并问题3+问题2+问题1
|
||||
# - 如果问题1不依赖任何其他问题,则直接返回问题1
|
||||
#
|
||||
# 【准则六】合并执行原则
|
||||
# 将选择的问题自然衔接成一个完整问题,不要添加任何解释性文字。
|
||||
#
|
||||
# 【准则七】SQL可行性验证
|
||||
# 合并后的问题应该能够通过一条SQL查询语句来回答。
|
||||
#
|
||||
# 【准则八】兜底措施
|
||||
# 当你无法判断问题一是否完整,也无法判断问题一是否依赖其他问题才能补全信息时,请直接向用户询问细节
|
||||
#
|
||||
#
|
||||
# 示例:
|
||||
# 输入:问题1="早退多少天",问题2="其中迟到多少天",问题3="张三九月工作了多少天"
|
||||
# 输出:"张三九月早退多少天"
|
||||
#
|
||||
# 输入:问题1="这些天里张三是否有迟到,问题2="李四考勤","问题3="张三九月的考勤"
|
||||
# 输出:"张三九月的迟到情况"
|
||||
#
|
||||
# 输入:问题1="最近一个月是否有迟到",问题2="李四考勤",问题3="张三九月的考勤"
|
||||
# 输出:"张三最近一个月是否有迟到"
|
||||
#
|
||||
# 输入:问题1="迟到了多少天",问题2="他哪几天迟到了",问题3="张三九月在林芝是否早退"
|
||||
# 输出:"张三九月在林芝迟到了多少天"
|
||||
#
|
||||
# 输入:问题1="张三九月休息了多少天" ,问题2="张三九月迟到了多少天",问题3="张三其中迟到多少天"
|
||||
# 输出:"张三九月休息了多少天"
|
||||
#
|
||||
# 输入:问题1="张三九月考勤情况" ,问题2="张三九月迟到了多少天",问题3="李四迟到多少天"
|
||||
# 输出:"张三九月考勤情况",
|
||||
# '''
|
||||
# sys_info2 = '''
|
||||
# 你是一个问题补全助手,任务是判断用户当前提出的问题(问题1)是否信息完整。
|
||||
#
|
||||
# 若问题1语义完整且可独立理解,则直接返回原问题1;
|
||||
#
|
||||
# 若问题1信息缺失或存在指代依赖,则根据其前序上下文(问题2 和 问题3,按时间倒序排列:问题1 最新,问题2 次新,问题3 最早)进行最小必要补全,生成一个语义完整、忠实于原意、且能通过一条 SQL 查询回答的问题。
|
||||
#
|
||||
# 请严格遵循以下准则:
|
||||
# <>
|
||||
# <rule_title>独立性优先</rule_title>
|
||||
# <rule>若问题1 本身语义完整(即不依赖问题2 或问题3 也能被准确理解),则直接返回问题1,禁止强行合并或改写。,也不再受下面的规则约束</rule>
|
||||
#
|
||||
# <rule_title>以最新问题为核心</rule_title>
|
||||
#
|
||||
# <rule>问题1 始终是查询意图的唯一来源</rule>
|
||||
# <rule>补全时只能借用问题2 或问题3 中的信息(如主语、时间、地点等),补全问题1的缺失,不得改变问题1 的核心要素。</rule>
|
||||
# <rule>合并后的问题必须完全保留问题1 的原始意图、时间范围、查询对象和动作。</rule>
|
||||
# <rule>如果问题1已有明确的时间、地点、人物等信息,禁止用前序问题的不同信息进行覆盖替换。</rule>
|
||||
#
|
||||
# <rule_title>单向合并限制</rule_title>
|
||||
# <rule>仅允许将问题1 与问题2 或问题3 合并。</rule>
|
||||
# <rule>禁止问题2 与问题3 直接合并,也禁止忽略问题1 进行其他组合。</rule>
|
||||
#
|
||||
# <rule_title>依赖判断标准</rule_title>
|
||||
# <rule>仅当问题1 明确依赖前序问题的上下文时,才触发合并。</rule>
|
||||
# <rule>包含指代词:如“这个”“这些”“其中”“他”“他们”等;</rule>
|
||||
# <rule>是对前一个问题的细节追问(如追问数量、时间、条件等);</rule>
|
||||
# <rule>存在顺序逻辑(如“然后呢?”“接下来怎么样?”);<rule>
|
||||
# <rule>缺失关键要素:如主语、时间范围、地点、对象等,需从前序问题中补全。</rule>
|
||||
#
|
||||
# <rule_title>合并范围选择规则</rule_title>
|
||||
# 根据依赖关系,按以下优先级确定合并方式:
|
||||
# <rule>仅依赖问题2 → 合并为:问题1 + 问题2</rule>
|
||||
# <rule>仅依赖问题3 → 合并为:问题1 + 问题3</rule>
|
||||
# <rule>问题2 和问题3 均可独立支撑问题1 → 采用就近原则,合并问题1 + 问题2</rule>
|
||||
# <rule>必须同时依赖问题2 和问题3 才能完整理解问题1 → 合并为:问题1 + 问题2 + 问题3</rule>
|
||||
# <rule>不依赖任何前序问题 → 直接返回问题1</rule>
|
||||
# <rule_title>自然衔接,无额外内容</rule_title>
|
||||
# <rule>合并后的问题必须是一个语法通顺、语义连贯的完整问句,不得添加解释、连接词或说明性文字(如“根据前面的问题”“结合上下文”等)。</rule>
|
||||
#
|
||||
# <rule_title>SQL 可执行性</rule_title>
|
||||
# <rule>合并后的问题必须能通过一条 SQL 查询直接回答。</rule>
|
||||
# <rule>若合并后的问题模糊、多义、或无法映射到具体数据库字段,则不应合并</rule>
|
||||
#
|
||||
# <rule_title>兜底策略</rule_title>
|
||||
# <rule>若无法明确判断问题1 是否完整,或无法确定其是否依赖前序问题,请不要猜测,而是主动向用户请求澄清或补充细节。</rule>
|
||||
#
|
||||
# <example>
|
||||
# 输入:问题1="早退多少天",问题2="其中迟到多少天",问题3="张三九月工作了多少天"
|
||||
# 输出:"张三九月早退多少天"
|
||||
#
|
||||
# 输入:问题1="这些天里张三是否有迟到,问题2="李四考勤","问题3="张三九月的考勤"
|
||||
# 正确输出:"张三九月的迟到情况"
|
||||
# 错误输出:
|
||||
#
|
||||
# 输入:问题1="最近一个月是否有迟到",问题2="李四考勤",问题3="张三九月的考勤"
|
||||
# 输出:"张三最近一个月是否有迟到"
|
||||
#
|
||||
# 输入:问题1="迟到了多少天",问题2="他哪几天迟到了",问题3="张三九月在林芝是否早退"
|
||||
# 输出:"张三九月在林芝迟到了多少天"
|
||||
#
|
||||
# 输入:问题1="张三九月休息了多少天" ,问题2="张三九月迟到了多少天",问题3="张三其中迟到多少天"
|
||||
# 输出:"张三九月休息了多少天"
|
||||
#
|
||||
# 输入:问题1="张三九月考勤情况" ,问题2="张三九月迟到了多少天",问题3="李四迟到多少天"
|
||||
# 输出:"张三九月考勤情况",
|
||||
#
|
||||
# 输入:问题1="张三9月在林芝上班多少天" ,问题2="9月29的考勤",问题3="9月29的考勤"
|
||||
# 输出:"张三9月在林芝上班多少天",
|
||||
# </example>
|
||||
# '''
|
||||
# sys_info3 = '''
|
||||
|
||||
# '''
|
||||
sys_info = '''
|
||||
你是一个问题补全助手,任务是判断用户当前提出的问题(问题1)是否信息完整。
|
||||
|
||||
处理流程:
|
||||
先判断问题1是否语义完整:
|
||||
如果问题1 自身含义清晰、包含必要要素(如主语、时间范围、具体查询目标等),指代明确,不依赖任何上下文也能被准确理解,则直接返回问题1原文,禁止任何形式的改写或合并。
|
||||
仅当问题1信息缺失时,才使用上下文补全:
|
||||
上下文包括前两个历史问题:问题2(较近)、问题3(较远)。
|
||||
补全时遵循就近优先原则:优先使用问题2 的信息;仅当问题2 无法提供所需信息,且问题3 可补全时,才使用问题3。
|
||||
若需同时依赖问题2 和问题3 才能补全,则按 问题3 + 问题2 + 问题1 的顺序融合。
|
||||
若问题1 中包含“他”“她”“它”等代词,无需无需区分性别或语义类别,(他不一定代表男,她也不一定代表女),采用就近原则从上下文中找出最近的具有人名或明确实体的主语进行替换。
|
||||
主语选择必须严格遵循时间顺序:仅当问题2 中无有效主语时,才考虑问题3。只要问题2 包含明确人名,就必须使用问题2 的主语
|
||||
补全要求:
|
||||
合并后的问题必须是一个语法通顺、语义完整的自然问句。
|
||||
不得添加任何解释性、连接性或说明性文字(如“根据前面的问题”“结合上下文”等)。
|
||||
补全后的问题必须能通过一条 SQL 查询语句直接回答(即具备明确的查询对象、条件和指标)。
|
||||
兜底策略:
|
||||
如果你无法确定问题1 是否完整,或无法判断是否依赖上下文,或即使合并上下文仍无法形成完整、可执行的问题,请不要猜测或强行输出,而是直接向用户请求补充细节。
|
||||
'''
|
||||
prompt = [
|
||||
self.system_message(
|
||||
"你的目标是将一系列相关问题合并成一个单一的问题。"
|
||||
"合并准则一、如果第二个问题与第一个问题无关且本身是完整独立的,则直接返回第二个问题。"
|
||||
"合并准则二、如果第二个问题域第一个问题相关,且要基于第一个问题的前提,请合并两个问题为一个问题,只需返回合并后的新问题,不要添加任何额外解释。"
|
||||
"合并准则三、理论上,合并后的问题应该能够通过单个SQL语句来回答"),
|
||||
self.user_message("First question: " + last_question + "\nSecond question: " + new_question),
|
||||
self.system_message(sys_info),
|
||||
# self.user_message("First question: " + last_question + "\nSecond question: " + new_question),
|
||||
self.user_message("问题1: " + new_question + "\n上下文: " +str(context_question))
|
||||
]
|
||||
|
||||
return self.submit_prompt(prompt=prompt, **kwargs)
|
||||
@@ -327,15 +489,16 @@ class CustomQdrant_VectorStore(Qdrant_VectorStore):
|
||||
return str(e)
|
||||
return "Unknown error"
|
||||
|
||||
request_body = {
|
||||
"model": self.embedding_model_name,
|
||||
"sentences": [data],
|
||||
}
|
||||
# request_body = {
|
||||
# "model": self.embedding_model_name,
|
||||
# "encoding_format": "float",
|
||||
# "input": [data],
|
||||
# }
|
||||
request_body = {
|
||||
"model": self.embedding_model_name,
|
||||
"encoding_format": "float",
|
||||
"input": [{"type":"text","text":data}],
|
||||
}
|
||||
request_body.update(kwargs)
|
||||
|
||||
response = requests.post(
|
||||
@@ -348,10 +511,10 @@ class CustomQdrant_VectorStore(Qdrant_VectorStore):
|
||||
f"Failed to create the embeddings, detail: {_get_error_string(response)}"
|
||||
)
|
||||
result = response.json()
|
||||
embeddings = result['embeddings']
|
||||
# embeddings = result['data'][0]['embedding']
|
||||
return embeddings[0]
|
||||
# return embeddings
|
||||
# embeddings = result['embeddings']
|
||||
embeddings = result['data']['embedding']
|
||||
# return embeddings[0]
|
||||
return embeddings
|
||||
|
||||
class CustomVanna(CustomQdrant_VectorStore, OpenAICompatibleLLM):
|
||||
def __init__(self, llm_config=None, vector_store_config=None):
|
||||
|
||||
@@ -46,5 +46,6 @@ def query_predefined_question_list() -> list:
|
||||
session = SqliteSqlalchemy().session
|
||||
all = session.query(PredefinedQuestion).filter(PredefinedQuestion.enable == True).all()
|
||||
all = [a.to_dict() for a in all] if all else []
|
||||
logger.info(f"all is :{all}")
|
||||
session.close()
|
||||
return all
|
||||
|
||||
@@ -701,7 +701,7 @@ question_and_answer = [
|
||||
"category": "外部单位统计"
|
||||
},
|
||||
{
|
||||
"question": "XX中心员工在林芝工作的天数",
|
||||
"question": "XX中心员工在林芝工作的天数排行",
|
||||
"answer": '''
|
||||
SELECT p."code" AS "工号",
|
||||
p."name" AS "姓名",
|
||||
@@ -727,6 +727,105 @@ question_and_answer = [
|
||||
"tags": ["员工", "部门", "考勤", "工作地", "区域", "工作天数统计"],
|
||||
"category": "工作地考勤统计分析"
|
||||
|
||||
},
|
||||
{
|
||||
"question": "XX中心张XX十月在林芝工作了多长时间",
|
||||
"answer": '''
|
||||
SELECT
|
||||
p."code" AS "工号",
|
||||
p."name" AS "姓名",
|
||||
SUM(ps."work_time") AS "在林芝工作时间"
|
||||
FROM "YJOA_APPSERVICE_DB"."t_pr3rl2oj_yj_person_database" p
|
||||
INNER JOIN "YJOA_APPSERVICE_DB"."t_yj_person_status" ps
|
||||
ON p."code" = ps."person_id"
|
||||
WHERE p."dr" = 0
|
||||
AND ps."dr" = 0
|
||||
AND p."name" = '张XX'
|
||||
AND ps."date_value" BETWEEN '2025-10-01' AND '2025-10-31'
|
||||
AND (p."code", ps."date_value") IN (
|
||||
SELECT
|
||||
a."person_id",
|
||||
TO_CHAR(a."attendance_time", 'yyyy-MM-dd')
|
||||
FROM "YJOA_APPSERVICE_DB"."t_yj_person_attendance" a
|
||||
INNER JOIN "YJOA_APPSERVICE_DB"."t_yj_person_ac_area" ac
|
||||
ON a."access_control_point" = ac."ac_point"
|
||||
WHERE a."dr" = 0
|
||||
AND ac."region" = 5
|
||||
)
|
||||
AND p."internal_dept" IN (
|
||||
SELECT "id"
|
||||
FROM "IUAP_APDOC_BASEDOC"."org_orgs"
|
||||
START WITH ("name" LIKE '%XX中心%' OR "shortname" LIKE '%XX中心%')
|
||||
AND "dr" = 0 AND "enable" = 1 AND "code" LIKE '%CYJ%'
|
||||
CONNECT BY PRIOR "id" = "parentid"
|
||||
)
|
||||
GROUP BY p."code", p."name"
|
||||
ORDER BY "在林芝工作时间" DESC
|
||||
LIMIT 1000;
|
||||
''',
|
||||
"tags": ["员工", "部门", "考勤", "工作地", "区域", "工作时长统计"],
|
||||
"category": "工作地考勤统计分析"
|
||||
|
||||
},
|
||||
{
|
||||
"question": "XX中心张三在林芝工作了多少天,迟到了多少天",
|
||||
"answer": '''
|
||||
SELECT p."name" AS "姓名",
|
||||
COUNT(DISTINCT TO_CHAR(a."attendance_time", 'yyyy-MM-dd')) AS "在林芝工作天数",
|
||||
COUNT(DISTINCT CASE WHEN ps."status" IN ('1006','1009','6002','6004') THEN ps."date_value" END) AS "在林芝迟到的天数"
|
||||
FROM "YJOA_APPSERVICE_DB"."t_pr3rl2oj_yj_person_database" p
|
||||
LEFT JOIN "YJOA_APPSERVICE_DB"."t_yj_person_attendance" a ON p."code" = a."person_id"
|
||||
LEFT JOIN "YJOA_APPSERVICE_DB"."t_yj_person_ac_area" ac ON a."access_control_point" = ac."ac_point"
|
||||
LEFT JOIN "YJOA_APPSERVICE_DB"."t_yj_person_status" ps ON p."code" = ps."person_id"
|
||||
AND TO_CHAR(a."attendance_time", 'yyyy-MM-dd') = ps."date_value"
|
||||
WHERE p."dr" = 0
|
||||
AND a."dr" = 0
|
||||
AND ps."dr" = 0
|
||||
AND ac."region" = 5
|
||||
AND p."name" = '张三'
|
||||
AND a."attendance_time" LIKE '2025-09%'
|
||||
AND ps."date_value" LIKE '2025-09%'
|
||||
AND p."internal_dept" IN (
|
||||
SELECT "id"
|
||||
FROM "IUAP_APDOC_BASEDOC"."org_orgs"
|
||||
START WITH ("name" LIKE '%XX中心%' OR "shortname" LIKE '%XX中心%')
|
||||
AND "dr" = 0 AND "enable" = 1 AND "code" LIKE '%CYJ%'
|
||||
CONNECT BY PRIOR "id" = "parentid"
|
||||
)
|
||||
GROUP BY p."name"
|
||||
''',
|
||||
"tags": ["员工", "部门", "考勤", "工作地", "区域", "工作天数统计", "迟到天数统计"],
|
||||
"category": "工作地考勤统计分析"
|
||||
|
||||
},
|
||||
|
||||
{
|
||||
"question": "张三9月在林芝上班期间,有多少天早退了?",
|
||||
"answer": '''
|
||||
SELECT p."name" AS "姓名",
|
||||
COUNT(DISTINCT ps."date_value") AS "早退天数"
|
||||
FROM "YJOA_APPSERVICE_DB"."t_yj_person_status" ps
|
||||
INNER JOIN "YJOA_APPSERVICE_DB"."t_pr3rl2oj_yj_person_database" p
|
||||
ON p."code" = ps."person_id"
|
||||
INNER JOIN "YJOA_APPSERVICE_DB"."t_yj_person_attendance" pa
|
||||
ON pa."person_id" = ps."person_id"
|
||||
AND pa."attendance_time" LIKE '2025-09%'
|
||||
INNER JOIN "YJOA_APPSERVICE_DB"."t_yj_person_ac_position" acp
|
||||
ON acp."ac_point" = pa."access_control_point"
|
||||
INNER JOIN "YJOA_APPSERVICE_DB"."t_yj_person_ac_area" aca
|
||||
ON aca."ac_point" = acp."ac_point"
|
||||
WHERE p."name" = '张三'
|
||||
AND ps."date_value" LIKE '2025-09%'
|
||||
AND ps."dr" = 0
|
||||
AND p."dr" = 0
|
||||
AND aca."region" = '5'
|
||||
AND ps."status" IN ('1006','6001','4006')
|
||||
GROUP BY p."name"
|
||||
LIMIT 1000
|
||||
''',
|
||||
"tags": ["员工", "部门", "考勤", "工作地", "区域","早退天数统计"],
|
||||
"category": "工作地考勤统计分析"
|
||||
|
||||
},
|
||||
{
|
||||
"question": "XX中心员工在成都工作的天数",
|
||||
|
||||
@@ -588,35 +588,6 @@ person_status_ddl='''
|
||||
"role": "dimension",
|
||||
"tags": ["时间信息", "日期记录"]
|
||||
},
|
||||
{
|
||||
"name": "practical_attendance",
|
||||
"type": "Int",
|
||||
"comment": "实际出勤",
|
||||
"value":{
|
||||
"1":"已出勤",
|
||||
"0":"未出勤",
|
||||
},
|
||||
"role": "dimension",
|
||||
"tags": ["出勤信息", "状态记录","枚举信息"]
|
||||
},
|
||||
{
|
||||
"name": "is_ought_attendance",
|
||||
"type": "Int",
|
||||
"comment": "是否应出勤",
|
||||
"value":{
|
||||
"1":"是",
|
||||
"0":"否",
|
||||
},
|
||||
"role": "dimension",
|
||||
"tags": ["是否应出勤", "枚举信息"]
|
||||
},
|
||||
{
|
||||
"name": "work_area",
|
||||
"type": "VARCHAR(50)",
|
||||
"comment": "工作地区",
|
||||
"role": "dimension",
|
||||
"tags": ["工作地区", "地域信息"]
|
||||
},
|
||||
{
|
||||
"name": "work_time",
|
||||
"type": "Int",
|
||||
|
||||
Reference in New Issue
Block a user