Skip to content

Parameter initialization

Albert Zeyer edited this page Jan 19, 2022 · 16 revisions

Terminology and formula

  • std: standard deviation
  • var: variance
  • E[X]: expected value (mean, average)
  • std = sqrt(var), var = std ** 2, var = E[(X - E[X]) ** 2] = E[X**2] - (E[X])**2
  • TF VarianceScaling
    • std = sqrt(scale / fan)
    • normal is always truncated normal
  • Truncated normal std = std / .87962566103423978, via TF VarianceScaling, Scipy a=-2, b=2 ...
  • Uniform bound = sqrt(3) * std (uniformly draw samples from interval [-bound, bound], results in the given std)
  • fan_in, fan_out: input/output dimension, potentially multiplied by filter sizes (receptive field size) in case of convolution
  • fan_avg = (fan_in + fan_out) / 2
  • Xavier Glorot (paper 2010): VarianceScaling(scale=1.0, mode="fan_avg", usually distribution="uniform")
  • Kaiming He (paper 2015): VarianceScaling(scale=2., mode="fan_in", usually distribution="normal")

Linear

  • RETURNN Theano: VarianceScaling(scale=6.0, mode="fan_avg", distribution="normal")

  • RETURNN TensorFlow: Glorot uniform = VarianceScaling(scale=1.0, mode="fan_avg", distribution="uniform")

  • Keras: Glorot uniform

  • Lingvo: Glorot uniform

  • PyTorch: uniform(-1/sqrt(fan_in), 1/sqrt(fan_in)) = VarianceScaling(scale=1. / 3, mode="fan_in", distribution="uniform")

  • PyTorch proposed: kaiming_normal(mode='fan_in') = VarianceScaling(scale=2., mode="fan_in", distribution="normal")

  • Transformer:

  • PyTorch #18182

  • comment

PyTorch nn.init code

PyTorch nn.init code:

def kaiming_uniform_(tensor, a=0, mode='fan_in', nonlinearity='leaky_relu'):
    fan = _calculate_correct_fan(tensor, mode)
    gain = calculate_gain(nonlinearity, a)
    std = gain / math.sqrt(fan)
    bound = math.sqrt(3.0) * std  # Calculate uniform bounds from standard deviation
    with torch.no_grad():
        return tensor.uniform_(-bound, bound)

def _calculate_correct_fan(tensor, mode):
    fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
    return fan_in if mode == 'fan_in' else fan_out

def _calculate_fan_in_and_fan_out(tensor):
    dimensions = tensor.dim()
    if dimensions < 2:
        raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")

    num_input_fmaps = tensor.size(1)
    num_output_fmaps = tensor.size(0)
    receptive_field_size = 1
    if tensor.dim() > 2:
        # math.prod is not always available, accumulate the product manually
        # we could use functools.reduce but that is not supported by TorchScript
        for s in tensor.shape[2:]:
            receptive_field_size *= s
    fan_in = num_input_fmaps * receptive_field_size
    fan_out = num_output_fmaps * receptive_field_size

    return fan_in, fan_out

def calculate_gain(nonlinearity, param=None):
    linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
    if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
        return 1
    elif nonlinearity == 'tanh':
        return 5.0 / 3
    elif nonlinearity == 'relu':
        return math.sqrt(2.0)
    elif nonlinearity == 'leaky_relu':
        if param is None:
            negative_slope = 0.01
        elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
            # True/False are instances of int, hence check above
            negative_slope = param
        else:
            raise ValueError("negative_slope {} not a valid number".format(param))
        return math.sqrt(2.0 / (1 + negative_slope ** 2))
    elif nonlinearity == 'selu':
        return 3.0 / 4  # Value found empirically (https://github.com/pytorch/pytorch/pull/50664)
    else:
        raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))

TensorFlow

class VarianceScaling(Initializer):
  def __init__(self,
               scale=1.0,
               mode="fan_in",
               distribution="truncated_normal",
               seed=None,
               dtype=dtypes.float32):
    if scale <= 0.:
      raise ValueError("Argument `scale` must be a positive float. Received: "
                       f"{scale}")
    if mode not in {"fan_in", "fan_out", "fan_avg"}:
      raise ValueError("Argument `mode` should be one of ('fan_in', 'fan_out', "
                       f"'fan_avg'). Received: {mode}")
    distribution = distribution.lower()
    if distribution not in {
        "normal", "uniform", "truncated_normal", "untruncated_normal"
    }:
      raise ValueError("Argument `distribution` should be one of ('normal', "
                       "uniform', 'truncated_normal', 'untruncated_normal'). "
                       f"Received: {distribution}")
    self.scale = scale
    self.mode = mode
    self.distribution = distribution
    self.seed = seed
    self.dtype = _assert_float_dtype(dtypes.as_dtype(dtype))

  def __call__(self, shape, dtype=None, partition_info=None):
    if dtype is None:
      dtype = self.dtype
    scale = self.scale
    scale_shape = shape
    if partition_info is not None:
      scale_shape = partition_info.full_shape
    fan_in, fan_out = _compute_fans(scale_shape)
    if self.mode == "fan_in":
      scale /= max(1., fan_in)
    elif self.mode == "fan_out":
      scale /= max(1., fan_out)
    else:
      scale /= max(1., (fan_in + fan_out) / 2.)
    if self.distribution == "normal" or self.distribution == "truncated_normal":
      # constant taken from scipy.stats.truncnorm.std(a=-2, b=2, loc=0., scale=1.)
      stddev = math.sqrt(scale) / .87962566103423978
      return random_ops.truncated_normal(
          shape, 0.0, stddev, dtype, seed=self.seed)
    elif self.distribution == "untruncated_normal":
      stddev = math.sqrt(scale)
      return random_ops.random_normal(shape, 0.0, stddev, dtype, seed=self.seed)
    else:
      limit = math.sqrt(3.0 * scale)
      return random_ops.random_uniform(
          shape, -limit, limit, dtype, seed=self.seed)

def _compute_fans(shape):
  """Computes the number of input and output units for a weight shape.

  Args:
    shape: Integer shape tuple or TF tensor shape.

  Returns:
    A tuple of integer scalars (fan_in, fan_out).
  """
  if len(shape) < 1:  # Just to avoid errors for constants.
    fan_in = fan_out = 1
  elif len(shape) == 1:
    fan_in = fan_out = shape[0]
  elif len(shape) == 2:
    fan_in = shape[0]
    fan_out = shape[1]
  else:
    # Assuming convolution kernels (2D, 3D, or more).
    # kernel shape: (..., input_depth, depth)
    receptive_field_size = 1
    for dim in shape[:-2]:
      receptive_field_size *= dim
    fan_in = shape[-2] * receptive_field_size
    fan_out = shape[-1] * receptive_field_size
  return int(fan_in), int(fan_out)