接入图像预处理,画出预处理流程图,完善处理前程序用户交互逻辑。

This commit is contained in:
Voge1imkafig 2025-08-06 17:56:29 +08:00
parent 17f82c78ee
commit 832c91454e
23 changed files with 2061 additions and 135 deletions

3
.gitignore vendored
View File

@ -201,3 +201,6 @@ __marimo__/
# Streamlit # Streamlit
.streamlit/secrets.toml .streamlit/secrets.toml
测试数据/
model/

12
.vscode/launch.json vendored Normal file
View File

@ -0,0 +1,12 @@
{
"version": "0.2.0",
"configurations": [
{
"name": "Run Main",
"type": "debugpy",
"request": "launch",
"program": "${workspaceFolder}/main.py", //
"console": "integratedTerminal"
}
]
}

127
MainWindow/mainwindow.py Normal file
View File

@ -0,0 +1,127 @@
from PySide6.QtWidgets import (QMainWindow, QWidget, QGridLayout,
QPushButton, QSizePolicy, QSplitter, QToolBar)
from PySide6.QtGui import QFontDatabase
from PySide6.QtCore import Signal, Qt
import os
from info_core.defines import *
from info_core.MyQtClass import ConfigComboBoxGroup, FolderDropWidget
class ReportGeneratorUI(QMainWindow):
send_baogao_choose_info = Signal(list[str])
def __init__(self):
super().__init__()
# 加载字体
self.load_font()
# 设置窗口属性
self.setWindowTitle("报告生成器")
self.setMinimumSize(WINDOW_MIN_WIDTH, WINDOW_MIN_HEIGHT)
# 主窗口部件
self.central_widget = QWidget()
self.setCentralWidget(self.central_widget)
# 主布局
self.main_layout = QGridLayout(self.central_widget)
self.main_layout.setSpacing(MAIN_LAYOUT_SPACING)
self.main_layout.setContentsMargins(*MAIN_LAYOUT_MARGINS)
# 初始化UI
self.init_ui()
def load_font(self):
"""加载自定义字体"""
if os.path.exists(FONT_PATH):
font_id = QFontDatabase.addApplicationFont(FONT_PATH)
if font_id == -1:
print("字体加载失败,将使用系统默认字体")
else:
print(f"字体文件未找到: {FONT_PATH},将使用系统默认字体")
def init_ui(self):
"""初始化所有UI组件"""
# 第一行:项目信息和人员配置
self.project_group = ConfigComboBoxGroup("项目基本信息")
self.staff_group = ConfigComboBoxGroup("单次检查配置信息", is_project=False)
# 第二行:导入图片路径、填写机组信息
self.picture_group = FolderDropWidget()
# self.image_analysis =
# self.main_layout.addWidget(self.image_analysis, 1, 1)
# 第三行:生成报告按钮(跨两列)
self.fill_turbine_info_button()
self.fill_btn.clicked.connect(self.on_fill_clicked)
# 创建一个垂直分割器
self.splitter = QSplitter(Qt.Vertical)
self.splitter.setStyleSheet(SPLITTER_STYLE)
# 创建顶部和底部容器
top_container = QWidget()
top_container.setLayout(QGridLayout())
top_container.layout().addWidget(self.project_group, 0, 0)
top_container.layout().addWidget(self.staff_group, 0, 1)
middle_container = QWidget()
middle_container.setLayout(QGridLayout())
middle_container.layout().addWidget(self.picture_group, 0, 0)
# 添加部件到分割器
self.splitter.addWidget(top_container)
self.splitter.addWidget(middle_container)
# 设置主布局
self.main_layout.addWidget(self.splitter, 0, 0, 2, 2) # 占据前两行两列
self.main_layout.addWidget(self.fill_btn, 2, 0, 1, 2)
# 设置分割器初始比例
self.splitter.setStretchFactor(0, 1)
self.splitter.setStretchFactor(1, 4)
self.toolbar = QToolBar()
self.addToolBar(self.toolbar)
self.toolbar.setMovable(False)
self.toolbar.setFloatable(False)
new_action = self.toolbar.addAction("重置布局比例")
self.toolbar.addSeparator()
new_action.triggered.connect(self.reset_splitter)
def reset_splitter(self):
"""重置分割器的比例"""
total_size = sum(self.splitter.sizes()) # 获取当前总大小
self.splitter.setSizes([
int(total_size * 0.2), # 第一部分占 20%(比例 1:4
int(total_size * 0.8) # 第二部分占 80%
])
def on_fill_clicked(self):
"""填写信息"""
# 读取各个配置信息
turbine_file_list = self.picture_group.get_selected_folders()
print(turbine_file_list)
# search_file_list = []
# if self.image_analysis.check_is_waibu:
# search_file_list.append("外汇总")
# if self.image_analysis.check_is_neibu:
# search_file_list.append("内汇总")
# if self.image_analysis.check_is_fanglei:
# search_file_list.append("防汇总")
# self.send_baogao_choose_info.emit(search_file_list)
def create_button(self, text):
"""创建统一风格的按钮"""
btn = QPushButton(text)
btn.setStyleSheet(BUTTON_STYLE)
btn.setFixedSize(BUTTON_WIDTH, BUTTON_HEIGHT)
return btn
def fill_turbine_info_button(self):
"""创建生成报告按钮"""
self.fill_btn = QPushButton("开始填写各个机组信息")
self.fill_btn.setStyleSheet(PRIMARY_BUTTON_STYLE)
self.fill_btn.setFixedHeight(50)
self.fill_btn.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed)

View File

@ -3,5 +3,7 @@
- 数据预处理:阴暗处亮度增加,细节增强。 - 数据预处理:阴暗处亮度增加,细节增强。
- 数据报告生成:基于模板批量生成报告。 - 数据报告生成:基于模板批量生成报告。
# 项目架构图
![项目架构](工具流程.png) ![项目架构](工具流程.png)
# 预处理流程图
![预处理流程](原始数据预处理流程图.png)

416
depth_anything_v2/dinov2.py Normal file
View File

