Skip to content

Commit

Permalink
Merge pull request #1161 from bghira/feature/validation-lycoris-strength
Browse files Browse the repository at this point in the history
validation: allow setting a non-default strength for validation with lycoris
  • Loading branch information
bghira authored Nov 16, 2024
2 parents ea620b2 + c4f30d1 commit d223171
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 11 deletions.
20 changes: 13 additions & 7 deletions documentation/LYCORIS.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,17 @@

## Using LyCORIS

To use LyCORIS, set `--lora_type=lycoris` and then set `--lycoris_config=config/lycoris_config.json`, where `config/lycoris_config.json` is the location of your LyCORIS configuration file:
To use LyCORIS, set `--lora_type=lycoris` and then set `--lycoris_config=config/lycoris_config.json`, where `config/lycoris_config.json` is the location of your LyCORIS configuration file.

```bash
MODEL_TYPE=lora
# We use trainer_extra_args for now, as Lycoris support is so new.
TRAINER_EXTRA_ARGS+=" --lora_type=lycoris --lycoris_config=config/lycoris_config.json"
The following will go into your `config.json`:
```json
{
"model_type": "lora",
"lora_type": "lycoris",
"lycoris_config": "config/lycoris_config.json",
"validation_lycoris_strength": 1.0,
...the rest of your settings...
}
```


Expand Down Expand Up @@ -48,7 +53,7 @@ Optional fields:
- any keyword arguments specific to the selected algorithm, at the end.

Mandatory fields:
- multiplier
- multiplier, which should be set to 1.0 only unless you know what to expect
- linear_dim
- linear_alpha

Expand Down Expand Up @@ -81,7 +86,8 @@ vae = AutoencoderKL.from_pretrained(bfl_repo, subfolder="vae", torch_dtype=dtype
transformer = FluxTransformer2DModel.from_pretrained(bfl_repo, subfolder="transformer")

lycoris_safetensors_path = 'pytorch_lora_weights.safetensors'
wrapper, _ = create_lycoris_from_weights(1.0, lycoris_safetensors_path, transformer)
lycoris_strength = 1.0
wrapper, _ = create_lycoris_from_weights(lycoris_strength, lycoris_safetensors_path, transformer)
wrapper.merge_to() # using apply_to() will be slower.

transformer.to(device, dtype=dtype)
Expand Down
9 changes: 9 additions & 0 deletions helpers/configuration/cmd_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1369,6 +1369,15 @@ def get_argument_parser():
" and submit debug.log to a new Github issue report."
),
)
parser.add_argument(
"--validation_lycoris_strength",
type=float,
default=1.0,
help=(
"When inferencing for validations, the Lycoris model will by default be run at its training strength, 1.0."
" However, this value can be increased to a value of around 1.3 or 1.5 to get a stronger effect from the model."
),
)
parser.add_argument(
"--validation_torch_compile",
action="store_true",
Expand Down
11 changes: 7 additions & 4 deletions helpers/publishing/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def code_example(args, repo_id: str = None):
image = pipeline(
prompt=prompt,{_negative_prompt(args, in_call=True) if args.model_family.lower() != 'flux' else ''}
num_inference_steps={args.validation_num_inference_steps},
generator=torch.Generator(device={_torch_device()}).manual_seed(1641421826),
generator=torch.Generator(device={_torch_device()}).manual_seed({args.validation_seed or args.seed or 42}),
{_validation_resolution(args)}
guidance_scale={args.validation_guidance},{_guidance_rescale(args)}{_skip_layers(args)}
).images[0]
Expand Down Expand Up @@ -293,7 +293,7 @@ def lora_info(args):
lycoris_config = json.load(file)
except:
lycoris_config = {"error": "could not locate or load LyCORIS config."}
return f"""- LyCORIS Config:\n```json\n{json.dumps(lycoris_config, indent=4)}\n```"""
return f"""### LyCORIS Config:\n```json\n{json.dumps(lycoris_config, indent=4)}\n```"""


def model_card_note(args):
Expand Down Expand Up @@ -516,16 +516,19 @@ def save_model_card(
- Training epochs: {StateTracker.get_epoch() - 1}
- Training steps: {StateTracker.get_global_step()}
- Learning rate: {StateTracker.get_args().learning_rate}
- Learning rate schedule: {StateTracker.get_args().lr_scheduler}
- Warmup steps: {StateTracker.get_args().lr_warmup_steps}
- Max grad norm: {StateTracker.get_args().max_grad_norm}
- Effective batch size: {StateTracker.get_args().train_batch_size * StateTracker.get_args().gradient_accumulation_steps * StateTracker.get_accelerator().num_processes}
- Micro-batch size: {StateTracker.get_args().train_batch_size}
- Gradient accumulation steps: {StateTracker.get_args().gradient_accumulation_steps}
- Number of GPUs: {StateTracker.get_accelerator().num_processes}
- Gradient checkpointing: {StateTracker.get_args().gradient_checkpointing}
- Prediction type: {'flow-matching' if (StateTracker.get_args().model_family in ["sd3", "flux"]) else StateTracker.get_args().prediction_type}{model_schedule_info(args=StateTracker.get_args())}
- Optimizer: {StateTracker.get_args().optimizer}{optimizer_config if optimizer_config is not None else ''}
- Trainable parameter precision: {'Pure BF16' if torch.backends.mps.is_available() or StateTracker.get_args().mixed_precision == "bf16" else 'FP32'}
- Quantised base model: {f'Yes ({StateTracker.get_args().base_model_precision})' if StateTracker.get_args().base_model_precision != "no_change" else 'No'}
- Xformers: {'Enabled' if StateTracker.get_args().enable_xformers_memory_efficient_attention else 'Not used'}
- Caption dropout probability: {StateTracker.get_args().caption_dropout_probability * 100}%
{'- Xformers: Enabled' if StateTracker.get_args().enable_xformers_memory_efficient_attention else ''}
{lora_info(args=StateTracker.get_args())}
## Datasets
Expand Down
6 changes: 6 additions & 0 deletions helpers/training/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,6 +957,10 @@ def setup_scheduler(self):
return scheduler

def setup_pipeline(self, validation_type, enable_ema_model: bool = True):
if hasattr(self.accelerator, "_lycoris_wrapped_network"):
self.accelerator._lycoris_wrapped_network.set_multiplier(float(getattr(
self.args, "validation_lycoris_strength", 1.0
)))
if validation_type == "intermediary" and self.args.use_ema:
if enable_ema_model:
if self.unet is not None:
Expand Down Expand Up @@ -1120,6 +1124,8 @@ def setup_pipeline(self, validation_type, enable_ema_model: bool = True):

def clean_pipeline(self):
"""Remove the pipeline."""
if hasattr(self.accelerator, "_lycoris_wrapped_network"):
self.accelerator._lycoris_wrapped_network.set_multiplier(1.0)
if self.pipeline is not None:
del self.pipeline
self.pipeline = None
Expand Down

0 comments on commit d223171

Please sign in to comment.