00 使用 sb3 RL 框架进行机械臂抓取仿真

项目概述

这是一个基于 Stable Baselines3 (SB3)强化学习框架和 MuJoCo 物理引擎的机械臂抓取任务仿真训练系统。项目旨在训练机械臂完成"抓取随机放置的方块并放入目标区域"的复杂操作任务。

核心文件

  • mujoco_env_test.py :主训练/测试脚本,负责模型的训练配置、训练执行和推理测试
  • my_env_5dof_obs_no_target_pos_2cam_0923.py :自定义环境实现,包含机械臂控制逻辑、状态观察和奖励函数设计
  • model/aloha/aloha_single_2cam.xml :MuJoCo 物理模型配置文件,定义了机械臂、摄像头和环境物体

训练流程详解

1. 环境初始化与配置

# 创建环境工厂函数
def make_env(model_path, rank, seed=0, render_mode=None):
    def _init():
        env = RobotArmEnv(model_path, render_mode=render_mode, rank=rank, use_image_obs=True, frame_skip=3)
        env.reset(seed=seed + rank)
        return env
    return _init

# 配置参数
save_path = "/mnt/679135f3-d669-4d3d-8bea-b8c0b7074491/yyy/PycharmProjects/rl_test/model_sb3_targets_GRIPER_skp3_obs_targets_5dof_0923/"
policy = "MultiInputPolicy" if use_image else "MlpPolicy"

2. 模型训练配置

  • 算法选择 :使用 SAC (Soft Actor-Critic)算法,特别适合机械臂这类高维度连续控制任务
  • 并行训练 :创建 60 个并行环境加速训练
  • 断点续训 :通过 is_continue 参数控制是否从已有模型继续训练
  • 自动保存 :每 10,000 步自动保存模型检查点
# 创建并行环境
env_num = 60
env = SubprocVecEnv([make_env('model/aloha/aloha_single_2cam.xml', i, render_mode="rgb_array") for i in range(env_num)])

# 加载已有模型或创建新模型
if is_continue:
    model = SAC.load(path=save_path + "sac_robot_arm_1800000_steps.zip",
                    env=env,
                    verbose=0,
                    use_sde=True,
                    buffer_size=250_000)
else:
    model = SAC(policy, env, verbose=0, use_sde=True, buffer_size=250_000)

# 启动训练
model.learn(total_timesteps=100_000_000, callback=checkpoint_callback)

3. 推理模式

训练完成后,可以使用训练好的模型进行推理测试,验证机械臂的抓取能力:

# 不学习,只进行推理
print("开始使用模型进行推理...")

# 创建一个用于推理的环境(单环境,不是并行环境)
env_infer = RobotArmEnv('model/aloha/aloha_single_2cam.xml', render_mode="rgb_array", use_image_obs=True)

model = SAC.load(path=save_path + "sac_robot_arm_8400000_steps.zip",
                env=env_infer,
                verbose=0,
                use_sde=True,
                buffer_size=200_000)

环境设计详解

1. 机械臂配置

  • 自由度 :5 个激活的关节(waist, shoulder, elbow, wrist_angle, left_finger)
  • 关节范围 :每个关节都有特定的运动范围限制
  • 动作空间 :Box 空间,对应 5 个关节的控制值
self.joint_names = ['waist', 'shoulder', 'elbow', 'forearm_roll', 'wrist_angle', 'wrist_rotate', 'left_finger']
self.activated_joint_names = ['waist', 'shoulder', 'elbow', 'wrist_angle', 'left_finger']
self.joint_ranges = np.array([
    [-3.14158, 3.14158],   # waist
    [-1.85005, 1.25664],   # shoulder
    [-1.76278, 1.6057],    # elbow
    [-1.8675, 2.23402],    # wrist_angle
    [-0.1, 0.1]            # left_finger
], dtype=np.float32)

2. 观察空间设计

环境使用混合观察空间,包括图像观察和状态向量观察:

  • 图像观察 :来自两个摄像头的图像

    • 手腕摄像头(wrist_cam_left)
    • 俯视摄像头(top_down_camera)
  • 状态观察 :包含以下组件的向量

    • 关节角度和速度
    • 末端执行器位置和速度
    • 方块位置
    • 目标点位置
    • 抓取状态(二进制值)
if self.use_image_obs:
    self.observation_space = gym.spaces.Dict({
        "wrist_image": gym.spaces.Box(
            low=0, high=255,
            shape=(self.image_size[0], self.image_size[1], 3), 
            dtype=np.uint8
        ),
        "overhead_image": gym.spaces.Box(
            low=0, high=255,
            shape=(self.image_size[0], self.image_size[1], 3), 
            dtype=np.uint8
        ),
        "state": gym.spaces.Box(
            low=-np.inf, high=np.inf, 
            shape=(obs_dim,), 
            dtype=np.float32
        )
    })

3. 奖励函数设计

