Comparing PyStruct and SVM-StructΒΆ

This example compares the performance of pystruct and SVM^struct on a multi-class problem. For the example to work, you need to install SVM^multiclass and set the path in this file. We are not using SVM^python, as that would be much slower, and we would need to implement our own model in a SVM^python compatible way. Instead, we just call the SVM^multiclass binary.

This comparison is only meaningful in the sense that both libraries use general structured prediction solvers to solve the task. The specialized implementation of the Crammer-Singer SVM in LibLinear is much faster than either one.

For SVM^struct, the plot show CPU time as reportet by SVM^struct. For pystruct, the plot shows the time spent in the fit function according to time.clock.

Both models have disabled constraint caching. With constraint caching, SVM^struct is somewhat faster, but PyStruct doesn’t gain anything.

import tempfile
import os
from time import clock

import numpy as np
from sklearn.datasets import dump_svmlight_file
from sklearn.datasets import fetch_mldata, load_iris, load_digits
from sklearn.metrics import accuracy_score
from sklearn.cross_validation import train_test_split
import matplotlib.pyplot as plt

from pystruct.models import MultiClassClf
from pystruct.learners import OneSlackSSVM

# please set the path to the svm-struct multiclass binaries here
svmstruct_path = "/home/user/amueller/tools/svm_multiclass/"


class MultiSVM():
    """scikit-learn compatible interface for SVM^multi.

    Dumps the data to a file and calls the binary.
    """
    def __init__(self, C=1.):
        self.C = C

    def fit(self, X, y):
        self.model_file = tempfile.mktemp(suffix='.svm')
        train_data_file = tempfile.mktemp(suffix='.svm_dat')
        dump_svmlight_file(X, y + 1, train_data_file, zero_based=False)
        C = self.C * 100. * len(X)
        svmstruct_process = os.popen(svmstruct_path
                                     + "svm_multiclass_learn -w 3 -c %f %s %s"
                                     % (C, train_data_file, self.model_file))
        self.output_ = svmstruct_process.read().split("\n")
        self.runtime_ = float(self.output_[-4].split(":")[1])

    def _predict(self, X, y=None):
        if y is None:
            y = np.ones(len(X))
        train_data_file = tempfile.mktemp(suffix='.svm_dat')

        dump_svmlight_file(X, y, train_data_file, zero_based=False)

        prediction_file = tempfile.mktemp(suffix='.out')
        os.system(svmstruct_path + "svm_multiclass_classify %s %s %s"
                  % (train_data_file, self.model_file, prediction_file))
        return np.loadtxt(prediction_file)

    def predict(self, X):
        return self._predict(X)[:, 0] - 1

    def score(self, X, y):
        y_pred = self.predict(X)
        return accuracy_score(y, y_pred)

    def decision_function(self, X):
        return self._predict(X)[:, 1:]


def eval_on_data(X_train, y_train, X_test, y_test, svm, Cs):
    # evaluate a single svm using varying C
    accuracies, times = [], []
    for C in Cs:
        svm.C = C
        start = clock()
        svm.fit(X_train, y_train)
        if hasattr(svm, "runtime_"):
            times.append(svm.runtime_)
        else:
            times.append(clock() - start)
        accuracies.append(accuracy_score(y_test, svm.predict(X_test)))
    return accuracies, times


def plot_curves(curve_svmstruct, curve_pystruct, Cs, title="", filename=""):
    # plot nice graphs comparing a value for the two implementations
    plt.figure(figsize=(7, 4))
    plt.plot(curve_svmstruct, "--", label="SVM^struct", c='red', linewidth=3)
    plt.plot(curve_pystruct, "-.", label="PyStruct", c='blue', linewidth=3)
    plt.xlabel("C")
    plt.xticks(np.arange(len(Cs)), Cs)
    plt.legend(loc='best')
    plt.title(title)
    if filename:
        plt.savefig("%s" % filename, bbox_inches='tight')


def do_comparison(X_train, y_train, X_test, y_test, dataset):
    # evaluate both svms on a given datasets, generate plots
    Cs = 10. ** np.arange(-4, 1)
    multisvm = MultiSVM()
    svm = OneSlackSSVM(MultiClassClf(), tol=0.01)

    accs_pystruct, times_pystruct = eval_on_data(X_train, y_train, X_test,
                                                 y_test, svm, Cs=Cs)
    accs_svmstruct, times_svmstruct = eval_on_data(X_train, y_train,
                                                   X_test, y_test,
                                                   multisvm, Cs=Cs)

    plot_curves(times_svmstruct, times_pystruct, Cs=Cs,
                title="learning time (s) %s" % dataset,
                filename="times_%s.pdf" % dataset)
    plot_curves(accs_svmstruct, accs_pystruct, Cs=Cs,
                title="accuracy %s" % dataset,
                filename="accs_%s.pdf" % dataset)


def main():
    if not os.path.exists(svmstruct_path + "svm_multiclass_learn"):
        print("Please install SVM^multi and set the svmstruct_path variable "
              "to run this example.")
        return

    datasets = ['iris', 'digits']
    #datasets = ['iris', 'digits', 'usps', 'mnist']

    # IRIS
    if 'iris' in datasets:
        iris = load_iris()
        X, y = iris.data, iris.target
        X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                            random_state=0)
        do_comparison(X_train, y_train, X_test, y_test, "iris")

    # DIGITS
    if 'digits' in datasets:
        digits = load_digits()
        X, y = digits.data / 16., digits.target
        X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                            random_state=0)
        do_comparison(X_train, y_train, X_test, y_test, "digits")

    # USPS
    if 'usps' in datasets:
        digits = fetch_mldata("USPS")
        X, y = digits.data, digits.target.astype(np.int) - 1
        X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                            random_state=0)
        do_comparison(X_train, y_train, X_test, y_test, "USPS")

    # MNIST
    if 'mnist' in datasets:
        digits = fetch_mldata("MNIST original")
        X, y = digits.data / 255., digits.target.astype(np.int)
        X_train, X_test = X[:60000], X[60000:]
        y_train, y_test = y[:60000], y[60000:]
        do_comparison(X_train, y_train, X_test, y_test, "MNIST")

    plt.show()


if __name__ == "__main__":
    main()

Total running time of the script: (0 minutes 0.000 seconds)

Download Python source code: multiclass_comparision_svm_struct.py