archery平台postgresql审核带备份代码

# -*- coding: UTF-8 -*-
""" 
@author: hhyo、yyukai
@license: Apache Licence
@file: pgsql.py
@time: 2019/03/29
"""
import re
import psycopg2
import logging
import traceback
import sqlparse

from common.config import SysConfig
from common.utils.timer import FuncTimer
from sql.utils.sql_utils import get_syntax_type
from . import EngineBase
from .models import ResultSet, ReviewSet, ReviewResult
from sql.utils.data_masking import simple_column_mask
import subprocess
import sqlglot
from sqlglot import expressions as exp
from datetime import datetime
import MySQLdb
import uuid
import os
import sqlparse
import time
import re
import json


__author__ = "hhyo、yyukai"

logger = logging.getLogger("default")


class PgSQLEngine(EngineBase):
    test_query = "SELECT 1"

    def get_connection(self, db_name=None):
        db_name = db_name or self.db_name or "postgres"
        if self.conn:
            return self.conn
        self.conn = psycopg2.connect(
            host=self.host,
            port=self.port,
            user=self.user,
            password=self.password,
            dbname=db_name,
            connect_timeout=10,
        )
        return self.conn
    @staticmethod
    def get_backup_connection():
        """备份库连接"""
        archer_config = SysConfig()
        backup_host = archer_config.get("inception_remote_backup_host")
        backup_port = int(archer_config.get("inception_remote_backup_port", 3306))
        backup_user = archer_config.get("inception_remote_backup_user")
        backup_password = archer_config.get("inception_remote_backup_password")
        return MySQLdb.connect(
            host=backup_host,
            port=backup_port,
            user=backup_user,
            passwd=backup_password,
            charset="utf8mb4",
            autocommit=True,
        )

    def _extract_table_and_where(self, sql: str):
        """使用 sqlglot 解析 UPDATE/DELETE,返回 (table_name, where_clause)"""
        try:
            parsed = sqlglot.parse_one(sql.strip().rstrip(';'), dialect="postgres")
            if isinstance(parsed, exp.Update):
                table = parsed.args.get("this")
                where = parsed.args.get("where")
            elif isinstance(parsed, exp.Delete):
                table = parsed.args.get("this")
                where = parsed.args.get("where")
            else:
                return None, None

            if not table or not isinstance(table, exp.Table):
                return None, None

            db = table.args.get("db")
            this = table.this
            table_name = f"{db}.{this}" if db and this else str(this)

            if where and isinstance(where, exp.Where):
                condition = where.this
                where_clause = condition.sql(dialect="postgres") if condition else "TRUE"
            else:
                where_clause = "TRUE"

            return table_name, where_clause

        except Exception as e:
            logger.warning(f"SQL parsing failed: {e}")
            return None, None
    def split_insert_statements_robust(self,sql_text):
        statements = []
        i = 0
        n = len(sql_text)

        while i < n:
            # 跳过空白字符
            while i < n and sql_text[i] in ' \t\r\n':
                i += 1
            if i >= n:
                break

            # 检查是否以 INSERT INTO 开头(11个字符)
            if i + 11 <= n and sql_text[i:i + 11].upper() == 'INSERT INTO':
                stmt_start = i
                in_single_quote = False  # '
                in_double_quote = False  # "
                paren_depth = 0
                j = i

                while j < n:
                    c = sql_text[j]

                    # 处理转义和引号状态(PostgreSQL 默认 standard_conforming_strings=on)
                    if not in_single_quote and not in_double_quote:
                        if c == "'":
                            in_single_quote = True
                        elif c == '"':
                            in_double_quote = True
                    elif in_single_quote:
                        if c == "'" and (j + 1 >= n or sql_text[j + 1] != "'"):
                            in_single_quote = False
                        elif c == "'" and j + 1 < n and sql_text[j + 1] == "'":
                            j += 1  # 跳过下一个 '(处理 '' 转义)
                    elif in_double_quote:
                        if c == '"':
                            in_double_quote = False
                        # 双引号内无转义,直接闭合

                    # 只有在单引号和双引号都未激活时,才处理结构字符
                    if not in_single_quote and not in_double_quote:
                        if c == '(':
                            paren_depth += 1
                        elif c == ')':
                            paren_depth -= 1
                            # 检查是否是语句结束: );
                            if paren_depth == 0 and j + 1 < n and sql_text[j + 1] == ';':
                                end = j + 2
                                statements.append(sql_text[stmt_start:end])
                                j = end
                                break

                    j += 1
                else:
                    # 未找到结束符,追加剩余部分(防错)
                    statements.append(sql_text[stmt_start:])

                i = j
            else:
                i += 1

        return statements
    def _backup_dml_data_with_pgdump(self, workflow ,db_name, table_name, statement, where_clause, statement_id):
        """
        使用 pg_dump 备份 DML 影响的数据:
          - 中间表名:backup_{uuid}_{timestamp}
          - 备份文件:/tmp/backup_{uuid}_{timestamp}.sql
          - 自动清理中间表(无论成功与否)
        """
        # 生成唯一 ID
        unique_id = str(uuid.uuid4()).replace('-', '')
        timestamp = int(time.time())
        backup_suffix = f"{unique_id}_{timestamp}"
        backup_table = f"backup_{backup_suffix}"
        backup_sql_file = f"/tmp/backup_{backup_suffix}.sql"

        # 使用独立连接,不污染 self.conn
        temp_conn = None
        temp_cursor = None


        try:
            temp_conn = psycopg2.connect(
            host=self.host,
            port=self.port,
            user=self.user,
            password=self.password,
            dbname=db_name,
            connect_timeout=10,
            )
            temp_cursor = temp_conn.cursor()

            # === 引用原表(支持 schema)===
            if '.' in table_name:
                schema, tbl = table_name.split('.', 1)
                quoted_table = f'{schema}.{tbl}'
            else:
                schema = self.user
                tbl = table_name
                quoted_table = f'{schema}.{tbl}'

            quoted_backup = f'{backup_table}'

            # === 1. 创建中间备份表 ===
            create_sql = f"""
            CREATE TABLE {quoted_backup} AS 
            SELECT * FROM {quoted_table} 
            WHERE {where_clause};
            """
            query_sql=f"""
            SELECT * FROM {quoted_table} 
            WHERE {where_clause};
            """
            logger.info(f'执行备份语句是:{create_sql}')
            temp_cursor.execute(create_sql)
            temp_conn.commit()
            logger.debug(f"Created backup table: {backup_table}")

            # === 2. 调用 pg_dump 导出该表 ===
            cmd = [
                "/u01/polardb_pg_tools/bin/pg_dump",
                "-h", self.host,
                "-p", str(self.port),
                "-U", self.user,
                "-d", db_name,
                "-t", backup_table,          # 注意:这里不带引号,pg_dump 自动匹配
                "--column-inserts",                 # 用 INSERT 而非 COPY,便于阅读和跨平台
                "--data-only",     
                "--no-owner",
                "--no-privileges",
                "--no-tablespaces",
                "-f", backup_sql_file
            ]

            env = os.environ.copy()
            env["PGPASSWORD"] = self.password
            env["PGCLIENTENCODING"] = "UTF8" 

            result = subprocess.run(cmd, env=env, capture_output=True, text=True, timeout=300)

            if result.returncode != 0:
                stderr_msg = result.stderr.decode('utf-8', errors='replace')
                raise Exception(f"pg_dump failed: {stderr_msg}")



            with open(backup_sql_file, 'r', encoding='utf-8') as f:
                lines = f.readlines()

            filtered_lines = [
                    line for line in lines[23:-6]
                             ]            
            clean_sql = ''.join(filtered_lines)
            # statements = sqlparse.parse(clean_sql)
            statements = self.split_insert_statements_robust(clean_sql)
            insert_statements = [
                str(stmt).strip().replace(f'{self.user}.{backup_table}', quoted_table,1)
                for stmt in statements
                if str(stmt).strip().upper().startswith('INSERT')
            ]
            logger.info(f'insert_statements语句{insert_statements}')
            # content = [row.replace(f'{self.user}.{backup_table}', quoted_table,1) for row in insert_statements]
            # logger.info(f'content语句{content}')
            if len(insert_statements)>0:
                insert_statements='\n'.join(insert_statements)
                metdata_back_flag = self.metdata_backup(workflow,query_sql,statement,insert_statements)
            # === 4. 删除中间表 ===
                if metdata_back_flag:
                    logger.info("Backup with pg_dump complete")
                else:
                    raise Exception("Backup with pg_dump failed")
            return backup_sql_file

        except Exception as e:
            logger.error(f"Backup with pg_dump failed: {e}")
            raise e
        finally:
            # 清理中间表(即使失败)
            if temp_cursor:
                try:
                    temp_cursor.execute(f"DROP TABLE IF EXISTS \"{backup_table}\";")
                    if temp_conn:
                        temp_conn.commit()
                except Exception as cleanup_err:
                    logger.warning(f"Failed to clean up backup table: {cleanup_err}")
                temp_cursor.close()
            if temp_conn:
                temp_conn.close()
    # end 
    # new add
    def get_rollback(self, workflow):
        list_execute_result = json.loads(workflow.sqlworkflowcontent.execute_result)
        # 回滚语句倒序展示
        list_execute_result.reverse()
        list_backup_sql = []
        try:
            # 创建连接
            conn = self.get_backup_connection()
            cur = conn.cursor()
            sql = f"""select redo_sql,undo_sql from sql_rollback where workflow_id = {workflow.id} order by id;"""
            cur.execute(f"use polar_archive;")
            cur.execute(sql)
            list_tables = cur.fetchall()
            for row in list_tables:
                redo_sql = row[0]
                undo_sql = row[1]
                # 拼接成回滚语句列表,['源语句','回滚语句']
                list_backup_sql.append([redo_sql, undo_sql])
        except Exception as e:
            logger.error(f"获取回滚语句报错,异常信息{traceback.format_exc()}")
            raise Exception(e)
        # 关闭连接
        if conn:
            conn.close()
        return list_backup_sql
        # 

    @property
    def name(self):
        return "PgSQL"

    @property
    def info(self):
        return "PgSQL engine"

    def get_all_databases(self):
        """
        获取数据库列表
        :return:
        """
        result = self.query(sql=f"SELECT datname FROM pg_database;")
        db_list = [
            row[0]
            for row in result.rows
            if row[0] not in ["postgres", "template0", "template1"]
        ]
        result.rows = db_list
        return result

    def get_all_schemas(self, db_name, **kwargs):
        """
        获取模式列表
        :return:
        """
        result = self.query(
            db_name=db_name, sql=f"select schema_name from information_schema.schemata;"
        )
        schema_list = [
            row[0]
            for row in result.rows
            if row[0]
            not in [
                "information_schema",
                "pg_catalog",
                "pg_toast_temp_1",
                "pg_temp_1",
                "pg_toast",
            ]
        ]
        result.rows = schema_list
        return result

    def get_all_tables(self, db_name, **kwargs):
        """
        获取表列表
        :param db_name:
        :param schema_name:
        :return:
        """
        schema_name = kwargs.get("schema_name")
        sql = f"""SELECT table_name 
        FROM information_schema.tables 
        where table_schema ='{schema_name}';"""
        result = self.query(db_name=db_name, sql=sql)
        tb_list = [row[0] for row in result.rows if row[0] not in ["test"]]
        result.rows = tb_list
        return result

    def get_all_columns_by_tb(self, db_name, tb_name, **kwargs):
        """
        获取字段列表
        :param db_name:
        :param tb_name:
        :param schema_name:
        :return:
        """
        schema_name = kwargs.get("schema_name")
        sql = f"""SELECT column_name
        FROM information_schema.columns 
        where table_name='{tb_name}'
        and table_schema ='{schema_name}';"""
        result = self.query(db_name=db_name, sql=sql)
        column_list = [row[0] for row in result.rows]
        result.rows = column_list
        return result

    def describe_table(self, db_name, tb_name, **kwargs):
        """
        获取表结构信息
        :param db_name:
        :param tb_name:
        :param schema_name:
        :return:
        """
        schema_name = kwargs.get("schema_name")
        sql = f"""select
        col.column_name,
        col.data_type,
        col.character_maximum_length,
        col.numeric_precision,
        col.numeric_scale,
        col.is_nullable,
        col.column_default,
        des.description
        from
        information_schema.columns col left join pg_description des on
        col.table_name::regclass = des.objoid
        and col.ordinal_position = des.objsubid
        where table_name = '{tb_name}'
        and col.table_schema = '{schema_name}'
        order by ordinal_position;"""
        result = self.query(db_name=db_name, schema_name=schema_name, sql=sql)
        return result

    def query_check(self, db_name=None, sql=""):
        # 查询语句的检查、注释去除、切分
        result = {"msg": "", "bad_query": False, "filtered_sql": sql, "has_star": False}
        # 删除注释语句,进行语法判断,执行第一条有效sql
        try:
            sql = sqlparse.format(sql, strip_comments=True)
            sql = sqlparse.split(sql)[0]
            result["filtered_sql"] = sql.strip()
        except IndexError:
            result["bad_query"] = True
            result["msg"] = "没有有效的SQL语句"
        if re.match(r"^select", sql, re.I) is None:
            result["bad_query"] = True
            result["msg"] = "不支持的查询语法类型!"
        if "*" in sql:
            result["has_star"] = True
            result["msg"] = "SQL语句中含有 * "
        return result

    def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs):
        """返回 ResultSet"""
        schema_name = kwargs.get("schema_name")
        result_set = ResultSet(full_sql=sql)
        try:
            conn = self.get_connection(db_name=db_name)
            max_execution_time = kwargs.get("max_execution_time", 0)
            cursor = conn.cursor()
            try:
                cursor.execute(f"SET statement_timeout TO {max_execution_time};")
            except:
                pass
            if schema_name:
                cursor.execute(f"SET search_path TO {schema_name};")
            cursor.execute(sql)
            effect_row = cursor.rowcount
            if int(limit_num) > 0:
                rows = cursor.fetchmany(size=int(limit_num))
            else:
                rows = cursor.fetchall()
            fields = cursor.description

            result_set.column_list = [i[0] for i in fields] if fields else []
            result_set.rows = rows
            result_set.affected_rows = effect_row
        except Exception as e:
            logger.warning(f"PgSQL命令执行报错,语句:{sql}, 错误信息:{traceback.format_exc()}")
            result_set.error = str(e)
        finally:
            if close_conn:
                self.close()
        return result_set

    def filter_sql(self, sql="", limit_num=0):
        # 对查询sql增加limit限制,# TODO limit改写待优化
        sql_lower = sql.lower().rstrip(";").strip()
        if re.match(r"^select", sql_lower):
            if re.search(r"limit\s+(\d+)$", sql_lower) is None:
                if re.search(r"limit\s+\d+\s*,\s*(\d+)$", sql_lower) is None:
                    return f"{sql.rstrip(';')} limit {limit_num};"
        return f"{sql.rstrip(';')};"

    def query_masking(self, db_name=None, sql="", resultset=None):
        """简单字段脱敏规则, 仅对select有效"""
        if re.match(r"^select", sql, re.I):
            filtered_result = simple_column_mask(self.instance, resultset)
            filtered_result.is_masked = True
        else:
            filtered_result = resultset
        return filtered_result

    def execute_check(self, db_name=None, sql=""):
        """上线单执行前的检查, 返回Review set"""
        config = SysConfig()
        check_result = ReviewSet(full_sql=sql)
        # 禁用/高危语句检查
        line = 1
        critical_ddl_regex = config.get("critical_ddl_regex", "")
        p = re.compile(critical_ddl_regex)
        check_result.syntax_type = 2  # TODO 工单类型 0、其他 1、DDL,2、DML
        split_sql = self.reassemble_sql_statements(sql)
        # for statement in sqlparse.split(sql):
        for statement in split_sql:
            statement = sqlparse.format(statement, strip_comments=True)
            # 禁用语句
            if re.match(r"^select", statement.lower()):
                result = ReviewResult(
                    id=line,
                    errlevel=2,
                    stagestatus="驳回不支持语句",
                    errormessage="仅支持DML和DDL语句,查询语句请使用SQL查询功能!",
                    sql=statement,
                )
            # 高危语句
            elif critical_ddl_regex and p.match(statement.strip().lower()):
                result = ReviewResult(
                    id=line,
                    errlevel=2,
                    stagestatus="驳回高危SQL",
                    errormessage="禁止提交匹配" + critical_ddl_regex + "条件的语句!",
                    sql=statement,
                )

            # 正常语句
            else:
                result = ReviewResult(
                    id=line,
                    errlevel=0,
                    stagestatus="Audit completed",
                    errormessage="None",
                    sql=statement,
                    affected_rows=0,
                    execute_time=0,
                )
            # 判断工单类型
            if get_syntax_type(statement) == "DDL":
                check_result.syntax_type = 1
            check_result.rows += [result]
            line += 1
        # 统计警告和错误数量
        for r in check_result.rows:
            if r.errlevel == 1:
                check_result.warning_count += 1
            if r.errlevel == 2:
                check_result.error_count += 1
        return check_result

    def metdata_backup(self, workflow, query_sql,statement,insert_statements):
        """
        :param workflow: 工单对象,作为备份记录与工单的关联列
        :param cursor: 执行SQL的当前会话游标,保存metadata
        :param redo_sql: 执行的SQL
        :return:
        """
        try:
            # 备份存放数据库和MySQL备份库统一,需新建备份用database和table,table存放备份SQL,记录使用workflow.id关联上线工单
            workflow_id = workflow.id
            conn = self.get_backup_connection()
            backup_cursor = conn.cursor()
            backup_cursor.execute(f"""create database if not exists polar_archive;""")
            backup_cursor.execute(f"use polar_archive;")
            backup_cursor.execute(
                f"""CREATE TABLE if not exists `sql_rollback` (
                                       `id` bigint(20) NOT NULL AUTO_INCREMENT,
                                       `redo_sql` text,
                                       `query_sql` text,
                                       `undo_sql` text,
                                       `workflow_id` bigint(20) NOT NULL,
                                        PRIMARY KEY (`id`),
                                        key `idx_sql_rollback_01` (`workflow_id`)
                                     ) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;"""
            )

            # 回滚SQL入库
            logger.info(f"新增语句是{insert_statements}")
            sql = f"""insert into sql_rollback(redo_sql,query_sql,undo_sql,workflow_id) values(%s, %s, %s, %s);"""
            logger.info(f'新增语句是{sql}')
            backup_cursor.execute(sql,(statement, query_sql, insert_statements, workflow_id))
        except Exception as e:
            logger.warning(f"PG备份失败,错误信息{traceback.format_exc()}")
            return False
        finally:
            # 关闭连接
            if conn:
                conn.close()
        return True

    def reassemble_sql_statements(self,raw_sql: str, dialect="postgres") :
        if not raw_sql or not raw_sql.strip():
            return ""

        try:
            parsed_list = sqlglot.parse(raw_sql, read=dialect)

            normalized_statements = []

            for parsed in parsed_list:
                if parsed is None:
                    continue  # 跳过空或无法解析的语句

                try:
                    # 生成标准 SQL
                    regenerated = parsed.sql(
                        dialect=dialect,
                        identify=True,  # 自动加引号
                        comments=False,  # 不保留注释(可选)
                        pretty=False  # 不美化
                    )

                    # 确保以分号结尾
                    if not regenerated.rstrip().endswith(';'):
                        regenerated += ';'

                    normalized_statements.append(regenerated)

                except Exception as inner_e:
                    # 单条语句失败,保留原片段(可选)
                    # 注意:这里无法轻易还原原始片段,所以跳过或记录
                    continue

            return normalized_statements

        except Exception as e:
            return f"-- [REASSEMBLE FAILED] {str(e)}\n{raw_sql}"

    def execute_workflow(self, workflow, close_conn=True):
        """执行上线单,返回Review set"""
        sql = workflow.sqlworkflowcontent.sql_content
        execute_result = ReviewSet(full_sql=sql)
        # 删除注释语句,切分语句,将切换CURRENT_SCHEMA语句增加到切分结果中
        sql = sqlparse.format(sql, strip_comments=True)
        # split_sql = sqlparse.split(sql)
        split_sql = self.reassemble_sql_statements(sql)
        line = 1
        statement = None
        db_name = workflow.db_name
        try:
            conn = self.get_connection(db_name=db_name)
            cursor = conn.cursor()
            flag=0
            # 逐条执行切分语句,追加到执行结果中
            for statement in split_sql:
                statement = statement.rstrip(";")
                # new 
                if not statement:
                    continue
                # 判断是否为 DML(UPDATE / DELETE)
                syntax_type = get_syntax_type(statement)
                is_dml = syntax_type == "DML"
                backup_file = None

                # 尝试解析 UPDATE / DELETE
                table_name, where_clause = None, None
                if is_dml and re.match(r"^(update|delete)", statement.lower()):
                    table_name, where_clause = self._extract_table_and_where(statement)
                    logger.info(f"备份的表名字{table_name},备份的表条件{where_clause}")
                # 如果是有效 DML,执行备份
                with FuncTimer() as f:                
                    if table_name and where_clause:
                        try:
                            backup_file = self._backup_dml_data_with_pgdump(
                                workflow=workflow,
                                db_name=db_name,
                                table_name=table_name,
                                statement=statement,
                                where_clause=where_clause,
                                statement_id=line
 
                            )
                            logger.info(f"Backup created for statement {line}: {backup_file}")
                        except Exception as e:
                            # 备份失败,可选择阻断执行(这里选择放弃不在执行)
                            logger.error(f"Backup failed for statement {line}, continue execution: {e}")
                            backup_file = f"BACKUP_FAILED: {str(e)}"
                            flag=1
                            break
                    # end
                    
                # 执行SQL语句
                with FuncTimer() as t:
                    cursor.execute(statement)
                    conn.commit()
                execute_result.rows.append(
                    ReviewResult(
                        id=line,
                        errlevel=0,
                        # stagestatus="Execute Successfully",
                        errormessage="None",
                        stagestatus='Execute Successfully Backup Successfully' if table_name and where_clause else 'Execute Successfully',
                        sql=statement,
                        affected_rows=cursor.rowcount,
                        execute_time=t.cost,
                        backup_time=f.cost if table_name and where_clause else '',
                        # 
                        backup_file=backup_file  
                        # end
                    )
                )
                line += 1
            if flag==1:
                raise Exception("脚本问题,无法正常备份")   
        except Exception as e:
            logger.warning(
                f"PGSQL命令执行报错,语句:{statement or sql}, 错误信息:{traceback.format_exc()}"
            )
            execute_result.error = str(e)
            # 追加当前报错语句信息到执行结果中
            execute_result.rows.append(
                ReviewResult(
                    id=line,
                    errlevel=2,
                    stagestatus="Execute Failed",
                    errormessage=f"异常信息:{e}",
                    sql=statement or sql,
                    affected_rows=0,
                    execute_time=0,
                )
            )
            line += 1
            # 报错语句后面的语句标记为审核通过、未执行,追加到执行结果中
            for statement in split_sql[line - 1 :]:
                execute_result.rows.append(
                    ReviewResult(
                        id=line,
                        errlevel=0,
                        stagestatus="Audit completed",
                        errormessage=f"前序语句失败, 未执行",
                        sql=statement,
                        affected_rows=0,
                        execute_time=0,
                    )
                )
                line += 1
        finally:
            if close_conn:
                self.close()
        return execute_result

    def close(self):
        if self.conn:
            self.conn.close()
            self.conn = None

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包
实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值