AI_Agent/classify_v3.py

347 lines
13 KiB
Python
Raw Normal View History

2025-08-01 18:03:56 +08:00
import glob
import os
from collections import Counter
import matplotlib
import numpy as np
from sklearn.cluster import KMeans, DBSCAN, AgglomerativeClustering
from sklearn.mixture import GaussianMixture
from pyexiv2 import Image
# 设置中文字体,解决显示问题
matplotlib.rcParams["font.family"] = ["SimHei", ]
matplotlib.rcParams['axes.unicode_minus'] = False # 正确显示负号
# matplotlib.use('Agg') # 仅保存图片不显示,需配合 plt.savefig()
import matplotlib.pyplot as plt
# 确保输出目录存在
os.makedirs('clustering_results', exist_ok=True)
def read_xmp_target(p):
"""读取图像的XMP元数据"""
try:
img = Image(p)
xmp = img.read_xmp()
# 定义需要提取的XMP标签
xmp_tags = {
'Xmp.drone-dji.GpsLatitude': None,
'Xmp.drone-dji.GpsLongitude': None,
'Xmp.drone-dji.AbsoluteAltitude': None,
'Xmp.drone-dji.GimbalRollDegree': None,
'Xmp.drone-dji.GimbalYawDegree': None,
'Xmp.drone-dji.GimbalPitchDegree': None,
'Xmp.drone-dji.FlightRollDegree': None,
'Xmp.drone-dji.FlightYawDegree': None,
'Xmp.drone-dji.FlightPitchDegree': None,
'Xmp.drone-dji.LRFTargetDistance': None,
'Xmp.drone-dji.LRFTargetLon': None,
'Xmp.drone-dji.LRFTargetLat': None,
'Xmp.drone-dji.LRFTargetAlt': None,
'Xmp.drone-dji.LRFTargetAbsAlt': None,
}
# 提取标签值
for key in xmp_tags.keys():
if key in xmp:
xmp_tags[key] = xmp[key]
else:
print(f"警告: 图像 {p} 中未找到标签 {key}")
return xmp_tags
except Exception as e:
print(f"读取图像 {p} 的XMP数据时出错: {e}")
return None
def repair_neighbors_of_minus1(labels):
"""修复标签中的-1值使用相邻区域的多数标签填充"""
out = np.array(labels, dtype=int)
n = len(out)
# 找到所有 -1 的下标
ind = np.where(out == -1)[0]
# 把区间切成段:[0, ind[0]), [ind[0]+1, ind[1]), ..., [ind[-1]+1, n)
split_indices = [0] + (ind + 1).tolist() + [n]
for k in range(len(split_indices) - 1):
start = split_indices[k]
end = split_indices[k + 1]
seg = out[start:end]
# 只统计非 -1 的标签
votes = seg[seg != -1]
if len(votes) == 0:
continue
winner = Counter(votes).most_common(1)[0][0]
# 把区间内非 -1 标签统一成 winner
out[start:end] = np.where(seg == -1, -1, winner)
return out.tolist()
def visualize_clusters(data, labels, cluster_centers, method_name):
"""可视化数据点和聚类中心,包含-1标签和图片名称标注"""
# 提取需要可视化的数据
norm_lat = [d['norm_lat'] for d in data]
norm_lon = [d['norm_lon'] for d in data]
alt = [d['clustering_alt'] for d in data] # 使用聚类时实际用的高度
image_names = [os.path.basename(d['path'])[-10:] for d in data] # 取图片名后10个字符
# 创建3D图形
fig = plt.figure(figsize=(12, 10))
ax = fig.add_subplot(111, projection='3d')
# 绘制数据点(使用标签区分颜色)
unique_labels = np.unique(labels)
colors = plt.cm.Spectral(np.linspace(0, 1, len(unique_labels)))
for label, color in zip(unique_labels, colors):
if label == -1:
# -1标签点用黄色带黑色边缘突出显示
color = [1, 1, 0, 0.8] # 黄色半透明
edgecolor = 'black'
size = 80
else:
edgecolor = 'none'
size = 50
mask = np.array(labels) == label
scatter = ax.scatter(
np.array(norm_lat)[mask],
np.array(norm_lon)[mask],
np.array(alt)[mask],
c=[color], alpha=0.7, s=size, edgecolors=edgecolor,
label=f'类别 {label}' if label != -1 else '待重新分类 (-1)'
)
# 标注图片名称后10个字符
for i, (x, y, z, name) in enumerate(zip(norm_lat, norm_lon, alt, image_names)):
# 只标注-1标签的数据点或每隔一定数量标注避免过于拥挤
if labels[i] == -1 : # 每隔5个点标注一个 or i % 5 == 0:
ax.text(x, y, z, name, fontsize=16,
bbox=dict(facecolor='white', alpha=0.5, boxstyle='round,pad=0.5'))
# 如果有聚类中心,则绘制
if cluster_centers is not None:
centers_x, centers_y, centers_z = zip(*cluster_centers)
ax.scatter(
centers_x, centers_y, centers_z,
c='red', marker='X', s=200, edgecolors='black',
label='聚类中心'
)
# 设置坐标轴标签
ax.set_xlabel('归一化纬度 (×100000)')
ax.set_ylabel('归一化经度 (×100000)')
ax.set_zlabel('高度')
# 添加标题和图例
ax.set_title(f'{method_name} 聚类结果的3D分布')
ax.legend()
# 调整视角,让标注更清晰
ax.view_init(elev=30, azim=45)
# 保存图形
plt.tight_layout()
plt.savefig(f'clustering_results/{method_name}_clusters.jpg', dpi=300)
plt.show()
# plt.close()
def process_images_with_method(image_paths, method, method_name):
"""使用指定的聚类方法处理图像数据"""
# 1. 读取数据
data = []
for path in image_paths:
xmp = read_xmp_target(path)
if not xmp:
continue
try:
lat = float(xmp['Xmp.drone-dji.GpsLatitude']) if xmp['Xmp.drone-dji.GpsLatitude'] else 0
lon = float(xmp['Xmp.drone-dji.GpsLongitude']) if xmp['Xmp.drone-dji.GpsLongitude'] else 0
alt = float(xmp['Xmp.drone-dji.AbsoluteAltitude']) if xmp['Xmp.drone-dji.AbsoluteAltitude'] else 0
yaw = float(xmp['Xmp.drone-dji.GimbalYawDegree']) if xmp['Xmp.drone-dji.GimbalYawDegree'] else 0
pitch = float(xmp['Xmp.drone-dji.GimbalPitchDegree']) if xmp['Xmp.drone-dji.GimbalPitchDegree'] else 0
data.append({
'path': path,
'lat': lat, 'lon': lon, 'alt': alt,
'yaw': yaw, 'pitch': pitch,
'lrf_lat': float(xmp['Xmp.drone-dji.LRFTargetLat']) if xmp['Xmp.drone-dji.LRFTargetLat'] else 0,
'lrf_lon': float(xmp['Xmp.drone-dji.LRFTargetLon']) if xmp['Xmp.drone-dji.LRFTargetLon'] else 0,
'lrf_alt': float(xmp['Xmp.drone-dji.LRFTargetAlt']) if xmp['Xmp.drone-dji.LRFTargetAlt'] else 0
})
except Exception as e:
print(f"处理图像 {path} 时出错: {e}")
continue
if not data:
print("没有有效数据可处理")
return None, None, None
# 2. 归一化GPS并考虑LRF数据
# 首先收集所有可能用到的纬度和经度包括LRF数据用于计算最小值
all_latitudes = []
all_longitudes = []
for d in data:
all_latitudes.append(d['lat'])
all_latitudes.append(d['lrf_lat'])
all_longitudes.append(d['lon'])
all_longitudes.append(d['lrf_lon'])
lat_min, lon_min = min(all_latitudes), min(all_longitudes)
distance_threshold = 10.0 # 距离阈值,单位:米(可根据实际情况调整)
for d in data:
# 计算无人机位置与LRF目标点的距离简化的欧氏距离计算
dx = (d['lrf_lon'] - d['lon']) * 111319 # 经度差转米 (1度≈111319米)
dy = (d['lrf_lat'] - d['lat']) * 111319 # 纬度差转米
distance = np.sqrt(dx ** 2 + dy ** 2)
# 根据距离决定使用哪种数据
if distance < distance_threshold and d['lrf_lat'] != 0 and d['lrf_lon'] != 0:
# 距离近且LRF数据有效使用LRF数据
d['norm_lat'] = (d['lrf_lat'] - lat_min) * 100000
d['norm_lon'] = (d['lrf_lon'] - lon_min) * 100000
d['used_data'] = 'LRF' # 标记使用了哪种数据
# 使用LRF的高度
d['clustering_alt'] = d['lrf_alt']
else:
# 距离远或LRF数据无效使用原始GPS数据
d['norm_lat'] = (d['lat'] - lat_min) * 100000
d['norm_lon'] = (d['lon'] - lon_min) * 100000
d['used_data'] = 'GPS' # 标记使用了哪种数据
# 使用原始高度
d['clustering_alt'] = d['alt']
# 3. 准备聚类数据,使用决定后的高度数据
X = np.array([[d['norm_lat'], d['norm_lon'], d['clustering_alt']] for d in data])
# 4. 应用聚类算法
labels = None
cluster_centers = None
if isinstance(method, KMeans):
method.fit(X)
labels = method.labels_
cluster_centers = method.cluster_centers_
elif isinstance(method, DBSCAN):
labels = method.fit_predict(X)
# DBSCAN没有传统意义上的聚类中心这里计算每个簇的质心作为参考
unique_labels = np.unique(labels)
cluster_centers = []
for label in unique_labels:
if label != -1: # 排除噪声点
cluster_points = X[labels == label]
cluster_centers.append(np.mean(cluster_points, axis=0))
if not cluster_centers:
cluster_centers = None
elif isinstance(method, AgglomerativeClustering):
labels = method.fit_predict(X)
# 计算每个簇的质心作为参考
unique_labels = np.unique(labels)
cluster_centers = []
for label in unique_labels:
cluster_points = X[labels == label]
cluster_centers.append(np.mean(cluster_points, axis=0))
elif isinstance(method, GaussianMixture):
labels = method.fit_predict(X)
cluster_centers = method.means_
# 5. 时序突变检测
vec = []
if len(data) > 1:
for i in range(1, len(data)):
prev, curr = data[i - 1], data[i]
yaw_change = abs(curr['yaw'] - prev['yaw'])
yaw_change = min(yaw_change, 360 - yaw_change) # 处理跨0°
pitch_change = abs(curr['pitch'] - prev['pitch'])
pos_change = np.linalg.norm([
curr['norm_lat'] - prev['norm_lat'],
curr['norm_lon'] - prev['norm_lon'],
curr['clustering_alt'] - prev['clustering_alt']
])
vec.append([
curr['norm_lat'] - prev['norm_lat'],
curr['norm_lon'] - prev['norm_lon'],
curr['clustering_alt'] - prev['clustering_alt'],
yaw_change,
pitch_change
])
# 若姿态突变但位置变化小→换面(保持标签)
# 若两者突变→换叶片(标记需重新检查)
if (yaw_change > 10 or pitch_change > 5) and pos_change > 10:
labels[i] = -1 # 待重新分类
# 6. 修复标签
if labels is not None:
labels = repair_neighbors_of_minus1(labels)
# 7. 可视化聚类结果
visualize_clusters(data, labels, cluster_centers, method_name)
return labels, vec, data
def compare_clustering_methods(image_paths):
"""比较多种聚类方法"""
# 定义要比较的聚类方法
n_clusters = 3 # 已知目标有3类
clustering_methods = {
'KMeans': KMeans(n_clusters=n_clusters, random_state=0),
'DBSCAN': DBSCAN(eps=5, min_samples=3), # 参数可根据实际数据调整
'Agglomerative': AgglomerativeClustering(n_clusters=n_clusters),
'GaussianMixture': GaussianMixture(n_components=n_clusters, random_state=0)
}
results = {}
# 对每种方法执行聚类
for name, method in clustering_methods.items():
print(f"正在使用 {name} 进行聚类...")
labels, vec, data = process_images_with_method(image_paths, method, name)
if labels is not None:
results[name] = {
'labels': labels,
'vec': vec,
'data': data
}
print(f"{name} 聚类完成\n")
return results
if __name__ == '__main__':
# 获取图像路径
image_dir = r'D:\4.work\study\classify\img'
image_paths = glob.glob(os.path.join(image_dir, '*.jpg'))
if not image_paths:
print(f"在目录 {image_dir} 中未找到任何JPG图像")
else:
print(f"找到 {len(image_paths)} 张图像,开始聚类比较...")
results = compare_clustering_methods(image_paths)
# 输出每种方法的聚类结果
for method_name, result in results.items():
print(f"\n{method_name} 聚类结果:")
output_file = f'clustering_results/{method_name}_labels.txt'
with open(output_file, 'w', encoding='utf-8') as f:
for i, (path, label) in enumerate(zip(
[d['path'] for d in result['data']],
result['labels'])):
line = f"图像: {os.path.basename(path)}, 类别: {label}, 使用数据: {result['data'][i]['used_data']}"
if i < len(result['vec']):
line += f", 变化向量: {result['vec'][i]}"
print(line)
f.write(line + '\n')
print("\n所有聚类结果已保存到 clustering_results 目录")