概述
本文将从零开始实现一个langchain应用程序, 该应用支持读取pdf文档并embedding编码到Chroma数据库, 当用户提问时,
可以从网络搜索结果和本地向量数据库中收集数据, 传递给第三方LLM大模型, 所有使用到的工具完全免费
将使用如下技术或工具:
- python3.9
- langchain
- Chroma DB
- Huggingface xlm-roberta-large 词嵌入模型
- TavilySearchAPI
- pytesseract OCR
- ...
依赖
本文使用python3.9
pip install fastapi langchain dotenv langchain_community
目录结构
- db (chroma_db存放位置)
- files (放所有要编码的pdf)- social (pdf分类文件夹)
- over (存放编码完成后的pdf)
- .env
- app.py (入口文件, 开放fastapi接口)
- models.py (定义使用的模型)
定义模型 models.py
首先获取讯飞开放API
如何使用api接入星火大模型(超详细,亲测有效!)_星火api-CSDN博客
将APIKEY写入.env文件
# 讯飞
IFLYTEK_SPARK_API_SECRET=
IFLYTEK_SPARK_API_KEY=
IFLYTEK_SPARK_APP_ID=## spark lite (免费不限量)
IFLYTEK_SPARK_MODEL=lite
IFLYTEK_SPARK_API_URL=wss://spark-api.xf-yun.com/v1.1/chat
定义模型
import os
from langchain_community.chat_models import ChatSparkLLMappid = os.getenv('iflytek_spark_app_id')
apikey = os.getenv('iflytek_spark_api_key')
apisecret = os.getenv('iflytek_spark_api_secret')
apiurl = os.getenv('iflytek_spark_api_url')
apimodel = os.getenv('iflytek_spark_model')llm = ChatSparkLLM(request_timeout=180,spark_api_url=apiurl,spark_llm_domain=apimodel,spark_app_id=appid, spark_api_key=apikey, spark_api_secret=apisecret
)
chat_model = llm
注意, 请在ide中Ctrl+左键点击进入ChatSparkLLM的源码, 然后添加一行break否则会遇到一直请求等待的bug
huggingface上的xlm-roberta-large模型
我们使用xlm-roberta-large作为本地embedding的模型, 下面这行代码在第一次运行会自动把模型下载到本地目录(
需要HUGGINGFACEHUB_API_TOKEN, 以及科学上网)
{% link 如何获取HuggingFace的Access Token;如何获取HuggingFace的API Key_huggingface access token-CSDN博客 %}
env文件
HUGGINGFACEHUB_API_TOKEN=
模型
from langchain.embeddings import HuggingFaceEmbeddings
embeddings = HuggingFaceEmbeddings(model_name="xlm-roberta-large",cache_folder='models/huggingface/')
生成答案时结合网络搜索
duckduckgo搜索因为不可抗力无法访问, 所以我现在使用有一定免费额度(一个月1000次request)的tavily搜索api (也可以使用wiki搜索库)
from langchain.utilities.tavily_search import TavilySearchAPIWrapper
from langchain.tools.tavily_search import TavilySearchResultssearch = TavilySearchAPIWrapper()
tavily_tool = TavilySearchResults(api_wrapper=search)
# search_agent = ChatOpenAI(model='gpt-4', temperature=0.7)
search_agent = llm
注意需要api key, 登录到控制台就可以获取
{% link Tavily %}
TAVILY_API_KEY=
Fastapi接口
import uvicorn
import os
import re
import shutil
import fitz # pdf
from fastapi import FastAPI, HTTPException, Form
from fastapi.middleware.cors import CORSMiddleware
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.embeddings import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.chains import RetrievalQAload_dotenv()app = FastAPI()# 导入模型
from models import embeddings, llm# 跨域 添加 CORS 中间件
app.add_middleware(CORSMiddleware,allow_origins=["http://localhost:3000","http://localhost:81",], # 允许的原点allow_credentials=True,allow_methods=["*"], # 允许所有方法allow_headers=["*"], # 允许所有头部
)vectorstore = None
qa_chain = Noneif __name__ == '__main__':uvicorn.run(app='app:app', host="127.0.0.1", port=8000, reload=True)
训练接口
处理pdf
# pytesseract需要去github下载安装, 然后才能使用其本地OCR
import pytesseractpytesseract.pytesseract.tesseract_cmd = r'D:\applications\TesseractOCR5\tesseract.exe'def find_pdf_files(directory):pdf_files = []for root, _, files in os.walk(directory):for file in files:if file.endswith('.pdf'):pdf_files.append(os.path.join(root, file))elif file.endswith('.epub'):pdf_files.append(os.path.join(root, file))return pdf_filesdef process_pdf_func(pdf_path):documents = []doc = fitz.open(pdf_path)for page in doc:# 尝试直接获取文本text = page.get_text()# 如果文本为空,使用 OCRif not text.strip():print('start OCR')# 获取页面的图像pix = page.get_pixmap()img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)# text = pytesseract.image_to_string(img)# text = pytesseract.image_to_string(img, lang='chi_sim+jpn+eng')# text = pytesseract.image_to_string(img, lang='jpn+eng')# text = pytesseract.image_to_string(img, lang='jpn+chi_sim')# text = pytesseract.image_to_string(img, lang='jpn')text = pytesseract.image_to_string(img, lang='chi_sim')print(f'new text: {text}')# print(f'new text: {text[:5]}')if text.strip():documents.append(Document(page_content=text))print(f'new text: {text}')# print(f'new text: {text[:5]}')return documentsdef process_file(file_path):if file_path.endswith('.pdf'):return process_pdf_func(file_path)else:raise ValueError("Unsupported file format")
加载pdf进行训练
请求时的url: POST http://localhost:8000/process_pdf/{dirname}
dirname就是files文件夹下某个子文件夹的名称, 表示某个领域的pdf文档集合, 编码后会存入db/dirname下的xxx.db, 在问答时,
根据问答接口的不同, 选择不同的db进行RAG检索生成, 达到分不同领域进行训练和使用的效果
# 加载pdf, 创建langchain链
def adjust_batch_size(embedding_size, available_memory_gb):# 每个嵌入的字节数embedding_memory = embedding_size * 4# 转换可用内存为字节available_memory_bytes = available_memory_gb * 1024 ** 3# 计算最大批处理大小max_batch_size = available_memory_bytes // embedding_memory# 设置一个合理的上限return min(max_batch_size, 50) # 100是一个假设的上限# return max_batch_size # 100是一个假设的上限def load_pdf_and_create_qa_chain(file_path: str, sub_path: str = ''):print('enter load and process')# loadglobal vectorstore, qa_chaindocuments = process_file(file_path)print('load file documents over!')# 使用文本切分器进行处理text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)# text_splitter = RecursiveCharacterTextSplitter(chunk_size=500, chunk_overlap=100)split_docs = text_splitter.split_documents(documents)print('split docs over!')# 创建目录(如果不存在)db_path = "db"if sub_path:db_path = os.path.join(db_path, sub_path)# 创建目录(如果不存在)os.makedirs(db_path, exist_ok=True)# Save document vectors to Chroma databasetry:# 计算批处理大小batch_size = adjust_batch_size(1024, 8) # 使用8GB内存print(f"Recommended batch size: {batch_size}")for i in range(0, len(split_docs), batch_size):# for i in range(8100, len(split_docs), batch_size):print(f'start batch-{i} persist! total: {len(split_docs)}')batch_docs = split_docs[i:i + batch_size]# 保存文档向量到 Chroma 数据库vectorstore = Chroma.from_documents(batch_docs, embeddings, persist_directory=db_path)vectorstore.persist() # 持久化到本地print(f'batch-{i} persisted! total: {len(split_docs)}')print('Vectorstore persist over!')except Exception as e:print(f'Error saving vectors: {e}')returnqa_chain = RetrievalQA.from_chain_type(llm=llm, chain_type="stuff", retriever=vectorstore.as_retriever())print(f'embedding {file_path} success')# Move the processed PDF to the 'over/sub_path' directoryover_path = os.path.join("over", sub_path)os.makedirs(over_path, exist_ok=True)shutil.copy(file_path, over_path)print(f'Moved {file_path} to {over_path}')
ask问答接口
class PromptRequest(BaseModel):question: str@app.post("/ask/{area}")
async def ask_ai(request: PromptRequest, area: str):if not area:chroma_db = Chroma(persist_directory="db", embedding_function=embeddings)else:# 初始化Chroma向量数据库chroma_db = Chroma(persist_directory="db/" + area, embedding_function=embeddings)# chroma_db = Chroma(persist_directory="db", embedding_function=OpenAIEmbeddings())# 创建RetrievalQA链qa_chain2 = RetrievalQA.from_chain_type(llm=chat_model, chain_type="stuff", retriever=chroma_db.as_retriever())# Define templates based on areatemplates = {'language': ("You are a language teacher specializing in teaching individuals who wish to settle abroad. ""Based on the following context and user's question, provide a detailed language knowledge answer. ""If unable to answer the user's question based on background knowledge, ask follow-up questions related ""to the background knowledge, limited to three questions.""Context: {context}\nQuestion: {question}\nAnswer(in chinese):"),'knowledge': ("You are a professor with expertise in multiple academic disciplines. ""Based on the following context and the user's academic question, provide an authoritative and professional answer. ""If unable to answer the user's question based on background knowledge, ask follow-up questions related ""to the background knowledge, limited to three questions.""Context: {context}\nQuestion: {question}\nAnswer(in chinese):"),'psychology': ("You are an expert professor in the field of psychology. ""Based on the following context and the user's academic question, provide an authoritative and professional answer. ""If unable to answer the user's question based on background knowledge, ask follow-up questions related ""to the background knowledge, limited to three questions.""Context: {context}\nQuestion: {question}\nAnswer(in chinese):")}template = templates.get(area, templates['language'])# 定义 PromptTemplateprompt_template = PromptTemplate(template=template,input_variables=["question"])try:# 使用 PromptTemplate 格式化提示formatted_prompt = prompt_template.format(question=request.question)# 使用 RetrievalQA 链处理格式化后的提示response = qa_chain2(formatted_prompt)print(f'qa chain response {response}')# 提取回答(falcon回答里有其他的东西)# Use regular expression to find text after "Helpful Answer:"match = re.search(r'Helpful Answer:\s*(.*)', response['result'], re.DOTALL)if match:helpful_answer = match.group(1).strip()print(helpful_answer)return {"response": (helpful_answer)}# return {"response": translate(helpful_answer)}else:print("No 'Helpful Answer:' found.")return {"response": (response['result'])}# return {"response": translate(response['result'])}except Exception as e:raise HTTPException(status_code=500, detail=str(e))
websocket聊天接口
from fastapi import WebSocket, WebSocketDisconnect# WebSocket连接处理
@app.websocket("/ws/ask/{area}")
async def websocket_endpoint(websocket: WebSocket, area: str):await websocket.accept()context = [] # Initialize context listmax_context_length = 1024 # Define a maximum context length# Initialize Chroma vector databasechroma_db = Chroma(persist_directory=f"db/{area}" if area else "db", embedding_function=embeddings)# Create RetrievalQA chainqa_chain2 = RetrievalQA.from_chain_type(llm=chat_model, chain_type="stuff",retriever=chroma_db.as_retriever(search_type="mmr",search_kwargs={'k': 2, 'lambda_mult': 0.75}))# Initialize web search tool (agent)agent_chain = initialize_agent([tavily_tool],search_agent,agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,verbose=True,)# Define templates based on areatemplates = {'knowledge': ("You are a professor with expertise in multiple academic disciplines. ""Based on the following context and the user's academic question, ""provide an authoritative and professional answer. ""If unable to answer the user's question based on background knowledge, ask follow-up questions related ""to the background knowledge, limited to three questions.""Context: {context}\nQuestion: {question}\nAnswer(in chinese):"),'psychology': ("You are an expert professor in the field of psychology. ""Based on the following context and the user's academic question, ""provide an authoritative and professional answer. ""If unable to answer the user's question based on background knowledge, ask follow-up questions related ""to the background knowledge, limited to three questions.""Context: {context}\nQuestion: {question}\nAnswer(in chinese):")}template = templates.get(area, templates['language'])# Define PromptTemplateprompt_template = PromptTemplate(template=template, input_variables=["context", "question"])try:while True:data = await websocket.receive_text()context.append(f"用户输入:{data}")print(f"用户输入:{data}")# Truncate context if it exceeds the maximum lengthcontext_str = "\n".join(context)if len(context_str) > max_context_length:# 保留最后1024位context_str = context_str[-max_context_length:]# Format the prompt with contextformatted_prompt = prompt_template.format(context=context_str[-800:],question=data)search_results = (agent_chain.run(data))print(f'搜索结果: \n{search_results}')combined_prompt = (f"{formatted_prompt}\n网络搜索结果: {search_results[-1200:]}")print(f"combined prompt: \n{combined_prompt}")# Process the prompt using RetrievalQA chain (RAG)response = qa_chain2(combined_prompt)model_reply = response['result']# # Extract answer using regexmatch = re.search(r'Helpful Answer:\s*(.*)', model_reply, re.DOTALL)if match:helpful_answer = match.group(1).strip()context.append(f"ai回复:{helpful_answer}")await websocket.send_text(helpful_answer)else:context.append(f"ai回复:{model_reply}")await websocket.send_text(model_reply)except WebSocketDisconnect:print("Client disconnected")