This repository has been archived by the owner on May 13, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathassignment-01.r
94 lines (59 loc) · 1.69 KB
/
assignment-01.r
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
###############
# Init system #
###############
# Load packages
library("e1071")
library("keras")
# Disable warnings
options(warn=-1)
# Clear workspace
graphics.off()
rm(list=ls())
# Experimental setup
num_train = 5000 # 60000 for full data set
num_test = 500 # 10000 for full data set
#########################
# Load and prepare data #
#########################
# Load MNIST data
mnist = dataset_mnist()
# Assign train and test data+labels
train_data = mnist$train$x
train_label = mnist$train$y
test_data = mnist$test$x
test_label = mnist$test$y
# Reshape images to vectors
train_data = array_reshape(train_data, c(nrow(train_data), 784))
test_data = array_reshape(test_data, c(nrow(test_data), 784))
# Rescale data to range [0,1]
train_data = train_data / 255
test_data = test_data / 255
# select subset for training
train_data = train_data[1:num_train,]
train_label = train_label[1:num_train]
# select subset for testing
test_data = test_data[1:num_test,]
test_label = test_label[1:num_test]
################################
# Train and run classification #
################################
# Init timer
t1 = proc.time()
# Train SVM
S = svm(train_data, factor(train_label))
cat("\n\nCorrect classification results:")
# Eval SVM on training data
pr_tr = predict(S, train_data)
success = sum(pr_tr==factor(train_label))/length(train_label)*100
res_s = sprintf('\n Train: %5.2f\n\n', success)
cat(res_s)
# Eval SVM on test data
pr_te = predict(S, test_data)
success = sum(pr_te==factor(test_label))/length(test_label)*100
res_s = sprintf('\n\n Test: %5.2f\n\n', success)
cat(res_s)
# End time, calculate elapsed time
t2 = proc.time()
t = t2-t1
cat("Computation time:\n\n")
print(t)