单张多光谱图像的 Fmask 云层检测实现

Python3.8

Python3.8

Conda
Python

Python 是一种高级、解释型、通用的编程语言,以其简洁易读的语法而闻名,适用于广泛的应用,包括Web开发、数据分析、人工智能和自动化脚本

在遥感影像处理中,云检测是一个非常关键的步骤。云不仅会遮挡地表信息,还会干扰指数计算、变化检测和土地覆盖分析。本文将分享一个基于 Python 的 Fmask(Function of Mask)算法实现,支持单张多光谱图像的云检测。

一、Fmask 算法简介

Fmask 是 Landsat 和 Sentinel 等遥感影像的标准云检测方法,其核心思想是结合光谱特征和阈值规则判断像素是否为云。典型的判断依据包括:

  • NDVI (Normalized Difference Vegetation Index)
    云区通常 NDVI 值较低,因为云覆盖了植被。

  • 白度指数 (Whiteness Index)
    云在可见光波段反射较高且均匀,可通过红绿蓝波段的差异来衡量。

  • NIR 与 SWIR 比率
    近红外 (NIR) 对植被敏感,而短波红外 (SWIR) 对云雪敏感,通过差异指数可进一步分离云。

  • 蓝红比率
    蓝光强度在云区通常大于红光,可辅助检测薄云或雪。

  • 形态学处理
    利用膨胀、腐蚀、孔洞填充等操作消除噪声和小面积误判。

通过以上步骤,Fmask 可以生成云概率图和二值云掩膜,为后续分析提供可靠的数据基础。

二、Python 实现概览

本文提供的 FmaskCloudDetector 类实现了单张多光谱图像的云检测,核心功能包括:

  • 波段提取:从多光谱图像中提取蓝、绿、红、NIR、SWIR 波段,支持 Landsat、Sentinel 及通用图像。

  • 云概率计算:根据 NDVI、白度指数、NIR-SWIR 差异、蓝红比率等规则计算每个像素的云概率。

  • 自适应阈值:结合概率分布动态调整阈值,提高云检测精度。

  • 形态学后处理:移除小云块,填充云内孔洞,生成平滑掩膜。

  • 可视化与输出:同时生成云概率图和二值云掩膜,可保存为 GeoTIFF。

import numpy as np
import rasterio
from skimage.filters import threshold_otsu
from skimage import io
import matplotlib.pyplot as plt
from scipy.ndimage import uniform_filter1d, binary_dilation, binary_erosion, binary_opening, binary_closing, \
    generate_binary_structure, label
from scipy.ndimage import gaussian_filter


