-
Notifications
You must be signed in to change notification settings - Fork 3
/
sp_detector.py
271 lines (215 loc) · 12.9 KB
/
sp_detector.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
import numpy as np
import abc
import copy
from arff import xrange
from arff_helper import ArffHelper
class SmoothPursuitDetector(object):
"""
DBSCAN-based smooth pursuit detector. All the logic is in a DBSCANWithMinPts class, this is just a wrapper
that based on the arguments to __init__ method initiates DBSCANWithMinPts
"""
def __init__(self, param):
"""
Initialize the SmoothPursuitDetector object
:param eps_deg: Spatial Euclidean distance threshold that defines the neighbourhood in the XY-plane.
Given in degrees of visual field.
:param time_slice_millisec: Width of the time slice that defines the size of the neighbourhood on the time axis.
Value is given in milliseconds. The neighbourhood essentially has cylindrical shape.
:param min_pts: of points required to form a "valid" neighbourhood
(that integer indicating the minimum number of a core points).
"""
min_pts = param['MIN_PTS']
eps_deg = param['EPS_DEG']
time_slice_millisec = param['TIME_SLICE_MILLISEC']
self.clustering = DBSCANWithMinPts(eps_deg=eps_deg, time_slice_millisec=time_slice_millisec,
min_pts=min_pts)
def detect(self, gaze_points_list, inplace=False):
return self.clustering.cluster(gaze_points_list=gaze_points_list,
inplace=inplace)
class DBSCANWithTimeSlice(object):
"""
The class is based on DBSCAN algorithm used for density-based data clustering
(we run this to detect SP, after pre-filtering has removed saccades and fixations).
Rather than only using spatial locations, the algorithm uses spatio-temporal
information, i.e. we cluster gaze points data in three-dimensional (t, x, y) space.
Since there is no a priori optimal scaling factor between time and space,
we modify the classical DBSCAN notion of the neighbourhood (i.e. a sphere of radius @eps).
Instead of it, we consider the cylinder with its axis aligned with the time axis.
This way we have a XY-neighbourhood defined by Euclidean distance and its threshold of @eps,
and on the temporal axis we take a time slice of @time_slice_millisec width (hence the class name)
Neighbourhood validation is implemented by two classes that implement the DBSCANWithTimeSlice interface.
It is done in two different ways, namely "minPts" (validating that the number of other gaze points in the
neighbourhood is at least @min_pts, closer to original DBSCAN) and "minObservers" (we validate that samples
of at least @min_observers different observers are present in the neighbourhood).
"""
__metaclass__ = abc.ABCMeta
def __init__(self, eps_deg=2.0, time_slice_millisec=40):
"""
:param eps_deg: Spatial Euclidean distance threshold that defines the neighbourhood in the XY-plane.
Given in degrees of visual field.
:param time_slice_millisec: Width of the time slice that defines the size of the neighbourhood on the time axis.
Value is given in microseconds.
"""
self.time_slice = time_slice_millisec
self.eps_deg = eps_deg
# initialize empty data
self._data_set = None
# store timestamps separately for efficiency
self._timestamps = None
def cluster(self, gaze_points_list, inplace=False):
"""
Find clusters of input gaze data and label clustered points as smooth pursuit.
Labels (sets the 'EYE_MOVEMENT_TYPE' field) the clusters of data points as 'SP',
other samples as 'NOISE_CLUSTER'.
New column 'CLUSTER_ID' is added into the @DATA section of each arff object in @gaze_points_list,
indicating cluster group ID.
:param gaze_points_list: a list of arff objects (dictionary with fields such as 'data' and 'metadata')
:param inplace: whether to modify the original input gaze data with gaze data after clustering or use a copy
:return: gaze data after clustering in the same form as the input data.
"""
if not inplace:
gaze_points_list = copy.deepcopy(gaze_points_list)
# add global indexing to be able to reference the particular sample even after clustering all in one structure
ArffHelper.add_column(gaze_points_list, name='global_index', dtype='INTEGER', default_value=-1)
gaze_points_list['data']['global_index'] = np.arange(gaze_points_list['data'].shape[0])
self._data_set = self._aggregate_data(gaze_points_list)
# has to be a copy, so that is is placed continuously in memory
self._timestamps = self._data_set['time'].copy()
current_cluster_id = 0
for i in xrange(len(self._data_set)):
if self._data_set[i]['visited_flag'] == 1:
continue
else:
self._data_set[i]['visited_flag'] = 1
neighbourhood = self._get_neighbourhood(i)
if self._validate_neighbourhood(neighbourhood):
# if not: mark current point as NOISE
self._expand_cluster(i, neighbourhood, current_cluster_id)
current_cluster_id += 1
# create a new column in gaze_points_list for CLUSTER_ID
ArffHelper.add_column(gaze_points_list, 'CLUSTER_ID', 'NUMERIC', -1)
# label data in gaze_points_list as SP according to CLUSTER_ID
for i in xrange(len(self._data_set)):
global_index = self._data_set[i]['global_index']
if self._data_set[i]['CLUSTER_ID'] != -1:
gaze_points_list['data']['EYE_MOVEMENT_TYPE'][global_index] = 'SP'
gaze_points_list['data']['CLUSTER_ID'][global_index] = self._data_set[i]['CLUSTER_ID']
else:
gaze_points_list['data']['EYE_MOVEMENT_TYPE'][global_index] = 'NOISE_CLUSTER'
# can now remove the global_index column
ArffHelper.remove_column(gaze_points_list, name='global_index')
return gaze_points_list
def _expand_cluster(self, current_point, neighbourhood, current_cluster_id):
"""
Check all points within neighbourhood of current core point in order
to expand neighbourhood. Processes points in the @self._data_set
(a 6-column numpy array as data set to be clustered)
:param current_point: index of the current core point.
:param neighbourhood: index list as neighbourhood of current core point.
:param current_cluster_id: index of current cluster.
:return: index list of expanded neighbourhood points.
"""
self._data_set[current_point]['CLUSTER_ID'] = current_cluster_id
for neighbour in neighbourhood:
if self._data_set[neighbour]['visited_flag'] == 0:
self._data_set[neighbour]['visited_flag'] = 1
new_neighbourhood = self._get_neighbourhood(neighbour) # eps as input parameter
if self._validate_neighbourhood(new_neighbourhood):
new_neighbourhood_set = set(new_neighbourhood)
new_neighbours = list(new_neighbourhood_set.difference(neighbourhood))
neighbourhood.extend(new_neighbours) # something wrong if use neighbourhood_set.update
if self._data_set[neighbour]['CLUSTER_ID'] == -1:
self._data_set[neighbour]['CLUSTER_ID'] = current_cluster_id
return neighbourhood
def _aggregate_data(self, gaze_points_list):
"""
Aggregate data from @DATA of all arff objects in the input list into a
new data set in form of a numpy array.
:param gaze_points_list: gaze data to be clustered in form of list of arff objects.
:return: data set to be clustered in form of a 6-column numpy array,
i.e. ['time','x','y','observer_id','CLUSTER_ID','visited_flag'],
ordered by 'time' column value.
"""
data_set = []
gaze_points_data = gaze_points_list['data'][
(gaze_points_list['data']['EYE_MOVEMENT_TYPE'] == 'UNKNOWN')][['time', 'x', 'y', 'global_index']]
gaze_points_data = ArffHelper.add_column_to_array(gaze_points_data, 'CLUSTER_ID', 'NUMERIC', -1)
gaze_points_data = ArffHelper.add_column_to_array(gaze_points_data, 'visited_flag', 'NUMERIC', 0)
if len(gaze_points_data) > 0:
data_set.append(gaze_points_data)
data_set = np.concatenate(data_set)
data_set = np.sort(data_set, order='time')
return data_set
def _get_neighbourhood(self, current_point):
"""
Get neighbourhood of current point in self._data_set (a 6-column numpy array as data set to be clustered)
:param current_point: index of the current core point candidate.
:return: index list of the neighbourhood of current point.
"""
# cast to the appropriate type just in case
start_index = np.searchsorted(self._timestamps,
self._timestamps[current_point] - self._timestamps.dtype.type(self.time_slice),
side='left')
end_index = np.searchsorted(self._timestamps,
self._timestamps[current_point] + self._timestamps.dtype.type(self.time_slice),
side='right')
distance = np.linalg.norm([self._data_set[start_index:end_index]['x'] - self._data_set[current_point]['x'],
self._data_set[start_index:end_index]['y'] - self._data_set[current_point]['y']],
axis=0)
neighbourhood = (np.where(distance <= self.eps_deg)[0] + start_index).tolist()
return neighbourhood
@abc.abstractmethod
def _validate_neighbourhood(self, *args, **kwargs):
"""
Should return a boolean value after neighbourhood validation. Returns True if the point with such neighbourhood
is a core point (see DBSCAN method explanation for details).
Abstract method - implemented in subclasses.
"""
raise NotImplementedError("Implemented in subclass methods.")
class DBSCANWithMinPts(DBSCANWithTimeSlice):
"""
DBSCAN with time slice that uses MinPts as neighbourhood validation method
(validating that the number of other gaze points in the neighbourhood is at least @min_pts before declaring
this a core point).
This method is dependent on the frame rate of the gaze position recording, since the number of points in a
fixed temporal slice will grow proportionally to gaze recording fps. If more independence from the fps is desired,
use DBSCANWithMinObservers. Tha default value here was used on a dataset with 250 Hz tracker used.
"""
def __init__(self, eps_deg=2.0, time_slice_millisec=40, min_pts=1):
"""
Initialize DBSCANWithMinPts object.
:param eps_deg: Spatial Euclidean distance threshold that defines the neighbourhood in the XY-plane.
Given in degrees of visual field, a pixel value is assigned when the recordings' data
is provided.
:param time_slice_millisec: Width of the time slice that defines the size of the neighbourhood on the time axis.
Value is given in microseconds.
:param min_pts: integer indicating the minimum number of points required to
form a "valid" neighbourhood (that of a core point).
Could also be a 'num_observers' string (default), in which case
the actual value is determined during the self._setup_internal_parameters() call
"""
super(DBSCANWithMinPts, self).__init__(eps_deg=eps_deg, time_slice_millisec=time_slice_millisec)
self.min_pts = min_pts
if type(self.min_pts) == int:
self.min_pts_abs_value = self.min_pts
def _setup_internal_parameters(self, gaze_points_list):
"""
If min_pts was 'num_observers', set it accordingly here
:param gaze_points_list: a list of arff objects (dictionary with fields such as 'data' and 'metadata')
"""
if self.min_pts == 'num_observers':
self.min_pts_abs_value = len(gaze_points_list)
def _validate_neighbourhood(self, neighbourhood):
"""
Compare the size of @neighbourhood with @self.min_pts and return boolean value
as result of validation. True if this is the neighbourhood of a core point, false otherwise.
@self._data_set (a 6-column numpy array as data set to be clustered) is used to interpret
the @neighbourhood list.
:param neighbourhood: index list as neighbourhood to be validated.
:return: boolean value.
True if @neighbourhood contains more than @self.min_pts points, False if not.
"""
if len(neighbourhood) >= self.min_pts_abs_value:
return True
else:
return False