使用langchain实现RAG(检索增强生成)

article/2025/6/20 3:47:11

概述

本文将从零开始实现一个langchain应用程序, 该应用支持读取pdf文档并embedding编码到Chroma数据库, 当用户提问时,
可以从网络搜索结果和本地向量数据库中收集数据, 传递给第三方LLM大模型, 所有使用到的工具完全免费

将使用如下技术或工具:

  • python3.9
  • langchain
  • Chroma DB
  • Huggingface xlm-roberta-large 词嵌入模型
  • TavilySearchAPI
  • pytesseract OCR
  • ...

依赖

本文使用python3.9

pip install fastapi langchain dotenv langchain_community

目录结构

- db (chroma_db存放位置)
- files (放所有要编码的pdf)- social (pdf分类文件夹)
- over (存放编码完成后的pdf)
- .env
- app.py (入口文件, 开放fastapi接口)
- models.py (定义使用的模型)

定义模型 models.py

首先获取讯飞开放API

如何使用api接入星火大模型(超详细,亲测有效!)_星火api-CSDN博客

将APIKEY写入.env文件

# 讯飞
IFLYTEK_SPARK_API_SECRET=
IFLYTEK_SPARK_API_KEY=
IFLYTEK_SPARK_APP_ID=## spark lite (免费不限量)
IFLYTEK_SPARK_MODEL=lite
IFLYTEK_SPARK_API_URL=wss://spark-api.xf-yun.com/v1.1/chat

定义模型

import os
from langchain_community.chat_models import ChatSparkLLMappid = os.getenv('iflytek_spark_app_id')
apikey = os.getenv('iflytek_spark_api_key')
apisecret = os.getenv('iflytek_spark_api_secret')
apiurl = os.getenv('iflytek_spark_api_url')
apimodel = os.getenv('iflytek_spark_model')llm = ChatSparkLLM(request_timeout=180,spark_api_url=apiurl,spark_llm_domain=apimodel,spark_app_id=appid, spark_api_key=apikey, spark_api_secret=apisecret
)
chat_model = llm

注意, 请在ide中Ctrl+左键点击进入ChatSparkLLM的源码, 然后添加一行break否则会遇到一直请求等待的bug

huggingface上的xlm-roberta-large模型

我们使用xlm-roberta-large作为本地embedding的模型, 下面这行代码在第一次运行会自动把模型下载到本地目录(
需要HUGGINGFACEHUB_API_TOKEN, 以及科学上网)
{% link 如何获取HuggingFace的Access Token;如何获取HuggingFace的API Key_huggingface access token-CSDN博客 %}

env文件

HUGGINGFACEHUB_API_TOKEN=

模型

from langchain.embeddings import HuggingFaceEmbeddings
embeddings = HuggingFaceEmbeddings(model_name="xlm-roberta-large",cache_folder='models/huggingface/')

生成答案时结合网络搜索

duckduckgo搜索因为不可抗力无法访问, 所以我现在使用有一定免费额度(一个月1000次request)的tavily搜索api (也可以使用wiki搜索库)

from langchain.utilities.tavily_search import TavilySearchAPIWrapper
from langchain.tools.tavily_search import TavilySearchResultssearch = TavilySearchAPIWrapper()
tavily_tool = TavilySearchResults(api_wrapper=search)
# search_agent = ChatOpenAI(model='gpt-4', temperature=0.7)
search_agent = llm

注意需要api key, 登录到控制台就可以获取
{% link Tavily %}

TAVILY_API_KEY=

Fastapi接口

import uvicorn
import os
import re
import shutil
import fitz # pdf
from fastapi import FastAPI, HTTPException, Form
from fastapi.middleware.cors import CORSMiddleware
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQAload_dotenv()app = FastAPI()# 导入模型
from models import embeddings, llm# 跨域 添加 CORS 中间件
app.add_middleware(CORSMiddleware,allow_origins=["http://localhost:3000","http://localhost:81",],  # 允许的原点allow_credentials=True,allow_methods=["*"],  # 允许所有方法allow_headers=["*"],  # 允许所有头部
)vectorstore = None
qa_chain = Noneif __name__ == '__main__':uvicorn.run(app='app:app', host="127.0.0.1", port=8000, reload=True)

