叶片号聚类分拣
This commit is contained in:
parent
bd3b95395a
commit
a2fc3e2381
|
@ -0,0 +1,346 @@
|
|||
import glob
|
||||
import os
|
||||
from collections import Counter
|
||||
|
||||
import matplotlib
|
||||
import numpy as np
|
||||
from sklearn.cluster import KMeans, DBSCAN, AgglomerativeClustering
|
||||
from sklearn.mixture import GaussianMixture
|
||||
from pyexiv2 import Image
|
||||
|
||||
# 设置中文字体,解决显示问题
|
||||
matplotlib.rcParams["font.family"] = ["SimHei", ]
|
||||
matplotlib.rcParams['axes.unicode_minus'] = False # 正确显示负号
|
||||
# matplotlib.use('Agg') # 仅保存图片不显示,需配合 plt.savefig()
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# 确保输出目录存在
|
||||
os.makedirs('clustering_results', exist_ok=True)
|
||||
|
||||
|
||||
def read_xmp_target(p):
|
||||
"""读取图像的XMP元数据"""
|
||||
try:
|
||||
img = Image(p)
|
||||
xmp = img.read_xmp()
|
||||
|
||||
# 定义需要提取的XMP标签
|
||||
xmp_tags = {
|
||||
'Xmp.drone-dji.GpsLatitude': None,
|
||||
'Xmp.drone-dji.GpsLongitude': None,
|
||||
'Xmp.drone-dji.AbsoluteAltitude': None,
|
||||
'Xmp.drone-dji.GimbalRollDegree': None,
|
||||
'Xmp.drone-dji.GimbalYawDegree': None,
|
||||
'Xmp.drone-dji.GimbalPitchDegree': None,
|
||||
'Xmp.drone-dji.FlightRollDegree': None,
|
||||
'Xmp.drone-dji.FlightYawDegree': None,
|
||||
'Xmp.drone-dji.FlightPitchDegree': None,
|
||||
'Xmp.drone-dji.LRFTargetDistance': None,
|
||||
'Xmp.drone-dji.LRFTargetLon': None,
|
||||
'Xmp.drone-dji.LRFTargetLat': None,
|
||||
'Xmp.drone-dji.LRFTargetAlt': None,
|
||||
'Xmp.drone-dji.LRFTargetAbsAlt': None,
|
||||
}
|
||||
|
||||
# 提取标签值
|
||||
for key in xmp_tags.keys():
|
||||
if key in xmp:
|
||||
xmp_tags[key] = xmp[key]
|
||||
else:
|
||||
print(f"警告: 图像 {p} 中未找到标签 {key}")
|
||||
|
||||
return xmp_tags
|
||||
except Exception as e:
|
||||
print(f"读取图像 {p} 的XMP数据时出错: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def repair_neighbors_of_minus1(labels):
|
||||
"""修复标签中的-1值,使用相邻区域的多数标签填充"""
|
||||
out = np.array(labels, dtype=int)
|
||||
n = len(out)
|
||||
|
||||
# 找到所有 -1 的下标
|
||||
ind = np.where(out == -1)[0]
|
||||
|
||||
# 把区间切成段:[0, ind[0]), [ind[0]+1, ind[1]), ..., [ind[-1]+1, n)
|
||||
split_indices = [0] + (ind + 1).tolist() + [n]
|
||||
|
||||
for k in range(len(split_indices) - 1):
|
||||
start = split_indices[k]
|
||||
end = split_indices[k + 1]
|
||||
seg = out[start:end]
|
||||
|
||||
# 只统计非 -1 的标签
|
||||
votes = seg[seg != -1]
|
||||
if len(votes) == 0:
|
||||
continue
|
||||
winner = Counter(votes).most_common(1)[0][0]
|
||||
|
||||
# 把区间内非 -1 标签统一成 winner
|
||||
out[start:end] = np.where(seg == -1, -1, winner)
|
||||
|
||||
return out.tolist()
|
||||
|
||||
|
||||
def visualize_clusters(data, labels, cluster_centers, method_name):
|
||||
"""可视化数据点和聚类中心,包含-1标签和图片名称标注"""
|
||||
# 提取需要可视化的数据
|
||||
norm_lat = [d['norm_lat'] for d in data]
|
||||
norm_lon = [d['norm_lon'] for d in data]
|
||||
alt = [d['clustering_alt'] for d in data] # 使用聚类时实际用的高度
|
||||
image_names = [os.path.basename(d['path'])[-10:] for d in data] # 取图片名后10个字符
|
||||
|
||||
# 创建3D图形
|
||||
fig = plt.figure(figsize=(12, 10))
|
||||
ax = fig.add_subplot(111, projection='3d')
|
||||
|
||||
# 绘制数据点(使用标签区分颜色)
|
||||
unique_labels = np.unique(labels)
|
||||
colors = plt.cm.Spectral(np.linspace(0, 1, len(unique_labels)))
|
||||
|
||||
for label, color in zip(unique_labels, colors):
|
||||
if label == -1:
|
||||
# -1标签点用黄色带黑色边缘突出显示
|
||||
color = [1, 1, 0, 0.8] # 黄色半透明
|
||||
edgecolor = 'black'
|
||||
size = 80
|
||||
else:
|
||||
edgecolor = 'none'
|
||||
size = 50
|
||||
|
||||
mask = np.array(labels) == label
|
||||
scatter = ax.scatter(
|
||||
np.array(norm_lat)[mask],
|
||||
np.array(norm_lon)[mask],
|
||||
np.array(alt)[mask],
|
||||
c=[color], alpha=0.7, s=size, edgecolors=edgecolor,
|
||||
label=f'类别 {label}' if label != -1 else '待重新分类 (-1)'
|
||||
)
|
||||
|
||||
# 标注图片名称(后10个字符)
|
||||
for i, (x, y, z, name) in enumerate(zip(norm_lat, norm_lon, alt, image_names)):
|
||||
# 只标注-1标签的数据点或每隔一定数量标注,避免过于拥挤
|
||||
if labels[i] == -1 : # 每隔5个点标注一个 or i % 5 == 0:
|
||||
ax.text(x, y, z, name, fontsize=16,
|
||||
bbox=dict(facecolor='white', alpha=0.5, boxstyle='round,pad=0.5'))
|
||||
|
||||
# 如果有聚类中心,则绘制
|
||||
if cluster_centers is not None:
|
||||
centers_x, centers_y, centers_z = zip(*cluster_centers)
|
||||
ax.scatter(
|
||||
centers_x, centers_y, centers_z,
|
||||
c='red', marker='X', s=200, edgecolors='black',
|
||||
label='聚类中心'
|
||||
)
|
||||
|
||||
# 设置坐标轴标签
|
||||
ax.set_xlabel('归一化纬度 (×100000)')
|
||||
ax.set_ylabel('归一化经度 (×100000)')
|
||||
ax.set_zlabel('高度')
|
||||
|
||||
# 添加标题和图例
|
||||
ax.set_title(f'{method_name} 聚类结果的3D分布')
|
||||
ax.legend()
|
||||
|
||||
# 调整视角,让标注更清晰
|
||||
ax.view_init(elev=30, azim=45)
|
||||
|
||||
# 保存图形
|
||||
plt.tight_layout()
|
||||
plt.savefig(f'clustering_results/{method_name}_clusters.jpg', dpi=300)
|
||||
plt.show()
|
||||
# plt.close()
|
||||
|
||||
|
||||
def process_images_with_method(image_paths, method, method_name):
|
||||
"""使用指定的聚类方法处理图像数据"""
|
||||
# 1. 读取数据
|
||||
data = []
|
||||
for path in image_paths:
|
||||
xmp = read_xmp_target(path)
|
||||
if not xmp:
|
||||
continue
|
||||
|
||||
try:
|
||||
lat = float(xmp['Xmp.drone-dji.GpsLatitude']) if xmp['Xmp.drone-dji.GpsLatitude'] else 0
|
||||
lon = float(xmp['Xmp.drone-dji.GpsLongitude']) if xmp['Xmp.drone-dji.GpsLongitude'] else 0
|
||||
alt = float(xmp['Xmp.drone-dji.AbsoluteAltitude']) if xmp['Xmp.drone-dji.AbsoluteAltitude'] else 0
|
||||
yaw = float(xmp['Xmp.drone-dji.GimbalYawDegree']) if xmp['Xmp.drone-dji.GimbalYawDegree'] else 0
|
||||
pitch = float(xmp['Xmp.drone-dji.GimbalPitchDegree']) if xmp['Xmp.drone-dji.GimbalPitchDegree'] else 0
|
||||
|
||||
data.append({
|
||||
'path': path,
|
||||
'lat': lat, 'lon': lon, 'alt': alt,
|
||||
'yaw': yaw, 'pitch': pitch,
|
||||
'lrf_lat': float(xmp['Xmp.drone-dji.LRFTargetLat']) if xmp['Xmp.drone-dji.LRFTargetLat'] else 0,
|
||||
'lrf_lon': float(xmp['Xmp.drone-dji.LRFTargetLon']) if xmp['Xmp.drone-dji.LRFTargetLon'] else 0,
|
||||
'lrf_alt': float(xmp['Xmp.drone-dji.LRFTargetAlt']) if xmp['Xmp.drone-dji.LRFTargetAlt'] else 0
|
||||
})
|
||||
except Exception as e:
|
||||
print(f"处理图像 {path} 时出错: {e}")
|
||||
continue
|
||||
|
||||
if not data:
|
||||
print("没有有效数据可处理")
|
||||
return None, None, None
|
||||
|
||||
# 2. 归一化GPS并考虑LRF数据
|
||||
# 首先收集所有可能用到的纬度和经度(包括LRF数据)用于计算最小值
|
||||
all_latitudes = []
|
||||
all_longitudes = []
|
||||
|
||||
for d in data:
|
||||
all_latitudes.append(d['lat'])
|
||||
all_latitudes.append(d['lrf_lat'])
|
||||
all_longitudes.append(d['lon'])
|
||||
all_longitudes.append(d['lrf_lon'])
|
||||
|
||||
lat_min, lon_min = min(all_latitudes), min(all_longitudes)
|
||||
distance_threshold = 10.0 # 距离阈值,单位:米(可根据实际情况调整)
|
||||
|
||||
for d in data:
|
||||
# 计算无人机位置与LRF目标点的距离(简化的欧氏距离计算)
|
||||
dx = (d['lrf_lon'] - d['lon']) * 111319 # 经度差转米 (1度≈111319米)
|
||||
dy = (d['lrf_lat'] - d['lat']) * 111319 # 纬度差转米
|
||||
distance = np.sqrt(dx ** 2 + dy ** 2)
|
||||
|
||||
# 根据距离决定使用哪种数据
|
||||
if distance < distance_threshold and d['lrf_lat'] != 0 and d['lrf_lon'] != 0:
|
||||
# 距离近且LRF数据有效,使用LRF数据
|
||||
d['norm_lat'] = (d['lrf_lat'] - lat_min) * 100000
|
||||
d['norm_lon'] = (d['lrf_lon'] - lon_min) * 100000
|
||||
d['used_data'] = 'LRF' # 标记使用了哪种数据
|
||||
# 使用LRF的高度
|
||||
d['clustering_alt'] = d['lrf_alt']
|
||||
else:
|
||||
# 距离远或LRF数据无效,使用原始GPS数据
|
||||
d['norm_lat'] = (d['lat'] - lat_min) * 100000
|
||||
d['norm_lon'] = (d['lon'] - lon_min) * 100000
|
||||
d['used_data'] = 'GPS' # 标记使用了哪种数据
|
||||
# 使用原始高度
|
||||
d['clustering_alt'] = d['alt']
|
||||
|
||||
# 3. 准备聚类数据,使用决定后的高度数据
|
||||
X = np.array([[d['norm_lat'], d['norm_lon'], d['clustering_alt']] for d in data])
|
||||
|
||||
# 4. 应用聚类算法
|
||||
labels = None
|
||||
cluster_centers = None
|
||||
|
||||
if isinstance(method, KMeans):
|
||||
method.fit(X)
|
||||
labels = method.labels_
|
||||
cluster_centers = method.cluster_centers_
|
||||
elif isinstance(method, DBSCAN):
|
||||
labels = method.fit_predict(X)
|
||||
# DBSCAN没有传统意义上的聚类中心,这里计算每个簇的质心作为参考
|
||||
unique_labels = np.unique(labels)
|
||||
cluster_centers = []
|
||||
for label in unique_labels:
|
||||
if label != -1: # 排除噪声点
|
||||
cluster_points = X[labels == label]
|
||||
cluster_centers.append(np.mean(cluster_points, axis=0))
|
||||
if not cluster_centers:
|
||||
cluster_centers = None
|
||||
elif isinstance(method, AgglomerativeClustering):
|
||||
labels = method.fit_predict(X)
|
||||
# 计算每个簇的质心作为参考
|
||||
unique_labels = np.unique(labels)
|
||||
cluster_centers = []
|
||||
for label in unique_labels:
|
||||
cluster_points = X[labels == label]
|
||||
cluster_centers.append(np.mean(cluster_points, axis=0))
|
||||
elif isinstance(method, GaussianMixture):
|
||||
labels = method.fit_predict(X)
|
||||
cluster_centers = method.means_
|
||||
|
||||
# 5. 时序突变检测
|
||||
vec = []
|
||||
if len(data) > 1:
|
||||
for i in range(1, len(data)):
|
||||
prev, curr = data[i - 1], data[i]
|
||||
yaw_change = abs(curr['yaw'] - prev['yaw'])
|
||||
yaw_change = min(yaw_change, 360 - yaw_change) # 处理跨0°
|
||||
pitch_change = abs(curr['pitch'] - prev['pitch'])
|
||||
pos_change = np.linalg.norm([
|
||||
curr['norm_lat'] - prev['norm_lat'],
|
||||
curr['norm_lon'] - prev['norm_lon'],
|
||||
curr['clustering_alt'] - prev['clustering_alt']
|
||||
])
|
||||
vec.append([
|
||||
curr['norm_lat'] - prev['norm_lat'],
|
||||
curr['norm_lon'] - prev['norm_lon'],
|
||||
curr['clustering_alt'] - prev['clustering_alt'],
|
||||
yaw_change,
|
||||
pitch_change
|
||||
])
|
||||
|
||||
# 若姿态突变但位置变化小→换面(保持标签)
|
||||
# 若两者突变→换叶片(标记需重新检查)
|
||||
if (yaw_change > 10 or pitch_change > 5) and pos_change > 10:
|
||||
labels[i] = -1 # 待重新分类
|
||||
|
||||
# 6. 修复标签
|
||||
if labels is not None:
|
||||
labels = repair_neighbors_of_minus1(labels)
|
||||
|
||||
# 7. 可视化聚类结果
|
||||
visualize_clusters(data, labels, cluster_centers, method_name)
|
||||
|
||||
return labels, vec, data
|
||||
|
||||
|
||||
def compare_clustering_methods(image_paths):
|
||||
"""比较多种聚类方法"""
|
||||
# 定义要比较的聚类方法
|
||||
n_clusters = 3 # 已知目标有3类
|
||||
clustering_methods = {
|
||||
'KMeans': KMeans(n_clusters=n_clusters, random_state=0),
|
||||
'DBSCAN': DBSCAN(eps=5, min_samples=3), # 参数可根据实际数据调整
|
||||
'Agglomerative': AgglomerativeClustering(n_clusters=n_clusters),
|
||||
'GaussianMixture': GaussianMixture(n_components=n_clusters, random_state=0)
|
||||
}
|
||||
|
||||
results = {}
|
||||
|
||||
# 对每种方法执行聚类
|
||||
for name, method in clustering_methods.items():
|
||||
print(f"正在使用 {name} 进行聚类...")
|
||||
labels, vec, data = process_images_with_method(image_paths, method, name)
|
||||
if labels is not None:
|
||||
results[name] = {
|
||||
'labels': labels,
|
||||
'vec': vec,
|
||||
'data': data
|
||||
}
|
||||
print(f"{name} 聚类完成\n")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# 获取图像路径
|
||||
image_dir = r'D:\4.work\study\classify\img'
|
||||
image_paths = glob.glob(os.path.join(image_dir, '*.jpg'))
|
||||
|
||||
if not image_paths:
|
||||
print(f"在目录 {image_dir} 中未找到任何JPG图像")
|
||||
else:
|
||||
print(f"找到 {len(image_paths)} 张图像,开始聚类比较...")
|
||||
results = compare_clustering_methods(image_paths)
|
||||
|
||||
# 输出每种方法的聚类结果
|
||||
for method_name, result in results.items():
|
||||
print(f"\n{method_name} 聚类结果:")
|
||||
output_file = f'clustering_results/{method_name}_labels.txt'
|
||||
with open(output_file, 'w', encoding='utf-8') as f:
|
||||
for i, (path, label) in enumerate(zip(
|
||||
[d['path'] for d in result['data']],
|
||||
result['labels'])):
|
||||
line = f"图像: {os.path.basename(path)}, 类别: {label}, 使用数据: {result['data'][i]['used_data']}"
|
||||
if i < len(result['vec']):
|
||||
line += f", 变化向量: {result['vec'][i]}"
|
||||
print(line)
|
||||
f.write(line + '\n')
|
||||
print("\n所有聚类结果已保存到 clustering_results 目录")
|
Loading…
Reference in New Issue