Skip to content

Commit

Permalink
MoRA Implementation (#9562)
Browse files Browse the repository at this point in the history
* MoRA Implementation

* MoRA算法

* MoRA

* MoRA

* MoRA

* MoRA

* MoRA

* MoRA

* MoRA

* MoRA
  • Loading branch information
lcykww authored Dec 18, 2024
1 parent 605a4ea commit 90bc68e
Show file tree
Hide file tree
Showing 10 changed files with 657 additions and 64 deletions.
3 changes: 2 additions & 1 deletion llm/config/llama/lora_argument.json
Original file line number Diff line number Diff line change
Expand Up @@ -31,5 +31,6 @@
"zero_padding": false,
"use_flash_attention": true,
"unified_checkpoint": true,
"pissa": false
"pissa": false,
"use_mora": false
}
1 change: 1 addition & 0 deletions llm/run_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,6 +522,7 @@ def create_peft_model(model_args, reft_args, training_args, dtype, model_config,
base_model_name_or_path=model_args.model_name_or_path,
use_quick_lora=model_args.use_quick_lora,
lora_use_mixer=model_args.lora_use_mixer,
use_mora=model_args.use_mora,
)
model = LoRAModel(model, lora_config)
else:
Expand Down
49 changes: 29 additions & 20 deletions llm/tools/merge_lora_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.
import argparse
import copy
import math
import os

import numpy as np
Expand Down Expand Up @@ -79,49 +78,56 @@ def weight_process(name, quant_config, lora_config, state_dict, device):
raise ValueError(f"quant_config.weight_quantize_algo {quant_config.weight_quantize_algo} is not supported.")


def lora_process(name, lora_config, state_dict, device, lora_state_dict=None):
def lora_process(name, layer, lora_config, state_dict, device, lora_state_dict=None):
target_device = device if device == "cpu" else device + ":0"

if (name + ".weight") not in state_dict.keys():
return

weight = state_dict.pop(name + ".weight")
lora_use_mixer = lora_config.lora_use_mixer
use_mora = lora_config.use_mora
if lora_state_dict is None:
lora_A = state_dict.pop(name + ".lora_A")
lora_B = state_dict.pop(name + ".lora_B")
if not use_mora:
lora_B = state_dict.pop(name + ".lora_B")
if lora_use_mixer:
lora_AB = state_dict.pop(name + ".lora_AB")
else:
lora_A = lora_state_dict.pop(name + ".lora_A")
lora_B = lora_state_dict.pop(name + ".lora_B")
if not use_mora:
lora_B = lora_state_dict.pop(name + ".lora_B")
if lora_use_mixer:
lora_AB = lora_state_dict.pop(name + ".lora_AB")
if device != "cpu":
weight = weight.to(target_device)
lora_A = lora_A.to(target_device)
lora_B = lora_B.to(target_device)
if not use_mora:
lora_B = lora_B.to(target_device)
if lora_use_mixer:
lora_AB = lora_AB.to(target_device)
if not lora_config.rslora:
scaling = lora_config.lora_alpha / lora_config.r
else:
scaling = lora_config.lora_alpha / math.sqrt(lora_config.r)

if device == "cpu" and weight.dtype.name == "BF16":
weight = weight.astype("float32")
lora_A = lora_A.astype("float32")
lora_B = lora_B.astype("float32")
if not use_mora:
lora_B = lora_B.astype("float32")
if lora_use_mixer:
lora_AB = lora_AB.astype(lora_config.dtype)
out = (weight + lora_A @ lora_AB @ lora_B * scaling).astype(lora_config.dtype)
delta_weight = layer.get_delta_weight(lora_A, lora_B, lora_AB)
elif use_mora:
delta_weight = layer.get_delta_weight(lora_A)
else:
out = (weight + lora_A @ lora_B * scaling).astype(lora_config.dtype)
delta_weight = layer.get_delta_weight(lora_A, lora_B)
out = (weight + delta_weight).astype(lora_config.dtype)
else:
if lora_use_mixer:
out = (weight + lora_A @ lora_AB @ lora_B * scaling).cpu()
delta_weight = layer.get_delta_weight(lora_A, lora_B, lora_AB)
elif use_mora:
delta_weight = layer.get_delta_weight(lora_A)
else:
out = (weight + lora_A @ lora_B * scaling).cpu()
delta_weight = layer.get_delta_weight(lora_A, lora_B)
out = (weight + delta_weight).cpu()

