Class Model
java.lang.Object
smile.deep.Model
- Direct Known Subclasses:
VisionModel
-
Constructor Summary
ConstructorsConstructorDescriptionModel(LayerBlock net) Constructor.Model(LayerBlock net, Function<Tensor, Tensor> transform) Constructor. -
Method Summary
Modifier and TypeMethodDescriptionorg.bytedeco.pytorch.ModuleasTorch()Returns the PyTorch Module object.device()Returns the device on which the model is stored.dtype()Returns the data type.eval()Sets the model in the evaluation/inference mode.Evaluates the model accuracy on a test dataset.Forward propagation (or forward pass) through the model.Loads a checkpoint.Serialize the model as a checkpoint.voidsetLearningRateSchedule(TimeFunction learningRateSchedule) Sets the learning rate schedule.Moves the model to a device.to(Device device, ScalarType dtype) Moves the model to a device.toString()train()Sets the model in the training mode.voidTrains the model.voidtrain(int epochs, Optimizer optimizer, Loss loss, Dataset train, Dataset test, String checkpoint, Metric... metrics) Trains the model.
-
Constructor Details
-
Model
-
Model
Constructor.- Parameters:
net- the neural network.transform- the optional data preprocessing function.
-
-
Method Details
-
toString
-
asTorch
public org.bytedeco.pytorch.Module asTorch()Returns the PyTorch Module object.- Returns:
- the PyTorch Module object.
-
train
-
eval
-
device
Returns the device on which the model is stored.- Returns:
- the compute device.
-
dtype
-
to
-
to
Moves the model to a device.- Parameters:
device- the compute device.dtype- the data type.- Returns:
- this model.
-
load
-
save
-
apply
-
forward
-
setLearningRateSchedule
Sets the learning rate schedule.- Parameters:
learningRateSchedule- the learning rate schedule.
-
train
-
train
public void train(int epochs, Optimizer optimizer, Loss loss, Dataset train, Dataset test, String checkpoint, Metric... metrics) Trains the model.- Parameters:
epochs- the number of training epochs.optimizer- the optimization algorithm.loss- the loss function.train- the training data.test- optional validation data.checkpoint- optional checkpoint file path.metrics- the evaluation metrics.
-
eval
-