class FmaskCloudDetector:
    """
    Fmask 算法实现 - 单张多光谱图像云检测
    """

    def __init__(self, cloud_prob_threshold=0.8):
        self.nir_threshold = 0.3
        self.swir_threshold = 0.1
        self.cloud_prob_threshold = cloud_prob_threshold

    def extract_bands_from_multispectral(self, data, sensor_type='landsat8'):
        """
        从多光谱图像中提取各个波段

        Args:
            data: 多光谱图像数组,形状为 (bands, height, width) 或 (height, width, bands)
            sensor_type: 传感器类型 ('landsat8', 'sentinel2', 'generic')

        Returns:
            bands_dict: 包含各波段的字典
        """
        if data.ndim == 3:
            if data.shape[0] <= 10:
                bands_data = data
            else:
                bands_data = np.transpose(data, (2, 0, 1))
        else:
            raise ValueError("输入数据必须是 3 维数组")

        num_bands = bands_data.shape[0]

        if sensor_type == 'landsat8':
            band_names = ['coastal', 'blue', 'green', 'red', 'nir',
                          'swir1', 'swir2', 'cirrus', 'thermal1', 'thermal2']
        elif sensor_type == 'sentinel2':
            band_names = ['b1', 'b2', 'b3', 'b4', 'b5', 'b6', 'b7',
                          'b8', 'b8a', 'b9', 'b10', 'b11', 'b12']
        else:
            band_names = [f'band{i + 1}' for i in range(num_bands)]

        bands_dict = {}
        for i in range(min(num_bands, len(band_names))):
            bands_dict[band_names[i]] = bands_data[i].astype(np.float64)

        if num_bands >= 4:
            bands_dict['blue'] = bands_dict.get('blue', bands_data[1] if num_bands > 1 else bands_data[0])
            bands_dict['green'] = bands_dict.get('green', bands_data[2] if num_bands > 2 else bands_data[0])
            bands_dict['red'] = bands_dict.get('red', bands_data[3] if num_bands > 3 else bands_data[0])
            bands_dict['nir'] = bands_dict.get('nir', bands_data[4] if num_bands > 4 else bands_data[0])

        if num_bands >= 6:
            bands_dict['swir1'] = bands_dict.get('swir1', bands_data[5] if num_bands > 5 else bands_data[0])

        if num_bands >= 7:
            bands_dict['swir2'] = bands_dict.get('swir2', bands_data[6] if num_bands > 6 else bands_data[0])

        return bands_dict

    def calculate_ndvi(self, nir, red):
        """计算 NDVI"""
        return (nir - red) / (nir + red + 1e-10)

    def calculate_whiteness_index(self, blue, green, red):
        """计算云白度指数"""
        mean_vis = (blue + green + red) / 3.0
        whiteness = np.abs(blue - mean_vis) + np.abs(green - mean_vis) + np.abs(red - mean_vis)
        return whiteness

    def calculate_nir_dsni(self, nir, swir1):
        """计算 NIR-SWIR 差异雪指数"""
        return (nir - swir1) / (nir + swir1 + 1e-10)

    def detect_clouds(self, bands_dict):
        """
        Fmask 云检测主函数

        Args:
            bands_dict: 包含各波段的字典

        Returns:
            cloud_mask: 云检测掩膜 (云区为白色/1,非云区为黑色/0)
            probability: 云概率图
        """
        blue = bands_dict.get('blue')
        green = bands_dict.get('green')
        red = bands_dict.get('red')
        nir = bands_dict.get('nir')
        swir1 = bands_dict.get('swir1', nir * 0.5)
        swir2 = bands_dict.get('swir2', swir1 * 0.8)

        if any(b is None for b in [blue, green, red, nir]):
            raise ValueError("需要至少 blue, green, red, nir 波段")

        print(f"波段数据统计信息:")
        print(f"  Blue - min: {np.nanmin(blue):.4f}, max: {np.nanmax(blue):.4f}, mean: {np.nanmean(blue):.4f}")
        print(f"  Green - min: {np.nanmin(green):.4f}, max: {np.nanmax(green):.4f}, mean: {np.nanmean(green):.4f}")
        print(f"  Red - min: {np.nanmin(red):.4f}, max: {np.nanmax(red):.4f}, mean: {np.nanmean(red):.4f}")
        print(f"  NIR - min: {np.nanmin(nir):.4f}, max: {np.nanmax(nir):.4f}, mean: {np.nanmean(nir):.4f}")

        if np.nanmax(blue) > 1 or np.nanmax(red) > 1 or np.nanmax(nir) > 1:
            print("检测到波段值范围大于 1,自动归一化到 0-1 范围...")
            max_val = max(np.nanmax(blue), np.nanmax(green),
                          np.nanmax(red), np.nanmax(nir))
            blue = blue / max_val
            green = green / max_val
            red = red / max_val
            nir = nir / max_val
            if swir1 is not None:
                swir1 = swir1 / max_val
            if swir2 is not None:
                swir2 = swir2 / max_val

        probability = np.zeros_like(blue, dtype=np.float64)

        ndvi = self.calculate_ndvi(nir, red)
        ndvi_valid = ~np.isnan(ndvi)
        ndvi_mask = ndvi < 0.8
        probability[ndvi_mask & ndvi_valid] += 0.25

        whiteness = self.calculate_whiteness_index(blue, green, red)
        whiteness_norm = whiteness / (np.nanmax(whiteness) + 1e-10)
        probability += (1 - whiteness_norm) * 0.25

        swir_ratio = swir1 / (swir2 + 1e-10)
        swir_ratio_mask = swir_ratio < 1.5
        probability[swir_ratio_mask & np.isfinite(swir_ratio)] += 0.15

        nir_mask = nir > self.nir_threshold
        probability[nir_mask] += 0.15

        dsni = self.calculate_nir_dsni(nir, swir1)
        dsni_valid = np.isfinite(dsni)
        dsni_positive = dsni > 0
        if np.any(dsni_positive & dsni_valid):
            dsni_threshold = np.percentile(dsni[dsni_positive & dsni_valid], 85)
        else:
            dsni_threshold = 0.1
        dsni_mask = dsni > dsni_threshold
        probability[~dsni_mask & dsni_valid] += 0.1

        blue_red_ratio = blue / (red + 1e-10)
        br_mask = blue_red_ratio > 1.2
        probability[br_mask & np.isfinite(blue_red_ratio)] += 0.1

        probability = np.clip(probability, 0, 1)

        print(
            f"\n概率图统计:min={np.nanmin(probability):.4f}, max={np.nanmax(probability):.4f}, mean={np.nanmean(probability):.4f}")

        high_prob_pixels = np.sum(probability >= self.cloud_prob_threshold)
        total_pixels = probability.size
        print(
            f"高概率云像素 (≥{self.cloud_prob_threshold}): {high_prob_pixels}/{total_pixels} ({high_prob_pixels / total_pixels * 100:.2f}%)")

        cloud_mask = probability >= self.cloud_prob_threshold

        if high_prob_pixels > 0:
            adaptive_threshold = np.percentile(probability[probability >= self.cloud_prob_threshold], 50)
            adaptive_threshold = max(adaptive_threshold, self.cloud_prob_threshold)
            print(f"在概率≥{self.cloud_prob_threshold}的区域内使用自适应阈值:{adaptive_threshold:.4f}")
            cloud_mask = probability >= adaptive_threshold
        else:
            print(f"警告:没有检测到概率≥{self.cloud_prob_threshold}的像素,跳过云检测")
            cloud_mask = np.zeros_like(probability, dtype=bool)

        cloud_mask = binary_dilation(cloud_mask, iterations=2)
        cloud_mask = binary_erosion(cloud_mask, iterations=1)

        cloud_mask = self.process_cloud_masks_morphological(cloud_mask)

        cloud_pixels = np.sum(cloud_mask)
        total_pixels = cloud_mask.size
        print(f"云检测像素数:{cloud_pixels}/{total_pixels} ({cloud_pixels / total_pixels * 100:.2f}%)")

        return cloud_mask.astype(np.uint8), probability

    def process_cloud_masks_morphological(self, cloud_mask, min_cloud_size=500, hole_fill_size=100):
        """
        对云层掩膜进行形态学后处理,填充云层中间的所有空洞为白色

        Args:
            cloud_mask: 初始云检测掩膜
            min_cloud_size: 最小云层面积阈值(像素数),小于此值的云层将被移除
            hole_fill_size: 空洞填充阈值(像素数),小于此值的内部空洞将被填充为白色

        Returns:
            processed_mask: 处理后的云层掩膜(所有内部空洞已填充为白色)
        """
        structure = generate_binary_structure(2, 2)

        labeled_array, num_features = label(cloud_mask, structure=structure)

        remove_small_clouds = np.zeros_like(labeled_array, dtype=bool)
        cloud_sizes = {}
        for i in range(1, num_features + 1):
            cloud_region = (labeled_array == i)
            cloud_size = np.sum(cloud_region)
            cloud_sizes[i] = cloud_size

            if cloud_size >= min_cloud_size:
                remove_small_clouds |= cloud_region
            else:
                print(f"移除小云层区域 #{i}: {cloud_size} 像素")

        dilated_mask = binary_dilation(remove_small_clouds, structure=structure, iterations=3)
        eroded_mask = binary_erosion(dilated_mask, structure=structure, iterations=3)

        filled_holes = np.zeros_like(eroded_mask, dtype=bool)
        labeled_holes, num_holes = label(~eroded_mask, structure=structure)

        for i in range(1, num_holes + 1):
            hole_region = (labeled_holes == i)
            hole_size = np.sum(hole_region)

            bounding_box = np.where(hole_region)
            y_min, y_max = np.min(bounding_box[0]), np.max(bounding_box[0])
            x_min, x_max = np.min(bounding_box[1]), np.max(bounding_box[1])

            margin = 5
            y_min_pad = max(0, y_min - margin)
            y_max_pad = min(eroded_mask.shape[0] - 1, y_max + margin)
            x_min_pad = max(0, x_min - margin)
            x_max_pad = min(eroded_mask.shape[1] - 1, x_max + margin)

            is_edge_hole = (y_min_pad == 0 or y_max_pad == eroded_mask.shape[0] - 1 or
                            x_min_pad == 0 or x_max_pad == eroded_mask.shape[1] - 1)

            if not is_edge_hole:
                filled_holes |= hole_region
                print(f"填充云层内部空洞 #{i}: {hole_size} 像素 (填充为白色)")
            else:
                check_y = max(0, y_min - 1)
                check_x = max(0, x_min - 1)

                surrounding_cloud = eroded_mask[check_y:min(eroded_mask.shape[0], y_max + 2),
                                    check_x:min(eroded_mask.shape[1], x_max + 2)]
                cloud_coverage = np.sum(surrounding_cloud) / surrounding_cloud.size

                if cloud_coverage > 0.5:
                    filled_holes |= hole_region
                    print(f"填充边缘空洞 #{i}: {hole_size} 像素 (被云包围,填充为白色)")
                else:
                    print(f"保留边缘空洞 #{i}: {hole_size} 像素")

        final_mask = eroded_mask | filled_holes

        final_mask = binary_closing(final_mask, structure=structure, iterations=3)

        print(f"\n形态学后处理完成:")
        print(f"  - 初始云层数量:{num_features}")
        print(f"  - 保留云层数量:{len([size for size in cloud_sizes.values() if size >= min_cloud_size])}")
        print(
            f"  - 已填充为白色的空洞总数:{np.sum(filled_holes)}")

        return final_mask

    def process_single_image(self, input_path, output_path=None,
                             sensor_type='generic', display_results=True):
        """
        处理单张多光谱图像

        Args:
            input_path: 输入图像路径(多光谱 GeoTIFF)
            output_path: 输出路径(可选)
            sensor_type: 传感器类型
            display_results: 是否显示结果

        Returns:
            cloud_mask, shadow_mask, cloud_prob: 云掩膜、阴影掩膜、云概率
        """
        with rasterio.open(input_path) as src:
            if src.count > 1:
                data = src.read()
            else:
                raise ValueError("输入图像必须包含多个波段")

            profile = src.profile.copy()

        print(f"输入数据形状:{data.shape}")
        print(f"输入数据类型:{data.dtype}")
        print(f"波段数:{data.shape[0]}, 高度:{data.shape[1]}, 宽度:{data.shape[2]}")

        bands_dict = self.extract_bands_from_multispectral(data, sensor_type)

        cloud_mask, cloud_prob = self.detect_clouds(bands_dict)

        cloud_mask_inv = cloud_mask

        shadow_mask = np.zeros_like(cloud_mask)

        if display_results:
            fig, axes = plt.subplots(1, 2, figsize=(18, 6))

            im0 = axes[0].imshow(cloud_prob, cmap='jet')
            axes[0].set_title('cloud probability distribution', fontsize=14, color='red')
            axes[0].axis('off')
            plt.colorbar(im0, ax=axes[0], fraction=0.046, pad=0.04)

            axes[1].imshow(cloud_mask_inv, cmap='gray')
            axes[1].set_title('Cloud detection results', fontsize=14, color='blue')
            axes[1].axis('off')

            # axes[2].imshow(shadow_mask, cmap='gray')
            # axes[2].set_title('Shadow detection result', fontsize=14, color='darkgreen')
            # axes[2].axis('off')

            plt.tight_layout()
            plt.savefig(input_path.replace('.tif', '_fmask_result.png').replace('.png', '_fmask_result.png'),
                        dpi=300, bbox_inches='tight')
            plt.show()

        if output_path:
            num_original_bands = min(data.shape[0], 6)

            output_bands = []

            for i in range(num_original_bands):
                band_data = data[i].astype(np.float32)
                band_min, band_max = np.nanmin(band_data), np.nanmax(band_data)
                if band_max > 1:
                    band_data = band_data / band_max
                output_bands.append(band_data)

            output_bands.append(cloud_mask_inv.astype(np.float32))
            output_bands.append((cloud_prob * 255).astype(np.float32))

            profile.update({
                'driver': 'GTiff',
                'compress': 'lzw',
                'count': len(output_bands),
                'dtype': 'float32'
            })

            with rasterio.open(output_path, 'w', **profile) as dst:
                for i, band in enumerate(output_bands):
                    dst.write(band, i + 1)

            print(f"\n输出文件已保存至:{output_path}")
            print(f"输出文件包含 {len(output_bands)} 个波段:")
            print(f"  - 波段 1-{num_original_bands}: 原始反射率波段 (归一化)")
            print(f"  - 波段 {num_original_bands + 1}: 云检测掩膜 (白色为云)")
            print(f"  - 波段 {num_original_bands + 2}: 云概率图 (0-255)")

        print(f"\n云区像素占比:{np.mean(cloud_mask_inv) * 100:.2f}%")
        #print(f"阴影区像素占比:{np.mean(shadow_mask) * 100:.2f}%")

        return cloud_mask_inv, shadow_mask, cloud_prob


