6.3 完成 RAG

背景

完成了embedding模型和向量数据库,之后就只剩下导入大语言模型和组装起来

实现配置类

计划使用yml文件作为配置类,方便后续的修改

YamlUtil

通过自定义环境变量构造器来实现yml读取环境变量

"""
-*- coding: UTF-8 -*-
@Author  :Leezed
@Date    :2025/8/10 11:43
"""

import os
import yaml


def env_var_constructor(loader, node):
    value = loader.construct_scalar(node)  # PyYAML loader的固定方法,用于根据当前节点构造一个变量值
    var_name = value.strip('${} ')  # 去除变量值(例如${USER})前后的特殊字符及空格
    return os.getenv(var_name, value)  # 尝试在环境变量中获取变量名(如USER)对应的值,获取不到使用默认值value(即原来的${USER})


yaml.SafeLoader.add_constructor('!env', env_var_constructor)  # 为SafeLoader添加新的tag和构造器


class YamlUtil():
    def __init__(self, yaml_path, yaml_name):
        """
        初始化,当前这个文件操作效率不高,懒得优化了,随便写的,哪天心情好了在优化
        :param yaml_path:
        :param yaml_name:
        """
        self.yaml_path = yaml_path
        self.yaml_name = yaml_name
        self.yaml_file = os.path.join(self.yaml_path, self.yaml_name)

        # 判断当前路径下文件是否存在,如果存在则读取,如果不存在则创建
        if not os.path.exists(self.yaml_file):
            with open(self.yaml_file, 'w') as f:
                f.write('')
            self.content = {}
            self.write_yaml()

        self.content = self.read_yaml()  # 读取到的yaml文件内容

    def __getattr__(self, key):
        return self.get_value(key)

    def get_value(self, key):
        """
        获取yaml中的值
        :param key: key路径,数组形式 ['a','b','c'] 或者字符串形式 'a.b.c'
        :return: key对应的值
        """
        really_key_path = self.is_exist_key(key)

        content = self.content
        for key in really_key_path:
            content = content[key]
        return content

    def add_root_key(self, key, value=None):
        """
        添加根key
        :param key: 添加的key
        :param value: 添加的值
        """
        if key not in self.content:
            self.content[key] = {} if value is None else value
        else:
            raise Exception('key已存在')

    def add_key(self, key_path, value=None):
        """
        添加key 当前只能添加最后一个key,比如a.b.c,只能添加c,如果b不存在就会报错
        :param key_path: key路径,数组形式 ['a','b','c'] 或者字符串形式 'a.b.c'
        :param value: 添加的值
        """
        # 判断key_path的类型,如果是字符串则转换为数组
        really_key_path = self.is_exist_key(key_path)

        content = self.content
        for key in really_key_path[:-1]:
            content = content[key]

        if really_key_path[-1] in content:
            raise Exception("当前key已存在")

        content[really_key_path[-1]] = {} if value is None else value

    def debug(self):
        print(self.content)

    # 修改某个key中的内容
    def modify_key(self, key_path, value=None):
        """
        修改内存中yaml中某个key的内容
        :param key_path: key路径,数组形式 ['a','b','c'] 或者字符串形式 'a.b.c'
        :param value: 修改后的值
        """
        # 判断key_path的类型,如果是字符串则转换为数组
        if value is None:
            value = {}
        really_key_path = self.is_exist_key(key_path, check_last=True)
        content = self.content
        for key in really_key_path[:-1]:
            content = content[key]
        content[really_key_path[-1]] = {} if value is None else value

    def is_exist_key(self, key_path, check_last=False):
        """
        判断yaml中是否存在某个key路径
        :param key_path: key_path,数组形式 ['a','b','c'] 或者字符串形式 'a.b.c'
        :param check_last: 是否检查最后一个key,因为对于最后一个key来说,可能是空的,所以默认不检查
        :return: 返回key_path,如果不存在key_path为[]
        """
        # 判断key_path的类型,如果是字符串则转换为数组
        really_key_path = key_path.split('.') if isinstance(key_path, str) else key_path

        assert isinstance(really_key_path, list), 'key_path必须是list或者str类型'

        content = self.content
        for index, key in enumerate(really_key_path):

            if not check_last and index == len(really_key_path) - 1:
                return really_key_path

            # if index == len(really_key_path) - 1 and content == {}:
            #     return True, really_key_path

            if key in content:
                content = content[key]
            else:
                raise Exception('key错误')
        return really_key_path

    def read_yaml(self):
        """
        读取整个yaml文件
        :return: 读取到的文件内容
        """
        with open(self.yaml_file, 'r') as f:
            content = yaml.safe_load(f)
        return content

    def write_yaml(self):
        """
        将content写入yaml文件
        """
        with open(self.yaml_file, 'w') as f:
            yaml.safe_dump(self.content, f)

CFGLoader

在对读取出来的yml文件封装一层,方便使用

"""
-*- coding: utf-8 -*-
@Author : Leezed
@Time : 2024/3/20 20:58
"""
from utils.YamlUtil import YamlUtil
import os

current_directory_path = os.path.dirname(os.path.abspath(__file__))
configs_path = os.path.join(current_directory_path, '../../configs')


class CFGLoader(dict):
    def __getattr__(self, item):
        if item in self:
            value = self[item]
            if isinstance(value, dict):
                return CFGLoader(value)
            else:
                return value
        else:
            raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")

    def __setattr__(self, key, value):
        print("in set")
        print(key)
        print(value)
        print("before change :", self[key])

        self[key] = value
        print("after change :", self[key])


