Record Class ModelArgs

java.lang.Object
java.lang.Record
smile.llm.llama.ModelArgs
Record Components:
dim - the dimension of token embedding.
numLayers - the number of transformer blocks.
numHeads - the number of attention heads.
numKvHeads - the number of key and value heads.
vocabSize - the size of the vocabulary.
multipleOf - make SwiGLU hidden layer size multiple of large power of 2.
ffnDimMultiplier - the multiplier for the hidden dimension of the feedforward layers.
normEps - the epsilon value used for numerical stability in normalization layers.
ropeTheta - the theta parameter in rotary positional encoding.
maxBatchSize - the maximum batch size.
maxSeqLen - the maximum sequence length for input data.

public record ModelArgs(int dim, int numLayers, int numHeads, Integer numKvHeads, int vocabSize, int multipleOf, Double ffnDimMultiplier, double normEps, double ropeTheta, boolean scaledRope, int maxBatchSize, int maxSeqLen) extends Record
LLaMA model hyperparameters.
  • Constructor Details

    • ModelArgs

      public ModelArgs()
      Constructor with default parameter values.
    • ModelArgs

      public ModelArgs(int dim, int numLayers, int numHeads, Integer numKvHeads, int vocabSize, int multipleOf, Double ffnDimMultiplier, double normEps, double ropeTheta, boolean scaledRope, int maxBatchSize, int maxSeqLen)
      Creates an instance of a ModelArgs record class.
      Parameters:
      dim - the value for the dim record component
      numLayers - the value for the numLayers record component
      numHeads - the value for the numHeads record component
      numKvHeads - the value for the numKvHeads record component
      vocabSize - the value for the vocabSize record component
      multipleOf - the value for the multipleOf record component
      ffnDimMultiplier - the value for the ffnDimMultiplier record component
      normEps - the value for the normEps record component
      ropeTheta - the value for the ropeTheta record component
      scaledRope - the value for the scaledRope record component
      maxBatchSize - the value for the maxBatchSize record component
      maxSeqLen - the value for the maxSeqLen record component
  • Method Details

    • from

      public static ModelArgs from(String path, int maxBatchSize, int maxSeqLen) throws IOException
      Loads the model hyperparameters from a JSON file.
      Parameters:
      path - the file path.
      maxBatchSize - the maximum batch size.
      maxSeqLen - the maximum sequence length for input data.
      Returns:
      the model hyperparameters.
      Throws:
      IOException
    • toString

      public final String toString()
      Returns a string representation of this record class. The representation contains the name of the class, followed by the name and value of each of the record components.
      Specified by:
      toString in class Record
      Returns:
      a string representation of this object
    • hashCode

      public final int hashCode()
      Returns a hash code value for this object. The value is derived from the hash code of each of the record components.
      Specified by:
      hashCode in class Record
      Returns:
      a hash code value for this object
    • equals

      public final boolean equals(Object o)
      Indicates whether some other object is "equal to" this one. The objects are equal if the other object is of the same class and if all the record components are equal. Reference components are compared with Objects::equals(Object,Object); primitive components are compared with '=='.
      Specified by:
      equals in class Record
      Parameters:
      o - the object with which to compare
      Returns:
      true if this object is the same as the o argument; false otherwise.
    • dim

      public int dim()
      Returns the value of the dim record component.
      Returns:
      the value of the dim record component
    • numLayers

      public int numLayers()
      Returns the value of the numLayers record component.
      Returns:
      the value of the numLayers record component
    • numHeads

      public int numHeads()
      Returns the value of the numHeads record component.
      Returns:
      the value of the numHeads record component
    • numKvHeads

      public Integer numKvHeads()
      Returns the value of the numKvHeads record component.
      Returns:
      the value of the numKvHeads record component
    • vocabSize

      public int vocabSize()
      Returns the value of the vocabSize record component.
      Returns:
      the value of the vocabSize record component
    • multipleOf

      public int multipleOf()
      Returns the value of the multipleOf record component.
      Returns:
      the value of the multipleOf record component
    • ffnDimMultiplier

      public Double ffnDimMultiplier()
      Returns the value of the ffnDimMultiplier record component.
      Returns:
      the value of the ffnDimMultiplier record component
    • normEps

      public double normEps()
      Returns the value of the normEps record component.
      Returns:
      the value of the normEps record component
    • ropeTheta

      public double ropeTheta()
      Returns the value of the ropeTheta record component.
      Returns:
      the value of the ropeTheta record component
    • scaledRope

      public boolean scaledRope()
      Returns the value of the scaledRope record component.
      Returns:
      the value of the scaledRope record component
    • maxBatchSize

      public int maxBatchSize()
      Returns the value of the maxBatchSize record component.
      Returns:
      the value of the maxBatchSize record component
    • maxSeqLen

      public int maxSeqLen()
      Returns the value of the maxSeqLen record component.
      Returns:
      the value of the maxSeqLen record component