@ -0,0 +1,416 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the Apache License, Version 2.0
# found in the LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
from functools import partial
import math
import logging
from typing import Sequence, Tuple, Union, Callable
import torch
import torch.nn as nn
import torch.utils.checkpoint
from torch.nn.init import trunc_normal_
from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
logger = logging.getLogger("dinov2")
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
if not depth_first and include_root:
fn(module=module, name=name)
for child_name, child_module in module.named_children():
child_name = ".".join((name, child_name)) if name else child_name
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
if depth_first and include_root:
fn(module=module, name=name)
return module
class BlockChunk(nn.ModuleList):
def forward(self, x):
for b in self:
x = b(x)
return x
class DinoVisionTransformer(nn.Module):
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=True,
ffn_bias=True,
proj_bias=True,
drop_path_rate=0.0,
drop_path_uniform=False,
init_values=None, # for layerscale: None or 0 => no layerscale
embed_layer=PatchEmbed,
act_layer=nn.GELU,
block_fn=Block,
ffn_layer="mlp",
block_chunks=1,
num_register_tokens=0,
interpolate_antialias=False,
interpolate_offset=0.1,
):
"""
Args:
img_size (int, tuple): input image size
patch_size (int, tuple): patch size
in_chans (int): number of input channels
embed_dim (int): embedding dimension
depth (int): depth of transformer
num_heads (int): number of attention heads
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
qkv_bias (bool): enable bias for qkv if True
proj_bias (bool): enable bias for proj in attn if True
ffn_bias (bool): enable bias for ffn if True
drop_path_rate (float): stochastic depth rate
drop_path_uniform (bool): apply uniform drop rate across blocks
weight_init (str): weight init scheme
init_values (float): layer-scale init values
embed_layer (nn.Module): patch embedding layer
act_layer (nn.Module): MLP activation layer
block_fn (nn.Module): transformer block class
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
"""
super().__init__()
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
self.num_tokens = 1
self.n_blocks = depth
self.num_heads = num_heads
self.patch_size = patch_size
self.num_register_tokens = num_register_tokens
self.interpolate_antialias = interpolate_antialias
self.interpolate_offset = interpolate_offset
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.num_patches
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
assert num_register_tokens >= 0
self.register_tokens = (
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
)
if drop_path_uniform is True:
dpr = [drop_path_rate] * depth
else:
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
if ffn_layer == "mlp":
logger.info("using MLP layer as FFN")
ffn_layer = Mlp
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
logger.info("using SwiGLU layer as FFN")
ffn_layer = SwiGLUFFNFused
elif ffn_layer == "identity":
logger.info("using Identity layer as FFN")
def f(*args, **kwargs):
return nn.Identity()
ffn_layer = f
else:
raise NotImplementedError
blocks_list = [
block_fn(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
ffn_bias=ffn_bias,
drop_path=dpr[i],
norm_layer=norm_layer,
act_layer=act_layer,
ffn_layer=ffn_layer,
init_values=init_values,
)
for i in range(depth)
]
if block_chunks > 0:
self.chunked_blocks = True
chunked_blocks = []
chunksize = depth // block_chunks
for i in range(0, depth, chunksize):
# this is to keep the block index consistent if we chunk the block list
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
else:
self.chunked_blocks = False
self.blocks = nn.ModuleList(blocks_list)
self.norm = norm_layer(embed_dim)
self.head = nn.Identity()
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
self.init_weights()
def init_weights(self):
trunc_normal_(self.pos_embed, std=0.02)
nn.init.normal_(self.cls_token, std=1e-6)
if self.register_tokens is not None:
nn.init.normal_(self.register_tokens, std=1e-6)
named_apply(init_weights_vit_timm, self)
def interpolate_pos_encoding(self, x, w, h):
previous_dtype = x.dtype
npatch = x.shape[1] - 1
N = self.pos_embed.shape[1] - 1
if npatch == N and w == h:
return self.pos_embed
pos_embed = self.pos_embed.float()
class_pos_embed = pos_embed[:, 0]
patch_pos_embed = pos_embed[:, 1:]
dim = x.shape[-1]
w0 = w // self.patch_size
h0 = h // self.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
# DINOv2 with register modify the interpolate_offset from 0.1 to 0.0
w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
# w0, h0 = w0 + 0.1, h0 + 0.1
sqrt_N = math.sqrt(N)
sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
scale_factor=(sx, sy),
# (int(w0), int(h0)), # to solve the upsampling shape issue
mode="bicubic"
)
# antialias=self.interpolate_antialias
# )
assert int(w0) == patch_pos_embed.shape[-2]
assert int(h0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
def prepare_tokens_with_masks(self, x, masks=None):
B, nc, w, h = x.shape
x = self.patch_embed(x)
if masks is not None:
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
x = x + self.interpolate_pos_encoding(x, w, h)
if self.register_tokens is not None:
x = torch.cat(
(
x[:, :1],
self.register_tokens.expand(x.shape[0], -1, -1),
x[:, 1:],
),
dim=1,
)
return x
def forward_features_list(self, x_list, masks_list):
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
for blk in self.blocks:
x = blk(x)
all_x = x
output = []
for x, masks in zip(all_x, masks_list):
x_norm = self.norm(x)
output.append(
{
"x_norm_clstoken": x_norm[:, 0],
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
"x_prenorm": x,
"masks": masks,
}
)
return output
def forward_features(self, x, masks=None):
if isinstance(x, list):
return self.forward_features_list(x, masks)
x = self.prepare_tokens_with_masks(x, masks)
for blk in self.blocks:
x = blk(x)
x_norm = self.norm(x)
return {
"x_norm_clstoken": x_norm[:, 0],
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
"x_prenorm": x,
"masks": masks,
}
def _get_intermediate_layers_not_chunked(self, x, n=1):
x = self.prepare_tokens_with_masks(x)
# If n is an int, take the n last blocks. If it's a list, take them
output, total_block_len = [], len(self.blocks)
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
for i, blk in enumerate(self.blocks):
x = blk(x)
if i in blocks_to_take:
output.append(x)
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
return output
def _get_intermediate_layers_chunked(self, x, n=1):
x = self.prepare_tokens_with_masks(x)
output, i, total_block_len = [], 0, len(self.blocks[-1])
# If n is an int, take the n last blocks. If it's a list, take them
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
for block_chunk in self.blocks:
for blk in block_chunk[i:]: # Passing the nn.Identity()
x = blk(x)
if i in blocks_to_take:
output.append(x)
i += 1
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
return output
def get_intermediate_layers(
self,
x: torch.Tensor,
n: Union[int, Sequence] = 1, # Layers or n last layers to take
reshape: bool = False,
return_class_token: bool = False,
norm=True
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
if self.chunked_blocks:
outputs = self._get_intermediate_layers_chunked(x, n)
else:
outputs = self._get_intermediate_layers_not_chunked(x, n)
if norm:
outputs = [self.norm(out) for out in outputs]
class_tokens = [out[:, 0] for out in outputs]
outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
if reshape:
B, _, w, h = x.shape
outputs = [
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
for out in outputs
]
if return_class_token:
return tuple(zip(outputs, class_tokens))
return tuple(outputs)
def forward(self, *args, is_training=False, **kwargs):
ret = self.forward_features(*args, **kwargs)
if is_training:
return ret
else:
return self.head(ret["x_norm_clstoken"])
def init_weights_vit_timm(module: nn.Module, name: str = ""):
"""ViT weight initialization, original timm impl (for reproducibility)"""
if isinstance(module, nn.Linear):
trunc_normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=384,
depth=12,
num_heads=6,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens,
**kwargs,
)
return model
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens,
**kwargs,
)
return model
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=1024,
depth=24,
num_heads=16,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens,
**kwargs,
)
return model
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
"""
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
"""
model = DinoVisionTransformer(
patch_size=patch_size,
embed_dim=1536,
depth=40,
num_heads=24,
mlp_ratio=4,
block_fn=partial(Block, attn_class=MemEffAttention),
num_register_tokens=num_register_tokens,
**kwargs,
)
return model
def DINOv2(model_name):
model_zoo = {
"vits": vit_small,
"vitb": vit_base,
"vitl": vit_large,
"vitg": vit_giant2
}
return model_zoo[model_name](
img_size=518,
patch_size=14,
init_values=1.0,
ffn_layer="mlp" if model_name != "vitg" else "swiglufused",
block_chunks=0,
num_register_tokens=0,
interpolate_antialias=False,
interpolate_offset=0.1
)

View File

@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from .mlp import Mlp
from .patch_embed import PatchEmbed
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
from .block import NestedTensorBlock
from .attention import MemEffAttention

View File

@ -0,0 +1,83 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
import logging
from torch import Tensor
from torch import nn
logger = logging.getLogger("dinov2")
try:
from xformers.ops import memory_efficient_attention, unbind, fmha
XFORMERS_AVAILABLE = True
except ImportError:
logger.warning("xFormers not available")
XFORMERS_AVAILABLE = False
class Attention(nn.Module):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
proj_bias: bool = True,
attn_drop: float = 0.0,
proj_drop: float = 0.0,
) -> None:
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim**-0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim, bias=proj_bias)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: Tensor) -> Tensor:
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
attn = q @ k.transpose(-2, -1)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MemEffAttention(Attention):
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
if not XFORMERS_AVAILABLE:
assert attn_bias is None, "xFormers is required for nested tensors usage"
return super().forward(x)
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
q, k, v = unbind(qkv, 2)
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
x = x.reshape([B, N, C])
x = self.proj(x)
x = self.proj_drop(x)
return x

