up to v0.17.2_supple (#7)
This commit is contained in:
parent
0f9b87898f
commit
4624f89cc1
|
@ -59,6 +59,7 @@ RUN --mount=type=cache,id=ragflow_apt,target=/var/cache/apt,sharing=locked \
|
|||
apt install -y default-jdk && \
|
||||
apt install -y libatk-bridge2.0-0 && \
|
||||
apt install -y libpython3-dev libgtk-4-1 libnss3 xdg-utils libgbm-dev && \
|
||||
apt install -y libjemalloc-dev && \
|
||||
apt install -y python3-pip pipx nginx unzip curl wget git vim less
|
||||
|
||||
RUN if [ "$NEED_MIRROR" == "1" ]; then \
|
||||
|
|
|
@ -28,7 +28,7 @@ git clone https://github.com/zstar1003/ragflow-plus.git
|
|||
2. 打包web文件
|
||||
```bash
|
||||
cd web
|
||||
npm run build
|
||||
pnpm run build
|
||||
```
|
||||
|
||||
3. 进入到容器,删除容器中已有的/ragflow/web/dist文件
|
||||
|
@ -126,4 +126,8 @@ This repository is available under the [Ragflow
|
|||
|
||||
- [ragflow](https://github.com/infiniflow/ragflow)
|
||||
|
||||
- [v3-admin-vite](https://github.com/un-pany/v3-admin-vite)
|
||||
- [v3-admin-vite](https://github.com/un-pany/v3-admin-vite)
|
||||
|
||||
## Star History
|
||||
|
||||

|
|
@ -216,6 +216,8 @@ class Generate(ComponentBase):
|
|||
return
|
||||
|
||||
msg = self._canvas.get_history(self._param.message_history_window_size)
|
||||
if msg and msg[0]['role'] == 'assistant':
|
||||
msg.pop(0)
|
||||
if len(msg) < 1:
|
||||
msg.append({"role": "user", "content": "Output: "})
|
||||
_, msg = message_fit_in([{"role": "system", "content": prompt}, *msg], int(chat_mdl.max_length * 0.97))
|
||||
|
|
|
@ -38,6 +38,10 @@ class IterationItem(ComponentBase, ABC):
|
|||
ans = parent.get_input()
|
||||
ans = parent._param.delimiter.join(ans["content"]) if "content" in ans else ""
|
||||
ans = [a.strip() for a in ans.split(parent._param.delimiter)]
|
||||
if not ans:
|
||||
self._idx = -1
|
||||
return pd.DataFrame()
|
||||
|
||||
df = pd.DataFrame([{"content": ans[self._idx]}])
|
||||
self._idx += 1
|
||||
if self._idx >= len(ans):
|
||||
|
|
|
@ -24,6 +24,7 @@ from api.db.services.llm_service import LLMBundle
|
|||
from api import settings
|
||||
from agent.component.base import ComponentBase, ComponentParamBase
|
||||
from rag.app.tag import label_question
|
||||
from rag.utils.tavily_conn import Tavily
|
||||
|
||||
|
||||
class RetrievalParam(ComponentParamBase):
|
||||
|
@ -40,6 +41,8 @@ class RetrievalParam(ComponentParamBase):
|
|||
self.kb_ids = []
|
||||
self.rerank_id = ""
|
||||
self.empty_response = ""
|
||||
self.tavily_api_key = ""
|
||||
self.use_kg = False
|
||||
|
||||
def check(self):
|
||||
self.check_decimal_float(self.similarity_threshold, "[Retrieval] Similarity threshold")
|
||||
|
@ -53,7 +56,9 @@ class Retrieval(ComponentBase, ABC):
|
|||
def _run(self, history, **kwargs):
|
||||
query = self.get_input()
|
||||
query = str(query["content"][0]) if "content" in query else ""
|
||||
|
||||
lines = query.split('\n')
|
||||
user_queries = [line.split("USER:", 1)[1] for line in lines if line.startswith("USER:")]
|
||||
query = user_queries[-1] if user_queries else ""
|
||||
kbs = KnowledgebaseService.get_by_ids(self._param.kb_ids)
|
||||
if not kbs:
|
||||
return Retrieval.be_output("")
|
||||
|
@ -73,6 +78,20 @@ class Retrieval(ComponentBase, ABC):
|
|||
self._param.similarity_threshold, 1 - self._param.keywords_similarity_weight,
|
||||
aggs=False, rerank_mdl=rerank_mdl,
|
||||
rank_feature=label_question(query, kbs))
|
||||
if self._param.use_kg:
|
||||
ck = settings.kg_retrievaler.retrieval(query,
|
||||
[kbs[0].tenant_id],
|
||||
self._param.kb_ids,
|
||||
embd_mdl,
|
||||
LLMBundle(kbs[0].tenant_id, LLMType.CHAT))
|
||||
if ck["content_with_weight"]:
|
||||
kbinfos["chunks"].insert(0, ck)
|
||||
|
||||
if self._param.tavily_api_key:
|
||||
tav = Tavily(self._param.tavily_api_key)
|
||||
tav_res = tav.retrieve_chunks(query)
|
||||
kbinfos["chunks"].extend(tav_res["chunks"])
|
||||
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
|
||||
|
||||
if not kbinfos["chunks"]:
|
||||
df = Retrieval.be_output("")
|
||||
|
|
|
@ -14,9 +14,8 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
from abc import ABC
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from agent.component import GenerateParam, Generate
|
||||
from rag.prompts import full_question
|
||||
|
||||
|
||||
class RewriteQuestionParam(GenerateParam):
|
||||
|
@ -33,48 +32,6 @@ class RewriteQuestionParam(GenerateParam):
|
|||
def check(self):
|
||||
super().check()
|
||||
|
||||
def get_prompt(self, conv, language, query):
|
||||
prompt = """
|
||||
Role: A helpful assistant
|
||||
Task: Generate a full user question that would follow the conversation.
|
||||
Requirements & Restrictions:
|
||||
- Text generated MUST be in the same language of the original user's question.
|
||||
- If the user's latest question is completely, don't do anything, just return the original question.
|
||||
- DON'T generate anything except a refined question."""
|
||||
|
||||
if language:
|
||||
prompt += f"""
|
||||
- Text generated MUST be in {language}"""
|
||||
|
||||
prompt += f"""
|
||||
######################
|
||||
-Examples-
|
||||
######################
|
||||
# Example 1
|
||||
## Conversation
|
||||
USER: What is the name of Donald Trump's father?
|
||||
ASSISTANT: Fred Trump.
|
||||
USER: And his mother?
|
||||
###############
|
||||
Output: What's the name of Donald Trump's mother?
|
||||
------------
|
||||
# Example 2
|
||||
## Conversation
|
||||
USER: What is the name of Donald Trump's father?
|
||||
ASSISTANT: Fred Trump.
|
||||
USER: And his mother?
|
||||
ASSISTANT: Mary Trump.
|
||||
USER: What's her full name?
|
||||
###############
|
||||
Output: What's the full name of Donald Trump's mother Mary Trump?
|
||||
######################
|
||||
# Real Data
|
||||
## Conversation
|
||||
{conv}
|
||||
###############
|
||||
"""
|
||||
return prompt
|
||||
|
||||
|
||||
class RewriteQuestion(Generate, ABC):
|
||||
component_name = "RewriteQuestion"
|
||||
|
@ -83,15 +40,10 @@ class RewriteQuestion(Generate, ABC):
|
|||
hist = self._canvas.get_history(self._param.message_history_window_size)
|
||||
query = self.get_input()
|
||||
query = str(query["content"][0]) if "content" in query else ""
|
||||
conv = []
|
||||
for m in hist:
|
||||
if m["role"] not in ["user", "assistant"]:
|
||||
continue
|
||||
conv.append("{}: {}".format(m["role"].upper(), m["content"]))
|
||||
conv = "\n".join(conv)
|
||||
chat_mdl = LLMBundle(self._canvas.get_tenant_id(), LLMType.CHAT, self._param.llm_id)
|
||||
ans = chat_mdl.chat(self._param.get_prompt(conv, self.gen_lang(self._param.language), query),
|
||||
[{"role": "user", "content": "Output: "}], self._param.gen_conf())
|
||||
messages = [h for h in hist if h["role"]!="system"]
|
||||
if messages[-1]["role"] != "user":
|
||||
messages.append({"role": "user", "content": query})
|
||||
ans = full_question(self._canvas.get_tenant_id(), self._param.llm_id, messages, self.gen_lang(self._param.language))
|
||||
self._canvas.history.pop()
|
||||
self._canvas.history.append(("user", ans))
|
||||
return RewriteQuestion.be_output(ans)
|
||||
|
|
|
@ -36,132 +36,188 @@ class DeepResearcher:
|
|||
self._kb_retrieve = kb_retrieve
|
||||
self._kg_retrieve = kg_retrieve
|
||||
|
||||
@staticmethod
|
||||
def _remove_query_tags(text):
|
||||
"""Remove query tags from text"""
|
||||
pattern = re.escape(BEGIN_SEARCH_QUERY) + r"(.*?)" + re.escape(END_SEARCH_QUERY)
|
||||
return re.sub(pattern, "", text)
|
||||
|
||||
@staticmethod
|
||||
def _remove_result_tags(text):
|
||||
"""Remove result tags from text"""
|
||||
pattern = re.escape(BEGIN_SEARCH_RESULT) + r"(.*?)" + re.escape(END_SEARCH_RESULT)
|
||||
return re.sub(pattern, "", text)
|
||||
|
||||
def _generate_reasoning(self, msg_history):
|
||||
"""Generate reasoning steps"""
|
||||
query_think = ""
|
||||
if msg_history[-1]["role"] != "user":
|
||||
msg_history.append({"role": "user", "content": "Continues reasoning with the new information.\n"})
|
||||
else:
|
||||
msg_history[-1]["content"] += "\n\nContinues reasoning with the new information.\n"
|
||||
|
||||
for ans in self.chat_mdl.chat_streamly(REASON_PROMPT, msg_history, {"temperature": 0.7}):
|
||||
ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL)
|
||||
if not ans:
|
||||
continue
|
||||
query_think = ans
|
||||
yield query_think
|
||||
return query_think
|
||||
|
||||
def _extract_search_queries(self, query_think, question, step_index):
|
||||
"""Extract search queries from thinking"""
|
||||
queries = extract_between(query_think, BEGIN_SEARCH_QUERY, END_SEARCH_QUERY)
|
||||
if not queries and step_index == 0:
|
||||
# If this is the first step and no queries are found, use the original question as the query
|
||||
queries = [question]
|
||||
return queries
|
||||
|
||||
def _truncate_previous_reasoning(self, all_reasoning_steps):
|
||||
"""Truncate previous reasoning steps to maintain a reasonable length"""
|
||||
truncated_prev_reasoning = ""
|
||||
for i, step in enumerate(all_reasoning_steps):
|
||||
truncated_prev_reasoning += f"Step {i + 1}: {step}\n\n"
|
||||
|
||||
prev_steps = truncated_prev_reasoning.split('\n\n')
|
||||
if len(prev_steps) <= 5:
|
||||
truncated_prev_reasoning = '\n\n'.join(prev_steps)
|
||||
else:
|
||||
truncated_prev_reasoning = ''
|
||||
for i, step in enumerate(prev_steps):
|
||||
if i == 0 or i >= len(prev_steps) - 4 or BEGIN_SEARCH_QUERY in step or BEGIN_SEARCH_RESULT in step:
|
||||
truncated_prev_reasoning += step + '\n\n'
|
||||
else:
|
||||
if truncated_prev_reasoning[-len('\n\n...\n\n'):] != '\n\n...\n\n':
|
||||
truncated_prev_reasoning += '...\n\n'
|
||||
|
||||
return truncated_prev_reasoning.strip('\n')
|
||||
|
||||
def _retrieve_information(self, search_query):
|
||||
"""Retrieve information from different sources"""
|
||||
# 1. Knowledge base retrieval
|
||||
kbinfos = self._kb_retrieve(question=search_query) if self._kb_retrieve else {"chunks": [], "doc_aggs": []}
|
||||
|
||||
# 2. Web retrieval (if Tavily API is configured)
|
||||
if self.prompt_config.get("tavily_api_key"):
|
||||
tav = Tavily(self.prompt_config["tavily_api_key"])
|
||||
tav_res = tav.retrieve_chunks(search_query)
|
||||
kbinfos["chunks"].extend(tav_res["chunks"])
|
||||
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
|
||||
|
||||
# 3. Knowledge graph retrieval (if configured)
|
||||
if self.prompt_config.get("use_kg") and self._kg_retrieve:
|
||||
ck = self._kg_retrieve(question=search_query)
|
||||
if ck["content_with_weight"]:
|
||||
kbinfos["chunks"].insert(0, ck)
|
||||
|
||||
return kbinfos
|
||||
|
||||
def _update_chunk_info(self, chunk_info, kbinfos):
|
||||
"""Update chunk information for citations"""
|
||||
if not chunk_info["chunks"]:
|
||||
# If this is the first retrieval, use the retrieval results directly
|
||||
for k in chunk_info.keys():
|
||||
chunk_info[k] = kbinfos[k]
|
||||
else:
|
||||
# Merge newly retrieved information, avoiding duplicates
|
||||
cids = [c["chunk_id"] for c in chunk_info["chunks"]]
|
||||
for c in kbinfos["chunks"]:
|
||||
if c["chunk_id"] not in cids:
|
||||
chunk_info["chunks"].append(c)
|
||||
|
||||
dids = [d["doc_id"] for d in chunk_info["doc_aggs"]]
|
||||
for d in kbinfos["doc_aggs"]:
|
||||
if d["doc_id"] not in dids:
|
||||
chunk_info["doc_aggs"].append(d)
|
||||
|
||||
def _extract_relevant_info(self, truncated_prev_reasoning, search_query, kbinfos):
|
||||
"""Extract and summarize relevant information"""
|
||||
summary_think = ""
|
||||
for ans in self.chat_mdl.chat_streamly(
|
||||
RELEVANT_EXTRACTION_PROMPT.format(
|
||||
prev_reasoning=truncated_prev_reasoning,
|
||||
search_query=search_query,
|
||||
document="\n".join(kb_prompt(kbinfos, 4096))
|
||||
),
|
||||
[{"role": "user",
|
||||
"content": f'Now you should analyze each web page and find helpful information based on the current search query "{search_query}" and previous reasoning steps.'}],
|
||||
{"temperature": 0.7}):
|
||||
ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL)
|
||||
if not ans:
|
||||
continue
|
||||
summary_think = ans
|
||||
yield summary_think
|
||||
|
||||
return summary_think
|
||||
|
||||
def thinking(self, chunk_info: dict, question: str):
|
||||
def rm_query_tags(line):
|
||||
pattern = re.escape(BEGIN_SEARCH_QUERY) + r"(.*?)" + re.escape(END_SEARCH_QUERY)
|
||||
return re.sub(pattern, "", line)
|
||||
|
||||
def rm_result_tags(line):
|
||||
pattern = re.escape(BEGIN_SEARCH_RESULT) + r"(.*?)" + re.escape(END_SEARCH_RESULT)
|
||||
return re.sub(pattern, "", line)
|
||||
|
||||
executed_search_queries = []
|
||||
msg_hisotry = [{"role": "user", "content": f'Question:\"{question}\"\n'}]
|
||||
msg_history = [{"role": "user", "content": f'Question:\"{question}\"\n'}]
|
||||
all_reasoning_steps = []
|
||||
think = "<think>"
|
||||
for ii in range(MAX_SEARCH_LIMIT + 1):
|
||||
if ii == MAX_SEARCH_LIMIT - 1:
|
||||
|
||||
for step_index in range(MAX_SEARCH_LIMIT + 1):
|
||||
# Check if the maximum search limit has been reached
|
||||
if step_index == MAX_SEARCH_LIMIT - 1:
|
||||
summary_think = f"\n{BEGIN_SEARCH_RESULT}\nThe maximum search limit is exceeded. You are not allowed to search.\n{END_SEARCH_RESULT}\n"
|
||||
yield {"answer": think + summary_think + "</think>", "reference": {}, "audio_binary": None}
|
||||
all_reasoning_steps.append(summary_think)
|
||||
msg_hisotry.append({"role": "assistant", "content": summary_think})
|
||||
msg_history.append({"role": "assistant", "content": summary_think})
|
||||
break
|
||||
|
||||
# Step 1: Generate reasoning
|
||||
query_think = ""
|
||||
if msg_hisotry[-1]["role"] != "user":
|
||||
msg_hisotry.append({"role": "user", "content": "Continues reasoning with the new information.\n"})
|
||||
else:
|
||||
msg_hisotry[-1]["content"] += "\n\nContinues reasoning with the new information.\n"
|
||||
for ans in self.chat_mdl.chat_streamly(REASON_PROMPT, msg_hisotry, {"temperature": 0.7}):
|
||||
ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL)
|
||||
if not ans:
|
||||
continue
|
||||
for ans in self._generate_reasoning(msg_history):
|
||||
query_think = ans
|
||||
yield {"answer": think + rm_query_tags(query_think) + "</think>", "reference": {}, "audio_binary": None}
|
||||
yield {"answer": think + self._remove_query_tags(query_think) + "</think>", "reference": {}, "audio_binary": None}
|
||||
|
||||
think += rm_query_tags(query_think)
|
||||
think += self._remove_query_tags(query_think)
|
||||
all_reasoning_steps.append(query_think)
|
||||
queries = extract_between(query_think, BEGIN_SEARCH_QUERY, END_SEARCH_QUERY)
|
||||
if not queries:
|
||||
if ii > 0:
|
||||
break
|
||||
queries = [question]
|
||||
|
||||
# Step 2: Extract search queries
|
||||
queries = self._extract_search_queries(query_think, question, step_index)
|
||||
if not queries and step_index > 0:
|
||||
# If not the first step and no queries, end the search process
|
||||
break
|
||||
|
||||
# Process each search query
|
||||
for search_query in queries:
|
||||
logging.info(f"[THINK]Query: {ii}. {search_query}")
|
||||
msg_hisotry.append({"role": "assistant", "content": search_query})
|
||||
think += f"\n\n> {ii +1}. {search_query}\n\n"
|
||||
logging.info(f"[THINK]Query: {step_index}. {search_query}")
|
||||
msg_history.append({"role": "assistant", "content": search_query})
|
||||
think += f"\n\n> {step_index + 1}. {search_query}\n\n"
|
||||
yield {"answer": think + "</think>", "reference": {}, "audio_binary": None}
|
||||
|
||||
summary_think = ""
|
||||
# The search query has been searched in previous steps.
|
||||
# Check if the query has already been executed
|
||||
if search_query in executed_search_queries:
|
||||
summary_think = f"\n{BEGIN_SEARCH_RESULT}\nYou have searched this query. Please refer to previous results.\n{END_SEARCH_RESULT}\n"
|
||||
yield {"answer": think + summary_think + "</think>", "reference": {}, "audio_binary": None}
|
||||
all_reasoning_steps.append(summary_think)
|
||||
msg_hisotry.append({"role": "user", "content": summary_think})
|
||||
msg_history.append({"role": "user", "content": summary_think})
|
||||
think += summary_think
|
||||
continue
|
||||
|
||||
truncated_prev_reasoning = ""
|
||||
for i, step in enumerate(all_reasoning_steps):
|
||||
truncated_prev_reasoning += f"Step {i + 1}: {step}\n\n"
|
||||
|
||||
prev_steps = truncated_prev_reasoning.split('\n\n')
|
||||
if len(prev_steps) <= 5:
|
||||
truncated_prev_reasoning = '\n\n'.join(prev_steps)
|
||||
else:
|
||||
truncated_prev_reasoning = ''
|
||||
for i, step in enumerate(prev_steps):
|
||||
if i == 0 or i >= len(prev_steps) - 4 or BEGIN_SEARCH_QUERY in step or BEGIN_SEARCH_RESULT in step:
|
||||
truncated_prev_reasoning += step + '\n\n'
|
||||
else:
|
||||
if truncated_prev_reasoning[-len('\n\n...\n\n'):] != '\n\n...\n\n':
|
||||
truncated_prev_reasoning += '...\n\n'
|
||||
truncated_prev_reasoning = truncated_prev_reasoning.strip('\n')
|
||||
|
||||
# Retrieval procedure:
|
||||
# 1. KB search
|
||||
# 2. Web search (optional)
|
||||
# 3. KG search (optional)
|
||||
kbinfos = self._kb_retrieve(question=search_query) if self._kb_retrieve else {"chunks": [], "doc_aggs": []}
|
||||
|
||||
if self.prompt_config.get("tavily_api_key"):
|
||||
tav = Tavily(self.prompt_config["tavily_api_key"])
|
||||
tav_res = tav.retrieve_chunks(" ".join(search_query))
|
||||
kbinfos["chunks"].extend(tav_res["chunks"])
|
||||
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
|
||||
if self.prompt_config.get("use_kg") and self._kg_retrieve:
|
||||
ck = self._kg_retrieve(question=search_query)
|
||||
if ck["content_with_weight"]:
|
||||
kbinfos["chunks"].insert(0, ck)
|
||||
|
||||
# Merge chunk info for citations
|
||||
if not chunk_info["chunks"]:
|
||||
for k in chunk_info.keys():
|
||||
chunk_info[k] = kbinfos[k]
|
||||
else:
|
||||
cids = [c["chunk_id"] for c in chunk_info["chunks"]]
|
||||
for c in kbinfos["chunks"]:
|
||||
if c["chunk_id"] in cids:
|
||||
continue
|
||||
chunk_info["chunks"].append(c)
|
||||
dids = [d["doc_id"] for d in chunk_info["doc_aggs"]]
|
||||
for d in kbinfos["doc_aggs"]:
|
||||
if d["doc_id"] in dids:
|
||||
continue
|
||||
chunk_info["doc_aggs"].append(d)
|
||||
|
||||
|
||||
executed_search_queries.append(search_query)
|
||||
|
||||
# Step 3: Truncate previous reasoning steps
|
||||
truncated_prev_reasoning = self._truncate_previous_reasoning(all_reasoning_steps)
|
||||
|
||||
# Step 4: Retrieve information
|
||||
kbinfos = self._retrieve_information(search_query)
|
||||
|
||||
# Step 5: Update chunk information
|
||||
self._update_chunk_info(chunk_info, kbinfos)
|
||||
|
||||
# Step 6: Extract relevant information
|
||||
think += "\n\n"
|
||||
for ans in self.chat_mdl.chat_streamly(
|
||||
RELEVANT_EXTRACTION_PROMPT.format(
|
||||
prev_reasoning=truncated_prev_reasoning,
|
||||
search_query=search_query,
|
||||
document="\n".join(kb_prompt(kbinfos, 4096))
|
||||
),
|
||||
[{"role": "user",
|
||||
"content": f'Now you should analyze each web page and find helpful information based on the current search query "{search_query}" and previous reasoning steps.'}],
|
||||
{"temperature": 0.7}):
|
||||
ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL)
|
||||
if not ans:
|
||||
continue
|
||||
summary_think = ""
|
||||
for ans in self._extract_relevant_info(truncated_prev_reasoning, search_query, kbinfos):
|
||||
summary_think = ans
|
||||
yield {"answer": think + rm_result_tags(summary_think) + "</think>", "reference": {}, "audio_binary": None}
|
||||
yield {"answer": think + self._remove_result_tags(summary_think) + "</think>", "reference": {}, "audio_binary": None}
|
||||
|
||||
all_reasoning_steps.append(summary_think)
|
||||
msg_hisotry.append(
|
||||
msg_history.append(
|
||||
{"role": "user", "content": f"\n\n{BEGIN_SEARCH_RESULT}{summary_think}{END_SEARCH_RESULT}\n\n"})
|
||||
think += rm_result_tags(summary_think)
|
||||
logging.info(f"[THINK]Summary: {ii}. {summary_think}")
|
||||
think += self._remove_result_tags(summary_think)
|
||||
logging.info(f"[THINK]Summary: {step_index}. {summary_think}")
|
||||
|
||||
yield think + "</think>"
|
||||
|
|
|
@ -68,6 +68,7 @@ REASON_PROMPT = (
|
|||
f"- You have a dataset to search, so you just provide a proper search query.\n"
|
||||
f"- Use {BEGIN_SEARCH_QUERY} to request a dataset search and end with {END_SEARCH_QUERY}.\n"
|
||||
"- The language of query MUST be as the same as 'Question' or 'search result'.\n"
|
||||
"- If no helpful information can be found, rewrite the search query to be less and precise keywords.\n"
|
||||
"- When done searching, continue your reasoning.\n\n"
|
||||
'Please answer the following question. You should think step by step to solve it.\n\n'
|
||||
)
|
||||
|
|
|
@ -17,6 +17,7 @@ import json
|
|||
import re
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
import trio
|
||||
from api.db.db_models import APIToken
|
||||
|
||||
from api.db.services.conversation_service import ConversationService, structure_answer
|
||||
|
@ -386,7 +387,8 @@ def mindmap():
|
|||
rank_feature=label_question(question, [kb])
|
||||
)
|
||||
mindmap = MindMapExtractor(chat_mdl)
|
||||
mind_map = mindmap([c["content_with_weight"] for c in ranks["chunks"]]).output
|
||||
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in ranks["chunks"]])
|
||||
mind_map = mind_map.output
|
||||
if "error" in mind_map:
|
||||
return server_error_response(Exception(mind_map["error"]))
|
||||
return get_json_result(data=mind_map)
|
||||
|
|
|
@ -71,11 +71,13 @@ def upload():
|
|||
if not e:
|
||||
raise LookupError("Can't find this knowledgebase!")
|
||||
|
||||
err, _ = FileService.upload_document(kb, file_objs, current_user.id)
|
||||
err, files = FileService.upload_document(kb, file_objs, current_user.id)
|
||||
files = [f[0] for f in files] # remove the blob
|
||||
|
||||
if err:
|
||||
return get_json_result(
|
||||
data=False, message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
|
||||
return get_json_result(data=True)
|
||||
data=files, message="\n".join(err), code=settings.RetCode.SERVER_ERROR)
|
||||
return get_json_result(data=files)
|
||||
|
||||
|
||||
@manager.route('/web_crawl', methods=['POST']) # noqa: F821
|
||||
|
@ -345,7 +347,7 @@ def rm():
|
|||
@manager.route('/run', methods=['POST']) # noqa: F821
|
||||
@login_required
|
||||
@validate_request("doc_ids", "run")
|
||||
def run():
|
||||
def run():
|
||||
req = request.json
|
||||
for doc_id in req["doc_ids"]:
|
||||
if not DocumentService.accessible(doc_id, current_user.id):
|
||||
|
|
|
@ -135,6 +135,8 @@ def set_api_key():
|
|||
def add_llm():
|
||||
req = request.json
|
||||
factory = req["llm_factory"]
|
||||
api_key = req.get("api_key", "x")
|
||||
llm_name = req["llm_name"]
|
||||
|
||||
def apikey_json(keys):
|
||||
nonlocal req
|
||||
|
@ -143,7 +145,6 @@ def add_llm():
|
|||
if factory == "VolcEngine":
|
||||
# For VolcEngine, due to its special authentication method
|
||||
# Assemble ark_api_key endpoint_id into api_key
|
||||
llm_name = req["llm_name"]
|
||||
api_key = apikey_json(["ark_api_key", "endpoint_id"])
|
||||
|
||||
elif factory == "Tencent Hunyuan":
|
||||
|
@ -157,52 +158,38 @@ def add_llm():
|
|||
elif factory == "Bedrock":
|
||||
# For Bedrock, due to its special authentication method
|
||||
# Assemble bedrock_ak, bedrock_sk, bedrock_region
|
||||
llm_name = req["llm_name"]
|
||||
api_key = apikey_json(["bedrock_ak", "bedrock_sk", "bedrock_region"])
|
||||
|
||||
elif factory == "LocalAI":
|
||||
llm_name = req["llm_name"] + "___LocalAI"
|
||||
api_key = "xxxxxxxxxxxxxxx"
|
||||
llm_name += "___LocalAI"
|
||||
|
||||
elif factory == "HuggingFace":
|
||||
llm_name = req["llm_name"] + "___HuggingFace"
|
||||
api_key = "xxxxxxxxxxxxxxx"
|
||||
llm_name += "___HuggingFace"
|
||||
|
||||
elif factory == "OpenAI-API-Compatible":
|
||||
llm_name = req["llm_name"] + "___OpenAI-API"
|
||||
api_key = req.get("api_key", "xxxxxxxxxxxxxxx")
|
||||
llm_name += "___OpenAI-API"
|
||||
|
||||
elif factory == "VLLM":
|
||||
llm_name = req["llm_name"] + "___VLLM"
|
||||
api_key = req.get("api_key", "xxxxxxxxxxxxxxx")
|
||||
llm_name += "___VLLM"
|
||||
|
||||
elif factory == "XunFei Spark":
|
||||
llm_name = req["llm_name"]
|
||||
if req["model_type"] == "chat":
|
||||
api_key = req.get("spark_api_password", "xxxxxxxxxxxxxxx")
|
||||
api_key = req.get("spark_api_password", "")
|
||||
elif req["model_type"] == "tts":
|
||||
api_key = apikey_json(["spark_app_id", "spark_api_secret", "spark_api_key"])
|
||||
|
||||
elif factory == "BaiduYiyan":
|
||||
llm_name = req["llm_name"]
|
||||
api_key = apikey_json(["yiyan_ak", "yiyan_sk"])
|
||||
|
||||
elif factory == "Fish Audio":
|
||||
llm_name = req["llm_name"]
|
||||
api_key = apikey_json(["fish_audio_ak", "fish_audio_refid"])
|
||||
|
||||
elif factory == "Google Cloud":
|
||||
llm_name = req["llm_name"]
|
||||
api_key = apikey_json(["google_project_id", "google_region", "google_service_account_key"])
|
||||
|
||||
elif factory == "Azure-OpenAI":
|
||||
llm_name = req["llm_name"]
|
||||
api_key = apikey_json(["api_key", "api_version"])
|
||||
|
||||
else:
|
||||
llm_name = req["llm_name"]
|
||||
api_key = req.get("api_key", "xxxxxxxxxxxxxxx")
|
||||
|
||||
llm = {
|
||||
"tenant_id": current_user.id,
|
||||
"llm_factory": factory,
|
||||
|
@ -351,8 +338,6 @@ def list_app():
|
|||
|
||||
llm_set = set([m["llm_name"] + "@" + m["fid"] for m in llms])
|
||||
for o in objs:
|
||||
if not o.api_key:
|
||||
continue
|
||||
if o.llm_name + "@" + o.llm_factory in llm_set:
|
||||
continue
|
||||
llms.append({"llm_name": o.llm_name, "model_type": o.model_type, "fid": o.llm_factory, "available": True})
|
||||
|
|
|
@ -31,9 +31,7 @@ from api.utils.api_utils import get_result
|
|||
@token_required
|
||||
def create(tenant_id):
|
||||
req = request.json
|
||||
ids = req.get("dataset_ids")
|
||||
if not ids:
|
||||
return get_error_data_result(message="`dataset_ids` is required")
|
||||
ids = [i for i in req.get("dataset_ids", []) if i]
|
||||
for kb_id in ids:
|
||||
kbs = KnowledgebaseService.accessible(kb_id=kb_id, user_id=tenant_id)
|
||||
if not kbs:
|
||||
|
@ -42,10 +40,16 @@ def create(tenant_id):
|
|||
kb = kbs[0]
|
||||
if kb.chunk_num == 0:
|
||||
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
|
||||
kbs = KnowledgebaseService.get_by_ids(ids)
|
||||
|
||||
# Check if all documents in the knowledge base have been parsed
|
||||
is_done, error_msg = KnowledgebaseService.is_parsed_done(kb_id)
|
||||
if not is_done:
|
||||
return get_error_data_result(error_msg)
|
||||
|
||||
kbs = KnowledgebaseService.get_by_ids(ids) if ids else []
|
||||
embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison
|
||||
embd_count = list(set(embd_ids))
|
||||
if len(embd_count) != 1:
|
||||
if len(embd_count) > 1:
|
||||
return get_result(message='Datasets use different embedding models."',
|
||||
code=settings.RetCode.AUTHENTICATION_ERROR)
|
||||
req["kb_ids"] = ids
|
||||
|
@ -178,6 +182,12 @@ def update(tenant_id, chat_id):
|
|||
kb = kbs[0]
|
||||
if kb.chunk_num == 0:
|
||||
return get_error_data_result(f"The dataset {kb_id} doesn't own parsed file")
|
||||
|
||||
# Check if all documents in the knowledge base have been parsed
|
||||
is_done, error_msg = KnowledgebaseService.is_parsed_done(kb_id)
|
||||
if not is_done:
|
||||
return get_error_data_result(error_msg)
|
||||
|
||||
kbs = KnowledgebaseService.get_by_ids(ids)
|
||||
embd_ids = [TenantLLMService.split_model_name_and_factory(kb.embd_id)[0] for kb in kbs] # remove vendor suffix for comparison
|
||||
embd_count = list(set(embd_ids))
|
||||
|
@ -320,7 +330,7 @@ def list_chat(tenant_id):
|
|||
for kb_id in res["kb_ids"]:
|
||||
kb = KnowledgebaseService.query(id=kb_id)
|
||||
if not kb:
|
||||
logging.WARN(f"Don't exist the kb {kb_id}")
|
||||
logging.warning(f"The kb {kb_id} does not exist.")
|
||||
continue
|
||||
kb_list.append(kb[0].to_json())
|
||||
del res["kb_ids"]
|
||||
|
|
|
@ -30,7 +30,7 @@ from api.utils.api_utils import (
|
|||
token_required,
|
||||
get_error_data_result,
|
||||
valid,
|
||||
get_parser_config,
|
||||
get_parser_config, valid_parser_config,
|
||||
)
|
||||
|
||||
|
||||
|
@ -66,10 +66,6 @@ def create(tenant_id):
|
|||
type: string
|
||||
enum: ['me', 'team']
|
||||
description: Dataset permission.
|
||||
language:
|
||||
type: string
|
||||
enum: ['Chinese', 'English']
|
||||
description: Language of the dataset.
|
||||
chunk_method:
|
||||
type: string
|
||||
enum: ["naive", "manual", "qa", "table", "paper", "book", "laws",
|
||||
|
@ -91,11 +87,10 @@ def create(tenant_id):
|
|||
req = request.json
|
||||
e, t = TenantService.get_by_id(tenant_id)
|
||||
permission = req.get("permission")
|
||||
language = req.get("language")
|
||||
chunk_method = req.get("chunk_method")
|
||||
parser_config = req.get("parser_config")
|
||||
valid_parser_config(parser_config)
|
||||
valid_permission = ["me", "team"]
|
||||
valid_language = ["Chinese", "English"]
|
||||
valid_chunk_method = [
|
||||
"naive",
|
||||
"manual",
|
||||
|
@ -114,8 +109,6 @@ def create(tenant_id):
|
|||
check_validation = valid(
|
||||
permission,
|
||||
valid_permission,
|
||||
language,
|
||||
valid_language,
|
||||
chunk_method,
|
||||
valid_chunk_method,
|
||||
)
|
||||
|
@ -134,13 +127,18 @@ def create(tenant_id):
|
|||
req["name"] = req["name"].strip()
|
||||
if req["name"] == "":
|
||||
return get_error_data_result(message="`name` is not empty string!")
|
||||
if len(req["name"]) >= 128:
|
||||
return get_error_data_result(
|
||||
message="Dataset name should not be longer than 128 characters."
|
||||
)
|
||||
if KnowledgebaseService.query(
|
||||
name=req["name"], tenant_id=tenant_id, status=StatusEnum.VALID.value
|
||||
):
|
||||
return get_error_data_result(
|
||||
message="Duplicated dataset name in creating dataset."
|
||||
)
|
||||
req["tenant_id"] = req["created_by"] = tenant_id
|
||||
req["tenant_id"] = tenant_id
|
||||
req["created_by"] = tenant_id
|
||||
if not req.get("embedding_model"):
|
||||
req["embedding_model"] = t.embd_id
|
||||
else:
|
||||
|
@ -182,6 +180,10 @@ def create(tenant_id):
|
|||
if old_key in req
|
||||
}
|
||||
req.update(mapped_keys)
|
||||
flds = list(req.keys())
|
||||
for f in flds:
|
||||
if req[f] == "" and f in ["permission", "parser_id", "chunk_method"]:
|
||||
del req[f]
|
||||
if not KnowledgebaseService.save(**req):
|
||||
return get_error_data_result(message="Create dataset error.(Database error)")
|
||||
renamed_data = {}
|
||||
|
@ -226,6 +228,8 @@ def delete(tenant_id):
|
|||
schema:
|
||||
type: object
|
||||
"""
|
||||
errors = []
|
||||
success_count = 0
|
||||
req = request.json
|
||||
if not req:
|
||||
ids = None
|
||||
|
@ -241,12 +245,12 @@ def delete(tenant_id):
|
|||
for id in id_list:
|
||||
kbs = KnowledgebaseService.query(id=id, tenant_id=tenant_id)
|
||||
if not kbs:
|
||||
return get_error_data_result(message=f"You don't own the dataset {id}")
|
||||
errors.append(f"You don't own the dataset {id}")
|
||||
continue
|
||||
for doc in DocumentService.query(kb_id=id):
|
||||
if not DocumentService.remove_document(doc, tenant_id):
|
||||
return get_error_data_result(
|
||||
message="Remove document error.(Database error)"
|
||||
)
|
||||
errors.append(f"Remove document error for dataset {id}")
|
||||
continue
|
||||
f2d = File2DocumentService.get_by_document_id(doc.id)
|
||||
FileService.filter_delete(
|
||||
[
|
||||
|
@ -258,11 +262,21 @@ def delete(tenant_id):
|
|||
FileService.filter_delete(
|
||||
[File.source_type == FileSource.KNOWLEDGEBASE, File.type == "folder", File.name == kbs[0].name])
|
||||
if not KnowledgebaseService.delete_by_id(id):
|
||||
return get_error_data_result(message="Delete dataset error.(Database error)")
|
||||
errors.append(f"Delete dataset error for {id}")
|
||||
continue
|
||||
success_count += 1
|
||||
if errors:
|
||||
if success_count > 0:
|
||||
return get_result(
|
||||
data={"success_count": success_count, "errors": errors},
|
||||
message=f"Partially deleted {success_count} datasets with {len(errors)} errors"
|
||||
)
|
||||
else:
|
||||
return get_error_data_result(message="; ".join(errors))
|
||||
return get_result(code=settings.RetCode.SUCCESS)
|
||||
|
||||
|
||||
@manager.route("/datasets/<dataset_id>", methods=["PUT"]) # noqa: F821
|
||||
@manager.route("/datasets/<dataset_id>", methods=["PUT"]) # noqa: F821
|
||||
@token_required
|
||||
def update(tenant_id, dataset_id):
|
||||
"""
|
||||
|
@ -297,10 +311,6 @@ def update(tenant_id, dataset_id):
|
|||
type: string
|
||||
enum: ['me', 'team']
|
||||
description: Updated permission.
|
||||
language:
|
||||
type: string
|
||||
enum: ['Chinese', 'English']
|
||||
description: Updated language.
|
||||
chunk_method:
|
||||
type: string
|
||||
enum: ["naive", "manual", "qa", "table", "paper", "book", "laws",
|
||||
|
@ -320,15 +330,14 @@ def update(tenant_id, dataset_id):
|
|||
return get_error_data_result(message="You don't own the dataset")
|
||||
req = request.json
|
||||
e, t = TenantService.get_by_id(tenant_id)
|
||||
invalid_keys = {"id", "embd_id", "chunk_num", "doc_num", "parser_id"}
|
||||
invalid_keys = {"id", "embd_id", "chunk_num", "doc_num", "parser_id", "create_date", "create_time", "created_by", "status","token_num","update_date","update_time"}
|
||||
if any(key in req for key in invalid_keys):
|
||||
return get_error_data_result(message="The input parameters are invalid.")
|
||||
permission = req.get("permission")
|
||||
language = req.get("language")
|
||||
chunk_method = req.get("chunk_method")
|
||||
parser_config = req.get("parser_config")
|
||||
valid_parser_config(parser_config)
|
||||
valid_permission = ["me", "team"]
|
||||
valid_language = ["Chinese", "English"]
|
||||
valid_chunk_method = [
|
||||
"naive",
|
||||
"manual",
|
||||
|
@ -347,8 +356,6 @@ def update(tenant_id, dataset_id):
|
|||
check_validation = valid(
|
||||
permission,
|
||||
valid_permission,
|
||||
language,
|
||||
valid_language,
|
||||
chunk_method,
|
||||
valid_chunk_method,
|
||||
)
|
||||
|
@ -370,7 +377,7 @@ def update(tenant_id, dataset_id):
|
|||
if req["document_count"] != kb.doc_num:
|
||||
return get_error_data_result(message="Can't change `document_count`.")
|
||||
req.pop("document_count")
|
||||
if "chunk_method" in req:
|
||||
if req.get("chunk_method"):
|
||||
if kb.chunk_num != 0 and req["chunk_method"] != kb.parser_id:
|
||||
return get_error_data_result(
|
||||
message="If `chunk_count` is not 0, `chunk_method` is not changeable."
|
||||
|
@ -416,6 +423,10 @@ def update(tenant_id, dataset_id):
|
|||
req["embd_id"] = req.pop("embedding_model")
|
||||
if "name" in req:
|
||||
req["name"] = req["name"].strip()
|
||||
if len(req["name"]) >= 128:
|
||||
return get_error_data_result(
|
||||
message="Dataset name should not be longer than 128 characters."
|
||||
)
|
||||
if (
|
||||
req["name"].lower() != kb.name.lower()
|
||||
and len(
|
||||
|
@ -428,6 +439,10 @@ def update(tenant_id, dataset_id):
|
|||
return get_error_data_result(
|
||||
message="Duplicated dataset name in updating dataset."
|
||||
)
|
||||
flds = list(req.keys())
|
||||
for f in flds:
|
||||
if req[f] == "" and f in ["permission", "parser_id", "chunk_method"]:
|
||||
del req[f]
|
||||
if not KnowledgebaseService.update_by_id(kb.id, req):
|
||||
return get_error_data_result(message="Update dataset error.(Database error)")
|
||||
return get_result(code=settings.RetCode.SUCCESS)
|
||||
|
@ -435,7 +450,7 @@ def update(tenant_id, dataset_id):
|
|||
|
||||
@manager.route("/datasets", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
def list(tenant_id):
|
||||
def list_datasets(tenant_id):
|
||||
"""
|
||||
List datasets.
|
||||
---
|
||||
|
@ -504,7 +519,9 @@ def list(tenant_id):
|
|||
page_number = int(request.args.get("page", 1))
|
||||
items_per_page = int(request.args.get("page_size", 30))
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
if request.args.get("desc") == "False" or request.args.get("desc") == "false":
|
||||
if request.args.get("desc", "false").lower() not in ["true", "false"]:
|
||||
return get_error_data_result("desc should be true or false")
|
||||
if request.args.get("desc", "true").lower() == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
|
|
|
@ -240,6 +240,11 @@ def update_doc(tenant_id, dataset_id, document_id):
|
|||
if req["progress"] != doc.progress:
|
||||
return get_error_data_result(message="Can't change `progress`.")
|
||||
|
||||
if "meta_fields" in req:
|
||||
if not isinstance(req["meta_fields"], dict):
|
||||
return get_error_data_result(message="meta_fields must be a dictionary")
|
||||
DocumentService.update_meta_fields(document_id, req["meta_fields"])
|
||||
|
||||
if "name" in req and req["name"] != doc.name:
|
||||
if (
|
||||
pathlib.Path(req["name"].lower()).suffix
|
||||
|
@ -256,15 +261,12 @@ def update_doc(tenant_id, dataset_id, document_id):
|
|||
)
|
||||
if not DocumentService.update_by_id(document_id, {"name": req["name"]}):
|
||||
return get_error_data_result(message="Database error (Document rename)!")
|
||||
if "meta_fields" in req:
|
||||
if not isinstance(req["meta_fields"], dict):
|
||||
return get_error_data_result(message="meta_fields must be a dictionary")
|
||||
DocumentService.update_meta_fields(document_id, req["meta_fields"])
|
||||
|
||||
informs = File2DocumentService.get_by_document_id(document_id)
|
||||
if informs:
|
||||
e, file = FileService.get_by_id(informs[0].file_id)
|
||||
FileService.update_by_id(file.id, {"name": req["name"]})
|
||||
|
||||
if "parser_config" in req:
|
||||
DocumentService.update_parser_config(doc.id, req["parser_config"])
|
||||
if "chunk_method" in req:
|
||||
|
|
|
@ -259,6 +259,7 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
|||
# The choices field on the last chunk will always be an empty array [].
|
||||
def streamed_response_generator(chat_id, dia, msg):
|
||||
token_used = 0
|
||||
should_split_index = 0
|
||||
response = {
|
||||
"id": f"chatcmpl-{chat_id}",
|
||||
"choices": [
|
||||
|
@ -284,8 +285,13 @@ def chat_completion_openai_like(tenant_id, chat_id):
|
|||
try:
|
||||
for ans in chat(dia, msg, True):
|
||||
answer = ans["answer"]
|
||||
incremental = answer[token_used:]
|
||||
incremental = answer[should_split_index:]
|
||||
token_used += len(incremental)
|
||||
if incremental.endswith("</think>"):
|
||||
response_data_len = len(incremental.rstrip("</think>"))
|
||||
else:
|
||||
response_data_len = len(incremental)
|
||||
should_split_index += response_data_len
|
||||
response["choices"][0]["delta"]["content"] = incremental
|
||||
yield f"data:{json.dumps(response, ensure_ascii=False)}\n\n"
|
||||
except Exception as e:
|
||||
|
@ -365,6 +371,18 @@ def agent_completions(tenant_id, agent_id):
|
|||
conv = API4ConversationService.query(id=req["session_id"], dialog_id=agent_id)
|
||||
if not conv:
|
||||
return get_error_data_result(f"You don't own the session {req['session_id']}")
|
||||
# If an update to UserCanvas is detected, update the API4Conversation.dsl
|
||||
sync_dsl = req.get("sync_dsl", False)
|
||||
if sync_dsl is True and cvs[0].update_time > conv[0].update_time:
|
||||
current_dsl = conv[0].dsl
|
||||
new_dsl = json.loads(dsl)
|
||||
state_fields = ["history", "messages", "path", "reference"]
|
||||
states = {field: current_dsl.get(field, []) for field in state_fields}
|
||||
current_dsl.update(new_dsl)
|
||||
current_dsl.update(states)
|
||||
API4ConversationService.update_by_id(req["session_id"], {
|
||||
"dsl": current_dsl
|
||||
})
|
||||
else:
|
||||
req["question"] = ""
|
||||
if req.get("stream", True):
|
||||
|
@ -448,7 +466,10 @@ def list_agent_session(tenant_id, agent_id):
|
|||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
convs = API4ConversationService.get_list(agent_id, tenant_id, page_number, items_per_page, orderby, desc, id, user_id)
|
||||
# dsl defaults to True in all cases except for False and false
|
||||
include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false"
|
||||
convs = API4ConversationService.get_list(agent_id, tenant_id, page_number, items_per_page, orderby, desc, id,
|
||||
user_id, include_dsl)
|
||||
if not convs:
|
||||
return get_result(data=[])
|
||||
for conv in convs:
|
||||
|
@ -511,6 +532,38 @@ def delete(tenant_id, chat_id):
|
|||
return get_result()
|
||||
|
||||
|
||||
@manager.route('/agents/<agent_id>/sessions', methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
def delete_agent_session(tenant_id, agent_id):
|
||||
req = request.json
|
||||
cvs = UserCanvasService.query(user_id=tenant_id, id=agent_id)
|
||||
if not cvs:
|
||||
return get_error_data_result(f"You don't own the agent {agent_id}")
|
||||
|
||||
convs = API4ConversationService.query(dialog_id=agent_id)
|
||||
if not convs:
|
||||
return get_error_data_result(f"Agent {agent_id} has no sessions")
|
||||
|
||||
if not req:
|
||||
ids = None
|
||||
else:
|
||||
ids = req.get("ids")
|
||||
|
||||
if not ids:
|
||||
conv_list = []
|
||||
for conv in convs:
|
||||
conv_list.append(conv.id)
|
||||
else:
|
||||
conv_list = ids
|
||||
|
||||
for session_id in conv_list:
|
||||
conv = API4ConversationService.query(id=session_id, dialog_id=agent_id)
|
||||
if not conv:
|
||||
return get_error_data_result(f"The agent doesn't own the session ${session_id}")
|
||||
API4ConversationService.delete_by_id(session_id)
|
||||
return get_result()
|
||||
|
||||
|
||||
@manager.route('/sessions/ask', methods=['POST']) # noqa: F821
|
||||
@token_required
|
||||
def ask_about(tenant_id):
|
||||
|
|
|
@ -201,7 +201,7 @@ def new_token():
|
|||
if not tenants:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
tenant_id = tenants[0].tenant_id
|
||||
tenant_id = [tenant for tenant in tenants if tenant.role == 'owner'][0].tenant_id
|
||||
obj = {
|
||||
"tenant_id": tenant_id,
|
||||
"token": generate_confirmation_token(tenant_id),
|
||||
|
@ -256,7 +256,7 @@ def token_list():
|
|||
if not tenants:
|
||||
return get_data_error_result(message="Tenant not found!")
|
||||
|
||||
tenant_id = tenants[0].tenant_id
|
||||
tenant_id = [tenant for tenant in tenants if tenant.role == 'owner'][0].tenant_id
|
||||
objs = APITokenService.query(tenant_id=tenant_id)
|
||||
objs = [o.to_dict() for o in objs]
|
||||
for o in objs:
|
||||
|
|
|
@ -843,8 +843,8 @@ class Task(DataBaseModel):
|
|||
id = CharField(max_length=32, primary_key=True)
|
||||
doc_id = CharField(max_length=32, null=False, index=True)
|
||||
from_page = IntegerField(default=0)
|
||||
|
||||
to_page = IntegerField(default=100000000)
|
||||
task_type = CharField(max_length=32, null=False, default="")
|
||||
|
||||
begin_at = DateTimeField(null=True, index=True)
|
||||
process_duation = FloatField(default=0)
|
||||
|
@ -935,7 +935,7 @@ class Conversation(DataBaseModel):
|
|||
class APIToken(DataBaseModel):
|
||||
tenant_id = CharField(max_length=32, null=False, index=True)
|
||||
token = CharField(max_length=255, null=False, index=True)
|
||||
dialog_id = CharField(max_length=32, null=False, index=True)
|
||||
dialog_id = CharField(max_length=32, null=True, index=True)
|
||||
source = CharField(max_length=16, null=True, help_text="none|agent|dialog", index=True)
|
||||
beta = CharField(max_length=255, null=True, index=True)
|
||||
|
||||
|
@ -1115,3 +1115,10 @@ def migrate_db():
|
|||
)
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
migrate(
|
||||
migrator.add_column("task", "task_type",
|
||||
CharField(max_length=32, null=False, default=""))
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
|
|
@ -160,7 +160,7 @@ def add_graph_templates():
|
|||
dir = os.path.join(get_project_base_directory(), "agent", "templates")
|
||||
for fnm in os.listdir(dir):
|
||||
try:
|
||||
cnvs = json.load(open(os.path.join(dir, fnm), "r"))
|
||||
cnvs = json.load(open(os.path.join(dir, fnm), "r",encoding="utf-8"))
|
||||
try:
|
||||
CanvasTemplateService.save(**cnvs)
|
||||
except Exception:
|
||||
|
|
|
@ -43,8 +43,12 @@ class API4ConversationService(CommonService):
|
|||
@DB.connection_context()
|
||||
def get_list(cls, dialog_id, tenant_id,
|
||||
page_number, items_per_page,
|
||||
orderby, desc, id, user_id=None):
|
||||
sessions = cls.model.select().where(cls.model.dialog_id == dialog_id)
|
||||
orderby, desc, id, user_id=None, include_dsl=True):
|
||||
if include_dsl:
|
||||
sessions = cls.model.select().where(cls.model.dialog_id == dialog_id)
|
||||
else:
|
||||
fields = [field for field in cls.model._meta.fields.values() if field.name != 'dsl']
|
||||
sessions = cls.model.select(*fields).where(cls.model.dialog_id == dialog_id)
|
||||
if id:
|
||||
sessions = sessions.where(cls.model.id == id)
|
||||
if user_id:
|
||||
|
|
|
@ -86,21 +86,9 @@ def completion(tenant_id, agent_id, question, session_id=None, stream=True, **kw
|
|||
"dsl": cvs.dsl
|
||||
}
|
||||
API4ConversationService.save(**conv)
|
||||
if query:
|
||||
yield "data:" + json.dumps({"code": 0,
|
||||
"message": "",
|
||||
"data": {
|
||||
"session_id": session_id,
|
||||
"answer": canvas.get_prologue(),
|
||||
"reference": [],
|
||||
"param": canvas.get_preset_param()
|
||||
}
|
||||
},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": True}, ensure_ascii=False) + "\n\n"
|
||||
return
|
||||
else:
|
||||
conv = API4Conversation(**conv)
|
||||
|
||||
|
||||
conv = API4Conversation(**conv)
|
||||
else:
|
||||
e, conv = API4ConversationService.get_by_id(session_id)
|
||||
assert e, "Session not found!"
|
||||
|
@ -130,7 +118,7 @@ def completion(tenant_id, agent_id, question, session_id=None, stream=True, **kw
|
|||
continue
|
||||
for k in ans.keys():
|
||||
final_ans[k] = ans[k]
|
||||
ans = {"answer": ans["content"], "reference": ans.get("reference", [])}
|
||||
ans = {"answer": ans["content"], "reference": ans.get("reference", []), "param": canvas.get_preset_param()}
|
||||
ans = structure_answer(conv, ans, message_id, session_id)
|
||||
yield "data:" + json.dumps({"code": 0, "message": "", "data": ans},
|
||||
ensure_ascii=False) + "\n\n"
|
||||
|
@ -160,8 +148,8 @@ def completion(tenant_id, agent_id, question, session_id=None, stream=True, **kw
|
|||
canvas.reference.append(final_ans["reference"])
|
||||
conv.dsl = json.loads(str(canvas))
|
||||
|
||||
result = {"answer": final_ans["content"], "reference": final_ans.get("reference", [])}
|
||||
result = {"answer": final_ans["content"], "reference": final_ans.get("reference", []) , "param": canvas.get_preset_param()}
|
||||
result = structure_answer(conv, result, message_id, session_id)
|
||||
API4ConversationService.append_message(conv.id, conv.to_dict())
|
||||
yield result
|
||||
break
|
||||
break
|
||||
|
|
|
@ -30,7 +30,8 @@ from api import settings
|
|||
from rag.app.resume import forbidden_select_fields4resume
|
||||
from rag.app.tag import label_question
|
||||
from rag.nlp.search import index_name
|
||||
from rag.prompts import kb_prompt, message_fit_in, llm_id2llm_type, keyword_extraction, full_question, chunks_format
|
||||
from rag.prompts import kb_prompt, message_fit_in, llm_id2llm_type, keyword_extraction, full_question, chunks_format, \
|
||||
citation_prompt
|
||||
from rag.utils import rmSpace, num_tokens_from_string
|
||||
from rag.utils.tavily_conn import Tavily
|
||||
|
||||
|
@ -72,7 +73,7 @@ def chat_solo(dialog, messages, stream=True):
|
|||
if prompt_config.get("tts"):
|
||||
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
|
||||
msg = [{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])}
|
||||
for m in messages if m["role"] != "system"]
|
||||
for m in messages if m["role"] != "system"]
|
||||
if stream:
|
||||
last_ans = ""
|
||||
for ans in chat_mdl.chat_streamly(prompt_config.get("system", ""), msg, dialog.llm_setting):
|
||||
|
@ -81,7 +82,9 @@ def chat_solo(dialog, messages, stream=True):
|
|||
if num_tokens_from_string(delta_ans) < 16:
|
||||
continue
|
||||
last_ans = answer
|
||||
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt":"", "created_at": time.time()}
|
||||
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
|
||||
if delta_ans:
|
||||
yield {"answer": answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans), "prompt": "", "created_at": time.time()}
|
||||
else:
|
||||
answer = chat_mdl.chat(prompt_config.get("system", ""), msg, dialog.llm_setting)
|
||||
user_content = msg[-1].get("content", "[content not available]")
|
||||
|
@ -90,126 +93,106 @@ def chat_solo(dialog, messages, stream=True):
|
|||
|
||||
|
||||
def chat(dialog, messages, stream=True, **kwargs):
|
||||
# 确保最后一条消息来自用户,否则抛出异常
|
||||
assert messages[-1]["role"] == "user", "The last content of this conversation is not from user."
|
||||
# 如果对话没有关联知识库,则调用chat_solo函数进行简单对话处理
|
||||
if not dialog.kb_ids:
|
||||
for ans in chat_solo(dialog, messages, stream):
|
||||
yield ans
|
||||
return
|
||||
# 记录聊天开始时间,用于后续性能分析
|
||||
|
||||
chat_start_ts = timer()
|
||||
# 根据对话配置的LLM类型获取模型配置
|
||||
|
||||
if llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||||
else:
|
||||
llm_model_config = TenantLLMService.get_model_config(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||
# 获取模型支持的最大token数,默认为8192
|
||||
|
||||
max_tokens = llm_model_config.get("max_tokens", 8192)
|
||||
# 记录检查LLM配置的时间点
|
||||
|
||||
check_llm_ts = timer()
|
||||
# 获取对话关联的所有知识库
|
||||
|
||||
kbs = KnowledgebaseService.get_by_ids(dialog.kb_ids)
|
||||
# 提取所有知识库使用的嵌入模型ID,并确保它们使用相同的嵌入模型
|
||||
embedding_list = list(set([kb.embd_id for kb in kbs]))
|
||||
if len(embedding_list) != 1:
|
||||
# 如果知识库使用了不同的嵌入模型,返回错误信息
|
||||
yield {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
|
||||
return {"answer": "**ERROR**: Knowledge bases use different embedding models.", "reference": []}
|
||||
# 获取嵌入模型名称
|
||||
|
||||
embedding_model_name = embedding_list[0]
|
||||
# 获取检索器实例
|
||||
|
||||
retriever = settings.retrievaler
|
||||
# 提取最近3条用户消息作为问题上下文
|
||||
|
||||
questions = [m["content"] for m in messages if m["role"] == "user"][-3:]
|
||||
# 处理附件文档ID
|
||||
attachments = kwargs["doc_ids"].split(",") if "doc_ids" in kwargs else None
|
||||
if "doc_ids" in messages[-1]:
|
||||
attachments = messages[-1]["doc_ids"]
|
||||
# 记录创建检索器的时间点
|
||||
|
||||
create_retriever_ts = timer()
|
||||
# 初始化嵌入模型
|
||||
|
||||
embd_mdl = LLMBundle(dialog.tenant_id, LLMType.EMBEDDING, embedding_model_name)
|
||||
# 如果嵌入模型不存在,抛出异常
|
||||
if not embd_mdl:
|
||||
raise LookupError("Embedding model(%s) not found" % embedding_model_name)
|
||||
|
||||
# 记录绑定嵌入模型的时间点
|
||||
bind_embedding_ts = timer()
|
||||
# 初始化聊天模型
|
||||
|
||||
if llm_id2llm_type(dialog.llm_id) == "image2text":
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.IMAGE2TEXT, dialog.llm_id)
|
||||
else:
|
||||
chat_mdl = LLMBundle(dialog.tenant_id, LLMType.CHAT, dialog.llm_id)
|
||||
|
||||
# 记录绑定LLM模型的时间点
|
||||
|
||||
bind_llm_ts = timer()
|
||||
# 获取提示词配置
|
||||
|
||||
prompt_config = dialog.prompt_config
|
||||
# 获取字段映射,用于SQL检索
|
||||
field_map = KnowledgebaseService.get_field_map(dialog.kb_ids)
|
||||
tts_mdl = None
|
||||
# 初始化文本转语音模型(如果配置中启用)
|
||||
if prompt_config.get("tts"):
|
||||
tts_mdl = LLMBundle(dialog.tenant_id, LLMType.TTS)
|
||||
# try to use sql if field mapping is good to go
|
||||
# 尝试使用SQL检索(如果有字段映射)
|
||||
if field_map:
|
||||
logging.debug("Use SQL to retrieval:{}".format(questions[-1]))
|
||||
ans = use_sql(questions[-1], field_map, dialog.tenant_id, chat_mdl, prompt_config.get("quote", True))
|
||||
if ans:
|
||||
# 如果SQL检索成功,直接返回结果
|
||||
yield ans
|
||||
return
|
||||
# 处理提示词配置中的参数
|
||||
|
||||
for p in prompt_config["parameters"]:
|
||||
# 跳过knowledge参数,它将在后面处理
|
||||
if p["key"] == "knowledge":
|
||||
continue
|
||||
# 检查必需参数是否提供
|
||||
if p["key"] not in kwargs and not p["optional"]:
|
||||
raise KeyError("Miss parameter: " + p["key"])
|
||||
# 如果参数未提供且为可选,则在提示词中替换为空格
|
||||
if p["key"] not in kwargs:
|
||||
prompt_config["system"] = prompt_config["system"].replace(
|
||||
"{%s}" % p["key"], " ")
|
||||
# 处理多轮对话开启,并且问题数量大于1
|
||||
|
||||
if len(questions) > 1 and prompt_config.get("refine_multiturn"):
|
||||
# 使用full_question函数将多轮对话合并为一个完整问题
|
||||
questions = [full_question(dialog.tenant_id, dialog.llm_id, messages)]
|
||||
else:
|
||||
# 否则只使用最后一个问题
|
||||
questions = questions[-1:]
|
||||
# 记录问题精炼的时间点
|
||||
|
||||
refine_question_ts = timer()
|
||||
# 初始化重排序模型(如果配置)
|
||||
|
||||
rerank_mdl = None
|
||||
if dialog.rerank_id:
|
||||
rerank_mdl = LLMBundle(dialog.tenant_id, LLMType.RERANK, dialog.rerank_id)
|
||||
# 记录绑定重排序模型的时间点
|
||||
|
||||
bind_reranker_ts = timer()
|
||||
generate_keyword_ts = bind_reranker_ts
|
||||
# 初始化思考过程和知识库信息
|
||||
thought = ""
|
||||
kbinfos = {"total": 0, "chunks": [], "doc_aggs": []}
|
||||
# 检查是否需要知识库检索
|
||||
|
||||
if "knowledge" not in [p["key"] for p in prompt_config["parameters"]]:
|
||||
knowledges = []
|
||||
else:
|
||||
# 如果启用了关键词提取,则增强问题
|
||||
if prompt_config.get("keyword", False):
|
||||
questions[-1] += keyword_extraction(chat_mdl, questions[-1])
|
||||
generate_keyword_ts = timer()
|
||||
# 获取所有知识库的表ID
|
||||
|
||||
tenant_ids = list(set([kb.tenant_id for kb in kbs]))
|
||||
|
||||
knowledges = []
|
||||
# 如果启用了推理功能,使用DeepResearcher进行深度推理
|
||||
if prompt_config.get("reasoning", False):
|
||||
reasoner = DeepResearcher(chat_mdl,
|
||||
prompt_config,
|
||||
partial(retriever.retrieval, embd_mdl=embd_mdl, tenant_ids=tenant_ids, kb_ids=dialog.kb_ids, page=1, page_size=dialog.top_n, similarity_threshold=0.2, vector_similarity_weight=0.3))
|
||||
# 执行推理过程
|
||||
|
||||
for think in reasoner.thinking(kbinfos, " ".join(questions)):
|
||||
if isinstance(think, str):
|
||||
thought = think
|
||||
|
@ -217,7 +200,6 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||
elif stream:
|
||||
yield think
|
||||
else:
|
||||
# 使用标准检索方法
|
||||
kbinfos = retriever.retrieval(" ".join(questions), embd_mdl, tenant_ids, dialog.kb_ids, 1, dialog.top_n,
|
||||
dialog.similarity_threshold,
|
||||
dialog.vector_similarity_weight,
|
||||
|
@ -225,13 +207,11 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||
top=dialog.top_k, aggs=False, rerank_mdl=rerank_mdl,
|
||||
rank_feature=label_question(" ".join(questions), kbs)
|
||||
)
|
||||
# 如果配置了Tavily API,使用外部搜索增强检索结果
|
||||
if prompt_config.get("tavily_api_key"):
|
||||
tav = Tavily(prompt_config["tavily_api_key"])
|
||||
tav_res = tav.retrieve_chunks(" ".join(questions))
|
||||
kbinfos["chunks"].extend(tav_res["chunks"])
|
||||
kbinfos["doc_aggs"].extend(tav_res["doc_aggs"])
|
||||
# 如果启用了知识图谱,使用知识图谱检索结果
|
||||
if prompt_config.get("use_kg"):
|
||||
ck = settings.kg_retrievaler.retrieval(" ".join(questions),
|
||||
tenant_ids,
|
||||
|
@ -240,31 +220,31 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||
LLMBundle(dialog.tenant_id, LLMType.CHAT))
|
||||
if ck["content_with_weight"]:
|
||||
kbinfos["chunks"].insert(0, ck)
|
||||
# 将检索到的知识格式化为提示词
|
||||
|
||||
knowledges = kb_prompt(kbinfos, max_tokens)
|
||||
# 记录检索到的知识
|
||||
|
||||
logging.debug(
|
||||
"{}->{}".format(" ".join(questions), "\n->".join(knowledges)))
|
||||
# 记录检索完成的时间点
|
||||
|
||||
retrieval_ts = timer()
|
||||
# 如果没有检索到知识且配置了空响应,则返回空响应
|
||||
if not knowledges and prompt_config.get("empty_response"):
|
||||
empty_res = prompt_config["empty_response"]
|
||||
yield {"answer": empty_res, "reference": kbinfos, "audio_binary": tts(tts_mdl, empty_res)}
|
||||
yield {"answer": empty_res, "reference": kbinfos, "prompt": "\n\n### Query:\n%s" % " ".join(questions), "audio_binary": tts(tts_mdl, empty_res)}
|
||||
return {"answer": prompt_config["empty_response"], "reference": kbinfos}
|
||||
# 将检索到的知识添加到kwargs中,用于格式化系统提示词
|
||||
|
||||
kwargs["knowledge"] = "\n------\n" + "\n\n------\n\n".join(knowledges)
|
||||
# 获取LLM生成配置
|
||||
gen_conf = dialog.llm_setting
|
||||
# 构建消息列表,包括系统提示词和用户消息
|
||||
|
||||
msg = [{"role": "system", "content": prompt_config["system"].format(**kwargs)}]
|
||||
prompt4citation = ""
|
||||
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
||||
prompt4citation = citation_prompt()
|
||||
msg.extend([{"role": m["role"], "content": re.sub(r"##\d+\$\$", "", m["content"])}
|
||||
for m in messages if m["role"] != "system"])
|
||||
# 确保消息列表不超过模型最大token限制
|
||||
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.97))
|
||||
used_token_count, msg = message_fit_in(msg, int(max_tokens * 0.95))
|
||||
assert len(msg) >= 2, f"message_fit_in has bug: {msg}"
|
||||
prompt = msg[0]["content"]
|
||||
# 调整生成配置中的max_tokens,确保不超过模型限制
|
||||
|
||||
if "max_tokens" in gen_conf:
|
||||
gen_conf["max_tokens"] = min(
|
||||
gen_conf["max_tokens"],
|
||||
|
@ -274,22 +254,29 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||
nonlocal prompt_config, knowledges, kwargs, kbinfos, prompt, retrieval_ts, questions
|
||||
|
||||
refs = []
|
||||
# 处理思考过程,如果答案包含</think>标签
|
||||
ans = answer.split("</think>")
|
||||
think = ""
|
||||
if len(ans) == 2:
|
||||
think = ans[0] + "</think>"
|
||||
answer = ans[1]
|
||||
# 如果有知识且启用了引用功能,添加引用
|
||||
if knowledges and (prompt_config.get("quote", True) and kwargs.get("quote", True)):
|
||||
answer, idx = retriever.insert_citations(answer,
|
||||
[ck["content_ltks"]
|
||||
for ck in kbinfos["chunks"]],
|
||||
[ck["vector"]
|
||||
for ck in kbinfos["chunks"]],
|
||||
embd_mdl,
|
||||
tkweight=1 - dialog.vector_similarity_weight,
|
||||
vtweight=dialog.vector_similarity_weight)
|
||||
answer = re.sub(r"##[ij]\$\$", "", answer, flags=re.DOTALL)
|
||||
if not re.search(r"##[0-9]+\$\$", answer):
|
||||
answer, idx = retriever.insert_citations(answer,
|
||||
[ck["content_ltks"]
|
||||
for ck in kbinfos["chunks"]],
|
||||
[ck["vector"]
|
||||
for ck in kbinfos["chunks"]],
|
||||
embd_mdl,
|
||||
tkweight=1 - dialog.vector_similarity_weight,
|
||||
vtweight=dialog.vector_similarity_weight)
|
||||
else:
|
||||
idx = set([])
|
||||
for r in re.finditer(r"##([0-9]+)\$\$", answer):
|
||||
i = int(r.group(1))
|
||||
if i < len(kbinfos["chunks"]):
|
||||
idx.add(i)
|
||||
|
||||
idx = set([kbinfos["chunks"][int(i)]["doc_id"] for i in idx])
|
||||
recall_docs = [
|
||||
d for d in kbinfos["doc_aggs"] if d["doc_id"] in idx]
|
||||
|
@ -304,9 +291,8 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||
|
||||
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
||||
answer += " Please set LLM API-Key in 'User Setting -> Model providers -> API-Key'"
|
||||
# 记录完成时间
|
||||
finish_chat_ts = timer()
|
||||
# 计算各阶段耗时
|
||||
|
||||
total_time_cost = (finish_chat_ts - chat_start_ts) * 1000
|
||||
check_llm_time_cost = (check_llm_ts - chat_start_ts) * 1000
|
||||
create_retriever_time_cost = (create_retriever_ts - check_llm_ts) * 1000
|
||||
|
@ -321,34 +307,25 @@ def chat(dialog, messages, stream=True, **kwargs):
|
|||
prompt += "\n\n### Query:\n%s" % " ".join(questions)
|
||||
prompt = f"{prompt}\n\n - Total: {total_time_cost:.1f}ms\n - Check LLM: {check_llm_time_cost:.1f}ms\n - Create retriever: {create_retriever_time_cost:.1f}ms\n - Bind embedding: {bind_embedding_time_cost:.1f}ms\n - Bind LLM: {bind_llm_time_cost:.1f}ms\n - Tune question: {refine_question_time_cost:.1f}ms\n - Bind reranker: {bind_reranker_time_cost:.1f}ms\n - Generate keyword: {generate_keyword_time_cost:.1f}ms\n - Retrieval: {retrieval_time_cost:.1f}ms\n - Generate answer: {generate_result_time_cost:.1f}ms"
|
||||
return {"answer": think+answer, "reference": refs, "prompt": re.sub(r"\n", " \n", prompt), "created_at": time.time()}
|
||||
# 根据stream参数决定是流式输出还是批量输出
|
||||
|
||||
if stream:
|
||||
# 流式输出模式
|
||||
last_ans = ""
|
||||
answer = ""
|
||||
# 调用LLM的流式生成接口
|
||||
for ans in chat_mdl.chat_streamly(prompt, msg[1:], gen_conf):
|
||||
# 如果有思考过程,从回答中移除思考部分
|
||||
for ans in chat_mdl.chat_streamly(prompt+prompt4citation, msg[1:], gen_conf):
|
||||
if thought:
|
||||
ans = re.sub(r"<think>.*</think>", "", ans, flags=re.DOTALL)
|
||||
answer = ans
|
||||
# 计算增量回答(与上次输出相比新增的内容)
|
||||
delta_ans = ans[len(last_ans):]
|
||||
# 如果增量内容太少,跳过本次输出
|
||||
if num_tokens_from_string(delta_ans) < 16:
|
||||
continue
|
||||
last_ans = answer
|
||||
# 生成包含回答、引用和音频的输出
|
||||
yield {"answer": thought+answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||
delta_ans = answer[len(last_ans):]
|
||||
# 处理最后一部分增量内容
|
||||
if delta_ans:
|
||||
yield {"answer": thought+answer, "reference": {}, "audio_binary": tts(tts_mdl, delta_ans)}
|
||||
# 最后输出装饰后的完整回答(包含引用和性能指标)
|
||||
yield decorate_answer(thought+answer)
|
||||
else:
|
||||
# 非流式输出模式
|
||||
answer = chat_mdl.chat(prompt, msg[1:], gen_conf)
|
||||
answer = chat_mdl.chat(prompt+prompt4citation, msg[1:], gen_conf)
|
||||
user_content = msg[-1].get("content", "[content not available]")
|
||||
logging.debug("User: {}|Assistant: {}".format(user_content, answer))
|
||||
res = decorate_answer(answer)
|
||||
|
@ -377,6 +354,7 @@ Please write the SQL, only SQL, without any other explanations or text.
|
|||
nonlocal sys_prompt, user_prompt, question, tried_times
|
||||
sql = chat_mdl.chat(sys_prompt, [{"role": "user", "content": user_prompt}], {
|
||||
"temperature": 0.06})
|
||||
sql = re.sub(r"<think>.*</think>", "", sql, flags=re.DOTALL)
|
||||
logging.debug(f"{question} ==> {user_prompt} get SQL: {sql}")
|
||||
sql = re.sub(r"[\r\n]+", " ", sql.lower())
|
||||
sql = re.sub(r".*select ", "select ", sql.lower())
|
||||
|
@ -549,12 +527,11 @@ def ask(question, kb_ids, tenant_id):
|
|||
|
||||
if answer.lower().find("invalid key") >= 0 or answer.lower().find("invalid api") >= 0:
|
||||
answer += " Please set LLM API-Key in 'User Setting -> Model Providers -> API-Key'"
|
||||
return {"answer": answer, "reference": chunks_format(refs)}
|
||||
refs["chunks"] = chunks_format(refs)
|
||||
return {"answer": answer, "reference": refs}
|
||||
|
||||
answer = ""
|
||||
for ans in chat_mdl.chat_streamly(prompt, msg, {"temperature": 0.1}):
|
||||
answer = ans
|
||||
yield {"answer": answer, "reference": {}}
|
||||
yield decorate_answer(answer)
|
||||
|
||||
|
||||
|
|
|
@ -22,13 +22,13 @@ from concurrent.futures import ThreadPoolExecutor
|
|||
from copy import deepcopy
|
||||
from datetime import datetime
|
||||
from io import BytesIO
|
||||
import trio
|
||||
|
||||
from peewee import fn
|
||||
|
||||
from api.db.db_utils import bulk_insert_into_db
|
||||
from api import settings
|
||||
from api.utils import current_timestamp, get_format_time, get_uuid
|
||||
from graphrag.general.mind_map_extractor import MindMapExtractor
|
||||
from rag.settings import SVR_QUEUE_NAME
|
||||
from rag.utils.storage_factory import STORAGE_IMPL
|
||||
from rag.nlp import search, rag_tokenizer
|
||||
|
@ -380,12 +380,6 @@ class DocumentService(CommonService):
|
|||
@classmethod
|
||||
@DB.connection_context()
|
||||
def update_progress(cls):
|
||||
MSG = {
|
||||
"raptor": "Start RAPTOR (Recursive Abstractive Processing for Tree-Organized Retrieval).",
|
||||
"graphrag": "Entities",
|
||||
"graph_resolution": "Resolution",
|
||||
"graph_community": "Communities"
|
||||
}
|
||||
docs = cls.get_unfinished_docs()
|
||||
for d in docs:
|
||||
try:
|
||||
|
@ -396,37 +390,31 @@ class DocumentService(CommonService):
|
|||
prg = 0
|
||||
finished = True
|
||||
bad = 0
|
||||
has_raptor = False
|
||||
has_graphrag = False
|
||||
e, doc = DocumentService.get_by_id(d["id"])
|
||||
status = doc.run # TaskStatus.RUNNING.value
|
||||
for t in tsks:
|
||||
if 0 <= t.progress < 1:
|
||||
finished = False
|
||||
prg += t.progress if t.progress >= 0 else 0
|
||||
if t.progress_msg not in msg:
|
||||
msg.append(t.progress_msg)
|
||||
if t.progress == -1:
|
||||
bad += 1
|
||||
prg += t.progress if t.progress >= 0 else 0
|
||||
msg.append(t.progress_msg)
|
||||
if t.task_type == "raptor":
|
||||
has_raptor = True
|
||||
elif t.task_type == "graphrag":
|
||||
has_graphrag = True
|
||||
prg /= len(tsks)
|
||||
if finished and bad:
|
||||
prg = -1
|
||||
status = TaskStatus.FAIL.value
|
||||
elif finished:
|
||||
m = "\n".join(sorted(msg))
|
||||
if d["parser_config"].get("raptor", {}).get("use_raptor") and m.find(MSG["raptor"]) < 0:
|
||||
queue_raptor_o_graphrag_tasks(d, "raptor", MSG["raptor"])
|
||||
if d["parser_config"].get("raptor", {}).get("use_raptor") and not has_raptor:
|
||||
queue_raptor_o_graphrag_tasks(d, "raptor")
|
||||
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
||||
elif d["parser_config"].get("graphrag", {}).get("use_graphrag") and m.find(MSG["graphrag"]) < 0:
|
||||
queue_raptor_o_graphrag_tasks(d, "graphrag", MSG["graphrag"])
|
||||
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
||||
elif d["parser_config"].get("graphrag", {}).get("use_graphrag") \
|
||||
and d["parser_config"].get("graphrag", {}).get("resolution") \
|
||||
and m.find(MSG["graph_resolution"]) < 0:
|
||||
queue_raptor_o_graphrag_tasks(d, "graph_resolution", MSG["graph_resolution"])
|
||||
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
||||
elif d["parser_config"].get("graphrag", {}).get("use_graphrag") \
|
||||
and d["parser_config"].get("graphrag", {}).get("community") \
|
||||
and m.find(MSG["graph_community"]) < 0:
|
||||
queue_raptor_o_graphrag_tasks(d, "graph_community", MSG["graph_community"])
|
||||
elif d["parser_config"].get("graphrag", {}).get("use_graphrag") and not has_graphrag:
|
||||
queue_raptor_o_graphrag_tasks(d, "graphrag")
|
||||
prg = 0.98 * len(tsks) / (len(tsks) + 1)
|
||||
else:
|
||||
status = TaskStatus.DONE.value
|
||||
|
@ -463,7 +451,7 @@ class DocumentService(CommonService):
|
|||
return False
|
||||
|
||||
|
||||
def queue_raptor_o_graphrag_tasks(doc, ty, msg):
|
||||
def queue_raptor_o_graphrag_tasks(doc, ty):
|
||||
chunking_config = DocumentService.get_chunking_config(doc["id"])
|
||||
hasher = xxhash.xxh64()
|
||||
for field in sorted(chunking_config.keys()):
|
||||
|
@ -476,7 +464,8 @@ def queue_raptor_o_graphrag_tasks(doc, ty, msg):
|
|||
"doc_id": doc["id"],
|
||||
"from_page": 100000000,
|
||||
"to_page": 100000000,
|
||||
"progress_msg": datetime.now().strftime("%H:%M:%S") + " " + msg
|
||||
"task_type": ty,
|
||||
"progress_msg": datetime.now().strftime("%H:%M:%S") + " created task " + ty
|
||||
}
|
||||
|
||||
task = new_task()
|
||||
|
@ -485,7 +474,6 @@ def queue_raptor_o_graphrag_tasks(doc, ty, msg):
|
|||
hasher.update(ty.encode("utf-8"))
|
||||
task["digest"] = hasher.hexdigest()
|
||||
bulk_insert_into_db(Task, [task], True)
|
||||
task["task_type"] = ty
|
||||
assert REDIS_CONN.queue_product(SVR_QUEUE_NAME, message=task), "Can't access Redis. Please check the Redis' status."
|
||||
|
||||
|
||||
|
@ -595,10 +583,11 @@ def doc_upload_and_parse(conversation_id, file_objs, user_id):
|
|||
cks = [c for c in docs if c["doc_id"] == doc_id]
|
||||
|
||||
if parser_ids[doc_id] != ParserType.PICTURE.value:
|
||||
from graphrag.general.mind_map_extractor import MindMapExtractor
|
||||
mindmap = MindMapExtractor(llm_bdl)
|
||||
try:
|
||||
mind_map = json.dumps(mindmap([c["content_with_weight"] for c in docs if c["doc_id"] == doc_id]).output,
|
||||
ensure_ascii=False, indent=2)
|
||||
mind_map = trio.run(mindmap, [c["content_with_weight"] for c in docs if c["doc_id"] == doc_id])
|
||||
mind_map = json.dumps(mind_map.output, ensure_ascii=False, indent=2)
|
||||
if len(mind_map) < 32:
|
||||
raise Exception("Few content: " + mind_map)
|
||||
cks.append({
|
||||
|
|
|
@ -22,6 +22,42 @@ from peewee import fn
|
|||
class KnowledgebaseService(CommonService):
|
||||
model = Knowledgebase
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def is_parsed_done(cls, kb_id):
|
||||
"""
|
||||
Check if all documents in the knowledge base have completed parsing
|
||||
|
||||
Args:
|
||||
kb_id: Knowledge base ID
|
||||
|
||||
Returns:
|
||||
If all documents are parsed successfully, returns (True, None)
|
||||
If any document is not fully parsed, returns (False, error_message)
|
||||
"""
|
||||
from api.db import TaskStatus
|
||||
from api.db.services.document_service import DocumentService
|
||||
|
||||
# Get knowledge base information
|
||||
kbs = cls.query(id=kb_id)
|
||||
if not kbs:
|
||||
return False, "Knowledge base not found"
|
||||
kb = kbs[0]
|
||||
|
||||
# Get all documents in the knowledge base
|
||||
docs, _ = DocumentService.get_by_kb_id(kb_id, 1, 1000, "create_time", True, "")
|
||||
|
||||
# Check parsing status of each document
|
||||
for doc in docs:
|
||||
# If document is being parsed, don't allow chat creation
|
||||
if doc['run'] == TaskStatus.RUNNING.value or doc['run'] == TaskStatus.CANCEL.value or doc['run'] == TaskStatus.FAIL.value:
|
||||
return False, f"Document '{doc['name']}' in dataset '{kb.name}' is still being parsed. Please wait until all documents are parsed before starting a chat."
|
||||
# If document is not yet parsed and has no chunks, don't allow chat creation
|
||||
if doc['run'] == TaskStatus.UNSTART.value and doc['chunk_num'] == 0:
|
||||
return False, f"Document '{doc['name']}' in dataset '{kb.name}' has not been parsed yet. Please parse all documents before starting a chat."
|
||||
|
||||
return True, None
|
||||
|
||||
@classmethod
|
||||
@DB.connection_context()
|
||||
def list_documents_by_ids(cls,kb_ids):
|
||||
|
|
|
@ -224,7 +224,7 @@ class TenantLLMService(CommonService):
|
|||
return list(objs)
|
||||
|
||||
|
||||
class LLMBundle(object):
|
||||
class LLMBundle:
|
||||
def __init__(self, tenant_id, llm_type, llm_name=None, lang="Chinese"):
|
||||
self.tenant_id = tenant_id
|
||||
self.llm_type = llm_type
|
||||
|
|
|
@ -42,16 +42,22 @@ from api.db.init_data import init_web_data
|
|||
from api.versions import get_ragflow_version
|
||||
from api.utils import show_configs
|
||||
from rag.settings import print_rag_settings
|
||||
from rag.utils.redis_conn import RedisDistributedLock
|
||||
|
||||
stop_event = threading.Event()
|
||||
|
||||
def update_progress():
|
||||
redis_lock = RedisDistributedLock("update_progress", timeout=60)
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
if not redis_lock.acquire():
|
||||
continue
|
||||
DocumentService.update_progress()
|
||||
stop_event.wait(6)
|
||||
except Exception:
|
||||
logging.exception("update_progress exception")
|
||||
finally:
|
||||
redis_lock.release()
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
logging.info("Received interrupt signal, shutting down...")
|
||||
|
|
|
@ -335,11 +335,9 @@ def generate_confirmation_token(tenent_id):
|
|||
return "ragflow-" + serializer.dumps(get_uuid(), salt=tenent_id)[2:34]
|
||||
|
||||
|
||||
def valid(permission, valid_permission, language, valid_language, chunk_method, valid_chunk_method):
|
||||
def valid(permission, valid_permission, chunk_method, valid_chunk_method):
|
||||
if valid_parameter(permission, valid_permission):
|
||||
return valid_parameter(permission, valid_permission)
|
||||
if valid_parameter(language, valid_language):
|
||||
return valid_parameter(language, valid_language)
|
||||
if valid_parameter(chunk_method, valid_chunk_method):
|
||||
return valid_parameter(chunk_method, valid_chunk_method)
|
||||
|
||||
|
@ -373,3 +371,32 @@ def get_parser_config(chunk_method, parser_config):
|
|||
"picture": None}
|
||||
parser_config = key_mapping[chunk_method]
|
||||
return parser_config
|
||||
|
||||
|
||||
def valid_parser_config(parser_config):
|
||||
if not parser_config:
|
||||
return
|
||||
scopes = set([
|
||||
"chunk_token_num",
|
||||
"delimiter",
|
||||
"raptor",
|
||||
"graphrag",
|
||||
"layout_recognize",
|
||||
"task_page_size",
|
||||
"pages",
|
||||
"html4excel",
|
||||
"auto_keywords",
|
||||
"auto_questions",
|
||||
"tag_kb_ids",
|
||||
"topn_tags",
|
||||
"filename_embd_weight"
|
||||
])
|
||||
for k in parser_config.keys():
|
||||
assert k in scopes, f"Abnormal 'parser_config'. Invalid key: {k}"
|
||||
|
||||
assert 1 <= parser_config.get("chunk_token_num", 1) < 100000000, "chunk_token_num should be in range from 1 to 100000000"
|
||||
assert 1 <= parser_config.get("task_page_size", 1) < 100000000, "task_page_size should be in range from 1 to 100000000"
|
||||
assert 0 <= parser_config.get("auto_keywords", 0) < 32, "auto_keywords should be in range from 0 to 32"
|
||||
assert 0 <= parser_config.get("auto_questions", 0) < 10, "auto_questions should be in range from 0 to 10"
|
||||
assert 0 <= parser_config.get("topn_tags", 0) < 10, "topn_tags should be in range from 0 to 10"
|
||||
assert isinstance(parser_config.get("html4excel", False), bool), "html4excel should be True or False"
|
||||
|
|
|
@ -17,6 +17,8 @@ import base64
|
|||
import json
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
import threading
|
||||
from io import BytesIO
|
||||
|
||||
import pdfplumber
|
||||
|
@ -30,6 +32,10 @@ from api.constants import IMG_BASE64_PREFIX
|
|||
PROJECT_BASE = os.getenv("RAG_PROJECT_BASE") or os.getenv("RAG_DEPLOY_BASE")
|
||||
RAG_BASE = os.getenv("RAG_BASE")
|
||||
|
||||
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
|
||||
if LOCK_KEY_pdfplumber not in sys.modules:
|
||||
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
|
||||
|
||||
|
||||
def get_project_base_directory(*args):
|
||||
global PROJECT_BASE
|
||||
|
@ -175,19 +181,20 @@ def thumbnail_img(filename, blob):
|
|||
"""
|
||||
filename = filename.lower()
|
||||
if re.match(r".*\.pdf$", filename):
|
||||
pdf = pdfplumber.open(BytesIO(blob))
|
||||
buffered = BytesIO()
|
||||
resolution = 32
|
||||
img = None
|
||||
for _ in range(10):
|
||||
# https://github.com/jsvine/pdfplumber?tab=readme-ov-file#creating-a-pageimage-with-to_image
|
||||
pdf.pages[0].to_image(resolution=resolution).annotated.save(buffered, format="png")
|
||||
img = buffered.getvalue()
|
||||
if len(img) >= 64000 and resolution >= 2:
|
||||
resolution = resolution / 2
|
||||
buffered = BytesIO()
|
||||
else:
|
||||
break
|
||||
with sys.modules[LOCK_KEY_pdfplumber]:
|
||||
pdf = pdfplumber.open(BytesIO(blob))
|
||||
buffered = BytesIO()
|
||||
resolution = 32
|
||||
img = None
|
||||
for _ in range(10):
|
||||
# https://github.com/jsvine/pdfplumber?tab=readme-ov-file#creating-a-pageimage-with-to_image
|
||||
pdf.pages[0].to_image(resolution=resolution).annotated.save(buffered, format="png")
|
||||
img = buffered.getvalue()
|
||||
if len(img) >= 64000 and resolution >= 2:
|
||||
resolution = resolution / 2
|
||||
buffered = BytesIO()
|
||||
else:
|
||||
break
|
||||
pdf.close()
|
||||
return img
|
||||
|
||||
|
|
|
@ -18,6 +18,8 @@ import os.path
|
|||
import logging
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
initialized_root_logger = False
|
||||
|
||||
def get_project_base_directory():
|
||||
PROJECT_BASE = os.path.abspath(
|
||||
os.path.join(
|
||||
|
@ -29,10 +31,13 @@ def get_project_base_directory():
|
|||
return PROJECT_BASE
|
||||
|
||||
def initRootLogger(logfile_basename: str, log_format: str = "%(asctime)-15s %(levelname)-8s %(process)d %(message)s"):
|
||||
logger = logging.getLogger()
|
||||
if logger.hasHandlers():
|
||||
global initialized_root_logger
|
||||
if initialized_root_logger:
|
||||
return
|
||||
initialized_root_logger = True
|
||||
|
||||
logger = logging.getLogger()
|
||||
logger.handlers.clear()
|
||||
log_path = os.path.abspath(os.path.join(get_project_base_directory(), "logs", f"{logfile_basename}.log"))
|
||||
|
||||
os.makedirs(os.path.dirname(log_path), exist_ok=True)
|
||||
|
|
BIN
assets/group.jpg
BIN
assets/group.jpg
Binary file not shown.
Before Width: | Height: | Size: 248 KiB After Width: | Height: | Size: 160 KiB |
|
@ -9,10 +9,10 @@
|
|||
"title_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
"title_sm_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
"name_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
"important_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
"tag_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
"important_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
|
||||
"tag_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
|
||||
"important_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
"question_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
"question_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
|
||||
"question_tks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
"content_with_weight": {"type": "varchar", "default": ""},
|
||||
"content_ltks": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
|
@ -28,7 +28,7 @@
|
|||
"rank_flt": {"type": "float", "default": 0},
|
||||
"available_int": {"type": "integer", "default": 1},
|
||||
"knowledge_graph_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
"entities_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace"},
|
||||
"entities_kwd": {"type": "varchar", "default": "", "analyzer": "whitespace-#"},
|
||||
"pagerank_fea": {"type": "integer", "default": 0},
|
||||
"tag_feas": {"type": "varchar", "default": ""},
|
||||
|
||||
|
|
|
@ -134,6 +134,18 @@
|
|||
"max_tokens": 32768,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "qwq-32b",
|
||||
"tags": "LLM,CHAT,128k",
|
||||
"max_tokens": 131072,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "qwq-plus",
|
||||
"tags": "LLM,CHAT,128k",
|
||||
"max_tokens": 131072,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "qwen-long",
|
||||
"tags": "LLM,CHAT,10000K",
|
||||
|
@ -663,80 +675,86 @@
|
|||
{
|
||||
"name": "Mistral",
|
||||
"logo": "",
|
||||
"tags": "LLM,TEXT EMBEDDING",
|
||||
"tags": "LLM,TEXT EMBEDDING,MODERATION",
|
||||
"status": "1",
|
||||
"llm": [
|
||||
{
|
||||
"llm_name": "open-mixtral-8x22b",
|
||||
"tags": "LLM,CHAT,64k",
|
||||
"max_tokens": 64000,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "open-mixtral-8x7b",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32000,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "open-mistral-7b",
|
||||
"tags": "LLM,CHAT,128k",
|
||||
"max_tokens": 128000,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "ministral-8b-latest",
|
||||
"tags": "LLM,CHAT,128k",
|
||||
"max_tokens": 128000,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "ministral-3b-latest",
|
||||
"tags": "LLM,CHAT,128k",
|
||||
"max_tokens": 128000,
|
||||
"llm_name": "codestral-latest",
|
||||
"tags": "LLM,CHAT,256k",
|
||||
"max_tokens": 256000,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "mistral-large-latest",
|
||||
"tags": "LLM,CHAT,128k",
|
||||
"max_tokens": 128000,
|
||||
"tags": "LLM,CHAT,131k",
|
||||
"max_tokens": 131000,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "mistral-small-latest",
|
||||
"tags": "LLM,CHAT,128k",
|
||||
"max_tokens": 128000,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "codestral-latest",
|
||||
"llm_name": "mistral-saba-latest",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32000,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "mistral-nemo",
|
||||
"tags": "LLM,CHAT,128k",
|
||||
"max_tokens": 128000,
|
||||
"llm_name": "pixtral-large-latest",
|
||||
"tags": "LLM,CHAT,IMAGE2TEXT,131k",
|
||||
"max_tokens": 131000,
|
||||
"model_type": "image2text"
|
||||
},
|
||||
{
|
||||
"llm_name": "ministral-3b-latest",
|
||||
"tags": "LLM,CHAT,131k",
|
||||
"max_tokens": 131000,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "ministral-8b-latest",
|
||||
"tags": "LLM,CHAT,131k",
|
||||
"max_tokens": 131000,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "mistral-embed",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"tags": "TEXT EMBEDDING,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "embedding"
|
||||
},
|
||||
{
|
||||
"llm_name": "pixtral-large-latest",
|
||||
"llm_name": "mistral-moderation-latest",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "mistral-small-latest",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32000,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "pixtral-12b-2409",
|
||||
"tags": "LLM,IMAGE2TEXT,131k",
|
||||
"max_tokens": 131000,
|
||||
"model_type": "image2text"
|
||||
},
|
||||
{
|
||||
"llm_name": "pixtral-12b",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32000,
|
||||
"llm_name": "mistral-ocr-latest",
|
||||
"tags": "LLM,IMAGE2TEXT,131k",
|
||||
"max_tokens": 131000,
|
||||
"model_type": "image2text"
|
||||
},
|
||||
{
|
||||
"llm_name": "open-mistral-nemo",
|
||||
"tags": "LLM,CHAT,131k",
|
||||
"max_tokens": 131000,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "open-codestral-mamba",
|
||||
"tags": "LLM,CHAT,256k",
|
||||
"max_tokens": 256000,
|
||||
"model_type": "chat"
|
||||
}
|
||||
]
|
||||
},
|
||||
|
@ -2292,11 +2310,83 @@
|
|||
{
|
||||
"name": "novita.ai",
|
||||
"logo": "",
|
||||
"tags": "LLM",
|
||||
"tags": "LLM,IMAGE2TEXT",
|
||||
"status": "1",
|
||||
"llm": [
|
||||
{
|
||||
"llm_name": "meta-llama/llama-3-8b-instruct",
|
||||
"llm_name": "deepseek/deepseek-r1",
|
||||
"tags": "LLM,CHAT,64k",
|
||||
"max_tokens": 64000,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "deepseek/deepseek_v3",
|
||||
"tags": "LLM,CHAT,64k",
|
||||
"max_tokens": 64000,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "deepseek/deepseek-r1-distill-llama-70b",
|
||||
"tags": "LLM,CHAT,64k",
|
||||
"max_tokens": 64000,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "deepseek/deepseek-r1-distill-qwen-32b",
|
||||
"tags": "LLM,CHAT,64k",
|
||||
"max_tokens": 64000,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "deepseek/deepseek-r1-distill-qwen-14b",
|
||||
"tags": "LLM,CHAT,64k",
|
||||
"max_tokens": 64000,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "deepseek/deepseek-r1-distill-llama-8b",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32000,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "meta-llama/llama-3.3-70b-instruct",
|
||||
"tags": "LLM,CHAT,128k",
|
||||
"max_tokens": 131072,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "meta-llama/llama-3.2-11b-vision-instruct",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32768,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "meta-llama/llama-3.2-3b-instruct",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32768,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "meta-llama/llama-3.2-1b-instruct",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32768,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "meta-llama/llama-3.1-70b-instruct",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32768,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "meta-llama/llama-3.1-8b-instruct",
|
||||
"tags": "LLM,CHAT,16k",
|
||||
"max_tokens": 16384,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "meta-llama/llama-3.1-8b-instruct-bf16",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "chat"
|
||||
|
@ -2307,58 +2397,34 @@
|
|||
"max_tokens": 8192,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "meta-llama/llama-3-8b-instruct",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "qwen/qwen-2.5-72b-instruct",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32000,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "qwen/qwen-2-vl-72b-instruct",
|
||||
"tags": "LLM,IMAGE2TEXT,32k",
|
||||
"max_tokens": 32768,
|
||||
"model_type": "image2text"
|
||||
},
|
||||
{
|
||||
"llm_name": "qwen/qwen-2-7b-instruct",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32768,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "mistralai/mistral-nemo",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32768,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "microsoft/wizardlm-2-7b",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32768,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "openchat/openchat-7b",
|
||||
"tags": "LLM,CHAT,4k",
|
||||
"max_tokens": 4096,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "meta-llama/llama-3.1-8b-instruct",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "meta-llama/llama-3.1-70b-instruct",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "meta-llama/llama-3.1-405b-instruct",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32768,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "google/gemma-2-9b-it",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "jondurbin/airoboros-l2-70b",
|
||||
"tags": "LLM,CHAT,4k",
|
||||
"max_tokens": 4096,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "nousresearch/hermes-2-pro-llama-3-8b",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"tags": "LLM,CHAT,128k",
|
||||
"max_tokens": 131072,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
|
@ -2368,19 +2434,43 @@
|
|||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "cognitivecomputations/dolphin-mixtral-8x22b",
|
||||
"tags": "LLM,CHAT,15k",
|
||||
"max_tokens": 16000,
|
||||
"llm_name": "Sao10K/L3-8B-Stheno-v3.2",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "sao10k/l3-70b-euryale-v2.1",
|
||||
"tags": "LLM,CHAT,15k",
|
||||
"max_tokens": 16000,
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "sophosympatheia/midnight-rose-70b",
|
||||
"llm_name": "sao10k/l3-8b-lunaris",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "sao10k/l31-70b-euryale-v2.2",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "nousresearch/hermes-2-pro-llama-3-8b",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "nousresearch/nous-hermes-llama2-13b",
|
||||
"tags": "LLM,CHAT,4k",
|
||||
"max_tokens": 4096,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "openchat/openchat-7b",
|
||||
"tags": "LLM,CHAT,4k",
|
||||
"max_tokens": 4096,
|
||||
"model_type": "chat"
|
||||
|
@ -2392,19 +2482,25 @@
|
|||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "nousresearch/nous-hermes-llama2-13b",
|
||||
"tags": "LLM,CHAT,4k",
|
||||
"max_tokens": 4096,
|
||||
"llm_name": "microsoft/wizardlm-2-8x22b",
|
||||
"tags": "LLM,CHAT,65k",
|
||||
"max_tokens": 65535,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "Nous-Hermes-2-Mixtral-8x7B-DPO",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32768,
|
||||
"llm_name": "google/gemma-2-9b-it",
|
||||
"tags": "LLM,CHAT,8k",
|
||||
"max_tokens": 8192,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "lzlv_70b",
|
||||
"llm_name": "cognitivecomputations/dolphin-mixtral-8x22b",
|
||||
"tags": "LLM,CHAT,16k",
|
||||
"max_tokens": 16000,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "jondurbin/airoboros-l2-70b",
|
||||
"tags": "LLM,CHAT,4k",
|
||||
"max_tokens": 4096,
|
||||
"model_type": "chat"
|
||||
|
@ -2416,9 +2512,9 @@
|
|||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "microsoft/wizardlm-2-8x22b",
|
||||
"tags": "LLM,CHAT,64k",
|
||||
"max_tokens": 65535,
|
||||
"llm_name": "sophosympatheia/midnight-rose-70b",
|
||||
"tags": "LLM,CHAT,4k",
|
||||
"max_tokens": 4096,
|
||||
"model_type": "chat"
|
||||
}
|
||||
]
|
||||
|
@ -2513,6 +2609,12 @@
|
|||
"max_tokens": 32768,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "Qwen/QwQ-32B",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
"max_tokens": 32768,
|
||||
"model_type": "chat"
|
||||
},
|
||||
{
|
||||
"llm_name": "Qwen/Qwen2.5-Coder-32B-Instruct",
|
||||
"tags": "LLM,CHAT,32k",
|
||||
|
@ -3169,7 +3271,7 @@
|
|||
"tags": "TEXT EMBEDDING,32000",
|
||||
"max_tokens": 32000,
|
||||
"model_type": "embedding"
|
||||
},
|
||||
},
|
||||
{
|
||||
"llm_name": "rerank-1",
|
||||
"tags": "RE-RANK, 8000",
|
||||
|
@ -3206,7 +3308,7 @@
|
|||
{
|
||||
"name": "HuggingFace",
|
||||
"logo": "",
|
||||
"tags": "TEXT EMBEDDING",
|
||||
"tags": "TEXT EMBEDDING,TEXT RE-RANK",
|
||||
"status": "1",
|
||||
"llm": []
|
||||
},
|
||||
|
|
|
@ -113,4 +113,4 @@ PDF、DOCX、EXCEL和PPT四种文档格式都有相应的解析器。最复杂
|
|||
|
||||
### 简历
|
||||
|
||||
简历是一种非常复杂的文件。一份由各种布局的非结构化文本组成的简历可以分解为由近百个字段组成的结构化数据。我们还没有打开解析器,因为我们在解析过程之后打开了处理方法。
|
||||
简历是一种非常复杂的文档。由各种格式的非结构化文本构成的简历可以被解析为包含近百个字段的结构化数据。我们还没有启用解析器,因为在解析过程之后才会启动处理方法。
|
||||
|
|
|
@ -11,52 +11,68 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
|
||||
from openpyxl import load_workbook, Workbook
|
||||
import logging
|
||||
import sys
|
||||
from io import BytesIO
|
||||
|
||||
from rag.nlp import find_codec
|
||||
|
||||
import pandas as pd
|
||||
from openpyxl import Workbook, load_workbook
|
||||
|
||||
from rag.nlp import find_codec
|
||||
|
||||
|
||||
class RAGFlowExcelParser:
|
||||
def html(self, fnm, chunk_rows=256):
|
||||
|
||||
# if isinstance(fnm, str):
|
||||
# wb = load_workbook(fnm)
|
||||
# else:
|
||||
# wb = load_workbook(BytesIO(fnm))++
|
||||
@staticmethod
|
||||
def _load_excel_to_workbook(file_like_object):
|
||||
if isinstance(file_like_object, bytes):
|
||||
file_like_object = BytesIO(file_like_object)
|
||||
|
||||
s_fnm = fnm
|
||||
if not isinstance(fnm, str):
|
||||
s_fnm = BytesIO(fnm)
|
||||
else:
|
||||
pass
|
||||
# Read first 4 bytes to determine file type
|
||||
file_like_object.seek(0)
|
||||
file_head = file_like_object.read(4)
|
||||
file_like_object.seek(0)
|
||||
|
||||
if not (file_head.startswith(b'PK\x03\x04') or file_head.startswith(b'\xD0\xCF\x11\xE0')):
|
||||
logging.info("****wxy: Not an Excel file, converting CSV to Excel Workbook")
|
||||
|
||||
try:
|
||||
file_like_object.seek(0)
|
||||
df = pd.read_csv(file_like_object)
|
||||
return RAGFlowExcelParser._dataframe_to_workbook(df)
|
||||
|
||||
except Exception as e_csv:
|
||||
raise Exception(f"****wxy: Failed to parse CSV and convert to Excel Workbook: {e_csv}")
|
||||
|
||||
try:
|
||||
wb = load_workbook(s_fnm)
|
||||
return load_workbook(file_like_object)
|
||||
except Exception as e:
|
||||
print(f'****wxy: file parser error: {e}, s_fnm={s_fnm}, trying convert files')
|
||||
df = pd.read_excel(s_fnm)
|
||||
wb = Workbook()
|
||||
# if len(wb.worksheets) > 0:
|
||||
# del wb.worksheets[0]
|
||||
# else: pass
|
||||
ws = wb.active
|
||||
ws.title = "Data"
|
||||
for col_num, column_name in enumerate(df.columns, 1):
|
||||
ws.cell(row=1, column=col_num, value=column_name)
|
||||
else:
|
||||
pass
|
||||
for row_num, row in enumerate(df.values, 2):
|
||||
for col_num, value in enumerate(row, 1):
|
||||
ws.cell(row=row_num, column=col_num, value=value)
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
logging.info(f"****wxy: openpyxl load error: {e}, try pandas instead")
|
||||
try:
|
||||
file_like_object.seek(0)
|
||||
df = pd.read_excel(file_like_object)
|
||||
return RAGFlowExcelParser._dataframe_to_workbook(df)
|
||||
except Exception as e_pandas:
|
||||
raise Exception(f"****wxy: pandas.read_excel error: {e_pandas}, original openpyxl error: {e}")
|
||||
|
||||
@staticmethod
|
||||
def _dataframe_to_workbook(df):
|
||||
wb = Workbook()
|
||||
ws = wb.active
|
||||
ws.title = "Data"
|
||||
|
||||
for col_num, column_name in enumerate(df.columns, 1):
|
||||
ws.cell(row=1, column=col_num, value=column_name)
|
||||
|
||||
for row_num, row in enumerate(df.values, 2):
|
||||
for col_num, value in enumerate(row, 1):
|
||||
ws.cell(row=row_num, column=col_num, value=value)
|
||||
|
||||
return wb
|
||||
|
||||
def html(self, fnm, chunk_rows=256):
|
||||
file_like_object = BytesIO(fnm) if not isinstance(fnm, str) else fnm
|
||||
wb = RAGFlowExcelParser._load_excel_to_workbook(file_like_object)
|
||||
tb_chunks = []
|
||||
for sheetname in wb.sheetnames:
|
||||
ws = wb[sheetname]
|
||||
|
@ -74,7 +90,7 @@ class RAGFlowExcelParser:
|
|||
tb += f"<table><caption>{sheetname}</caption>"
|
||||
tb += tb_rows_0
|
||||
for r in list(
|
||||
rows[1 + chunk_i * chunk_rows: 1 + (chunk_i + 1) * chunk_rows]
|
||||
rows[1 + chunk_i * chunk_rows: 1 + (chunk_i + 1) * chunk_rows]
|
||||
):
|
||||
tb += "<tr>"
|
||||
for i, c in enumerate(r):
|
||||
|
@ -89,40 +105,8 @@ class RAGFlowExcelParser:
|
|||
return tb_chunks
|
||||
|
||||
def __call__(self, fnm):
|
||||
# if isinstance(fnm, str):
|
||||
# wb = load_workbook(fnm)
|
||||
# else:
|
||||
# wb = load_workbook(BytesIO(fnm))
|
||||
|
||||
s_fnm = fnm
|
||||
if not isinstance(fnm, str):
|
||||
s_fnm = BytesIO(fnm)
|
||||
else:
|
||||
pass
|
||||
|
||||
try:
|
||||
wb = load_workbook(s_fnm)
|
||||
except Exception as e:
|
||||
print(f'****wxy: file parser error: {e}, s_fnm={s_fnm}, trying convert files')
|
||||
df = pd.read_excel(s_fnm)
|
||||
wb = Workbook()
|
||||
if len(wb.worksheets) > 0:
|
||||
del wb.worksheets[0]
|
||||
else:
|
||||
pass
|
||||
ws = wb.active
|
||||
ws.title = "Data"
|
||||
for col_num, column_name in enumerate(df.columns, 1):
|
||||
ws.cell(row=1, column=col_num, value=column_name)
|
||||
else:
|
||||
pass
|
||||
for row_num, row in enumerate(df.values, 2):
|
||||
for col_num, value in enumerate(row, 1):
|
||||
ws.cell(row=row_num, column=col_num, value=value)
|
||||
else:
|
||||
pass
|
||||
else:
|
||||
pass
|
||||
file_like_object = BytesIO(fnm) if not isinstance(fnm, str) else fnm
|
||||
wb = RAGFlowExcelParser._load_excel_to_workbook(file_like_object)
|
||||
|
||||
res = []
|
||||
for sheetname in wb.sheetnames:
|
||||
|
@ -148,7 +132,7 @@ class RAGFlowExcelParser:
|
|||
@staticmethod
|
||||
def row_number(fnm, binary):
|
||||
if fnm.split(".")[-1].lower().find("xls") >= 0:
|
||||
wb = load_workbook(BytesIO(binary))
|
||||
wb = RAGFlowExcelParser._load_excel_to_workbook(BytesIO(binary))
|
||||
total = 0
|
||||
for sheetname in wb.sheetnames:
|
||||
ws = wb[sheetname]
|
||||
|
@ -164,4 +148,3 @@ class RAGFlowExcelParser:
|
|||
if __name__ == "__main__":
|
||||
psr = RAGFlowExcelParser()
|
||||
psr(sys.argv[1])
|
||||
|
||||
|
|
|
@ -22,27 +22,56 @@ class RAGFlowMarkdownParser:
|
|||
self.chunk_token_num = int(chunk_token_num)
|
||||
|
||||
def extract_tables_and_remainder(self, markdown_text):
|
||||
# Standard Markdown table
|
||||
table_pattern = re.compile(
|
||||
r'''
|
||||
(?:\n|^)
|
||||
(?:\|.*?\|.*?\|.*?\n)
|
||||
(?:\|(?:\s*[:-]+[-| :]*\s*)\|.*?\n)
|
||||
(?:\|.*?\|.*?\|.*?\n)+
|
||||
tables = []
|
||||
remainder = markdown_text
|
||||
if "|" in markdown_text: # for optimize performance
|
||||
# Standard Markdown table
|
||||
border_table_pattern = re.compile(
|
||||
r'''
|
||||
(?:\n|^)
|
||||
(?:\|.*?\|.*?\|.*?\n)
|
||||
(?:\|(?:\s*[:-]+[-| :]*\s*)\|.*?\n)
|
||||
(?:\|.*?\|.*?\|.*?\n)+
|
||||
''', re.VERBOSE)
|
||||
tables = table_pattern.findall(markdown_text)
|
||||
remainder = table_pattern.sub('', markdown_text)
|
||||
border_tables = border_table_pattern.findall(markdown_text)
|
||||
tables.extend(border_tables)
|
||||
remainder = border_table_pattern.sub('', remainder)
|
||||
|
||||
# Borderless Markdown table
|
||||
no_border_table_pattern = re.compile(
|
||||
# Borderless Markdown table
|
||||
no_border_table_pattern = re.compile(
|
||||
r'''
|
||||
(?:\n|^)
|
||||
(?:\S.*?\|.*?\n)
|
||||
(?:(?:\s*[:-]+[-| :]*\s*).*?\n)
|
||||
(?:\S.*?\|.*?\n)+
|
||||
''', re.VERBOSE)
|
||||
no_border_tables = no_border_table_pattern.findall(remainder)
|
||||
tables.extend(no_border_tables)
|
||||
remainder = no_border_table_pattern.sub('', remainder)
|
||||
|
||||
if "<table>" in remainder.lower(): # for optimize performance
|
||||
#HTML table extraction - handle possible html/body wrapper tags
|
||||
html_table_pattern = re.compile(
|
||||
r'''
|
||||
(?:\n|^)
|
||||
(?:\S.*?\|.*?\n)
|
||||
(?:(?:\s*[:-]+[-| :]*\s*).*?\n)
|
||||
(?:\S.*?\|.*?\n)+
|
||||
''', re.VERBOSE)
|
||||
no_border_tables = no_border_table_pattern.findall(remainder)
|
||||
tables.extend(no_border_tables)
|
||||
remainder = no_border_table_pattern.sub('', remainder)
|
||||
(?:\n|^)
|
||||
\s*
|
||||
(?:
|
||||
# case1: <html><body><table>...</table></body></html>
|
||||
(?:<html[^>]*>\s*<body[^>]*>\s*<table[^>]*>.*?</table>\s*</body>\s*</html>)
|
||||
|
|
||||
# case2: <body><table>...</table></body>
|
||||
(?:<body[^>]*>\s*<table[^>]*>.*?</table>\s*</body>)
|
||||
|
|
||||
# case3: only<table>...</table>
|
||||
(?:<table[^>]*>.*?</table>)
|
||||
)
|
||||
\s*
|
||||
(?=\n|$)
|
||||
''',
|
||||
re.VERBOSE | re.DOTALL | re.IGNORECASE
|
||||
)
|
||||
html_tables = html_table_pattern.findall(remainder)
|
||||
tables.extend(html_tables)
|
||||
remainder = html_table_pattern.sub('', remainder)
|
||||
|
||||
return remainder, tables
|
||||
|
|
|
@ -18,6 +18,8 @@ import logging
|
|||
import os
|
||||
import random
|
||||
from timeit import default_timer as timer
|
||||
import sys
|
||||
import threading
|
||||
|
||||
import xgboost as xgb
|
||||
from io import BytesIO
|
||||
|
@ -34,8 +36,23 @@ from rag.nlp import rag_tokenizer
|
|||
from copy import deepcopy
|
||||
from huggingface_hub import snapshot_download
|
||||
|
||||
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
|
||||
if LOCK_KEY_pdfplumber not in sys.modules:
|
||||
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
|
||||
|
||||
class RAGFlowPdfParser:
|
||||
def __init__(self):
|
||||
"""
|
||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||
|
||||
For Linux:
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
For Windows:
|
||||
Good luck
|
||||
^_-
|
||||
|
||||
"""
|
||||
self.ocr = OCR()
|
||||
if hasattr(self, "model_speciess"):
|
||||
self.layouter = LayoutRecognizer("layout." + self.model_speciess)
|
||||
|
@ -66,17 +83,6 @@ class RAGFlowPdfParser:
|
|||
model_dir, "updown_concat_xgb.model"))
|
||||
|
||||
self.page_from = 0
|
||||
"""
|
||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||
|
||||
For Linux:
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
For Windows:
|
||||
Good luck
|
||||
^_-
|
||||
|
||||
"""
|
||||
|
||||
def __char_width(self, c):
|
||||
return (c["x1"] - c["x0"]) // max(len(c["text"]), 1)
|
||||
|
@ -948,8 +954,9 @@ class RAGFlowPdfParser:
|
|||
@staticmethod
|
||||
def total_page_number(fnm, binary=None):
|
||||
try:
|
||||
pdf = pdfplumber.open(
|
||||
fnm) if not binary else pdfplumber.open(BytesIO(binary))
|
||||
with sys.modules[LOCK_KEY_pdfplumber]:
|
||||
pdf = pdfplumber.open(
|
||||
fnm) if not binary else pdfplumber.open(BytesIO(binary))
|
||||
total_page = len(pdf.pages)
|
||||
pdf.close()
|
||||
return total_page
|
||||
|
@ -968,17 +975,18 @@ class RAGFlowPdfParser:
|
|||
self.page_from = page_from
|
||||
start = timer()
|
||||
try:
|
||||
self.pdf = pdfplumber.open(fnm) if isinstance(
|
||||
fnm, str) else pdfplumber.open(BytesIO(fnm))
|
||||
self.page_images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
|
||||
enumerate(self.pdf.pages[page_from:page_to])]
|
||||
try:
|
||||
self.page_chars = [[c for c in page.dedupe_chars().chars if self._has_color(c)] for page in self.pdf.pages[page_from:page_to]]
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to extract characters for pages {page_from}-{page_to}: {str(e)}")
|
||||
self.page_chars = [[] for _ in range(page_to - page_from)] # If failed to extract, using empty list instead.
|
||||
|
||||
self.total_page = len(self.pdf.pages)
|
||||
with sys.modules[LOCK_KEY_pdfplumber]:
|
||||
self.pdf = pdfplumber.open(fnm) if isinstance(
|
||||
fnm, str) else pdfplumber.open(BytesIO(fnm))
|
||||
self.page_images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
|
||||
enumerate(self.pdf.pages[page_from:page_to])]
|
||||
try:
|
||||
self.page_chars = [[c for c in page.dedupe_chars().chars if self._has_color(c)] for page in self.pdf.pages[page_from:page_to]]
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to extract characters for pages {page_from}-{page_to}: {str(e)}")
|
||||
self.page_chars = [[] for _ in range(page_to - page_from)] # If failed to extract, using empty list instead.
|
||||
|
||||
self.total_page = len(self.pdf.pages)
|
||||
except Exception:
|
||||
logging.exception("RAGFlowPdfParser __images__")
|
||||
logging.info(f"__images__ dedupe_chars cost {timer() - start}s")
|
||||
|
@ -1162,7 +1170,7 @@ class RAGFlowPdfParser:
|
|||
return poss
|
||||
|
||||
|
||||
class PlainParser(object):
|
||||
class PlainParser:
|
||||
def __call__(self, filename, from_page=0, to_page=100000, **kwargs):
|
||||
self.outlines = []
|
||||
lines = []
|
||||
|
|
|
@ -19,7 +19,7 @@ from io import BytesIO
|
|||
from pptx import Presentation
|
||||
|
||||
|
||||
class RAGFlowPptParser(object):
|
||||
class RAGFlowPptParser:
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
|
|
@ -30,10 +30,10 @@ GOODS = pd.read_csv(
|
|||
GOODS["cid"] = GOODS["cid"].astype(str)
|
||||
GOODS = GOODS.set_index(["cid"])
|
||||
CORP_TKS = json.load(
|
||||
open(os.path.join(current_file_path, "res/corp.tks.freq.json"), "r")
|
||||
open(os.path.join(current_file_path, "res/corp.tks.freq.json"), "r",encoding="utf-8")
|
||||
)
|
||||
GOOD_CORP = json.load(open(os.path.join(current_file_path, "res/good_corp.json"), "r"))
|
||||
CORP_TAG = json.load(open(os.path.join(current_file_path, "res/corp_tag.json"), "r"))
|
||||
GOOD_CORP = json.load(open(os.path.join(current_file_path, "res/good_corp.json"), "r",encoding="utf-8"))
|
||||
CORP_TAG = json.load(open(os.path.join(current_file_path, "res/corp_tag.json"), "r",encoding="utf-8"))
|
||||
|
||||
|
||||
def baike(cid, default_v=0):
|
||||
|
|
|
@ -25,7 +25,7 @@ TBL = pd.read_csv(
|
|||
os.path.join(current_file_path, "res/schools.csv"), sep="\t", header=0
|
||||
).fillna("")
|
||||
TBL["name_en"] = TBL["name_en"].map(lambda x: x.lower().strip())
|
||||
GOOD_SCH = json.load(open(os.path.join(current_file_path, "res/good_sch.json"), "r"))
|
||||
GOOD_SCH = json.load(open(os.path.join(current_file_path, "res/good_sch.json"), "r",encoding="utf-8"))
|
||||
GOOD_SCH = set([re.sub(r"[,. &()()]+", "", c) for c in GOOD_SCH])
|
||||
|
||||
|
||||
|
|
|
@ -31,6 +31,7 @@ class RAGFlowTxtParser:
|
|||
raise TypeError("txt type should be str!")
|
||||
cks = [""]
|
||||
tk_nums = [0]
|
||||
delimiter = delimiter.encode('utf-8').decode('unicode_escape').encode('latin1').decode('utf-8')
|
||||
|
||||
def add_chunk(t):
|
||||
nonlocal cks, tk_nums, delimiter
|
||||
|
|
|
@ -14,7 +14,8 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
import io
|
||||
|
||||
import sys
|
||||
import threading
|
||||
import pdfplumber
|
||||
|
||||
from .ocr import OCR
|
||||
|
@ -23,6 +24,11 @@ from .layout_recognizer import LayoutRecognizer4YOLOv10 as LayoutRecognizer
|
|||
from .table_structure_recognizer import TableStructureRecognizer
|
||||
|
||||
|
||||
LOCK_KEY_pdfplumber = "global_shared_lock_pdfplumber"
|
||||
if LOCK_KEY_pdfplumber not in sys.modules:
|
||||
sys.modules[LOCK_KEY_pdfplumber] = threading.Lock()
|
||||
|
||||
|
||||
def init_in_out(args):
|
||||
from PIL import Image
|
||||
import os
|
||||
|
@ -36,9 +42,10 @@ def init_in_out(args):
|
|||
|
||||
def pdf_pages(fnm, zoomin=3):
|
||||
nonlocal outputs, images
|
||||
pdf = pdfplumber.open(fnm)
|
||||
images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
|
||||
enumerate(pdf.pages)]
|
||||
with sys.modules[LOCK_KEY_pdfplumber]:
|
||||
pdf = pdfplumber.open(fnm)
|
||||
images = [p.to_image(resolution=72 * zoomin).annotated for i, p in
|
||||
enumerate(pdf.pages)]
|
||||
|
||||
for i, page in enumerate(images):
|
||||
outputs.append(os.path.split(fnm)[-1] + f"_{i}.jpg")
|
||||
|
|
|
@ -122,7 +122,7 @@ def load_model(model_dir, nm):
|
|||
return loaded_model
|
||||
|
||||
|
||||
class TextRecognizer(object):
|
||||
class TextRecognizer:
|
||||
def __init__(self, model_dir):
|
||||
self.rec_image_shape = [int(v) for v in "3, 48, 320".split(",")]
|
||||
self.rec_batch_num = 16
|
||||
|
@ -393,7 +393,7 @@ class TextRecognizer(object):
|
|||
return rec_res, time.time() - st
|
||||
|
||||
|
||||
class TextDetector(object):
|
||||
class TextDetector:
|
||||
def __init__(self, model_dir):
|
||||
pre_process_list = [{
|
||||
'DetResizeForTest': {
|
||||
|
@ -506,7 +506,7 @@ class TextDetector(object):
|
|||
return dt_boxes, time.time() - st
|
||||
|
||||
|
||||
class OCR(object):
|
||||
class OCR:
|
||||
def __init__(self, model_dir=None):
|
||||
"""
|
||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -23,7 +23,7 @@ import math
|
|||
from PIL import Image
|
||||
|
||||
|
||||
class DecodeImage(object):
|
||||
class DecodeImage:
|
||||
""" decode image """
|
||||
|
||||
def __init__(self,
|
||||
|
@ -65,7 +65,7 @@ class DecodeImage(object):
|
|||
return data
|
||||
|
||||
|
||||
class StandardizeImage(object):
|
||||
class StandardizeImag:
|
||||
"""normalize image
|
||||
Args:
|
||||
mean (list): im - mean
|
||||
|
@ -102,7 +102,7 @@ class StandardizeImage(object):
|
|||
return im, im_info
|
||||
|
||||
|
||||
class NormalizeImage(object):
|
||||
class NormalizeImage:
|
||||
""" normalize image such as subtract mean, divide std
|
||||
"""
|
||||
|
||||
|
@ -129,7 +129,7 @@ class NormalizeImage(object):
|
|||
return data
|
||||
|
||||
|
||||
class ToCHWImage(object):
|
||||
class ToCHWImage:
|
||||
""" convert hwc image to chw image
|
||||
"""
|
||||
|
||||
|
@ -145,7 +145,7 @@ class ToCHWImage(object):
|
|||
return data
|
||||
|
||||
|
||||
class KeepKeys(object):
|
||||
class KeepKeys:
|
||||
def __init__(self, keep_keys, **kwargs):
|
||||
self.keep_keys = keep_keys
|
||||
|
||||
|
@ -156,7 +156,7 @@ class KeepKeys(object):
|
|||
return data_list
|
||||
|
||||
|
||||
class Pad(object):
|
||||
class Pad:
|
||||
def __init__(self, size=None, size_div=32, **kwargs):
|
||||
if size is not None and not isinstance(size, (int, list, tuple)):
|
||||
raise TypeError("Type of target_size is invalid. Now is {}".format(
|
||||
|
@ -194,7 +194,7 @@ class Pad(object):
|
|||
return data
|
||||
|
||||
|
||||
class LinearResize(object):
|
||||
class LinearResize:
|
||||
"""resize image by target_size and max_size
|
||||
Args:
|
||||
target_size (int): the target size of image
|
||||
|
@ -261,7 +261,7 @@ class LinearResize(object):
|
|||
return im_scale_y, im_scale_x
|
||||
|
||||
|
||||
class Resize(object):
|
||||
class Resize:
|
||||
def __init__(self, size=(640, 640), **kwargs):
|
||||
self.size = size
|
||||
|
||||
|
@ -291,7 +291,7 @@ class Resize(object):
|
|||
return data
|
||||
|
||||
|
||||
class DetResizeForTest(object):
|
||||
class DetResizeForTest:
|
||||
def __init__(self, **kwargs):
|
||||
super(DetResizeForTest, self).__init__()
|
||||
self.resize_type = 0
|
||||
|
@ -421,7 +421,7 @@ class DetResizeForTest(object):
|
|||
return img, [ratio_h, ratio_w]
|
||||
|
||||
|
||||
class E2EResizeForTest(object):
|
||||
class E2EResizeForTest:
|
||||
def __init__(self, **kwargs):
|
||||
super(E2EResizeForTest, self).__init__()
|
||||
self.max_side_len = kwargs['max_side_len']
|
||||
|
@ -489,7 +489,7 @@ class E2EResizeForTest(object):
|
|||
return im, (ratio_h, ratio_w)
|
||||
|
||||
|
||||
class KieResize(object):
|
||||
class KieResize:
|
||||
def __init__(self, **kwargs):
|
||||
super(KieResize, self).__init__()
|
||||
self.max_side, self.min_side = kwargs['img_scale'][0], kwargs[
|
||||
|
@ -539,7 +539,7 @@ class KieResize(object):
|
|||
return points
|
||||
|
||||
|
||||
class SRResize(object):
|
||||
class SRResize:
|
||||
def __init__(self,
|
||||
imgH=32,
|
||||
imgW=128,
|
||||
|
@ -576,7 +576,7 @@ class SRResize(object):
|
|||
return data
|
||||
|
||||
|
||||
class ResizeNormalize(object):
|
||||
class ResizeNormalize:
|
||||
def __init__(self, size, interpolation=Image.BICUBIC):
|
||||
self.size = size
|
||||
self.interpolation = interpolation
|
||||
|
@ -588,7 +588,7 @@ class ResizeNormalize(object):
|
|||
return img_numpy
|
||||
|
||||
|
||||
class GrayImageChannelFormat(object):
|
||||
class GrayImageChannelFormat:
|
||||
"""
|
||||
format gray scale image's channel: (3,h,w) -> (1,h,w)
|
||||
Args:
|
||||
|
@ -612,7 +612,7 @@ class GrayImageChannelFormat(object):
|
|||
return data
|
||||
|
||||
|
||||
class Permute(object):
|
||||
class Permute:
|
||||
"""permute image
|
||||
Args:
|
||||
to_bgr (bool): whether convert RGB to BGR
|
||||
|
@ -635,7 +635,7 @@ class Permute(object):
|
|||
return im, im_info
|
||||
|
||||
|
||||
class PadStride(object):
|
||||
class PadStride:
|
||||
""" padding image for model with FPN, instead PadBatch(pad_to_stride) in original config
|
||||
Args:
|
||||
stride (bool): model with FPN need image shape % stride == 0
|
||||
|
|
|
@ -38,7 +38,7 @@ def build_post_process(config, global_config=None):
|
|||
return module_class(**config)
|
||||
|
||||
|
||||
class DBPostProcess(object):
|
||||
class DBPostProcess:
|
||||
"""
|
||||
The post process for Differentiable Binarization (DB).
|
||||
"""
|
||||
|
@ -259,7 +259,7 @@ class DBPostProcess(object):
|
|||
return boxes_batch
|
||||
|
||||
|
||||
class BaseRecLabelDecode(object):
|
||||
class BaseRecLabelDecode:
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self, character_dict_path=None, use_space_char=False):
|
||||
|
|
|
@ -28,7 +28,7 @@ from .operators import preprocess
|
|||
from . import operators
|
||||
from .ocr import load_model
|
||||
|
||||
class Recognizer(object):
|
||||
class Recognizer:
|
||||
def __init__(self, label_list, task_name, model_dir=None):
|
||||
"""
|
||||
If you have trouble downloading HuggingFace models, -_^ this might help!!
|
||||
|
|
|
@ -80,13 +80,13 @@ REDIS_PASSWORD=infini_rag_flow
|
|||
SVR_HTTP_PORT=9380
|
||||
|
||||
# The RAGFlow Docker image to download.
|
||||
# Defaults to the v0.17.0-slim edition, which is the RAGFlow Docker image without embedding models.
|
||||
# RAGFLOW_IMAGE=infiniflow/ragflow:v0.17.0-slim
|
||||
# Defaults to the v0.17.2-slim edition, which is the RAGFlow Docker image without embedding models.
|
||||
# RAGFLOW_IMAGE=infiniflow/ragflow:v0.17.2-slim
|
||||
#
|
||||
# To download the RAGFlow Docker image with embedding models, uncomment the following line instead:
|
||||
RAGFLOW_IMAGE=infiniflow/ragflow:v0.17.2
|
||||
#
|
||||
# The Docker image of the v0.17.0 edition includes:
|
||||
# The Docker image of the v0.17.2 edition includes:
|
||||
# - Built-in embedding models:
|
||||
# - BAAI/bge-large-zh-v1.5
|
||||
# - BAAI/bge-reranker-v2-m3
|
||||
|
@ -122,7 +122,7 @@ TIMEZONE='Asia/Shanghai'
|
|||
# HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
# Optimizations for MacOS
|
||||
# Uncomment the following line if your OS is MacOS:
|
||||
# Uncomment the following line if your operating system is MacOS:
|
||||
# MACOS=1
|
||||
|
||||
# The maximum file size for each uploaded file, in bytes.
|
||||
|
|
|
@ -78,8 +78,8 @@ The [.env](./.env) file contains important environment variables for Docker.
|
|||
- `RAGFLOW-IMAGE`
|
||||
The Docker image edition. Available editions:
|
||||
|
||||
- `infiniflow/ragflow:v0.17.0-slim` (default): The RAGFlow Docker image without embedding models.
|
||||
- `infiniflow/ragflow:v0.17.0`: The RAGFlow Docker image with embedding models including:
|
||||
- `infiniflow/ragflow:v0.17.2-slim` (default): The RAGFlow Docker image without embedding models.
|
||||
- `infiniflow/ragflow:v0.17.2`: The RAGFlow Docker image with embedding models including:
|
||||
- Built-in embedding models:
|
||||
- `BAAI/bge-large-zh-v1.5`
|
||||
- `BAAI/bge-reranker-v2-m3`
|
||||
|
|
|
@ -15,8 +15,9 @@ CONSUMER_NO_BEG=$1
|
|||
CONSUMER_NO_END=$2
|
||||
|
||||
function task_exe(){
|
||||
JEMALLOC_PATH=$(pkg-config --variable=libdir jemalloc)/libjemalloc.so
|
||||
while [ 1 -eq 1 ]; do
|
||||
$PY rag/svr/task_executor.py $1;
|
||||
LD_PRELOAD=$JEMALLOC_PATH $PY rag/svr/task_executor.py $1;
|
||||
done
|
||||
}
|
||||
|
||||
|
|
|
@ -17,8 +17,9 @@ if [[ -z "$WS" || $WS -lt 1 ]]; then
|
|||
fi
|
||||
|
||||
function task_exe(){
|
||||
JEMALLOC_PATH=$(pkg-config --variable=libdir jemalloc)/libjemalloc.so
|
||||
while [ 1 -eq 1 ];do
|
||||
$PY rag/svr/task_executor.py $1;
|
||||
LD_PRELOAD=$JEMALLOC_PATH $PY rag/svr/task_executor.py $1;
|
||||
done
|
||||
}
|
||||
|
||||
|
|
|
@ -5,8 +5,10 @@ set -e
|
|||
|
||||
# Unset HTTP proxies that might be set by Docker daemon
|
||||
export http_proxy=""; export https_proxy=""; export no_proxy=""; export HTTP_PROXY=""; export HTTPS_PROXY=""; export NO_PROXY=""
|
||||
export PYTHONPATH=$(pwd)
|
||||
|
||||
export LD_LIBRARY_PATH=/usr/lib/x86_64-linux-gnu/
|
||||
JEMALLOC_PATH=$(pkg-config --variable=libdir jemalloc)/libjemalloc.so
|
||||
|
||||
PY=python3
|
||||
|
||||
|
@ -47,7 +49,7 @@ task_exe(){
|
|||
local retry_count=0
|
||||
while ! $STOP && [ $retry_count -lt $MAX_RETRIES ]; do
|
||||
echo "Starting task_executor.py for task $task_id (Attempt $((retry_count+1)))"
|
||||
$PY rag/svr/task_executor.py "$task_id"
|
||||
LD_PRELOAD=$JEMALLOC_PATH $PY rag/svr/task_executor.py "$task_id"
|
||||
EXIT_CODE=$?
|
||||
if [ $EXIT_CODE -eq 0 ]; then
|
||||
echo "task_executor.py for task $task_id exited successfully."
|
||||
|
@ -100,4 +102,4 @@ run_server &
|
|||
PIDS+=($!)
|
||||
|
||||
# Wait for all background processes to finish
|
||||
wait
|
||||
wait
|
||||
|
|
|
@ -15,7 +15,7 @@ When it comes to system configurations, you will need to manage the following fi
|
|||
- [service_conf.yaml.template](https://github.com/infiniflow/ragflow/blob/main/docker/service_conf.yaml.template): Configures the back-end services. It specifies the system-level configuration for RAGFlow and is used by its API server and task executor. Upon container startup, the `service_conf.yaml` file will be generated based on this template file. This process replaces any environment variables within the template, allowing for dynamic configuration tailored to the container's environment.
|
||||
- [docker-compose.yml](https://github.com/infiniflow/ragflow/blob/main/docker/docker-compose.yml): The Docker Compose file for starting up the RAGFlow service.
|
||||
|
||||
To update the default HTTP serving port (80), go to [docker-compose.yml](./docker/docker-compose.yml) and change `80:80`
|
||||
To update the default HTTP serving port (80), go to [docker-compose.yml](https://github.com/infiniflow/ragflow/blob/main/docker/docker-compose.yml) and change `80:80`
|
||||
to `<YOUR_SERVING_PORT>:80`.
|
||||
|
||||
:::tip NOTE
|
||||
|
@ -97,8 +97,8 @@ The [.env](https://github.com/infiniflow/ragflow/blob/main/docker/.env) file con
|
|||
- `RAGFLOW-IMAGE`
|
||||
The Docker image edition. Available editions:
|
||||
|
||||
- `infiniflow/ragflow:v0.17.0-slim` (default): The RAGFlow Docker image without embedding models.
|
||||
- `infiniflow/ragflow:v0.17.0`: The RAGFlow Docker image with embedding models including:
|
||||
- `infiniflow/ragflow:v0.17.2-slim` (default): The RAGFlow Docker image without embedding models.
|
||||
- `infiniflow/ragflow:v0.17.2`: The RAGFlow Docker image with embedding models including:
|
||||
- Built-in embedding models:
|
||||
- `BAAI/bge-large-zh-v1.5`
|
||||
- `BAAI/bge-reranker-v2-m3`
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
{
|
||||
"label": "Developers",
|
||||
"position": 4,
|
||||
"link": {
|
||||
"type": "generated-index",
|
||||
"description": "Guides for hardcore developers"
|
||||
}
|
||||
}
|
|
@ -3,7 +3,7 @@ sidebar_position: 3
|
|||
slug: /acquire_ragflow_api_key
|
||||
---
|
||||
|
||||
# Acquire a RAGFlow API key
|
||||
# Acquire RAGFlow API key
|
||||
|
||||
A key is required for the RAGFlow server to authenticate your requests via HTTP or a Python API. This documents provides instructions on obtaining a RAGFlow API key.
|
||||
|
||||
|
@ -14,5 +14,5 @@ A key is required for the RAGFlow server to authenticate your requests via HTTP
|
|||

|
||||
|
||||
:::tip NOTE
|
||||
See the [RAGFlow HTTP API reference](../../references/http_api_reference.md) or the [RAGFlow Python API reference](../../references/python_api_reference.md) for a complete reference of RAGFlow's HTTP or Python APIs.
|
||||
See the [RAGFlow HTTP API reference](../references/http_api_reference.md) or the [RAGFlow Python API reference](../references/python_api_reference.md) for a complete reference of RAGFlow's HTTP or Python APIs.
|
||||
:::
|
|
@ -3,7 +3,7 @@ sidebar_position: 1
|
|||
slug: /build_docker_image
|
||||
---
|
||||
|
||||
# Build a RAGFlow Docker Image
|
||||
# Build RAGFlow Docker image
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
|
@ -77,17 +77,8 @@ After building the infiniflow/ragflow:nightly-slim image, you are ready to launc
|
|||
|
||||
1. Edit Docker Compose Configuration
|
||||
|
||||
Open the `docker/docker-compose-base.yml` file. Find the `infinity.image` setting and change the image reference from `infiniflow/infinity:v0.6.0-dev3` to `infiniflow/ragflow:nightly-slim` to use the pre-built image.
|
||||
Open the `docker/.env` file. Find the `RAGFLOW_IMAGE` setting and change the image reference from `infiniflow/ragflow:v0.17.2-slim` to `infiniflow/ragflow:nightly-slim` to use the pre-built image.
|
||||
|
||||
```yaml
|
||||
infinity:
|
||||
container_name: ragflow-infinity
|
||||
image: infiniflow/ragflow:nightly-slim # here
|
||||
volumes:
|
||||
- ...
|
||||
- ...
|
||||
...
|
||||
```
|
||||
|
||||
2. Launch the Service
|
||||
|
|
@ -3,11 +3,11 @@ sidebar_position: 2
|
|||
slug: /launch_ragflow_from_source
|
||||
---
|
||||
|
||||
# Launch a RAGFlow Service from Source
|
||||
# Launch service from source
|
||||
|
||||
A guide explaining how to set up a RAGFlow service from its source code. By following this guide, you'll be able to debug using the source code.
|
||||
|
||||
## Target Audience
|
||||
## Target audience
|
||||
|
||||
Developers who have added new features or modified existing code and wish to debug using the source code, *provided that* their machine has the target deployment environment set up.
|
||||
|
||||
|
@ -22,11 +22,11 @@ Developers who have added new features or modified existing code and wish to deb
|
|||
If you have not installed Docker on your local machine (Windows, Mac, or Linux), see the [Install Docker Engine](https://docs.docker.com/engine/install/) guide.
|
||||
:::
|
||||
|
||||
## Launch the Service from Source
|
||||
## Launch a service from source
|
||||
|
||||
To launch the RAGFlow service from source code:
|
||||
To launch a RAGFlow service from source code:
|
||||
|
||||
### Clone the RAGFlow Repository
|
||||
### Clone the RAGFlow repository
|
||||
|
||||
```bash
|
||||
git clone https://github.com/infiniflow/ragflow.git
|
||||
|
@ -52,7 +52,7 @@ cd ragflow/
|
|||
```
|
||||
*A virtual environment named `.venv` is created, and all Python dependencies are installed into the new environment.*
|
||||
|
||||
### Launch Third-party Services
|
||||
### Launch third-party services
|
||||
|
||||
The following command launches the 'base' services (MinIO, Elasticsearch, Redis, and MySQL) using Docker Compose:
|
||||
|
||||
|
@ -70,7 +70,7 @@ docker compose -f docker/docker-compose-base.yml up -d
|
|||
|
||||
2. In **docker/service_conf.yaml.template**, update mysql port to `5455` and es port to `1200`, as specified in **docker/.env**.
|
||||
|
||||
### Launch the RAGFlow Backend Service
|
||||
### Launch the RAGFlow backend service
|
||||
|
||||
1. Comment out the `nginx` line in **docker/entrypoint.sh**.
|
||||
|
|
@ -3,9 +3,9 @@ sidebar_position: 10
|
|||
slug: /faq
|
||||
---
|
||||
|
||||
# Frequently asked questions
|
||||
# FAQs
|
||||
|
||||
Queries regarding general features, troubleshooting, usage, and more.
|
||||
Answers to questions about general features, troubleshooting, usage, and more.
|
||||
|
||||
---
|
||||
|
||||
|
@ -37,12 +37,12 @@ If you build RAGFlow from source, the version number is also in the system log:
|
|||
/ _, _// ___ |/ /_/ // __/ / // /_/ /| |/ |/ /
|
||||
/_/ |_|/_/ |_|\____//_/ /_/ \____/ |__/|__/
|
||||
|
||||
2025-02-18 10:10:43,835 INFO 1445658 RAGFlow version: v0.17.0-50-g6daae7f2 full
|
||||
2025-02-18 10:10:43,835 INFO 1445658 RAGFlow version: v0.15.0-50-g6daae7f2 full
|
||||
```
|
||||
|
||||
Where:
|
||||
|
||||
- `v0.17.0`: The officially published release.
|
||||
- `v0.15.0`: The officially published release.
|
||||
- `50`: The number of git commits since the official release.
|
||||
- `g6daae7f2`: `g` is the prefix, and `6daae7f2` is the first seven characters of the current commit ID.
|
||||
- `full`/`slim`: The RAGFlow edition.
|
||||
|
@ -65,16 +65,16 @@ RAGFlow has a number of built-in models for document structure parsing, which ac
|
|||
|
||||
### Which architectures or devices does RAGFlow support?
|
||||
|
||||
We officially support x86 CPU and nvidia GPU. While we also test RAGFlow on ARM64 platforms, we do not maintain RAGFlow Docker images for ARM. If you are on an ARM platform, follow [this guide](https://ragflow.io/docs/dev/build_docker_image) to build a RAGFlow Docker image.
|
||||
We officially support x86 CPU and nvidia GPU. While we also test RAGFlow on ARM64 platforms, we do not maintain RAGFlow Docker images for ARM. If you are on an ARM platform, follow [this guide](./develop/build_docker_image.mdx) to build a RAGFlow Docker image.
|
||||
|
||||
---
|
||||
|
||||
### Which embedding models can be deployed locally?
|
||||
|
||||
RAGFlow offers two Docker image editions, `v0.17.0-slim` and `v0.17.0`:
|
||||
RAGFlow offers two Docker image editions, `v0.17.2-slim` and `v0.17.2`:
|
||||
|
||||
- `infiniflow/ragflow:v0.17.0-slim` (default): The RAGFlow Docker image without embedding models.
|
||||
- `infiniflow/ragflow:v0.17.0`: The RAGFlow Docker image with embedding models including:
|
||||
- `infiniflow/ragflow:v0.17.2-slim` (default): The RAGFlow Docker image without embedding models.
|
||||
- `infiniflow/ragflow:v0.17.2`: The RAGFlow Docker image with embedding models including:
|
||||
- Built-in embedding models:
|
||||
- `BAAI/bge-large-zh-v1.5`
|
||||
- `BAAI/bge-reranker-v2-m3`
|
||||
|
@ -94,7 +94,7 @@ RAGFlow offers two Docker image editions, `v0.17.0-slim` and `v0.17.0`:
|
|||
|
||||
### Do you offer an API for integration with third-party applications?
|
||||
|
||||
The corresponding APIs are now available. See the [RAGFlow HTTP API Reference](./http_api_reference.md) or the [RAGFlow Python API Reference](./python_api_reference.md) for more information.
|
||||
The corresponding APIs are now available. See the [RAGFlow HTTP API Reference](./references/http_api_reference.md) or the [RAGFlow Python API Reference](./references/python_api_reference.md) for more information.
|
||||
|
||||
---
|
||||
|
||||
|
@ -130,7 +130,7 @@ Yes, we support enhancing user queries based on existing context of an ongoing c
|
|||
|
||||
#### How to build the RAGFlow image from scratch?
|
||||
|
||||
See [Build a RAGFlow Docker image](https://ragflow.io/docs/dev/build_docker_image).
|
||||
See [Build a RAGFlow Docker image](./develop/build_docker_image.mdx).
|
||||
|
||||
---
|
||||
|
||||
|
@ -296,7 +296,7 @@ tail -f ragflow/docker/ragflow-logs/*.log
|
|||
cd29bcb254bc quay.io/minio/minio:RELEASE.2023-12-20T01-00-02Z "/usr/bin/docker-ent…" 2 weeks ago Up 11 hours 0.0.0.0:9001->9001/tcp, :::9001->9001/tcp, 0.0.0.0:9000->9000/tcp, :::9000->9000/tcp ragflow-minio
|
||||
```
|
||||
|
||||
2. Follow [this document](../guides/run_health_check.md) to check the health status of the Elasticsearch service.
|
||||
2. Follow [this document](./guides/run_health_check.md) to check the health status of the Elasticsearch service.
|
||||
|
||||
:::danger IMPORTANT
|
||||
The status of a Docker container status does not necessarily reflect the status of the service. You may find that your services are unhealthy even when the corresponding Docker containers are up running. Possible reasons for this include network failures, incorrect port numbers, or DNS issues.
|
||||
|
@ -318,7 +318,7 @@ The status of a Docker container status does not necessarily reflect the status
|
|||
91220e3285dd docker.elastic.co/elasticsearch/elasticsearch:8.11.3 "/bin/tini -- /usr/l…" 11 hours ago Up 11 hours (healthy) 9300/tcp, 0.0.0.0:9200->9200/tcp, :::9200->9200/tcp ragflow-es-01
|
||||
```
|
||||
|
||||
2. Follow [this document](../guides/run_health_check.md) to check the health status of the Elasticsearch service.
|
||||
2. Follow [this document](./guides/run_health_check.md) to check the health status of the Elasticsearch service.
|
||||
|
||||
:::danger IMPORTANT
|
||||
The status of a Docker container status does not necessarily reflect the status of the service. You may find that your services are unhealthy even when the corresponding Docker containers are up running. Possible reasons for this include network failures, incorrect port numbers, or DNS issues.
|
||||
|
@ -347,7 +347,7 @@ A correct Ollama IP address and port is crucial to adding models to Ollama:
|
|||
- If you are on demo.ragflow.io, ensure that the server hosting Ollama has a publicly accessible IP address. Note that 127.0.0.1 is not a publicly accessible IP address.
|
||||
- If you deploy RAGFlow locally, ensure that Ollama and RAGFlow are in the same LAN and can communicate with each other.
|
||||
|
||||
See [Deploy a local LLM](../guides/deploy_local_llm.mdx) for more information.
|
||||
See [Deploy a local LLM](./guides/models/deploy_local_llm.mdx) for more information.
|
||||
|
||||
---
|
||||
|
||||
|
@ -395,7 +395,7 @@ Ensure that you update the **MAX_CONTENT_LENGTH** environment variable:
|
|||
cd29bcb254bc quay.io/minio/minio:RELEASE.2023-12-20T01-00-02Z "/usr/bin/docker-ent…" 2 weeks ago Up 11 hours 0.0.0.0:9001->9001/tcp, :::9001->9001/tcp, 0.0.0.0:9000->9000/tcp, :::9000->9000/tcp ragflow-minio
|
||||
```
|
||||
|
||||
2. Follow [this document](../guides/run_health_check.md) to check the health status of the Elasticsearch service.
|
||||
2. Follow [this document](./guides/run_health_check.md) to check the health status of the Elasticsearch service.
|
||||
|
||||
:::danger IMPORTANT
|
||||
The status of a Docker container status does not necessarily reflect the status of the service. You may find that your services are unhealthy even when the corresponding Docker containers are up running. Possible reasons for this include network failures, incorrect port numbers, or DNS issues.
|
||||
|
@ -417,7 +417,7 @@ The status of a Docker container status does not necessarily reflect the status
|
|||
|
||||
### How to run RAGFlow with a locally deployed LLM?
|
||||
|
||||
You can use Ollama or Xinference to deploy local LLM. See [here](../guides/deploy_local_llm.mdx) for more information.
|
||||
You can use Ollama or Xinference to deploy local LLM. See [here](./guides/models/deploy_local_llm.mdx) for more information.
|
||||
|
||||
---
|
||||
|
||||
|
@ -434,7 +434,7 @@ If your model is not currently supported but has APIs compatible with those of O
|
|||
- If RAGFlow is locally deployed, ensure that your RAGFlow and Ollama are in the same LAN.
|
||||
- If you are using our online demo, ensure that the IP address of your Ollama server is public and accessible.
|
||||
|
||||
See [here](../guides/deploy_local_llm.mdx) for more information.
|
||||
See [here](./guides/models/deploy_local_llm.mdx) for more information.
|
||||
|
||||
---
|
||||
|
||||
|
@ -453,12 +453,12 @@ This error occurs because there are too many chunks matching your search criteri
|
|||
|
||||
### How to get an API key for integration with third-party applications?
|
||||
|
||||
See [Acquire a RAGFlow API key](../guides/develop/acquire_ragflow_api_key.md).
|
||||
See [Acquire a RAGFlow API key](./develop/acquire_ragflow_api_key.md).
|
||||
|
||||
---
|
||||
|
||||
### How to upgrade RAGFlow?
|
||||
|
||||
See [Upgrade RAGFlow](../guides/upgrade_ragflow.mdx) for more information.
|
||||
See [Upgrade RAGFlow](./guides/upgrade_ragflow.mdx) for more information.
|
||||
|
||||
---
|
|
@ -34,6 +34,7 @@ Evaluates whether the output of specific components meets certain conditions, wi
|
|||
|
||||
:::danger IMPORTANT
|
||||
When you have added multiple conditions for a specific case, a **Logical operator** field appears, requiring you to set the logical relationship between these conditions as either AND or OR.
|
||||

|
||||
:::
|
||||
|
||||
- **Component ID**: The ID of the corresponding component.
|
||||
|
|
|
@ -3,10 +3,14 @@ sidebar_position: 3
|
|||
slug: /embed_agent_into_webpage
|
||||
---
|
||||
|
||||
# Embed agent into a webpage
|
||||
# Embed agent into webpage
|
||||
|
||||
You can use iframe to embed an agent into a third-party webpage.
|
||||
|
||||
:::caution WARNING
|
||||
If your agent's **Begin** component takes a key of **file** type (a **file** type variable), you *cannot* embed it into a webpage.
|
||||
:::
|
||||
|
||||
1. Before proceeding, you must [acquire an API key](../models/llm_api_key_setup.md); otherwise, an error message would appear.
|
||||
2. On the **Agent** page, click an intended agent **>** **Edit** to access its editing page.
|
||||
3. Click **Embed into webpage** on the top right corner of the canvas to show the **iframe** window:
|
||||
|
|
|
@ -3,11 +3,15 @@ sidebar_position: 2
|
|||
slug: /general_purpose_chatbot
|
||||
---
|
||||
|
||||
# Create a general-purpose chatbot
|
||||
# Create chatbot
|
||||
|
||||
Create a general-purpose chatbot.
|
||||
|
||||
---
|
||||
|
||||
Chatbot is one of the most common AI scenarios. However, effectively understanding user queries and responding appropriately remains a challenge. RAGFlow's general-purpose chatbot agent is our attempt to tackle this longstanding issue.
|
||||
|
||||
This chatbot closely resembles the chatbot introduced in [Start an AI chat](../start_chat.md), but with a key difference - it introduces a reflective mechanism that allows it to improve the retrieval from the target knowledge bases by rewriting the user's query.
|
||||
This chatbot closely resembles the chatbot introduced in [Start an AI chat](../chat/start_chat.md), but with a key difference - it introduces a reflective mechanism that allows it to improve the retrieval from the target knowledge bases by rewriting the user's query.
|
||||
|
||||
This document provides guides on creating such a chatbot using our chatbot template.
|
||||
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
{
|
||||
"label": "Chat",
|
||||
"position": 1,
|
||||
"link": {
|
||||
"type": "generated-index",
|
||||
"description": "Chat-specific guides."
|
||||
}
|
||||
}
|
|
@ -1,25 +1,16 @@
|
|||
---
|
||||
sidebar_position: 2
|
||||
slug: /accelerate_doc_indexing_and_question_answering
|
||||
slug: /accelerate_question_answering
|
||||
---
|
||||
|
||||
# Accelerate document indexing and question answering
|
||||
# Accelerate answering
|
||||
import APITable from '@site/src/components/APITable';
|
||||
|
||||
A checklist to speed up document parsing and question answering.
|
||||
A checklist to speed up question answering.
|
||||
|
||||
---
|
||||
|
||||
Please note that some of your settings may consume a significant amount of time. If you often find that document parsing and question answering are time-consuming, here is a checklist to consider:
|
||||
|
||||
## 1. Accelerate document indexing
|
||||
|
||||
- Use GPU to reduce embedding time.
|
||||
- On the configuration page of your knowledge base, switch off **Use RAPTOR to enhance retrieval**.
|
||||
- Extracting knowledge graph (GraphRAG) is time-consuming.
|
||||
- Disable **Auto-keyword** and **Auto-question** on the configuration page of yor knowledge base, as both depend on the LLM.
|
||||
|
||||
## 2. Accelerate question answering
|
||||
Please note that some of your settings may consume a significant amount of time. If you often find that your question answering is time-consuming, here is a checklist to consider:
|
||||
|
||||
- In the **Prompt Engine** tab of your **Chat Configuration** dialogue, disabling **Multi-turn optimization** will reduce the time required to get an answer from the LLM.
|
||||
- In the **Prompt Engine** tab of your **Chat Configuration** dialogue, leaving the **Rerank model** field empty will significantly decrease retrieval time.
|
||||
|
@ -32,18 +23,18 @@ Please note that some of your settings may consume a significant amount of time.
|
|||
<APITable>
|
||||
```
|
||||
|
||||
| Item name | Description |
|
||||
| ----------------- | ------------------------------------------------------------ |
|
||||
| Item name | Description |
|
||||
| ----------------- | --------------------------------------------------------------------------------------------- |
|
||||
| Total | Total time spent on this conversation round, including chunk retrieval and answer generation. |
|
||||
| Check LLM | Time to validate the specified LLM. |
|
||||
| Create retriever | Time to create a chunk retriever. |
|
||||
| Bind embedding | Time to initialize an embedding model instance. |
|
||||
| Bind LLM | Time to initialize an LLM instance. |
|
||||
| Tune question | Time to optimize the user query using the context of the mult-turn conversation. |
|
||||
| Bind reranker | Time to initialize an reranker model instance for chunk retrieval. |
|
||||
| Generate keywords | Time to extract keywords from the user query. |
|
||||
| Retrieval | Time to retrieve the chunks. |
|
||||
| Generate answer | Time to generate the answer. |
|
||||
| Check LLM | Time to validate the specified LLM. |
|
||||
| Create retriever | Time to create a chunk retriever. |
|
||||
| Bind embedding | Time to initialize an embedding model instance. |
|
||||
| Bind LLM | Time to initialize an LLM instance. |
|
||||
| Tune question | Time to optimize the user query using the context of the mult-turn conversation. |
|
||||
| Bind reranker | Time to initialize an reranker model instance for chunk retrieval. |
|
||||
| Generate keywords | Time to extract keywords from the user query. |
|
||||
| Retrieval | Time to retrieve the chunks. |
|
||||
| Generate answer | Time to generate the answer. |
|
||||
|
||||
```mdx-code-block
|
||||
</APITable>
|
|
@ -0,0 +1,28 @@
|
|||
---
|
||||
sidebar_position: 3
|
||||
slug: /implement_deep_research
|
||||
---
|
||||
|
||||
# Implement deep research
|
||||
|
||||
Implements deep research for agentic reasoning.
|
||||
|
||||
---
|
||||
|
||||
From v0.17.0 onward, RAGFlow supports integrating agentic reasoning in an AI chat. The following diagram illustrates the workflow of RAGFlow's deep research:
|
||||
|
||||

|
||||
|
||||
To activate this feature:
|
||||
|
||||
1. Enable the **Reasoning** toggle under the **Prompt Engine** tab of your chat assistant dialogue.
|
||||
|
||||

|
||||
|
||||
2. Enter the correct Tavily API key under the **Assistant Setting** tab of your chat assistant dialogue to leverage Tavily-based web search
|
||||
|
||||

|
||||
|
||||
*The following is a screenshot of a conversation that integrates Deep Research:*
|
||||
|
||||

|
|
@ -3,13 +3,13 @@ sidebar_position: 1
|
|||
slug: /start_chat
|
||||
---
|
||||
|
||||
# Chat
|
||||
# Start AI chat
|
||||
|
||||
Initiate an AI-powered chat with a configured chat assistant.
|
||||
|
||||
---
|
||||
|
||||
Knowledge base, hallucination-free chat, and file management are the three pillars of RAGFlow. Chats in RAGFlow are based on a particular knowledge base or multiple knowledge bases. Once you have created your knowledge base and finished file parsing, you can go ahead and start an AI conversation.
|
||||
Knowledge base, hallucination-free chat, and file management are the three pillars of RAGFlow. Chats in RAGFlow are based on a particular knowledge base or multiple knowledge bases. Once you have created your knowledge base, finished file parsing, and [run a retrieval test](../dataset/run_retrieval_test.md), you can go ahead and start an AI conversation.
|
||||
|
||||
## Start an AI chat
|
||||
|
||||
|
@ -80,13 +80,13 @@ Hover over an intended chat assistant **>** **Edit** to show the chat configurat
|
|||
|
||||
RAGFlow offers HTTP and Python APIs for you to integrate RAGFlow's capabilities into your applications. Read the following documents for more information:
|
||||
|
||||
- [Acquire a RAGFlow API key](./models/llm_api_key_setup.md)
|
||||
- [HTTP API reference](../references/http_api_reference.md)
|
||||
- [Python API reference](../references/python_api_reference.md)
|
||||
- [Acquire a RAGFlow API key](../../develop/acquire_ragflow_api_key.md)
|
||||
- [HTTP API reference](../../references/http_api_reference.md)
|
||||
- [Python API reference](../../references/python_api_reference.md)
|
||||
|
||||
You can use iframe to embed the created chat assistant into a third-party webpage:
|
||||
|
||||
1. Before proceeding, you must [acquire an API key](./models/llm_api_key_setup.md); otherwise, an error message would appear.
|
||||
1. Before proceeding, you must [acquire an API key](../models/llm_api_key_setup.md); otherwise, an error message would appear.
|
||||
2. Hover over an intended chat assistant **>** **Edit** to show the **iframe** window:
|
||||
|
||||

|
|
@ -0,0 +1,19 @@
|
|||
---
|
||||
sidebar_position: 9
|
||||
slug: /accelerate_doc_indexing
|
||||
---
|
||||
|
||||
# Accelerate indexing
|
||||
import APITable from '@site/src/components/APITable';
|
||||
|
||||
A checklist to speed up document parsing and indexing.
|
||||
|
||||
---
|
||||
|
||||
Please note that some of your settings may consume a significant amount of time. If you often find that document parsing is time-consuming, here is a checklist to consider:
|
||||
|
||||
- Use GPU to reduce embedding time.
|
||||
- On the configuration page of your knowledge base, switch off **Use RAPTOR to enhance retrieval**.
|
||||
- Extracting knowledge graph (GraphRAG) is time-consuming.
|
||||
- Disable **Auto-keyword** and **Auto-question** on the configuration page of yor knowledge base, as both depend on the LLM.
|
||||
- **v0.17.0:** If your document is plain text PDF and does not require GPU-intensive processes like OCR (Optical Character Recognition), TSR (Table Structure Recognition), or DLA (Document Layout Analysis), you can choose **Naive** over **DeepDoc** or other time-consuming large model options in the **Document parser** dropdown. This will substantially reduce document parsing time.
|
|
@ -39,18 +39,18 @@ This section covers the following topics:
|
|||
|
||||
RAGFlow offers multiple chunking template to facilitate chunking files of different layouts and ensure semantic integrity. In **Chunk method**, you can choose the default template that suits the layouts and formats of your files. The following table shows the descriptions and the compatible file formats of each supported chunk template:
|
||||
|
||||
| **Template** | Description | File format |
|
||||
|--------------|-----------------------------------------------------------------------|------------------------------------------------------|
|
||||
| General | Files are consecutively chunked based on a preset chunk token number. | DOCX, EXCEL, PPT, PDF, TXT, JPEG, JPG, PNG, TIF, GIF |
|
||||
| Q&A | | XLSX, CSV/TXT |
|
||||
| Manual | | PDF |
|
||||
| Table | | XLSX, CSV/TXT |
|
||||
| Paper | | PDF |
|
||||
| Book | | DOCX, PDF, TXT |
|
||||
| Laws | | DOCX, PDF, TXT |
|
||||
| Presentation | | PDF, PPTX |
|
||||
| Picture | | JPEG, JPG, PNG, TIF, GIF |
|
||||
| One | The entire document is chunked as one. | DOCX, EXCEL, PDF, TXT |
|
||||
| **Template** | Description | File format |
|
||||
|--------------|-----------------------------------------------------------------------|------------------------------------------------------------------------------|
|
||||
| General | Files are consecutively chunked based on a preset chunk token number. | DOCX, XLSX, XLS (Excel97~2003), PPT, PDF, TXT, JPEG, JPG, PNG, TIF, GIF, CSV |
|
||||
| Q&A | | XLSX, XLS (Excel97~2003), CSV/TXT |
|
||||
| Manual | | PDF |
|
||||
| Table | | XLSX, XLS (Excel97~2003), CSV/TXT |
|
||||
| Paper | | PDF |
|
||||
| Book | | DOCX, PDF, TXT |
|
||||
| Laws | | DOCX, PDF, TXT |
|
||||
| Presentation | | PDF, PPTX |
|
||||
| Picture | | JPEG, JPG, PNG, TIF, GIF |
|
||||
| One | The entire document is chunked as one. | DOCX, XLSX, XLS (Excel97~2003), PDF, TXT |
|
||||
|
||||
You can also change a file's chunk method on the **Datasets** page.
|
||||
|
||||
|
@ -124,17 +124,19 @@ RAGFlow uses multiple recall of both full-text search and vector search in its c
|
|||
- Similarity threshold: Chunks with similarities below the threshold will be filtered. By default, it is set to 0.2.
|
||||
- Vector similarity weight: The percentage by which vector similarity contributes to the overall score. By default, it is set to 0.3.
|
||||
|
||||
See [Run retrieval test](./run_retrieval_test.md) for details.
|
||||
|
||||

|
||||
|
||||
## Search for knowledge base
|
||||
|
||||
As of RAGFlow v0.17.0, the search feature is still in a rudimentary form, supporting only knowledge base search by name.
|
||||
As of RAGFlow v0.17.2, the search feature is still in a rudimentary form, supporting only knowledge base search by name.
|
||||
|
||||

|
||||
|
||||
## Delete knowledge base
|
||||
|
||||
You are allowed to delete a knowledge base. Hover your mouse over the three dot of the intended knowledge base card and the **Delete** option appears. Once you delete a knowledge base, the associated folder under **root/.knowledge** directory is AUTOMATICALLY REMOVED. The consequence is:
|
||||
You are allowed to delete a knowledge base. Hover your mouse over the three dot of the intended knowledge base card and the **Delete** option appears. Once you delete a knowledge base, the associated folder under **root/.knowledge** directory is AUTOMATICALLY REMOVED. The consequence is:
|
||||
|
||||
- The files uploaded directly to the knowledge base are gone;
|
||||
- The file references, which you created from within **File Management**, are gone, but the associated files still exist in **File Management**.
|
||||
|
|
|
@ -13,7 +13,7 @@ To enhance multi-hop question-answering, RAGFlow adds a knowledge graph construc
|
|||
|
||||

|
||||
|
||||
As of v0.17.0, RAGFlow supports constructing a knowledge graph on a knowledge base, allowing you to construct a *unified* graph across multiple files within your knowledge base. When a newly uploaded file starts parsing, the generated graph will automatically update.
|
||||
From v0.16.0 onward, RAGFlow supports constructing a knowledge graph on a knowledge base, allowing you to construct a *unified* graph across multiple files within your knowledge base. When a newly uploaded file starts parsing, the generated graph will automatically update.
|
||||
|
||||
:::danger WARNING
|
||||
Constructing a knowledge graph requires significant memory, computational resources, and tokens.
|
||||
|
|
|
@ -9,7 +9,7 @@ Conduct a retrieval test on your knowledge base to check whether the intended ch
|
|||
|
||||
---
|
||||
|
||||
After your files are uploaded and parsed, it is recommended that you run a retrieval test before proceeding with the chat assistant configuration. Just like fine-tuning a precision instrument, RAGFlow requires careful tuning to deliver optimal question answering performance. Your knowledge base settings, chat assistant configurations, and the specified large and small models can all significantly impact the final results. Running a retrieval test verifies whether the intended chunks can be recovered, allowing you to quickly identify areas for improvement or pinpoint any issue that needs addressing. For instance, when debugging your question answering system, if you know that the correct chunks can be retrieved, you can focus your efforts elsewhere.
|
||||
After your files are uploaded and parsed, it is recommended that you run a retrieval test before proceeding with the chat assistant configuration. Running a retrieval test is *not* an unnecessary or superfluous step at all! Just like fine-tuning a precision instrument, RAGFlow requires careful tuning to deliver optimal question answering performance. Your knowledge base settings, chat assistant configurations, and the specified large and small models can all significantly impact the final results. Running a retrieval test verifies whether the intended chunks can be recovered, allowing you to quickly identify areas for improvement or pinpoint any issue that needs addressing. For instance, when debugging your question answering system, if you know that the correct chunks can be retrieved, you can focus your efforts elsewhere. For example, in issue [#5627](https://github.com/infiniflow/ragflow/issues/5627), the problem was found to be due to the LLM's limitations.
|
||||
|
||||
During a retrieval test, chunks created from your specified chunk method are retrieved using a hybrid search. This search combines weighted keyword similarity with either weighted vector cosine similarity or a weighted reranking score, depending on your settings:
|
||||
|
||||
|
@ -75,6 +75,10 @@ This field is where you put in your testing query.
|
|||
*The following is a screenshot of a retrieval test conducted using a knowledge graph. It shows that only vector similarity is used for knowledge graph-generated chunks:*
|
||||

|
||||
|
||||
:::caution WARNING
|
||||
If you have adjusted the default settings, such as keyword similarity weight or similarity threshold, to achieve the optimal results, be aware that these changes will not be automatically saved. You must apply them to your chat assistant settings or the **Retrieval** agent component settings.
|
||||
:::
|
||||
|
||||
## Frequently asked questions
|
||||
|
||||
### Is an LLM used when the Use Knowledge Graph switch is enabled?
|
||||
|
|
|
@ -0,0 +1,103 @@
|
|||
---
|
||||
sidebar_position: 6
|
||||
slug: /use_tag_sets
|
||||
---
|
||||
|
||||
# Use tag set
|
||||
|
||||
Use a tag set to tag chunks in your datasets.
|
||||
|
||||
---
|
||||
|
||||
Retrieval accuracy is the touchstone for a production-ready RAG framework. In addition to retrieval-enhancing approaches like auto-keyword, auto-question, and knowledge graph, RAGFlow introduces an auto-tagging feature to address semantic gaps. The auto-tagging feature automatically maps tags in the user-defined tag sets to relevant chunks within your knowledge base based on similarity with each chunk. This automation mechanism allows you to apply an additional "layer" of domain-specific knowledge to existing datasets, which is particularly useful when dealing with a large number of chunks.
|
||||
|
||||
To use this feature, ensure you have at least one properly configured tag set, specify the tag set(s) on the **Configuration** page of your knowledge base (dataset), and then re-parse your documents to initiate the auto-tag process. During this process, each chunk in your dataset is compared with every entry in the specified tag set(s), and tags are automatically applied based on similarity.
|
||||
|
||||
## Scenarios
|
||||
|
||||
Auto-tagging applies in situations where chunks are so similar to each other that the intended chunks cannot be distinguished from the rest. For example, when you have a few chunks about iPhone and a majority about iPhone case or iPhone accessaries, it becomes difficult to retrieve the iPhone-specific chunks without additional information.
|
||||
|
||||
## Create tag set
|
||||
|
||||
You can consider a tag set as a closed set, and the tags to attach to the chunks in your dataset (knowledge base) are *exclusively* from the specified tag set. You use a tag set to "inform" RAGFlow which chunks to tag and which tags to apply.
|
||||
|
||||
### Prepare a tag table file
|
||||
|
||||
A tag set can comprise one or multiple table files in XLSX, CSV, or TXT formats. Each table file in the tag set contains two columns, **Description** and **Tag**:
|
||||
|
||||
- The first column provides descriptions of the tags listed in the second column. These descriptions can be example chunks or example queries. Similarity will be calculated between each entry in this column and every chunk in your dataset.
|
||||
- The **Tag** column includes tags to pair with the description entries. Multiple tags should be separated by a comma (,).
|
||||
|
||||
:::tip NOTE
|
||||
As a rule of thumb, consider including the following entries in your tag table:
|
||||
|
||||
- Descriptions of intended chunks, along with their corresponding tags.
|
||||
- User queries that fail to retrieve the correct responses using other methods, ensuring their tags match the intended chunks in your dataset.
|
||||
:::
|
||||
|
||||
### Create a tag set
|
||||
|
||||
1. Click **+ Create knowledge base** to create a knowledge base.
|
||||
2. Navigate to the **Configuration** page of the created knowledge base and choose **Tag** as the default chunk method.
|
||||
3. Navigate to the **Dataset** page and upload and parse your table file in XLSX, CSV, or TXT formats.
|
||||
_A tag cloud appears under the **Tag view** section, indicating the tag set is created:_
|
||||

|
||||
4. Click the **Table** tab to view the tag frequency table:
|
||||

|
||||
|
||||
:::danger IMPORTANT
|
||||
A tag set is *not* involved in document indexing or retrieval. Do not specify a tag set when configuring your chat assistant or agent.
|
||||
:::
|
||||
|
||||
## Tag chunks
|
||||
|
||||
Once a tag set is created, you can apply it to your dataset:
|
||||
|
||||
1. Navigate to the **Configuration** page of your knowledge base (dataset).
|
||||
2. Select the tag set from the **Tag sets** dropdown and click **Save** to confirm.
|
||||
|
||||
:::tip NOTE
|
||||
If the tag set is missing from the dropdown, check that it has been created or configured correctly.
|
||||
:::
|
||||
|
||||
3. Re-parse your documents to start the auto-tagging process.
|
||||
_In an AI chat scenario using auto-tagged datasets, each query will be tagged using the corresponding tag set(s) and chunks with these tags will have a higher chance to be retrieved._
|
||||
|
||||
## Update tag set
|
||||
|
||||
Creating a tag set is *not* for once and for all. Oftentimes, you may find it necessary to update or delete existing tags or add new entries.
|
||||
|
||||
- You can update the existing tag set in the tag frequency table.
|
||||
- To add new entries, you can add and parse new table files in XLSX, CSV, or TXT formats.
|
||||
|
||||
### Update tag set in tag frequency table
|
||||
|
||||
1. Navigate to the **Configuration** page in your tag set.
|
||||
2. Click the **Table** tab under **Tag view** to view the tag frequncy table, where you can update tag names or delete tags.
|
||||
|
||||
:::danger IMPORTANT
|
||||
When a tag set is updated, you must re-parse the documents in your dataset so that their tags can be updated accordingly.
|
||||
:::
|
||||
|
||||
### Add new table files
|
||||
|
||||
1. Navigate to the **Configuration** page in your tag set.
|
||||
2. Navigate to the **Dataset** page and upload and parse your table file in XLSX, CSV, or TXT formats.
|
||||
|
||||
:::danger IMPORTANT
|
||||
If you add new table files to your tag set, it is at your own discretion whether to re-parse your documents in your datasets.
|
||||
:::
|
||||
|
||||
## Frequently asked questions
|
||||
|
||||
### Can I reference more than one tag set?
|
||||
|
||||
Yes, you can. Usually one tag set suffices. When using multiple tag sets, ensure they are independent of each other; otherwise, consider merging your tag sets.
|
||||
|
||||
### Difference between a tag set and a standard knowledge base?
|
||||
|
||||
A standard knowledge base is a dataset. It will be searched by RAGFlow's document engine and the retrieved chunks will be fed to the LLM. In contrast, a tag set is used solely to attach tags to chunks within your dataset. It does not directly participate in the retrieval process, and you should not choose a tag set when selecting datasets for your chat assistant or agent.
|
||||
|
||||
### Difference between auto-tag and auto-keyword?
|
||||
|
||||
Both features enhance retrieval in RAGFlow. The auto-keyword feature relies on the LLM and consumes a significant number of tokens, whereas the auto-tag feature is based on vector similarity and predefined tag set(s). You can view the keywords applied in the auto-keyword feature as an open set, as they are generated by the LLM. In contrast, a tag set can be considered a user-defined close set, requiring upload tag set(s) in specified formats before use.
|
|
@ -1,8 +0,0 @@
|
|||
{
|
||||
"label": "Develop",
|
||||
"position": 10,
|
||||
"link": {
|
||||
"type": "generated-index",
|
||||
"description": "Guides for Hardcore Developers"
|
||||
}
|
||||
}
|
|
@ -5,7 +5,11 @@ slug: /manage_files
|
|||
|
||||
# Files
|
||||
|
||||
Knowledge base, hallucination-free chat, and file management are the three pillars of RAGFlow. RAGFlow's file management allows you to upload files individually or in bulk. You can then link an uploaded file to multiple target knowledge bases. This guide showcases some basic usages of the file management feature.
|
||||
Knowledge base, hallucination-free chat, and file management are the three pillars of RAGFlow. RAGFlow's file management allows you to upload files individually or in bulk. You can then link an uploaded file to multiple target knowledge bases. This guide showcases some basic usages of the file management feature.
|
||||
|
||||
:::danger IMPORTANT
|
||||
Compared to uploading files directly to various knowledge bases, uploading them to RAGFlow's file management and then linking them to different knowledge bases is *not* an unnecessary step, particularly when you want to delete some parsed files or an entire knowledge base but retain the original files.
|
||||
:::
|
||||
|
||||
## Create folder
|
||||
|
||||
|
@ -35,7 +39,7 @@ RAGFlow's file management supports previewing files in the following formats:
|
|||
|
||||
## Link file to knowledge bases
|
||||
|
||||
RAGFlow's file management allows you to *link* an uploaded file to multiple knowledge bases, creating a file reference in each target knowledge base. Therefore, deleting a file in your file management will AUTOMATICALLY REMOVE all related file references across the knowledge bases.
|
||||
RAGFlow's file management allows you to *link* an uploaded file to multiple knowledge bases, creating a file reference in each target knowledge base. Therefore, deleting a file in your file management will AUTOMATICALLY REMOVE all related file references across the knowledge bases.
|
||||
|
||||

|
||||
|
||||
|
@ -81,4 +85,4 @@ RAGFlow's file management allows you to download an uploaded file:
|
|||
|
||||

|
||||
|
||||
> As of RAGFlow v0.17.0, bulk download is not supported, nor can you download an entire folder.
|
||||
> As of RAGFlow v0.17.2, bulk download is not supported, nor can you download an entire folder.
|
||||
|
|
|
@ -3,7 +3,7 @@ sidebar_position: 4
|
|||
slug: /manage_team_members
|
||||
---
|
||||
|
||||
# Manage team members
|
||||
# Team
|
||||
|
||||
Invite or remove team members, join or leave a team.
|
||||
|
||||
|
@ -11,8 +11,9 @@ Invite or remove team members, join or leave a team.
|
|||
|
||||
By default, each RAGFlow user is assigned a single team named after their name. RAGFlow allows you to invite RAGFlow users to your team. Your team members can help you:
|
||||
|
||||
- Upload documents to your datasets.
|
||||
- Upload documents to your datasets (knowledge bases).
|
||||
- Update document configurations in your datasets.
|
||||
- Update the default configurations for your datasets.
|
||||
- Parse documents in your datasets.
|
||||
|
||||
:::tip NOTE
|
||||
|
|
|
@ -3,7 +3,7 @@ sidebar_position: 2
|
|||
slug: /deploy_local_llm
|
||||
---
|
||||
|
||||
# Deploy a local LLM
|
||||
# Deploy LLM locally
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
|
@ -25,7 +25,6 @@ This user guide does not intend to cover much of the installation or configurati
|
|||
|
||||
:::note
|
||||
- For information about downloading Ollama, see [here](https://github.com/ollama/ollama?tab=readme-ov-file#ollama).
|
||||
- For information about configuring Ollama server, see [here](https://github.com/ollama/ollama/blob/main/docs/faq.md#how-do-i-configure-ollama-server).
|
||||
- For a complete list of supported models and variants, see the [Ollama model library](https://ollama.com/library).
|
||||
:::
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ sidebar_position: 11
|
|||
slug: /upgrade_ragflow
|
||||
---
|
||||
|
||||
# Upgrade RAGFlow
|
||||
# Upgrade
|
||||
import Tabs from '@theme/Tabs';
|
||||
import TabItem from '@theme/TabItem';
|
||||
|
||||
|
@ -62,16 +62,16 @@ To upgrade RAGFlow, you must upgrade **both** your code **and** your Docker imag
|
|||
git clone https://github.com/infiniflow/ragflow.git
|
||||
```
|
||||
|
||||
2. Switch to the latest, officially published release, e.g., `v0.17.0`:
|
||||
2. Switch to the latest, officially published release, e.g., `v0.17.2`:
|
||||
|
||||
```bash
|
||||
git checkout -f v0.17.0
|
||||
git checkout -f v0.17.2
|
||||
```
|
||||
|
||||
3. Update **ragflow/docker/.env** as follows:
|
||||
|
||||
```bash
|
||||
RAGFLOW_IMAGE=infiniflow/ragflow:v0.17.0
|
||||
RAGFLOW_IMAGE=infiniflow/ragflow:v0.17.2
|
||||
```
|
||||
|
||||
4. Update the RAGFlow image and restart RAGFlow:
|
||||
|
|
|
@ -20,7 +20,7 @@ This quick start guide describes a general process from:
|
|||
:::danger IMPORTANT
|
||||
We officially support x86 CPU and Nvidia GPU, and this document offers instructions on deploying RAGFlow using Docker on x86 platforms. While we also test RAGFlow on ARM64 platforms, we do not maintain RAGFlow Docker images for ARM.
|
||||
|
||||
If you are on an ARM platform, follow [this guide](https://ragflow.io/docs/dev/build_docker_image) to build a RAGFlow Docker image.
|
||||
If you are on an ARM platform, follow [this guide](./develop/build_docker_image.mdx) to build a RAGFlow Docker image.
|
||||
:::
|
||||
|
||||
## Prerequisites
|
||||
|
@ -39,7 +39,7 @@ This section provides instructions on setting up the RAGFlow server on Linux. If
|
|||
|
||||
`vm.max_map_count`. This value sets the maximum number of memory map areas a process may have. Its default value is 65530. While most applications require fewer than a thousand maps, reducing this value can result in abnormal behaviors, and the system will throw out-of-memory errors when a process reaches the limitation.
|
||||
|
||||
RAGFlow v0.17.0 uses Elasticsearch or [Infinity](https://github.com/infiniflow/infinity) for multiple recall. Setting the value of `vm.max_map_count` correctly is crucial to the proper functioning of the Elasticsearch component.
|
||||
RAGFlow v0.17.2 uses Elasticsearch or [Infinity](https://github.com/infiniflow/infinity) for multiple recall. Setting the value of `vm.max_map_count` correctly is crucial to the proper functioning of the Elasticsearch component.
|
||||
|
||||
<Tabs
|
||||
defaultValue="linux"
|
||||
|
@ -179,13 +179,13 @@ This section provides instructions on setting up the RAGFlow server on Linux. If
|
|||
```bash
|
||||
$ git clone https://github.com/infiniflow/ragflow.git
|
||||
$ cd ragflow/docker
|
||||
$ git checkout -f v0.17.0
|
||||
$ git checkout -f v0.17.2
|
||||
```
|
||||
|
||||
3. Use the pre-built Docker images and start up the server:
|
||||
|
||||
:::tip NOTE
|
||||
The command below downloads the `v0.17.0-slim` edition of the RAGFlow Docker image. Refer to the following table for descriptions of different RAGFlow editions. To download a RAGFlow edition different from `v0.17.0-slim`, update the `RAGFLOW_IMAGE` variable accordingly in **docker/.env** before using `docker compose` to start the server. For example: set `RAGFLOW_IMAGE=infiniflow/ragflow:v0.17.0` for the full edition `v0.17.0`.
|
||||
The command below downloads the `v0.17.2-slim` edition of the RAGFlow Docker image. Refer to the following table for descriptions of different RAGFlow editions. To download a RAGFlow edition different from `v0.17.2-slim`, update the `RAGFLOW_IMAGE` variable accordingly in **docker/.env** before using `docker compose` to start the server. For example: set `RAGFLOW_IMAGE=infiniflow/ragflow:v0.17.2` for the full edition `v0.17.2`.
|
||||
:::
|
||||
|
||||
```bash
|
||||
|
@ -198,8 +198,8 @@ This section provides instructions on setting up the RAGFlow server on Linux. If
|
|||
|
||||
| RAGFlow image tag | Image size (GB) | Has embedding models and Python packages? | Stable? |
|
||||
| ------------------- | --------------- | ----------------------------------------- | ------------------------ |
|
||||
| `v0.17.0` | ≈9 | :heavy_check_mark: | Stable release |
|
||||
| `v0.17.0-slim` | ≈2 | ❌ | Stable release |
|
||||
| `v0.17.2` | ≈9 | :heavy_check_mark: | Stable release |
|
||||
| `v0.17.2-slim` | ≈2 | ❌ | Stable release |
|
||||
| `nightly` | ≈9 | :heavy_check_mark: | *Unstable* nightly build |
|
||||
| `nightly-slim` | ≈2 | ❌ | *Unstable* nightly build |
|
||||
|
||||
|
@ -356,7 +356,7 @@ Conversations in RAGFlow are based on a particular knowledge base or multiple kn
|
|||
:::tip NOTE
|
||||
RAGFlow also offers HTTP and Python APIs for you to integrate RAGFlow's capabilities into your applications. Read the following documents for more information:
|
||||
|
||||
- [Acquire a RAGFlow API key](./guides/develop/acquire_ragflow_api_key.md)
|
||||
- [Acquire a RAGFlow API key](./develop/acquire_ragflow_api_key.md)
|
||||
- [HTTP API reference](./references/http_api_reference.md)
|
||||
- [Python API reference](./references/python_api_reference.md)
|
||||
:::
|
||||
|
|
|
@ -178,7 +178,6 @@ Creates a dataset.
|
|||
- `"name"`: `string`
|
||||
- `"avatar"`: `string`
|
||||
- `"description"`: `string`
|
||||
- `"language"`: `string`
|
||||
- `"embedding_model"`: `string`
|
||||
- `"permission"`: `string`
|
||||
- `"chunk_method"`: `string`
|
||||
|
@ -214,11 +213,6 @@ curl --request POST \
|
|||
- `"description"`: (*Body parameter*), `string`
|
||||
A brief description of the dataset to create.
|
||||
|
||||
- `"language"`: (*Body parameter*), `string`
|
||||
The language setting of the dataset to create. Available options:
|
||||
- `"English"` (default)
|
||||
- `"Chinese"`
|
||||
|
||||
- `"embedding_model"`: (*Body parameter*), `string`
|
||||
The name of the embedding model to use. For example: `"BAAI/bge-zh-v1.5"`
|
||||
|
||||
|
@ -634,6 +628,7 @@ Updates configurations for a specified document.
|
|||
- `'Authorization: Bearer <YOUR_API_KEY>'`
|
||||
- Body:
|
||||
- `"name"`:`string`
|
||||
- `"meta_fields"`:`object`
|
||||
- `"chunk_method"`:`string`
|
||||
- `"parser_config"`:`object`
|
||||
|
||||
|
@ -660,6 +655,7 @@ curl --request PUT \
|
|||
- `document_id`: (*Path parameter*)
|
||||
The ID of the document to update.
|
||||
- `"name"`: (*Body parameter*), `string`
|
||||
- `"meta_fields"`: (*Body parameter*), `dict[str, Any]` The meta fields of the document.
|
||||
- `"chunk_method"`: (*Body parameter*), `string`
|
||||
The parsing method to apply to the document:
|
||||
- `"naive"`: General
|
||||
|
@ -672,8 +668,6 @@ curl --request PUT \
|
|||
- `"presentation"`: Presentation
|
||||
- `"picture"`: Picture
|
||||
- `"one"`: One
|
||||
- `"knowledge_graph"`: Knowledge Graph
|
||||
Ensure your LLM is properly configured on the **Settings** page before selecting this. Please also note that Knowledge Graph consumes a large number of Tokens!
|
||||
- `"email"`: Email
|
||||
- `"parser_config"`: (*Body parameter*), `object`
|
||||
The configuration settings for the dataset parser. The attributes in this JSON object vary with the selected `"chunk_method"`:
|
||||
|
@ -2519,6 +2513,7 @@ Asks a specified agent a question to start an AI-powered conversation.
|
|||
- `"stream"`: `boolean`
|
||||
- `"session_id"`: `string`
|
||||
- `"user_id"`: `string`(optional)
|
||||
- `"sync_dsl"`: `boolean` (optional)
|
||||
- other parameters: `string`
|
||||
##### Request example
|
||||
If the **Begin** component does not take parameters, the following code will create a session.
|
||||
|
@ -2571,6 +2566,8 @@ curl --request POST \
|
|||
The ID of the session. If it is not provided, a new session will be generated.
|
||||
- `"user_id"`: (*Body parameter*), `string`
|
||||
The optional user-defined ID. Valid *only* when no `session_id` is provided.
|
||||
- `"sync_dsl"`: (*Body parameter*), `boolean`
|
||||
Whether to synchronize the changes to existing sessions when an agent is modified, defaults to `false`.
|
||||
- Other parameters: (*Body Parameter*)
|
||||
Parameters specified in the **Begin** component.
|
||||
|
||||
|
@ -2722,7 +2719,7 @@ Failure:
|
|||
|
||||
### List agent sessions
|
||||
|
||||
**GET** `/api/v1/agents/{agent_id}/sessions?page={page}&page_size={page_size}&orderby={orderby}&desc={desc}&id={session_id}&user_id={user_id}`
|
||||
**GET** `/api/v1/agents/{agent_id}/sessions?page={page}&page_size={page_size}&orderby={orderby}&desc={desc}&id={session_id}&user_id={user_id}&dsl={dsl}`
|
||||
|
||||
Lists sessions associated with a specified agent.
|
||||
|
||||
|
@ -2759,7 +2756,9 @@ curl --request GET \
|
|||
The ID of the agent session to retrieve.
|
||||
- `user_id`: (*Filter parameter*), `string`
|
||||
The optional user-defined ID passed in when creating session.
|
||||
|
||||
- `dsl`: (*Filter parameter*), `boolean`
|
||||
Indicates whether to include the dsl field of the sessions in the response. Defaults to `true`.
|
||||
|
||||
#### Response
|
||||
|
||||
Success:
|
||||
|
@ -2767,7 +2766,7 @@ Success:
|
|||
```json
|
||||
{
|
||||
"code": 0,
|
||||
"data": {
|
||||
"data": [{
|
||||
"agent_id": "e9e2b9c2b2f911ef801d0242ac120006",
|
||||
"dsl": {
|
||||
"answer": [],
|
||||
|
@ -2899,7 +2898,7 @@ Success:
|
|||
],
|
||||
"source": "agent",
|
||||
"user_id": ""
|
||||
}
|
||||
}]
|
||||
}
|
||||
```
|
||||
|
||||
|
@ -2914,6 +2913,62 @@ Failure:
|
|||
|
||||
---
|
||||
|
||||
### Delete agent's sessions
|
||||
|
||||
**DELETE** `/api/v1/agents/{agent_id}/sessions`
|
||||
|
||||
Deletes sessions of a agent by ID.
|
||||
|
||||
#### Request
|
||||
|
||||
- Method: DELETE
|
||||
- URL: `/api/v1/agents/{agent_id}/sessions`
|
||||
- Headers:
|
||||
- `'content-Type: application/json'`
|
||||
- `'Authorization: Bearer <YOUR_API_KEY>'`
|
||||
- Body:
|
||||
- `"ids"`: `list[string]`
|
||||
|
||||
##### Request example
|
||||
|
||||
```bash
|
||||
curl --request DELETE \
|
||||
--url http://{address}/api/v1/agents/{agent_id}/sessions \
|
||||
--header 'Content-Type: application/json' \
|
||||
--header 'Authorization: Bearer <YOUR_API_KEY>' \
|
||||
--data '
|
||||
{
|
||||
"ids": ["test_1", "test_2"]
|
||||
}'
|
||||
```
|
||||
|
||||
##### Request Parameters
|
||||
|
||||
- `agent_id`: (*Path parameter*)
|
||||
The ID of the associated agent.
|
||||
- `"ids"`: (*Body Parameter*), `list[string]`
|
||||
The IDs of the sessions to delete. If it is not specified, all sessions associated with the specified agent will be deleted.
|
||||
|
||||
#### Response
|
||||
|
||||
Success:
|
||||
|
||||
```json
|
||||
{
|
||||
"code": 0
|
||||
}
|
||||
```
|
||||
|
||||
Failure:
|
||||
|
||||
```json
|
||||
{
|
||||
"code": 102,
|
||||
"message": "The agent doesn't own the session cbd31e52f73911ef93b232903b842af6"
|
||||
}
|
||||
```
|
||||
---
|
||||
|
||||
## AGENT MANAGEMENT
|
||||
|
||||
---
|
||||
|
|
|
@ -82,7 +82,6 @@ RAGFlow.create_dataset(
|
|||
avatar: str = "",
|
||||
description: str = "",
|
||||
embedding_model: str = "BAAI/bge-large-zh-v1.5",
|
||||
language: str = "English",
|
||||
permission: str = "me",
|
||||
chunk_method: str = "naive",
|
||||
parser_config: DataSet.ParserConfig = None
|
||||
|
@ -97,11 +96,6 @@ Creates a dataset.
|
|||
|
||||
The unique name of the dataset to create. It must adhere to the following requirements:
|
||||
|
||||
- Permitted characters include:
|
||||
- English letters (a-z, A-Z)
|
||||
- Digits (0-9)
|
||||
- "_" (underscore)
|
||||
- Must begin with an English letter or underscore.
|
||||
- Maximum 65,535 characters.
|
||||
- Case-insensitive.
|
||||
|
||||
|
@ -113,12 +107,6 @@ Base64 encoding of the avatar. Defaults to `""`
|
|||
|
||||
A brief description of the dataset to create. Defaults to `""`.
|
||||
|
||||
##### language: `str`
|
||||
|
||||
The language setting of the dataset to create. Available options:
|
||||
|
||||
- `"English"` (default)
|
||||
- `"Chinese"`
|
||||
|
||||
##### permission
|
||||
|
||||
|
@ -313,9 +301,6 @@ A dictionary representing the attributes to update, with the following keys:
|
|||
- `"picture"`: Picture
|
||||
- `"one"`: One
|
||||
- `"email"`: Email
|
||||
- `"knowledge_graph"`: Knowledge Graph
|
||||
Ensure your LLM is properly configured on the **Settings** page before selecting this. Please also note that Knowledge Graph consumes a large number of Tokens!
|
||||
- `"meta_fields"`: `dict[str, Any]` The meta fields of the dataset.
|
||||
|
||||
#### Returns
|
||||
|
||||
|
@ -384,6 +369,7 @@ Updates configurations for the current document.
|
|||
A dictionary representing the attributes to update, with the following keys:
|
||||
|
||||
- `"display_name"`: `str` The name of the document to update.
|
||||
- `"meta_fields"`: `dict[str, Any]` The meta fields of the document.
|
||||
- `"chunk_method"`: `str` The parsing method to apply to the document.
|
||||
- `"naive"`: General
|
||||
- `"manual`: Manual
|
||||
|
@ -1460,21 +1446,13 @@ while True:
|
|||
### Create session with agent
|
||||
|
||||
```python
|
||||
Agent.create_session(id,rag, **kwargs) -> Session
|
||||
Agent.create_session(**kwargs) -> Session
|
||||
```
|
||||
|
||||
Creates a session with the current agent.
|
||||
|
||||
#### Parameters
|
||||
|
||||
##### id: `str`, *Required*
|
||||
|
||||
The id of agent
|
||||
|
||||
##### rag:`RAGFlow object`
|
||||
|
||||
The RAGFlow object
|
||||
|
||||
##### **kwargs
|
||||
|
||||
The parameters in `begin` component.
|
||||
|
@ -1494,7 +1472,8 @@ from ragflow_sdk import RAGFlow, Agent
|
|||
|
||||
rag_object = RAGFlow(api_key="<YOUR_API_KEY>", base_url="http://<YOUR_BASE_URL>:9380")
|
||||
AGENT_ID = "AGENT_ID"
|
||||
session = Agent.create_session(AGENT_ID, rag_object)
|
||||
agent = rag_object.list_agents(id = AGENT_id)[0]
|
||||
session = agent.create_session()
|
||||
```
|
||||
|
||||
---
|
||||
|
@ -1571,7 +1550,8 @@ from ragflow_sdk import RAGFlow, Agent
|
|||
|
||||
rag_object = RAGFlow(api_key="<YOUR_API_KEY>", base_url="http://<YOUR_BASE_URL>:9380")
|
||||
AGENT_id = "AGENT_ID"
|
||||
session = Agent.create_session(AGENT_id, rag_object)
|
||||
agent = rag_object.list_agents(id = AGENT_id)[0]
|
||||
session = agent.create_session()
|
||||
|
||||
print("\n===== Miss R ====\n")
|
||||
print("Hello. What can I do for you?")
|
||||
|
@ -1592,8 +1572,6 @@ while True:
|
|||
|
||||
```python
|
||||
Agent.list_sessions(
|
||||
agent_id,
|
||||
rag
|
||||
page: int = 1,
|
||||
page_size: int = 30,
|
||||
orderby: str = "update_time",
|
||||
|
@ -1640,11 +1618,42 @@ The ID of the agent session to retrieve. Defaults to `None`.
|
|||
from ragflow_sdk import RAGFlow
|
||||
|
||||
rag_object = RAGFlow(api_key="<YOUR_API_KEY>", base_url="http://<YOUR_BASE_URL>:9380")
|
||||
agent_id = "2710f2269b4611ef8fdf0242ac120006"
|
||||
sessions=Agent.list_sessions(agent_id,rag_object)
|
||||
AGENT_id = "AGENT_ID"
|
||||
agent = rag_object.list_agents(id = AGENT_id)[0]
|
||||
sessons = agent.list_sessions()
|
||||
for session in sessions:
|
||||
print(session)
|
||||
```
|
||||
---
|
||||
### Delete agent's sessions
|
||||
|
||||
```python
|
||||
Agent.delete_sessions(ids: list[str] = None)
|
||||
```
|
||||
|
||||
Deletes sessions of a agent by ID.
|
||||
|
||||
#### Parameters
|
||||
|
||||
##### ids: `list[str]`
|
||||
|
||||
The IDs of the sessions to delete. Defaults to `None`. If it is not specified, all sessions associated with the agent will be deleted.
|
||||
|
||||
#### Returns
|
||||
|
||||
- Success: No value is returned.
|
||||
- Failure: `Exception`
|
||||
|
||||
#### Examples
|
||||
|
||||
```python
|
||||
from ragflow_sdk import RAGFlow
|
||||
|
||||
rag_object = RAGFlow(api_key="<YOUR_API_KEY>", base_url="http://<YOUR_BASE_URL>:9380")
|
||||
AGENT_id = "AGENT_ID"
|
||||
agent = rag_object.list_agents(id = AGENT_id)[0]
|
||||
agent.delete_sessions(ids=["id_1","id_2"])
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
|
|
|
@ -42,7 +42,6 @@ A complete list of models supported by RAGFlow, which will continue to expand.
|
|||
| Ollama | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | | |
|
||||
| OpenAI | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| OpenAI-API-Compatible | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | |
|
||||
| VLLM | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | |
|
||||
| OpenRouter | :heavy_check_mark: | | | :heavy_check_mark: | | |
|
||||
| PerfXCloud | :heavy_check_mark: | :heavy_check_mark: | | | | |
|
||||
| Replicate | :heavy_check_mark: | :heavy_check_mark: | | | | |
|
||||
|
@ -54,6 +53,7 @@ A complete list of models supported by RAGFlow, which will continue to expand.
|
|||
| TogetherAI | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | |
|
||||
| Tongyi-Qianwen | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
| Upstage | :heavy_check_mark: | :heavy_check_mark: | | | | |
|
||||
| VLLM | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | |
|
||||
| VolcEngine | :heavy_check_mark: | | | | | |
|
||||
| Voyage AI | | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | |
|
||||
| Xinference | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
|
|
|
@ -7,6 +7,100 @@ slug: /release_notes
|
|||
|
||||
Key features, improvements and bug fixes in the latest releases.
|
||||
|
||||
## v0.17.2
|
||||
|
||||
Released on March 13, 2025.
|
||||
|
||||
### Improvements
|
||||
|
||||
- Adds OpenAI-compatible APIs.
|
||||
- Introduces a German user interface.
|
||||
- Accelerates knowledge graph extraction.
|
||||
- Enables Tavily-based web search in the **Retrieval** agent component.
|
||||
- Adds Tongyi-Qianwen QwQ models (OpenAI-compatible).
|
||||
- Supports CSV files in the **General** chunk method.
|
||||
|
||||
### Fixed issues
|
||||
|
||||
- Unable to add models via Ollama/Xinference, an issue introduced in v0.17.1.
|
||||
|
||||
### Related APIs
|
||||
|
||||
#### HTTP APIs
|
||||
|
||||
[Create chat completion](./references/http_api_reference.md#openai-compatible-api)
|
||||
|
||||
#### Python APIs
|
||||
|
||||
[Create chat completion](./references/python_api_reference.md#openai-compatible-api)
|
||||
|
||||
## v0.17.1
|
||||
|
||||
Released on March 11, 2025.
|
||||
|
||||
### Improvements
|
||||
|
||||
- Improves English tokenization quality.
|
||||
- Improves the table extraction logic in Markdown document parsing.
|
||||
- Updates SiliconFlow's model list.
|
||||
- Supports parsing XLS files (Excel97~2003) with improved corresponding error handling.
|
||||
- Supports Huggingface rerank models.
|
||||
- Enables relative time expressions ("now", "yesterday", "last week", "next year", and more) in the **Rewrite** agent component.
|
||||
|
||||
### Fixed issues
|
||||
|
||||
- A repetitive knowledge graph extraction issue.
|
||||
- Issues with API calling.
|
||||
- Options in the **Document parser** dropdown are missing.
|
||||
- A Tavily web search issue.
|
||||
- Unable to preview diagrams or images in an AI chat.
|
||||
|
||||
### Documentation
|
||||
|
||||
#### Added documents
|
||||
|
||||
[Use tag set](./guides/dataset/use_tag_sets.md)
|
||||
|
||||
## v0.17.0
|
||||
|
||||
Released on March 3, 2025.
|
||||
|
||||
### New features
|
||||
|
||||
- AI chat: Implements Deep Research for agentic reasoning. To activate this, enable the **Reasoning** toggle under the **Prompt Engine** tab of your chat assistant dialogue.
|
||||
- AI chat: Leverages Tavily-based web search to enhance contexts in agentic reasoning. To activate this, enter the correct Tavily API key under the **Assistant Setting** tab of your chat assistant dialogue.
|
||||
- AI chat: Supports starting a chat without specifying knowledge bases.
|
||||
- AI chat: HTML files can also be previewed and referenced, in addition to PDF files.
|
||||
- Dataset: Adds a **Document parser** dropdown menu to dataset configurations. This includes a DeepDoc model option, which is time-consuming, a much faster **naive** option (plain text), which skips DLA (Document Layout Analysis), OCR (Optical Character Recognition), and TSR (Table Structure Recognition) tasks, and several currently *experimental* large model options.
|
||||
- Agent component: **(x)** or a forward slash `/` can be used to insert available keys (variables) in the system prompt field of the **Generate** or **Template** component.
|
||||
- Object storage: Supports using Aliyun OSS (Object Storage Service) as a file storage option.
|
||||
- Models: Updates the supported model list for Tongyi-Qianwen (Qwen), adding DeepSeek-specific models; adds ModelScope as a model provider.
|
||||
- APIs: Document metadata can be updated through an API.
|
||||
|
||||
The following diagram illustrates the workflow of RAGFlow's Deep Research:
|
||||
|
||||

|
||||
|
||||
The following is a screenshot of a conversation that integrates Deep Research:
|
||||
|
||||

|
||||
|
||||
### Related APIs
|
||||
|
||||
#### HTTP APIs
|
||||
|
||||
Adds a body parameter `"meta_fields"` to the [Update document](./references/http_api_reference.md#update-document) method.
|
||||
|
||||
#### Python APIs
|
||||
|
||||
Adds a key option `"meta_fields"` to the [Update document](./references/python_api_reference.md#update-document) method.
|
||||
|
||||
### Documentation
|
||||
|
||||
#### Added documents
|
||||
|
||||
[Run retrieval test](./guides/dataset/run_retrieval_test.md)
|
||||
|
||||
## v0.16.0
|
||||
|
||||
Released on February 6, 2025.
|
||||
|
@ -205,9 +299,9 @@ pip install ragflow-sdk==0.13.0
|
|||
|
||||
#### Added documents
|
||||
|
||||
- [Acquire a RAGFlow API key](https://ragflow.io/docs/dev/acquire_ragflow_api_key)
|
||||
- [HTTP API Reference](https://ragflow.io/docs/dev/http_api_reference)
|
||||
- [Python API Reference](https://ragflow.io/docs/dev/python_api_reference)
|
||||
- [Acquire a RAGFlow API key](./develop/acquire_ragflow_api_key.md)
|
||||
- [HTTP API Reference](./references/http_api_reference.md)
|
||||
- [Python API Reference](./references/python_api_reference.md)
|
||||
|
||||
## v0.12.0
|
||||
|
||||
|
@ -315,7 +409,7 @@ Released on May 31, 2024.
|
|||
:::danger IMPORTANT
|
||||
While we also test RAGFlow on ARM64 platforms, we do not maintain RAGFlow Docker images for ARM.
|
||||
|
||||
If you are on an ARM platform, follow [this guide](https://ragflow.io/docs/dev/build_docker_image) to build a RAGFlow Docker image.
|
||||
If you are on an ARM platform, follow [this guide](./develop/build_docker_image.mdx) to build a RAGFlow Docker image.
|
||||
:::
|
||||
|
||||
### Related APIs
|
||||
|
|
|
@ -21,13 +21,14 @@ from dataclasses import dataclass
|
|||
from typing import Any, Callable
|
||||
|
||||
import networkx as nx
|
||||
import trio
|
||||
|
||||
from graphrag.general.extractor import Extractor
|
||||
from rag.nlp import is_english
|
||||
import editdistance
|
||||
from graphrag.entity_resolution_prompt import ENTITY_RESOLUTION_PROMPT
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
from graphrag.utils import perform_variable_replacements
|
||||
from graphrag.utils import perform_variable_replacements, chat_limiter
|
||||
|
||||
DEFAULT_RECORD_DELIMITER = "##"
|
||||
DEFAULT_ENTITY_INDEX_DELIMITER = "<|>"
|
||||
|
@ -67,13 +68,13 @@ class EntityResolution(Extractor):
|
|||
self._resolution_result_delimiter_key = "resolution_result_delimiter"
|
||||
self._input_text_key = "input_text"
|
||||
|
||||
def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None) -> EntityResolutionResult:
|
||||
async def __call__(self, graph: nx.Graph, prompt_variables: dict[str, Any] | None = None, callback: Callable | None = None) -> EntityResolutionResult:
|
||||
"""Call method definition."""
|
||||
if prompt_variables is None:
|
||||
prompt_variables = {}
|
||||
|
||||
# Wire defaults into the prompt variables
|
||||
prompt_variables = {
|
||||
self.prompt_variables = {
|
||||
**prompt_variables,
|
||||
self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
|
||||
or DEFAULT_RECORD_DELIMITER,
|
||||
|
@ -93,85 +94,97 @@ class EntityResolution(Extractor):
|
|||
candidate_resolution = {entity_type: [] for entity_type in entity_types}
|
||||
for k, v in node_clusters.items():
|
||||
candidate_resolution[k] = [(a, b) for a, b in itertools.combinations(v, 2) if self.is_similarity(a, b)]
|
||||
num_candidates = sum([len(candidates) for _, candidates in candidate_resolution.items()])
|
||||
callback(msg=f"Identified {num_candidates} candidate pairs")
|
||||
|
||||
gen_conf = {"temperature": 0.5}
|
||||
resolution_result = set()
|
||||
for candidate_resolution_i in candidate_resolution.items():
|
||||
if candidate_resolution_i[1]:
|
||||
try:
|
||||
pair_txt = [
|
||||
f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n']
|
||||
for index, candidate in enumerate(candidate_resolution_i[1]):
|
||||
pair_txt.append(
|
||||
f'Question {index + 1}: name of{candidate_resolution_i[0]} A is {candidate[0]} ,name of{candidate_resolution_i[0]} B is {candidate[1]}')
|
||||
sent = 'question above' if len(pair_txt) == 1 else f'above {len(pair_txt)} questions'
|
||||
pair_txt.append(
|
||||
f'\nUse domain knowledge of {candidate_resolution_i[0]}s to help understand the text and answer the {sent} in the format: For Question i, Yes, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are the same {candidate_resolution_i[0]}./No, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are different {candidate_resolution_i[0]}s. For Question i+1, (repeat the above procedures)')
|
||||
pair_prompt = '\n'.join(pair_txt)
|
||||
|
||||
variables = {
|
||||
**prompt_variables,
|
||||
self._input_text_key: pair_prompt
|
||||
}
|
||||
text = perform_variable_replacements(self._resolution_prompt, variables=variables)
|
||||
|
||||
response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
|
||||
result = self._process_results(len(candidate_resolution_i[1]), response,
|
||||
prompt_variables.get(self._record_delimiter_key,
|
||||
DEFAULT_RECORD_DELIMITER),
|
||||
prompt_variables.get(self._entity_index_dilimiter_key,
|
||||
DEFAULT_ENTITY_INDEX_DELIMITER),
|
||||
prompt_variables.get(self._resolution_result_delimiter_key,
|
||||
DEFAULT_RESOLUTION_RESULT_DELIMITER))
|
||||
for result_i in result:
|
||||
resolution_result.add(candidate_resolution_i[1][result_i[0] - 1])
|
||||
except Exception:
|
||||
logging.exception("error entity resolution")
|
||||
async with trio.open_nursery() as nursery:
|
||||
for candidate_resolution_i in candidate_resolution.items():
|
||||
if not candidate_resolution_i[1]:
|
||||
continue
|
||||
nursery.start_soon(lambda: self._resolve_candidate(candidate_resolution_i, resolution_result))
|
||||
callback(msg=f"Resolved {num_candidates} candidate pairs, {len(resolution_result)} of them are selected to merge.")
|
||||
|
||||
connect_graph = nx.Graph()
|
||||
removed_entities = []
|
||||
connect_graph.add_edges_from(resolution_result)
|
||||
for sub_connect_graph in nx.connected_components(connect_graph):
|
||||
sub_connect_graph = connect_graph.subgraph(sub_connect_graph)
|
||||
remove_nodes = list(sub_connect_graph.nodes)
|
||||
keep_node = remove_nodes.pop()
|
||||
self._merge_nodes(keep_node, self._get_entity_(remove_nodes))
|
||||
for remove_node in remove_nodes:
|
||||
removed_entities.append(remove_node)
|
||||
remove_node_neighbors = graph[remove_node]
|
||||
remove_node_neighbors = list(remove_node_neighbors)
|
||||
for remove_node_neighbor in remove_node_neighbors:
|
||||
rel = self._get_relation_(remove_node, remove_node_neighbor)
|
||||
if graph.has_edge(remove_node, remove_node_neighbor):
|
||||
graph.remove_edge(remove_node, remove_node_neighbor)
|
||||
if remove_node_neighbor == keep_node:
|
||||
if graph.has_edge(keep_node, remove_node):
|
||||
graph.remove_edge(keep_node, remove_node)
|
||||
continue
|
||||
if not rel:
|
||||
continue
|
||||
if graph.has_edge(keep_node, remove_node_neighbor):
|
||||
self._merge_edges(keep_node, remove_node_neighbor, [rel])
|
||||
else:
|
||||
pair = sorted([keep_node, remove_node_neighbor])
|
||||
graph.add_edge(pair[0], pair[1], weight=rel['weight'])
|
||||
self._set_relation_(pair[0], pair[1],
|
||||
dict(
|
||||
src_id=pair[0],
|
||||
tgt_id=pair[1],
|
||||
weight=rel['weight'],
|
||||
description=rel['description'],
|
||||
keywords=[],
|
||||
source_id=rel.get("source_id", ""),
|
||||
metadata={"created_at": time.time()}
|
||||
))
|
||||
graph.remove_node(remove_node)
|
||||
all_entities_data = []
|
||||
all_relationships_data = []
|
||||
all_remove_nodes = []
|
||||
|
||||
async with trio.open_nursery() as nursery:
|
||||
for sub_connect_graph in nx.connected_components(connect_graph):
|
||||
sub_connect_graph = connect_graph.subgraph(sub_connect_graph)
|
||||
remove_nodes = list(sub_connect_graph.nodes)
|
||||
keep_node = remove_nodes.pop()
|
||||
all_remove_nodes.append(remove_nodes)
|
||||
nursery.start_soon(lambda: self._merge_nodes(keep_node, self._get_entity_(remove_nodes), all_entities_data))
|
||||
for remove_node in remove_nodes:
|
||||
removed_entities.append(remove_node)
|
||||
remove_node_neighbors = graph[remove_node]
|
||||
remove_node_neighbors = list(remove_node_neighbors)
|
||||
for remove_node_neighbor in remove_node_neighbors:
|
||||
rel = self._get_relation_(remove_node, remove_node_neighbor)
|
||||
if graph.has_edge(remove_node, remove_node_neighbor):
|
||||
graph.remove_edge(remove_node, remove_node_neighbor)
|
||||
if remove_node_neighbor == keep_node:
|
||||
if graph.has_edge(keep_node, remove_node):
|
||||
graph.remove_edge(keep_node, remove_node)
|
||||
continue
|
||||
if not rel:
|
||||
continue
|
||||
if graph.has_edge(keep_node, remove_node_neighbor):
|
||||
nursery.start_soon(lambda: self._merge_edges(keep_node, remove_node_neighbor, [rel], all_relationships_data))
|
||||
else:
|
||||
pair = sorted([keep_node, remove_node_neighbor])
|
||||
graph.add_edge(pair[0], pair[1], weight=rel['weight'])
|
||||
self._set_relation_(pair[0], pair[1],
|
||||
dict(
|
||||
src_id=pair[0],
|
||||
tgt_id=pair[1],
|
||||
weight=rel['weight'],
|
||||
description=rel['description'],
|
||||
keywords=[],
|
||||
source_id=rel.get("source_id", ""),
|
||||
metadata={"created_at": time.time()}
|
||||
))
|
||||
graph.remove_node(remove_node)
|
||||
|
||||
return EntityResolutionResult(
|
||||
graph=graph,
|
||||
removed_entities=removed_entities
|
||||
)
|
||||
|
||||
async def _resolve_candidate(self, candidate_resolution_i, resolution_result):
|
||||
gen_conf = {"temperature": 0.5}
|
||||
pair_txt = [
|
||||
f'When determining whether two {candidate_resolution_i[0]}s are the same, you should only focus on critical properties and overlook noisy factors.\n']
|
||||
for index, candidate in enumerate(candidate_resolution_i[1]):
|
||||
pair_txt.append(
|
||||
f'Question {index + 1}: name of{candidate_resolution_i[0]} A is {candidate[0]} ,name of{candidate_resolution_i[0]} B is {candidate[1]}')
|
||||
sent = 'question above' if len(pair_txt) == 1 else f'above {len(pair_txt)} questions'
|
||||
pair_txt.append(
|
||||
f'\nUse domain knowledge of {candidate_resolution_i[0]}s to help understand the text and answer the {sent} in the format: For Question i, Yes, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are the same {candidate_resolution_i[0]}./No, {candidate_resolution_i[0]} A and {candidate_resolution_i[0]} B are different {candidate_resolution_i[0]}s. For Question i+1, (repeat the above procedures)')
|
||||
pair_prompt = '\n'.join(pair_txt)
|
||||
variables = {
|
||||
**self.prompt_variables,
|
||||
self._input_text_key: pair_prompt
|
||||
}
|
||||
text = perform_variable_replacements(self._resolution_prompt, variables=variables)
|
||||
logging.info(f"Created resolution prompt {len(text)} bytes for {len(candidate_resolution_i[1])} entity pairs of type {candidate_resolution_i[0]}")
|
||||
async with chat_limiter:
|
||||
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf))
|
||||
logging.debug(f"_resolve_candidate chat prompt: {text}\nchat response: {response}")
|
||||
result = self._process_results(len(candidate_resolution_i[1]), response,
|
||||
self.prompt_variables.get(self._record_delimiter_key,
|
||||
DEFAULT_RECORD_DELIMITER),
|
||||
self.prompt_variables.get(self._entity_index_dilimiter_key,
|
||||
DEFAULT_ENTITY_INDEX_DELIMITER),
|
||||
self.prompt_variables.get(self._resolution_result_delimiter_key,
|
||||
DEFAULT_RESOLUTION_RESULT_DELIMITER))
|
||||
for result_i in result:
|
||||
resolution_result.add(candidate_resolution_i[1][result_i[0] - 1])
|
||||
|
||||
def _process_results(
|
||||
self,
|
||||
records_length: int,
|
||||
|
|
|
@ -1,268 +0,0 @@
|
|||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
"""
|
||||
Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import argparse
|
||||
import json
|
||||
import re
|
||||
import traceback
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import tiktoken
|
||||
|
||||
from graphrag.general.claim_prompt import CLAIM_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
|
||||
from graphrag.general.extractor import Extractor
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
|
||||
|
||||
DEFAULT_TUPLE_DELIMITER = "<|>"
|
||||
DEFAULT_RECORD_DELIMITER = "##"
|
||||
DEFAULT_COMPLETION_DELIMITER = "<|COMPLETE|>"
|
||||
CLAIM_MAX_GLEANINGS = 1
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClaimExtractorResult:
|
||||
"""Claim extractor result class definition."""
|
||||
|
||||
output: list[dict]
|
||||
source_docs: dict[str, Any]
|
||||
|
||||
|
||||
class ClaimExtractor(Extractor):
|
||||
"""Claim extractor class definition."""
|
||||
|
||||
_extraction_prompt: str
|
||||
_summary_prompt: str
|
||||
_output_formatter_prompt: str
|
||||
_input_text_key: str
|
||||
_input_entity_spec_key: str
|
||||
_input_claim_description_key: str
|
||||
_tuple_delimiter_key: str
|
||||
_record_delimiter_key: str
|
||||
_completion_delimiter_key: str
|
||||
_max_gleanings: int
|
||||
_on_error: ErrorHandlerFn
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm_invoker: CompletionLLM,
|
||||
extraction_prompt: str | None = None,
|
||||
input_text_key: str | None = None,
|
||||
input_entity_spec_key: str | None = None,
|
||||
input_claim_description_key: str | None = None,
|
||||
input_resolved_entities_key: str | None = None,
|
||||
tuple_delimiter_key: str | None = None,
|
||||
record_delimiter_key: str | None = None,
|
||||
completion_delimiter_key: str | None = None,
|
||||
encoding_model: str | None = None,
|
||||
max_gleanings: int | None = None,
|
||||
on_error: ErrorHandlerFn | None = None,
|
||||
):
|
||||
"""Init method definition."""
|
||||
self._llm = llm_invoker
|
||||
self._extraction_prompt = extraction_prompt or CLAIM_EXTRACTION_PROMPT
|
||||
self._input_text_key = input_text_key or "input_text"
|
||||
self._input_entity_spec_key = input_entity_spec_key or "entity_specs"
|
||||
self._tuple_delimiter_key = tuple_delimiter_key or "tuple_delimiter"
|
||||
self._record_delimiter_key = record_delimiter_key or "record_delimiter"
|
||||
self._completion_delimiter_key = (
|
||||
completion_delimiter_key or "completion_delimiter"
|
||||
)
|
||||
self._input_claim_description_key = (
|
||||
input_claim_description_key or "claim_description"
|
||||
)
|
||||
self._input_resolved_entities_key = (
|
||||
input_resolved_entities_key or "resolved_entities"
|
||||
)
|
||||
self._max_gleanings = (
|
||||
max_gleanings if max_gleanings is not None else CLAIM_MAX_GLEANINGS
|
||||
)
|
||||
self._on_error = on_error or (lambda _e, _s, _d: None)
|
||||
|
||||
# Construct the looping arguments
|
||||
encoding = tiktoken.get_encoding(encoding_model or "cl100k_base")
|
||||
yes = encoding.encode("YES")
|
||||
no = encoding.encode("NO")
|
||||
self._loop_args = {"logit_bias": {yes[0]: 100, no[0]: 100}, "max_tokens": 1}
|
||||
|
||||
def __call__(
|
||||
self, inputs: dict[str, Any], prompt_variables: dict | None = None
|
||||
) -> ClaimExtractorResult:
|
||||
"""Call method definition."""
|
||||
if prompt_variables is None:
|
||||
prompt_variables = {}
|
||||
texts = inputs[self._input_text_key]
|
||||
entity_spec = str(inputs[self._input_entity_spec_key])
|
||||
claim_description = inputs[self._input_claim_description_key]
|
||||
resolved_entities = inputs.get(self._input_resolved_entities_key, {})
|
||||
source_doc_map = {}
|
||||
|
||||
prompt_args = {
|
||||
self._input_entity_spec_key: entity_spec,
|
||||
self._input_claim_description_key: claim_description,
|
||||
self._tuple_delimiter_key: prompt_variables.get(self._tuple_delimiter_key)
|
||||
or DEFAULT_TUPLE_DELIMITER,
|
||||
self._record_delimiter_key: prompt_variables.get(self._record_delimiter_key)
|
||||
or DEFAULT_RECORD_DELIMITER,
|
||||
self._completion_delimiter_key: prompt_variables.get(
|
||||
self._completion_delimiter_key
|
||||
)
|
||||
or DEFAULT_COMPLETION_DELIMITER,
|
||||
}
|
||||
|
||||
all_claims: list[dict] = []
|
||||
for doc_index, text in enumerate(texts):
|
||||
document_id = f"d{doc_index}"
|
||||
try:
|
||||
claims = self._process_document(prompt_args, text, doc_index)
|
||||
all_claims += [
|
||||
self._clean_claim(c, document_id, resolved_entities) for c in claims
|
||||
]
|
||||
source_doc_map[document_id] = text
|
||||
except Exception as e:
|
||||
logging.exception("error extracting claim")
|
||||
self._on_error(
|
||||
e,
|
||||
traceback.format_exc(),
|
||||
{"doc_index": doc_index, "text": text},
|
||||
)
|
||||
continue
|
||||
|
||||
return ClaimExtractorResult(
|
||||
output=all_claims,
|
||||
source_docs=source_doc_map,
|
||||
)
|
||||
|
||||
def _clean_claim(
|
||||
self, claim: dict, document_id: str, resolved_entities: dict
|
||||
) -> dict:
|
||||
# clean the parsed claims to remove any claims with status = False
|
||||
obj = claim.get("object_id", claim.get("object"))
|
||||
subject = claim.get("subject_id", claim.get("subject"))
|
||||
|
||||
# If subject or object in resolved entities, then replace with resolved entity
|
||||
obj = resolved_entities.get(obj, obj)
|
||||
subject = resolved_entities.get(subject, subject)
|
||||
claim["object_id"] = obj
|
||||
claim["subject_id"] = subject
|
||||
claim["doc_id"] = document_id
|
||||
return claim
|
||||
|
||||
def _process_document(
|
||||
self, prompt_args: dict, doc, doc_index: int
|
||||
) -> list[dict]:
|
||||
record_delimiter = prompt_args.get(
|
||||
self._record_delimiter_key, DEFAULT_RECORD_DELIMITER
|
||||
)
|
||||
completion_delimiter = prompt_args.get(
|
||||
self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER
|
||||
)
|
||||
variables = {
|
||||
self._input_text_key: doc,
|
||||
**prompt_args,
|
||||
}
|
||||
text = perform_variable_replacements(self._extraction_prompt, variables=variables)
|
||||
gen_conf = {"temperature": 0.5}
|
||||
results = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
|
||||
claims = results.strip().removesuffix(completion_delimiter)
|
||||
history = [{"role": "system", "content": text}, {"role": "assistant", "content": results}]
|
||||
|
||||
# Repeat to ensure we maximize entity count
|
||||
for i in range(self._max_gleanings):
|
||||
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
|
||||
history.append({"role": "user", "content": text})
|
||||
extension = self._chat("", history, gen_conf)
|
||||
claims += record_delimiter + extension.strip().removesuffix(
|
||||
completion_delimiter
|
||||
)
|
||||
|
||||
# If this isn't the last loop, check to see if we should continue
|
||||
if i >= self._max_gleanings - 1:
|
||||
break
|
||||
|
||||
history.append({"role": "assistant", "content": extension})
|
||||
history.append({"role": "user", "content": LOOP_PROMPT})
|
||||
continuation = self._chat("", history, self._loop_args)
|
||||
if continuation != "YES":
|
||||
break
|
||||
|
||||
result = self._parse_claim_tuples(claims, prompt_args)
|
||||
for r in result:
|
||||
r["doc_id"] = f"{doc_index}"
|
||||
return result
|
||||
|
||||
def _parse_claim_tuples(
|
||||
self, claims: str, prompt_variables: dict
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Parse claim tuples."""
|
||||
record_delimiter = prompt_variables.get(
|
||||
self._record_delimiter_key, DEFAULT_RECORD_DELIMITER
|
||||
)
|
||||
completion_delimiter = prompt_variables.get(
|
||||
self._completion_delimiter_key, DEFAULT_COMPLETION_DELIMITER
|
||||
)
|
||||
tuple_delimiter = prompt_variables.get(
|
||||
self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER
|
||||
)
|
||||
|
||||
def pull_field(index: int, fields: list[str]) -> str | None:
|
||||
return fields[index].strip() if len(fields) > index else None
|
||||
|
||||
result: list[dict[str, Any]] = []
|
||||
claims_values = (
|
||||
claims.strip().removesuffix(completion_delimiter).split(record_delimiter)
|
||||
)
|
||||
for claim in claims_values:
|
||||
claim = claim.strip().removeprefix("(").removesuffix(")")
|
||||
claim = re.sub(r".*Output:", "", claim)
|
||||
|
||||
# Ignore the completion delimiter
|
||||
if claim == completion_delimiter:
|
||||
continue
|
||||
|
||||
claim_fields = claim.split(tuple_delimiter)
|
||||
o = {
|
||||
"subject_id": pull_field(0, claim_fields),
|
||||
"object_id": pull_field(1, claim_fields),
|
||||
"type": pull_field(2, claim_fields),
|
||||
"status": pull_field(3, claim_fields),
|
||||
"start_date": pull_field(4, claim_fields),
|
||||
"end_date": pull_field(5, claim_fields),
|
||||
"description": pull_field(6, claim_fields),
|
||||
"source_text": pull_field(7, claim_fields),
|
||||
"doc_id": pull_field(8, claim_fields),
|
||||
}
|
||||
if any([not o["subject_id"], not o["object_id"], o["subject_id"].lower() == "none", o["object_id"] == "none"]):
|
||||
continue
|
||||
result.append(o)
|
||||
return result
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True)
|
||||
parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api import settings
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
|
||||
kb_ids = KnowledgebaseService.get_kb_ids(args.tenant_id)
|
||||
|
||||
ex = ClaimExtractor(LLMBundle(args.tenant_id, LLMType.CHAT))
|
||||
docs = [d["content_with_weight"] for d in settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, kb_ids, max_count=12, fields=["content_with_weight"])]
|
||||
info = {
|
||||
"input_text": docs,
|
||||
"entity_specs": "organization, person",
|
||||
"claim_description": ""
|
||||
}
|
||||
claim = ex(info)
|
||||
logging.info(json.dumps(claim.output, ensure_ascii=False, indent=2))
|
|
@ -1,71 +0,0 @@
|
|||
# Copyright (c) 2024 Microsoft Corporation.
|
||||
# Licensed under the MIT License
|
||||
"""
|
||||
Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
"""
|
||||
|
||||
CLAIM_EXTRACTION_PROMPT = """
|
||||
################
|
||||
-Target activity-
|
||||
################
|
||||
You are an intelligent assistant that helps a human analyst to analyze claims against certain entities presented in a text document.
|
||||
|
||||
################
|
||||
-Goal-
|
||||
################
|
||||
Given a text document that is potentially relevant to this activity, an entity specification, and a claim description, extract all entities that match the entity specification and all claims against those entities.
|
||||
|
||||
################
|
||||
-Steps-
|
||||
################
|
||||
- 1. Extract all named entities that match the predefined entity specification. Entity specification can either be a list of entity names or a list of entity types.
|
||||
- 2. For each entity identified in step 1, extract all claims associated with the entity. Claims need to match the specified claim description, and the entity should be the subject of the claim.
|
||||
For each claim, extract the following information:
|
||||
- Subject: name of the entity that is subject of the claim, capitalized. The subject entity is one that committed the action described in the claim. Subject needs to be one of the named entities identified in step 1.
|
||||
- Object: name of the entity that is object of the claim, capitalized. The object entity is one that either reports/handles or is affected by the action described in the claim. If object entity is unknown, use **NONE**.
|
||||
- Claim Type: overall category of the claim, capitalized. Name it in a way that can be repeated across multiple text inputs, so that similar claims share the same claim type
|
||||
- Claim Status: **TRUE**, **FALSE**, or **SUSPECTED**. TRUE means the claim is confirmed, FALSE means the claim is found to be False, SUSPECTED means the claim is not verified.
|
||||
- Claim Description: Detailed description explaining the reasoning behind the claim, together with all the related evidence and references.
|
||||
- Claim Date: Period (start_date, end_date) when the claim was made. Both start_date and end_date should be in ISO-8601 format. If the claim was made on a single date rather than a date range, set the same date for both start_date and end_date. If date is unknown, return **NONE**.
|
||||
- Claim Source Text: List of **all** quotes from the original text that are relevant to the claim.
|
||||
|
||||
- 3. Format each claim as (<subject_entity>{tuple_delimiter}<object_entity>{tuple_delimiter}<claim_type>{tuple_delimiter}<claim_status>{tuple_delimiter}<claim_start_date>{tuple_delimiter}<claim_end_date>{tuple_delimiter}<claim_description>{tuple_delimiter}<claim_source>)
|
||||
- 4. Return output in language of the 'Text' as a single list of all the claims identified in steps 1 and 2. Use **{record_delimiter}** as the list delimiter.
|
||||
- 5. If there's nothing satisfy the above requirements, just keep output empty.
|
||||
- 6. When finished, output {completion_delimiter}
|
||||
|
||||
################
|
||||
-Examples-
|
||||
################
|
||||
Example 1:
|
||||
Entity specification: organization
|
||||
Claim description: red flags associated with an entity
|
||||
Text: According to an article on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. The company is owned by Person C who was suspected of engaging in corruption activities in 2015.
|
||||
Output:
|
||||
(COMPANY A{tuple_delimiter}GOVERNMENT AGENCY B{tuple_delimiter}ANTI-COMPETITIVE PRACTICES{tuple_delimiter}TRUE{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10{tuple_delimiter}According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.)
|
||||
{completion_delimiter}
|
||||
|
||||
###########################
|
||||
Example 2:
|
||||
Entity specification: Company A, Person C
|
||||
Claim description: red flags associated with an entity
|
||||
Text: According to an article on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B. The company is owned by Person C who was suspected of engaging in corruption activities in 2015.
|
||||
Output:
|
||||
(COMPANY A{tuple_delimiter}GOVERNMENT AGENCY B{tuple_delimiter}ANTI-COMPETITIVE PRACTICES{tuple_delimiter}TRUE{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}2022-01-10T00:00:00{tuple_delimiter}Company A was found to engage in anti-competitive practices because it was fined for bid rigging in multiple public tenders published by Government Agency B according to an article published on 2022/01/10{tuple_delimiter}According to an article published on 2022/01/10, Company A was fined for bid rigging while participating in multiple public tenders published by Government Agency B.)
|
||||
{record_delimiter}
|
||||
(PERSON C{tuple_delimiter}NONE{tuple_delimiter}CORRUPTION{tuple_delimiter}SUSPECTED{tuple_delimiter}2015-01-01T00:00:00{tuple_delimiter}2015-12-30T00:00:00{tuple_delimiter}Person C was suspected of engaging in corruption activities in 2015{tuple_delimiter}The company is owned by Person C who was suspected of engaging in corruption activities in 2015)
|
||||
{completion_delimiter}
|
||||
|
||||
################
|
||||
-Real Data-
|
||||
################
|
||||
Use the following input for your answer.
|
||||
Entity specification: {entity_specs}
|
||||
Claim description: {claim_description}
|
||||
Text: {input_text}
|
||||
Output:"""
|
||||
|
||||
|
||||
CONTINUE_PROMPT = "MANY entities were missed in the last extraction. Add them below using the same format(see 'Steps', start with the 'Output').\nOutput: "
|
||||
LOOP_PROMPT = "It appears some entities may have still been missed. Answer YES {tuple_delimiter} NO if there are still entities that need to be added.\n"
|
|
@ -17,9 +17,9 @@ from graphrag.general.community_report_prompt import COMMUNITY_REPORT_PROMPT
|
|||
from graphrag.general.extractor import Extractor
|
||||
from graphrag.general.leiden import add_community_info2graph
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types
|
||||
from graphrag.utils import perform_variable_replacements, dict_has_keys_with_types, chat_limiter
|
||||
from rag.utils import num_tokens_from_string
|
||||
from timeit import default_timer as timer
|
||||
import trio
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -52,7 +52,7 @@ class CommunityReportsExtractor(Extractor):
|
|||
self._extraction_prompt = COMMUNITY_REPORT_PROMPT
|
||||
self._max_report_length = max_report_length or 1500
|
||||
|
||||
def __call__(self, graph: nx.Graph, callback: Callable | None = None):
|
||||
async def __call__(self, graph: nx.Graph, callback: Callable | None = None):
|
||||
for node_degree in graph.degree:
|
||||
graph.nodes[str(node_degree[0])]["rank"] = int(node_degree[1])
|
||||
|
||||
|
@ -61,60 +61,69 @@ class CommunityReportsExtractor(Extractor):
|
|||
res_str = []
|
||||
res_dict = []
|
||||
over, token_count = 0, 0
|
||||
st = timer()
|
||||
for level, comm in communities.items():
|
||||
logging.info(f"Level {level}: Community: {len(comm.keys())}")
|
||||
for cm_id, ents in comm.items():
|
||||
weight = ents["weight"]
|
||||
ents = ents["nodes"]
|
||||
ent_df = pd.DataFrame(self._get_entity_(ents)).dropna()#[{"entity": n, **graph.nodes[n]} for n in ents])
|
||||
if ent_df.empty or "entity_name" not in ent_df.columns:
|
||||
continue
|
||||
ent_df["entity"] = ent_df["entity_name"]
|
||||
del ent_df["entity_name"]
|
||||
rela_df = pd.DataFrame(self._get_relation_(list(ent_df["entity"]), list(ent_df["entity"]), 10000))
|
||||
if rela_df.empty:
|
||||
continue
|
||||
rela_df["source"] = rela_df["src_id"]
|
||||
rela_df["target"] = rela_df["tgt_id"]
|
||||
del rela_df["src_id"]
|
||||
del rela_df["tgt_id"]
|
||||
async def extract_community_report(community):
|
||||
nonlocal res_str, res_dict, over, token_count
|
||||
cm_id, ents = community
|
||||
weight = ents["weight"]
|
||||
ents = ents["nodes"]
|
||||
ent_df = pd.DataFrame(self._get_entity_(ents)).dropna()
|
||||
if ent_df.empty or "entity_name" not in ent_df.columns:
|
||||
return
|
||||
ent_df["entity"] = ent_df["entity_name"]
|
||||
del ent_df["entity_name"]
|
||||
rela_df = pd.DataFrame(self._get_relation_(list(ent_df["entity"]), list(ent_df["entity"]), 10000))
|
||||
if rela_df.empty:
|
||||
return
|
||||
rela_df["source"] = rela_df["src_id"]
|
||||
rela_df["target"] = rela_df["tgt_id"]
|
||||
del rela_df["src_id"]
|
||||
del rela_df["tgt_id"]
|
||||
|
||||
prompt_variables = {
|
||||
"entity_df": ent_df.to_csv(index_label="id"),
|
||||
"relation_df": rela_df.to_csv(index_label="id")
|
||||
}
|
||||
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
|
||||
gen_conf = {"temperature": 0.3}
|
||||
try:
|
||||
response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
|
||||
token_count += num_tokens_from_string(text + response)
|
||||
response = re.sub(r"^[^\{]*", "", response)
|
||||
response = re.sub(r"[^\}]*$", "", response)
|
||||
response = re.sub(r"\{\{", "{", response)
|
||||
response = re.sub(r"\}\}", "}", response)
|
||||
logging.debug(response)
|
||||
response = json.loads(response)
|
||||
if not dict_has_keys_with_types(response, [
|
||||
("title", str),
|
||||
("summary", str),
|
||||
("findings", list),
|
||||
("rating", float),
|
||||
("rating_explanation", str),
|
||||
]):
|
||||
continue
|
||||
response["weight"] = weight
|
||||
response["entities"] = ents
|
||||
except Exception:
|
||||
logging.exception("CommunityReportsExtractor got exception")
|
||||
continue
|
||||
prompt_variables = {
|
||||
"entity_df": ent_df.to_csv(index_label="id"),
|
||||
"relation_df": rela_df.to_csv(index_label="id")
|
||||
}
|
||||
text = perform_variable_replacements(self._extraction_prompt, variables=prompt_variables)
|
||||
gen_conf = {"temperature": 0.3}
|
||||
async with chat_limiter:
|
||||
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf))
|
||||
token_count += num_tokens_from_string(text + response)
|
||||
response = re.sub(r"^[^\{]*", "", response)
|
||||
response = re.sub(r"[^\}]*$", "", response)
|
||||
response = re.sub(r"\{\{", "{", response)
|
||||
response = re.sub(r"\}\}", "}", response)
|
||||
logging.debug(response)
|
||||
try:
|
||||
response = json.loads(response)
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Failed to parse JSON response: {e}")
|
||||
logging.error(f"Response content: {response}")
|
||||
return
|
||||
if not dict_has_keys_with_types(response, [
|
||||
("title", str),
|
||||
("summary", str),
|
||||
("findings", list),
|
||||
("rating", float),
|
||||
("rating_explanation", str),
|
||||
]):
|
||||
return
|
||||
response["weight"] = weight
|
||||
response["entities"] = ents
|
||||
add_community_info2graph(graph, ents, response["title"])
|
||||
res_str.append(self._get_text_output(response))
|
||||
res_dict.append(response)
|
||||
over += 1
|
||||
if callback:
|
||||
callback(msg=f"Communities: {over}/{total}, used tokens: {token_count}")
|
||||
|
||||
add_community_info2graph(graph, ents, response["title"])
|
||||
res_str.append(self._get_text_output(response))
|
||||
res_dict.append(response)
|
||||
over += 1
|
||||
if callback:
|
||||
callback(msg=f"Communities: {over}/{total}, elapsed: {timer() - st}s, used tokens: {token_count}")
|
||||
st = trio.current_time()
|
||||
async with trio.open_nursery() as nursery:
|
||||
for level, comm in communities.items():
|
||||
logging.info(f"Level {level}: Community: {len(comm.keys())}")
|
||||
for community in comm.items():
|
||||
nursery.start_soon(lambda: extract_community_report(community))
|
||||
if callback:
|
||||
callback(msg=f"Community reports done in {trio.current_time() - st:.2f}s, used tokens: {token_count}")
|
||||
|
||||
return CommunityReportsResult(
|
||||
structured_output=res_dict,
|
||||
|
|
|
@ -14,17 +14,17 @@
|
|||
# limitations under the License.
|
||||
#
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict, Counter
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from copy import deepcopy
|
||||
from typing import Callable
|
||||
import trio
|
||||
|
||||
from graphrag.general.graph_prompt import SUMMARIZE_DESCRIPTIONS_PROMPT
|
||||
from graphrag.utils import get_llm_cache, set_llm_cache, handle_single_entity_extraction, \
|
||||
handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list
|
||||
handle_single_relationship_extraction, split_string_by_multi_markers, flat_uniq_list, chat_limiter
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
from rag.prompts import message_fit_in
|
||||
from rag.utils import truncate
|
||||
|
||||
GRAPH_FIELD_SEP = "<SEP>"
|
||||
|
@ -59,7 +59,8 @@ class Extractor:
|
|||
response = get_llm_cache(self._llm.llm_name, system, hist, conf)
|
||||
if response:
|
||||
return response
|
||||
response = self._llm.chat(system, hist, conf)
|
||||
_, system_msg = message_fit_in([{"role": "system", "content": system}], int(self._llm.max_length * 0.97))
|
||||
response = self._llm.chat(system_msg[0]["content"], hist, conf)
|
||||
response = re.sub(r"<think>.*</think>", "", response, flags=re.DOTALL)
|
||||
if response.find("**ERROR**") >= 0:
|
||||
raise Exception(response)
|
||||
|
@ -91,54 +92,50 @@ class Extractor:
|
|||
)
|
||||
return dict(maybe_nodes), dict(maybe_edges)
|
||||
|
||||
def __call__(
|
||||
self, chunks: list[tuple[str, str]],
|
||||
async def __call__(
|
||||
self, doc_id: str, chunks: list[str],
|
||||
callback: Callable | None = None
|
||||
):
|
||||
|
||||
results = []
|
||||
max_workers = int(os.environ.get('GRAPH_EXTRACTOR_MAX_WORKERS', 10))
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as exe:
|
||||
threads = []
|
||||
for i, (cid, ck) in enumerate(chunks):
|
||||
self.callback = callback
|
||||
start_ts = trio.current_time()
|
||||
out_results = []
|
||||
async with trio.open_nursery() as nursery:
|
||||
for i, ck in enumerate(chunks):
|
||||
ck = truncate(ck, int(self._llm.max_length*0.8))
|
||||
threads.append(
|
||||
exe.submit(self._process_single_content, (cid, ck)))
|
||||
|
||||
for i, _ in enumerate(threads):
|
||||
n, r, tc = _.result()
|
||||
if not isinstance(n, Exception):
|
||||
results.append((n, r))
|
||||
if callback:
|
||||
callback(0.5 + 0.1 * i / len(threads), f"Entities extraction progress ... {i + 1}/{len(threads)} ({tc} tokens)")
|
||||
elif callback:
|
||||
callback(msg="Knowledge graph extraction error:{}".format(str(n)))
|
||||
nursery.start_soon(lambda: self._process_single_content((doc_id, ck), i, len(chunks), out_results))
|
||||
|
||||
maybe_nodes = defaultdict(list)
|
||||
maybe_edges = defaultdict(list)
|
||||
for m_nodes, m_edges in results:
|
||||
sum_token_count = 0
|
||||
for m_nodes, m_edges, token_count in out_results:
|
||||
for k, v in m_nodes.items():
|
||||
maybe_nodes[k].extend(v)
|
||||
for k, v in m_edges.items():
|
||||
maybe_edges[tuple(sorted(k))].extend(v)
|
||||
logging.info("Inserting entities into storage...")
|
||||
sum_token_count += token_count
|
||||
now = trio.current_time()
|
||||
if callback:
|
||||
callback(msg = f"Entities and relationships extraction done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {sum_token_count} tokens, {now-start_ts:.2f}s.")
|
||||
start_ts = now
|
||||
logging.info("Entities merging...")
|
||||
all_entities_data = []
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as exe:
|
||||
threads = []
|
||||
async with trio.open_nursery() as nursery:
|
||||
for en_nm, ents in maybe_nodes.items():
|
||||
threads.append(
|
||||
exe.submit(self._merge_nodes, en_nm, ents))
|
||||
for t in threads:
|
||||
n = t.result()
|
||||
if not isinstance(n, Exception):
|
||||
all_entities_data.append(n)
|
||||
elif callback:
|
||||
callback(msg="Knowledge graph nodes merging error: {}".format(str(n)))
|
||||
nursery.start_soon(lambda: self._merge_nodes(en_nm, ents, all_entities_data))
|
||||
now = trio.current_time()
|
||||
if callback:
|
||||
callback(msg = f"Entities merging done, {now-start_ts:.2f}s.")
|
||||
|
||||
logging.info("Inserting relationships into storage...")
|
||||
start_ts = now
|
||||
logging.info("Relationships merging...")
|
||||
all_relationships_data = []
|
||||
for (src, tgt), rels in maybe_edges.items():
|
||||
all_relationships_data.append(self._merge_edges(src, tgt, rels))
|
||||
async with trio.open_nursery() as nursery:
|
||||
for (src, tgt), rels in maybe_edges.items():
|
||||
nursery.start_soon(lambda: self._merge_edges(src, tgt, rels, all_relationships_data))
|
||||
now = trio.current_time()
|
||||
if callback:
|
||||
callback(msg = f"Relationships merging done, {now-start_ts:.2f}s.")
|
||||
|
||||
if not len(all_entities_data) and not len(all_relationships_data):
|
||||
logging.warning(
|
||||
|
@ -152,7 +149,7 @@ class Extractor:
|
|||
|
||||
return all_entities_data, all_relationships_data
|
||||
|
||||
def _merge_nodes(self, entity_name: str, entities: list[dict]):
|
||||
async def _merge_nodes(self, entity_name: str, entities: list[dict], all_relationships_data):
|
||||
if not entities:
|
||||
return
|
||||
already_entity_types = []
|
||||
|
@ -176,26 +173,22 @@ class Extractor:
|
|||
sorted(set([dp["description"] for dp in entities] + already_description))
|
||||
)
|
||||
already_source_ids = flat_uniq_list(entities, "source_id")
|
||||
try:
|
||||
description = self._handle_entity_relation_summary(
|
||||
entity_name, description
|
||||
)
|
||||
node_data = dict(
|
||||
entity_type=entity_type,
|
||||
description=description,
|
||||
source_id=already_source_ids,
|
||||
)
|
||||
node_data["entity_name"] = entity_name
|
||||
self._set_entity_(entity_name, node_data)
|
||||
return node_data
|
||||
except Exception as e:
|
||||
return e
|
||||
description = await self._handle_entity_relation_summary(entity_name, description)
|
||||
node_data = dict(
|
||||
entity_type=entity_type,
|
||||
description=description,
|
||||
source_id=already_source_ids,
|
||||
)
|
||||
node_data["entity_name"] = entity_name
|
||||
self._set_entity_(entity_name, node_data)
|
||||
all_relationships_data.append(node_data)
|
||||
|
||||
def _merge_edges(
|
||||
async def _merge_edges(
|
||||
self,
|
||||
src_id: str,
|
||||
tgt_id: str,
|
||||
edges_data: list[dict]
|
||||
edges_data: list[dict],
|
||||
all_relationships_data=None
|
||||
):
|
||||
if not edges_data:
|
||||
return
|
||||
|
@ -226,7 +219,7 @@ class Extractor:
|
|||
"description": description,
|
||||
"entity_type": 'UNKNOWN'
|
||||
})
|
||||
description = self._handle_entity_relation_summary(
|
||||
description = await self._handle_entity_relation_summary(
|
||||
f"({src_id}, {tgt_id})", description
|
||||
)
|
||||
edge_data = dict(
|
||||
|
@ -238,23 +231,27 @@ class Extractor:
|
|||
source_id=source_id
|
||||
)
|
||||
self._set_relation_(src_id, tgt_id, edge_data)
|
||||
if all_relationships_data is not None:
|
||||
all_relationships_data.append(edge_data)
|
||||
|
||||
return edge_data
|
||||
|
||||
def _handle_entity_relation_summary(
|
||||
async def _handle_entity_relation_summary(
|
||||
self,
|
||||
entity_or_relation_name: str,
|
||||
description: str
|
||||
) -> str:
|
||||
summary_max_tokens = 512
|
||||
use_description = truncate(description, summary_max_tokens)
|
||||
description_list=use_description.split(GRAPH_FIELD_SEP),
|
||||
if len(description_list) <= 12:
|
||||
return use_description
|
||||
prompt_template = SUMMARIZE_DESCRIPTIONS_PROMPT
|
||||
context_base = dict(
|
||||
entity_name=entity_or_relation_name,
|
||||
description_list=use_description.split(GRAPH_FIELD_SEP),
|
||||
description_list=description_list,
|
||||
language=self._language,
|
||||
)
|
||||
use_prompt = prompt_template.format(**context_base)
|
||||
logging.info(f"Trigger summary: {entity_or_relation_name}")
|
||||
summary = self._chat(use_prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.8})
|
||||
async with chat_limiter:
|
||||
summary = await trio.to_thread.run_sync(lambda: self._chat(use_prompt, [{"role": "user", "content": "Output: "}], {"temperature": 0.8}))
|
||||
return summary
|
||||
|
|
|
@ -5,15 +5,15 @@ Reference:
|
|||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Callable
|
||||
from dataclasses import dataclass
|
||||
import tiktoken
|
||||
import trio
|
||||
|
||||
from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS, DEFAULT_ENTITY_TYPES
|
||||
from graphrag.general.graph_prompt import GRAPH_EXTRACTION_PROMPT, CONTINUE_PROMPT, LOOP_PROMPT
|
||||
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
|
||||
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, chat_limiter
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
import networkx as nx
|
||||
from rag.utils import num_tokens_from_string
|
||||
|
@ -102,53 +102,47 @@ class GraphExtractor(Extractor):
|
|||
self._entity_types_key: ",".join(DEFAULT_ENTITY_TYPES),
|
||||
}
|
||||
|
||||
def _process_single_content(self,
|
||||
chunk_key_dp: tuple[str, str]
|
||||
):
|
||||
async def _process_single_content(self, chunk_key_dp: tuple[str, str], chunk_seq: int, num_chunks: int, out_results):
|
||||
token_count = 0
|
||||
|
||||
chunk_key = chunk_key_dp[0]
|
||||
content = chunk_key_dp[1]
|
||||
variables = {
|
||||
**self._prompt_variables,
|
||||
self._input_text_key: content,
|
||||
}
|
||||
try:
|
||||
gen_conf = {"temperature": 0.3}
|
||||
hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables)
|
||||
response = self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf)
|
||||
token_count += num_tokens_from_string(hint_prompt + response)
|
||||
|
||||
results = response or ""
|
||||
history = [{"role": "system", "content": hint_prompt}, {"role": "user", "content": response}]
|
||||
|
||||
# Repeat to ensure we maximize entity count
|
||||
for i in range(self._max_gleanings):
|
||||
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
|
||||
history.append({"role": "user", "content": text})
|
||||
response = self._chat("", history, gen_conf)
|
||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
|
||||
results += response or ""
|
||||
|
||||
# if this is the final glean, don't bother updating the continuation flag
|
||||
if i >= self._max_gleanings - 1:
|
||||
break
|
||||
history.append({"role": "assistant", "content": response})
|
||||
history.append({"role": "user", "content": LOOP_PROMPT})
|
||||
continuation = self._chat("", history, {"temperature": 0.8})
|
||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
|
||||
if continuation != "YES":
|
||||
break
|
||||
|
||||
record_delimiter = variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER)
|
||||
tuple_delimiter = variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER)
|
||||
records = [re.sub(r"^\(|\)$", "", r.strip()) for r in results.split(record_delimiter)]
|
||||
records = [r for r in records if r.strip()]
|
||||
maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, tuple_delimiter)
|
||||
return maybe_nodes, maybe_edges, token_count
|
||||
except Exception as e:
|
||||
logging.exception("error extracting graph")
|
||||
return e, None, None
|
||||
gen_conf = {"temperature": 0.3}
|
||||
hint_prompt = perform_variable_replacements(self._extraction_prompt, variables=variables)
|
||||
async with chat_limiter:
|
||||
response = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf))
|
||||
token_count += num_tokens_from_string(hint_prompt + response)
|
||||
|
||||
results = response or ""
|
||||
history = [{"role": "system", "content": hint_prompt}, {"role": "user", "content": response}]
|
||||
|
||||
# Repeat to ensure we maximize entity count
|
||||
for i in range(self._max_gleanings):
|
||||
text = perform_variable_replacements(CONTINUE_PROMPT, history=history, variables=variables)
|
||||
history.append({"role": "user", "content": text})
|
||||
async with chat_limiter:
|
||||
response = await trio.to_thread.run_sync(lambda: self._chat("", history, gen_conf))
|
||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
|
||||
results += response or ""
|
||||
|
||||
# if this is the final glean, don't bother updating the continuation flag
|
||||
if i >= self._max_gleanings - 1:
|
||||
break
|
||||
history.append({"role": "assistant", "content": response})
|
||||
history.append({"role": "user", "content": LOOP_PROMPT})
|
||||
async with chat_limiter:
|
||||
continuation = await trio.to_thread.run_sync(lambda: self._chat("", history, {"temperature": 0.8}))
|
||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + response)
|
||||
if continuation != "YES":
|
||||
break
|
||||
record_delimiter = variables.get(self._record_delimiter_key, DEFAULT_RECORD_DELIMITER)
|
||||
tuple_delimiter = variables.get(self._tuple_delimiter_key, DEFAULT_TUPLE_DELIMITER)
|
||||
records = [re.sub(r"^\(|\)$", "", r.strip()) for r in results.split(record_delimiter)]
|
||||
records = [r for r in records if r.strip()]
|
||||
maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, tuple_delimiter)
|
||||
out_results.append((maybe_nodes, maybe_edges, token_count))
|
||||
if self.callback:
|
||||
self.callback(0.5+0.1*len(out_results)/num_chunks, msg = f"Entities extraction of chunk {chunk_seq} {len(out_results)}/{num_chunks} done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {token_count} tokens.")
|
||||
|
|
|
@ -15,183 +15,353 @@
|
|||
#
|
||||
import json
|
||||
import logging
|
||||
from functools import reduce, partial
|
||||
from functools import partial
|
||||
import networkx as nx
|
||||
import trio
|
||||
|
||||
from api import settings
|
||||
from graphrag.light.graph_extractor import GraphExtractor as LightKGExt
|
||||
from graphrag.general.graph_extractor import GraphExtractor as GeneralKGExt
|
||||
from graphrag.general.community_reports_extractor import CommunityReportsExtractor
|
||||
from graphrag.entity_resolution import EntityResolution
|
||||
from graphrag.general.extractor import Extractor
|
||||
from graphrag.general.graph_extractor import DEFAULT_ENTITY_TYPES
|
||||
from graphrag.utils import graph_merge, set_entity, get_relation, set_relation, get_entity, get_graph, set_graph, \
|
||||
chunk_id, update_nodes_pagerank_nhop_neighbour
|
||||
from graphrag.utils import (
|
||||
graph_merge,
|
||||
set_entity,
|
||||
get_relation,
|
||||
set_relation,
|
||||
get_entity,
|
||||
get_graph,
|
||||
set_graph,
|
||||
chunk_id,
|
||||
update_nodes_pagerank_nhop_neighbour,
|
||||
does_graph_contains,
|
||||
get_graph_doc_ids,
|
||||
)
|
||||
from rag.nlp import rag_tokenizer, search
|
||||
from rag.utils.redis_conn import RedisDistributedLock
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
|
||||
class Dealer:
|
||||
def __init__(self,
|
||||
extractor: Extractor,
|
||||
tenant_id: str,
|
||||
kb_id: str,
|
||||
llm_bdl,
|
||||
chunks: list[tuple[str, str]],
|
||||
language,
|
||||
entity_types=DEFAULT_ENTITY_TYPES,
|
||||
embed_bdl=None,
|
||||
callback=None
|
||||
):
|
||||
docids = list(set([docid for docid,_ in chunks]))
|
||||
self.llm_bdl = llm_bdl
|
||||
self.embed_bdl = embed_bdl
|
||||
ext = extractor(self.llm_bdl, language=language,
|
||||
entity_types=entity_types,
|
||||
get_entity=partial(get_entity, tenant_id, kb_id),
|
||||
set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl),
|
||||
get_relation=partial(get_relation, tenant_id, kb_id),
|
||||
set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl)
|
||||
)
|
||||
ents, rels = ext(chunks, callback)
|
||||
self.graph = nx.Graph()
|
||||
for en in ents:
|
||||
self.graph.add_node(en["entity_name"], entity_type=en["entity_type"])#, description=en["description"])
|
||||
def graphrag_task_set(tenant_id, kb_id, doc_id) -> bool:
|
||||
key = f"graphrag:{tenant_id}:{kb_id}"
|
||||
ok = REDIS_CONN.set(key, doc_id, exp=3600 * 24)
|
||||
if not ok:
|
||||
raise Exception(f"Faild to set the {key} to {doc_id}")
|
||||
|
||||
for rel in rels:
|
||||
self.graph.add_edge(
|
||||
rel["src_id"],
|
||||
rel["tgt_id"],
|
||||
weight=rel["weight"],
|
||||
#description=rel["description"]
|
||||
|
||||
def graphrag_task_get(tenant_id, kb_id) -> str | None:
|
||||
key = f"graphrag:{tenant_id}:{kb_id}"
|
||||
doc_id = REDIS_CONN.get(key)
|
||||
return doc_id
|
||||
|
||||
|
||||
async def run_graphrag(
|
||||
row: dict,
|
||||
language,
|
||||
with_resolution: bool,
|
||||
with_community: bool,
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
):
|
||||
start = trio.current_time()
|
||||
tenant_id, kb_id, doc_id = row["tenant_id"], str(row["kb_id"]), row["doc_id"]
|
||||
chunks = []
|
||||
for d in settings.retrievaler.chunk_list(
|
||||
doc_id, tenant_id, [kb_id], fields=["content_with_weight", "doc_id"]
|
||||
):
|
||||
chunks.append(d["content_with_weight"])
|
||||
|
||||
graph, doc_ids = await update_graph(
|
||||
LightKGExt
|
||||
if row["parser_config"]["graphrag"]["method"] != "general"
|
||||
else GeneralKGExt,
|
||||
tenant_id,
|
||||
kb_id,
|
||||
doc_id,
|
||||
chunks,
|
||||
language,
|
||||
row["parser_config"]["graphrag"]["entity_types"],
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
)
|
||||
if not graph:
|
||||
return
|
||||
if with_resolution or with_community:
|
||||
graphrag_task_set(tenant_id, kb_id, doc_id)
|
||||
if with_resolution:
|
||||
await resolve_entities(
|
||||
graph,
|
||||
doc_ids,
|
||||
tenant_id,
|
||||
kb_id,
|
||||
doc_id,
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
)
|
||||
if with_community:
|
||||
await extract_community(
|
||||
graph,
|
||||
doc_ids,
|
||||
tenant_id,
|
||||
kb_id,
|
||||
doc_id,
|
||||
chat_model,
|
||||
embedding_model,
|
||||
callback,
|
||||
)
|
||||
now = trio.current_time()
|
||||
callback(msg=f"GraphRAG for doc {doc_id} done in {now - start:.2f} seconds.")
|
||||
return
|
||||
|
||||
|
||||
async def update_graph(
|
||||
extractor: Extractor,
|
||||
tenant_id: str,
|
||||
kb_id: str,
|
||||
doc_id: str,
|
||||
chunks: list[str],
|
||||
language,
|
||||
entity_types,
|
||||
llm_bdl,
|
||||
embed_bdl,
|
||||
callback,
|
||||
):
|
||||
contains = await does_graph_contains(tenant_id, kb_id, doc_id)
|
||||
if contains:
|
||||
callback(msg=f"Graph already contains {doc_id}, cancel myself")
|
||||
return None, None
|
||||
start = trio.current_time()
|
||||
ext = extractor(
|
||||
llm_bdl,
|
||||
language=language,
|
||||
entity_types=entity_types,
|
||||
get_entity=partial(get_entity, tenant_id, kb_id),
|
||||
set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl),
|
||||
get_relation=partial(get_relation, tenant_id, kb_id),
|
||||
set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl),
|
||||
)
|
||||
ents, rels = await ext(doc_id, chunks, callback)
|
||||
subgraph = nx.Graph()
|
||||
for en in ents:
|
||||
subgraph.add_node(en["entity_name"], entity_type=en["entity_type"])
|
||||
|
||||
for rel in rels:
|
||||
subgraph.add_edge(
|
||||
rel["src_id"],
|
||||
rel["tgt_id"],
|
||||
weight=rel["weight"],
|
||||
# description=rel["description"]
|
||||
)
|
||||
# TODO: infinity doesn't support array search
|
||||
chunk = {
|
||||
"content_with_weight": json.dumps(
|
||||
nx.node_link_data(subgraph, edges="edges"), ensure_ascii=False, indent=2
|
||||
),
|
||||
"knowledge_graph_kwd": "subgraph",
|
||||
"kb_id": kb_id,
|
||||
"source_id": [doc_id],
|
||||
"available_int": 0,
|
||||
"removed_kwd": "N",
|
||||
}
|
||||
cid = chunk_id(chunk)
|
||||
await trio.to_thread.run_sync(
|
||||
lambda: settings.docStoreConn.insert(
|
||||
[{"id": cid, **chunk}], search.index_name(tenant_id), kb_id
|
||||
)
|
||||
)
|
||||
now = trio.current_time()
|
||||
callback(msg=f"generated subgraph for doc {doc_id} in {now - start:.2f} seconds.")
|
||||
start = now
|
||||
|
||||
while True:
|
||||
new_graph = subgraph
|
||||
now_docids = set([doc_id])
|
||||
old_graph, old_doc_ids = await get_graph(tenant_id, kb_id)
|
||||
if old_graph is not None:
|
||||
logging.info("Merge with an exiting graph...................")
|
||||
new_graph = graph_merge(old_graph, subgraph)
|
||||
await update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, new_graph, 2)
|
||||
if old_doc_ids:
|
||||
for old_doc_id in old_doc_ids:
|
||||
now_docids.add(old_doc_id)
|
||||
old_doc_ids2 = await get_graph_doc_ids(tenant_id, kb_id)
|
||||
delta_doc_ids = set(old_doc_ids2) - set(old_doc_ids)
|
||||
if delta_doc_ids:
|
||||
callback(
|
||||
msg="The global graph has changed during merging, try again"
|
||||
)
|
||||
|
||||
with RedisDistributedLock(kb_id, 60*60):
|
||||
old_graph, old_doc_ids = get_graph(tenant_id, kb_id)
|
||||
if old_graph is not None:
|
||||
logging.info("Merge with an exiting graph...................")
|
||||
self.graph = reduce(graph_merge, [old_graph, self.graph])
|
||||
update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, self.graph, 2)
|
||||
if old_doc_ids:
|
||||
docids.extend(old_doc_ids)
|
||||
docids = list(set(docids))
|
||||
set_graph(tenant_id, kb_id, self.graph, docids)
|
||||
await trio.sleep(1)
|
||||
continue
|
||||
break
|
||||
await set_graph(tenant_id, kb_id, new_graph, list(now_docids))
|
||||
now = trio.current_time()
|
||||
callback(
|
||||
msg=f"merging subgraph for doc {doc_id} into the global graph done in {now - start:.2f} seconds."
|
||||
)
|
||||
return new_graph, now_docids
|
||||
|
||||
|
||||
class WithResolution(Dealer):
|
||||
def __init__(self,
|
||||
tenant_id: str,
|
||||
kb_id: str,
|
||||
llm_bdl,
|
||||
embed_bdl=None,
|
||||
callback=None
|
||||
):
|
||||
self.llm_bdl = llm_bdl
|
||||
self.embed_bdl = embed_bdl
|
||||
async def resolve_entities(
|
||||
graph,
|
||||
doc_ids,
|
||||
tenant_id: str,
|
||||
kb_id: str,
|
||||
doc_id: str,
|
||||
llm_bdl,
|
||||
embed_bdl,
|
||||
callback,
|
||||
):
|
||||
working_doc_id = graphrag_task_get(tenant_id, kb_id)
|
||||
if doc_id != working_doc_id:
|
||||
callback(
|
||||
msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself"
|
||||
)
|
||||
return
|
||||
start = trio.current_time()
|
||||
er = EntityResolution(
|
||||
llm_bdl,
|
||||
get_entity=partial(get_entity, tenant_id, kb_id),
|
||||
set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl),
|
||||
get_relation=partial(get_relation, tenant_id, kb_id),
|
||||
set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl),
|
||||
)
|
||||
reso = await er(graph, callback=callback)
|
||||
graph = reso.graph
|
||||
callback(msg=f"Graph resolution removed {len(reso.removed_entities)} nodes.")
|
||||
await update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, 2)
|
||||
callback(msg="Graph resolution updated pagerank.")
|
||||
|
||||
with RedisDistributedLock(kb_id, 60*60):
|
||||
self.graph, doc_ids = get_graph(tenant_id, kb_id)
|
||||
if not self.graph:
|
||||
logging.error(f"Faild to fetch the graph. tenant_id:{kb_id}, kb_id:{kb_id}")
|
||||
if callback:
|
||||
callback(-1, msg="Faild to fetch the graph.")
|
||||
return
|
||||
working_doc_id = graphrag_task_get(tenant_id, kb_id)
|
||||
if doc_id != working_doc_id:
|
||||
callback(
|
||||
msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself"
|
||||
)
|
||||
return
|
||||
await set_graph(tenant_id, kb_id, graph, doc_ids)
|
||||
|
||||
if callback:
|
||||
callback(msg="Fetch the existing graph.")
|
||||
er = EntityResolution(self.llm_bdl,
|
||||
get_entity=partial(get_entity, tenant_id, kb_id),
|
||||
set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl),
|
||||
get_relation=partial(get_relation, tenant_id, kb_id),
|
||||
set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl))
|
||||
reso = er(self.graph)
|
||||
self.graph = reso.graph
|
||||
logging.info("Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
|
||||
if callback:
|
||||
callback(msg="Graph resolution is done. Remove {} nodes.".format(len(reso.removed_entities)))
|
||||
update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, self.graph, 2)
|
||||
set_graph(tenant_id, kb_id, self.graph, doc_ids)
|
||||
|
||||
settings.docStoreConn.delete({
|
||||
"knowledge_graph_kwd": "relation",
|
||||
"kb_id": kb_id,
|
||||
"from_entity_kwd": reso.removed_entities
|
||||
}, search.index_name(tenant_id), kb_id)
|
||||
settings.docStoreConn.delete({
|
||||
"knowledge_graph_kwd": "relation",
|
||||
"kb_id": kb_id,
|
||||
"to_entity_kwd": reso.removed_entities
|
||||
}, search.index_name(tenant_id), kb_id)
|
||||
settings.docStoreConn.delete({
|
||||
"knowledge_graph_kwd": "entity",
|
||||
"kb_id": kb_id,
|
||||
"entity_kwd": reso.removed_entities
|
||||
}, search.index_name(tenant_id), kb_id)
|
||||
|
||||
|
||||
class WithCommunity(Dealer):
|
||||
def __init__(self,
|
||||
tenant_id: str,
|
||||
kb_id: str,
|
||||
llm_bdl,
|
||||
embed_bdl=None,
|
||||
callback=None
|
||||
):
|
||||
|
||||
self.community_structure = None
|
||||
self.community_reports = None
|
||||
self.llm_bdl = llm_bdl
|
||||
self.embed_bdl = embed_bdl
|
||||
|
||||
with RedisDistributedLock(kb_id, 60*60):
|
||||
self.graph, doc_ids = get_graph(tenant_id, kb_id)
|
||||
if not self.graph:
|
||||
logging.error(f"Faild to fetch the graph. tenant_id:{kb_id}, kb_id:{kb_id}")
|
||||
if callback:
|
||||
callback(-1, msg="Faild to fetch the graph.")
|
||||
return
|
||||
if callback:
|
||||
callback(msg="Fetch the existing graph.")
|
||||
|
||||
cr = CommunityReportsExtractor(self.llm_bdl,
|
||||
get_entity=partial(get_entity, tenant_id, kb_id),
|
||||
set_entity=partial(set_entity, tenant_id, kb_id, self.embed_bdl),
|
||||
get_relation=partial(get_relation, tenant_id, kb_id),
|
||||
set_relation=partial(set_relation, tenant_id, kb_id, self.embed_bdl))
|
||||
cr = cr(self.graph, callback=callback)
|
||||
self.community_structure = cr.structured_output
|
||||
self.community_reports = cr.output
|
||||
set_graph(tenant_id, kb_id, self.graph, doc_ids)
|
||||
|
||||
if callback:
|
||||
callback(msg="Graph community extraction is done. Indexing {} reports.".format(len(cr.structured_output)))
|
||||
|
||||
settings.docStoreConn.delete({
|
||||
"knowledge_graph_kwd": "community_report",
|
||||
"kb_id": kb_id
|
||||
}, search.index_name(tenant_id), kb_id)
|
||||
|
||||
for stru, rep in zip(self.community_structure, self.community_reports):
|
||||
obj = {
|
||||
"report": rep,
|
||||
"evidences": "\n".join([f["explanation"] for f in stru["findings"]])
|
||||
}
|
||||
chunk = {
|
||||
"docnm_kwd": stru["title"],
|
||||
"title_tks": rag_tokenizer.tokenize(stru["title"]),
|
||||
"content_with_weight": json.dumps(obj, ensure_ascii=False),
|
||||
"content_ltks": rag_tokenizer.tokenize(obj["report"] +" "+ obj["evidences"]),
|
||||
"knowledge_graph_kwd": "community_report",
|
||||
"weight_flt": stru["weight"],
|
||||
"entities_kwd": stru["entities"],
|
||||
"important_kwd": stru["entities"],
|
||||
await trio.to_thread.run_sync(
|
||||
lambda: settings.docStoreConn.delete(
|
||||
{
|
||||
"knowledge_graph_kwd": "relation",
|
||||
"kb_id": kb_id,
|
||||
"source_id": doc_ids,
|
||||
"available_int": 0
|
||||
}
|
||||
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
|
||||
#try:
|
||||
# ebd, _ = self.embed_bdl.encode([", ".join(community["entities"])])
|
||||
# chunk["q_%d_vec" % len(ebd[0])] = ebd[0]
|
||||
#except Exception as e:
|
||||
# logging.exception(f"Fail to embed entity relation: {e}")
|
||||
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id))
|
||||
"from_entity_kwd": reso.removed_entities,
|
||||
},
|
||||
search.index_name(tenant_id),
|
||||
kb_id,
|
||||
)
|
||||
)
|
||||
await trio.to_thread.run_sync(
|
||||
lambda: settings.docStoreConn.delete(
|
||||
{
|
||||
"knowledge_graph_kwd": "relation",
|
||||
"kb_id": kb_id,
|
||||
"to_entity_kwd": reso.removed_entities,
|
||||
},
|
||||
search.index_name(tenant_id),
|
||||
kb_id,
|
||||
)
|
||||
)
|
||||
await trio.to_thread.run_sync(
|
||||
lambda: settings.docStoreConn.delete(
|
||||
{
|
||||
"knowledge_graph_kwd": "entity",
|
||||
"kb_id": kb_id,
|
||||
"entity_kwd": reso.removed_entities,
|
||||
},
|
||||
search.index_name(tenant_id),
|
||||
kb_id,
|
||||
)
|
||||
)
|
||||
now = trio.current_time()
|
||||
callback(msg=f"Graph resolution done in {now - start:.2f}s.")
|
||||
|
||||
|
||||
async def extract_community(
|
||||
graph,
|
||||
doc_ids,
|
||||
tenant_id: str,
|
||||
kb_id: str,
|
||||
doc_id: str,
|
||||
llm_bdl,
|
||||
embed_bdl,
|
||||
callback,
|
||||
):
|
||||
working_doc_id = graphrag_task_get(tenant_id, kb_id)
|
||||
if doc_id != working_doc_id:
|
||||
callback(
|
||||
msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself"
|
||||
)
|
||||
return
|
||||
start = trio.current_time()
|
||||
ext = CommunityReportsExtractor(
|
||||
llm_bdl,
|
||||
get_entity=partial(get_entity, tenant_id, kb_id),
|
||||
set_entity=partial(set_entity, tenant_id, kb_id, embed_bdl),
|
||||
get_relation=partial(get_relation, tenant_id, kb_id),
|
||||
set_relation=partial(set_relation, tenant_id, kb_id, embed_bdl),
|
||||
)
|
||||
cr = await ext(graph, callback=callback)
|
||||
community_structure = cr.structured_output
|
||||
community_reports = cr.output
|
||||
working_doc_id = graphrag_task_get(tenant_id, kb_id)
|
||||
if doc_id != working_doc_id:
|
||||
callback(
|
||||
msg=f"Another graphrag task of doc_id {working_doc_id} is working on this kb, cancel myself"
|
||||
)
|
||||
return
|
||||
await set_graph(tenant_id, kb_id, graph, doc_ids)
|
||||
|
||||
now = trio.current_time()
|
||||
callback(
|
||||
msg=f"Graph extracted {len(cr.structured_output)} communities in {now - start:.2f}s."
|
||||
)
|
||||
start = now
|
||||
await trio.to_thread.run_sync(
|
||||
lambda: settings.docStoreConn.delete(
|
||||
{"knowledge_graph_kwd": "community_report", "kb_id": kb_id},
|
||||
search.index_name(tenant_id),
|
||||
kb_id,
|
||||
)
|
||||
)
|
||||
for stru, rep in zip(community_structure, community_reports):
|
||||
obj = {
|
||||
"report": rep,
|
||||
"evidences": "\n".join([f["explanation"] for f in stru["findings"]]),
|
||||
}
|
||||
chunk = {
|
||||
"docnm_kwd": stru["title"],
|
||||
"title_tks": rag_tokenizer.tokenize(stru["title"]),
|
||||
"content_with_weight": json.dumps(obj, ensure_ascii=False),
|
||||
"content_ltks": rag_tokenizer.tokenize(
|
||||
obj["report"] + " " + obj["evidences"]
|
||||
),
|
||||
"knowledge_graph_kwd": "community_report",
|
||||
"weight_flt": stru["weight"],
|
||||
"entities_kwd": stru["entities"],
|
||||
"important_kwd": stru["entities"],
|
||||
"kb_id": kb_id,
|
||||
"source_id": doc_ids,
|
||||
"available_int": 0,
|
||||
}
|
||||
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(
|
||||
chunk["content_ltks"]
|
||||
)
|
||||
# try:
|
||||
# ebd, _ = embed_bdl.encode([", ".join(community["entities"])])
|
||||
# chunk["q_%d_vec" % len(ebd[0])] = ebd[0]
|
||||
# except Exception as e:
|
||||
# logging.exception(f"Fail to embed entity relation: {e}")
|
||||
await trio.to_thread.run_sync(
|
||||
lambda: settings.docStoreConn.insert(
|
||||
[{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id)
|
||||
)
|
||||
)
|
||||
|
||||
now = trio.current_time()
|
||||
callback(
|
||||
msg=f"Graph indexed {len(cr.structured_output)} communities in {now - start:.2f}s."
|
||||
)
|
||||
return community_structure, community_reports
|
||||
|
|
|
@ -120,6 +120,9 @@ def run(graph: nx.Graph, args: dict[str, Any]) -> dict[int, dict[str, dict]]:
|
|||
result = {}
|
||||
results_by_level[level] = result
|
||||
for node_id, raw_community_id in node_id_to_community_map[level].items():
|
||||
if node_id not in graph.nodes:
|
||||
logging.warning(f"Node {node_id} not found in the graph.")
|
||||
continue
|
||||
community_id = str(raw_community_id)
|
||||
if community_id not in result:
|
||||
result[community_id] = {"weight": 0, "nodes": []}
|
||||
|
|
|
@ -16,16 +16,14 @@
|
|||
|
||||
import logging
|
||||
import collections
|
||||
import os
|
||||
import re
|
||||
import traceback
|
||||
from typing import Any
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
import trio
|
||||
|
||||
from graphrag.general.extractor import Extractor
|
||||
from graphrag.general.mind_map_prompt import MIND_MAP_EXTRACTION_PROMPT
|
||||
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements
|
||||
from graphrag.utils import ErrorHandlerFn, perform_variable_replacements, chat_limiter
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
import markdown_to_json
|
||||
from functools import reduce
|
||||
|
@ -80,63 +78,47 @@ class MindMapExtractor(Extractor):
|
|||
)
|
||||
return arr
|
||||
|
||||
def __call__(
|
||||
async def __call__(
|
||||
self, sections: list[str], prompt_variables: dict[str, Any] | None = None
|
||||
) -> MindMapResult:
|
||||
"""Call method definition."""
|
||||
if prompt_variables is None:
|
||||
prompt_variables = {}
|
||||
|
||||
try:
|
||||
res = []
|
||||
max_workers = int(os.environ.get('MINDMAP_EXTRACTOR_MAX_WORKERS', 12))
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as exe:
|
||||
threads = []
|
||||
token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512)
|
||||
texts = []
|
||||
cnt = 0
|
||||
for i in range(len(sections)):
|
||||
section_cnt = num_tokens_from_string(sections[i])
|
||||
if cnt + section_cnt >= token_count and texts:
|
||||
threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
|
||||
texts = []
|
||||
cnt = 0
|
||||
texts.append(sections[i])
|
||||
cnt += section_cnt
|
||||
if texts:
|
||||
threads.append(exe.submit(self._process_document, "".join(texts), prompt_variables))
|
||||
|
||||
for i, _ in enumerate(threads):
|
||||
res.append(_.result())
|
||||
|
||||
if not res:
|
||||
return MindMapResult(output={"id": "root", "children": []})
|
||||
|
||||
merge_json = reduce(self._merge, res)
|
||||
if len(merge_json) > 1:
|
||||
keys = [re.sub(r"\*+", "", k) for k, v in merge_json.items() if isinstance(v, dict)]
|
||||
keyset = set(i for i in keys if i)
|
||||
merge_json = {
|
||||
"id": "root",
|
||||
"children": [
|
||||
{
|
||||
"id": self._key(k),
|
||||
"children": self._be_children(v, keyset)
|
||||
}
|
||||
for k, v in merge_json.items() if isinstance(v, dict) and self._key(k)
|
||||
]
|
||||
}
|
||||
else:
|
||||
k = self._key(list(merge_json.keys())[0])
|
||||
merge_json = {"id": k, "children": self._be_children(list(merge_json.items())[0][1], {k})}
|
||||
|
||||
except Exception as e:
|
||||
logging.exception("error mind graph")
|
||||
self._on_error(
|
||||
e,
|
||||
traceback.format_exc(), None
|
||||
)
|
||||
merge_json = {"error": str(e)}
|
||||
res = []
|
||||
token_count = max(self._llm.max_length * 0.8, self._llm.max_length - 512)
|
||||
texts = []
|
||||
cnt = 0
|
||||
async with trio.open_nursery() as nursery:
|
||||
for i in range(len(sections)):
|
||||
section_cnt = num_tokens_from_string(sections[i])
|
||||
if cnt + section_cnt >= token_count and texts:
|
||||
nursery.start_soon(lambda: self._process_document("".join(texts), prompt_variables, res))
|
||||
texts = []
|
||||
cnt = 0
|
||||
texts.append(sections[i])
|
||||
cnt += section_cnt
|
||||
if texts:
|
||||
nursery.start_soon(lambda: self._process_document("".join(texts), prompt_variables, res))
|
||||
if not res:
|
||||
return MindMapResult(output={"id": "root", "children": []})
|
||||
merge_json = reduce(self._merge, res)
|
||||
if len(merge_json) > 1:
|
||||
keys = [re.sub(r"\*+", "", k) for k, v in merge_json.items() if isinstance(v, dict)]
|
||||
keyset = set(i for i in keys if i)
|
||||
merge_json = {
|
||||
"id": "root",
|
||||
"children": [
|
||||
{
|
||||
"id": self._key(k),
|
||||
"children": self._be_children(v, keyset)
|
||||
}
|
||||
for k, v in merge_json.items() if isinstance(v, dict) and self._key(k)
|
||||
]
|
||||
}
|
||||
else:
|
||||
k = self._key(list(merge_json.keys())[0])
|
||||
merge_json = {"id": k, "children": self._be_children(list(merge_json.items())[0][1], {k})}
|
||||
|
||||
return MindMapResult(output=merge_json)
|
||||
|
||||
|
@ -181,8 +163,8 @@ class MindMapExtractor(Extractor):
|
|||
|
||||
return self._list_to_kv(to_ret)
|
||||
|
||||
def _process_document(
|
||||
self, text: str, prompt_variables: dict[str, str]
|
||||
async def _process_document(
|
||||
self, text: str, prompt_variables: dict[str, str], out_res
|
||||
) -> str:
|
||||
variables = {
|
||||
**prompt_variables,
|
||||
|
@ -190,8 +172,9 @@ class MindMapExtractor(Extractor):
|
|||
}
|
||||
text = perform_variable_replacements(self._mind_map_prompt, variables=variables)
|
||||
gen_conf = {"temperature": 0.5}
|
||||
response = self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf)
|
||||
async with chat_limiter:
|
||||
response = await trio.to_thread.run_sync(lambda: self._chat(text, [{"role": "user", "content": "Output:"}], gen_conf))
|
||||
response = re.sub(r"```[^\n]*", "", response)
|
||||
logging.debug(response)
|
||||
logging.debug(self._todict(markdown_to_json.dictify(response)))
|
||||
return self._todict(markdown_to_json.dictify(response))
|
||||
out_res.append(self._todict(markdown_to_json.dictify(response)))
|
||||
|
|
|
@ -16,8 +16,9 @@
|
|||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
import logging
|
||||
import networkx as nx
|
||||
import trio
|
||||
|
||||
from api import settings
|
||||
from api.db import LLMType
|
||||
|
@ -25,39 +26,85 @@ from api.db.services.document_service import DocumentService
|
|||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.user_service import TenantService
|
||||
from graphrag.general.index import WithCommunity, Dealer, WithResolution
|
||||
from graphrag.light.graph_extractor import GraphExtractor
|
||||
from rag.utils.redis_conn import RedisDistributedLock
|
||||
from graphrag.general.graph_extractor import GraphExtractor
|
||||
from graphrag.general.index import update_graph, with_resolution, with_community
|
||||
|
||||
settings.init_settings()
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def callback(prog=None, msg="Processing..."):
|
||||
logging.info(msg)
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True)
|
||||
parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True)
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--tenant_id",
|
||||
default=False,
|
||||
help="Tenant ID",
|
||||
action="store",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--doc_id",
|
||||
default=False,
|
||||
help="Document ID",
|
||||
action="store",
|
||||
required=True,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
e, doc = DocumentService.get_by_id(args.doc_id)
|
||||
if not e:
|
||||
raise LookupError("Document not found.")
|
||||
kb_id = doc.kb_id
|
||||
|
||||
chunks = [d["content_with_weight"] for d in
|
||||
settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, [kb_id], max_count=6,
|
||||
fields=["content_with_weight"])]
|
||||
chunks = [("x", c) for c in chunks]
|
||||
|
||||
RedisDistributedLock.clean_lock(kb_id)
|
||||
chunks = [
|
||||
d["content_with_weight"]
|
||||
for d in settings.retrievaler.chunk_list(
|
||||
args.doc_id,
|
||||
args.tenant_id,
|
||||
[kb_id],
|
||||
max_count=6,
|
||||
fields=["content_with_weight"],
|
||||
)
|
||||
]
|
||||
|
||||
_, tenant = TenantService.get_by_id(args.tenant_id)
|
||||
llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id)
|
||||
_, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
|
||||
|
||||
dealer = Dealer(GraphExtractor, args.tenant_id, kb_id, llm_bdl, chunks, "English", embed_bdl=embed_bdl)
|
||||
print(json.dumps(nx.node_link_data(dealer.graph), ensure_ascii=False, indent=2))
|
||||
graph, doc_ids = await update_graph(
|
||||
GraphExtractor,
|
||||
args.tenant_id,
|
||||
kb_id,
|
||||
args.doc_id,
|
||||
chunks,
|
||||
"English",
|
||||
llm_bdl,
|
||||
embed_bdl,
|
||||
callback,
|
||||
)
|
||||
print(json.dumps(nx.node_link_data(graph), ensure_ascii=False, indent=2))
|
||||
|
||||
dealer = WithResolution(args.tenant_id, kb_id, llm_bdl, embed_bdl)
|
||||
dealer = WithCommunity(args.tenant_id, kb_id, llm_bdl, embed_bdl)
|
||||
await with_resolution(
|
||||
args.tenant_id, kb_id, args.doc_id, llm_bdl, embed_bdl, callback
|
||||
)
|
||||
community_structure, community_reports = await with_community(
|
||||
args.tenant_id, kb_id, args.doc_id, llm_bdl, embed_bdl, callback
|
||||
)
|
||||
|
||||
print("------------------ COMMUNITY REPORT ----------------------\n", dealer.community_reports)
|
||||
print(json.dumps(dealer.community_structure, ensure_ascii=False, indent=2))
|
||||
print(
|
||||
"------------------ COMMUNITY STRUCTURE--------------------\n",
|
||||
json.dumps(community_structure, ensure_ascii=False, indent=2),
|
||||
)
|
||||
print(
|
||||
"------------------ COMMUNITY REPORTS----------------------\n",
|
||||
community_reports,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
trio.run(main)
|
||||
|
|
|
@ -4,16 +4,16 @@
|
|||
Reference:
|
||||
- [graphrag](https://github.com/microsoft/graphrag)
|
||||
"""
|
||||
import logging
|
||||
import re
|
||||
from typing import Any, Callable
|
||||
from dataclasses import dataclass
|
||||
from graphrag.general.extractor import Extractor, ENTITY_EXTRACTION_MAX_GLEANINGS
|
||||
from graphrag.light.graph_prompt import PROMPTS
|
||||
from graphrag.utils import pack_user_ass_to_openai_messages, split_string_by_multi_markers
|
||||
from graphrag.utils import pack_user_ass_to_openai_messages, split_string_by_multi_markers, chat_limiter
|
||||
from rag.llm.chat_model import Base as CompletionLLM
|
||||
import networkx as nx
|
||||
from rag.utils import num_tokens_from_string
|
||||
import trio
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -82,7 +82,7 @@ class GraphExtractor(Extractor):
|
|||
)
|
||||
self._left_token_count = max(llm_invoker.max_length * 0.6, self._left_token_count)
|
||||
|
||||
def _process_single_content(self, chunk_key_dp: tuple[str, str]):
|
||||
async def _process_single_content(self, chunk_key_dp: tuple[str, str], chunk_seq: int, num_chunks: int, out_results):
|
||||
token_count = 0
|
||||
chunk_key = chunk_key_dp[0]
|
||||
content = chunk_key_dp[1]
|
||||
|
@ -90,38 +90,39 @@ class GraphExtractor(Extractor):
|
|||
**self._context_base, input_text="{input_text}"
|
||||
).format(**self._context_base, input_text=content)
|
||||
|
||||
try:
|
||||
gen_conf = {"temperature": 0.8}
|
||||
final_result = self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf)
|
||||
token_count += num_tokens_from_string(hint_prompt + final_result)
|
||||
history = pack_user_ass_to_openai_messages("Output:", final_result, self._continue_prompt)
|
||||
for now_glean_index in range(self._max_gleanings):
|
||||
glean_result = self._chat(hint_prompt, history, gen_conf)
|
||||
history.extend([{"role": "assistant", "content": glean_result}, {"role": "user", "content": self._continue_prompt}])
|
||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt)
|
||||
final_result += glean_result
|
||||
if now_glean_index == self._max_gleanings - 1:
|
||||
break
|
||||
gen_conf = {"temperature": 0.8}
|
||||
async with chat_limiter:
|
||||
final_result = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, [{"role": "user", "content": "Output:"}], gen_conf))
|
||||
token_count += num_tokens_from_string(hint_prompt + final_result)
|
||||
history = pack_user_ass_to_openai_messages("Output:", final_result, self._continue_prompt)
|
||||
for now_glean_index in range(self._max_gleanings):
|
||||
async with chat_limiter:
|
||||
glean_result = await trio.to_thread.run_sync(lambda: self._chat(hint_prompt, history, gen_conf))
|
||||
history.extend([{"role": "assistant", "content": glean_result}, {"role": "user", "content": self._continue_prompt}])
|
||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + hint_prompt + self._continue_prompt)
|
||||
final_result += glean_result
|
||||
if now_glean_index == self._max_gleanings - 1:
|
||||
break
|
||||
|
||||
if_loop_result = self._chat(self._if_loop_prompt, history, gen_conf)
|
||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt)
|
||||
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
||||
if if_loop_result != "yes":
|
||||
break
|
||||
async with chat_limiter:
|
||||
if_loop_result = await trio.to_thread.run_sync(lambda: self._chat(self._if_loop_prompt, history, gen_conf))
|
||||
token_count += num_tokens_from_string("\n".join([m["content"] for m in history]) + if_loop_result + self._if_loop_prompt)
|
||||
if_loop_result = if_loop_result.strip().strip('"').strip("'").lower()
|
||||
if if_loop_result != "yes":
|
||||
break
|
||||
|
||||
records = split_string_by_multi_markers(
|
||||
final_result,
|
||||
[self._context_base["record_delimiter"], self._context_base["completion_delimiter"]],
|
||||
)
|
||||
rcds = []
|
||||
for record in records:
|
||||
record = re.search(r"\((.*)\)", record)
|
||||
if record is None:
|
||||
continue
|
||||
rcds.append(record.group(1))
|
||||
records = rcds
|
||||
maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, self._context_base["tuple_delimiter"])
|
||||
return maybe_nodes, maybe_edges, token_count
|
||||
except Exception as e:
|
||||
logging.exception("error extracting graph")
|
||||
return e, None, None
|
||||
records = split_string_by_multi_markers(
|
||||
final_result,
|
||||
[self._context_base["record_delimiter"], self._context_base["completion_delimiter"]],
|
||||
)
|
||||
rcds = []
|
||||
for record in records:
|
||||
record = re.search(r"\((.*)\)", record)
|
||||
if record is None:
|
||||
continue
|
||||
rcds.append(record.group(1))
|
||||
records = rcds
|
||||
maybe_nodes, maybe_edges = self._entities_and_relations(chunk_key, records, self._context_base["tuple_delimiter"])
|
||||
out_results.append((maybe_nodes, maybe_edges, token_count))
|
||||
if self.callback:
|
||||
self.callback(0.5+0.1*len(out_results)/num_chunks, msg = f"Entities extraction of chunk {chunk_seq} {len(out_results)}/{num_chunks} done, {len(maybe_nodes)} nodes, {len(maybe_edges)} edges, {token_count} tokens.")
|
||||
|
|
|
@ -18,22 +18,42 @@ import argparse
|
|||
import json
|
||||
from api import settings
|
||||
import networkx as nx
|
||||
import logging
|
||||
import trio
|
||||
|
||||
from api.db import LLMType
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.llm_service import LLMBundle
|
||||
from api.db.services.user_service import TenantService
|
||||
from graphrag.general.index import Dealer
|
||||
from graphrag.general.index import update_graph
|
||||
from graphrag.light.graph_extractor import GraphExtractor
|
||||
from rag.utils.redis_conn import RedisDistributedLock
|
||||
|
||||
settings.init_settings()
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
def callback(prog=None, msg="Processing..."):
|
||||
logging.info(msg)
|
||||
|
||||
|
||||
async def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('-t', '--tenant_id', default=False, help="Tenant ID", action='store', required=True)
|
||||
parser.add_argument('-d', '--doc_id', default=False, help="Document ID", action='store', required=True)
|
||||
parser.add_argument(
|
||||
"-t",
|
||||
"--tenant_id",
|
||||
default=False,
|
||||
help="Tenant ID",
|
||||
action="store",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument(
|
||||
"-d",
|
||||
"--doc_id",
|
||||
default=False,
|
||||
help="Document ID",
|
||||
action="store",
|
||||
required=True,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
e, doc = DocumentService.get_by_id(args.doc_id)
|
||||
|
@ -41,18 +61,36 @@ if __name__ == "__main__":
|
|||
raise LookupError("Document not found.")
|
||||
kb_id = doc.kb_id
|
||||
|
||||
chunks = [d["content_with_weight"] for d in
|
||||
settings.retrievaler.chunk_list(args.doc_id, args.tenant_id, [kb_id], max_count=6,
|
||||
fields=["content_with_weight"])]
|
||||
chunks = [("x", c) for c in chunks]
|
||||
|
||||
RedisDistributedLock.clean_lock(kb_id)
|
||||
chunks = [
|
||||
d["content_with_weight"]
|
||||
for d in settings.retrievaler.chunk_list(
|
||||
args.doc_id,
|
||||
args.tenant_id,
|
||||
[kb_id],
|
||||
max_count=6,
|
||||
fields=["content_with_weight"],
|
||||
)
|
||||
]
|
||||
|
||||
_, tenant = TenantService.get_by_id(args.tenant_id)
|
||||
llm_bdl = LLMBundle(args.tenant_id, LLMType.CHAT, tenant.llm_id)
|
||||
_, kb = KnowledgebaseService.get_by_id(kb_id)
|
||||
embed_bdl = LLMBundle(args.tenant_id, LLMType.EMBEDDING, kb.embd_id)
|
||||
|
||||
dealer = Dealer(GraphExtractor, args.tenant_id, kb_id, llm_bdl, chunks, "English", embed_bdl=embed_bdl)
|
||||
graph, doc_ids = await update_graph(
|
||||
GraphExtractor,
|
||||
args.tenant_id,
|
||||
kb_id,
|
||||
args.doc_id,
|
||||
chunks,
|
||||
"English",
|
||||
llm_bdl,
|
||||
embed_bdl,
|
||||
callback,
|
||||
)
|
||||
|
||||
print(json.dumps(nx.node_link_data(dealer.graph), ensure_ascii=False, indent=2))
|
||||
print(json.dumps(nx.node_link_data(graph), ensure_ascii=False, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
trio.run(main)
|
||||
|
|
|
@ -228,7 +228,7 @@ class KGSearch(Dealer):
|
|||
ents.append({
|
||||
"Entity": n,
|
||||
"Score": "%.2f" % (ent["sim"] * ent["pagerank"]),
|
||||
"Description": json.loads(ent["description"]).get("description", "")
|
||||
"Description": json.loads(ent["description"]).get("description", "") if ent["description"] else ""
|
||||
})
|
||||
max_token -= num_tokens_from_string(str(ents[-1]))
|
||||
if max_token <= 0:
|
||||
|
|
|
@ -15,6 +15,8 @@ from collections import defaultdict
|
|||
from copy import deepcopy
|
||||
from hashlib import md5
|
||||
from typing import Any, Callable
|
||||
import os
|
||||
import trio
|
||||
|
||||
import networkx as nx
|
||||
import numpy as np
|
||||
|
@ -28,6 +30,7 @@ from rag.utils.redis_conn import REDIS_CONN
|
|||
|
||||
ErrorHandlerFn = Callable[[BaseException | None, str | None, dict | None], None]
|
||||
|
||||
chat_limiter = trio.CapacityLimiter(int(os.environ.get('MAX_CONCURRENT_CHATS', 10)))
|
||||
|
||||
def perform_variable_replacements(
|
||||
input: str, history: list[dict] | None = None, variables: dict | None = None
|
||||
|
@ -234,8 +237,33 @@ def is_float_regex(value):
|
|||
def chunk_id(chunk):
|
||||
return xxhash.xxh64((chunk["content_with_weight"] + chunk["kb_id"]).encode("utf-8")).hexdigest()
|
||||
|
||||
def get_entity_cache(tenant_id, kb_id, ent_name) -> str | list[str]:
|
||||
hasher = xxhash.xxh64()
|
||||
hasher.update(str(tenant_id).encode("utf-8"))
|
||||
hasher.update(str(kb_id).encode("utf-8"))
|
||||
hasher.update(str(ent_name).encode("utf-8"))
|
||||
|
||||
k = hasher.hexdigest()
|
||||
bin = REDIS_CONN.get(k)
|
||||
if not bin:
|
||||
return
|
||||
return json.loads(bin)
|
||||
|
||||
|
||||
def set_entity_cache(tenant_id, kb_id, ent_name, content_with_weight):
|
||||
hasher = xxhash.xxh64()
|
||||
hasher.update(str(tenant_id).encode("utf-8"))
|
||||
hasher.update(str(kb_id).encode("utf-8"))
|
||||
hasher.update(str(ent_name).encode("utf-8"))
|
||||
|
||||
k = hasher.hexdigest()
|
||||
REDIS_CONN.set(k, content_with_weight.encode("utf-8"), 3600)
|
||||
|
||||
|
||||
def get_entity(tenant_id, kb_id, ent_name):
|
||||
cache = get_entity_cache(tenant_id, kb_id, ent_name)
|
||||
if cache:
|
||||
return cache
|
||||
conds = {
|
||||
"fields": ["content_with_weight"],
|
||||
"entity_kwd": ent_name,
|
||||
|
@ -247,6 +275,7 @@ def get_entity(tenant_id, kb_id, ent_name):
|
|||
for id in es_res.ids:
|
||||
try:
|
||||
if isinstance(ent_name, str):
|
||||
set_entity_cache(tenant_id, kb_id, ent_name, es_res.field[id]["content_with_weight"])
|
||||
return json.loads(es_res.field[id]["content_with_weight"])
|
||||
res.append(json.loads(es_res.field[id]["content_with_weight"]))
|
||||
except Exception:
|
||||
|
@ -269,6 +298,7 @@ def set_entity(tenant_id, kb_id, embd_mdl, ent_name, meta):
|
|||
"available_int": 0
|
||||
}
|
||||
chunk["content_sm_ltks"] = rag_tokenizer.fine_grained_tokenize(chunk["content_ltks"])
|
||||
set_entity_cache(tenant_id, kb_id, ent_name, chunk["content_with_weight"])
|
||||
res = settings.retrievaler.search({"entity_kwd": ent_name, "size": 1, "fields": []},
|
||||
search.index_name(tenant_id), [kb_id])
|
||||
if res.ids:
|
||||
|
@ -349,25 +379,57 @@ def set_relation(tenant_id, kb_id, embd_mdl, from_ent_name, to_ent_name, meta):
|
|||
chunk["q_%d_vec" % len(ebd)] = ebd
|
||||
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id)
|
||||
|
||||
async def does_graph_contains(tenant_id, kb_id, doc_id):
|
||||
# Get doc_ids of graph
|
||||
fields = ["source_id"]
|
||||
condition = {
|
||||
"knowledge_graph_kwd": ["graph"],
|
||||
"removed_kwd": "N",
|
||||
}
|
||||
res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(fields, [], condition, [], OrderByExpr(), 0, 1, search.index_name(tenant_id), [kb_id]))
|
||||
fields2 = settings.docStoreConn.getFields(res, fields)
|
||||
graph_doc_ids = set()
|
||||
for chunk_id in fields2.keys():
|
||||
graph_doc_ids = set(fields2[chunk_id]["source_id"])
|
||||
return doc_id in graph_doc_ids
|
||||
|
||||
def get_graph(tenant_id, kb_id):
|
||||
async def get_graph_doc_ids(tenant_id, kb_id) -> list[str]:
|
||||
conds = {
|
||||
"fields": ["source_id"],
|
||||
"removed_kwd": "N",
|
||||
"size": 1,
|
||||
"knowledge_graph_kwd": ["graph"]
|
||||
}
|
||||
res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id]))
|
||||
doc_ids = []
|
||||
if res.total == 0:
|
||||
return doc_ids
|
||||
for id in res.ids:
|
||||
doc_ids = res.field[id]["source_id"]
|
||||
return doc_ids
|
||||
|
||||
|
||||
async def get_graph(tenant_id, kb_id):
|
||||
conds = {
|
||||
"fields": ["content_with_weight", "source_id"],
|
||||
"removed_kwd": "N",
|
||||
"size": 1,
|
||||
"knowledge_graph_kwd": ["graph"]
|
||||
}
|
||||
res = settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id])
|
||||
res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search(conds, search.index_name(tenant_id), [kb_id]))
|
||||
if res.total == 0:
|
||||
return None, []
|
||||
for id in res.ids:
|
||||
try:
|
||||
return json_graph.node_link_graph(json.loads(res.field[id]["content_with_weight"]), edges="edges"), \
|
||||
res.field[id]["source_id"]
|
||||
except Exception:
|
||||
continue
|
||||
return rebuild_graph(tenant_id, kb_id)
|
||||
result = await rebuild_graph(tenant_id, kb_id)
|
||||
return result
|
||||
|
||||
|
||||
def set_graph(tenant_id, kb_id, graph, docids):
|
||||
async def set_graph(tenant_id, kb_id, graph, docids):
|
||||
chunk = {
|
||||
"content_with_weight": json.dumps(nx.node_link_data(graph, edges="edges"), ensure_ascii=False,
|
||||
indent=2),
|
||||
|
@ -376,13 +438,13 @@ def set_graph(tenant_id, kb_id, graph, docids):
|
|||
"source_id": list(docids),
|
||||
"available_int": 0,
|
||||
"removed_kwd": "N"
|
||||
}
|
||||
res = settings.retrievaler.search({"knowledge_graph_kwd": "graph", "size": 1, "fields": []}, search.index_name(tenant_id), [kb_id])
|
||||
}
|
||||
res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "graph", "size": 1, "fields": []}, search.index_name(tenant_id), [kb_id]))
|
||||
if res.ids:
|
||||
settings.docStoreConn.update({"knowledge_graph_kwd": "graph"}, chunk,
|
||||
search.index_name(tenant_id), kb_id)
|
||||
await trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"knowledge_graph_kwd": "graph"}, chunk,
|
||||
search.index_name(tenant_id), kb_id))
|
||||
else:
|
||||
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id)
|
||||
await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id))
|
||||
|
||||
|
||||
def is_continuous_subsequence(subseq, seq):
|
||||
|
@ -427,7 +489,7 @@ def merge_tuples(list1, list2):
|
|||
return result
|
||||
|
||||
|
||||
def update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, n_hop):
|
||||
async def update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, n_hop):
|
||||
def n_neighbor(id):
|
||||
nonlocal graph, n_hop
|
||||
count = 0
|
||||
|
@ -454,15 +516,16 @@ def update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, n_hop):
|
|||
return nbrs
|
||||
|
||||
pr = nx.pagerank(graph)
|
||||
for n, p in pr.items():
|
||||
graph.nodes[n]["pagerank"] = p
|
||||
try:
|
||||
settings.docStoreConn.update({"entity_kwd": n, "kb_id": kb_id},
|
||||
{"rank_flt": p,
|
||||
"n_hop_with_weight": json.dumps(n_neighbor(n), ensure_ascii=False)},
|
||||
search.index_name(tenant_id), kb_id)
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
try:
|
||||
async with trio.open_nursery() as nursery:
|
||||
for n, p in pr.items():
|
||||
graph.nodes[n]["pagerank"] = p
|
||||
nursery.start_soon(lambda: trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"entity_kwd": n, "kb_id": kb_id},
|
||||
{"rank_flt": p,
|
||||
"n_hop_with_weight": json.dumps((n), ensure_ascii=False)},
|
||||
search.index_name(tenant_id), kb_id)))
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
ty2ents = defaultdict(list)
|
||||
for p, r in sorted(pr.items(), key=lambda x: x[1], reverse=True):
|
||||
|
@ -477,21 +540,21 @@ def update_nodes_pagerank_nhop_neighbour(tenant_id, kb_id, graph, n_hop):
|
|||
"knowledge_graph_kwd": "ty2ents",
|
||||
"available_int": 0
|
||||
}
|
||||
res = settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "size": 1, "fields": []},
|
||||
search.index_name(tenant_id), [kb_id])
|
||||
res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "size": 1, "fields": []},
|
||||
search.index_name(tenant_id), [kb_id]))
|
||||
if res.ids:
|
||||
settings.docStoreConn.update({"knowledge_graph_kwd": "ty2ents"},
|
||||
await trio.to_thread.run_sync(lambda: settings.docStoreConn.update({"knowledge_graph_kwd": "ty2ents"},
|
||||
chunk,
|
||||
search.index_name(tenant_id), kb_id)
|
||||
search.index_name(tenant_id), kb_id))
|
||||
else:
|
||||
settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id)
|
||||
await trio.to_thread.run_sync(lambda: settings.docStoreConn.insert([{"id": chunk_id(chunk), **chunk}], search.index_name(tenant_id), kb_id))
|
||||
|
||||
|
||||
def get_entity_type2sampels(idxnms, kb_ids: list):
|
||||
es_res = settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids,
|
||||
async def get_entity_type2sampels(idxnms, kb_ids: list):
|
||||
es_res = await trio.to_thread.run_sync(lambda: settings.retrievaler.search({"knowledge_graph_kwd": "ty2ents", "kb_id": kb_ids,
|
||||
"size": 10000,
|
||||
"fields": ["content_with_weight"]},
|
||||
idxnms, kb_ids)
|
||||
idxnms, kb_ids))
|
||||
|
||||
res = defaultdict(list)
|
||||
for id in es_res.ids:
|
||||
|
@ -519,18 +582,18 @@ def flat_uniq_list(arr, key):
|
|||
return list(set(res))
|
||||
|
||||
|
||||
def rebuild_graph(tenant_id, kb_id):
|
||||
async def rebuild_graph(tenant_id, kb_id):
|
||||
graph = nx.Graph()
|
||||
src_ids = []
|
||||
flds = ["entity_kwd", "entity_type_kwd", "from_entity_kwd", "to_entity_kwd", "weight_int", "knowledge_graph_kwd", "source_id"]
|
||||
bs = 256
|
||||
for i in range(0, 39*bs, bs):
|
||||
es_res = settings.docStoreConn.search(flds, [],
|
||||
es_res = await trio.to_thread.run_sync(lambda: settings.docStoreConn.search(flds, [],
|
||||
{"kb_id": kb_id, "knowledge_graph_kwd": ["entity", "relation"]},
|
||||
[],
|
||||
OrderByExpr(),
|
||||
i, bs, search.index_name(tenant_id), [kb_id]
|
||||
)
|
||||
))
|
||||
tot = settings.docStoreConn.getTotal(es_res)
|
||||
if tot == 0:
|
||||
return None, None
|
||||
|
|
|
@ -35,7 +35,7 @@ spec:
|
|||
{{- end }}
|
||||
backend:
|
||||
service:
|
||||
name: {{ include "ragflow.fullname" $ }}
|
||||
name: {{ .Release.Name }}
|
||||
port:
|
||||
name: http
|
||||
{{- end }}
|
||||
|
|
|
@ -31,6 +31,8 @@ spec:
|
|||
ports:
|
||||
- containerPort: 80
|
||||
name: http
|
||||
- containerPort: 9380
|
||||
name: http-api
|
||||
volumeMounts:
|
||||
- mountPath: /etc/nginx/conf.d/ragflow.conf
|
||||
subPath: ragflow.conf
|
||||
|
@ -70,3 +72,23 @@ spec:
|
|||
targetPort: http
|
||||
name: http
|
||||
type: {{ .Values.ragflow.service.type }}
|
||||
---
|
||||
{{- if .Values.ragflow.api.service.enabled }}
|
||||
apiVersion: v1
|
||||
kind: Service
|
||||
metadata:
|
||||
name: {{ .Release.Name }}-api
|
||||
labels:
|
||||
{{- include "ragflow.labels" . | nindent 4 }}
|
||||
app.kubernetes.io/component: ragflow
|
||||
spec:
|
||||
selector:
|
||||
{{- include "ragflow.selectorLabels" . | nindent 4 }}
|
||||
app.kubernetes.io/component: ragflow
|
||||
ports:
|
||||
- protocol: TCP
|
||||
port: 80
|
||||
targetPort: http-api
|
||||
name: http-api
|
||||
type: {{ .Values.ragflow.api.service.type }}
|
||||
{{- end }}
|
||||
|
|
|
@ -27,13 +27,13 @@ env:
|
|||
REDIS_PASSWORD: infini_rag_flow_helm
|
||||
|
||||
# The RAGFlow Docker image to download.
|
||||
# Defaults to the v0.17.0-slim edition, which is the RAGFlow Docker image without embedding models.
|
||||
RAGFLOW_IMAGE: infiniflow/ragflow:v0.17.0-slim
|
||||
# Defaults to the v0.17.2-slim edition, which is the RAGFlow Docker image without embedding models.
|
||||
RAGFLOW_IMAGE: infiniflow/ragflow:v0.17.2-slim
|
||||
#
|
||||
# To download the RAGFlow Docker image with embedding models, uncomment the following line instead:
|
||||
# RAGFLOW_IMAGE: infiniflow/ragflow:v0.17.0
|
||||
# RAGFLOW_IMAGE: infiniflow/ragflow:v0.17.2
|
||||
#
|
||||
# The Docker image of the v0.17.0 edition includes:
|
||||
# The Docker image of the v0.17.2 edition includes:
|
||||
# - Built-in embedding models:
|
||||
# - BAAI/bge-large-zh-v1.5
|
||||
# - BAAI/bge-reranker-v2-m3
|
||||
|
@ -69,6 +69,10 @@ ragflow:
|
|||
service:
|
||||
# Use LoadBalancer to expose the web interface externally
|
||||
type: ClusterIP
|
||||
api:
|
||||
service:
|
||||
enabled: true
|
||||
type: ClusterIP
|
||||
|
||||
infinity:
|
||||
image:
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
[project]
|
||||
name = "ragflow"
|
||||
version = "0.17.0"
|
||||
version = "0.17.2"
|
||||
description = "[RAGFlow](https://ragflow.io/) is an open-source RAG (Retrieval-Augmented Generation) engine based on deep document understanding. It offers a streamlined RAG workflow for businesses of any scale, combining LLM (Large Language Models) to provide truthful question-answering capabilities, backed by well-founded citations from various complex formatted data."
|
||||
authors = [
|
||||
{ name = "Zhichang Yu", email = "yuzhichang@gmail.com" }
|
||||
|
@ -122,7 +122,8 @@ dependencies = [
|
|||
"pyodbc>=5.2.0,<6.0.0",
|
||||
"pyicu>=2.13.1,<3.0.0",
|
||||
"flasgger>=0.9.7.1,<0.10.0",
|
||||
"xxhash>=3.5.0,<4.0.0"
|
||||
"xxhash>=3.5.0,<4.0.0",
|
||||
"trio>=0.29.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
@ -133,4 +134,7 @@ full = [
|
|||
"flagembedding==1.2.10",
|
||||
"torch>=2.5.0,<3.0.0",
|
||||
"transformers>=4.35.0,<5.0.0"
|
||||
]
|
||||
]
|
||||
|
||||
[[tool.uv.index]]
|
||||
url = "https://mirrors.aliyun.com/pypi/simple"
|
||||
|
|
|
@ -240,7 +240,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|||
callback=callback)
|
||||
res = tokenize_table(tables, doc, is_english)
|
||||
|
||||
elif re.search(r"\.xlsx?$", filename, re.IGNORECASE):
|
||||
elif re.search(r"\.(csv|xlsx?)$", filename, re.IGNORECASE):
|
||||
callback(0.1, "Start to parse.")
|
||||
excel_parser = ExcelParser()
|
||||
if parser_config.get("html4excel"):
|
||||
|
@ -307,9 +307,7 @@ def chunk(filename, binary=None, from_page=0, to_page=100000,
|
|||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
|
||||
def dummy(prog=None, msg=""):
|
||||
pass
|
||||
|
||||
|
||||
chunk(sys.argv[1], from_page=0, to_page=10, callback=dummy)
|
||||
|
|
|
@ -20,7 +20,7 @@ from io import BytesIO
|
|||
from xpinyin import Pinyin
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from openpyxl import load_workbook
|
||||
# from openpyxl import load_workbook, Workbook
|
||||
from dateutil.parser import parse as datetime_parse
|
||||
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
|
@ -33,9 +33,9 @@ class Excel(ExcelParser):
|
|||
def __call__(self, fnm, binary=None, from_page=0,
|
||||
to_page=10000000000, callback=None):
|
||||
if not binary:
|
||||
wb = load_workbook(fnm)
|
||||
wb = Excel._load_excel_to_workbook(fnm)
|
||||
else:
|
||||
wb = load_workbook(BytesIO(binary))
|
||||
wb = Excel._load_excel_to_workbook(BytesIO(binary))
|
||||
total = 0
|
||||
for sheetname in wb.sheetnames:
|
||||
total += len(list(wb[sheetname].rows))
|
||||
|
|
|
@ -107,6 +107,7 @@ from .cv_model import (
|
|||
YiCV,
|
||||
HunyuanCV,
|
||||
)
|
||||
|
||||
from .rerank_model import (
|
||||
LocalAIRerank,
|
||||
DefaultRerank,
|
||||
|
@ -123,7 +124,9 @@ from .rerank_model import (
|
|||
VoyageRerank,
|
||||
QWenRerank,
|
||||
GPUStackRerank,
|
||||
HuggingfaceRerank,
|
||||
)
|
||||
|
||||
from .sequence2txt_model import (
|
||||
GPTSeq2txt,
|
||||
QWenSeq2txt,
|
||||
|
@ -132,6 +135,7 @@ from .sequence2txt_model import (
|
|||
TencentCloudSeq2txt,
|
||||
GPUStackSeq2txt,
|
||||
)
|
||||
|
||||
from .tts_model import (
|
||||
FishAudioTTS,
|
||||
QwenTTS,
|
||||
|
@ -255,6 +259,7 @@ RerankModel = {
|
|||
"Voyage AI": VoyageRerank,
|
||||
"Tongyi-Qianwen": QWenRerank,
|
||||
"GPUStack": GPUStackRerank,
|
||||
"HuggingFace": HuggingfaceRerank,
|
||||
}
|
||||
|
||||
Seq2txtModel = {
|
||||
|
|
|
@ -29,8 +29,8 @@ import json
|
|||
import requests
|
||||
import asyncio
|
||||
|
||||
LENGTH_NOTIFICATION_CN = "······\n由于长度的原因,回答被截断了,要继续吗?"
|
||||
LENGTH_NOTIFICATION_EN = "...\nFor the content length reason, it stopped, continue?"
|
||||
LENGTH_NOTIFICATION_CN = "······\n由于大模型的上下文窗口大小限制,回答已经被大模型截断。"
|
||||
LENGTH_NOTIFICATION_EN = "...\nThe answer is truncated by your chosen LLM due to its limitation on context length."
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
|
@ -42,6 +42,8 @@ class Base(ABC):
|
|||
def chat(self, system, history, gen_conf):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
|
@ -62,6 +64,8 @@ class Base(ABC):
|
|||
def chat_streamly(self, system, history, gen_conf):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
try:
|
||||
|
@ -187,6 +191,8 @@ class BaiChuanChat(Base):
|
|||
def chat(self, system, history, gen_conf):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model_name,
|
||||
|
@ -214,6 +220,8 @@ class BaiChuanChat(Base):
|
|||
def chat_streamly(self, system, history, gen_conf):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
try:
|
||||
|
@ -260,11 +268,13 @@ class QWenChat(Base):
|
|||
import dashscope
|
||||
dashscope.api_key = key
|
||||
self.model_name = model_name
|
||||
if model_name.lower().find("deepseek") >= 0:
|
||||
if self.is_reasoning_model(self.model_name):
|
||||
super().__init__(key, model_name, "https://dashscope.aliyuncs.com/compatible-mode/v1")
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
if self.model_name.lower().find("deepseek") >= 0:
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
if self.is_reasoning_model(self.model_name):
|
||||
return super().chat(system, history, gen_conf)
|
||||
|
||||
stream_flag = str(os.environ.get('QWEN_CHAT_BY_STREAM', 'true')).lower() == 'true'
|
||||
|
@ -305,6 +315,8 @@ class QWenChat(Base):
|
|||
from http import HTTPStatus
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
ans = ""
|
||||
tk_count = 0
|
||||
try:
|
||||
|
@ -334,11 +346,21 @@ class QWenChat(Base):
|
|||
yield tk_count
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
if self.model_name.lower().find("deepseek") >= 0:
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
if self.is_reasoning_model(self.model_name):
|
||||
return super().chat_streamly(system, history, gen_conf)
|
||||
|
||||
return self._chat_streamly(system, history, gen_conf)
|
||||
|
||||
@staticmethod
|
||||
def is_reasoning_model(model_name: str) -> bool:
|
||||
return any([
|
||||
model_name.lower().find("deepseek") >= 0,
|
||||
model_name.lower().find("qwq") >= 0 and model_name.lower() != 'qwq-32b-preview',
|
||||
])
|
||||
|
||||
|
||||
|
||||
class ZhipuChat(Base):
|
||||
def __init__(self, key, model_name="glm-3-turbo", **kwargs):
|
||||
|
@ -348,6 +370,8 @@ class ZhipuChat(Base):
|
|||
def chat(self, system, history, gen_conf):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
try:
|
||||
if "presence_penalty" in gen_conf:
|
||||
del gen_conf["presence_penalty"]
|
||||
|
@ -371,6 +395,8 @@ class ZhipuChat(Base):
|
|||
def chat_streamly(self, system, history, gen_conf):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
if "presence_penalty" in gen_conf:
|
||||
del gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf:
|
||||
|
@ -412,6 +438,8 @@ class OllamaChat(Base):
|
|||
def chat(self, system, history, gen_conf):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
try:
|
||||
options = {}
|
||||
if "temperature" in gen_conf:
|
||||
|
@ -438,6 +466,8 @@ class OllamaChat(Base):
|
|||
def chat_streamly(self, system, history, gen_conf):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
options = {}
|
||||
if "temperature" in gen_conf:
|
||||
options["temperature"] = gen_conf["temperature"]
|
||||
|
@ -515,8 +545,6 @@ class LocalLLM(Base):
|
|||
from rag.svr.jina_server import Prompt
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
if "max_tokens" in gen_conf:
|
||||
gen_conf["max_new_tokens"] = gen_conf.pop("max_tokens")
|
||||
return Prompt(message=history, gen_conf=gen_conf)
|
||||
|
||||
def _stream_response(self, endpoint, prompt):
|
||||
|
@ -538,6 +566,8 @@ class LocalLLM(Base):
|
|||
yield num_tokens_from_string(answer)
|
||||
|
||||
def chat(self, system, history, gen_conf):
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
prompt = self._prepare_prompt(system, history, gen_conf)
|
||||
chat_gen = self._stream_response("/chat", prompt)
|
||||
ans = next(chat_gen)
|
||||
|
@ -545,6 +575,8 @@ class LocalLLM(Base):
|
|||
return ans, total_tokens
|
||||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
prompt = self._prepare_prompt(system, history, gen_conf)
|
||||
return self._stream_response("/stream", prompt)
|
||||
|
||||
|
@ -606,6 +638,9 @@ class MiniMaxChat(Base):
|
|||
def chat_streamly(self, system, history, gen_conf):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
for k in list(gen_conf.keys()):
|
||||
if k not in ["temperature", "top_p", "max_tokens"]:
|
||||
del gen_conf[k]
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
try:
|
||||
|
@ -713,7 +748,7 @@ class BedrockChat(Base):
|
|||
self.bedrock_sk = json.loads(key).get('bedrock_sk', '')
|
||||
self.bedrock_region = json.loads(key).get('bedrock_region', '')
|
||||
self.model_name = model_name
|
||||
|
||||
|
||||
if self.bedrock_ak == '' or self.bedrock_sk == '' or self.bedrock_region == '':
|
||||
# Try to create a client using the default credentials (AWS_PROFILE, AWS_DEFAULT_REGION, etc.)
|
||||
self.client = boto3.client('bedrock-runtime')
|
||||
|
@ -724,14 +759,8 @@ class BedrockChat(Base):
|
|||
def chat(self, system, history, gen_conf):
|
||||
from botocore.exceptions import ClientError
|
||||
for k in list(gen_conf.keys()):
|
||||
if k not in ["temperature", "top_p", "max_tokens"]:
|
||||
if k not in ["top_p", "max_tokens"]:
|
||||
del gen_conf[k]
|
||||
if "max_tokens" in gen_conf:
|
||||
gen_conf["maxTokens"] = gen_conf["max_tokens"]
|
||||
_ = gen_conf.pop("max_tokens")
|
||||
if "top_p" in gen_conf:
|
||||
gen_conf["topP"] = gen_conf["top_p"]
|
||||
_ = gen_conf.pop("top_p")
|
||||
for item in history:
|
||||
if not isinstance(item["content"], list) and not isinstance(item["content"], tuple):
|
||||
item["content"] = [{"text": item["content"]}]
|
||||
|
@ -755,14 +784,8 @@ class BedrockChat(Base):
|
|||
def chat_streamly(self, system, history, gen_conf):
|
||||
from botocore.exceptions import ClientError
|
||||
for k in list(gen_conf.keys()):
|
||||
if k not in ["temperature", "top_p", "max_tokens"]:
|
||||
if k not in ["top_p", "max_tokens"]:
|
||||
del gen_conf[k]
|
||||
if "max_tokens" in gen_conf:
|
||||
gen_conf["maxTokens"] = gen_conf["max_tokens"]
|
||||
_ = gen_conf.pop("max_tokens")
|
||||
if "top_p" in gen_conf:
|
||||
gen_conf["topP"] = gen_conf["top_p"]
|
||||
_ = gen_conf.pop("top_p")
|
||||
for item in history:
|
||||
if not isinstance(item["content"], list) and not isinstance(item["content"], tuple):
|
||||
item["content"] = [{"text": item["content"]}]
|
||||
|
@ -819,11 +842,8 @@ class GeminiChat(Base):
|
|||
|
||||
if system:
|
||||
self.model._system_instruction = content_types.to_content(system)
|
||||
|
||||
if 'max_tokens' in gen_conf:
|
||||
gen_conf['max_output_tokens'] = gen_conf['max_tokens']
|
||||
for k in list(gen_conf.keys()):
|
||||
if k not in ["temperature", "top_p", "max_output_tokens"]:
|
||||
if k not in ["temperature", "top_p", "max_tokens"]:
|
||||
del gen_conf[k]
|
||||
for item in history:
|
||||
if 'role' in item and item['role'] == 'assistant':
|
||||
|
@ -847,10 +867,8 @@ class GeminiChat(Base):
|
|||
|
||||
if system:
|
||||
self.model._system_instruction = content_types.to_content(system)
|
||||
if 'max_tokens' in gen_conf:
|
||||
gen_conf['max_output_tokens'] = gen_conf['max_tokens']
|
||||
for k in list(gen_conf.keys()):
|
||||
if k not in ["temperature", "top_p", "max_output_tokens"]:
|
||||
if k not in ["temperature", "top_p", "max_tokens"]:
|
||||
del gen_conf[k]
|
||||
for item in history:
|
||||
if 'role' in item and item['role'] == 'assistant':
|
||||
|
@ -992,6 +1010,8 @@ class CoHereChat(Base):
|
|||
def chat(self, system, history, gen_conf):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
if "top_p" in gen_conf:
|
||||
gen_conf["p"] = gen_conf.pop("top_p")
|
||||
if "frequency_penalty" in gen_conf and "presence_penalty" in gen_conf:
|
||||
|
@ -1026,6 +1046,8 @@ class CoHereChat(Base):
|
|||
def chat_streamly(self, system, history, gen_conf):
|
||||
if system:
|
||||
history.insert(0, {"role": "system", "content": system})
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
if "top_p" in gen_conf:
|
||||
gen_conf["p"] = gen_conf.pop("top_p")
|
||||
if "frequency_penalty" in gen_conf and "presence_penalty" in gen_conf:
|
||||
|
@ -1122,7 +1144,7 @@ class ReplicateChat(Base):
|
|||
|
||||
def chat(self, system, history, gen_conf):
|
||||
if "max_tokens" in gen_conf:
|
||||
gen_conf["max_new_tokens"] = gen_conf.pop("max_tokens")
|
||||
del gen_conf["max_tokens"]
|
||||
if system:
|
||||
self.system = system
|
||||
prompt = "\n".join(
|
||||
|
@ -1141,7 +1163,7 @@ class ReplicateChat(Base):
|
|||
|
||||
def chat_streamly(self, system, history, gen_conf):
|
||||
if "max_tokens" in gen_conf:
|
||||
gen_conf["max_new_tokens"] = gen_conf.pop("max_tokens")
|
||||
del gen_conf["max_tokens"]
|
||||
if system:
|
||||
self.system = system
|
||||
prompt = "\n".join(
|
||||
|
@ -1185,6 +1207,8 @@ class HunyuanChat(Base):
|
|||
_history = [{k.capitalize(): v for k, v in item.items()} for item in history]
|
||||
if system:
|
||||
_history.insert(0, {"Role": "system", "Content": system})
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
if "temperature" in gen_conf:
|
||||
_gen_conf["Temperature"] = gen_conf["temperature"]
|
||||
if "top_p" in gen_conf:
|
||||
|
@ -1211,7 +1235,8 @@ class HunyuanChat(Base):
|
|||
_history = [{k.capitalize(): v for k, v in item.items()} for item in history]
|
||||
if system:
|
||||
_history.insert(0, {"Role": "system", "Content": system})
|
||||
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
if "temperature" in gen_conf:
|
||||
_gen_conf["Temperature"] = gen_conf["temperature"]
|
||||
if "top_p" in gen_conf:
|
||||
|
@ -1284,7 +1309,7 @@ class BaiduYiyanChat(Base):
|
|||
0)) / 2
|
||||
) + 1
|
||||
if "max_tokens" in gen_conf:
|
||||
gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
|
||||
del gen_conf["max_tokens"]
|
||||
ans = ""
|
||||
|
||||
try:
|
||||
|
@ -1308,7 +1333,7 @@ class BaiduYiyanChat(Base):
|
|||
0)) / 2
|
||||
) + 1
|
||||
if "max_tokens" in gen_conf:
|
||||
gen_conf["max_output_tokens"] = gen_conf["max_tokens"]
|
||||
del gen_conf["max_tokens"]
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
|
||||
|
@ -1344,8 +1369,6 @@ class AnthropicChat(Base):
|
|||
def chat(self, system, history, gen_conf):
|
||||
if system:
|
||||
self.system = system
|
||||
if "max_tokens" not in gen_conf:
|
||||
gen_conf["max_tokens"] = 4096
|
||||
if "presence_penalty" in gen_conf:
|
||||
del gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf:
|
||||
|
@ -1377,8 +1400,6 @@ class AnthropicChat(Base):
|
|||
def chat_streamly(self, system, history, gen_conf):
|
||||
if system:
|
||||
self.system = system
|
||||
if "max_tokens" not in gen_conf:
|
||||
gen_conf["max_tokens"] = 4096
|
||||
if "presence_penalty" in gen_conf:
|
||||
del gen_conf["presence_penalty"]
|
||||
if "frequency_penalty" in gen_conf:
|
||||
|
@ -1458,8 +1479,8 @@ class GoogleChat(Base):
|
|||
self.system = system
|
||||
|
||||
if "claude" in self.model_name:
|
||||
if "max_tokens" not in gen_conf:
|
||||
gen_conf["max_tokens"] = 4096
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
try:
|
||||
response = self.client.messages.create(
|
||||
model=self.model_name,
|
||||
|
@ -1508,8 +1529,8 @@ class GoogleChat(Base):
|
|||
self.system = system
|
||||
|
||||
if "claude" in self.model_name:
|
||||
if "max_tokens" not in gen_conf:
|
||||
gen_conf["max_tokens"] = 4096
|
||||
if "max_tokens" in gen_conf:
|
||||
del gen_conf["max_tokens"]
|
||||
ans = ""
|
||||
total_tokens = 0
|
||||
try:
|
||||
|
@ -1556,6 +1577,7 @@ class GoogleChat(Base):
|
|||
|
||||
yield response._chunks[-1].usage_metadata.total_token_count
|
||||
|
||||
|
||||
class GPUStackChat(Base):
|
||||
def __init__(self, key=None, model_name="", base_url=""):
|
||||
if not base_url:
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue