Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
deeptimhe committed Dec 7, 2024
1 parent 8a0b799 commit 6e46457
Show file tree
Hide file tree
Showing 34 changed files with 7,655 additions and 406 deletions.
563 changes: 166 additions & 397 deletions .gitignore

Large diffs are not rendered by default.

306 changes: 297 additions & 9 deletions README.md

Large diffs are not rendered by default.

Binary file added assets/example.mp4
Binary file not shown.
Binary file added assets/radar.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
118 changes: 118 additions & 0 deletions configs/fsq_causal_41616_262144.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
model:
base_learning_rate: 1e-5
target: vidtok.models.autoencoder.AutoencodingEngine
params:
monitor: val/rec_loss
mode: min
# ckpt_path: checkpoints/fsq_causal_41616_262144.ckpt # train from existing checkpoint
ignore_keys: []
# ema_decay: 0.999

encoder_config:
target: vidtok.modules.model_3dcausal.EncoderCausal3DPadding
params:
double_z: false
z_channels: 6
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4, 4]
time_downsample_factor: 4
num_res_blocks: 2
dropout: 0.0
use_checkpoint: true
init_pad_mode: replicate
norm_type: layernorm # layernorm, groupnorm
fix_encoder: false # if True, fix it without updating params
fix_decoder: false # if True, fix it without updating params

decoder_config:
target: vidtok.modules.model_3dcausal.DecoderCausal3DPadding
params: ${model.params.encoder_config.params}

regularizer_config:
target: vidtok.modules.regularizers.FSQRegularizer
params:
levels: [8, 8, 8, 8, 8, 8] # codebook size: 8*8*8*8*8*8=262144
entropy_loss_weight: 0.1
entropy_loss_annealing_steps: 2000
entropy_loss_annealing_factor: 3
commitment_loss_weight: 0.25

loss_config:
target: vidtok.modules.losses.GeneralLPIPSWithDiscriminator
params:
dims: 3 # video - [t,h,w]
perceptual_weight: 1.0
disc_start: 20001
disc_weight: 0.2
disc_type: 2d # 2d, 3d
learn_logvar: true
gen_loss_cross_entropy: true
lecam_loss_weight: 0.005
regularization_weights: {'aux_loss': 1.0, 'kl_loss': 0.000001}

data:
target: vidtok.data.datamodule.DataModuleFromConfig
params:
batch_size: 2
num_workers: 12

train:
target: vidtok.data.vidtok.VidTokDataset
params:
data_dir: DATA_DIR_1 # DATA_DIR for training data
meta_path: META_PATH_1 # path to the .csv meta file of training data
video_params:
input_height: INPUT_HEIGHT_1
input_width: INPUT_WIDTH_1
sample_num_frames: 17
sample_fps: 3

validation:
target: vidtok.data.vidtok.VidTokDataset
params:
data_dir: DATA_DIR_2 # DATA_DIR for validation data
meta_path: META_PATH_2 # path to the .csv meta file of validation data
video_params:
input_height: INPUT_HEIGHT_2
input_width: INPUT_WIDTH_2
sample_num_frames: 17
sample_fps: 8
start_index: 0

lightning:
strategy:
target: lightning.pytorch.strategies.DDPStrategy
params:
find_unused_parameters: true

modelcheckpoint:
params:
every_n_train_steps: 5000

callbacks:
image_logger:
target: vidtok.modules.logger.ImageVideoLogger
params:
disabled: false
rescale: true
enable_autocast: false
batch_frequency: 5000
max_samples: 2
increase_log_steps: false
log_first_step: false
log_before_first_step: false
log_images_kwargs:
n_rows: 17

trainer:
precision: bf16-mixed
devices: auto
num_nodes: 1
benchmark: true
num_sanity_val_steps: 10
val_check_interval: 2000
check_val_every_n_epoch: null # default: 1
accumulate_grad_batches: 1
max_epochs: 1000
118 changes: 118 additions & 0 deletions configs/fsq_causal_488_262144.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
model:
base_learning_rate: 1e-5
target: vidtok.models.autoencoder.AutoencodingEngine
params:
monitor: val/rec_loss
mode: min
# ckpt_path: checkpoints/fsq_causal_488_262144.ckpt # train from existing checkpoint
ignore_keys: []
# ema_decay: 0.999

encoder_config:
target: vidtok.modules.model_3dcausal.EncoderCausal3DPadding
params:
double_z: false
z_channels: 6
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
time_downsample_factor: 4
num_res_blocks: 2
dropout: 0.0
use_checkpoint: true
init_pad_mode: replicate
norm_type: layernorm # layernorm, groupnorm
fix_encoder: false # if True, fix it without updating params
fix_decoder: false # if True, fix it without updating params

decoder_config:
target: vidtok.modules.model_3dcausal.DecoderCausal3DPadding
params: ${model.params.encoder_config.params}

regularizer_config:
target: vidtok.modules.regularizers.FSQRegularizer
params:
levels: [8, 8, 8, 8, 8, 8] # codebook size: 8*8*8*8*8*8=262144
entropy_loss_weight: 0.1
entropy_loss_annealing_steps: 2000
entropy_loss_annealing_factor: 3
commitment_loss_weight: 0.25

