Python 媒体识别包 Mediapipe - 姿势分类(PoseClassification)
姿势分类(PoseClassification)
模型训练
官方详细说明:https://mediapipe.page.link/pose_classification_basic
import os import csv from BootstrapHelper import BootstrapHelper from FullBodyPoseEmbedder import FullBodyPoseEmbedder from PoseClassification import PoseClassifier # region Bootstrap images # images_in_folder 的必需结构: # # fitness_poses_images_in/ # pushups_up/ # image_001.jpg # image_002.jpg # ... # pushups_down/ # image_001.jpg # image_002.jpg # ... # ... bootstrap_images_in_folder = 'mediapipe/images/pose/extended/train' # 引导图像和 CSV 的输出文件夹。 bootstrap_images_out_folder = 'mediapipe/images/pose/extended/images_out' bootstrap_csvs_out_folder = 'mediapipe/images/pose/extended/csvs_out' # 初始化助手。 bootstrap_helper = BootstrapHelper( images_in_folder=bootstrap_images_in_folder, images_out_folder=bootstrap_images_out_folder, csvs_out_folder=bootstrap_csvs_out_folder, ) # 检查有多少姿势类和图像可用。 bootstrap_helper.print_images_in_statistics() # 引导所有图像。 # 将限制设置为一些小的数字以进行调试。 bootstrap_helper.bootstrap(per_pose_class_limit=None) # 检查引导了多少图像。 bootstrap_helper.print_images_out_statistics() # 没有检测到姿势的初始引导图像仍然保存在 # 用于调试目的的文件夹(但不在 CSV 中)。 让我们移除它们。 bootstrap_helper.align_images_and_csvs(print_removed_items=False) bootstrap_helper.print_images_out_statistics() # endregion Bootstrap images # region Manual filtration # # 请手动验证预测并删除姿势预测错误的样本(图像)。 检查是否要求您仅根据预测的地标对姿势进行分类。 如果你不能 - 删除它。 # # 完成后对齐 CSV 和图像文件夹。 # # 将 CSV 与过滤后的图像对齐。 # bootstrap_helper.align_images_and_csvs(print_removed_items=False) # bootstrap_helper.print_images_out_statistics() # endregion Manual filtration # region Automatic filtration # 找出异常值。 # 将姿势地标转换为嵌入。 pose_embedder = FullBodyPoseEmbedder() # 根据姿势数据库对姿势进行分类。 pose_classifier = PoseClassifier( pose_samples_folder=bootstrap_csvs_out_folder, pose_embedder=pose_embedder, top_n_by_max_distance=30, top_n_by_mean_distance=10) outliers = pose_classifier.find_pose_sample_outliers() print('Number of outliers: ', len(outliers)) # 分析异常值。 bootstrap_helper.analyze_outliers(outliers) # 删除所有异常值(如果您不想手动选择)。 bootstrap_helper.remove_outliers(outliers) # 去除异常值后将 CSV 与图像对齐。 bootstrap_helper.align_images_and_csvs(print_removed_items=False) bootstrap_helper.print_images_out_statistics() # endregion Automatic filtration # region Dump for the App def dump_for_the_app(): pose_samples_folder = 'mediapipe/images/pose/extended/fitness_poses_csvs_out' pose_samples_csv_path = 'fitness_poses_csvs_out.csv' file_extension = 'csv' file_separator = ',' # 文件夹中的每个文件代表一个姿势类。 file_names = [name for name in os.listdir( pose_samples_folder) if name.endswith(file_extension)] with open(pose_samples_csv_path, 'w') as csv_out: csv_out_writer = csv.writer( csv_out, delimiter=file_separator, quoting=csv.QUOTE_MINIMAL) for file_name in file_names: # Use file name as pose class name. class_name = file_name[:-(len(file_extension) + 1)] # 一个文件行:`sample_00001,x1,y1,x2,y2,....`。 with open(os.path.join(pose_samples_folder, file_name)) as csv_in: csv_in_reader = csv.reader(csv_in, delimiter=file_separator) for row in csv_in_reader: row.insert(1, class_name) csv_out_writer.writerow(row) # 将过滤后的姿势转储到 CSV 并下载。 如何在 ML Kit 示例应用程序中使用此 CSV。 # dump_for_the_app() # endregion Dump for the App
姿势识别
官方详细说明:https://mediapipe.page.link/pose_classification_extended
import cv2 import numpy as np import tqdm from mediapipe.python.solutions import drawing_utils as mp_drawing from mediapipe.python.solutions import pose as mp_pose from EMADictSmoothing import EMADictSmoothing from FullBodyPoseEmbedder import FullBodyPoseEmbedder from PoseClassificationVisualizer import PoseClassificationVisualizer from RepetitionCounter import RepetitionCounter from PoseClassification import PoseClassifier # 指定您的视频名称和目标姿势类以计算重复次数。 video_path = 'mediapipe/images/pose/extended/test/5.mp4' class_name = 'down' out_video_path = 'mediapipe/images/pose/extended/test_out/5_out.mp4' # 打开视频。 video_cap = cv2.VideoCapture(video_path) # 获取一些视频参数以生成带分类的输出视频。 video_n_frames = video_cap.get(cv2.CAP_PROP_FRAME_COUNT) video_fps = video_cap.get(cv2.CAP_PROP_FPS) video_width = int(video_cap.get(cv2.CAP_PROP_FRAME_WIDTH)) video_height = int(video_cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) # 初始化跟踪器、分类器和计数器。 # 在每个视频之前都这样做,因为它们都有状态。 # 包含姿势类 CSV 的文件夹。 那应该是你使用的同一个文件夹 # 构建分类器以输出 CSV。 pose_samples_folder = 'mediapipe/images/pose/extended/csvs_out' # 初始化跟踪器。 pose_tracker = mp_pose.Pose(static_image_mode=False) # 初始化嵌入器。 pose_embedder = FullBodyPoseEmbedder() # 初始化分类器。 # 确认您使用的参数与引导期间相同。 pose_classifier = PoseClassifier( pose_samples_folder=pose_samples_folder, pose_embedder=pose_embedder, top_n_by_max_distance=30, top_n_by_mean_distance=10) # 取消注释以验证分类器使用的目标姿势并查找异常值。 # outliers =pose_classifier.find_pose_sample_outliers() # print('姿势样本异常值的数量(考虑删除它们):', len(outliers)) # 初始化 EMA 平滑。 pose_classification_filter = EMADictSmoothing( window_size=10, alpha=0.2) # 初始化计数器。 repetition_counter = RepetitionCounter( class_name=class_name, enter_threshold=6, exit_threshold=4) # 初始化渲染器。 pose_classification_visualizer = PoseClassificationVisualizer( class_name=class_name, plot_x_max=video_n_frames, # 如果与 `top_n_by_mean_distance` 相同,则图形看起来更好。 plot_y_max=10) # 对视频运行分类。 # 打开输出视频。 out_video = cv2.VideoWriter(out_video_path, cv2.VideoWriter_fourcc( *'mp4v'), video_fps, (video_width, video_height)) frame_idx = 0 output_frame = None with tqdm.tqdm(total=video_n_frames, position=0, leave=True) as pbar: while True: # 获取视频的下一帧。 success, input_frame = video_cap.read() if not success: break # 运行姿势跟踪器。 input_frame = cv2.cvtColor(input_frame, cv2.COLOR_BGR2RGB) result = pose_tracker.process(image=input_frame) pose_landmarks = result.pose_landmarks # 绘制姿势预测。 output_frame = input_frame.copy() if pose_landmarks is not None: mp_drawing.draw_landmarks( image=output_frame, landmark_list=pose_landmarks, connections=mp_pose.POSE_CONNECTIONS) if pose_landmarks is not None: # 获取地标。 frame_height, frame_width = output_frame.shape[0], output_frame.shape[1] pose_landmarks = np.array([[lmk.x * frame_width, lmk.y * frame_height, lmk.z * frame_width] for lmk in pose_landmarks.landmark], dtype=np.float32) assert pose_landmarks.shape == ( 33, 3), '意外的地标形状: {}'.format(pose_landmarks.shape) # 对当前帧的姿势进行分类。 pose_classification = pose_classifier(pose_landmarks) # 使用 EMA 进行平滑分类。 pose_classification_filtered = pose_classification_filter( pose_classification) # 计算重复次数。 repetitions_count = repetition_counter( pose_classification_filtered) else: # 没有姿势 => 当前帧没有分类。 pose_classification = None # 仍然向过滤器添加空分类以保持未来帧的正确平滑。 pose_classification_filtered = pose_classification_filter(dict()) pose_classification_filtered = None # 不要假设那个人被“冻结”而更新计数器。 只是取最新的重复次数。 repetitions_count = repetition_counter.n_repeats # 绘制分类图和重复计数器。 output_frame = pose_classification_visualizer( frame=output_frame, pose_classification=pose_classification, pose_classification_filtered=pose_classification_filtered, repetitions_count=repetitions_count) # 保存输出帧。 out_video.write(cv2.cvtColor( np.array(output_frame), cv2.COLOR_RGB2BGR)) # 显示视频的中间帧以跟踪进度。 cv2.imshow('姿势分类 - 俯卧撑(down)', np.array(output_frame)) cv2.waitKey(1) frame_idx += 1 pbar.update() # 关闭输出视频。 out_video.release() # 释放MediaPipe资源。 pose_tracker.close() cv2.waitKey()