训练接口

处理pdf

# pytesseract需要去github下载安装, 然后才能使用其本地OCR
import pytesseractpytesseract.pytesseract.tesseract_cmd = r'D:\applications\TesseractOCR5\tesseract.exe'def find_pdf_files(directory):pdf_files = []for root, _, files in os.walk(directory):for file in files:if file.endswith('.pdf'):pdf_files.append(os.path.join(root, file))elif file.endswith('.epub'):pdf_files.append(os.path.join(root, file))return pdf_filesdef process_pdf_func(pdf_path):documents = []doc = fitz.open(pdf_path)for page in doc:# 尝试直接获取文本text = page.get_text()# 如果文本为空,使用 OCRif not text.strip():print('start OCR')# 获取页面的图像pix = page.get_pixmap()img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)# text = pytesseract.image_to_string(img)# text = pytesseract.image_to_string(img, lang='chi_sim+jpn+eng')# text = pytesseract.image_to_string(img, lang='jpn+eng')# text = pytesseract.image_to_string(img, lang='jpn+chi_sim')# text = pytesseract.image_to_string(img, lang='jpn')text = pytesseract.image_to_string(img, lang='chi_sim')print(f'new text: {text}')# print(f'new text: {text[:5]}')if text.strip():documents.append(Document(page_content=text))print(f'new text: {text}')# print(f'new text: {text[:5]}')return documentsdef process_file(file_path):if file_path.endswith('.pdf'):return process_pdf_func(file_path)else:raise ValueError("Unsupported file format")

加载pdf进行训练

请求时的url: POST http://localhost:8000/process_pdf/{dirname}

dirname就是files文件夹下某个子文件夹的名称, 表示某个领域的pdf文档集合, 编码后会存入db/dirname下的xxx.db, 在问答时,
根据问答接口的不同, 选择不同的db进行RAG检索生成, 达到分不同领域进行训练和使用的效果

# 加载pdf, 创建langchain链
def adjust_batch_size(embedding_size, available_memory_gb):# 每个嵌入的字节数embedding_memory = embedding_size * 4# 转换可用内存为字节available_memory_bytes = available_memory_gb * 1024 ** 3# 计算最大批处理大小max_batch_size = available_memory_bytes // embedding_memory# 设置一个合理的上限return min(max_batch_size, 50)  # 100是一个假设的上限# return max_batch_size  # 100是一个假设的上限def load_pdf_and_create_qa_chain(file_path: str, sub_path: str = ''):print('enter load and process')# loadglobal vectorstore, qa_chaindocuments = process_file(file_path)print('load file documents over!')# 使用文本切分器进行处理text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)# text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)split_docs = text_splitter.split_documents(documents)print('split docs over!')# 创建目录(如果不存在)db_path = "db"if sub_path:db_path = os.path.join(db_path, sub_path)# 创建目录(如果不存在)os.makedirs(db_path, exist_ok=True)# Save document vectors to Chroma databasetry:# 计算批处理大小batch_size = adjust_batch_size(1024, 8)  # 使用8GB内存print(f"Recommended batch size: {batch_size}")for i in range(0, len(split_docs), batch_size):# for i in range(8100, len(split_docs), batch_size):print(f'start batch-{i} persist! total: {len(split_docs)}')batch_docs = split_docs[i:i + batch_size]# 保存文档向量到 Chroma 数据库vectorstore = Chroma.from_documents(batch_docs, embeddings, persist_directory=db_path)vectorstore.persist()  # 持久化到本地print(f'batch-{i} persisted! total: {len(split_docs)}')print('Vectorstore persist over!')except Exception as e:print(f'Error saving vectors: {e}')returnqa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=vectorstore.as_retriever())print(f'embedding {file_path} success')# Move the processed PDF to the 'over/sub_path' directoryover_path = os.path.join("over", sub_path)os.makedirs(over_path, exist_ok=True)shutil.copy(file_path, over_path)print(f'Moved {file_path} to {over_path}')

ask问答接口