loss_config:
target: vidtok.modules.losses.GeneralLPIPSWithDiscriminator
params:
dims: 3 # video - [t,h,w]
perceptual_weight: 1.0
disc_start: 20001
disc_weight: 0.2
disc_type: 2d # 2d, 3d
learn_logvar: true
gen_loss_cross_entropy: true
lecam_loss_weight: 0.005
regularization_weights: {'aux_loss': 1.0, 'kl_loss': 0.000001}

data:
target: vidtok.data.datamodule.DataModuleFromConfig
params:
batch_size: 2
num_workers: 12

train:
target: vidtok.data.vidtok.VidTokDataset
params:
data_dir: DATA_DIR_1 # DATA_DIR for training data
meta_path: META_PATH_1 # path to the .csv meta file of training data
video_params:
input_height: INPUT_HEIGHT_1
input_width: INPUT_WIDTH_1
sample_num_frames: 17
sample_fps: 3

validation:
target: vidtok.data.vidtok.VidTokDataset
params:
data_dir: DATA_DIR_2 # DATA_DIR for validation data
meta_path: META_PATH_2 # path to the .csv meta file of validation data
video_params:
input_height: INPUT_HEIGHT_2
input_width: INPUT_WIDTH_2
sample_num_frames: 17
sample_fps: 8
start_index: 0

lightning:
strategy:
target: lightning.pytorch.strategies.DDPStrategy
params:
find_unused_parameters: true

modelcheckpoint:
params:
every_n_train_steps: 5000

callbacks:
image_logger:
target: vidtok.modules.logger.ImageVideoLogger
params:
disabled: false
rescale: true
enable_autocast: false
batch_frequency: 5000
max_samples: 2
increase_log_steps: false
log_first_step: false
log_before_first_step: false
log_images_kwargs:
n_rows: 17

trainer:
precision: bf16-mixed
devices: auto
num_nodes: 1
benchmark: true
num_sanity_val_steps: 10
val_check_interval: 2000
check_val_every_n_epoch: null # default: 1
accumulate_grad_batches: 1
max_epochs: 1000
118 changes: 118 additions & 0 deletions configs/fsq_causal_488_32768.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
model:
base_learning_rate: 1e-5
target: vidtok.models.autoencoder.AutoencodingEngine
params:
monitor: val/rec_loss
mode: min
# ckpt_path: checkpoints/fsq_causal_488_32768.ckpt # train from existing checkpoint
ignore_keys: []
# ema_decay: 0.999

encoder_config:
target: vidtok.modules.model_3dcausal.EncoderCausal3DPadding
params:
double_z: false
z_channels: 5
in_channels: 3
out_ch: 3
ch: 128
ch_mult: [1, 2, 4, 4]
time_downsample_factor: 4
num_res_blocks: 2
dropout: 0.0
use_checkpoint: true
init_pad_mode: replicate
norm_type: layernorm # layernorm, groupnorm
fix_encoder: false # if True, fix it without updating params
fix_decoder: false # if True, fix it without updating params

decoder_config:
target: vidtok.modules.model_3dcausal.DecoderCausal3DPadding
params: ${model.params.encoder_config.params}

regularizer_config:
target: vidtok.modules.regularizers.FSQRegularizer
params:
levels: [8, 8, 8, 8, 8] # codebook size: 8*8*8*8*8=32768
entropy_loss_weight: 0.1
entropy_loss_annealing_steps: 2000
entropy_loss_annealing_factor: 3
commitment_loss_weight: 0.25

loss_config:
target: vidtok.modules.losses.GeneralLPIPSWithDiscriminator
params:
dims: 3 # video - [t,h,w]
perceptual_weight: 1.0
disc_start: 20001
disc_weight: 0.2
disc_type: 2d # 2d, 3d
learn_logvar: true
gen_loss_cross_entropy: true
lecam_loss_weight: 0.005
regularization_weights: {'aux_loss': 1.0, 'kl_loss': 0.000001}

data:
target: vidtok.data.datamodule.DataModuleFromConfig
params:
batch_size: 2
num_workers: 12

train:
target: vidtok.data.vidtok.VidTokDataset
params:
data_dir: DATA_DIR_1 # DATA_DIR for training data
meta_path: META_PATH_1 # path to the .csv meta file of training data
video_params:
input_height: INPUT_HEIGHT_1
input_width: INPUT_WIDTH_1
sample_num_frames: 17
sample_fps: 3

validation:
target: vidtok.data.vidtok.VidTokDataset
params:
data_dir: DATA_DIR_2 # DATA_DIR for validation data
meta_path: META_PATH_2 # path to the .csv meta file of validation data
video_params:
input_height: INPUT_HEIGHT_2
input_width: INPUT_WIDTH_2
sample_num_frames: 17
sample_fps: 8
start_index: 0

lightning:
strategy:
target: lightning.pytorch.strategies.DDPStrategy
params:
find_unused_parameters: true

modelcheckpoint:
params:
every_n_train_steps: 5000

callbacks:
image_logger:
target: vidtok.modules.logger.ImageVideoLogger
params:
disabled: false
rescale: true
enable_autocast: false
batch_frequency: 5000
max_samples: 2
increase_log_steps: false
log_first_step: false
log_before_first_step: false
log_images_kwargs:
n_rows: 17

trainer:
precision: bf16-mixed
devices: auto
num_nodes: 1
benchmark: true
num_sanity_val_steps: 10
val_check_interval: 2000
check_val_every_n_epoch: null # default: 1
accumulate_grad_batches: 1
max_epochs: 1000
Loading

0 comments on commit 6e46457

Please sign in to comment.