Package smile.llm
Interface RotaryPositionalEncoding
public interface RotaryPositionalEncoding
Rotary positional encoding (RoPE). RoPE encodes the absolute position with
a rotation matrix and meanwhile incorporates the explicit relative position
dependency in self-attention formulation. Notably, RoPE enables the
flexibility of sequence length, decaying inter-token dependency with
increasing relative distances, and the capability of equipping the linear
self-attention with relative position encoding.
-
Method Summary
Modifier and TypeMethodDescriptionApplies rotary embeddings to the input query and key tensors.static Tensor
computeFreqCis
(int dim, int end) Precompute the frequency tensor for complex exponentials (cis).static Tensor
computeFreqCis
(int dim, int end, double theta, boolean scaling) Precompute the frequency tensor for complex exponentials (cis).static Tensor
reshapeForBroadcast
(Tensor cis, Tensor x) Reshapes the cis tensor to match the shape of the target tensor x for broadcasting purposes, allowing for element-wise operations between tensors of compatible shapes.static Tensor
Adapts RoPE to longer input lengths.
-
Method Details
-
apply
Applies rotary embeddings to the input query and key tensors. It ensures that the output tensors have the same data type as the input tensors.- Parameters:
xq
- the query tensor.xk
- the key tensor.- Returns:
- the tuple of modified query tensor and key tensor with rotary embeddings.
-
computeFreqCis
Precompute the frequency tensor for complex exponentials (cis). with default theta 10000.0.- Parameters:
dim
- the dimension of the frequency tensor.end
- the end index for precomputing frequencies.- Returns:
- the precomputed frequency tensor for complex exponentials.
-
computeFreqCis
Precompute the frequency tensor for complex exponentials (cis).- Parameters:
dim
- the dimension of the frequency tensor.end
- the end index for precomputing frequencies.theta
- the scaling factor for frequency computation.scaling
- if true, scale the frequency tensor.- Returns:
- the precomputed frequency tensor for complex exponentials.
-
reshapeForBroadcast
Reshapes the cis tensor to match the shape of the target tensor x for broadcasting purposes, allowing for element-wise operations between tensors of compatible shapes.- Parameters:
cis
- the frequency tensor for complex exponentials.x
- the target tensor for broadcasting.- Returns:
- the reshaped cis tensor view.
-
scale
Adapts RoPE to longer input lengths.- Parameters:
freqs
- the frequency tensor.- Returns:
- the scaled frequency tensor.
-