View File

@ -0,0 +1,252 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
import logging
from typing import Callable, List, Any, Tuple, Dict
import torch
from torch import nn, Tensor
from .attention import Attention, MemEffAttention
from .drop_path import DropPath
from .layer_scale import LayerScale
from .mlp import Mlp
logger = logging.getLogger("dinov2")
try:
from xformers.ops import fmha
from xformers.ops import scaled_index_add, index_select_cat
XFORMERS_AVAILABLE = True
except ImportError:
logger.warning("xFormers not available")
XFORMERS_AVAILABLE = False
class Block(nn.Module):
def __init__(
self,
dim: int,
num_heads: int,
mlp_ratio: float = 4.0,
qkv_bias: bool = False,
proj_bias: bool = True,
ffn_bias: bool = True,
drop: float = 0.0,
attn_drop: float = 0.0,
init_values=None,
drop_path: float = 0.0,
act_layer: Callable[..., nn.Module] = nn.GELU,
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
attn_class: Callable[..., nn.Module] = Attention,
ffn_layer: Callable[..., nn.Module] = Mlp,
) -> None:
super().__init__()
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
self.norm1 = norm_layer(dim)
self.attn = attn_class(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
proj_bias=proj_bias,
attn_drop=attn_drop,
proj_drop=drop,
)
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = ffn_layer(
in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop,
bias=ffn_bias,
)
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.sample_drop_ratio = drop_path
def forward(self, x: Tensor) -> Tensor:
def attn_residual_func(x: Tensor) -> Tensor:
return self.ls1(self.attn(self.norm1(x)))
def ffn_residual_func(x: Tensor) -> Tensor:
return self.ls2(self.mlp(self.norm2(x)))
if self.training and self.sample_drop_ratio > 0.1:
# the overhead is compensated only for a drop path rate larger than 0.1
x = drop_add_residual_stochastic_depth(
x,
residual_func=attn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
)
x = drop_add_residual_stochastic_depth(
x,
residual_func=ffn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
)
elif self.training and self.sample_drop_ratio > 0.0:
x = x + self.drop_path1(attn_residual_func(x))
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
else:
x = x + attn_residual_func(x)
x = x + ffn_residual_func(x)
return x
def drop_add_residual_stochastic_depth(
x: Tensor,
residual_func: Callable[[Tensor], Tensor],
sample_drop_ratio: float = 0.0,
) -> Tensor:
# 1) extract subset using permutation
b, n, d = x.shape
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
x_subset = x[brange]
# 2) apply residual_func to get residual
residual = residual_func(x_subset)
x_flat = x.flatten(1)
residual = residual.flatten(1)
residual_scale_factor = b / sample_subset_size
# 3) add the residual
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
return x_plus_residual.view_as(x)
def get_branges_scales(x, sample_drop_ratio=0.0):
b, n, d = x.shape
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
residual_scale_factor = b / sample_subset_size
return brange, residual_scale_factor
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
if scaling_vector is None:
x_flat = x.flatten(1)
residual = residual.flatten(1)
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
else:
x_plus_residual = scaled_index_add(
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
)
return x_plus_residual
attn_bias_cache: Dict[Tuple, Any] = {}
def get_attn_bias_and_cat(x_list, branges=None):
"""
this will perform the index select, cat the tensors, and provide the attn_bias from cache
"""
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
if all_shapes not in attn_bias_cache.keys():
seqlens = []
for b, x in zip(batch_sizes, x_list):
for _ in range(b):
seqlens.append(x.shape[1])
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
attn_bias._batch_sizes = batch_sizes
attn_bias_cache[all_shapes] = attn_bias
if branges is not None:
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
else:
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
cat_tensors = torch.cat(tensors_bs1, dim=1)
return attn_bias_cache[all_shapes], cat_tensors
def drop_add_residual_stochastic_depth_list(
x_list: List[Tensor],
residual_func: Callable[[Tensor, Any], Tensor],
sample_drop_ratio: float = 0.0,
scaling_vector=None,
) -> Tensor:
# 1) generate random set of indices for dropping samples in the batch
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
branges = [s[0] for s in branges_scales]
residual_scale_factors = [s[1] for s in branges_scales]
# 2) get attention bias and index+concat the tensors
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
# 3) apply residual_func to get residual, and split the result
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
outputs = []
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
return outputs
class NestedTensorBlock(Block):
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
"""
x_list contains a list of tensors to nest together and run
"""
assert isinstance(self.attn, MemEffAttention)
if self.training and self.sample_drop_ratio > 0.0:
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
return self.attn(self.norm1(x), attn_bias=attn_bias)
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
return self.mlp(self.norm2(x))
x_list = drop_add_residual_stochastic_depth_list(
x_list,
residual_func=attn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
)
x_list = drop_add_residual_stochastic_depth_list(
x_list,
residual_func=ffn_residual_func,
sample_drop_ratio=self.sample_drop_ratio,
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
)
return x_list
else:
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
return self.ls2(self.mlp(self.norm2(x)))
attn_bias, x = get_attn_bias_and_cat(x_list)
x = x + attn_residual_func(x, attn_bias=attn_bias)
x = x + ffn_residual_func(x)
return attn_bias.split(x)
def forward(self, x_or_x_list):
if isinstance(x_or_x_list, Tensor):
return super().forward(x_or_x_list)
elif isinstance(x_or_x_list, list):
assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
return self.forward_nested(x_or_x_list)
else:
raise AssertionError

View File

@ -0,0 +1,35 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
from torch import nn
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
if keep_prob > 0.0:
random_tensor.div_(keep_prob)
output = x * random_tensor
return output
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)

View File

@ -0,0 +1,28 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
from typing import Union
import torch
from torch import Tensor
from torch import nn
class LayerScale(nn.Module):
def __init__(
self,
dim: int,
init_values: Union[float, Tensor] = 1e-5,
inplace: bool = False,
) -> None:
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x: Tensor) -> Tensor:
return x.mul_(self.gamma) if self.inplace else x * self.gamma

View File

@ -0,0 +1,41 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
from typing import Callable, Optional
from torch import Tensor, nn
class Mlp(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = nn.GELU,
drop: float = 0.0,
bias: bool = True,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
self.drop = nn.Dropout(drop)
def forward(self, x: Tensor) -> Tensor:
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x

View File

@ -0,0 +1,89 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# References:
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
from typing import Callable, Optional, Tuple, Union
from torch import Tensor
import torch.nn as nn
def make_2tuple(x):
if isinstance(x, tuple):
assert len(x) == 2
return x
assert isinstance(x, int)
return (x, x)
class PatchEmbed(nn.Module):
"""
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
Args:
img_size: Image size.
patch_size: Patch token size.
in_chans: Number of input image channels.
embed_dim: Number of linear projection output channels.
norm_layer: Normalization layer.
"""
def __init__(
self,
img_size: Union[int, Tuple[int, int]] = 224,
patch_size: Union[int, Tuple[int, int]] = 16,
in_chans: int = 3,
embed_dim: int = 768,
norm_layer: Optional[Callable] = None,
flatten_embedding: bool = True,
) -> None:
super().__init__()
image_HW = make_2tuple(img_size)
patch_HW = make_2tuple(patch_size)
patch_grid_size = (
image_HW[0] // patch_HW[0],
image_HW[1] // patch_HW[1],
)
self.img_size = image_HW
self.patch_size = patch_HW
self.patches_resolution = patch_grid_size
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
self.flatten_embedding = flatten_embedding
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x: Tensor) -> Tensor:
_, _, H, W = x.shape
patch_H, patch_W = self.patch_size
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
x = self.proj(x) # B C H W
H, W = x.size(2), x.size(3)
x = x.flatten(2).transpose(1, 2) # B HW C
x = self.norm(x)
if not self.flatten_embedding:
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
return x
def flops(self) -> float:
Ho, Wo = self.patches_resolution
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
if self.norm is not None:
flops += Ho * Wo * self.embed_dim
return flops

