環境
$ sw_vers
ProductName: Mac OS X
ProductVersion: 10.12.5
BuildVersion: 16F73
$ pyenv versions
system
* 3.6.1 (set by /Users/dela/.pyenv/version)
$ pip freeze | grep chainer
chainer==2.0.0
セットアップ
pip install sklearn
pip install chainer
コード
from sklearn.datasets import fetch_mldata
from sklearn.cross_validation import train_test_split
from sklearn.svm import LinearSVC as Classifier
from sklearn.metrics import confusion_matrix
import numpy as np
mnist = fetch_mldata("MNIST original", data_home=".")
data = np.asarray(mnist.data, np.float32)
data_train, data_test, label_train, label_test = train_test_split(data, mnist.target, test_size=0.2)
classifier = Classifier()
classifier.fit(data_train, label_train)
result = classifier.predict(data_test)
cmat = confusion_matrix(label_test, result)
print(cmat)
実行ログ
/Users/dela/.pyenv/versions/3.6.1/lib/python3.6/site-packages/sklearn/cross_validation.py:44: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.
"This module will be removed in 0.20.", DeprecationWarning)
[[1306 0 7 20 2 4 2 1 15 1]
[ 0 1511 6 22 2 0 0 5 62 2]
[ 13 24 1068 112 21 4 5 33 149 3]
[ 4 2 21 1360 2 5 2 16 48 9]
[ 5 2 8 14 1228 0 7 13 64 28]
[ 27 2 5 302 21 672 5 15 209 22]
[ 17 5 31 45 8 19 1101 0 68 0]
[ 4 2 4 36 21 0 0 1311 22 21]
[ 9 14 3 102 9 5 2 10 1192 12]
[ 12 2 7 63 62 3 0 158 102 1000]]