Skip to content

Commit

Permalink
Add own flag for saveing ply files
Browse files Browse the repository at this point in the history
  • Loading branch information
MrNeRF committed Sep 28, 2024
1 parent cba1d3e commit d30a1a9
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion examples/simple_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ class Config:
eval_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
# Steps to save the model
save_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])
# Steps to save the model as ply
ply_steps: List[int] = field(default_factory=lambda: [7_000, 30_000])

# Initialization strategy
init_type: str = "sfm"
Expand Down Expand Up @@ -165,6 +167,7 @@ class Config:
def adjust_steps(self, factor: float):
self.eval_steps = [int(i * factor) for i in self.eval_steps]
self.save_steps = [int(i * factor) for i in self.save_steps]
self.ply_steps = [int(i * factor) for i in self.ply_steps]
self.max_steps = int(self.max_steps * factor)
self.sh_degree_interval = int(self.sh_degree_interval * factor)

Expand Down Expand Up @@ -765,7 +768,8 @@ def train(self):
torch.save(
data, f"{self.ckpt_dir}/ckpt_{step}_rank{self.world_rank}.pt"
)
save_ply(self.splats, f"{self.ply_dir}/point_cloud.ply")
if step in [i - 1 for i in cfg.ply_steps] or step == max_steps - 1:
save_ply(self.splats, f"{self.ply_dir}/point_cloud_{step}.ply")

if isinstance(self.cfg.strategy, DefaultStrategy):
self.cfg.strategy.step_post_backward(
Expand Down

0 comments on commit d30a1a9

Please sign in to comment.