|
| 1 | +import numpy as np |
| 2 | +import torch |
| 3 | + |
| 4 | +import d4rl # noqa |
| 5 | +import gym |
| 6 | +import tqdm |
| 7 | +import train_diffuser |
| 8 | +from diffusers import DDPMScheduler, UNet1DModel |
| 9 | + |
| 10 | + |
| 11 | +env_name = "hopper-medium-expert-v2" |
| 12 | +env = gym.make(env_name) |
| 13 | +data = env.get_dataset() # dataset is only used for normalization in this colab |
| 14 | + |
| 15 | +DEVICE = "cpu" |
| 16 | +DTYPE = torch.float |
| 17 | + |
| 18 | +# diffusion model settings |
| 19 | +n_samples = 4 # number of trajectories planned via diffusion |
| 20 | +horizon = 128 # length of sampled trajectories |
| 21 | +state_dim = env.observation_space.shape[0] |
| 22 | +action_dim = env.action_space.shape[0] |
| 23 | +num_inference_steps = 100 # number of difusion steps |
| 24 | + |
| 25 | + |
| 26 | +# Two generators for different parts of the diffusion loop to work in colab |
| 27 | +generator_cpu = torch.Generator(device="cpu") |
| 28 | + |
| 29 | +scheduler = DDPMScheduler(num_train_timesteps=100, beta_schedule="squaredcos_cap_v2") |
| 30 | + |
| 31 | +# 3 different pretrained models are available for this task. |
| 32 | +# The horizion represents the length of trajectories used in training. |
| 33 | +network = UNet1DModel.from_pretrained("fusing/ddpm-unet-rl-hopper-hor128").to(device=DEVICE) |
| 34 | +# network = TemporalUNet.from_pretrained("fusing/ddpm-unet-rl-hopper-hor256").to(device=DEVICE) |
| 35 | +# network = TemporalUNet.from_pretrained("fusing/ddpm-unet-rl-hopper-hor512").to(device=DEVICE) |
| 36 | + |
| 37 | + |
| 38 | +# network specific constants for inference |
| 39 | +clip_denoised = network.clip_denoised |
| 40 | +predict_epsilon = network.predict_epsilon |
| 41 | + |
| 42 | +# [ observation_dim ] --> [ n_samples x observation_dim ] |
| 43 | +obs = env.reset() |
| 44 | +total_reward = 0 |
| 45 | +done = False |
| 46 | +T = 300 |
| 47 | +rollout = [obs.copy()] |
| 48 | + |
| 49 | +try: |
| 50 | + for t in tqdm.tqdm(range(T)): |
| 51 | + obs_raw = obs |
| 52 | + |
| 53 | + # normalize observations for forward passes |
| 54 | + obs = train_diffuser.normalize(obs, data, "observations") |
| 55 | + obs = obs[None].repeat(n_samples, axis=0) |
| 56 | + conditions = {0: train_diffuser.to_torch(obs, device=DEVICE)} |
| 57 | + |
| 58 | + # constants for inference |
| 59 | + batch_size = len(conditions[0]) |
| 60 | + shape = (batch_size, horizon, state_dim + action_dim) |
| 61 | + |
| 62 | + # sample random initial noise vector |
| 63 | + x1 = torch.randn(shape, device=DEVICE, generator=generator_cpu) |
| 64 | + |
| 65 | + # this model is conditioned from an initial state, so you will see this function |
| 66 | + # multiple times to change the initial state of generated data to the state |
| 67 | + # generated via env.reset() above or env.step() below |
| 68 | + x = train_diffuser.reset_x0(x1, conditions, action_dim) |
| 69 | + |
| 70 | + # convert a np observation to torch for model forward pass |
| 71 | + x = train_diffuser.to_torch(x) |
| 72 | + |
| 73 | + eta = 1.0 # noise factor for sampling reconstructed state |
| 74 | + |
| 75 | + # run the diffusion process |
| 76 | + # for i in tqdm.tqdm(reversed(range(num_inference_steps)), total=num_inference_steps): |
| 77 | + for i in tqdm.tqdm(scheduler.timesteps): |
| 78 | + # create batch of timesteps to pass into model |
| 79 | + timesteps = torch.full((batch_size,), i, device=DEVICE, dtype=torch.long) |
| 80 | + |
| 81 | + # 1. generate prediction from model |
| 82 | + with torch.no_grad(): |
| 83 | + residual = network(x, timesteps).sample |
| 84 | + |
| 85 | + # 2. use the model prediction to reconstruct an observation (de-noise) |
| 86 | + obs_reconstruct = scheduler.step(residual, i, x, predict_epsilon=predict_epsilon)["prev_sample"] |
| 87 | + |
| 88 | + # 3. [optional] add posterior noise to the sample |
| 89 | + if eta > 0: |
| 90 | + noise = torch.randn(obs_reconstruct.shape, generator=generator_cpu).to(obs_reconstruct.device) |
| 91 | + posterior_variance = scheduler._get_variance(i) # * noise |
| 92 | + # no noise when t == 0 |
| 93 | + # NOTE: original implementation missing sqrt on posterior_variance |
| 94 | + obs_reconstruct = ( |
| 95 | + obs_reconstruct + int(i > 0) * (0.5 * posterior_variance) * eta * noise |
| 96 | + ) # MJ had as log var, exponentiated |
| 97 | + |
| 98 | + # 4. apply conditions to the trajectory |
| 99 | + obs_reconstruct_postcond = train_diffuser.reset_x0(obs_reconstruct, conditions, action_dim) |
| 100 | + x = train_diffuser.to_torch(obs_reconstruct_postcond) |
| 101 | + plans = train_diffuser.helpers.to_np(x[:, :, :action_dim]) |
| 102 | + # select random plan |
| 103 | + idx = np.random.randint(plans.shape[0]) |
| 104 | + # select action at correct time |
| 105 | + action = plans[idx, 0, :] |
| 106 | + actions = train_diffuser.de_normalize(action, data, "actions") |
| 107 | + # execute action in environment |
| 108 | + next_observation, reward, terminal, _ = env.step(action) |
| 109 | + |
| 110 | + # update return |
| 111 | + total_reward += reward |
| 112 | + print(f"Step: {t}, Reward: {reward}, Total Reward: {total_reward}") |
| 113 | + |
| 114 | + # save observations for rendering |
| 115 | + rollout.append(next_observation.copy()) |
| 116 | + obs = next_observation |
| 117 | +except KeyboardInterrupt: |
| 118 | + pass |
| 119 | + |
| 120 | +print(f"Total reward: {total_reward}") |
| 121 | +render = train_diffuser.MuJoCoRenderer(env) |
| 122 | +train_diffuser.show_sample(render, np.expand_dims(np.stack(rollout), axis=0)) |
0 commit comments