class PromptRequest(BaseModel):question: str@app.post("/ask/{area}")
async def ask_ai(request: PromptRequest, area: str):if not area:chroma_db = Chroma(persist_directory="db", embedding_function=embeddings)else:# 初始化Chroma向量数据库chroma_db = Chroma(persist_directory="db/" + area, embedding_function=embeddings)# chroma_db = Chroma(persist_directory="db", embedding_function=OpenAIEmbeddings())# 创建RetrievalQA链qa_chain2 = RetrievalQA.from_chain_type(llm=chat_model, chain_type="stuff", retriever=chroma_db.as_retriever())# Define templates based on areatemplates = {'language': ("You are a language teacher specializing in teaching individuals who wish to settle abroad. ""Based on the following context and user's question, provide a detailed language knowledge answer. ""If unable to answer the user's question based on background knowledge, ask follow-up questions related ""to the background knowledge, limited to three questions.""Context: {context}\nQuestion: {question}\nAnswer(in chinese):"),'knowledge': ("You are a professor with expertise in multiple academic disciplines. ""Based on the following context and the user's academic question, provide an authoritative and professional answer. ""If unable to answer the user's question based on background knowledge, ask follow-up questions related ""to the background knowledge, limited to three questions.""Context: {context}\nQuestion: {question}\nAnswer(in chinese):"),'psychology': ("You are an expert professor in the field of psychology. ""Based on the following context and the user's academic question, provide an authoritative and professional answer. ""If unable to answer the user's question based on background knowledge, ask follow-up questions related ""to the background knowledge, limited to three questions.""Context: {context}\nQuestion: {question}\nAnswer(in chinese):")}template = templates.get(area, templates['language'])# 定义 PromptTemplateprompt_template = PromptTemplate(template=template,input_variables=["question"])try:# 使用 PromptTemplate 格式化提示formatted_prompt = prompt_template.format(question=request.question)# 使用 RetrievalQA 链处理格式化后的提示response = qa_chain2(formatted_prompt)print(f'qa chain response {response}')# 提取回答(falcon回答里有其他的东西)# Use regular expression to find text after "Helpful Answer:"match = re.search(r'Helpful Answer:\s*(.*)', response['result'], re.DOTALL)if match:helpful_answer = match.group(1).strip()print(helpful_answer)return {"response": (helpful_answer)}# return {"response": translate(helpful_answer)}else:print("No 'Helpful Answer:' found.")return {"response": (response['result'])}# return {"response": translate(response['result'])}except Exception as e:raise HTTPException(status_code=500, detail=str(e))

websocket聊天接口

from fastapi import WebSocket, WebSocketDisconnect# WebSocket连接处理
@app.websocket("/ws/ask/{area}")
async def websocket_endpoint(websocket: WebSocket, area: str):await websocket.accept()context = []  # Initialize context listmax_context_length = 1024  # Define a maximum context length# Initialize Chroma vector databasechroma_db = Chroma(persist_directory=f"db/{area}" if area else "db", embedding_function=embeddings)# Create RetrievalQA chainqa_chain2 = RetrievalQA.from_chain_type(llm=chat_model, chain_type="stuff",retriever=chroma_db.as_retriever(search_type="mmr",search_kwargs={'k': 2, 'lambda_mult': 0.75}))# Initialize web search tool (agent)agent_chain = initialize_agent([tavily_tool],search_agent,agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,verbose=True,)# Define templates based on areatemplates = {'knowledge': ("You are a professor with expertise in multiple academic disciplines. ""Based on the following context and the user's academic question, ""provide an authoritative and professional answer. ""If unable to answer the user's question based on background knowledge, ask follow-up questions related ""to the background knowledge, limited to three questions.""Context: {context}\nQuestion: {question}\nAnswer(in chinese):"),'psychology': ("You are an expert professor in the field of psychology. ""Based on the following context and the user's academic question, ""provide an authoritative and professional answer. ""If unable to answer the user's question based on background knowledge, ask follow-up questions related ""to the background knowledge, limited to three questions.""Context: {context}\nQuestion: {question}\nAnswer(in chinese):")}template = templates.get(area, templates['language'])# Define PromptTemplateprompt_template = PromptTemplate(template=template, input_variables=["context", "question"])try:while True:data = await websocket.receive_text()context.append(f"用户输入:{data}")print(f"用户输入:{data}")# Truncate context if it exceeds the maximum lengthcontext_str = "\n".join(context)if len(context_str) > max_context_length:# 保留最后1024位context_str = context_str[-max_context_length:]# Format the prompt with contextformatted_prompt = prompt_template.format(context=context_str[-800:],question=data)search_results = (agent_chain.run(data))print(f'搜索结果: \n{search_results}')combined_prompt = (f"{formatted_prompt}\n网络搜索结果: {search_results[-1200:]}")print(f"combined prompt: \n{combined_prompt}")# Process the prompt using RetrievalQA chain (RAG)response = qa_chain2(combined_prompt)model_reply = response['result']# # Extract answer using regexmatch = re.search(r'Helpful Answer:\s*(.*)', model_reply, re.DOTALL)if match:helpful_answer = match.group(1).strip()context.append(f"ai回复:{helpful_answer}")await websocket.send_text(helpful_answer)else:context.append(f"ai回复:{model_reply}")await websocket.send_text(model_reply)except WebSocketDisconnect:print("Client disconnected")

