-
Notifications
You must be signed in to change notification settings - Fork 0
/
testFunction_for_students_MTb.m
75 lines (54 loc) · 2.33 KB
/
testFunction_for_students_MTb.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
% Test Script to give to the students, March 2015
%% Continuous Position Estimator Test Script
% This function first calls the function "positionEstimatorTraining" to get
% the relevant modelParameters, and then calls the function
% "positionEstimator" to decode the trajectory.
function RMSE = testFunction_for_students_MTb(teamName)
load monkeydata0.mat
% Set random number generator
rng(2013);
ix = randperm(length(trial));
addpath(teamName);
% Select training and testing data (you can choose to split your data in a different way if you wish)
trainingData = trial(ix(1:50),:);
testData = trial(ix(51:end),:);
fprintf('Testing the continuous position estimator...')
meanSqError = 0;
n_predictions = 0;
figure
hold on
axis square
grid
% Train Model
modelParameters = positionEstimatorTraining(trainingData);
for tr=1:size(testData,1)
display(['Decoding block ',num2str(tr),' out of ',num2str(size(testData,1))]);
pause(0.001)
for direc=randperm(8)
decodedHandPos = [];
times=320:20:size(testData(tr,direc).spikes,2);
for t=times
past_current_trial.trialId = testData(tr,direc).trialId;
past_current_trial.spikes = testData(tr,direc).spikes(:,1:t);
past_current_trial.decodedHandPos = decodedHandPos;
past_current_trial.startHandPos = testData(tr,direc).handPos(1:2,1);
if nargout('positionEstimator') == 3
[decodedPosX, decodedPosY, newParameters] = positionEstimator(past_current_trial, modelParameters);
modelParameters = newParameters;
elseif nargout('positionEstimator') == 2
[decodedPosX, decodedPosY] = positionEstimator(past_current_trial, modelParameters);
end
decodedPos = [decodedPosX; decodedPosY];
decodedHandPos = [decodedHandPos decodedPos];
meanSqError = meanSqError + norm(testData(tr,direc).handPos(1:2,t) - decodedPos)^2;
end
n_predictions = n_predictions+length(times);
hold on
plot(decodedHandPos(1,:),decodedHandPos(2,:), 'r');
plot(testData(tr,direc).handPos(1,times),testData(tr,direc).handPos(2,times),'b')
end
end
legend('Decoded Position', 'Actual Position')
RMSE = sqrt(meanSqError/n_predictions)
rmpath(genpath(teamName))
end