View File

@ -0,0 +1,63 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable, Optional
from torch import Tensor, nn
import torch.nn.functional as F
class SwiGLUFFN(nn.Module):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = None,
drop: float = 0.0,
bias: bool = True,
) -> None:
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
def forward(self, x: Tensor) -> Tensor:
x12 = self.w12(x)
x1, x2 = x12.chunk(2, dim=-1)
hidden = F.silu(x1) * x2
return self.w3(hidden)
try:
from xformers.ops import SwiGLU
XFORMERS_AVAILABLE = True
except ImportError:
SwiGLU = SwiGLUFFN
XFORMERS_AVAILABLE = False
class SwiGLUFFNFused(SwiGLU):
def __init__(
self,
in_features: int,
hidden_features: Optional[int] = None,
out_features: Optional[int] = None,
act_layer: Callable[..., nn.Module] = None,
drop: float = 0.0,
bias: bool = True,
) -> None:
out_features = out_features or in_features
hidden_features = hidden_features or in_features
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
super().__init__(
in_features=in_features,
hidden_features=hidden_features,
out_features=out_features,
bias=bias,
)

221
depth_anything_v2/dpt.py Normal file
View File

@ -0,0 +1,221 @@
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import Compose
from .dinov2 import DINOv2
from .util.blocks import FeatureFusionBlock, _make_scratch
from .util.transform import Resize, NormalizeImage, PrepareForNet
def _make_fusion_block(features, use_bn, size=None):
return FeatureFusionBlock(
features,
nn.ReLU(False),
deconv=False,
bn=use_bn,
expand=False,
align_corners=True,
size=size,
)
class ConvBlock(nn.Module):
def __init__(self, in_feature, out_feature):
super().__init__()
self.conv_block = nn.Sequential(
nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_feature),
nn.ReLU(True)
)
def forward(self, x):
return self.conv_block(x)
class DPTHead(nn.Module):
def __init__(
self,
in_channels,
features=256,
use_bn=False,
out_channels=[256, 512, 1024, 1024],
use_clstoken=False
):
super(DPTHead, self).__init__()
self.use_clstoken = use_clstoken
self.projects = nn.ModuleList([
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channel,
kernel_size=1,
stride=1,
padding=0,
) for out_channel in out_channels
])
self.resize_layers = nn.ModuleList([
nn.ConvTranspose2d(
in_channels=out_channels[0],
out_channels=out_channels[0],
kernel_size=4,
stride=4,
padding=0),
nn.ConvTranspose2d(
in_channels=out_channels[1],
out_channels=out_channels[1],
kernel_size=2,
stride=2,
padding=0),
nn.Identity(),
nn.Conv2d(
in_channels=out_channels[3],
out_channels=out_channels[3],
kernel_size=3,
stride=2,
padding=1)
])
if use_clstoken:
self.readout_projects = nn.ModuleList()
for _ in range(len(self.projects)):
self.readout_projects.append(
nn.Sequential(
nn.Linear(2 * in_channels, in_channels),
nn.GELU()))
self.scratch = _make_scratch(
out_channels,
features,
groups=1,
expand=False,
)
self.scratch.stem_transpose = None
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
head_features_1 = features
head_features_2 = 32
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
self.scratch.output_conv2 = nn.Sequential(
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
nn.ReLU(True),
nn.Identity(),
)
def forward(self, out_features, patch_h, patch_w):
out = []
for i, x in enumerate(out_features):
if self.use_clstoken:
x, cls_token = x[0], x[1]
readout = cls_token.unsqueeze(1).expand_as(x)
x = self.readout_projects[i](torch.cat((x, readout), -1))
else:
x = x[0]
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
x = self.projects[i](x)
x = self.resize_layers[i](x)
out.append(x)
layer_1, layer_2, layer_3, layer_4 = out
layer_1_rn = self.scratch.layer1_rn(layer_1)
layer_2_rn = self.scratch.layer2_rn(layer_2)
layer_3_rn = self.scratch.layer3_rn(layer_3)
layer_4_rn = self.scratch.layer4_rn(layer_4)
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
out = self.scratch.output_conv1(path_1)
out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
out = self.scratch.output_conv2(out)
return out
class DepthAnythingV2(nn.Module):
def __init__(
self,
encoder='vitl',
features=256,
out_channels=[256, 512, 1024, 1024],
use_bn=False,
use_clstoken=False
):
super(DepthAnythingV2, self).__init__()
self.intermediate_layer_idx = {
'vits': [2, 5, 8, 11],
'vitb': [2, 5, 8, 11],
'vitl': [4, 11, 17, 23],
'vitg': [9, 19, 29, 39]
}
self.encoder = encoder
self.pretrained = DINOv2(model_name=encoder)
self.depth_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
def forward(self, x):
patch_h, patch_w = x.shape[-2] // 14, x.shape[-1] // 14
features = self.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True)
depth = self.depth_head(features, patch_h, patch_w)
depth = F.relu(depth)
return depth.squeeze(1)
@torch.no_grad()
def infer_image(self, raw_image, input_size=518):
image, (h, w) = self.image2tensor(raw_image, input_size)
depth = self.forward(image)
depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True)[0, 0]
return depth.cpu().numpy()
def image2tensor(self, raw_image, input_size=518):
transform = Compose([
Resize(
width=input_size,
height=input_size,
resize_target=False,
keep_aspect_ratio=True,
ensure_multiple_of=14,
resize_method='lower_bound',
image_interpolation_method=cv2.INTER_CUBIC,
),
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
PrepareForNet(),
])
h, w = raw_image.shape[:2]
image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0
image = transform({'image': image})['image']
image = torch.from_numpy(image).unsqueeze(0)
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
image = image.to(DEVICE)
return image, (h, w)

View File

