# -*- coding: UTF-8 -*-
"""
@author: hhyo、yyukai
@license: Apache Licence
@file: pgsql.py
@time: 2019/03/29
"""
import re
import psycopg2
import logging
import traceback
import sqlparse
from common.config import SysConfig
from common.utils.timer import FuncTimer
from sql.utils.sql_utils import get_syntax_type
from . import EngineBase
from .models import ResultSet, ReviewSet, ReviewResult
from sql.utils.data_masking import simple_column_mask
import subprocess
import sqlglot
from sqlglot import expressions as exp
from datetime import datetime
import MySQLdb
import uuid
import os
import sqlparse
import time
import re
import json
__author__ = "hhyo、yyukai"
logger = logging.getLogger("default")
class PgSQLEngine(EngineBase):
test_query = "SELECT 1"
def get_connection(self, db_name=None):
db_name = db_name or self.db_name or "postgres"
if self.conn:
return self.conn
self.conn = psycopg2.connect(
host=self.host,
port=self.port,
user=self.user,
password=self.password,
dbname=db_name,
connect_timeout=10,
)
return self.conn
@staticmethod
def get_backup_connection():
"""备份库连接"""
archer_config = SysConfig()
backup_host = archer_config.get("inception_remote_backup_host")
backup_port = int(archer_config.get("inception_remote_backup_port", 3306))
backup_user = archer_config.get("inception_remote_backup_user")
backup_password = archer_config.get("inception_remote_backup_password")
return MySQLdb.connect(
host=backup_host,
port=backup_port,
user=backup_user,
passwd=backup_password,
charset="utf8mb4",
autocommit=True,
)
def _extract_table_and_where(self, sql: str):
"""使用 sqlglot 解析 UPDATE/DELETE,返回 (table_name, where_clause)"""
try:
parsed = sqlglot.parse_one(sql.strip().rstrip(';'), dialect="postgres")
if isinstance(parsed, exp.Update):
table = parsed.args.get("this")
where = parsed.args.get("where")
elif isinstance(parsed, exp.Delete):
table = parsed.args.get("this")
where = parsed.args.get("where")
else:
return None, None
if not table or not isinstance(table, exp.Table):
return None, None
db = table.args.get("db")
this = table.this
table_name = f"{db}.{this}" if db and this else str(this)
if where and isinstance(where, exp.Where):
condition = where.this
where_clause = condition.sql(dialect="postgres") if condition else "TRUE"
else:
where_clause = "TRUE"
return table_name, where_clause
except Exception as e:
logger.warning(f"SQL parsing failed: {e}")
return None, None
def split_insert_statements_robust(self,sql_text):
statements = []
i = 0
n = len(sql_text)
while i < n:
# 跳过空白字符
while i < n and sql_text[i] in ' \t\r\n':
i += 1
if i >= n:
break
# 检查是否以 INSERT INTO 开头(11个字符)
if i + 11 <= n and sql_text[i:i + 11].upper() == 'INSERT INTO':
stmt_start = i
in_single_quote = False # '
in_double_quote = False # "
paren_depth = 0
j = i
while j < n:
c = sql_text[j]
# 处理转义和引号状态(PostgreSQL 默认 standard_conforming_strings=on)
if not in_single_quote and not in_double_quote:
if c == "'":
in_single_quote = True
elif c == '"':
in_double_quote = True
elif in_single_quote:
if c == "'" and (j + 1 >= n or sql_text[j + 1] != "'"):
in_single_quote = False
elif c == "'" and j + 1 < n and sql_text[j + 1] == "'":
j += 1 # 跳过下一个 '(处理 '' 转义)
elif in_double_quote:
if c == '"':
in_double_quote = False
# 双引号内无转义,直接闭合
# 只有在单引号和双引号都未激活时,才处理结构字符
if not in_single_quote and not in_double_quote:
if c == '(':
paren_depth += 1
elif c == ')':
paren_depth -= 1
# 检查是否是语句结束: );
if paren_depth == 0 and j + 1 < n and sql_text[j + 1] == ';':
end = j + 2
statements.append(sql_text[stmt_start:end])
j = end
break
j += 1
else:
# 未找到结束符,追加剩余部分(防错)
statements.append(sql_text[stmt_start:])
i = j
else:
i += 1
return statements
def _backup_dml_data_with_pgdump(self, workflow ,db_name, table_name, statement, where_clause, statement_id):
"""
使用 pg_dump 备份 DML 影响的数据:
- 中间表名:backup_{uuid}_{timestamp}
- 备份文件:/tmp/backup_{uuid}_{timestamp}.sql
- 自动清理中间表(无论成功与否)
"""
# 生成唯一 ID
unique_id = str(uuid.uuid4()).replace('-', '')
timestamp = int(time.time())
backup_suffix = f"{unique_id}_{timestamp}"
backup_table = f"backup_{backup_suffix}"
backup_sql_file = f"/tmp/backup_{backup_suffix}.sql"
# 使用独立连接,不污染 self.conn
temp_conn = None
temp_cursor = None
try:
temp_conn = psycopg2.connect(
host=self.host,
port=self.port,
user=self.user,
password=self.password,
dbname=db_name,
connect_timeout=10,
)
temp_cursor = temp_conn.cursor()
# === 引用原表(支持 schema)===
if '.' in table_name:
schema, tbl = table_name.split('.', 1)
quoted_table = f'{schema}.{tbl}'
else:
schema = self.user
tbl = table_name
quoted_table = f'{schema}.{tbl}'
quoted_backup = f'{backup_table}'
# === 1. 创建中间备份表 ===
create_sql = f"""
CREATE TABLE {quoted_backup} AS
SELECT * FROM {quoted_table}
WHERE {where_clause};
"""
query_sql=f"""
SELECT * FROM {quoted_table}
WHERE {where_clause};
"""
logger.info(f'执行备份语句是:{create_sql}')
temp_cursor.execute(create_sql)
temp_conn.commit()
logger.debug(f"Created backup table: {backup_table}")
# === 2. 调用 pg_dump 导出该表 ===
cmd = [
"/u01/polardb_pg_tools/bin/pg_dump",
"-h", self.host,
"-p", str(self.port),
"-U", self.user,
"-d", db_name,
"-t", backup_table, # 注意:这里不带引号,pg_dump 自动匹配
"--column-inserts", # 用 INSERT 而非 COPY,便于阅读和跨平台
"--data-only",
"--no-owner",
"--no-privileges",
"--no-tablespaces",
"-f", backup_sql_file
]
env = os.environ.copy()
env["PGPASSWORD"] = self.password
env["PGCLIENTENCODING"] = "UTF8"
result = subprocess.run(cmd, env=env, capture_output=True, text=True, timeout=300)
if result.returncode != 0:
stderr_msg = result.stderr.decode('utf-8', errors='replace')
raise Exception(f"pg_dump failed: {stderr_msg}")
with open(backup_sql_file, 'r', encoding='utf-8') as f:
lines = f.readlines()
filtered_lines = [
line for line in lines[23:-6]
]
clean_sql = ''.join(filtered_lines)
# statements = sqlparse.parse(clean_sql)
statements = self.split_insert_statements_robust(clean_sql)
insert_statements = [
str(stmt).strip().replace(f'{self.user}.{backup_table}', quoted_table,1)
for stmt in statements
if str(stmt).strip().upper().startswith('INSERT')
]
logger.info(f'insert_statements语句{insert_statements}')
# content = [row.replace(f'{self.user}.{backup_table}', quoted_table,1) for row in insert_statements]
# logger.info(f'content语句{content}')
if len(insert_statements)>0:
insert_statements='\n'.join(insert_statements)
metdata_back_flag = self.metdata_backup(workflow,query_sql,statement,insert_statements)
# === 4. 删除中间表 ===
if metdata_back_flag:
logger.info("Backup with pg_dump complete")
else:
raise Exception("Backup with pg_dump failed")
return backup_sql_file
except Exception as e:
logger.error(f"Backup with pg_dump failed: {e}")
raise e
finally:
# 清理中间表(即使失败)
if temp_cursor:
try:
temp_cursor.execute(f"DROP TABLE IF EXISTS \"{backup_table}\";")
if temp_conn:
temp_conn.commit()
except Exception as cleanup_err:
logger.warning(f"Failed to clean up backup table: {cleanup_err}")
temp_cursor.close()
if temp_conn:
temp_conn.close()
# end
# new add
def get_rollback(self, workflow):
list_execute_result = json.loads(workflow.sqlworkflowcontent.execute_result)
# 回滚语句倒序展示
list_execute_result.reverse()
list_backup_sql = []
try:
# 创建连接
conn = self.get_backup_connection()
cur = conn.cursor()
sql = f"""select redo_sql,undo_sql from sql_rollback where workflow_id = {workflow.id} order by id;"""
cur.execute(f"use polar_archive;")
cur.execute(sql)
list_tables = cur.fetchall()
for row in list_tables:
redo_sql = row[0]
undo_sql = row[1]
# 拼接成回滚语句列表,['源语句','回滚语句']
list_backup_sql.append([redo_sql, undo_sql])
except Exception as e:
logger.error(f"获取回滚语句报错,异常信息{traceback.format_exc()}")
raise Exception(e)
# 关闭连接
if conn:
conn.close()
return list_backup_sql
#
@property
def name(self):
return "PgSQL"
@property
def info(self):
return "PgSQL engine"
def get_all_databases(self):
"""
获取数据库列表
:return:
"""
result = self.query(sql=f"SELECT datname FROM pg_database;")
db_list = [
row[0]
for row in result.rows
if row[0] not in ["postgres", "template0", "template1"]
]
result.rows = db_list
return result
def get_all_schemas(self, db_name, **kwargs):
"""
获取模式列表
:return:
"""
result = self.query(
db_name=db_name, sql=f"select schema_name from information_schema.schemata;"
)
schema_list = [
row[0]
for row in result.rows
if row[0]
not in [
"information_schema",
"pg_catalog",
"pg_toast_temp_1",
"pg_temp_1",
"pg_toast",
]
]
result.rows = schema_list
return result
def get_all_tables(self, db_name, **kwargs):
"""
获取表列表
:param db_name:
:param schema_name:
:return:
"""
schema_name = kwargs.get("schema_name")
sql = f"""SELECT table_name
FROM information_schema.tables
where table_schema ='{schema_name}';"""
result = self.query(db_name=db_name, sql=sql)
tb_list = [row[0] for row in result.rows if row[0] not in ["test"]]
result.rows = tb_list
return result
def get_all_columns_by_tb(self, db_name, tb_name, **kwargs):
"""
获取字段列表
:param db_name:
:param tb_name:
:param schema_name:
:return:
"""
schema_name = kwargs.get("schema_name")
sql = f"""SELECT column_name
FROM information_schema.columns
where table_name='{tb_name}'
and table_schema ='{schema_name}';"""
result = self.query(db_name=db_name, sql=sql)
column_list = [row[0] for row in result.rows]
result.rows = column_list
return result
def describe_table(self, db_name, tb_name, **kwargs):
"""
获取表结构信息
:param db_name:
:param tb_name:
:param schema_name:
:return:
"""
schema_name = kwargs.get("schema_name")
sql = f"""select
col.column_name,
col.data_type,
col.character_maximum_length,
col.numeric_precision,
col.numeric_scale,
col.is_nullable,
col.column_default,
des.description
from
information_schema.columns col left join pg_description des on
col.table_name::regclass = des.objoid
and col.ordinal_position = des.objsubid
where table_name = '{tb_name}'
and col.table_schema = '{schema_name}'
order by ordinal_position;"""
result = self.query(db_name=db_name, schema_name=schema_name, sql=sql)
return result
def query_check(self, db_name=None, sql=""):
# 查询语句的检查、注释去除、切分
result = {"msg": "", "bad_query": False, "filtered_sql": sql, "has_star": False}
# 删除注释语句,进行语法判断,执行第一条有效sql
try:
sql = sqlparse.format(sql, strip_comments=True)
sql = sqlparse.split(sql)[0]
result["filtered_sql"] = sql.strip()
except IndexError:
result["bad_query"] = True
result["msg"] = "没有有效的SQL语句"
if re.match(r"^select", sql, re.I) is None:
result["bad_query"] = True
result["msg"] = "不支持的查询语法类型!"
if "*" in sql:
result["has_star"] = True
result["msg"] = "SQL语句中含有 * "
return result
def query(self, db_name=None, sql="", limit_num=0, close_conn=True, **kwargs):
"""返回 ResultSet"""
schema_name = kwargs.get("schema_name")
result_set = ResultSet(full_sql=sql)
try:
conn = self.get_connection(db_name=db_name)
max_execution_time = kwargs.get("max_execution_time", 0)
cursor = conn.cursor()
try:
cursor.execute(f"SET statement_timeout TO {max_execution_time};")
except:
pass
if schema_name:
cursor.execute(f"SET search_path TO {schema_name};")
cursor.execute(sql)
effect_row = cursor.rowcount
if int(limit_num) > 0:
rows = cursor.fetchmany(size=int(limit_num))
else:
rows = cursor.fetchall()
fields = cursor.description
result_set.column_list = [i[0] for i in fields] if fields else []
result_set.rows = rows
result_set.affected_rows = effect_row
except Exception as e:
logger.warning(f"PgSQL命令执行报错,语句:{sql}, 错误信息:{traceback.format_exc()}")
result_set.error = str(e)
finally:
if close_conn:
self.close()
return result_set
def filter_sql(self, sql="", limit_num=0):
# 对查询sql增加limit限制,# TODO limit改写待优化
sql_lower = sql.lower().rstrip(";").strip()
if re.match(r"^select", sql_lower):
if re.search(r"limit\s+(\d+)$", sql_lower) is None:
if re.search(r"limit\s+\d+\s*,\s*(\d+)$", sql_lower) is None:
return f"{sql.rstrip(';')} limit {limit_num};"
return f"{sql.rstrip(';')};"
def query_masking(self, db_name=None, sql="", resultset=None):
"""简单字段脱敏规则, 仅对select有效"""
if re.match(r"^select", sql, re.I):
filtered_result = simple_column_mask(self.instance, resultset)
filtered_result.is_masked = True
else:
filtered_result = resultset
return filtered_result
def execute_check(self, db_name=None, sql=""):
"""上线单执行前的检查, 返回Review set"""
config = SysConfig()
check_result = ReviewSet(full_sql=sql)
# 禁用/高危语句检查
line = 1
critical_ddl_regex = config.get("critical_ddl_regex", "")
p = re.compile(critical_ddl_regex)
check_result.syntax_type = 2 # TODO 工单类型 0、其他 1、DDL,2、DML
split_sql = self.reassemble_sql_statements(sql)
# for statement in sqlparse.split(sql):
for statement in split_sql:
statement = sqlparse.format(statement, strip_comments=True)
# 禁用语句
if re.match(r"^select", statement.lower()):
result = ReviewResult(
id=line,
errlevel=2,
stagestatus="驳回不支持语句",
errormessage="仅支持DML和DDL语句,查询语句请使用SQL查询功能!",
sql=statement,
)
# 高危语句
elif critical_ddl_regex and p.match(statement.strip().lower()):
result = ReviewResult(
id=line,
errlevel=2,
stagestatus="驳回高危SQL",
errormessage="禁止提交匹配" + critical_ddl_regex + "条件的语句!",
sql=statement,
)
# 正常语句
else:
result = ReviewResult(
id=line,
errlevel=0,
stagestatus="Audit completed",
errormessage="None",
sql=statement,
affected_rows=0,
execute_time=0,
)
# 判断工单类型
if get_syntax_type(statement) == "DDL":
check_result.syntax_type = 1
check_result.rows += [result]
line += 1
# 统计警告和错误数量
for r in check_result.rows:
if r.errlevel == 1:
check_result.warning_count += 1
if r.errlevel == 2:
check_result.error_count += 1
return check_result
def metdata_backup(self, workflow, query_sql,statement,insert_statements):
"""
:param workflow: 工单对象,作为备份记录与工单的关联列
:param cursor: 执行SQL的当前会话游标,保存metadata
:param redo_sql: 执行的SQL
:return:
"""
try:
# 备份存放数据库和MySQL备份库统一,需新建备份用database和table,table存放备份SQL,记录使用workflow.id关联上线工单
workflow_id = workflow.id
conn = self.get_backup_connection()
backup_cursor = conn.cursor()
backup_cursor.execute(f"""create database if not exists polar_archive;""")
backup_cursor.execute(f"use polar_archive;")
backup_cursor.execute(
f"""CREATE TABLE if not exists `sql_rollback` (
`id` bigint(20) NOT NULL AUTO_INCREMENT,
`redo_sql` text,
`query_sql` text,
`undo_sql` text,
`workflow_id` bigint(20) NOT NULL,
PRIMARY KEY (`id`),
key `idx_sql_rollback_01` (`workflow_id`)
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4;"""
)
# 回滚SQL入库
logger.info(f"新增语句是{insert_statements}")
sql = f"""insert into sql_rollback(redo_sql,query_sql,undo_sql,workflow_id) values(%s, %s, %s, %s);"""
logger.info(f'新增语句是{sql}')
backup_cursor.execute(sql,(statement, query_sql, insert_statements, workflow_id))
except Exception as e:
logger.warning(f"PG备份失败,错误信息{traceback.format_exc()}")
return False
finally:
# 关闭连接
if conn:
conn.close()
return True
def reassemble_sql_statements(self,raw_sql: str, dialect="postgres") :
if not raw_sql or not raw_sql.strip():
return ""
try:
parsed_list = sqlglot.parse(raw_sql, read=dialect)
normalized_statements = []
for parsed in parsed_list:
if parsed is None:
continue # 跳过空或无法解析的语句
try:
# 生成标准 SQL
regenerated = parsed.sql(
dialect=dialect,
identify=True, # 自动加引号
comments=False, # 不保留注释(可选)
pretty=False # 不美化
)
# 确保以分号结尾
if not regenerated.rstrip().endswith(';'):
regenerated += ';'
normalized_statements.append(regenerated)
except Exception as inner_e:
# 单条语句失败,保留原片段(可选)
# 注意:这里无法轻易还原原始片段,所以跳过或记录
continue
return normalized_statements
except Exception as e:
return f"-- [REASSEMBLE FAILED] {str(e)}\n{raw_sql}"
def execute_workflow(self, workflow, close_conn=True):
"""执行上线单,返回Review set"""
sql = workflow.sqlworkflowcontent.sql_content
execute_result = ReviewSet(full_sql=sql)
# 删除注释语句,切分语句,将切换CURRENT_SCHEMA语句增加到切分结果中
sql = sqlparse.format(sql, strip_comments=True)
# split_sql = sqlparse.split(sql)
split_sql = self.reassemble_sql_statements(sql)
line = 1
statement = None
db_name = workflow.db_name
try:
conn = self.get_connection(db_name=db_name)
cursor = conn.cursor()
flag=0
# 逐条执行切分语句,追加到执行结果中
for statement in split_sql:
statement = statement.rstrip(";")
# new
if not statement:
continue
# 判断是否为 DML(UPDATE / DELETE)
syntax_type = get_syntax_type(statement)
is_dml = syntax_type == "DML"
backup_file = None
# 尝试解析 UPDATE / DELETE
table_name, where_clause = None, None
if is_dml and re.match(r"^(update|delete)", statement.lower()):
table_name, where_clause = self._extract_table_and_where(statement)
logger.info(f"备份的表名字{table_name},备份的表条件{where_clause}")
# 如果是有效 DML,执行备份
with FuncTimer() as f:
if table_name and where_clause:
try:
backup_file = self._backup_dml_data_with_pgdump(
workflow=workflow,
db_name=db_name,
table_name=table_name,
statement=statement,
where_clause=where_clause,
statement_id=line
)
logger.info(f"Backup created for statement {line}: {backup_file}")
except Exception as e:
# 备份失败,可选择阻断执行(这里选择放弃不在执行)
logger.error(f"Backup failed for statement {line}, continue execution: {e}")
backup_file = f"BACKUP_FAILED: {str(e)}"
flag=1
break
# end
# 执行SQL语句
with FuncTimer() as t:
cursor.execute(statement)
conn.commit()
execute_result.rows.append(
ReviewResult(
id=line,
errlevel=0,
# stagestatus="Execute Successfully",
errormessage="None",
stagestatus='Execute Successfully Backup Successfully' if table_name and where_clause else 'Execute Successfully',
sql=statement,
affected_rows=cursor.rowcount,
execute_time=t.cost,
backup_time=f.cost if table_name and where_clause else '',
#
backup_file=backup_file
# end
)
)
line += 1
if flag==1:
raise Exception("脚本问题,无法正常备份")
except Exception as e:
logger.warning(
f"PGSQL命令执行报错,语句:{statement or sql}, 错误信息:{traceback.format_exc()}"
)
execute_result.error = str(e)
# 追加当前报错语句信息到执行结果中
execute_result.rows.append(
ReviewResult(
id=line,
errlevel=2,
stagestatus="Execute Failed",
errormessage=f"异常信息:{e}",
sql=statement or sql,
affected_rows=0,
execute_time=0,
)
)
line += 1
# 报错语句后面的语句标记为审核通过、未执行,追加到执行结果中
for statement in split_sql[line - 1 :]:
execute_result.rows.append(
ReviewResult(
id=line,
errlevel=0,
stagestatus="Audit completed",
errormessage=f"前序语句失败, 未执行",
sql=statement,
affected_rows=0,
execute_time=0,
)
)
line += 1
finally:
if close_conn:
self.close()
return execute_result
def close(self):
if self.conn:
self.conn.close()
self.conn = None

05-16
973
973
12-30
1025
1025

被折叠的 条评论
为什么被折叠?



