Merge branch 'dev' of gitlab-devt.yced.com.cn:lei_y601/sqlbot_agent into dev

# Conflicts:
#	main_service.py
This commit is contained in:
yujj128
2025-09-25 16:52:11 +08:00
4 changed files with 36 additions and 5 deletions

10
.env
View File

@@ -1,13 +1,19 @@
IS_FIRST_LOAD=True
CHAT_MODEL_BASE_URL=https://api.siliconflow.cn
CHAT_MODEL_API_KEY=sk-cjiakyfpzamtxgxitcbxvwyvaulnygmyxqpykkgngsvfqhuy
CHAT_MODEL_API_KEY=sk-iyhiltycmrfnhrnbljsgqjrinhbztwdplyvuhfihcdlepole
CHAT_MODEL_NAME=Qwen/Qwen3-Next-80B-A3B-Instruct
EMBEDDING_MODEL_BASE_URL=https://api.siliconflow.cn
EMBEDDING_MODEL_API_KEY=sk-cjiakyfpzamtxgxitcbxvwyvaulnygmyxqpykkgngsvfqhuy
EMBEDDING_MODEL_API_KEY=sk-iyhiltycmrfnhrnbljsgqjrinhbztwdplyvuhfihcdlepole
EMBEDDING_MODEL_NAME=Qwen/Qwen3-Embedding-8B
#向量数据库
#type:memory/remote,如果设置为remote将IS_FIRST_LOAD 设置成false
QDRANT_TYPE=memory
QDRANT_DB_HOST=106.13.42.156
QDRANT_DB_PORT=16000
#mysql ,sqlite,pg等
DATA_SOURCE_TYPE=dameng
#数据库类型

9
Dockerfile Normal file
View File

@@ -0,0 +1,9 @@
FROM docker.m.daocloud.io/python:3.12-slim
WORKDIR /app
COPY . /app
ENV TZ=Asia/Shanghai \
LANG=C.UTF-8
RUN rm -rf logs .git .idea .venv && apt-get update && apt-get install -y vim curl && pip install -r requirement.txt -i https://mirrors.aliyun.com/pypi/simple/
RUN mkdir -p /app/logs && touch /app/logs/sqlbot.log && rm -rf *.whl
EXPOSE 8084
CMD ["python","main_service.py"]

View File

@@ -39,9 +39,11 @@ def load_train_data_ddl(vn: CustomVanna):
def create_vana():
print("----------------create---------")
logger.info("----------------create vana ---------")
q_client = QdrantClient(":memory:") if config('QDRANT_TYPE', default='memory') == 'memory' else QdrantClient(
url=config('QDRANT_DB_HOST', default=''), port=config('QDRANT_DB_PORT', default=6333))
vn = CustomVanna(
vector_store_config={"client": QdrantClient(":memory:")},
vector_store_config={"client": q_client},
llm_config={
"api_key": config('CHAT_MODEL_API_KEY', default=''),
"api_base": config('CHAT_MODEL_BASE_URL', default=''),
@@ -52,7 +54,7 @@ def create_vana():
def init_vn(vn):
print("--------------init vn-----connect----")
logger.info("--------------init vana-----connect to datasouce db----")
connect_database(vn)
load_ddl_doc.add_ddl(vn)
load_ddl_doc.add_documentation(vn)
@@ -103,6 +105,7 @@ def generate_sql_2():
logger.info("Generate sql result is {0}".format(data))
data['id'] = id
sql = data["resp"]["sql"]
logger.info("generate sql is : "+ sql)
cache.set(id=id, field="question", value=question)
cache.set(id=id, field="sql", value=sql)
data["type"]="success"

View File

@@ -78,6 +78,19 @@ template:
<rule>
SQL查询的字段名不要自动翻译别名必须为英文
</rule>
<rule>
生成sql时如果返回字段中有枚举字段,请根据枚举字段选项值生成对应的case when语句
<example>
SELECT
CASE
WHEN gender = 1 THEN '男'
WHEN gender = 2 THEN '女'
ELSE gender
END AS gender,
COUNT(*) AS count
FROM person GROUP BY gender;
</example>
</rule>
<rule>
SQL查询的字段若是函数字段如 COUNT(),CAST() 等,必须加上别名
</rule>