diff --git a/classify_v3.py b/classify_v3.py new file mode 100644 index 0000000..bd70d65 --- /dev/null +++ b/classify_v3.py @@ -0,0 +1,346 @@ +import glob +import os +from collections import Counter + +import matplotlib +import numpy as np +from sklearn.cluster import KMeans, DBSCAN, AgglomerativeClustering +from sklearn.mixture import GaussianMixture +from pyexiv2 import Image + +# 设置中文字体,解决显示问题 +matplotlib.rcParams["font.family"] = ["SimHei", ] +matplotlib.rcParams['axes.unicode_minus'] = False # 正确显示负号 +# matplotlib.use('Agg') # 仅保存图片不显示,需配合 plt.savefig() +import matplotlib.pyplot as plt + +# 确保输出目录存在 +os.makedirs('clustering_results', exist_ok=True) + + +def read_xmp_target(p): + """读取图像的XMP元数据""" + try: + img = Image(p) + xmp = img.read_xmp() + + # 定义需要提取的XMP标签 + xmp_tags = { + 'Xmp.drone-dji.GpsLatitude': None, + 'Xmp.drone-dji.GpsLongitude': None, + 'Xmp.drone-dji.AbsoluteAltitude': None, + 'Xmp.drone-dji.GimbalRollDegree': None, + 'Xmp.drone-dji.GimbalYawDegree': None, + 'Xmp.drone-dji.GimbalPitchDegree': None, + 'Xmp.drone-dji.FlightRollDegree': None, + 'Xmp.drone-dji.FlightYawDegree': None, + 'Xmp.drone-dji.FlightPitchDegree': None, + 'Xmp.drone-dji.LRFTargetDistance': None, + 'Xmp.drone-dji.LRFTargetLon': None, + 'Xmp.drone-dji.LRFTargetLat': None, + 'Xmp.drone-dji.LRFTargetAlt': None, + 'Xmp.drone-dji.LRFTargetAbsAlt': None, + } + + # 提取标签值 + for key in xmp_tags.keys(): + if key in xmp: + xmp_tags[key] = xmp[key] + else: + print(f"警告: 图像 {p} 中未找到标签 {key}") + + return xmp_tags + except Exception as e: + print(f"读取图像 {p} 的XMP数据时出错: {e}") + return None + + +def repair_neighbors_of_minus1(labels): + """修复标签中的-1值,使用相邻区域的多数标签填充""" + out = np.array(labels, dtype=int) + n = len(out) + + # 找到所有 -1 的下标 + ind = np.where(out == -1)[0] + + # 把区间切成段:[0, ind[0]), [ind[0]+1, ind[1]), ..., [ind[-1]+1, n) + split_indices = [0] + (ind + 1).tolist() + [n] + + for k in range(len(split_indices) - 1): + start = split_indices[k] + end = split_indices[k + 1] + seg = out[start:end] + + # 只统计非 -1 的标签 + votes = seg[seg != -1] + if len(votes) == 0: + continue + winner = Counter(votes).most_common(1)[0][0] + + # 把区间内非 -1 标签统一成 winner + out[start:end] = np.where(seg == -1, -1, winner) + + return out.tolist() + + +def visualize_clusters(data, labels, cluster_centers, method_name): + """可视化数据点和聚类中心,包含-1标签和图片名称标注""" + # 提取需要可视化的数据 + norm_lat = [d['norm_lat'] for d in data] + norm_lon = [d['norm_lon'] for d in data] + alt = [d['clustering_alt'] for d in data] # 使用聚类时实际用的高度 + image_names = [os.path.basename(d['path'])[-10:] for d in data] # 取图片名后10个字符 + + # 创建3D图形 + fig = plt.figure(figsize=(12, 10)) + ax = fig.add_subplot(111, projection='3d') + + # 绘制数据点(使用标签区分颜色) + unique_labels = np.unique(labels) + colors = plt.cm.Spectral(np.linspace(0, 1, len(unique_labels))) + + for label, color in zip(unique_labels, colors): + if label == -1: + # -1标签点用黄色带黑色边缘突出显示 + color = [1, 1, 0, 0.8] # 黄色半透明 + edgecolor = 'black' + size = 80 + else: + edgecolor = 'none' + size = 50 + + mask = np.array(labels) == label + scatter = ax.scatter( + np.array(norm_lat)[mask], + np.array(norm_lon)[mask], + np.array(alt)[mask], + c=[color], alpha=0.7, s=size, edgecolors=edgecolor, + label=f'类别 {label}' if label != -1 else '待重新分类 (-1)' + ) + + # 标注图片名称(后10个字符) + for i, (x, y, z, name) in enumerate(zip(norm_lat, norm_lon, alt, image_names)): + # 只标注-1标签的数据点或每隔一定数量标注,避免过于拥挤 + if labels[i] == -1 : # 每隔5个点标注一个 or i % 5 == 0: + ax.text(x, y, z, name, fontsize=16, + bbox=dict(facecolor='white', alpha=0.5, boxstyle='round,pad=0.5')) + + # 如果有聚类中心,则绘制 + if cluster_centers is not None: + centers_x, centers_y, centers_z = zip(*cluster_centers) + ax.scatter( + centers_x, centers_y, centers_z, + c='red', marker='X', s=200, edgecolors='black', + label='聚类中心' + ) + + # 设置坐标轴标签 + ax.set_xlabel('归一化纬度 (×100000)') + ax.set_ylabel('归一化经度 (×100000)') + ax.set_zlabel('高度') + + # 添加标题和图例 + ax.set_title(f'{method_name} 聚类结果的3D分布') + ax.legend() + + # 调整视角,让标注更清晰 + ax.view_init(elev=30, azim=45) + + # 保存图形 + plt.tight_layout() + plt.savefig(f'clustering_results/{method_name}_clusters.jpg', dpi=300) + plt.show() + # plt.close() + + +def process_images_with_method(image_paths, method, method_name): + """使用指定的聚类方法处理图像数据""" + # 1. 读取数据 + data = [] + for path in image_paths: + xmp = read_xmp_target(path) + if not xmp: + continue + + try: + lat = float(xmp['Xmp.drone-dji.GpsLatitude']) if xmp['Xmp.drone-dji.GpsLatitude'] else 0 + lon = float(xmp['Xmp.drone-dji.GpsLongitude']) if xmp['Xmp.drone-dji.GpsLongitude'] else 0 + alt = float(xmp['Xmp.drone-dji.AbsoluteAltitude']) if xmp['Xmp.drone-dji.AbsoluteAltitude'] else 0 + yaw = float(xmp['Xmp.drone-dji.GimbalYawDegree']) if xmp['Xmp.drone-dji.GimbalYawDegree'] else 0 + pitch = float(xmp['Xmp.drone-dji.GimbalPitchDegree']) if xmp['Xmp.drone-dji.GimbalPitchDegree'] else 0 + + data.append({ + 'path': path, + 'lat': lat, 'lon': lon, 'alt': alt, + 'yaw': yaw, 'pitch': pitch, + 'lrf_lat': float(xmp['Xmp.drone-dji.LRFTargetLat']) if xmp['Xmp.drone-dji.LRFTargetLat'] else 0, + 'lrf_lon': float(xmp['Xmp.drone-dji.LRFTargetLon']) if xmp['Xmp.drone-dji.LRFTargetLon'] else 0, + 'lrf_alt': float(xmp['Xmp.drone-dji.LRFTargetAlt']) if xmp['Xmp.drone-dji.LRFTargetAlt'] else 0 + }) + except Exception as e: + print(f"处理图像 {path} 时出错: {e}") + continue + + if not data: + print("没有有效数据可处理") + return None, None, None + + # 2. 归一化GPS并考虑LRF数据 + # 首先收集所有可能用到的纬度和经度(包括LRF数据)用于计算最小值 + all_latitudes = [] + all_longitudes = [] + + for d in data: + all_latitudes.append(d['lat']) + all_latitudes.append(d['lrf_lat']) + all_longitudes.append(d['lon']) + all_longitudes.append(d['lrf_lon']) + + lat_min, lon_min = min(all_latitudes), min(all_longitudes) + distance_threshold = 10.0 # 距离阈值,单位:米(可根据实际情况调整) + + for d in data: + # 计算无人机位置与LRF目标点的距离(简化的欧氏距离计算) + dx = (d['lrf_lon'] - d['lon']) * 111319 # 经度差转米 (1度≈111319米) + dy = (d['lrf_lat'] - d['lat']) * 111319 # 纬度差转米 + distance = np.sqrt(dx ** 2 + dy ** 2) + + # 根据距离决定使用哪种数据 + if distance < distance_threshold and d['lrf_lat'] != 0 and d['lrf_lon'] != 0: + # 距离近且LRF数据有效,使用LRF数据 + d['norm_lat'] = (d['lrf_lat'] - lat_min) * 100000 + d['norm_lon'] = (d['lrf_lon'] - lon_min) * 100000 + d['used_data'] = 'LRF' # 标记使用了哪种数据 + # 使用LRF的高度 + d['clustering_alt'] = d['lrf_alt'] + else: + # 距离远或LRF数据无效,使用原始GPS数据 + d['norm_lat'] = (d['lat'] - lat_min) * 100000 + d['norm_lon'] = (d['lon'] - lon_min) * 100000 + d['used_data'] = 'GPS' # 标记使用了哪种数据 + # 使用原始高度 + d['clustering_alt'] = d['alt'] + + # 3. 准备聚类数据,使用决定后的高度数据 + X = np.array([[d['norm_lat'], d['norm_lon'], d['clustering_alt']] for d in data]) + + # 4. 应用聚类算法 + labels = None + cluster_centers = None + + if isinstance(method, KMeans): + method.fit(X) + labels = method.labels_ + cluster_centers = method.cluster_centers_ + elif isinstance(method, DBSCAN): + labels = method.fit_predict(X) + # DBSCAN没有传统意义上的聚类中心,这里计算每个簇的质心作为参考 + unique_labels = np.unique(labels) + cluster_centers = [] + for label in unique_labels: + if label != -1: # 排除噪声点 + cluster_points = X[labels == label] + cluster_centers.append(np.mean(cluster_points, axis=0)) + if not cluster_centers: + cluster_centers = None + elif isinstance(method, AgglomerativeClustering): + labels = method.fit_predict(X) + # 计算每个簇的质心作为参考 + unique_labels = np.unique(labels) + cluster_centers = [] + for label in unique_labels: + cluster_points = X[labels == label] + cluster_centers.append(np.mean(cluster_points, axis=0)) + elif isinstance(method, GaussianMixture): + labels = method.fit_predict(X) + cluster_centers = method.means_ + + # 5. 时序突变检测 + vec = [] + if len(data) > 1: + for i in range(1, len(data)): + prev, curr = data[i - 1], data[i] + yaw_change = abs(curr['yaw'] - prev['yaw']) + yaw_change = min(yaw_change, 360 - yaw_change) # 处理跨0° + pitch_change = abs(curr['pitch'] - prev['pitch']) + pos_change = np.linalg.norm([ + curr['norm_lat'] - prev['norm_lat'], + curr['norm_lon'] - prev['norm_lon'], + curr['clustering_alt'] - prev['clustering_alt'] + ]) + vec.append([ + curr['norm_lat'] - prev['norm_lat'], + curr['norm_lon'] - prev['norm_lon'], + curr['clustering_alt'] - prev['clustering_alt'], + yaw_change, + pitch_change + ]) + + # 若姿态突变但位置变化小→换面(保持标签) + # 若两者突变→换叶片(标记需重新检查) + if (yaw_change > 10 or pitch_change > 5) and pos_change > 10: + labels[i] = -1 # 待重新分类 + + # 6. 修复标签 + if labels is not None: + labels = repair_neighbors_of_minus1(labels) + + # 7. 可视化聚类结果 + visualize_clusters(data, labels, cluster_centers, method_name) + + return labels, vec, data + + +def compare_clustering_methods(image_paths): + """比较多种聚类方法""" + # 定义要比较的聚类方法 + n_clusters = 3 # 已知目标有3类 + clustering_methods = { + 'KMeans': KMeans(n_clusters=n_clusters, random_state=0), + 'DBSCAN': DBSCAN(eps=5, min_samples=3), # 参数可根据实际数据调整 + 'Agglomerative': AgglomerativeClustering(n_clusters=n_clusters), + 'GaussianMixture': GaussianMixture(n_components=n_clusters, random_state=0) + } + + results = {} + + # 对每种方法执行聚类 + for name, method in clustering_methods.items(): + print(f"正在使用 {name} 进行聚类...") + labels, vec, data = process_images_with_method(image_paths, method, name) + if labels is not None: + results[name] = { + 'labels': labels, + 'vec': vec, + 'data': data + } + print(f"{name} 聚类完成\n") + + return results + + +if __name__ == '__main__': + # 获取图像路径 + image_dir = r'D:\4.work\study\classify\img' + image_paths = glob.glob(os.path.join(image_dir, '*.jpg')) + + if not image_paths: + print(f"在目录 {image_dir} 中未找到任何JPG图像") + else: + print(f"找到 {len(image_paths)} 张图像,开始聚类比较...") + results = compare_clustering_methods(image_paths) + + # 输出每种方法的聚类结果 + for method_name, result in results.items(): + print(f"\n{method_name} 聚类结果:") + output_file = f'clustering_results/{method_name}_labels.txt' + with open(output_file, 'w', encoding='utf-8') as f: + for i, (path, label) in enumerate(zip( + [d['path'] for d in result['data']], + result['labels'])): + line = f"图像: {os.path.basename(path)}, 类别: {label}, 使用数据: {result['data'][i]['used_data']}" + if i < len(result['vec']): + line += f", 变化向量: {result['vec'][i]}" + print(line) + f.write(line + '\n') + print("\n所有聚类结果已保存到 clustering_results 目录")