@ -0,0 +1,148 @@
import torch.nn as nn
def _make_scratch(in_shape, out_shape, groups=1, expand=False):
scratch = nn.Module()
out_shape1 = out_shape
out_shape2 = out_shape
out_shape3 = out_shape
if len(in_shape) >= 4:
out_shape4 = out_shape
if expand:
out_shape1 = out_shape
out_shape2 = out_shape * 2
out_shape3 = out_shape * 4
if len(in_shape) >= 4:
out_shape4 = out_shape * 8
scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
if len(in_shape) >= 4:
scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups)
return scratch
class ResidualConvUnit(nn.Module):
"""Residual convolution module.
"""
def __init__(self, features, activation, bn):
"""Init.
Args:
features (int): number of features
"""
super().__init__()
self.bn = bn
self.groups=1
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
if self.bn == True:
self.bn1 = nn.BatchNorm2d(features)
self.bn2 = nn.BatchNorm2d(features)
self.activation = activation
self.skip_add = nn.quantized.FloatFunctional()
def forward(self, x):
"""Forward pass.
Args:
x (tensor): input
Returns:
tensor: output
"""
out = self.activation(x)
out = self.conv1(out)
if self.bn == True:
out = self.bn1(out)
out = self.activation(out)
out = self.conv2(out)
if self.bn == True:
out = self.bn2(out)
if self.groups > 1:
out = self.conv_merge(out)
return self.skip_add.add(out, x)
class FeatureFusionBlock(nn.Module):
"""Feature fusion block.
"""
def __init__(
self,
features,
activation,
deconv=False,
bn=False,
expand=False,
align_corners=True,
size=None
):
"""Init.
Args:
features (int): number of features
"""
super(FeatureFusionBlock, self).__init__()
self.deconv = deconv
self.align_corners = align_corners
self.groups=1
self.expand = expand
out_features = features
if self.expand == True:
out_features = features // 2
self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1)
self.resConfUnit1 = ResidualConvUnit(features, activation, bn)
self.resConfUnit2 = ResidualConvUnit(features, activation, bn)
self.skip_add = nn.quantized.FloatFunctional()
self.size=size
def forward(self, *xs, size=None):
"""Forward pass.
Returns:
tensor: output
"""
output = xs[0]
if len(xs) == 2:
res = self.resConfUnit1(xs[1])
output = self.skip_add.add(output, res)
output = self.resConfUnit2(output)
if (size is None) and (self.size is None):
modifier = {"scale_factor": 2}
elif size is None:
modifier = {"size": self.size}
else:
modifier = {"size": size}
output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
output = self.out_conv(output)
return output

View File

@ -0,0 +1,158 @@
import numpy as np
import cv2
class Resize(object):
"""Resize sample to given size (width, height).
"""
def __init__(
self,
width,
height,
resize_target=True,
keep_aspect_ratio=False,
ensure_multiple_of=1,
resize_method="lower_bound",
image_interpolation_method=cv2.INTER_AREA,
):
"""Init.
Args:
width (int): desired output width
height (int): desired output height
resize_target (bool, optional):
True: Resize the full sample (image, mask, target).
False: Resize image only.
Defaults to True.
keep_aspect_ratio (bool, optional):
True: Keep the aspect ratio of the input sample.
Output sample might not have the given width and height, and
resize behaviour depends on the parameter 'resize_method'.
Defaults to False.
ensure_multiple_of (int, optional):
Output width and height is constrained to be multiple of this parameter.
Defaults to 1.
resize_method (str, optional):
"lower_bound": Output will be at least as large as the given size.
"upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.)
"minimal": Scale as least as possible. (Output size might be smaller than given size.)
Defaults to "lower_bound".
"""
self.__width = width
self.__height = height
self.__resize_target = resize_target
self.__keep_aspect_ratio = keep_aspect_ratio
self.__multiple_of = ensure_multiple_of
self.__resize_method = resize_method
self.__image_interpolation_method = image_interpolation_method
def constrain_to_multiple_of(self, x, min_val=0, max_val=None):
y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int)
if max_val is not None and y > max_val:
y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int)
if y < min_val:
y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int)
return y
def get_size(self, width, height):
# determine new height and width
scale_height = self.__height / height
scale_width = self.__width / width
if self.__keep_aspect_ratio:
if self.__resize_method == "lower_bound":
# scale such that output size is lower bound
if scale_width > scale_height:
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
elif self.__resize_method == "upper_bound":
# scale such that output size is upper bound
if scale_width < scale_height:
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
elif self.__resize_method == "minimal":
# scale as least as possbile
if abs(1 - scale_width) < abs(1 - scale_height):
# fit width
scale_height = scale_width
else:
# fit height
scale_width = scale_height
else:
raise ValueError(f"resize_method {self.__resize_method} not implemented")
if self.__resize_method == "lower_bound":
new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height)
new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width)
elif self.__resize_method == "upper_bound":
new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height)
new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width)
elif self.__resize_method == "minimal":
new_height = self.constrain_to_multiple_of(scale_height * height)
new_width = self.constrain_to_multiple_of(scale_width * width)
else:
raise ValueError(f"resize_method {self.__resize_method} not implemented")
return (new_width, new_height)
def __call__(self, sample):
width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0])
# resize sample
sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method)
if self.__resize_target:
if "depth" in sample:
sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST)
if "mask" in sample:
sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST)
return sample
class NormalizeImage(object):
"""Normlize image by given mean and std.
"""
def __init__(self, mean, std):
self.__mean = mean
self.__std = std
def __call__(self, sample):
sample["image"] = (sample["image"] - self.__mean) / self.__std
return sample
class PrepareForNet(object):
"""Prepare sample for usage as network input.
"""
def __init__(self):
pass
def __call__(self, sample):
image = np.transpose(sample["image"], (2, 0, 1))
sample["image"] = np.ascontiguousarray(image).astype(np.float32)
if "depth" in sample:
depth = sample["depth"].astype(np.float32)
sample["depth"] = np.ascontiguousarray(depth)
if "mask" in sample:
sample["mask"] = sample["mask"].astype(np.float32)
sample["mask"] = np.ascontiguousarray(sample["mask"])
return sample

View File

