背景
完成了embedding模型和向量数据库,之后就只剩下导入大语言模型和组装起来
实现配置类
计划使用yml文件作为配置类,方便后续的修改
YamlUtil
通过自定义环境变量构造器来实现yml读取环境变量
"""
-*- coding: UTF-8 -*-
@Author :Leezed
@Date :2025/8/10 11:43
"""
import os
import yaml
def env_var_constructor(loader, node):
value = loader.construct_scalar(node) # PyYAML loader的固定方法,用于根据当前节点构造一个变量值
var_name = value.strip('${} ') # 去除变量值(例如${USER})前后的特殊字符及空格
return os.getenv(var_name, value) # 尝试在环境变量中获取变量名(如USER)对应的值,获取不到使用默认值value(即原来的${USER})
yaml.SafeLoader.add_constructor('!env', env_var_constructor) # 为SafeLoader添加新的tag和构造器
class YamlUtil():
def __init__(self, yaml_path, yaml_name):
"""
初始化,当前这个文件操作效率不高,懒得优化了,随便写的,哪天心情好了在优化
:param yaml_path:
:param yaml_name:
"""
self.yaml_path = yaml_path
self.yaml_name = yaml_name
self.yaml_file = os.path.join(self.yaml_path, self.yaml_name)
# 判断当前路径下文件是否存在,如果存在则读取,如果不存在则创建
if not os.path.exists(self.yaml_file):
with open(self.yaml_file, 'w') as f:
f.write('')
self.content = {}
self.write_yaml()
self.content = self.read_yaml() # 读取到的yaml文件内容
def __getattr__(self, key):
return self.get_value(key)
def get_value(self, key):
"""
获取yaml中的值
:param key: key路径,数组形式 ['a','b','c'] 或者字符串形式 'a.b.c'
:return: key对应的值
"""
really_key_path = self.is_exist_key(key)
content = self.content
for key in really_key_path:
content = content[key]
return content
def add_root_key(self, key, value=None):
"""
添加根key
:param key: 添加的key
:param value: 添加的值
"""
if key not in self.content:
self.content[key] = {} if value is None else value
else:
raise Exception('key已存在')
def add_key(self, key_path, value=None):
"""
添加key 当前只能添加最后一个key,比如a.b.c,只能添加c,如果b不存在就会报错
:param key_path: key路径,数组形式 ['a','b','c'] 或者字符串形式 'a.b.c'
:param value: 添加的值
"""
# 判断key_path的类型,如果是字符串则转换为数组
really_key_path = self.is_exist_key(key_path)
content = self.content
for key in really_key_path[:-1]:
content = content[key]
if really_key_path[-1] in content:
raise Exception("当前key已存在")
content[really_key_path[-1]] = {} if value is None else value
def debug(self):
print(self.content)
# 修改某个key中的内容
def modify_key(self, key_path, value=None):
"""
修改内存中yaml中某个key的内容
:param key_path: key路径,数组形式 ['a','b','c'] 或者字符串形式 'a.b.c'
:param value: 修改后的值
"""
# 判断key_path的类型,如果是字符串则转换为数组
if value is None:
value = {}
really_key_path = self.is_exist_key(key_path, check_last=True)
content = self.content
for key in really_key_path[:-1]:
content = content[key]
content[really_key_path[-1]] = {} if value is None else value
def is_exist_key(self, key_path, check_last=False):
"""
判断yaml中是否存在某个key路径
:param key_path: key_path,数组形式 ['a','b','c'] 或者字符串形式 'a.b.c'
:param check_last: 是否检查最后一个key,因为对于最后一个key来说,可能是空的,所以默认不检查
:return: 返回key_path,如果不存在key_path为[]
"""
# 判断key_path的类型,如果是字符串则转换为数组
really_key_path = key_path.split('.') if isinstance(key_path, str) else key_path
assert isinstance(really_key_path, list), 'key_path必须是list或者str类型'
content = self.content
for index, key in enumerate(really_key_path):
if not check_last and index == len(really_key_path) - 1:
return really_key_path
# if index == len(really_key_path) - 1 and content == {}:
# return True, really_key_path
if key in content:
content = content[key]
else:
raise Exception('key错误')
return really_key_path
def read_yaml(self):
"""
读取整个yaml文件
:return: 读取到的文件内容
"""
with open(self.yaml_file, 'r') as f:
content = yaml.safe_load(f)
return content
def write_yaml(self):
"""
将content写入yaml文件
"""
with open(self.yaml_file, 'w') as f:
yaml.safe_dump(self.content, f)
CFGLoader
在对读取出来的yml文件封装一层,方便使用
"""
-*- coding: utf-8 -*-
@Author : Leezed
@Time : 2024/3/20 20:58
"""
from utils.YamlUtil import YamlUtil
import os
current_directory_path = os.path.dirname(os.path.abspath(__file__))
configs_path = os.path.join(current_directory_path, '../../configs')
class CFGLoader(dict):
def __getattr__(self, item):
if item in self:
value = self[item]
if isinstance(value, dict):
return CFGLoader(value)
else:
return value
else:
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")
def __setattr__(self, key, value):
print("in set")
print(key)
print(value)
print("before change :", self[key])
self[key] = value
print("after change :", self[key])
def get_cfg(path, cfg_name):
return YamlUtil(path, cfg_name if cfg_name.endswith(".yml") else cfg_name + ".yml").content
def env_setting(cfg_name):
"""
获取配置文件 ,默认是local.yaml
:param cfg_name: 如果不想获取local.yaml,可以传入其他的配置文件名
:return: 返回配置文件
"""
cfg = CFGLoader(get_cfg(configs_path, cfg_name if cfg_name is not None else 'local.yml'))
return cfg
配置文件编写
model:
name: DeepSeek
embedding:
name: BGE_base_zh
embeddings:
BGE_base_zh:
ClassName: BGEBaseZH
GTE_base_zh:
ClassName: GTEBaseZH
api:
key:
DeepSeek: !env ${DEEPSEEK_API_KEY}
导入大语言模型
定义基类
主要目的是统一定义操作方法,就一个chat和一个load_model
"""
-*- coding: UTF-8 -*-
@Author :Leezed
@Date :2025/8/10 11:31
"""
class BaseModel:
def __init__(self, cfg):
self.cfg = cfg
self.model = None
pass
def chat(self, query, history=None, **kwargs):
raise NotImplementedError("Subclasses must implement this method")
def load_model(self):
raise NotImplementedError("Subclasses must implement this method")
实现DeepSeek
选择使用DeepSeek的原因就是DeepSeek相对而言比较便宜,方便调试
"""
-*- coding: UTF-8 -*-
@Author :Leezed
@Date :2025/8/10 11:59
"""
import os.path
from TinyRAG.models.BaseModel import BaseModel
from openai import OpenAI
from utils.CFGLoader import CFGLoader, get_cfg
class DeepSeek(BaseModel):
def __init__(self, cfg):
super().__init__(cfg)
self.model_name = "deepseek-chat"
self.path = "https://api.deepseek.com"
self.history = []
self.client = None
self.load_model()
def load_model(self):
self.client = OpenAI(api_key=self.cfg.api.key.DeepSeek, base_url=self.path)
def new_session(self):
self.history = []
def chat(self, query, **kwargs):
messages = self.history + [{"role": "user", "content": query}]
response = self.client.chat.completions.create(
model=self.model_name,
messages=messages,
**kwargs
)
result = response.choices[0].message.content
self.history = self.history + [{"role": "user", "content": query}, {"role": "system", "content": result}]
return result, self.history
if __name__ == '__main__':
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../config')
cfg = CFGLoader(get_cfg(config_path, "RAG"))
model = DeepSeek(cfg)
result, history = model.chat("你好,DeepSeek!")
print(result)
组装起来
"""
-*- coding: UTF-8 -*-
@Author :Leezed
@Date :2025/8/10 13:42
"""
import os
from utils.CFGLoader import CFGLoader, get_cfg
import importlib
from TinyRAG.VectorBase.LocalVectorBase import LocalVectorBase
from TinyRAG.models.DeepSeek import DeepSeek
from TinyRAG import PROMPT_TEMPLATE
class TinyRAG:
def __init__(self, config_name='RAG'):
config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'TinyRAG/config')
self.cfg = CFGLoader(get_cfg(config_path, config_name))
self.embedding_model = self.get_embedding_model()
self.database = LocalVectorBase(self.embedding_model)
self.debug = False
self.llm = DeepSeek(self.cfg)
def get_embedding_model(self):
embedding_name = self.cfg.model.embedding.name
embedding_module = importlib.import_module(f"TinyRAG.Embedding.{embedding_name}")
embedding_class = getattr(embedding_module, self.cfg.embeddings[embedding_name]["ClassName"])
embedding_instance = embedding_class()
return embedding_instance
def debug_on(self):
self.debug = True
def debug_off(self):
self.debug = False
def get_db_content(self, question):
'''
根据提问去数据库中查询相关内容
:param question: string
:return:
'''
db_res = self.database.query(question, k=2)
# 将content 压缩成 一段文本
content = ""
for s in db_res:
content += s
if self.debug:
print("数据库查询结果:", content)
return content
def query(self, question):
self.llm.new_session()
content = self.get_db_content(question)
prompted_question = PROMPT_TEMPLATE['InternLM_PROMPT_TEMPALTE'].format(question=question, context=content)
if self.debug:
print("Prompted Question:", prompted_question)
res,his = self.llm.chat(prompted_question)
return res
其中我们的Prompt 定义为:
PROMPT_TEMPLATE = dict(
InternLM_PROMPT_TEMPALTE="""先对上下文进行内容总结,再使用上下文来回答用户的问题。如果你不知道答案,就说你不知道。总是使用中文回答。
问题: {question}
可参考的上下文:
···
{context}
···
如果给定的上下文无法让你做出回答,请回答数据库中没有这个内容,你不知道。
有用的回答:"""
)
最后就只要运行就可以了
if __name__ == '__main__':
tiny_rag = TinyRAG()
question = "请你讲讲git push的用法"
tiny_rag.debug_on()
res = tiny_rag.query(question)
print(res)
结果




1323

被折叠的 条评论
为什么被折叠?



