AI_Agent/classify_v3.py

347 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import glob
import os
from collections import Counter
import matplotlib
import numpy as np
from sklearn.cluster import KMeans, DBSCAN, AgglomerativeClustering
from sklearn.mixture import GaussianMixture
from pyexiv2 import Image
# 设置中文字体,解决显示问题
matplotlib.rcParams["font.family"] = ["SimHei", ]
matplotlib.rcParams['axes.unicode_minus'] = False # 正确显示负号
# matplotlib.use('Agg') # 仅保存图片不显示,需配合 plt.savefig()
import matplotlib.pyplot as plt
# 确保输出目录存在
os.makedirs('clustering_results', exist_ok=True)
def read_xmp_target(p):
"""读取图像的XMP元数据"""
try:
img = Image(p)
xmp = img.read_xmp()
# 定义需要提取的XMP标签
xmp_tags = {
'Xmp.drone-dji.GpsLatitude': None,
'Xmp.drone-dji.GpsLongitude': None,
'Xmp.drone-dji.AbsoluteAltitude': None,
'Xmp.drone-dji.GimbalRollDegree': None,
'Xmp.drone-dji.GimbalYawDegree': None,
'Xmp.drone-dji.GimbalPitchDegree': None,
'Xmp.drone-dji.FlightRollDegree': None,
'Xmp.drone-dji.FlightYawDegree': None,
'Xmp.drone-dji.FlightPitchDegree': None,
'Xmp.drone-dji.LRFTargetDistance': None,
'Xmp.drone-dji.LRFTargetLon': None,
'Xmp.drone-dji.LRFTargetLat': None,
'Xmp.drone-dji.LRFTargetAlt': None,
'Xmp.drone-dji.LRFTargetAbsAlt': None,
}
# 提取标签值
for key in xmp_tags.keys():
if key in xmp:
xmp_tags[key] = xmp[key]
else:
print(f"警告: 图像 {p} 中未找到标签 {key}")
return xmp_tags
except Exception as e:
print(f"读取图像 {p} 的XMP数据时出错: {e}")
return None
def repair_neighbors_of_minus1(labels):
"""修复标签中的-1值使用相邻区域的多数标签填充"""
out = np.array(labels, dtype=int)
n = len(out)
# 找到所有 -1 的下标
ind = np.where(out == -1)[0]
# 把区间切成段:[0, ind[0]), [ind[0]+1, ind[1]), ..., [ind[-1]+1, n)
split_indices = [0] + (ind + 1).tolist() + [n]
for k in range(len(split_indices) - 1):
start = split_indices[k]
end = split_indices[k + 1]
seg = out[start:end]
# 只统计非 -1 的标签
votes = seg[seg != -1]
if len(votes) == 0:
continue
winner = Counter(votes).most_common(1)[0][0]
# 把区间内非 -1 标签统一成 winner
out[start:end] = np.where(seg == -1, -1, winner)
return out.tolist()
def visualize_clusters(data, labels, cluster_centers, method_name):
"""可视化数据点和聚类中心,包含-1标签和图片名称标注"""
# 提取需要可视化的数据
norm_lat = [d['norm_lat'] for d in data]
norm_lon = [d['norm_lon'] for d in data]
alt = [d['clustering_alt'] for d in data] # 使用聚类时实际用的高度
image_names = [os.path.basename(d['path'])[-10:] for d in data] # 取图片名后10个字符
# 创建3D图形
fig = plt.figure(figsize=(12, 10))
ax = fig.add_subplot(111, projection='3d')
# 绘制数据点(使用标签区分颜色)
unique_labels = np.unique(labels)
colors = plt.cm.Spectral(np.linspace(0, 1, len(unique_labels)))
for label, color in zip(unique_labels, colors):
if label == -1:
# -1标签点用黄色带黑色边缘突出显示
color = [1, 1, 0, 0.8] # 黄色半透明
edgecolor = 'black'
size = 80
else:
edgecolor = 'none'
size = 50
mask = np.array(labels) == label
scatter = ax.scatter(
np.array(norm_lat)[mask],
np.array(norm_lon)[mask],
np.array(alt)[mask],
c=[color], alpha=0.7, s=size, edgecolors=edgecolor,
label=f'类别 {label}' if label != -1 else '待重新分类 (-1)'
)
# 标注图片名称后10个字符
for i, (x, y, z, name) in enumerate(zip(norm_lat, norm_lon, alt, image_names)):
# 只标注-1标签的数据点或每隔一定数量标注避免过于拥挤
if labels[i] == -1 : # 每隔5个点标注一个 or i % 5 == 0:
ax.text(x, y, z, name, fontsize=16,
bbox=dict(facecolor='white', alpha=0.5, boxstyle='round,pad=0.5'))
# 如果有聚类中心,则绘制
if cluster_centers is not None:
centers_x, centers_y, centers_z = zip(*cluster_centers)
ax.scatter(
centers_x, centers_y, centers_z,
c='red', marker='X', s=200, edgecolors='black',
label='聚类中心'
)
# 设置坐标轴标签
ax.set_xlabel('归一化纬度 (×100000)')
ax.set_ylabel('归一化经度 (×100000)')
ax.set_zlabel('高度')
# 添加标题和图例
ax.set_title(f'{method_name} 聚类结果的3D分布')
ax.legend()
# 调整视角,让标注更清晰
ax.view_init(elev=30, azim=45)
# 保存图形
plt.tight_layout()
plt.savefig(f'clustering_results/{method_name}_clusters.jpg', dpi=300)
plt.show()
# plt.close()
def process_images_with_method(image_paths, method, method_name):
"""使用指定的聚类方法处理图像数据"""
# 1. 读取数据
data = []
for path in image_paths:
xmp = read_xmp_target(path)
if not xmp:
continue
try:
lat = float(xmp['Xmp.drone-dji.GpsLatitude']) if xmp['Xmp.drone-dji.GpsLatitude'] else 0
lon = float(xmp['Xmp.drone-dji.GpsLongitude']) if xmp['Xmp.drone-dji.GpsLongitude'] else 0
alt = float(xmp['Xmp.drone-dji.AbsoluteAltitude']) if xmp['Xmp.drone-dji.AbsoluteAltitude'] else 0
yaw = float(xmp['Xmp.drone-dji.GimbalYawDegree']) if xmp['Xmp.drone-dji.GimbalYawDegree'] else 0
pitch = float(xmp['Xmp.drone-dji.GimbalPitchDegree']) if xmp['Xmp.drone-dji.GimbalPitchDegree'] else 0
data.append({
'path': path,
'lat': lat, 'lon': lon, 'alt': alt,
'yaw': yaw, 'pitch': pitch,
'lrf_lat': float(xmp['Xmp.drone-dji.LRFTargetLat']) if xmp['Xmp.drone-dji.LRFTargetLat'] else 0,
'lrf_lon': float(xmp['Xmp.drone-dji.LRFTargetLon']) if xmp['Xmp.drone-dji.LRFTargetLon'] else 0,
'lrf_alt': float(xmp['Xmp.drone-dji.LRFTargetAlt']) if xmp['Xmp.drone-dji.LRFTargetAlt'] else 0
})
except Exception as e:
print(f"处理图像 {path} 时出错: {e}")
continue
if not data:
print("没有有效数据可处理")
return None, None, None
# 2. 归一化GPS并考虑LRF数据
# 首先收集所有可能用到的纬度和经度包括LRF数据用于计算最小值
all_latitudes = []
all_longitudes = []
for d in data:
all_latitudes.append(d['lat'])
all_latitudes.append(d['lrf_lat'])
all_longitudes.append(d['lon'])
all_longitudes.append(d['lrf_lon'])
lat_min, lon_min = min(all_latitudes), min(all_longitudes)
distance_threshold = 10.0 # 距离阈值,单位:米(可根据实际情况调整)
for d in data:
# 计算无人机位置与LRF目标点的距离简化的欧氏距离计算
dx = (d['lrf_lon'] - d['lon']) * 111319 # 经度差转米 (1度≈111319米)
dy = (d['lrf_lat'] - d['lat']) * 111319 # 纬度差转米
distance = np.sqrt(dx ** 2 + dy ** 2)
# 根据距离决定使用哪种数据
if distance < distance_threshold and d['lrf_lat'] != 0 and d['lrf_lon'] != 0:
# 距离近且LRF数据有效使用LRF数据
d['norm_lat'] = (d['lrf_lat'] - lat_min) * 100000
d['norm_lon'] = (d['lrf_lon'] - lon_min) * 100000
d['used_data'] = 'LRF' # 标记使用了哪种数据
# 使用LRF的高度
d['clustering_alt'] = d['lrf_alt']
else:
# 距离远或LRF数据无效使用原始GPS数据
d['norm_lat'] = (d['lat'] - lat_min) * 100000
d['norm_lon'] = (d['lon'] - lon_min) * 100000
d['used_data'] = 'GPS' # 标记使用了哪种数据
# 使用原始高度
d['clustering_alt'] = d['alt']
# 3. 准备聚类数据,使用决定后的高度数据
X = np.array([[d['norm_lat'], d['norm_lon'], d['clustering_alt']] for d in data])
# 4. 应用聚类算法
labels = None
cluster_centers = None
if isinstance(method, KMeans):
method.fit(X)
labels = method.labels_
cluster_centers = method.cluster_centers_
elif isinstance(method, DBSCAN):
labels = method.fit_predict(X)
# DBSCAN没有传统意义上的聚类中心这里计算每个簇的质心作为参考
unique_labels = np.unique(labels)
cluster_centers = []
for label in unique_labels:
if label != -1: # 排除噪声点
cluster_points = X[labels == label]
cluster_centers.append(np.mean(cluster_points, axis=0))
if not cluster_centers:
cluster_centers = None
elif isinstance(method, AgglomerativeClustering):
labels = method.fit_predict(X)
# 计算每个簇的质心作为参考
unique_labels = np.unique(labels)
cluster_centers = []
for label in unique_labels:
cluster_points = X[labels == label]
cluster_centers.append(np.mean(cluster_points, axis=0))
elif isinstance(method, GaussianMixture):
labels = method.fit_predict(X)
cluster_centers = method.means_
# 5. 时序突变检测
vec = []
if len(data) > 1:
for i in range(1, len(data)):
prev, curr = data[i - 1], data[i]
yaw_change = abs(curr['yaw'] - prev['yaw'])
yaw_change = min(yaw_change, 360 - yaw_change) # 处理跨0°
pitch_change = abs(curr['pitch'] - prev['pitch'])
pos_change = np.linalg.norm([
curr['norm_lat'] - prev['norm_lat'],
curr['norm_lon'] - prev['norm_lon'],
curr['clustering_alt'] - prev['clustering_alt']
])
vec.append([
curr['norm_lat'] - prev['norm_lat'],
curr['norm_lon'] - prev['norm_lon'],
curr['clustering_alt'] - prev['clustering_alt'],
yaw_change,
pitch_change
])
# 若姿态突变但位置变化小→换面(保持标签)
# 若两者突变→换叶片(标记需重新检查)
if (yaw_change > 10 or pitch_change > 5) and pos_change > 10:
labels[i] = -1 # 待重新分类
# 6. 修复标签
if labels is not None:
labels = repair_neighbors_of_minus1(labels)
# 7. 可视化聚类结果
visualize_clusters(data, labels, cluster_centers, method_name)
return labels, vec, data
def compare_clustering_methods(image_paths):
"""比较多种聚类方法"""
# 定义要比较的聚类方法
n_clusters = 3 # 已知目标有3类
clustering_methods = {
'KMeans': KMeans(n_clusters=n_clusters, random_state=0),
'DBSCAN': DBSCAN(eps=5, min_samples=3), # 参数可根据实际数据调整
'Agglomerative': AgglomerativeClustering(n_clusters=n_clusters),
'GaussianMixture': GaussianMixture(n_components=n_clusters, random_state=0)
}
results = {}
# 对每种方法执行聚类
for name, method in clustering_methods.items():
print(f"正在使用 {name} 进行聚类...")
labels, vec, data = process_images_with_method(image_paths, method, name)
if labels is not None:
results[name] = {
'labels': labels,
'vec': vec,
'data': data
}
print(f"{name} 聚类完成\n")
return results
if __name__ == '__main__':
# 获取图像路径
image_dir = r'D:\4.work\study\classify\img'
image_paths = glob.glob(os.path.join(image_dir, '*.jpg'))
if not image_paths:
print(f"在目录 {image_dir} 中未找到任何JPG图像")
else:
print(f"找到 {len(image_paths)} 张图像,开始聚类比较...")
results = compare_clustering_methods(image_paths)
# 输出每种方法的聚类结果
for method_name, result in results.items():
print(f"\n{method_name} 聚类结果:")
output_file = f'clustering_results/{method_name}_labels.txt'
with open(output_file, 'w', encoding='utf-8') as f:
for i, (path, label) in enumerate(zip(
[d['path'] for d in result['data']],
result['labels'])):
line = f"图像: {os.path.basename(path)}, 类别: {label}, 使用数据: {result['data'][i]['used_data']}"
if i < len(result['vec']):
line += f", 变化向量: {result['vec'][i]}"
print(line)
f.write(line + '\n')
print("\n所有聚类结果已保存到 clustering_results 目录")