over


http://www.hkcw.cn/article/nGoxSrweeC.shtml

相关文章

力扣HOT100之动态规划:139. 单词拆分

这道题之前刷代码随想录的时候已经做过了,但是现在再做一遍还是不会,直接去看视频了。感觉这道题的dp数组很难想到,感觉做不出来也是情有可原吧。这道题目也是一个完全背包问题,字典里的单词就相当于物品,而字符串相当…

趋势直线指标

趋势直线副图和主图指标,旨在通过技术分析工具帮助交易者识别市场趋势和潜在的买卖点。 副图指标:基于KDJ指标的交易策略 1. RSV值计算: - RSV(未成熟随机值)反映了当前收盘价在过去一段时间内的相对位置。通过计算当前…

应急响应靶机-web3-知攻善防实验室

题目: 1.攻击者的两个IP地址 2.攻击者隐藏用户名称 3.三个攻击者留下的flag 密码:yj123456 解题: 1.攻击者的两个IP地址 一个可能是远程,D盾,404.php,192.168.75.129 找到远程连接相关的英文,1149代表远程连接成功…

前端-不对用户显示

这是steam的商店偏好设置界面,在没有被锁在国区的steam账号会有5个选项,而被锁在国区的账号只有3个选项,这里使用的技术手段仅仅在前端隐藏了这个其他两个按钮。 单击F12打开开发者模式 单击1处,找到这一行代码,可以看…

C++单调栈(递增、递减)

定义 先说单调栈的定义 单调栈,是指栈内数据逐步上升(一个比一个大),或逐步下降(一个比一个小)的栈,其并没有独立的代码,而是在stack的基础上加以限制及条件形成的。 比如&#x…

WIN11+CUDA11.8+VS2019配置BundleFusion

参考: BundleFusion:VS2019 2017 ,CUDA11.5,win11,Realsense D435i离线数据包跑通,环境搭建 - 知乎 Win10VS2017CUDA10.1环境下配置BundleFusion - 知乎 BundleFusionWIN11VS2019 CUDA11.7环境配置-CSDN博客 我的环境:Win 11…

【基于SpringBoot的图书购买系统】Redis中的数据以分页的形式展示:从配置到前后端交互的完整实现

引言 在当今互联网应用开发中,高性能和高并发已经成为系统设计的核心考量因素。Redis作为一款高性能的内存数据库,以其快速的读写速度、丰富的数据结构和灵活的扩展性,成为解决系统缓存、高并发访问等场景的首选技术之一。在图书管理系统中&…

Leetcode LCR 187. 破冰游戏

1.题目基本信息 1.1.题目描述 社团共有 num 位成员参与破冰游戏,编号为 0 ~ num-1。成员们按照编号顺序围绕圆桌而坐。社长抽取一个数字 target,从 0 号成员起开始计数,排在第 target 位的成员离开圆桌,且成员离开后从下一个成员…

任务20:实现各省份平均气温预测

任务描述 知识点: 时间序列分析 重 点: 指数平滑法Python连接数据库,更新数据 内 容: 读取所有省份各月的平均气温数据预测各省份下一年1-12月的气温,并存储到MySQL数据库 任务指导 1. 读取所有省份各月的平…

【Unity】AudioSource超过MaxDistance还是能听见