def fmask_single_image(input_path, output_path=None, cloud_prob_threshold=0.8):
    """
    便捷函数:对单张图像进行 Fmask 云检测

    Args:
        input_path: 输入图像路径
        output_path: 输出路径(可选)
        cloud_prob_threshold: 云层概率阈值,低于此值的云不进行检测 (默认 0.8)

    Returns:
        cloud_mask, cloud_prob: 云掩膜和云概率
    """
    detector = FmaskCloudDetector(cloud_prob_threshold=cloud_prob_threshold)
    cloud_mask, shadow_mask, cloud_prob = detector.process_single_image(
        input_path,
        output_path,
        sensor_type='generic',
        display_results=True
    )
    return cloud_mask, cloud_prob


if __name__ == "__main__":
    input_image = "input/test.tif"

    try:
        cloud_mask, cloud_prob = fmask_single_image(
            input_image,
            output_path="output/fmask_result.tif",
            cloud_prob_threshold=0.8
        )

        print("\n✓ Fmask 云检测完成!")

    except Exception as e:
        print(f"✗ 错误:{str(e)}")
        print("\n请确保输入文件路径正确")

您可能感兴趣的与本文相关的镜像

Python3.8

Python3.8

Conda
Python

Python 是一种高级、解释型、通用的编程语言,以其简洁易读的语法而闻名,适用于广泛的应用,包括Web开发、数据分析、人工智能和自动化脚本

评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

VisionX Lab

你的鼓励将是我更新的动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值