-
Notifications
You must be signed in to change notification settings - Fork 0
/
train2.py
110 lines (91 loc) · 3.04 KB
/
train2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# -*- coding: utf-8 -*-
"""
用于训练网络,很简单,就几行代码.
作者:殷和义
时间:2018年3月13日
"""
import scipy.io
import random
import net2
import numpy as np
import matplotlib.pyplot as plt
# 导入数据;
data = scipy.io.loadmat('clean.mat')
fft = data['fft_tr180']
power = data['power_tr180']
dps3 = data['dps3_tr180']
train_label = data['train_label']
fft_test = data['fft_tst180']
power_test = data['power_tst180']
dps3_test = data['dps3_tst180']
test_label = data['test_label']
# 对训练数据进行洗牌,注意:一定要把三种特征及标签的顺序保持一致;
random_seed = random.sample(np.arange(200), 200)
fft = fft[random_seed]
power = power[random_seed]
dps3 = dps3[random_seed]
train_label = train_label[random_seed]
#一些相关的重要参数
num_train = 2000
test_interval = 20
lr = 0.1
weight_decay = 0.001
train_batch_size = 50
test_batch_size = 250
# 创建网络并加载样本
solver = net2.net(train_batch_size, lr, weight_decay)
solver.load_sample_and_label(fft, power, dps3, train_label)
solver.load_sample_and_label_test(fft_test, power_test, dps3_test, test_label)
# 初始化权值;
solver.initial()
#初始化一些数组,用于保存需要的数据;
train_sequence = range(num_train) #生成1-num_train的数组
test_sequence = range(num_train // test_interval)
train_error = np.zeros(num_train)
#weight1 = np.zeros(num_train)
#weight2 = np.zeros(num_train)
#weight3 = np.zeros(num_train)
acc1 = np.zeros((num_train - 1)// test_interval + 1)
acc2 = np.zeros((num_train - 1)// test_interval + 1)
acc3 = np.zeros((num_train - 1)// test_interval + 1)
acc4 = np.zeros((num_train - 1)// test_interval + 1)
# 训练
for i in train_sequence:
print '第', i, '次迭代'
net2.layer.update_method.iteration = i
solver.forward()
solver.backward()
solver.update()
#记录一些数值
# weight1[i] = solver.fu.weights[0]
# weight2[i] = solver.fu.weights[1]
# weight3[i] = solver.fu.weights[2]
train_error[i] = solver.loss.loss
if i % test_interval == 0:
solver.turn_to_test(test_batch_size)
solver.forward_test()
acc1[i // test_interval] = solver.loss.accuracy
solver.forward_test()
acc2[i // test_interval] = solver.loss.accuracy
solver.forward_test()
acc3[i // test_interval] = solver.loss.accuracy
solver.forward_test()
acc4[i // test_interval] = solver.loss.accuracy
solver.turn_to_train(train_batch_size)
#fig1 = plt.figure(1)
#plt.plot(train_sequence, weight1, 'b', train_sequence, weight2, 'g', train_sequence, weight3, 'r',
# train_sequence[::20], weight1[::20],' bo', train_sequence[::20], weight2[::20],' gs',
# train_sequence[::20], weight3[::20], ' r^',label = ' dddddddsf')
## train_sequence, train_error)
#plt.legend()
#fig1.savefig('3 weights.svg')
#
##fig2 = plt.figure(2)
##plt.plot(test_sequence, acc1, test_sequence, acc2, test_sequence, acc3, test_sequence, acc4)
##fig2.savefig('accuracy.svg')
#
#plt.show()
#print '权值系数:', weight1[-1], weight2[-1], weight3[-1]
print '识别率:', acc1[-1], acc2[-1], acc3[-1], acc4[-1]
print '误差:', train_error[-1]