unity版本:2022.3.51f1c1 将SpatialBlend拉到1即可 或者这里改到0 Hearing audio outside max distance - #11 by wderstine - Questions & Answers - Unity Discussions

VulnStack|红日靶场——红队评估四

信息收集及漏洞利用 扫描跟kali处在同一网段的设备,找出目标IP arp-scan -l 扫描目标端口 nmap -p- -n -O -A -Pn -v -sV 192.168.126.154 3个端口上有web服务,分别对应三个漏洞环境 :2001——Struts2、2002——Tomcat、2003——phpMyAd…

在 RK3588 上通过 VSCode 远程开发配置指南

在 RK3588 上通过 VSCode 远程开发配置指南 RK3588 设备本身不具备可视化编程环境,但可以通过 VSCode 的 Remote - SSH 插件 实现远程代码编写与调试。以下是完整的配置流程。 一、连接 RK3588 1. 安装 Debian 系统 先在 RK3588 上安装 Debian 操作系统。 2. 安…

Docker-搭建MySQL主从复制与双主双从

Docker -- 搭建MySQL主从复制与双主双从 一、MySQL主从复制1.1 准备工作从 Harbor 私有仓库拉取镜像直接拉取镜像运行容器 1.2 配置主、从服务器1.3 创建主、从服务器1.4 启动主库,创建同步用户1.5 配置启动从库1.6 主从复制测试 二、MySQL双主双从2.1 创建网络2.2 …

累加法求数列通项公式

文章目录 前言如何判断注意事项适用类型方法介绍典例剖析对应练习 前言 累加法,顾名思义,就是多次相加的意思。求通项公式题型中,如果给定条件最终可以转化为 a n 1 − a n f ( n ) a_{n1}-a_nf(n) an1​−an​f(n)的形式,或者…

vue3的watch用法

<template><div class"container mx-auto p-4"><h1 class"text-2xl font-bold mb-4">Vue 3 Watch 示例</h1><div class"grid grid-cols-1 md:grid-cols-2 gap-6"><!-- 基本数据监听 --><div class"…

day15 leetcode-hot100-28(链表7)

2. 两数相加 - 力扣&#xff08;LeetCode&#xff09; 1.模拟 思路 最核心的一点就是将两个链表模拟为等长&#xff0c;不足的假设为0&#xff1b; &#xff08;1&#xff09;设置一个新链表newl来代表相加结果。 &#xff08;2&#xff09;链表1与链表2相加&#xff0c;具…

边缘计算场景下的大模型落地:基于 Cherry Studio 的 DeepSeek-R1-0528 本地部署

前言 作为学生&#xff0c;我选择用 Cherry Studio 在本地调用 DeepSeek-R1-0528&#xff0c;完全是被它的实用性和 “性价比” 圈粉。最近在 GitHub 和 AI 社群里&#xff0c;大家都在热议 DeepSeek-R1-0528&#xff0c;尤其是它的数学解题和编程能力。像我在准备数学建模竞赛…

Tomcat的整体架构及其设计精髓

1.Tomcat介绍 官方文档&#xff1a;https://tomcat.apache.org/tomcat-9.0-doc/index.html 1.1 Tomcat概念 Tomcat是Apache Software Foundation&#xff08;Apache软件基金会&#xff09;开发的一款开源的Java Servlet 容器。它是一种Web服务器&#xff0c;用于在服务器端运行…

使用 Let‘s Encrypt 和 Certbot 为 Cloudflare 托管的域名申请 SSL 证书

一、准备工作 1. 确保域名解析在 Cloudflare 确保你的域名 jessi53.com 和 www.jessi53.com 的 DNS 记录已经正确配置在 Cloudflare 中&#xff0c;并且状态为 Active。 2. 安装 Certbot 在你的服务器上安装 Certbot 和 Cloudflare 插件。以下是基于 Debian/Ubuntu 和 Cent…

JAVA最新版本详细安装教程(附安装包)

目录 文章自述 一、JAVA下载 二、JAVA安装 1.首先在D盘创建【java/jdk-23】文件夹 2.把下载的压缩包移动到【jdk-23】文件夹内&#xff0c;右键点击【解压到当前文件夹】 3.如图解压会有【jdk-23.0.1】文件 4.右键桌面此电脑&#xff0c;点击【属性】 5.下滑滚动条&…