def get_cfg(path, cfg_name):
    return YamlUtil(path, cfg_name if cfg_name.endswith(".yml") else cfg_name + ".yml").content


def env_setting(cfg_name):
    """
    获取配置文件 ,默认是local.yaml
    :param cfg_name:  如果不想获取local.yaml,可以传入其他的配置文件名
    :return: 返回配置文件
    """
    cfg = CFGLoader(get_cfg(configs_path, cfg_name if cfg_name is not None else 'local.yml'))
    return cfg

配置文件编写

model:
  name: DeepSeek
  embedding:
    name: BGE_base_zh

embeddings:
  BGE_base_zh:
    ClassName: BGEBaseZH
  GTE_base_zh:
    ClassName: GTEBaseZH
api:
  key:
    DeepSeek: !env ${DEEPSEEK_API_KEY}

导入大语言模型

定义基类

主要目的是统一定义操作方法,就一个chat和一个load_model

"""
-*- coding: UTF-8 -*-
@Author  :Leezed
@Date    :2025/8/10 11:31 
"""


class BaseModel:
    def __init__(self, cfg):
        self.cfg = cfg
        self.model = None
        pass

    def chat(self, query, history=None, **kwargs):
        raise NotImplementedError("Subclasses must implement this method")

    def load_model(self):
        raise NotImplementedError("Subclasses must implement this method")

实现DeepSeek

选择使用DeepSeek的原因就是DeepSeek相对而言比较便宜,方便调试

"""
-*- coding: UTF-8 -*-
@Author  :Leezed
@Date    :2025/8/10 11:59 
"""
import os.path

from TinyRAG.models.BaseModel import BaseModel
from openai import OpenAI
from utils.CFGLoader import CFGLoader, get_cfg


class DeepSeek(BaseModel):
    def __init__(self, cfg):
        super().__init__(cfg)
        self.model_name = "deepseek-chat"
        self.path = "https://api.deepseek.com"
        self.history = []
        self.client = None
        self.load_model()

    def load_model(self):
        self.client = OpenAI(api_key=self.cfg.api.key.DeepSeek, base_url=self.path)

    def new_session(self):
        self.history = []

    def chat(self, query, **kwargs):
        messages = self.history + [{"role": "user", "content": query}]
        response = self.client.chat.completions.create(
            model=self.model_name,
            messages=messages,
            **kwargs
        )
        result = response.choices[0].message.content
        self.history = self.history + [{"role": "user", "content": query}, {"role": "system", "content": result}]
        return result, self.history


if __name__ == '__main__':
    config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), '../config')

    cfg = CFGLoader(get_cfg(config_path, "RAG"))
    model = DeepSeek(cfg)
    result, history = model.chat("你好,DeepSeek!")
    print(result)

组装起来

"""
-*- coding: UTF-8 -*-
@Author  :Leezed
@Date    :2025/8/10 13:42 
"""
import os
from utils.CFGLoader import CFGLoader, get_cfg
import importlib
from TinyRAG.VectorBase.LocalVectorBase import LocalVectorBase
from TinyRAG.models.DeepSeek import DeepSeek
from TinyRAG import PROMPT_TEMPLATE


class TinyRAG:
    def __init__(self, config_name='RAG'):
        config_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'TinyRAG/config')
        self.cfg = CFGLoader(get_cfg(config_path, config_name))

        self.embedding_model = self.get_embedding_model()
        self.database = LocalVectorBase(self.embedding_model)

        self.debug = False

        self.llm = DeepSeek(self.cfg)

    def get_embedding_model(self):
        embedding_name = self.cfg.model.embedding.name

        embedding_module = importlib.import_module(f"TinyRAG.Embedding.{embedding_name}")
        embedding_class = getattr(embedding_module, self.cfg.embeddings[embedding_name]["ClassName"])
        embedding_instance = embedding_class()
        return embedding_instance

    def debug_on(self):
        self.debug = True

    def debug_off(self):
        self.debug = False

    def get_db_content(self, question):
        '''
        根据提问去数据库中查询相关内容
        :param question: string
        :return:
        '''
        db_res = self.database.query(question, k=2)
        # 将content 压缩成 一段文本
        content = ""
        for s in db_res:
            content += s

        if self.debug:
            print("数据库查询结果:", content)

        return content

    def query(self, question):
        self.llm.new_session()
        content = self.get_db_content(question)
        prompted_question = PROMPT_TEMPLATE['InternLM_PROMPT_TEMPALTE'].format(question=question, context=content)
        if self.debug:
            print("Prompted Question:", prompted_question)
        res,his = self.llm.chat(prompted_question)
        return res

其中我们的Prompt 定义为:

PROMPT_TEMPLATE = dict(
    InternLM_PROMPT_TEMPALTE="""先对上下文进行内容总结,再使用上下文来回答用户的问题。如果你不知道答案,就说你不知道。总是使用中文回答。
        问题: {question}
        可参考的上下文:
        ···
        {context}
        ···
        如果给定的上下文无法让你做出回答,请回答数据库中没有这个内容,你不知道。
        有用的回答:"""
)

最后就只要运行就可以了


if __name__ == '__main__':
    tiny_rag = TinyRAG()
    question = "请你讲讲git push的用法"
    tiny_rag.debug_on()
    res = tiny_rag.query(question)
    print(res)

结果

在这里插入图片描述
在这里插入图片描述

完整代码

可以再这个仓库里下载

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

理智点

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值