Package smile.deep
Class Model
java.lang.Object
smile.deep.Model
- Direct Known Subclasses:
VisionModel
The deep learning models.
-
Constructor Summary
-
Method Summary
Modifier and TypeMethodDescriptionorg.bytedeco.pytorch.Module
asTorch()
Returns the PyTorch Module object.device()
Returns the device on which the model is stored.eval()
Sets the model in the evaluation/inference mode.Evaluates the model accuracy on a test dataset.Forward propagation (or forward pass) through the model.Loads a checkpoint.Serialize the model as a checkpoint.void
setLearningRateSchedule
(TimeFunction learningRateSchedule) Sets the learning rate schedule.Moves the model to a device.toString()
train()
Sets the model in the training mode.void
Trains the model.void
train
(int epochs, Optimizer optimizer, Loss loss, Dataset train, Dataset val, String checkpoint, Metric... metrics) Trains the model.
-
Constructor Details
-
Model
Constructor.- Parameters:
net
- the neural network.
-
-
Method Details
-
toString
-
asTorch
public org.bytedeco.pytorch.Module asTorch()Returns the PyTorch Module object.- Returns:
- the PyTorch Module object.
-
train
Sets the model in the training mode.- Returns:
- this model.
-
eval
Sets the model in the evaluation/inference mode.- Returns:
- this model.
-
device
Returns the device on which the model is stored.- Returns:
- the compute device.
-
to
Moves the model to a device.- Parameters:
device
- the compute device.- Returns:
- this model.
-
load
Loads a checkpoint.- Parameters:
path
- the checkpoint file path.- Returns:
- this model.
-
save
Serialize the model as a checkpoint.- Parameters:
path
- the checkpoint file path.- Returns:
- this model.
-
apply
-
forward
Forward propagation (or forward pass) through the model.- Parameters:
input
- the input tensor.- Returns:
- the output tensor.
-
setLearningRateSchedule
Sets the learning rate schedule.- Parameters:
learningRateSchedule
- the learning rate schedule.
-
train
Trains the model.- Parameters:
epochs
- the number of training epochs.optimizer
- the optimization algorithm.loss
- the loss function.train
- the training data.
-
train
public void train(int epochs, Optimizer optimizer, Loss loss, Dataset train, Dataset val, String checkpoint, Metric... metrics) Trains the model.- Parameters:
epochs
- the number of training epochs.optimizer
- the optimization algorithm.loss
- the loss function.train
- the training data.val
- optional validation data.checkpoint
- optional checkpoint file path.metrics
- the evaluation metrics.
-
eval
Evaluates the model accuracy on a test dataset.- Parameters:
dataset
- the test dataset.metrics
- the evaluation metrics.- Returns:
- the accuracy.
-