state_dict[name + ".weight"] = out

Expand Down Expand Up @@ -220,12 +226,15 @@ def merge():
if isinstance(layer, paddle.nn.Linear) or isinstance(layer, QuantizationLinear):
weight_process(name, quant_config, lora_config, model_state_dict, args.device)

lora_name_list = []
for key in model_state_dict.keys():
if "lora_A" in key:
lora_name_list.append(key[:-7])
for name in lora_name_list:
lora_process(name, lora_config, model_state_dict, args.device)
lora_info = {}
for sublayer_name, sublayer in model.named_sublayers():
if isinstance(sublayer, paddle.nn.Linear):
for param_name, param in sublayer.named_parameters():
if "lora_A" in param_name:
lora_info[sublayer_name[6:]] = sublayer

for name, layer in lora_info.items():
lora_process(name, layer, lora_config, model_state_dict, args.device)

logger.info("Begin to save merged model")
if args.safe_serialization:
Expand Down
3 changes: 3 additions & 0 deletions paddlenlp/peft/lora/lora_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ class LoRAConfig:
rslora: bool = field(default=False, metadata={"help": "Whether to use RsLoRA"})
pissa: bool = field(default=False, metadata={"help": "Whether to use Pissa: https://arxiv.org/pdf/2404.02948.pdf"})
loraga: bool = field(default=False, metadata={"help": "Whether to LoRA-GA"})
use_mora: bool = field(
default=False, metadata={"help": "Whether to use MoRA: https://arxiv.org/pdf/2405.12130.pdf"}
)
lora_plus_scale: float = field(default=1.0, metadata={"help": "Lora B scale in LoRA+"})
base_model_name_or_path: Optional[str] = field(
default=None, metadata={"help": "The name of the base model to use."}
Expand Down
185 changes: 146 additions & 39 deletions paddlenlp/peft/lora/lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
ReduceScatterOp = None
mark_as_sequence_parallel_parameter = None


from ...transformers.mc2_parallel_linear import (
MC2ColumnParallelCoreLinear,
MC2ColumnSeqParallelCoreLinear,
Expand All @@ -65,11 +64,13 @@ def __init__(
lora_plus_scale: float = 1.0,
pissa: bool = False,
lora_use_mixer: bool = False,
use_mora: bool = False,
**kwargs
):
nn.Linear.__init__(self, in_features, out_features, **kwargs)
if not isinstance(r, int) or r <= 0:
raise ValueError("Lora rank r should be a positive integer")
self.use_mora = use_mora
self.r = r
self.lora_alpha = lora_alpha
# Optional dropout
Expand All @@ -83,36 +84,58 @@ def __init__(
self.lora_use_mixer = lora_use_mixer

# Actual trainable parameters
self.lora_A = self.create_parameter(
shape=[in_features, r],
dtype=self._dtype,
is_bias=False,
default_initializer=nn.initializer.KaimingUniform(negative_slope=math.sqrt(5), nonlinearity="leaky_relu"),
)
if self.lora_use_mixer:
self.lora_AB = self.create_parameter(
shape=[r, r],
if use_mora: # reset the rank and create high rank matrix
self.in_features = in_features
self.out_features = out_features
new_r = int(math.sqrt((in_features + out_features) * r) + 0.5)
new_r = new_r // 2 * 2
self.r = new_r
self.lora_A = self.create_parameter(
shape=[self.r, self.r],
dtype=self._dtype,
is_bias=False,
default_initializer=nn.initializer.Constant(value=0.0),
)
self.cos = None
self.sin = None
# Count the number of tiles
self.rb1 = self.in_features // self.r if self.in_features % self.r == 0 else self.in_features // self.r + 1
self.rb2 = (
self.out_features // self.r if self.out_features % self.r == 0 else self.out_features // self.r + 1
)
self.rope_init()
else:
self.lora_A = self.create_parameter(
shape=[in_features, r],
dtype=self._dtype,
is_bias=False,
default_initializer=nn.initializer.KaimingUniform(
negative_slope=math.sqrt(5), nonlinearity="leaky_relu"
),
)
self.lora_B = self.create_parameter(
shape=[r, out_features],
dtype=self._dtype,
is_bias=False,
attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=0.0),
learning_rate=lora_plus_scale,
),
)
if self.lora_use_mixer:
self.lora_AB = self.create_parameter(
shape=[r, r],
dtype=self._dtype,
is_bias=False,
default_initializer=nn.initializer.KaimingUniform(
negative_slope=math.sqrt(5), nonlinearity="leaky_relu"
),
)
self.lora_B = self.create_parameter(
shape=[r, out_features],
dtype=self._dtype,
is_bias=False,
attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Constant(value=0.0),
learning_rate=lora_plus_scale,
),
)
self.apply_pissa = False

if not rslora and not pissa:
self.scaling = self.lora_alpha / self.r
elif pissa:
if use_mora or pissa:
self.scaling = 1.0
elif not rslora:
self.scaling = self.lora_alpha / self.r
else:
self.scaling = self.lora_alpha / math.sqrt(self.r)

Expand All @@ -121,10 +144,6 @@ def __init__(
self._use_quick_lora = use_quick_lora and lora_dropout == 0.0
self.disable_lora = False

@property
def use_quick_lora(self):
return self._use_quick_lora and self.training and not self.merged

def pissa_init(self, rank):
weight = self.weight
dtype = weight.dtype
Expand All @@ -144,21 +163,102 @@ def pissa_init(self, rank):
weight = res.astype(dtype)
self.weight.set_value(weight)

def rope_init(self):
if self.cos is None or self.sin is None:
inv_freq = 1.0 / (10000 ** (paddle.arange(0, self.r, 2, dtype=paddle.float32) / self.r))
t = paddle.arange(self.rb1, dtype=paddle.float32)
freqs = t.unsqueeze(1) @ inv_freq.unsqueeze(0)
emb = paddle.concat([freqs, freqs], axis=-1)
self.cos = paddle.unsqueeze(paddle.cos(emb), axis=0).astype(self._dtype)
self.sin = paddle.unsqueeze(paddle.sin(emb), axis=0).astype(self._dtype)

@property
def use_quick_lora(self):
return self._use_quick_lora and self.training and not self.merged

def _apply_mora(self, x):
r = self.r

# Calculate grouping
sum_inter = self.in_features // r

# padding
if self.in_features % r != 0:
pad_size = r - self.in_features % r
x = paddle.concat([x, x[..., :pad_size]], axis=-1)
sum_inter += 1

# reshape the input to apply RoPE
in_x = x.reshape([*x.shape[:-1], sum_inter, r])

# apply RoPE rotation
rh_in_x = paddle.concat([-in_x[..., r // 2 :], in_x[..., : r // 2]], axis=-1)
in_x = in_x * self.cos + rh_in_x * self.sin

# matmul with high rank matrix
out_x = in_x @ self.lora_A

# reshape the output
out_x = out_x.reshape([*x.shape[:-1], -1])[..., : self.out_features]
if out_x.shape[-1] < self.out_features:
repeat_time = self.out_features // out_x.shape[-1]
if self.out_features % out_x.shape[-1] != 0:
repeat_time += 1
out_x = paddle.concat([out_x] * repeat_time, axis=-1)[..., : self.out_features]

return out_x

def get_delta_weight(self, lora_A=None, lora_B=None, lora_AB=None):
# compute the delta weight,which is used to merge weights
if self.lora_use_mixer:
lora_A = lora_A if lora_A is not None else self.lora_A
lora_B = lora_B if lora_B is not None else self.lora_B
lora_AB = lora_AB if lora_AB is not None else self.lora_AB
delta_weight = lora_A @ lora_AB @ lora_B * self.scaling
elif self.use_mora:
lora_A = lora_A if lora_A is not None else self.lora_A
r = self.r
# compute padding
pad_size = r - self.in_features % r if self.in_features % r != 0 else 0
# initialize weights
w = paddle.zeros([self.in_features + pad_size, self.in_features], dtype=lora_A.dtype)

# create the weights after rotation
aw2 = paddle.concat([lora_A[:, r // 2 :], -lora_A[:, : r // 2]], axis=-1)
# apply RoPE
for i in range(self.rb1 - 1):
w[i * r : (i + 1) * r, i * r : (i + 1) * r] = aw2 * self.sin[:, i] + lora_A * self.cos[:, i]
# Process the last chunk that may be incomplete
i = self.rb1 - 1
w[i * r :, i * r :] = (aw2 * self.sin[:, i] + lora_A * self.cos[:, i])[:, : r - pad_size]
# padding
if pad_size > 0:
w[i * r :, :pad_size] = (aw2 * self.sin[:, i] + lora_A * self.cos[:, i])[:, r - pad_size :]
# reshape the weights
if self.in_features < self.out_features:
w = paddle.concat([w] * self.rb2, axis=0)[: self.out_features]
else:
w = w[: self.out_features]
final_weight = w
delta_weight = final_weight.T
else:
lora_A = lora_A if lora_A is not None else self.lora_A
lora_B = lora_B if lora_B is not None else self.lora_B
delta_weight = lora_A @ lora_B * self.scaling

return delta_weight

def merge(self):
if not self.merged:
if self.lora_use_mixer:
new_weight = self.weight + self.lora_A @ self.lora_AB @ self.lora_B * self.scaling
else:
new_weight = self.weight + self.lora_A @ self.lora_B * self.scaling
delta_weight = self.get_delta_weight()
new_weight = self.weight + delta_weight
self.weight.set_value(new_weight)
self.merged = True

def unmerge(self):
if self.merged:
if self.lora_use_mixer:
new_weight = self.weight - self.lora_A @ self.lora_AB @ self.lora_B * self.scaling
else:
new_weight = self.weight - self.lora_A @ self.lora_B * self.scaling
delta_weight = self.get_delta_weight()
new_weight = self.weight - delta_weight
self.weight.set_value(new_weight)
self.merged = False

Expand All @@ -171,6 +271,11 @@ def forward(self, input: paddle.Tensor, *args, **kwargs):
elif self.use_quick_lora:
# Use the quick lora implementation
result = quick_lora(input, self.lora_A, self.lora_B, self.weight, self.bias, self.scaling)
elif self.use_mora:
result = F.linear(x=input, weight=self.weight, bias=self.bias, name=self.name)
input = self.lora_dropout(input)
mora_out = self._apply_mora(input)
result += mora_out
else:
result = F.linear(x=input, weight=self.weight, bias=self.bias, name=self.name)
if self.lora_use_mixer:
Expand All @@ -196,14 +301,15 @@ def __init__(
lora_plus_scale: float = 1.0,
use_quick_lora: bool = False,
pissa: bool = False,
use_mora: bool = False,
**kwargs
):
RowParallelLinear.__init__(self, in_features, out_features, **kwargs)
if not isinstance(r, int) or r <= 0:
raise ValueError("Lora rank r should be a positive integer")

if pissa:
raise ValueError("Pissa is not supported in model parallel by now")
if pissa or use_mora:
raise ValueError("Pissa or Mora is not supported in model parallel by now")

self.r = r
self.lora_alpha = lora_alpha
Expand Down Expand Up @@ -461,14 +567,15 @@ def __init__(
lora_A_weight_attr: Optional[paddle.ParamAttr] = None,
use_quick_lora: bool = False,
pissa: bool = False,
use_mora: bool = False,
**kwargs
):
ColumnParallelLinear.__init__(self, in_features, out_features, **kwargs)
if not isinstance(r, int) or r <= 0:
raise ValueError("Lora rank r should be a positive integer")

if pissa:
raise ValueError("Pissa is not supported in model parallel by now")
if pissa or use_mora:
raise ValueError("Pissa or Mora is not supported in model parallel by now")

self.r = r
self.lora_alpha = lora_alpha
Expand Down
Loading

0 comments on commit 90bc68e

Please sign in to comment.