AGI|三种高级RAG检索方法帮企业告别冗长文档!
基于LangChain实现的RAG检索方法
RAG(检索增强生成)已经成为当前企业级AI问答应用的主流技术路线。相比重新训练一个大模型来理解私域数据,RAG不需要额外的预训练成本,同时又能保持相近的回答质量。但在实际落地过程中,有一个始终绕不开的瓶颈——大语言模型的上下文长度限制。
虽然各大模型厂商已经将上下文token从年初的2k、4k一路推到了128k、192k,但面对企业真实场景下的合同、标书、技术文档,这点容量可能连一份完整的文件都塞不进去。更关键的是,即便能塞得下,过长的prompt反而会稀释大模型的注意力,导致中间部分的信息被忽略。所以,如何在尽可能压缩输入规模的同时,还能保证内容的完整性,就成了提升RAG应用准确度的一个核心命题。
这篇文章就从LangChain出发,拆解三种高级检索策略:
句子窗口检索
自动合并检索
多路召回检索
Part1 前言
RAG的核心逻辑并不复杂:先把企业私域知识做切片、向量化,存到向量数据库里;用户提问时,把问题也向量化,去库里搜索最相关的几段内容,再塞进prompt交给大模型生成答案。流程看起来很顺畅,但真正的挑战在于——当知识库非常大时,你不可能把所有相关文本都扔给模型,你又必须保证召回来的那些片段恰好包含回答问题所需的关键信息。
这就回到了切片和召回的设计问题上。切片太粗,可能把不相关的信息也打包进来;切片太细,单个片段信息不完整,RAG流程中很容易出现上下文断裂。下面介绍的三种方法,正是从不同角度来缓解这个矛盾。
Part2 先验知识
RAG简要流程

加载文档 → 切分划片 → 嵌入向量表示 → 存入数据库

向量化问题 → 向量召回文档 → 合并放入Prompt → LLM生成答案
Part3 句子窗口检索
一、概念
文档被切分后,每一块会被封装成一个Langchain自定义的Document对象,它有两个重要属性:page_content记录该片段的文本内容,meta_data用来存放额外信息。
句子窗口检索的思路很简单:在切片阶段,就把每个片段的前后相邻片段的内容也一并存进它的meta_data里。等到召回时,命中文档就能从meta_data中拿回自己的上下文,并把它们拼接到page_content中。这样一来,最终交付给大模型的就不是孤立的片段,而是带完整上下文的段落。
实际操作中的建议:切片器尽量选用依据标点符号切分的类型,把粒度控制得足够细;同时在封装和召回阶段,适当把窗口大小扩大一些,保证上下文信息的完整性。
二、代码实现
(1) 元数据封装
def metadata_format(self, ordered_text, **kwargs):
count = kwargs.get("split_count", 1)
for i, document in enumerate(ordered_text):
if i > 0:
document.metadata['previous_page'] = ordered_text[i-count].page_content
else:
document.metadata['previous_page'] = ''
if i < len(ordered_text) - 1:
document.metadata['next_page'] = ordered_text[i+count].page_content
else:
document.metadata['next_page'] = ''
return ordered_text
(2) 数据重构
def search_and_format(self, databases, query, **kwargs):
top_documents = []
for db in databases:
top_documents.append(db.similarity_search_with_score(query))
docs = []
for doc, _ in top_documents:
doc.page_content = doc.metadata.get("previous_page") + doc.page_content + doc.metadata.get("next_page")
docs.append(doc)
return docs
(3) 调用示例
# load document
......
# split
......
# use smartvision sdk to format
sentence_window_retrival = SentenceWindow()
formatted_documents = sentence_window_retrival.metadata_format(documents, split_count=2)
# embedding
......
# load in local vector db
......
# use smartvision sdk to do search and multiple recall
databases = [db]
query = "烟草专卖品的运输"
top_documents = sentence_window_retrival.search_and_format(databases, query)
print(top_documents)

