Skip to content

Commit

Permalink
add mask and value clipping to normalization wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
kpertsch committed Dec 21, 2023
1 parent 5bd39e8 commit 1477edb
Showing 1 changed file with 29 additions and 10 deletions.
39 changes: 29 additions & 10 deletions octo/utils/gym_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,26 +271,45 @@ def __init__(
super().__init__(env)

def unnormalize(self, data, metadata):
mask = metadata.get("mask", np.ones_like(metadata["mean"], dtype=bool))
if self.normalization_type == "normal":
return (data * metadata["std"]) + metadata["mean"]
return np.where(
mask,
(data * metadata["std"]) + metadata["mean"],
data,
)
elif self.normalization_type == "bounds":
return (
(data + 1) / 2 * (metadata["max"] - metadata["min"] + 1e-8)
) + metadata["min"]
return np.where(
mask,
((data + 1) / 2 * (metadata["max"] - metadata["min"] + 1e-8))
+ metadata["min"],
data,
)
else:
raise ValueError(
f"Unknown action/proprio normalization type: {self.normalization_type}"
)

def normalize(self, data, metadata):
mask = metadata.get("mask", np.ones_like(metadata["mean"], dtype=bool))
if self.normalization_type == "normal":
return (data - metadata["mean"]) / (metadata["std"] + 1e-8)
return np.where(
mask,
(data - metadata["mean"]) / (metadata["std"] + 1e-8),
data,
)
elif self.normalization_type == "bounds":
return (
2
* (data - metadata["min"])
/ (metadata["max"] - metadata["min"] + 1e-8)
- 1
return np.where(
mask,
np.clip(
2
* (data - metadata["min"])
/ (metadata["max"] - metadata["min"] + 1e-8)
- 1,
-1,
1,
),
data,
)
else:
raise ValueError(
Expand Down

0 comments on commit 1477edb

Please sign in to comment.