Package smile.llm

Interface RotaryPositionalEncoding


public interface RotaryPositionalEncoding
Rotary positional encoding (RoPE). RoPE encodes the absolute position with a rotation matrix and meanwhile incorporates the explicit relative position dependency in self-attention formulation. Notably, RoPE enables the flexibility of sequence length, decaying inter-token dependency with increasing relative distances, and the capability of equipping the linear self-attention with relative position encoding.
  • Method Summary

    Static Methods
    Modifier and Type
    Method
    Description
    apply(Tensor xq, Tensor xk, Tensor cis)
    Applies rotary embeddings to the input query and key tensors.
    static Tensor
    computeFreqCis(int dim, int end)
    Precompute the frequency tensor for complex exponentials (cis).
    static Tensor
    computeFreqCis(int dim, int end, double theta, boolean scaling)
    Precompute the frequency tensor for complex exponentials (cis).
    static Tensor
    Reshapes the cis tensor to match the shape of the target tensor x for broadcasting purposes, allowing for element-wise operations between tensors of compatible shapes.
    static Tensor
    scale(Tensor freqs)
    Adapts RoPE to longer input lengths.
  • Method Details

    • apply

      static Tuple2<Tensor,Tensor> apply(Tensor xq, Tensor xk, Tensor cis)
      Applies rotary embeddings to the input query and key tensors. It ensures that the output tensors have the same data type as the input tensors.
      Parameters:
      xq - the query tensor.
      xk - the key tensor.
      Returns:
      the tuple of modified query tensor and key tensor with rotary embeddings.
    • computeFreqCis

      static Tensor computeFreqCis(int dim, int end)
      Precompute the frequency tensor for complex exponentials (cis). with default theta 10000.0.
      Parameters:
      dim - the dimension of the frequency tensor.
      end - the end index for precomputing frequencies.
      Returns:
      the precomputed frequency tensor for complex exponentials.
    • computeFreqCis

      static Tensor computeFreqCis(int dim, int end, double theta, boolean scaling)
      Precompute the frequency tensor for complex exponentials (cis).
      Parameters:
      dim - the dimension of the frequency tensor.
      end - the end index for precomputing frequencies.
      theta - the scaling factor for frequency computation.
      scaling - if true, scale the frequency tensor.
      Returns:
      the precomputed frequency tensor for complex exponentials.
    • reshapeForBroadcast

      static Tensor reshapeForBroadcast(Tensor cis, Tensor x)
      Reshapes the cis tensor to match the shape of the target tensor x for broadcasting purposes, allowing for element-wise operations between tensors of compatible shapes.
      Parameters:
      cis - the frequency tensor for complex exponentials.
      x - the target tensor for broadcasting.
      Returns:
      the reshaped cis tensor view.
    • scale

      static Tensor scale(Tensor freqs)
      Adapts RoPE to longer input lengths.
      Parameters:
      freqs - the frequency tensor.
      Returns:
      the scaled frequency tensor.