347 lines
13 KiB
Python
347 lines
13 KiB
Python
import glob
|
||
import os
|
||
from collections import Counter
|
||
|
||
import matplotlib
|
||
import numpy as np
|
||
from sklearn.cluster import KMeans, DBSCAN, AgglomerativeClustering
|
||
from sklearn.mixture import GaussianMixture
|
||
from pyexiv2 import Image
|
||
|
||
# 设置中文字体,解决显示问题
|
||
matplotlib.rcParams["font.family"] = ["SimHei", ]
|
||
matplotlib.rcParams['axes.unicode_minus'] = False # 正确显示负号
|
||
# matplotlib.use('Agg') # 仅保存图片不显示,需配合 plt.savefig()
|
||
import matplotlib.pyplot as plt
|
||
|
||
# 确保输出目录存在
|
||
os.makedirs('clustering_results', exist_ok=True)
|
||
|
||
|
||
def read_xmp_target(p):
|
||
"""读取图像的XMP元数据"""
|
||
try:
|
||
img = Image(p)
|
||
xmp = img.read_xmp()
|
||
|
||
# 定义需要提取的XMP标签
|
||
xmp_tags = {
|
||
'Xmp.drone-dji.GpsLatitude': None,
|
||
'Xmp.drone-dji.GpsLongitude': None,
|
||
'Xmp.drone-dji.AbsoluteAltitude': None,
|
||
'Xmp.drone-dji.GimbalRollDegree': None,
|
||
'Xmp.drone-dji.GimbalYawDegree': None,
|
||
'Xmp.drone-dji.GimbalPitchDegree': None,
|
||
'Xmp.drone-dji.FlightRollDegree': None,
|
||
'Xmp.drone-dji.FlightYawDegree': None,
|
||
'Xmp.drone-dji.FlightPitchDegree': None,
|
||
'Xmp.drone-dji.LRFTargetDistance': None,
|
||
'Xmp.drone-dji.LRFTargetLon': None,
|
||
'Xmp.drone-dji.LRFTargetLat': None,
|
||
'Xmp.drone-dji.LRFTargetAlt': None,
|
||
'Xmp.drone-dji.LRFTargetAbsAlt': None,
|
||
}
|
||
|
||
# 提取标签值
|
||
for key in xmp_tags.keys():
|
||
if key in xmp:
|
||
xmp_tags[key] = xmp[key]
|
||
else:
|
||
print(f"警告: 图像 {p} 中未找到标签 {key}")
|
||
|
||
return xmp_tags
|
||
except Exception as e:
|
||
print(f"读取图像 {p} 的XMP数据时出错: {e}")
|
||
return None
|
||
|
||
|
||
def repair_neighbors_of_minus1(labels):
|
||
"""修复标签中的-1值,使用相邻区域的多数标签填充"""
|
||
out = np.array(labels, dtype=int)
|
||
n = len(out)
|
||
|
||
# 找到所有 -1 的下标
|
||
ind = np.where(out == -1)[0]
|
||
|
||
# 把区间切成段:[0, ind[0]), [ind[0]+1, ind[1]), ..., [ind[-1]+1, n)
|
||
split_indices = [0] + (ind + 1).tolist() + [n]
|
||
|
||
for k in range(len(split_indices) - 1):
|
||
start = split_indices[k]
|
||
end = split_indices[k + 1]
|
||
seg = out[start:end]
|
||
|
||
# 只统计非 -1 的标签
|
||
votes = seg[seg != -1]
|
||
if len(votes) == 0:
|
||
continue
|
||
winner = Counter(votes).most_common(1)[0][0]
|
||
|
||
# 把区间内非 -1 标签统一成 winner
|
||
out[start:end] = np.where(seg == -1, -1, winner)
|
||
|
||
return out.tolist()
|
||
|
||
|
||
def visualize_clusters(data, labels, cluster_centers, method_name):
|
||
"""可视化数据点和聚类中心,包含-1标签和图片名称标注"""
|
||
# 提取需要可视化的数据
|
||
norm_lat = [d['norm_lat'] for d in data]
|
||
norm_lon = [d['norm_lon'] for d in data]
|
||
alt = [d['clustering_alt'] for d in data] # 使用聚类时实际用的高度
|
||
image_names = [os.path.basename(d['path'])[-10:] for d in data] # 取图片名后10个字符
|
||
|
||
# 创建3D图形
|
||
fig = plt.figure(figsize=(12, 10))
|
||
ax = fig.add_subplot(111, projection='3d')
|
||
|
||
# 绘制数据点(使用标签区分颜色)
|
||
unique_labels = np.unique(labels)
|
||
colors = plt.cm.Spectral(np.linspace(0, 1, len(unique_labels)))
|
||
|
||
for label, color in zip(unique_labels, colors):
|
||
if label == -1:
|
||
# -1标签点用黄色带黑色边缘突出显示
|
||
color = [1, 1, 0, 0.8] # 黄色半透明
|
||
edgecolor = 'black'
|
||
size = 80
|
||
else:
|
||
edgecolor = 'none'
|
||
size = 50
|
||
|
||
mask = np.array(labels) == label
|
||
scatter = ax.scatter(
|
||
np.array(norm_lat)[mask],
|
||
np.array(norm_lon)[mask],
|
||
np.array(alt)[mask],
|
||
c=[color], alpha=0.7, s=size, edgecolors=edgecolor,
|
||
label=f'类别 {label}' if label != -1 else '待重新分类 (-1)'
|
||
)
|
||
|
||
# 标注图片名称(后10个字符)
|
||
for i, (x, y, z, name) in enumerate(zip(norm_lat, norm_lon, alt, image_names)):
|
||
# 只标注-1标签的数据点或每隔一定数量标注,避免过于拥挤
|
||
if labels[i] == -1 : # 每隔5个点标注一个 or i % 5 == 0:
|
||
ax.text(x, y, z, name, fontsize=16,
|
||
bbox=dict(facecolor='white', alpha=0.5, boxstyle='round,pad=0.5'))
|
||
|
||
# 如果有聚类中心,则绘制
|
||
if cluster_centers is not None:
|
||
centers_x, centers_y, centers_z = zip(*cluster_centers)
|
||
ax.scatter(
|
||
centers_x, centers_y, centers_z,
|
||
c='red', marker='X', s=200, edgecolors='black',
|
||
label='聚类中心'
|
||
)
|
||
|
||
# 设置坐标轴标签
|
||
ax.set_xlabel('归一化纬度 (×100000)')
|
||
ax.set_ylabel('归一化经度 (×100000)')
|
||
ax.set_zlabel('高度')
|
||
|
||
# 添加标题和图例
|
||
ax.set_title(f'{method_name} 聚类结果的3D分布')
|
||
ax.legend()
|
||
|
||
# 调整视角,让标注更清晰
|
||
ax.view_init(elev=30, azim=45)
|
||
|
||
# 保存图形
|
||
plt.tight_layout()
|
||
plt.savefig(f'clustering_results/{method_name}_clusters.jpg', dpi=300)
|
||
plt.show()
|
||
# plt.close()
|
||
|
||
|
||
def process_images_with_method(image_paths, method, method_name):
|
||
"""使用指定的聚类方法处理图像数据"""
|
||
# 1. 读取数据
|
||
data = []
|
||
for path in image_paths:
|
||
xmp = read_xmp_target(path)
|
||
if not xmp:
|
||
continue
|
||
|
||
try:
|
||
lat = float(xmp['Xmp.drone-dji.GpsLatitude']) if xmp['Xmp.drone-dji.GpsLatitude'] else 0
|
||
lon = float(xmp['Xmp.drone-dji.GpsLongitude']) if xmp['Xmp.drone-dji.GpsLongitude'] else 0
|
||
alt = float(xmp['Xmp.drone-dji.AbsoluteAltitude']) if xmp['Xmp.drone-dji.AbsoluteAltitude'] else 0
|
||
yaw = float(xmp['Xmp.drone-dji.GimbalYawDegree']) if xmp['Xmp.drone-dji.GimbalYawDegree'] else 0
|
||
pitch = float(xmp['Xmp.drone-dji.GimbalPitchDegree']) if xmp['Xmp.drone-dji.GimbalPitchDegree'] else 0
|
||
|
||
data.append({
|
||
'path': path,
|
||
'lat': lat, 'lon': lon, 'alt': alt,
|
||
'yaw': yaw, 'pitch': pitch,
|
||
'lrf_lat': float(xmp['Xmp.drone-dji.LRFTargetLat']) if xmp['Xmp.drone-dji.LRFTargetLat'] else 0,
|
||
'lrf_lon': float(xmp['Xmp.drone-dji.LRFTargetLon']) if xmp['Xmp.drone-dji.LRFTargetLon'] else 0,
|
||
'lrf_alt': float(xmp['Xmp.drone-dji.LRFTargetAlt']) if xmp['Xmp.drone-dji.LRFTargetAlt'] else 0
|
||
})
|
||
except Exception as e:
|
||
print(f"处理图像 {path} 时出错: {e}")
|
||
continue
|
||
|
||
if not data:
|
||
print("没有有效数据可处理")
|
||
return None, None, None
|
||
|
||
# 2. 归一化GPS并考虑LRF数据
|
||
# 首先收集所有可能用到的纬度和经度(包括LRF数据)用于计算最小值
|
||
all_latitudes = []
|
||
all_longitudes = []
|
||
|
||
for d in data:
|
||
all_latitudes.append(d['lat'])
|
||
all_latitudes.append(d['lrf_lat'])
|
||
all_longitudes.append(d['lon'])
|
||
all_longitudes.append(d['lrf_lon'])
|
||
|
||
lat_min, lon_min = min(all_latitudes), min(all_longitudes)
|
||
distance_threshold = 10.0 # 距离阈值,单位:米(可根据实际情况调整)
|
||
|
||
for d in data:
|
||
# 计算无人机位置与LRF目标点的距离(简化的欧氏距离计算)
|
||
dx = (d['lrf_lon'] - d['lon']) * 111319 # 经度差转米 (1度≈111319米)
|
||
dy = (d['lrf_lat'] - d['lat']) * 111319 # 纬度差转米
|
||
distance = np.sqrt(dx ** 2 + dy ** 2)
|
||
|
||
# 根据距离决定使用哪种数据
|
||
if distance < distance_threshold and d['lrf_lat'] != 0 and d['lrf_lon'] != 0:
|
||
# 距离近且LRF数据有效,使用LRF数据
|
||
d['norm_lat'] = (d['lrf_lat'] - lat_min) * 100000
|
||
d['norm_lon'] = (d['lrf_lon'] - lon_min) * 100000
|
||
d['used_data'] = 'LRF' # 标记使用了哪种数据
|
||
# 使用LRF的高度
|
||
d['clustering_alt'] = d['lrf_alt']
|
||
else:
|
||
# 距离远或LRF数据无效,使用原始GPS数据
|
||
d['norm_lat'] = (d['lat'] - lat_min) * 100000
|
||
d['norm_lon'] = (d['lon'] - lon_min) * 100000
|
||
d['used_data'] = 'GPS' # 标记使用了哪种数据
|
||
# 使用原始高度
|
||
d['clustering_alt'] = d['alt']
|
||
|
||
# 3. 准备聚类数据,使用决定后的高度数据
|
||
X = np.array([[d['norm_lat'], d['norm_lon'], d['clustering_alt']] for d in data])
|
||
|
||
# 4. 应用聚类算法
|
||
labels = None
|
||
cluster_centers = None
|
||
|
||
if isinstance(method, KMeans):
|
||
method.fit(X)
|
||
labels = method.labels_
|
||
cluster_centers = method.cluster_centers_
|
||
elif isinstance(method, DBSCAN):
|
||
labels = method.fit_predict(X)
|
||
# DBSCAN没有传统意义上的聚类中心,这里计算每个簇的质心作为参考
|
||
unique_labels = np.unique(labels)
|
||
cluster_centers = []
|
||
for label in unique_labels:
|
||
if label != -1: # 排除噪声点
|
||
cluster_points = X[labels == label]
|
||
cluster_centers.append(np.mean(cluster_points, axis=0))
|
||
if not cluster_centers:
|
||
cluster_centers = None
|
||
elif isinstance(method, AgglomerativeClustering):
|
||
labels = method.fit_predict(X)
|
||
# 计算每个簇的质心作为参考
|
||
unique_labels = np.unique(labels)
|
||
cluster_centers = []
|
||
for label in unique_labels:
|
||
cluster_points = X[labels == label]
|
||
cluster_centers.append(np.mean(cluster_points, axis=0))
|
||
elif isinstance(method, GaussianMixture):
|
||
labels = method.fit_predict(X)
|
||
cluster_centers = method.means_
|
||
|
||
# 5. 时序突变检测
|
||
vec = []
|
||
if len(data) > 1:
|
||
for i in range(1, len(data)):
|
||
prev, curr = data[i - 1], data[i]
|
||
yaw_change = abs(curr['yaw'] - prev['yaw'])
|
||
yaw_change = min(yaw_change, 360 - yaw_change) # 处理跨0°
|
||
pitch_change = abs(curr['pitch'] - prev['pitch'])
|
||
pos_change = np.linalg.norm([
|
||
curr['norm_lat'] - prev['norm_lat'],
|
||
curr['norm_lon'] - prev['norm_lon'],
|
||
curr['clustering_alt'] - prev['clustering_alt']
|
||
])
|
||
vec.append([
|
||
curr['norm_lat'] - prev['norm_lat'],
|
||
curr['norm_lon'] - prev['norm_lon'],
|
||
curr['clustering_alt'] - prev['clustering_alt'],
|
||
yaw_change,
|
||
pitch_change
|
||
])
|
||
|
||
# 若姿态突变但位置变化小→换面(保持标签)
|
||
# 若两者突变→换叶片(标记需重新检查)
|
||
if (yaw_change > 10 or pitch_change > 5) and pos_change > 10:
|
||
labels[i] = -1 # 待重新分类
|
||
|
||
# 6. 修复标签
|
||
if labels is not None:
|
||
labels = repair_neighbors_of_minus1(labels)
|
||
|
||
# 7. 可视化聚类结果
|
||
visualize_clusters(data, labels, cluster_centers, method_name)
|
||
|
||
return labels, vec, data
|
||
|
||
|
||
def compare_clustering_methods(image_paths):
|
||
"""比较多种聚类方法"""
|
||
# 定义要比较的聚类方法
|
||
n_clusters = 3 # 已知目标有3类
|
||
clustering_methods = {
|
||
'KMeans': KMeans(n_clusters=n_clusters, random_state=0),
|
||
'DBSCAN': DBSCAN(eps=5, min_samples=3), # 参数可根据实际数据调整
|
||
'Agglomerative': AgglomerativeClustering(n_clusters=n_clusters),
|
||
'GaussianMixture': GaussianMixture(n_components=n_clusters, random_state=0)
|
||
}
|
||
|
||
results = {}
|
||
|
||
# 对每种方法执行聚类
|
||
for name, method in clustering_methods.items():
|
||
print(f"正在使用 {name} 进行聚类...")
|
||
labels, vec, data = process_images_with_method(image_paths, method, name)
|
||
if labels is not None:
|
||
results[name] = {
|
||
'labels': labels,
|
||
'vec': vec,
|
||
'data': data
|
||
}
|
||
print(f"{name} 聚类完成\n")
|
||
|
||
return results
|
||
|
||
|
||
if __name__ == '__main__':
|
||
# 获取图像路径
|
||
image_dir = r'D:\4.work\study\classify\img'
|
||
image_paths = glob.glob(os.path.join(image_dir, '*.jpg'))
|
||
|
||
if not image_paths:
|
||
print(f"在目录 {image_dir} 中未找到任何JPG图像")
|
||
else:
|
||
print(f"找到 {len(image_paths)} 张图像,开始聚类比较...")
|
||
results = compare_clustering_methods(image_paths)
|
||
|
||
# 输出每种方法的聚类结果
|
||
for method_name, result in results.items():
|
||
print(f"\n{method_name} 聚类结果:")
|
||
output_file = f'clustering_results/{method_name}_labels.txt'
|
||
with open(output_file, 'w', encoding='utf-8') as f:
|
||
for i, (path, label) in enumerate(zip(
|
||
[d['path'] for d in result['data']],
|
||
result['labels'])):
|
||
line = f"图像: {os.path.basename(path)}, 类别: {label}, 使用数据: {result['data'][i]['used_data']}"
|
||
if i < len(result['vec']):
|
||
line += f", 变化向量: {result['vec'][i]}"
|
||
print(line)
|
||
f.write(line + '\n')
|
||
print("\n所有聚类结果已保存到 clustering_results 目录")
|