@ -740,8 +740,8 @@ class FolderDropWidget(QWidget):
def init_ui(self): def init_ui(self):
main_layout = QVBoxLayout() main_layout = QVBoxLayout()
main_layout.setSpacing(15) main_layout.setSpacing(5)
main_layout.setContentsMargins(15, 15, 15, 15) main_layout.setContentsMargins(0, 0, 0, 0)
# 顶部按钮区域 # 顶部按钮区域
btn_layout = QHBoxLayout() btn_layout = QHBoxLayout()
@ -750,12 +750,24 @@ class FolderDropWidget(QWidget):
self.reset_btn.setFixedSize(120, 40) self.reset_btn.setFixedSize(120, 40)
self.reset_btn.setVisible(False) self.reset_btn.setVisible(False)
self.clear_btn = QPushButton("清空选中机组列表") self.clear_btn = QPushButton("清空")
self.clear_btn.clicked.connect(self.clear_selected) self.clear_btn.clicked.connect(self.clear_selected)
self.clear_btn.setFixedSize(120, 40) self.clear_btn.setFixedSize(120, 40)
self.clear_btn.setVisible(False) self.clear_btn.setVisible(False)
self.full_select_btn = QPushButton("全选")
self.full_select_btn.clicked.connect(self.select_all_folders)
self.full_select_btn.setFixedSize(120, 40)
self.full_select_btn.setVisible(False)
# 添加返回按钮
self.back_btn = QPushButton("返回项目选择")
self.back_btn.clicked.connect(lambda: (self.stacked_layout.setCurrentIndex(1) or self.clear_selected() or self.reset_btn.setVisible(True) or self.back_btn.setVisible(False)))
self.back_btn.setVisible(False)
btn_layout.addWidget(self.reset_btn) btn_layout.addWidget(self.reset_btn)
btn_layout.addWidget(self.back_btn)
btn_layout.addWidget(self.full_select_btn)
btn_layout.addWidget(self.clear_btn) btn_layout.addWidget(self.clear_btn)
btn_layout.addStretch() btn_layout.addStretch()
main_layout.addLayout(btn_layout) main_layout.addLayout(btn_layout)
@ -776,6 +788,8 @@ class FolderDropWidget(QWidget):
# 状态2: 显示第一层文件夹 (项目选择) # 状态2: 显示第一层文件夹 (项目选择)
self.level1_container = QWidget() self.level1_container = QWidget()
level1_main_layout = QVBoxLayout() level1_main_layout = QVBoxLayout()
level1_main_layout.setSpacing(0)
level1_main_layout.setContentsMargins(0, 0, 0, 0)
# 添加状态标签 # 添加状态标签
self.level1_title = QLabel("项目选择") self.level1_title = QLabel("项目选择")
@ -797,6 +811,8 @@ class FolderDropWidget(QWidget):
# 状态3: 显示第二层文件夹 (机组选择) # 状态3: 显示第二层文件夹 (机组选择)
self.level2_container = QWidget() self.level2_container = QWidget()
level2_main_layout = QVBoxLayout() level2_main_layout = QVBoxLayout()
level2_main_layout.setSpacing(0)
level2_main_layout.setContentsMargins(0, 0, 0, 0)
# 添加状态标签 # 添加状态标签
self.level2_title = QLabel("机组选择") self.level2_title = QLabel("机组选择")
@ -811,11 +827,6 @@ class FolderDropWidget(QWidget):
self.level2_list.setSpacing(10) self.level2_list.setSpacing(10)
level2_main_layout.addWidget(self.level2_list) level2_main_layout.addWidget(self.level2_list)
# 添加返回按钮
self.back_btn = QPushButton("返回项目选择")
self.back_btn.clicked.connect(lambda: self.stacked_layout.setCurrentIndex(1))
level2_main_layout.addWidget(self.back_btn)
self.level2_container.setLayout(level2_main_layout) self.level2_container.setLayout(level2_main_layout)
self.stacked_layout.addWidget(self.level2_container) self.stacked_layout.addWidget(self.level2_container)
@ -836,6 +847,13 @@ class FolderDropWidget(QWidget):
self.setLayout(main_layout) self.setLayout(main_layout)
def select_all_folders(self):
self.selected_folders.update(
set(item.data(Qt.ItemDataRole.UserRole) for item in self.level2_list.findItems("", Qt.MatchFlag.MatchContains))
)
self.update_checkbox_states()
self.selection_changed.emit(self.get_selected_folders())
def apply_styles(self): def apply_styles(self):
self.prompt_label.setStyleSheet(f""" self.prompt_label.setStyleSheet(f"""
{PATH_DISPLAY_STYLE} {PATH_DISPLAY_STYLE}
@ -872,6 +890,7 @@ class FolderDropWidget(QWidget):
""") """)
self.selected_group.setStyleSheet(GROUP_BOX_STYLE) self.selected_group.setStyleSheet(GROUP_BOX_STYLE)
self.full_select_btn.setStyleSheet(BUTTON_STYLE)
self.reset_btn.setStyleSheet(BUTTON_STYLE) self.reset_btn.setStyleSheet(BUTTON_STYLE)
self.clear_btn.setStyleSheet(BUTTON_STYLE) self.clear_btn.setStyleSheet(BUTTON_STYLE)
self.back_btn.setStyleSheet(BUTTON_STYLE) self.back_btn.setStyleSheet(BUTTON_STYLE)
@ -890,6 +909,7 @@ class FolderDropWidget(QWidget):
self.stacked_layout.setCurrentIndex(0) # 显示提示 self.stacked_layout.setCurrentIndex(0) # 显示提示
self.reset_btn.setVisible(False) self.reset_btn.setVisible(False)
self.clear_btn.setVisible(False) self.clear_btn.setVisible(False)
self.back_btn.setVisible(False)
self.selected_group.setVisible(False) self.selected_group.setVisible(False)
self.selection_changed.emit([]) self.selection_changed.emit([])
@ -904,7 +924,9 @@ class FolderDropWidget(QWidget):
self.load_level1_folders(path) self.load_level1_folders(path)
self.stacked_layout.setCurrentIndex(1) # 显示第一层 self.stacked_layout.setCurrentIndex(1) # 显示第一层
self.reset_btn.setVisible(True) self.reset_btn.setVisible(True)
self.back_btn.setVisible(False)
self.clear_btn.setVisible(True) self.clear_btn.setVisible(True)
self.full_select_btn.setVisible(True)
self.selected_group.setVisible(True) self.selected_group.setVisible(True)
def dragEnterEvent(self, event: QDragEnterEvent): def dragEnterEvent(self, event: QDragEnterEvent):
@ -933,6 +955,8 @@ class FolderDropWidget(QWidget):
def show_level2_folders(self, item): def show_level2_folders(self, item):
self.level2_list.clear() self.level2_list.clear()
self.reset_btn.setVisible(False)
self.back_btn.setVisible(True)
folder_path = item.data(Qt.ItemDataRole.UserRole) folder_path = item.data(Qt.ItemDataRole.UserRole)
for sub_item in os.listdir(folder_path): for sub_item in os.listdir(folder_path):

View File