项目采用了复杂的多阶段奖励机制,引导机械臂按步骤完成任务:

  1. 接近奖励 :机械臂末端接近方块时获得奖励

    self.joint_names = ['waist', 'shoulder', 'elbow', 'forearm_roll', 'wrist_angle', 'wrist_rotate', 'left_finger']
    self.activated_joint_names = ['waist', 'shoulder', 'elbow', 'wrist_angle', 'left_finger']
    self.joint_ranges = np.array([
        [-3.14158, 3.14158],   # waist
        [-1.85005, 1.25664],   # shoulder
        [-1.76278, 1.6057],    # elbow
        [-1.8675, 2.23402],    # wrist_angle
        [-0.1, 0.1]            # left_finger
    ], dtype=np.float32)
    
  2. 抓取奖励 :当末端与方块足够接近时获得奖励

    if self.use_image_obs:
        self.observation_space = gym.spaces.Dict({
            "wrist_image": gym.spaces.Box(
                low=0, high=255,
                shape=(self.image_size[0], self.image_size[1], 3), 
                dtype=np.uint8
            ),
            "overhead_image": gym.spaces.Box(
                low=0, high=255,
                shape=(self.image_size[0], self.image_size[1], 3), 
                dtype=np.uint8
            ),
            "state": gym.spaces.Box(
                low=-np.inf, high=np.inf, 
                shape=(obs_dim,), 
                dtype=np.float32
            )
        })
    
  3. 抬起奖励 :当方块被成功抓取并抬起时获得奖励

    reach_dist = np.linalg.norm(grip_pos - test_cube_pos)
    reach_reward = 2 * np.exp(-3 * reach_dist)
    
  4. 运输奖励 :当方块被抬起并向目标移动时获得奖励

    place_dist = np.linalg.norm(test_cube_pos - target_point)
    transport_reward = transport_multiple * self.target_idx
    # 距离目标越近奖励越大
    if place_dist > dist_setted:
        transport_reward += (dist_setted - place_dist) * 20
    else:
        normalized_dist = place_dist / dist_setted
        transport_reward += transport_multiple * e_fun(-3, 1 - normalized_dist)
    
  5. 成功奖励 :当方块成功放入目标区域并松开时获得大额奖励

    success = self._is_block_in_target() and grasp_reward == 0.0
    success_reward = 18.0 if success else 0.0
    
  6. 惩罚机制 :

    • 长时间不动惩罚
    • 超时惩罚(3 秒)
    • 方块或机械臂超出目标区域惩罚

4. 任务完成条件

episode 结束的条件包括:

  • 方块成功放入目标区域并保持 20 步
  • 超时(超过 N 秒)
  • 机械臂长时间不动(超过 100 步移动距离小于阈值)
  • 方块或机械臂末端超出指定范围
def _check_done(self):
    # 条件1: 方块成功放入目标区域,且保持一段时间
    if self._is_block_in_target():
        self.success_counter += 1
        if self.success_counter >= 20:
            self.success_records.append(True)
            return True
    # 条件2: 超时
    if self.data.time > 3.0:  # 3秒超时
        self.success_records.append(False)
        return True
    return False

目标点与导航机制

项目实现了一种分段式导航机制,通过两个关键目标点引导机械臂完成任务:

  1. 第一个目标点:方块上方的位置
  2. 第二个目标点:目标区域上方的位置

机械臂需要依次到达这些目标点,系统会根据任务进度自动切换目标点:

self.targets = [
    [self.data.joint('test_cube_joint').qpos[0], self.data.joint('test_cube_joint').qpos[1],  0.25],
    [self.data.joint('dropbox_joint').qpos[0].copy(), self.data.joint('dropbox_joint').qpos[1].copy(), 0.18],
]

# 当运输奖励足够高时,切换到下一个目标点
if transport_reward >= transport_multiple * 0.8 and self.target_idx < len(self.targets) - 1:
    self.target_idx += 1

实验记录与可视化

项目包含完善的实验记录和可视化功能:

  1. 成功率统计 :记录最近 10 次任务的成功情况
  2. 实时监控 :显示各种奖励组件、关节状态、物体位置等信息
  3. 图像可视化 :显示摄像头观察图像,便于调试和监控训练进度
if self.rank == 0:
    self.success_rate = sum(self.success_records) / 10 if self.success_records else 0
    print(
        f"reach_reward:{np.round(reach_reward, 3)}"
        f"grasp_reward:{np.round(grasp_reward, 3)}"
        # ... 其他监控信息 ...
        f"Success_rate:{np.round(self.success_rate, 3) if self.success_records else 0}"
    )

项目执行流程

  1. 环境初始化 :加载 MuJoCo 模型,设置关节参数、动作空间和观察空间

  2. 训练阶段 :

    • 创建并行环境
    • 初始化或加载 SAC 模型
    • 执行强化学习训练,定期保存模型
  3. 推理阶段 (可选):

    • 加载训练好的模型
    • 创建单环境进行测试
    • 执行任务并评估性能

测试视频:

  • 机器学习

    机器学习(Machine Learning)是一门多领域交叉学科,涉及概率论、统计学、逼近论、凸分析、算法复杂度理论等多门学科。专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之不断改善自身的性能。

    78 引用 • 37 回帖
2 操作
yyy666guamiha 在 2025-10-25 16:42:19 更新了该帖
yyy666guamiha 在 2025-10-25 16:41:43 更新了该帖

相关帖子

欢迎来到这里!

我们正在构建一个小众社区,大家在这里相互信任,以平等 • 自由 • 奔放的价值观进行分享交流。最终,希望大家能够找到与自己志同道合的伙伴,共同成长。

注册 关于
请输入回帖内容 ...