Part4 自动合并检索
一、概念
自动合并检索的思路来自LlamaIndex封装的同名方法,但整套流程需要我们自己把它对齐到LangChain的规范中。具体做法是:在读取并切分文档之后,先把Langchain格式的Document对象转换成LlamaIndex的Document对象,然后利用LlamaIndex的算法自动将切片列表划分为子节点和父节点的层级结构。最后再转回Langchain格式,并将父节点信息、层级深度等一并封装进每个节点的meta_data里。
检索时,先召回最相关的若干个叶子节点。遍历这些节点时,如果发现超过K个节点(K是用户自定义的阈值,通常取父节点子节点总数的一半)同时指向同一个父节点,就把这个父节点下属的所有子节点合并,直接返回父节点的完整内容。这种做法使得原本可能分散的、较小的上下文片段有机会组合成一个更完整的信息块,有助于大模型生成更准确的回答。
二、代码实现
(1) 元数据封装
def auto_merge_format(documents, **kwargs):
if documents is None:
raise ValueError('documents is required')
formatted_documents = []
doc_text = "\n\n".join([d.page_content for d in documents])
docs = [Document(text=doc_text)]
node_parser = HierarchicalNodeParser.from_defaults(
chunk_sizes=kwargs.get("pc_chunk_size", [2048, 512, 128]),
chunk_overlap=kwargs.get("pc_chunk_overlap", 10)
)
nodes = node_parser.get_nodes_from_documents(docs)
leaf_nodes = get_leaf_nodes(nodes)
root_nodes = get_root_nodes(nodes)
middle_nodes = get_middle_node(nodes, leaf_nodes, root_nodes)
root_context_dict = {}
for root_node in nodes:
root_context_dict[root_node.node_id] = root_node.get_content()
for node in nodes:
if node.parent_node:
node_id = node.node_id
root_node_id = node.parent_node.node_id
root_node_content = root_context_dict.get(node.parent_node.node_id)
root_node_child_count = 0
for parent_node in root_nodes + middle_nodes:
if parent_node.node_id == node.parent_node.node_id:
root_node_child_count = len(parent_node.child_nodes)
break
depth = 2 if node in middle_nodes else 3
child_count = len(node.child_nodes) if node.child_nodes is not None else 0
document = langchain.schema.Document(
page_content=node.get_content(),
metadata={
"node_id": node_id,
"root_node_id": root_node_id,
"root_node_content": root_node_content,
"root_node_child_count": root_node_child_count,
"depth": depth,
"child_count": child_count
}
)
formatted_documents.append(document)
return formatted_documents
(2) 数据重构
def search_and_format(self, databases, query, **kwargs):
top_documents = []
for db in databases:
top_document = db.similarity_search_with_score(query)
top_documents.append(top_document)
leaf_nodes = [doc for doc, _ in top_documents]
return do_merge(leaf_nodes, **kwargs)
def group_nodes_by_depth(nodes, depth):
return [node for node in nodes if node.metadata.get("depth") == depth]
def process_group(nodes, threshold):
grouped_by_root_id = {}
for node in nodes:
root_id = node.metadata.get("root_node_id")
grouped_by_root_id.setdefault(root_id, []).append(node)
merge_context = []
for group in grouped_by_root_id.values():
node_count = len(group)
child_count = group[0].metadata.get("root_node_child_count")
if node_count / child_count >= threshold:
merge_context.append(langchain.schema.Document(
page_content=group[0].metadata.get("root_node_content")
))
else:
for document in group:
merge_context.append(document)
return merge_context
def do_merge(nodes, **kwargs) -> List[langchain.schema.Document]:
threshold = kwargs.get("threshold", 0.5)
leaf_nodes = group_nodes_by_depth(nodes, 3)
middle_nodes = group_nodes_by_depth(nodes, 2)
leaf_merge_context = process_group(leaf_nodes, threshold)
middle_merge_context = process_group(middle_nodes, threshold)
merge_content = leaf_merge_context + middle_merge_context
return merge_content
def get_middle_node(nodes, leaf_nodes, root_nodes):
middle_node = []
for node in nodes:
if node not in leaf_nodes and node not in root_nodes:
middle_node.append(node)
return middle_node
(3) 调用示例
# load document
......
# split
......
# use smartvision sdk to format
auto_merge_retrival = AutoMergeRetrieval()
formatted_documents = auto_merge_retrival.metadata_format(
documents,
pc_chunk_size=[1024, 128, 32],
pc_chunk_overlap=4
)
# embedding
......
# load in local vector db
......
# use smartvision sdk to do search and multiple recall
top_documents = auto_merge_retrival.search_and_format(databases, query, threshold=0.5)
print(top_documents)

Part5 多路召回检索
一、概念
多路召回检索跟前两种方法不同,它在元数据封装阶段不做任何特殊处理。真正的变化发生在检索阶段:它允许用户同时把多个数据集或不同类型的向量数据库作为检索源,应对文档类型多样、数量庞大的场景。
实际流程是:从多个数据源分别召回文档列表,然后用一个rerank模型对每个文档和问题之间的相关性进行评分,最后只保留评分超过一定阈值的文档,组合成最终的prompt。
多路召回在数据源“广而杂”的情况下优势非常明显。不过必须意识到,引入rerank模型虽然在准确度上做了二次把关,但它是以牺牲响应速度和执行效率为代价的。在实际设计中,是否启用多路召回、是否加入rerank,需要根据业务场景对实时性和准确度的具体要求来权衡。
二、代码实现
(1) 元数据封装
def metadata_format(self, ordered_text, **kwargs):
"""
默认rag,不做任何处理
"""
return ordered_text
(2) 数据重构
def search_and_format(self, databases, query, **kwargs):
top_documents = []
result_data = []
for db in databases:
top_document = db.similarity_search_with_score(query)
top_documents.append(top_document)
pairs = [[query, item.page_content] for item in top_documents]
with torch.no_grad():
rerank_tokenizer = AutoTokenizer.from_pretrained(RERANK_FILE_PATH)
inputs = rerank_tokenizer(pairs, padding=True, truncation=True, return_tensors='pt', max_length=512)
rerank_model = AutoModelForSequenceClassification.from_pretrained(RERANK_FILE_PATH)
scores = rerank_model(**inputs, return_dict=True).logits.view(-1, ).float()
for i, score in enumerate(scores):
data = {
"text": top_documents[i].page_content,
"score": float(score)
}
result_data.append(data)
return result_data
Part6 结语
以上三种高级RAG检索方法,主要目标都是改善检索召回环节的信息残缺问题。事实上RAG的完整流程中还有多处可以优化的环节——比如文档预处理、切片策略、向量化模型的选择、Prompt的编排等等——但从实践效果来看,改进召回方式仍然是回报最直接的优化方向。
贯穿这三种方法都有一个共同的关键点: