public class CRF
extends java.lang.Object
implements java.io.Serializable
A CRF is a Markov random field that was trained discriminatively. Therefore it is not necessary to model the distribution over always observed variables, which makes it possible to include arbitrarily complicated features of the observed variables into the model.
This class implements an algorithm that trains CRFs via gradient tree boosting. In tree boosting, the CRF potential functions are represented as weighted sums of regression trees, which provide compact representations of feature interactions. So the algorithm does not explicitly consider the potentially large parameter space. As a result, gradient tree boosting scales linearly in the order of the Markov model and in the order of the feature interactions, rather than exponentially as in previous algorithms based on iterative scaling and gradient descent.
Constructor and Description |
---|
CRF(StructType schema,
RegressionTree[][] potentials,
double shrinkage)
Constructor.
|
Modifier and Type | Method and Description |
---|---|
static CRF |
fit(Tuple[][] sequences,
int[][] labels)
Fits a CRF model.
|
static CRF |
fit(Tuple[][] sequences,
int[][] labels,
int ntrees,
int maxDepth,
int maxNodes,
int nodeSize,
double shrinkage)
Fits a CRF model.
|
static CRF |
fit(Tuple[][] sequences,
int[][] labels,
java.util.Properties prop)
Fits a CRF model.
|
int[] |
predict(Tuple[] x)
Returns the most likely label sequence given the feature sequence by the
forward-backward algorithm.
|
int[] |
viterbi(Tuple[] x)
Labels sequence with Viterbi algorithm.
|
public CRF(StructType schema, RegressionTree[][] potentials, double shrinkage)
schema
- the schema of features.potentials
- the potential functions.shrinkage
- the learning rate.public int[] viterbi(Tuple[] x)
public int[] predict(Tuple[] x)
x
- a sequence.public static CRF fit(Tuple[][] sequences, int[][] labels)
sequences
- the training data.labels
- the training sequence labels.public static CRF fit(Tuple[][] sequences, int[][] labels, java.util.Properties prop)
sequences
- the training data.labels
- the training sequence labels.public static CRF fit(Tuple[][] sequences, int[][] labels, int ntrees, int maxDepth, int maxNodes, int nodeSize, double shrinkage)
sequences
- the training data.labels
- the training sequence labels.ntrees
- the number of trees/iterations.maxDepth
- the maximum depth of the tree.maxNodes
- the maximum number of leaf nodes in the tree.nodeSize
- the number of instances in a node below which the tree will
not split, setting nodeSize = 5 generally gives good results.shrinkage
- the shrinkage parameter in (0, 1] controls the learning rate of procedure.