Class Attention

java.lang.Object
smile.llm.llama.Attention

public class Attention extends Object
Multi-head attention. It caches key and value information, applying rotary embeddings, and performing linear transformations.
  • Constructor Details

    • Attention

      public Attention(ModelArgs args)
      Constructor.
      Parameters:
      args - the model configuration parameters.
  • Method Details

    • forward

      public Tensor forward(Tensor x, int startPos, Tensor cis, Tensor mask)
      Forward pass through the attention module.
      Parameters:
      x - the input tensor.
      startPos - the starting position for attention caching.
      cis - the precomputed frequency tensor.
      mask - the attention mask tensor.
      Returns:
      the output tensor.