Skip to content

Commit 48a7414

Browse files
bglick13Nathan Lambert
and
Nathan Lambert
authored
Add Value Function and corresponding example script to Diffuser implementation (#884)
* valuefunction code * start example scripts * missing imports * bug fixes and placeholder example script * add value function scheduler * load value function from hub and get best actions in example * very close to working example * larger batch size for planning * more tests * merge unet1d changes * wandb for debugging, use newer models * success! * turns out we just need more diffusion steps * run on modal * merge and code cleanup * use same api for rl model * fix variance type * wrong normalization function * add tests * style * style and quality * edits based on comments * style and quality * remove unused var * hack unet1d into a value function * add pipeline * fix arg order * add pipeline to core library * community pipeline * fix couple shape bugs * style * Apply suggestions from code review Co-authored-by: Nathan Lambert <[email protected]>
1 parent a6314f6 commit 48a7414

File tree

14 files changed

+1143
-28
lines changed

14 files changed

+1143
-28
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,4 +163,6 @@ tags
163163
*.lock
164164

165165
# DS_Store (MacOS)
166-
.DS_Store
166+
.DS_Store
167+
# RL pipelines may produce mp4 outputs
168+
*.mp4

examples/community/pipeline.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import torch
2+
3+
import tqdm
4+
from diffusers import DiffusionPipeline
5+
from diffusers.models.unet_1d import UNet1DModel
6+
from diffusers.utils.dummy_pt_objects import DDPMScheduler
7+
8+
9+
class ValueGuidedDiffuserPipeline(DiffusionPipeline):
10+
def __init__(
11+
self,
12+
value_function: UNet1DModel,
13+
unet: UNet1DModel,
14+
scheduler: DDPMScheduler,
15+
env,
16+
):
17+
super().__init__()
18+
self.value_function = value_function
19+
self.unet = unet
20+
self.scheduler = scheduler
21+
self.env = env
22+
self.data = env.get_dataset()
23+
self.means = dict()
24+
for key in self.data.keys():
25+
try:
26+
self.means[key] = self.data[key].mean()
27+
except:
28+
pass
29+
self.stds = dict()
30+
for key in self.data.keys():
31+
try:
32+
self.stds[key] = self.data[key].std()
33+
except:
34+
pass
35+
self.state_dim = env.observation_space.shape[0]
36+
self.action_dim = env.action_space.shape[0]
37+
38+
def normalize(self, x_in, key):
39+
return (x_in - self.means[key]) / self.stds[key]
40+
41+
def de_normalize(self, x_in, key):
42+
return x_in * self.stds[key] + self.means[key]
43+
44+
def to_torch(self, x_in):
45+
if type(x_in) is dict:
46+
return {k: self.to_torch(v) for k, v in x_in.items()}
47+
elif torch.is_tensor(x_in):
48+
return x_in.to(self.unet.device)
49+
return torch.tensor(x_in, device=self.unet.device)
50+
51+
def reset_x0(self, x_in, cond, act_dim):
52+
for key, val in cond.items():
53+
x_in[:, key, act_dim:] = val.clone()
54+
return x_in
55+
56+
def run_diffusion(self, x, conditions, n_guide_steps, scale):
57+
batch_size = x.shape[0]
58+
y = None
59+
for i in tqdm.tqdm(self.scheduler.timesteps):
60+
# create batch of timesteps to pass into model
61+
timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
62+
# 3. call the sample function
63+
for _ in range(n_guide_steps):
64+
with torch.enable_grad():
65+
x.requires_grad_()
66+
y = self.value_function(x.permute(0, 2, 1), timesteps).sample
67+
grad = torch.autograd.grad([y.sum()], [x])[0]
68+
69+
posterior_variance = self.scheduler._get_variance(i)
70+
model_std = torch.exp(0.5 * posterior_variance)
71+
grad = model_std * grad
72+
grad[timesteps < 2] = 0
73+
x = x.detach()
74+
x = x + scale * grad
75+
x = self.reset_x0(x, conditions, self.action_dim)
76+
prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
77+
x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
78+
79+
# 4. apply conditions to the trajectory
80+
x = self.reset_x0(x, conditions, self.action_dim)
81+
x = self.to_torch(x)
82+
return x, y
83+
84+
def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
85+
obs = self.normalize(obs, "observations")
86+
obs = obs[None].repeat(batch_size, axis=0)
87+
conditions = {0: self.to_torch(obs)}
88+
shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
89+
x1 = torch.randn(shape, device=self.unet.device)
90+
x = self.reset_x0(x1, conditions, self.action_dim)
91+
x = self.to_torch(x)
92+
x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
93+
sorted_idx = y.argsort(0, descending=True).squeeze()
94+
sorted_values = x[sorted_idx]
95+
actions = sorted_values[:, :, : self.action_dim]
96+
actions = actions.detach().cpu().numpy()
97+
denorm_actions = self.de_normalize(actions, key="actions")
98+
denorm_actions = denorm_actions[0, 0]
99+
return denorm_actions
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
import torch
2+
3+
import tqdm
4+
from diffusers import DiffusionPipeline
5+
from diffusers.models.unet_1d import UNet1DModel
6+
from diffusers.utils.dummy_pt_objects import DDPMScheduler
7+
8+
9+
class ValueGuidedDiffuserPipeline(DiffusionPipeline):
10+
def __init__(
11+
self,
12+
value_function: UNet1DModel,
13+
unet: UNet1DModel,
14+
scheduler: DDPMScheduler,
15+
env,
16+
):
17+
super().__init__()
18+
self.value_function = value_function
19+
self.unet = unet
20+
self.scheduler = scheduler
21+
self.env = env
22+
self.data = env.get_dataset()
23+
self.means = dict()
24+
for key in self.data.keys():
25+
try:
26+
self.means[key] = self.data[key].mean()
27+
except:
28+
pass
29+
self.stds = dict()
30+
for key in self.data.keys():
31+
try:
32+
self.stds[key] = self.data[key].std()
33+
except:
34+
pass
35+
self.state_dim = env.observation_space.shape[0]
36+
self.action_dim = env.action_space.shape[0]
37+
38+
def normalize(self, x_in, key):
39+
return (x_in - self.means[key]) / self.stds[key]
40+
41+
def de_normalize(self, x_in, key):
42+
return x_in * self.stds[key] + self.means[key]
43+
44+
def to_torch(self, x_in):
45+
if type(x_in) is dict:
46+
return {k: self.to_torch(v) for k, v in x_in.items()}
47+
elif torch.is_tensor(x_in):
48+
return x_in.to(self.unet.device)
49+
return torch.tensor(x_in, device=self.unet.device)
50+
51+
def reset_x0(self, x_in, cond, act_dim):
52+
for key, val in cond.items():
53+
x_in[:, key, act_dim:] = val.clone()
54+
return x_in
55+
56+
def run_diffusion(self, x, conditions, n_guide_steps, scale):
57+
batch_size = x.shape[0]
58+
y = None
59+
for i in tqdm.tqdm(self.scheduler.timesteps):
60+
# create batch of timesteps to pass into model
61+
timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
62+
# 3. call the sample function
63+
for _ in range(n_guide_steps):
64+
with torch.enable_grad():
65+
x.requires_grad_()
66+
y = self.value_function(x.permute(0, 2, 1), timesteps).sample
67+
grad = torch.autograd.grad([y.sum()], [x])[0]
68+
69+
posterior_variance = self.scheduler._get_variance(i)
70+
model_std = torch.exp(0.5 * posterior_variance)
71+
grad = model_std * grad
72+
grad[timesteps < 2] = 0
73+
x = x.detach()
74+
x = x + scale * grad
75+
x = self.reset_x0(x, conditions, self.action_dim)
76+
prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
77+
x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
78+
79+
# 4. apply conditions to the trajectory
80+
x = self.reset_x0(x, conditions, self.action_dim)
81+
x = self.to_torch(x)
82+
return x, y
83+
84+
def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
85+
obs = self.normalize(obs, "observations")
86+
obs = obs[None].repeat(batch_size, axis=0)
87+
conditions = {0: self.to_torch(obs)}
88+
shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
89+
x1 = torch.randn(shape, device=self.unet.device)
90+
x = self.reset_x0(x1, conditions, self.action_dim)
91+
x = self.to_torch(x)
92+
x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
93+
sorted_idx = y.argsort(0, descending=True).squeeze()
94+
sorted_values = x[sorted_idx]
95+
actions = sorted_values[:, :, : self.action_dim]
96+
actions = actions.detach().cpu().numpy()
97+
denorm_actions = self.de_normalize(actions, key="actions")
98+
denorm_actions = denorm_actions[0, 0]
99+
return denorm_actions

examples/diffuser/run_diffuser.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import numpy as np
2+
import torch
3+
4+
import d4rl # noqa
5+
import gym
6+
import tqdm
7+
import train_diffuser
8+
from diffusers import DDPMScheduler, UNet1DModel
9+
10+
11+
env_name = "hopper-medium-expert-v2"
12+
env = gym.make(env_name)
13+
data = env.get_dataset() # dataset is only used for normalization in this colab
14+
15+
DEVICE = "cpu"
16+
DTYPE = torch.float
17+
18+
# diffusion model settings
19+
n_samples = 4 # number of trajectories planned via diffusion
20+
horizon = 128 # length of sampled trajectories
21+
state_dim = env.observation_space.shape[0]
22+
action_dim = env.action_space.shape[0]
23+
num_inference_steps = 100 # number of difusion steps
24+
25+
26+
# Two generators for different parts of the diffusion loop to work in colab
27+
generator_cpu = torch.Generator(device="cpu")
28+
29+
scheduler = DDPMScheduler(num_train_timesteps=100, beta_schedule="squaredcos_cap_v2")
30+
31+
# 3 different pretrained models are available for this task.
32+
# The horizion represents the length of trajectories used in training.
33+
network = UNet1DModel.from_pretrained("fusing/ddpm-unet-rl-hopper-hor128").to(device=DEVICE)
34+
# network = TemporalUNet.from_pretrained("fusing/ddpm-unet-rl-hopper-hor256").to(device=DEVICE)
35+
# network = TemporalUNet.from_pretrained("fusing/ddpm-unet-rl-hopper-hor512").to(device=DEVICE)
36+
37+
38+
# network specific constants for inference
39+
clip_denoised = network.clip_denoised
40+
predict_epsilon = network.predict_epsilon
41+
42+
# [ observation_dim ] --> [ n_samples x observation_dim ]
43+
obs = env.reset()
44+
total_reward = 0
45+
done = False
46+
T = 300
47+
rollout = [obs.copy()]
48+
49+
try:
50+
for t in tqdm.tqdm(range(T)):
51+
obs_raw = obs
52+
53+
# normalize observations for forward passes
54+
obs = train_diffuser.normalize(obs, data, "observations")
55+
obs = obs[None].repeat(n_samples, axis=0)
56+
conditions = {0: train_diffuser.to_torch(obs, device=DEVICE)}
57+
58+
# constants for inference
59+
batch_size = len(conditions[0])
60+
shape = (batch_size, horizon, state_dim + action_dim)
61+
62+
# sample random initial noise vector
63+
x1 = torch.randn(shape, device=DEVICE, generator=generator_cpu)
64+
65+
# this model is conditioned from an initial state, so you will see this function
66+
# multiple times to change the initial state of generated data to the state
67+
# generated via env.reset() above or env.step() below
68+
x = train_diffuser.reset_x0(x1, conditions, action_dim)
69+
70+
# convert a np observation to torch for model forward pass
71+
x = train_diffuser.to_torch(x)
72+
73+
eta = 1.0 # noise factor for sampling reconstructed state
74+
75+
# run the diffusion process
76+
# for i in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps):
77+
for i in tqdm.tqdm(scheduler.timesteps):
78+
# create batch of timesteps to pass into model
79+
timesteps = torch.full((batch_size,), i, device=DEVICE, dtype=torch.long)
80+
81+
# 1. generate prediction from model
82+
with torch.no_grad():
83+
residual = network(x, timesteps).sample
84+
85+
# 2. use the model prediction to reconstruct an observation (de-noise)
86+
obs_reconstruct = scheduler.step(residual, i, x, predict_epsilon=predict_epsilon)["prev_sample"]
87+
88+
# 3. [optional] add posterior noise to the sample
89+
if eta > 0:
90+
noise = torch.randn(obs_reconstruct.shape, generator=generator_cpu).to(obs_reconstruct.device)
91+
posterior_variance = scheduler._get_variance(i) # * noise
92+
# no noise when t == 0
93+
# NOTE: original implementation missing sqrt on posterior_variance
94+
obs_reconstruct = (
95+
obs_reconstruct + int(i > 0) * (0.5 * posterior_variance) * eta * noise
96+
) # MJ had as log var, exponentiated
97+
98+
# 4. apply conditions to the trajectory
99+
obs_reconstruct_postcond = train_diffuser.reset_x0(obs_reconstruct, conditions, action_dim)
100+
x = train_diffuser.to_torch(obs_reconstruct_postcond)
101+
plans = train_diffuser.helpers.to_np(x[:, :, :action_dim])
102+
# select random plan
103+
idx = np.random.randint(plans.shape[0])
104+
# select action at correct time
105+
action = plans[idx, 0, :]
106+
actions = train_diffuser.de_normalize(action, data, "actions")
107+
# execute action in environment
108+
next_observation, reward, terminal, _ = env.step(action)
109+
110+
# update return
111+
total_reward += reward
112+
print(f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}")
113+
114+
# save observations for rendering
115+
rollout.append(next_observation.copy())
116+
obs = next_observation
117+
except KeyboardInterrupt:
118+
pass
119+
120+
print(f"Total reward: {total_reward}")
121+
render = train_diffuser.MuJoCoRenderer(env)
122+
train_diffuser.show_sample(render, np.expand_dims(np.stack(rollout), axis=0))

0 commit comments

Comments
 (0)