Package smile.deep

Class Model

java.lang.Object
smile.deep.Model
All Implemented Interfaces:
Function<Tensor,Tensor>
Direct Known Subclasses:
VisionModel

public class Model extends Object implements Function<Tensor,Tensor>
The deep learning models.
  • Constructor Details

    • Model

      public Model(LayerBlock net)
      Constructor.
      Parameters:
      net - the neural network.
    • Model

      public Model(LayerBlock net, Function<Tensor,Tensor> transform)
      Constructor.
      Parameters:
      net - the neural network.
      transform - the optional data preprocessing function.
  • Method Details

    • toString

      public String toString()
      Overrides:
      toString in class Object
    • asTorch

      public org.bytedeco.pytorch.Module asTorch()
      Returns the PyTorch Module object.
      Returns:
      the PyTorch Module object.
    • train

      public Model train()
      Sets the model in the training mode.
      Returns:
      this model.
    • eval

      public Model eval()
      Sets the model in the evaluation/inference mode.
      Returns:
      this model.
    • device

      public Device device()
      Returns the device on which the model is stored.
      Returns:
      the compute device.
    • dtype

      public ScalarType dtype()
      Returns the data type.
      Returns:
      the data type.
    • to

      public Model to(Device device)
      Moves the model to a device.
      Parameters:
      device - the compute device.
      Returns:
      this model.
    • to

      public Model to(Device device, ScalarType dtype)
      Moves the model to a device.
      Parameters:
      device - the compute device.
      dtype - the data type.
      Returns:
      this model.
    • load

      public Model load(String path)
      Loads a checkpoint.
      Parameters:
      path - the checkpoint file path.
      Returns:
      this model.
    • save

      public Model save(String path)
      Serialize the model as a checkpoint.
      Parameters:
      path - the checkpoint file path.
      Returns:
      this model.
    • apply

      public Tensor apply(Tensor input)
      Specified by:
      apply in interface Function<Tensor,Tensor>
    • forward

      public Tensor forward(Tensor input)
      Forward propagation (or forward pass) through the model.
      Parameters:
      input - the input tensor.
      Returns:
      the output tensor.
    • setLearningRateSchedule

      public void setLearningRateSchedule(TimeFunction learningRateSchedule)
      Sets the learning rate schedule.
      Parameters:
      learningRateSchedule - the learning rate schedule.
    • train

      public void train(int epochs, Optimizer optimizer, Loss loss, Dataset train)
      Trains the model.
      Parameters:
      epochs - the number of training epochs.
      optimizer - the optimization algorithm.
      loss - the loss function.
      train - the training data.
    • train

      public void train(int epochs, Optimizer optimizer, Loss loss, Dataset train, Dataset test, String checkpoint, Metric... metrics)
      Trains the model.
      Parameters:
      epochs - the number of training epochs.
      optimizer - the optimization algorithm.
      loss - the loss function.
      train - the training data.
      test - optional validation data.
      checkpoint - optional checkpoint file path.
      metrics - the evaluation metrics.
    • eval

      public Map<String,Double> eval(Dataset dataset, Metric... metrics)
      Evaluates the model accuracy on a test dataset.
      Parameters:
      dataset - the test dataset.
      metrics - the evaluation metrics.
      Returns:
      the accuracy.