有没有帮忙做问卷调查的网站/百度自动驾驶技术
本文是lerobot[部署,元数据集,加载数据集]的后续
目录
- 评估策略
- 加载模型
- 加载环境
- 交互
- 训练策略
- 加载数据集
- 加载策略
- 加载优化器
- 前向传播,反向传播,更新参数
- 参考资料
- 后续
评估策略
策略评估这块基本就是套路式的三步走: 加载模型,加载环境,循环{ 获取状态,将策略输入到状态中获得动作,与环境交互}
加载模型
TODO: diffusion model解析
如果能够再本地找到pretrained_policy_path
就用本地参数,不能就从hf上下载
pretrained_policy_path = "lerobot/diffusion_pusht"
policy=DiffusionPolicy.from_pretrained(pretrained_policy_path)
加载环境
这里env 需要在前面下载环境的时候配置
env = gym.make("gym_pusht/PushT-v0",obs_type="pixels_agent_pos",max_episode_steps=300,
)
交互
采集state,policy(state) 获得action,env(action) 的循环
while not done:state = torch.from_numpy(numpy_observation["agent_pos"])image = torch.from_numpy(numpy_observation["pixels"])state = state.to(torch.float32)image = image.to(torch.float32) / 255image = image.permute(2, 0, 1)state = state.to(device, non_blocking=True)image = image.to(device, non_blocking=True)state = state.unsqueeze(0)image = image.unsqueeze(0)observation = {"observation.state": state,"observation.image": image,}with torch.inference_mode():action = policy.select_action(observation)numpy_action = action.squeeze(0).to("cpu").numpy()numpy_observation, reward, terminated, truncated, info = env.step(numpy_action)print(f"{step=} {reward=} {terminated=}")rewards.append(reward)frames.append(env.render())done = terminated | truncated | donestep += 1
if terminated:print("Success!")
else:print("Failure!")
最后还有一个可视化,将每帧图片连起来组成视频
video_path = output_directory / "rollout.mp4"
imageio.mimsave(str(video_path), numpy.stack(frames), fps=fps)
TODO插入视频
训练策略
训练策略的步骤也是比较固定,由于是offline training也是比较简单的:加载策略,加载数据集,加载优化器,前向传播,反向传播,更新参数。
加载数据集
dataset_metadata = LeRobotDatasetMetadata("lerobot/pusht")
features=dataset_to_policy_features(dataset_metadata.features)
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
input_features = {key: ft for key, ft in features.items() if key not in output_features}delta_timestamps = {"observation.image": [-0.1, 0.0],"observation.state": [-0.1, 0.0],"action": [-0.1, 0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4],}
dataset=LeRobotDataset("lerobot/pusht",delta_timestamps=delta_timestamps)
dataloader = torch.utils.data.DataLoader(dataset,num_workers=4,batch_size=64,shuffle=True,pin_memory=device.type != "cpu",drop_last=True,)
加载策略
cfg=DiffusionConfig(input_features=input_features,output_features=output_features)
policy=DiffusionPolicy(cfg,dataset_stats=dataset_metadata.stats)
policy.train()
policy.to(device)
加载优化器
optimizer = torch.optim.Adam(policy.parameters(), lr=1e-4)
前向传播,反向传播,更新参数
step = 0done = Falsewhile not done:for batch in dataloader:batch = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in batch.items()}loss, _ = policy.forward(batch)loss.backward()optimizer.step()optimizer.zero_grad()if step % log_freq == 0:print(f"step: {step} loss: {loss.item():.3f}")step += 1if step >= training_steps:done = Truebreak
参考资料
https://huggingface.co/lerobot
后续
diffusion policy,act算法解析