This example compares the performance of pystruct and SVM^struct on a multi-class problem. For the example to work, you need to install SVM^multiclass and set the path in this file. We are not using SVM^python, as that would be much slower, and we would need to implement our own model in a SVM^python compatible way. Instead, we just call the SVM^multiclass binary.
This comparison is only meaningful in the sense that both libraries use general structured prediction solvers to solve the task. The specialized implementation of the Crammer-Singer SVM in LibLinear is much faster than either one.
For SVM^struct, the plot show CPU time as reportet by SVM^struct. For pystruct, the plot shows the time spent in the fit function according to time.clock.
Both models have disabled constraint caching. With constraint caching, SVM^struct is somewhat faster, but PyStruct doesn’t gain anything.
import tempfile
import os
from time import clock
import numpy as np
from sklearn.datasets import dump_svmlight_file
from sklearn.datasets import fetch_mldata, load_iris, load_digits
from sklearn.metrics import accuracy_score
from sklearn.cross_validation import train_test_split
import matplotlib.pyplot as plt
from pystruct.models import MultiClassClf
from pystruct.learners import OneSlackSSVM
# please set the path to the svm-struct multiclass binaries here
svmstruct_path = "/home/user/amueller/tools/svm_multiclass/"
class MultiSVM():
"""scikit-learn compatible interface for SVM^multi.
Dumps the data to a file and calls the binary.
"""
def __init__(self, C=1.):
self.C = C
def fit(self, X, y):
self.model_file = tempfile.mktemp(suffix='.svm')
train_data_file = tempfile.mktemp(suffix='.svm_dat')
dump_svmlight_file(X, y + 1, train_data_file, zero_based=False)
C = self.C * 100. * len(X)
svmstruct_process = os.popen(svmstruct_path
+ "svm_multiclass_learn -w 3 -c %f %s %s"
% (C, train_data_file, self.model_file))
self.output_ = svmstruct_process.read().split("\n")
self.runtime_ = float(self.output_[-4].split(":")[1])
def _predict(self, X, y=None):
if y is None:
y = np.ones(len(X))
train_data_file = tempfile.mktemp(suffix='.svm_dat')
dump_svmlight_file(X, y, train_data_file, zero_based=False)
prediction_file = tempfile.mktemp(suffix='.out')
os.system(svmstruct_path + "svm_multiclass_classify %s %s %s"
% (train_data_file, self.model_file, prediction_file))
return np.loadtxt(prediction_file)
def predict(self, X):
return self._predict(X)[:, 0] - 1
def score(self, X, y):
y_pred = self.predict(X)
return accuracy_score(y, y_pred)
def decision_function(self, X):
return self._predict(X)[:, 1:]
def eval_on_data(X_train, y_train, X_test, y_test, svm, Cs):
# evaluate a single svm using varying C
accuracies, times = [], []
for C in Cs:
svm.C = C
start = clock()
svm.fit(X_train, y_train)
if hasattr(svm, "runtime_"):
times.append(svm.runtime_)
else:
times.append(clock() - start)
accuracies.append(accuracy_score(y_test, svm.predict(X_test)))
return accuracies, times
def plot_curves(curve_svmstruct, curve_pystruct, Cs, title="", filename=""):
# plot nice graphs comparing a value for the two implementations
plt.figure(figsize=(7, 4))
plt.plot(curve_svmstruct, "--", label="SVM^struct", c='red', linewidth=3)
plt.plot(curve_pystruct, "-.", label="PyStruct", c='blue', linewidth=3)
plt.xlabel("C")
plt.xticks(np.arange(len(Cs)), Cs)
plt.legend(loc='best')
plt.title(title)
if filename:
plt.savefig("%s" % filename, bbox_inches='tight')
def do_comparison(X_train, y_train, X_test, y_test, dataset):
# evaluate both svms on a given datasets, generate plots
Cs = 10. ** np.arange(-4, 1)
multisvm = MultiSVM()
svm = OneSlackSSVM(MultiClassClf(), tol=0.01)
accs_pystruct, times_pystruct = eval_on_data(X_train, y_train, X_test,
y_test, svm, Cs=Cs)
accs_svmstruct, times_svmstruct = eval_on_data(X_train, y_train,
X_test, y_test,
multisvm, Cs=Cs)
plot_curves(times_svmstruct, times_pystruct, Cs=Cs,
title="learning time (s) %s" % dataset,
filename="times_%s.pdf" % dataset)
plot_curves(accs_svmstruct, accs_pystruct, Cs=Cs,
title="accuracy %s" % dataset,
filename="accs_%s.pdf" % dataset)
def main():
if not os.path.exists(svmstruct_path + "svm_multiclass_learn"):
print("Please install SVM^multi and set the svmstruct_path variable "
"to run this example.")
return
datasets = ['iris', 'digits']
#datasets = ['iris', 'digits', 'usps', 'mnist']
# IRIS
if 'iris' in datasets:
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y,
random_state=0)
do_comparison(X_train, y_train, X_test, y_test, "iris")
# DIGITS
if 'digits' in datasets:
digits = load_digits()
X, y = digits.data / 16., digits.target
X_train, X_test, y_train, y_test = train_test_split(X, y,
random_state=0)
do_comparison(X_train, y_train, X_test, y_test, "digits")
# USPS
if 'usps' in datasets:
digits = fetch_mldata("USPS")
X, y = digits.data, digits.target.astype(np.int) - 1
X_train, X_test, y_train, y_test = train_test_split(X, y,
random_state=0)
do_comparison(X_train, y_train, X_test, y_test, "USPS")
# MNIST
if 'mnist' in datasets:
digits = fetch_mldata("MNIST original")
X, y = digits.data / 255., digits.target.astype(np.int)
X_train, X_test = X[:60000], X[60000:]
y_train, y_test = y[:60000], y[60000:]
do_comparison(X_train, y_train, X_test, y_test, "MNIST")
plt.show()
if __name__ == "__main__":
main()
Total running time of the script: (0 minutes 0.000 seconds)
Download Python source code: multiclass_comparision_svm_struct.py