273 lines
13 KiB
Python
273 lines
13 KiB
Python
import cv2
|
||
import rawpy
|
||
from PIL import Image
|
||
from tqdm import tqdm
|
||
import tkinter as tk
|
||
from tkinter import filedialog
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
import argparse
|
||
import numpy as np
|
||
import os
|
||
import torch
|
||
from depth_anything_v2.dpt import DepthAnythingV2
|
||
from utils import specify_name_group_blade,get_name
|
||
parser = argparse.ArgumentParser(description='Depth Anything V2')
|
||
parser.add_argument('--input-size', type=int, default=518)
|
||
parser.add_argument('--encoder', type=str, default='vitl', choices=['vits', 'vitb', 'vitl', 'vitg'])
|
||
args = parser.parse_args()
|
||
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
|
||
model_configs = {
|
||
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
|
||
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
|
||
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
|
||
'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
|
||
}
|
||
|
||
depth_anything = DepthAnythingV2(**model_configs[args.encoder])
|
||
depth_anything.load_state_dict(
|
||
torch.load(f'checkpoints/depth_anything_v2_{args.encoder}.pth', map_location='cpu'))
|
||
depth_anything = depth_anything.to(DEVICE).eval()
|
||
print("model loaded")
|
||
def extraction_win_lamina_mask(raw_image):
|
||
# 前面读的都是RGB图像,深度估计需要BGR,即下面是把RGB→BGR
|
||
raw_image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB)
|
||
|
||
depth = depth_anything.infer_image(raw_image, args.input_size)
|
||
|
||
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
||
depth = depth.astype(np.uint8)
|
||
|
||
_, otsu_mask = cv2.threshold(depth, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
||
return otsu_mask
|
||
|
||
|
||
def check_overexposure(image, lamina_mask,threshold):
|
||
"""
|
||
检查图像中是否有过曝区域
|
||
:param image: 输入图像
|
||
:param threshold: 亮度值的阈值,像素值超过该值则认为是过曝区域
|
||
:return: 一个二值图像,过曝区域为白色,其它区域为黑色
|
||
"""
|
||
# 转换为浮动数据类型并归一化
|
||
image_float = image.astype(np.float32)
|
||
|
||
# 获取每个像素的亮度值(可以使用 YUV 或 RGB 亮度)
|
||
gray_image = cv2.cvtColor(image_float, cv2.COLOR_BGR2GRAY)
|
||
|
||
# 标记过曝区域,亮度大于阈值的像素为过曝区域
|
||
overexposure_mask = gray_image > threshold
|
||
|
||
overexposure_mask = (lamina_mask == 255) & overexposure_mask #只取叶片的亮光区mask
|
||
|
||
overexposure_number = np.sum(overexposure_mask * 1)
|
||
overexposed_pixels = gray_image[overexposure_mask]
|
||
# 计算曝光区域的平均亮度
|
||
if overexposed_pixels.size > 0:
|
||
avg_overexposed_value = np.mean(overexposed_pixels)
|
||
else:
|
||
avg_overexposed_value = 0 # 如果没有曝光区域,返回 0
|
||
|
||
return overexposure_mask, avg_overexposed_value,overexposure_number
|
||
|
||
def check_shawn(image,lamina_mask, threshold):
|
||
"""
|
||
检查图像中是否有阴影区域
|
||
:param image: 输入图像
|
||
:param threshold: 亮度值的阈值,像素值小于该值则认为是过阴影域
|
||
:return: 一个二值图像,过曝区域为白色,其它区域为黑色
|
||
"""
|
||
|
||
# 转换为浮动数据类型并归一化
|
||
image_float = image.astype(np.float32)
|
||
# 获取每个像素的亮度值(可以使用 YUV 或 RGB 亮度)
|
||
gray_image = cv2.cvtColor(image_float, cv2.COLOR_BGR2GRAY)
|
||
# 标记阴影区域,亮度小于阈值的像素为阴影区域
|
||
shawn_mask = gray_image < threshold
|
||
|
||
shawn_mask = (lamina_mask == 255) & shawn_mask
|
||
|
||
shawn_number = np.sum(shawn_mask * 1)
|
||
shawn_pixels = gray_image[shawn_mask]
|
||
# 计算阴影区域的平均亮度
|
||
if shawn_pixels.size > 0:
|
||
avg_shawn_value = np.mean(shawn_pixels)
|
||
else:
|
||
avg_shawn_value = 0
|
||
|
||
return shawn_mask, avg_shawn_value,shawn_number
|
||
|
||
def smooth_overexposed_regions(image, overexposure_mask, kernel_size=(15, 15)):
|
||
"""
|
||
对过曝区域进行平滑处理,修复与周围区域的过渡
|
||
:param image: 输入图像
|
||
:param overexposure_mask: 过曝区域的掩码
|
||
:param kernel_size: 高斯核大小
|
||
:return: 平滑过曝区域后的图像
|
||
"""
|
||
# 使用高斯模糊平滑过曝区域和周围区域
|
||
smoothed_image = cv2.GaussianBlur(image, kernel_size, 0)
|
||
|
||
# 将平滑后的图像和原始图像合成,修复过曝区域
|
||
fixed_image = np.where(overexposure_mask[..., None] == 255, image,smoothed_image)
|
||
|
||
return fixed_image
|
||
|
||
def bilateral_filter_adjustment(image, high_light_mask, d=15, sigma_color=75, sigma_space=75):
|
||
"""
|
||
使用双边滤波器平滑高光区域与周围区域的过渡
|
||
:param image: 输入图像
|
||
:param high_light_mask: 高光区域的掩码
|
||
:param d: 邻域的直径
|
||
:param sigma_color: 颜色空间的标准差
|
||
:param sigma_space: 坐标空间的标准差
|
||
:return: 平滑后的图像
|
||
"""
|
||
# 对整个图像应用双边滤波
|
||
filtered_image = cv2.bilateralFilter(image, d, sigma_color, sigma_space)
|
||
|
||
# 合成平滑后的图像和原图,保留非高光部分
|
||
final_image = np.where(high_light_mask[..., None] == 0, image, filtered_image)
|
||
|
||
return final_image
|
||
|
||
def adjust_highlights_shadows(image, threshold_shawn,threshold):
|
||
|
||
lamina_mask = extraction_win_lamina_mask(image)
|
||
|
||
shawn_mask,avg_shawn_value,shawn_number = check_shawn(image,lamina_mask,threshold_shawn)
|
||
# 调整亮度,gamma < 1 时变亮,gamma > 1 时变暗
|
||
gamma = 1 - (threshold_shawn-avg_shawn_value)/255.0
|
||
# print('阴影区调整的gama: '+str(gamma))#+str('\n'))
|
||
lookup_table = np.array([((i / 255.0) ** gamma) * 255 for i in range(256)]).astype('uint8')
|
||
# 应用 gamma 校正
|
||
if shawn_number !=0:
|
||
image[shawn_mask == True] = cv2.LUT(image[shawn_mask == True], lookup_table)
|
||
gamma_corrected_image = image
|
||
else:
|
||
gamma_corrected_image = image
|
||
#gamma_corrected_image = cv2.LUT(image, lookup_table)
|
||
|
||
#寻找过爆区域
|
||
overexposure_mask,avg_overexposed_value,overexposure_number = check_overexposure(gamma_corrected_image,lamina_mask,threshold)
|
||
reduction_factor = 1-(avg_overexposed_value-threshold) / 255
|
||
#reduction_factor = (avg_overexposed_value/255)*scale
|
||
# print("降低曝光区比例:" + str(reduction_factor))
|
||
|
||
# 调整亮度,gamma < 1 时变亮(越小越亮),gamma > 1 时变暗(越大越暗),
|
||
print((1+reduction_factor))
|
||
lookup_table = np.array([((i / 255.0) ** (1+reduction_factor)) * 255 for i in range(256)]).astype('uint8')
|
||
# 应用 gamma 校正
|
||
if overexposure_number !=0:
|
||
gamma_corrected_image[overexposure_mask == True] = cv2.LUT(gamma_corrected_image[overexposure_mask == True], lookup_table)
|
||
|
||
|
||
#gamma_corrected_image[overexposure_mask == True] = np.clip(gamma_corrected_image[overexposure_mask == True] * reduction_factor, 0, 255)
|
||
|
||
|
||
#smoothed_image = smooth_overexposed_regions(gamma_corrected_image, overexposure_mask)
|
||
#smoothed_image = bilateral_filter_adjustment(gamma_corrected_image, overexposure_mask, d=15, sigma_color=75, sigma_space=75)
|
||
|
||
return gamma_corrected_image
|
||
|
||
|
||
def process_single_image(input_path, output_path, gamma, threshold, no_auto_bright,config=None):
|
||
|
||
if input_path.lower().endswith(('.arw', '.dng')):
|
||
with rawpy.imread(input_path) as raw:
|
||
# 获取 RGB 图像数据
|
||
rgb_image = raw.postprocess(use_camera_wb=True, no_auto_bright=no_auto_bright, output_bps=16)
|
||
rgb_image = np.uint8(rgb_image / 256)
|
||
|
||
if rgb_image is not None:
|
||
adjusted_image = adjust_highlights_shadows(rgb_image, gamma, threshold)
|
||
# 使用 Pillow 创建图像
|
||
image = Image.fromarray(adjusted_image)
|
||
# print(exif)
|
||
image.save(output_path[:-4]+".JPG" , format='JPEG', quality=100,)#'JPEG',exif=exif
|
||
else:
|
||
print(f"Failed to read image: {input_path}")
|
||
else:
|
||
exif = Image.open(input_path).getexif()
|
||
image_pil = Image.open(input_path)
|
||
image = np.array(image_pil)
|
||
if image is not None:
|
||
adjusted_image = adjust_highlights_shadows(image, gamma, threshold)
|
||
images = Image.fromarray(adjusted_image)
|
||
images.save(output_path, 'JPEG', quality=100,exif=exif)
|
||
else:
|
||
print(f"Failed to read image: {input_path}")
|
||
|
||
def process_images(input_dir, output_dir,gamma,threshold,no_auto_bright,config=None):
|
||
all_files = []
|
||
# with tqdm(total=len(all_files), desc="Processing images", unit="file") as pbar:
|
||
for root, dirs, files in os.walk(input_dir):
|
||
for file in files:
|
||
if file.lower().endswith(('.png', '.jpg', '.jpeg','.arw','.dng')):
|
||
# print(file)
|
||
input_path = os.path.join(root, file)
|
||
output_subdir = os.path.join(output_dir, os.path.relpath(root, input_dir))
|
||
|
||
os.makedirs(output_subdir, exist_ok=True)
|
||
if len(dirs) == 0 or dirs[0] == 'CaptureOne':
|
||
# 图片目录 file为叶片图片
|
||
forward = '-'.join(os.path.normpath(root).split('/')[-2:])
|
||
else:
|
||
# 机组目录 file为机组图片
|
||
forward = os.path.normpath(root).split('/')[-1] + '-xxx'
|
||
output_path = os.path.join(output_subdir, get_name(config,forward,file))
|
||
all_files.append((input_path, output_path))
|
||
# pbar.update(1)
|
||
with tqdm(total=len(all_files), desc="Processing images", unit="file") as pbar:
|
||
with ThreadPoolExecutor() as executor:
|
||
futures = []
|
||
for input_path, output_path in all_files:
|
||
future = executor.submit(process_single_image, input_path, output_path, gamma, threshold,no_auto_bright,config)
|
||
future.add_done_callback(lambda _: pbar.update(1))
|
||
futures.append(future)
|
||
for future in futures:
|
||
future.result()
|
||
|
||
# if file.lower().endswith(('.arw','.dng')):
|
||
# with rawpy.imread(input_path) as raw:
|
||
# # 获取 RGB 图像数据
|
||
# rgb_image = raw.postprocess(use_camera_wb=True, no_auto_bright=no_auto_bright, output_bps=16)
|
||
# rgb_image = np.uint8(rgb_image / 256)
|
||
#
|
||
# if rgb_image is not None:
|
||
# adjusted_image = adjust_highlights_shadows(rgb_image,gamma,threshold)
|
||
# # 使用 Pillow 创建图像
|
||
# image = Image.fromarray(adjusted_image)
|
||
# image.save(os.path.splitext(output_path)[0]+'.JPG', 'JPEG', quality=100)
|
||
# else:
|
||
# print(f"Failed to read image: {input_path}")
|
||
# else:
|
||
# #image = cv2.imread(input_path)
|
||
# image_pil = Image.open(input_path)
|
||
# image = np.array(image_pil)
|
||
# if image is not None:
|
||
# adjusted_image = adjust_highlights_shadows(image, gamma,threshold)
|
||
# images = Image.fromarray(adjusted_image)
|
||
# images.save(output_path, 'JPEG', quality=100)
|
||
# else:
|
||
# print(f"Failed to read image: {input_path}")
|
||
|
||
if __name__ == "__main__":
|
||
# input_dir = input("请输入处理图片路径: ")
|
||
# output_dir = input("请输入保存图片路径: ")
|
||
root = tk.Tk()
|
||
root.withdraw() # 不显示主窗口
|
||
# 打开文件选择对话框
|
||
input_dir = filedialog.askdirectory(title="请选择需要处理图片的文件夹,将会处理此文件夹中所有图片")
|
||
#output_dir = filedialog.askdirectory(title="请选择处理后保存的文件夹,处理后与处理前文件命名与路径相同")
|
||
output_dir = input_dir+"+调整"
|
||
|
||
# if not os.path.exists(output_dir):
|
||
# os.makedirs(output_dir,exist_ok=True)
|
||
|
||
# gamma = float(input("请输入筛选阴影区阈值(0~255),建议120~200,优先160,若阴影亮度调高效果不明显可以调大值: "))
|
||
# threshold = float(input("请输入筛选爆光区阈值(0~255),建议240~254,优先245,若曝光区亮度降低不明显或区域较大可以调小: "))
|
||
gamma = 180
|
||
threshold = 253
|
||
no_auto_bright = False
|
||
process_images(input_dir, output_dir,gamma,threshold,no_auto_bright)
|
||
print("处理完成,输出路径:", output_dir) |