@ -84,7 +84,7 @@ COMBO_BOX_STYLE = f"""
GROUP_BOX_MIN_WIDTH = 300 GROUP_BOX_MIN_WIDTH = 300
GROUP_BOX_MIN_HEIGHT = 120 GROUP_BOX_MIN_HEIGHT = 120
GROUP_BOX_SPACING = 5 GROUP_BOX_SPACING = 5
GROUP_BOX_MARGINS = (5, 5, 5, 5) GROUP_BOX_MARGINS = (1, 1, 1, 1)
GROUP_BOX_STYLE = f""" GROUP_BOX_STYLE = f"""
QGroupBox {{ QGroupBox {{
font-family: "{FONT_FAMILY}"; font-family: "{FONT_FAMILY}";

126
main.py
View File

@ -1,127 +1,5 @@
from PySide6.QtWidgets import (QApplication, QMainWindow, QWidget, QGridLayout, from PySide6.QtWidgets import QApplication
QPushButton, QSizePolicy, QSplitter, QToolBar) from MainWindow.mainwindow import ReportGeneratorUI
from PySide6.QtGui import QFontDatabase
from PySide6.QtCore import Signal, Qt
import os
from info_core.defines import *
from info_core.MyQtClass import ConfigComboBoxGroup, FolderDropWidget, DraggableLine
class ReportGeneratorUI(QMainWindow):
send_baogao_choose_info = Signal(list[str])
def __init__(self):
super().__init__()
# 加载字体
self.load_font()
# 设置窗口属性
self.setWindowTitle("报告生成器")
self.setMinimumSize(WINDOW_MIN_WIDTH, WINDOW_MIN_HEIGHT)
# 主窗口部件
self.central_widget = QWidget()
self.setCentralWidget(self.central_widget)
# 主布局
self.main_layout = QGridLayout(self.central_widget)
self.main_layout.setSpacing(MAIN_LAYOUT_SPACING)
self.main_layout.setContentsMargins(*MAIN_LAYOUT_MARGINS)
# 初始化UI
self.init_ui()
def load_font(self):
"""加载自定义字体"""
if os.path.exists(FONT_PATH):
font_id = QFontDatabase.addApplicationFont(FONT_PATH)
if font_id == -1:
print("字体加载失败,将使用系统默认字体")
else:
print(f"字体文件未找到: {FONT_PATH},将使用系统默认字体")
def init_ui(self):
"""初始化所有UI组件"""
# 第一行:项目信息和人员配置
self.project_group = ConfigComboBoxGroup("项目基本信息")
self.staff_group = ConfigComboBoxGroup("单次检查配置信息", is_project=False)
# 第二行:导入图片路径、填写机组信息
self.picture_group = FolderDropWidget()
# self.image_analysis =
# self.main_layout.addWidget(self.image_analysis, 1, 1)
# 第三行:生成报告按钮(跨两列)
self.create_generate_button()
self.generate_btn.setEnabled(False)
# 创建一个垂直分割器
self.splitter = QSplitter(Qt.Vertical)
self.splitter.setStyleSheet(SPLITTER_STYLE)
# 创建顶部和底部容器
top_container = QWidget()
top_container.setLayout(QGridLayout())
top_container.layout().addWidget(self.project_group, 0, 0)
top_container.layout().addWidget(self.staff_group, 0, 1)
middle_container = QWidget()
middle_container.setLayout(QGridLayout())
middle_container.layout().addWidget(self.picture_group, 0, 0)
# 添加部件到分割器
self.splitter.addWidget(top_container)
self.splitter.addWidget(middle_container)
# 设置主布局
self.main_layout.addWidget(self.splitter, 0, 0, 2, 2) # 占据前两行两列
self.main_layout.addWidget(self.generate_btn, 2, 0, 1, 2)
# 设置分割器初始比例
self.splitter.setStretchFactor(0, 1)
self.splitter.setStretchFactor(1, 4)
self.toolbar = QToolBar()
self.addToolBar(self.toolbar)
self.toolbar.setMovable(False)
self.toolbar.setFloatable(False)
new_action = self.toolbar.addAction("重置布局比例")
self.toolbar.addSeparator()
new_action.triggered.connect(self.reset_splitter)
def reset_splitter(self):
"""重置分割器的比例"""
total_size = sum(self.splitter.sizes()) # 获取当前总大小
self.splitter.setSizes([
int(total_size * 0.2), # 第一部分占 20%(比例 1:4
int(total_size * 0.8) # 第二部分占 80%
])
def on_generate_path_selected(self, path):
self.generate_btn.setEnabled(True)
# search_file_list = []
# if self.image_analysis.check_is_waibu:
# search_file_list.append("外汇总")
# if self.image_analysis.check_is_neibu:
# search_file_list.append("内汇总")
# if self.image_analysis.check_is_fanglei:
# search_file_list.append("防汇总")
# self.send_baogao_choose_info.emit(search_file_list)
def create_button(self, text):
"""创建统一风格的按钮"""
btn = QPushButton(text)
btn.setStyleSheet(BUTTON_STYLE)
btn.setFixedSize(BUTTON_WIDTH, BUTTON_HEIGHT)
return btn
def create_generate_button(self):
"""创建生成报告按钮"""
self.generate_btn = QPushButton("生成报告")
self.generate_btn.setStyleSheet(PRIMARY_BUTTON_STYLE)
self.generate_btn.setFixedHeight(50)
self.generate_btn.setSizePolicy(QSizePolicy.Expanding, QSizePolicy.Fixed)
if __name__ == "__main__": if __name__ == "__main__":
app = QApplication([]) app = QApplication([])

176
tool/lighter.py Normal file
View File

@ -0,0 +1,176 @@
import cv2
import numpy as np
# import rawpy
# from PIL import Image
# from tqdm import tqdm
# import tkinter as tk
#from tkinter import filedialog
#from concurrent.futures import ThreadPoolExecutor
#import argparse
#import os
#import torch
#from depth_anything_v2.dpt import DepthAnythingV2
#from utils import specify_name_group_blade,get_name
#parser = argparse.ArgumentParser(description='Depth Anything V2')
# parser.add_argument('--input-size', type=int, default=518)
# parser.add_argument('--encoder', type=str, default='vitl', choices=['vits', 'vitb', 'vitl', 'vitg'])
# args = parser.parse_args()
def extraction_win_lamina_mask(raw_image, depth_anything):
# 前面读的都是RGB图像深度估计需要BGR即下面是把RGB→BGR
raw_image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB)
depth = depth_anything.infer_image(raw_image)
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
depth = depth.astype(np.uint8)
_, otsu_mask = cv2.threshold(depth, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
return otsu_mask
def check_overexposure(image, lamina_mask,threshold):
"""
检查图像中是否有过曝区域
:param image: 输入图像
:param threshold: 亮度值的阈值像素值超过该值则认为是过曝区域
:return: 一个二值图像过曝区域为白色其它区域为黑色
"""
# 转换为浮动数据类型并归一化
image_float = image.astype(np.float32)
# 获取每个像素的亮度值(可以使用 YUV 或 RGB 亮度)
gray_image = cv2.cvtColor(image_float, cv2.COLOR_BGR2GRAY)
# 标记过曝区域,亮度大于阈值的像素为过曝区域
overexposure_mask = gray_image > threshold
overexposure_mask = (lamina_mask == 255) & overexposure_mask #只取叶片的亮光区mask
overexposure_number = np.sum(overexposure_mask * 1)
overexposed_pixels = gray_image[overexposure_mask]
# 计算曝光区域的平均亮度
if overexposed_pixels.size > 0:
avg_overexposed_value = np.mean(overexposed_pixels)
else:
avg_overexposed_value = 0 # 如果没有曝光区域,返回 0
return overexposure_mask, avg_overexposed_value,overexposure_number
def check_shawn(image,lamina_mask, threshold):
"""
检查图像中是否有阴影区域
:param image: 输入图像
:param threshold: 亮度值的阈值像素值小于该值则认为是过阴影域
:return: 一个二值图像过曝区域为白色其它区域为黑色
"""
# 转换为浮动数据类型并归一化
image_float = image.astype(np.float32)
# 获取每个像素的亮度值(可以使用 YUV 或 RGB 亮度)
gray_image = cv2.cvtColor(image_float, cv2.COLOR_BGR2GRAY)
# 标记阴影区域,亮度小于阈值的像素为阴影区域
shawn_mask = gray_image < threshold
shawn_mask = (lamina_mask == 255) & shawn_mask
shawn_number = np.sum(shawn_mask * 1)
shawn_pixels = gray_image[shawn_mask]
# 计算阴影区域的平均亮度
if shawn_pixels.size > 0:
avg_shawn_value = np.mean(shawn_pixels)
else:
avg_shawn_value = 0
return shawn_mask, avg_shawn_value,shawn_number
def smooth_overexposed_regions(image, overexposure_mask, kernel_size=(15, 15)):
"""
对过曝区域进行平滑处理修复与周围区域的过渡
:param image: 输入图像
:param overexposure_mask: 过曝区域的掩码
:param kernel_size: 高斯核大小
:return: 平滑过曝区域后的图像
"""
# 使用高斯模糊平滑过曝区域和周围区域
smoothed_image = cv2.GaussianBlur(image, kernel_size, 0)
# 将平滑后的图像和原始图像合成,修复过曝区域
fixed_image = np.where(overexposure_mask[..., None] == 255, image,smoothed_image)
return fixed_image
def bilateral_filter_adjustment(image, high_light_mask, d=15, sigma_color=75, sigma_space=75):
"""
使用双边滤波器平滑高光区域与周围区域的过渡
:param image: 输入图像
:param high_light_mask: 高光区域的掩码
:param d: 邻域的直径
:param sigma_color: 颜色空间的标准差
:param sigma_space: 坐标空间的标准差
:return: 平滑后的图像
"""
# 对整个图像应用双边滤波
filtered_image = cv2.bilateralFilter(image, d, sigma_color, sigma_space)
# 合成平滑后的图像和原图,保留非高光部分
final_image = np.where(high_light_mask[..., None] == 0, image, filtered_image)
return final_image
def adjust_highlights_shadows(image, threshold_shawn,threshold, depth_anything):
lamina_mask = extraction_win_lamina_mask(image, depth_anything)
shawn_mask,avg_shawn_value,shawn_number = check_shawn(image,lamina_mask,threshold_shawn)
# 调整亮度gamma < 1 时变亮gamma > 1 时变暗
gamma = 1 - (threshold_shawn-avg_shawn_value)/255.0
# print('阴影区调整的gama: '+str(gamma))#+str('\n'))
lookup_table = np.array([((i / 255.0) ** gamma) * 255 for i in range(256)]).astype('uint8')
# 应用 gamma 校正
if shawn_number !=0:
image[shawn_mask == True] = cv2.LUT(image[shawn_mask == True], lookup_table)
gamma_corrected_image = image
else:
gamma_corrected_image = image
#gamma_corrected_image = cv2.LUT(image, lookup_table)
#寻找过爆区域
overexposure_mask,avg_overexposed_value,overexposure_number = check_overexposure(gamma_corrected_image,lamina_mask,threshold)
reduction_factor = 1-(avg_overexposed_value-threshold) / 255
#reduction_factor = (avg_overexposed_value/255)*scale
# print("降低曝光区比例:" + str(reduction_factor))
# 调整亮度gamma < 1 时变亮越小越亮gamma > 1 时变暗(越大越暗)
print((1+reduction_factor))
lookup_table = np.array([((i / 255.0) ** (1+reduction_factor)) * 255 for i in range(256)]).astype('uint8')
# 应用 gamma 校正
if overexposure_number !=0:
gamma_corrected_image[overexposure_mask == True] = cv2.LUT(gamma_corrected_image[overexposure_mask == True], lookup_table)
#gamma_corrected_image[overexposure_mask == True] = np.clip(gamma_corrected_image[overexposure_mask == True] * reduction_factor, 0, 255)
#smoothed_image = smooth_overexposed_regions(gamma_corrected_image, overexposure_mask)
#smoothed_image = bilateral_filter_adjustment(gamma_corrected_image, overexposure_mask, d=15, sigma_color=75, sigma_space=75)
return gamma_corrected_image
# path=r'/home/dtyx/下载/1_涂层损伤_叶尖3m_2m_一般_紧急_一般_尽快打磨维修.jpg'
# from time import time
# if __name__ == "__main__":
# # input_dir = input("请输入处理图片路径: ")
# # output_dir = input("请输入保存图片路径: ")
# time_start = time()
# imgpath = path
# img = cv2.imdecode(np.fromfile(imgpath, dtype=np.uint8), cv2.IMREAD_COLOR)
# cv2.imshow('img',cv2.resize(img,(800,600)))
# mask=adjust_highlights_shadows(img,180,253)
# mask=cv2.resize(mask,(800,600))
# time_end = time()
# print("Time used:", time_end - time_start)
# cv2.imshow('mask',mask)
# cv2.waitKey(0)
# cv2.destroyAllWindows()
# # gamma = float(input("请输入 gamma 值: "))

18
tool/model_start.py Normal file
View File

@ -0,0 +1,18 @@
import torch
from depth_anything_v2.dpt import DepthAnythingV2
def model_start(model_path):
DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu'
model_configs = {
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
}
depth_anything = DepthAnythingV2()
depth_anything.load_state_dict(
torch.load(f'{model_path}', map_location='cpu'))
depth_anything = depth_anything.to(DEVICE).eval()
print("model loaded")
return depth_anything

141
tool/process_image.py Normal file
View File

@ -0,0 +1,141 @@
import os
import numpy as np
import cv2
from tool.lighter import adjust_highlights_shadows
from concurrent.futures import ThreadPoolExecutor
from PIL import Image
import piexif
# parser = argparse.ArgumentParser(description='Depth Anything V2')
# parser.add_argument('--input-paths', type=str, nargs='+', required=True,
# help='输入文件夹列表(多个路径,用空格分隔)')
# parser.add_argument('--output-path', type=str, default='./output',
# help="按输出路径的结构输出处理好的图片")
# parser.add_argument('--model-path', type=str, default='./model/depth_anything_v2_vitl.pth',
# help='模型路径')
# args = parser.parse_args()
def process_single_image(img_path, depth_anything):
"""处理单个图片
Args:
img_path: 图片路径
Returns:
tuple: (处理后的图片numpy数组, 原始图片路径)
"""
# 使用OpenCV读取图片
img = cv2.imdecode(np.fromfile(img_path, dtype=np.uint8), cv2.IMREAD_COLOR)
if img is None:
raise ValueError(f"无法读取图片: {img_path}")
# 处理图片 - 这里假设adjust_highlights_shadows能处理OpenCV格式的图片
processed_img = adjust_highlights_shadows(img, 180, 253, depth_anything)
return processed_img, img_path
def save_image_with_exif(input_path, output_path, processed_img):
"""保存图片并保留EXIF信息
Args:
input_path: 原始图片路径
output_path: 输出图片路径
processed_img: 处理后的numpy数组图片
"""
# 确保输出目录存在
os.makedirs(os.path.dirname(output_path), exist_ok=True)
# 读取原始图片的EXIF信息
exif_dict = None
try:
with exif_lock: # 使用锁防止多线程同时读取EXIF
with Image.open(input_path) as img:
if 'exif' in img.info:
exif_dict = piexif.load(img.info['exif'])
except Exception as e:
print(f"Warning: 无法读取 {input_path} 的EXIF信息: {str(e)}")
# 将OpenCV格式转换为PIL格式以便保存EXIF
processed_img_rgb = cv2.cvtColor(processed_img, cv2.COLOR_BGR2RGB)
pil_img = Image.fromarray(processed_img_rgb)
# 保存图片
try:
if exif_dict:
pil_img.save(output_path, exif=piexif.dump(exif_dict))
else:
pil_img.save(output_path)
except Exception as e:
print(f"Error: 无法保存图片 {output_path}: {str(e)}")
raise
def process_single_file(input_path, output_root, input_base, depth_anything):
"""处理单个文件"""
try:
# 计算相对路径以保持文件夹结构
rel_path = os.path.relpath(os.path.dirname(input_path), start=input_base)
output_dir = os.path.join(output_root, rel_path)
output_path = os.path.join(output_dir, os.path.basename(input_path))
# 处理图片
processed_img, _ = process_single_image(input_path, depth_anything)
# 保存图片
save_image_with_exif(input_path, output_path, processed_img)
print(f"Processed and saved: {output_path}")
except Exception as e:
print(f"Error processing {input_path}: {str(e)}")
def process_images(input_paths, output_root, depth_anything, workers=4):
"""处理所有图片并保持原文件夹结构(多线程版本)
Args:
input_paths: 输入路径列表
output_root: 输出根目录
workers: 线程池大小
"""
# 收集所有需要处理的文件
all_files = []
for input_path in input_paths:
for root, dirs, files in os.walk(input_path):
for file in files:
if file.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')):
input_file_path = os.path.join(root, file)
all_files.append((input_file_path, input_path))
# 使用线程池处理文件
with ThreadPoolExecutor(max_workers=workers) as executor:
futures = []
for file_path, input_base in all_files:
futures.append(
executor.submit(
process_single_file,
file_path,
output_root,
input_base,
depth_anything
)
)
# 等待所有任务完成
for future in futures:
try:
future.result()
except Exception as e:
print(f"Error in processing: {str(e)}")
# if __name__ == '__main__':
# model_path = args.model_path
# input_paths = args.input_paths
# output_path = args.output_path
# # 线程锁防止多线程同时访问EXIF数据时出现问题
# exif_lock = threading.Lock()
# input_paths = ["/home/dtyx/桌面/yhh/ReportGenerator/测试数据/山东国华无棣风电场叶片外部数据/A131301",]
# model_path = "./model/depth_anything_v2_vitl.pth"
# output_path = "./output"
# #启动模型
# depth_anything = model_start(model_path)
# # 创建输出目录
# os.makedirs(output_path, exist_ok=True)
# # 处理所有图片
# process_images(input_paths, output_path, depth_anything)

Binary file not shown.

After

Width:  |  Height:  |  Size: 453 KiB