Temp/Depth-Anything-V2-main/深度估计_叶片区域_高光阴影_多线程.py

262 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
import cv2
import rawpy
from PIL import Image
from tqdm import tqdm
import tkinter as tk
from tkinter import filedialog
from concurrent.futures import ThreadPoolExecutor
import argparse
import numpy as np
import os
from depth_anything_v2.dpt import DepthAnythingV2
def extraction_win_lamina_mask(raw_image):
# 前面读的都是RGB图像深度估计需要BGR即下面是把RGB→BGR
raw_image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB)
parser = argparse.ArgumentParser(description='Depth Anything V2')
parser.add_argument('--input-size', type=int, default=518)
parser.add_argument('--encoder', type=str, default='vitl', choices=['vits', 'vitb', 'vitl', 'vitg'])
args = parser.parse_args()
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
model_configs = {
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
}
depth_anything = DepthAnythingV2(**model_configs[args.encoder])
depth_anything.load_state_dict(
torch.load(f'checkpoints/depth_anything_v2_{args.encoder}.pth', map_location='cpu'))
depth_anything = depth_anything.to(DEVICE).eval()
depth = depth_anything.infer_image(raw_image, args.input_size)
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
depth = depth.astype(np.uint8)
_, otsu_mask = cv2.threshold(depth, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
return otsu_mask
def check_overexposure(image, lamina_mask,threshold):
"""
检查图像中是否有过曝区域
:param image: 输入图像
:param threshold: 亮度值的阈值,像素值超过该值则认为是过曝区域
:return: 一个二值图像,过曝区域为白色,其它区域为黑色
"""
# 转换为浮动数据类型并归一化
image_float = image.astype(np.float32)
# 获取每个像素的亮度值(可以使用 YUV 或 RGB 亮度)
gray_image = cv2.cvtColor(image_float, cv2.COLOR_BGR2GRAY)
# 标记过曝区域,亮度大于阈值的像素为过曝区域
overexposure_mask = gray_image > threshold
overexposure_mask = (lamina_mask == 255) & overexposure_mask #只取叶片的亮光区mask
overexposure_number = np.sum(overexposure_mask * 1)
overexposed_pixels = gray_image[overexposure_mask]
# 计算曝光区域的平均亮度
if overexposed_pixels.size > 0:
avg_overexposed_value = np.mean(overexposed_pixels)
else:
avg_overexposed_value = 0 # 如果没有曝光区域,返回 0
return overexposure_mask, avg_overexposed_value,overexposure_number
def check_shawn(image,lamina_mask, threshold):
"""
检查图像中是否有阴影区域
:param image: 输入图像
:param threshold: 亮度值的阈值,像素值小于该值则认为是过阴影域
:return: 一个二值图像,过曝区域为白色,其它区域为黑色
"""
# 转换为浮动数据类型并归一化
image_float = image.astype(np.float32)
# 获取每个像素的亮度值(可以使用 YUV 或 RGB 亮度)
gray_image = cv2.cvtColor(image_float, cv2.COLOR_BGR2GRAY)
# 标记阴影区域,亮度小于阈值的像素为阴影区域
shawn_mask = gray_image < threshold
shawn_mask = (lamina_mask == 255) & shawn_mask
shawn_number = np.sum(shawn_mask * 1)
shawn_pixels = gray_image[shawn_mask]
# 计算阴影区域的平均亮度
if shawn_pixels.size > 0:
avg_shawn_value = np.mean(shawn_pixels)
else:
avg_shawn_value = 0
return shawn_mask, avg_shawn_value,shawn_number
def smooth_overexposed_regions(image, overexposure_mask, kernel_size=(15, 15)):
"""
对过曝区域进行平滑处理,修复与周围区域的过渡
:param image: 输入图像
:param overexposure_mask: 过曝区域的掩码
:param kernel_size: 高斯核大小
:return: 平滑过曝区域后的图像
"""
# 使用高斯模糊平滑过曝区域和周围区域
smoothed_image = cv2.GaussianBlur(image, kernel_size, 0)
# 将平滑后的图像和原始图像合成,修复过曝区域
fixed_image = np.where(overexposure_mask[..., None] == 255, image,smoothed_image)
return fixed_image
def bilateral_filter_adjustment(image, high_light_mask, d=15, sigma_color=75, sigma_space=75):
"""
使用双边滤波器平滑高光区域与周围区域的过渡
:param image: 输入图像
:param high_light_mask: 高光区域的掩码
:param d: 邻域的直径
:param sigma_color: 颜色空间的标准差
:param sigma_space: 坐标空间的标准差
:return: 平滑后的图像
"""
# 对整个图像应用双边滤波
filtered_image = cv2.bilateralFilter(image, d, sigma_color, sigma_space)
# 合成平滑后的图像和原图,保留非高光部分
final_image = np.where(high_light_mask[..., None] == 0, image, filtered_image)
return final_image
def adjust_highlights_shadows(image, threshold_shawn,threshold):
lamina_mask = extraction_win_lamina_mask(image)
shawn_mask,avg_shawn_value,shawn_number = check_shawn(image,lamina_mask,threshold_shawn)
# 调整亮度gamma < 1 时变亮gamma > 1 时变暗
gamma = 1 - (threshold_shawn-avg_shawn_value)/255.0
# print('阴影区调整的gama: '+str(gamma))#+str('\n'))
lookup_table = np.array([((i / 255.0) ** gamma) * 255 for i in range(256)]).astype('uint8')
# 应用 gamma 校正
if shawn_number !=0:
image[shawn_mask == True] = cv2.LUT(image[shawn_mask == True], lookup_table)
gamma_corrected_image = image
else:
gamma_corrected_image = image
#gamma_corrected_image = cv2.LUT(image, lookup_table)
#寻找过爆区域
overexposure_mask,avg_overexposed_value,overexposure_number = check_overexposure(gamma_corrected_image,lamina_mask,threshold)
reduction_factor = 1-(avg_overexposed_value-threshold) / 255
#reduction_factor = (avg_overexposed_value/255)*scale
# print("降低曝光区比例:" + str(reduction_factor))
# 调整亮度gamma < 1 时变亮越小越亮gamma > 1 时变暗(越大越暗)
print((1+reduction_factor))
lookup_table = np.array([((i / 255.0) ** (1+reduction_factor)) * 255 for i in range(256)]).astype('uint8')
# 应用 gamma 校正
if overexposure_number !=0:
gamma_corrected_image[overexposure_mask == True] = cv2.LUT(gamma_corrected_image[overexposure_mask == True], lookup_table)
#gamma_corrected_image[overexposure_mask == True] = np.clip(gamma_corrected_image[overexposure_mask == True] * reduction_factor, 0, 255)
#smoothed_image = smooth_overexposed_regions(gamma_corrected_image, overexposure_mask)
#smoothed_image = bilateral_filter_adjustment(gamma_corrected_image, overexposure_mask, d=15, sigma_color=75, sigma_space=75)
return gamma_corrected_image
def process_single_image(input_path, output_path, gamma, threshold, no_auto_bright):
if input_path.lower().endswith(('.arw', '.dng')):
with rawpy.imread(input_path) as raw:
# 获取 RGB 图像数据
rgb_image = raw.postprocess(use_camera_wb=True, no_auto_bright=no_auto_bright, output_bps=16)
rgb_image = np.uint8(rgb_image / 256)
if rgb_image is not None:
adjusted_image = adjust_highlights_shadows(rgb_image, gamma, threshold)
# 使用 Pillow 创建图像
image = Image.fromarray(adjusted_image)
image.save(os.path.splitext(output_path)[0] + '.JPG', 'JPEG', quality=100)
else:
print(f"Failed to read image: {input_path}")
else:
image_pil = Image.open(input_path)
image = np.array(image_pil)
if image is not None:
adjusted_image = adjust_highlights_shadows(image, gamma, threshold)
images = Image.fromarray(adjusted_image)
images.save(output_path, 'JPEG', quality=100)
else:
print(f"Failed to read image: {input_path}")
def process_images(input_dir, output_dir,gamma,threshold,no_auto_bright):
all_files = []
# with tqdm(total=len(all_files), desc="Processing images", unit="file") as pbar:
for root, dirs, files in os.walk(input_dir):
for file in files:
if file.lower().endswith(('.png', '.jpg', '.jpeg','.arw','.dng')):
# print(file)
input_path = os.path.join(root, file)
output_subdir = os.path.join(output_dir, os.path.relpath(root, input_dir))
os.makedirs(output_subdir, exist_ok=True)
output_path = os.path.join(output_subdir, file)
all_files.append((input_path, output_path))
# pbar.update(1)
with tqdm(total=len(all_files), desc="Processing images", unit="file") as pbar:
with ThreadPoolExecutor() as executor:
futures = []
for input_path, output_path in all_files:
future = executor.submit(process_single_image, input_path, output_path, gamma, threshold,no_auto_bright)
future.add_done_callback(lambda _: pbar.update(1))
futures.append(future)
for future in futures:
future.result()
# if file.lower().endswith(('.arw','.dng')):
# with rawpy.imread(input_path) as raw:
# # 获取 RGB 图像数据
# rgb_image = raw.postprocess(use_camera_wb=True, no_auto_bright=no_auto_bright, output_bps=16)
# rgb_image = np.uint8(rgb_image / 256)
#
# if rgb_image is not None:
# adjusted_image = adjust_highlights_shadows(rgb_image,gamma,threshold)
# # 使用 Pillow 创建图像
# image = Image.fromarray(adjusted_image)
# image.save(os.path.splitext(output_path)[0]+'.JPG', 'JPEG', quality=100)
# else:
# print(f"Failed to read image: {input_path}")
# else:
# #image = cv2.imread(input_path)
# image_pil = Image.open(input_path)
# image = np.array(image_pil)
# if image is not None:
# adjusted_image = adjust_highlights_shadows(image, gamma,threshold)
# images = Image.fromarray(adjusted_image)
# images.save(output_path, 'JPEG', quality=100)
# else:
# print(f"Failed to read image: {input_path}")
if __name__ == "__main__":
# input_dir = input("请输入处理图片路径: ")
# output_dir = input("请输入保存图片路径: ")
root = tk.Tk()
root.withdraw() # 不显示主窗口
# 打开文件选择对话框
input_dir = filedialog.askdirectory(title="请选择需要处理图片的文件夹,将会处理此文件夹中所有图片")
#output_dir = filedialog.askdirectory(title="请选择处理后保存的文件夹,处理后与处理前文件命名与路径相同")
output_dir = input_dir+"+调整"
# if not os.path.exists(output_dir):
# os.makedirs(output_dir,exist_ok=True)
# gamma = float(input("请输入筛选阴影区阈值(0~255)建议120~200优先160若阴影亮度调高效果不明显可以调大值: "))
# threshold = float(input("请输入筛选爆光区阈值(0~255)建议240~254优先245若曝光区亮度降低不明显或区域较大可以调小: "))
gamma = 180
threshold = 253
no_auto_bright = False
process_images(input_dir, output_dir,gamma,threshold,no_auto_bright)
print("处理完成,输出路径:", output_dir)