Skip to content

Commit e2ccc96

Browse files
authored
Feature PCA outlier detector (#728)
* Add Linear and kernel PCA outlier dectectors
1 parent 7e6d9df commit e2ccc96

File tree

13 files changed

+829
-18
lines changed

13 files changed

+829
-18
lines changed

alibi_detect/od/_knn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from typing import Callable, Union, Optional, Dict, Any, List, Tuple
22
from typing import TYPE_CHECKING
3+
from typing_extensions import Literal
34

45
import numpy as np
56

6-
from typing_extensions import Literal
77
from alibi_detect.base import outlier_prediction_dict
88
from alibi_detect.exceptions import _catch_error as catch_error
99
from alibi_detect.od.base import TransformProtocol, TransformProtocolType

alibi_detect/od/_mahalanobis.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from typing import Union, Optional, Dict, Any
22
from typing import TYPE_CHECKING
33
from alibi_detect.exceptions import _catch_error as catch_error
4-
4+
from typing_extensions import Literal
55

66
import numpy as np
77

8-
from alibi_detect.utils._types import Literal
98
from alibi_detect.base import BaseDetector, FitMixin, ThresholdMixin, outlier_prediction_dict
109
from alibi_detect.od.pytorch import MahalanobisTorch
1110
from alibi_detect.utils.frameworks import BackendValidator

alibi_detect/od/_pca.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
from typing import Union, Optional, Callable, Dict, Any
2+
from typing import TYPE_CHECKING
3+
from typing_extensions import Literal
4+
5+
import numpy as np
6+
7+
from alibi_detect.base import outlier_prediction_dict
8+
from alibi_detect.base import BaseDetector, ThresholdMixin, FitMixin
9+
from alibi_detect.od.pytorch import KernelPCATorch, LinearPCATorch
10+
from alibi_detect.utils.frameworks import BackendValidator
11+
from alibi_detect.version import __version__
12+
from alibi_detect.exceptions import _catch_error as catch_error
13+
14+
15+
if TYPE_CHECKING:
16+
import torch
17+
18+
19+
backends = {
20+
'pytorch': (KernelPCATorch, LinearPCATorch)
21+
}
22+
23+
24+
class PCA(BaseDetector, ThresholdMixin, FitMixin):
25+
def __init__(
26+
self,
27+
n_components: int,
28+
kernel: Optional[Callable] = None,
29+
backend: Literal['pytorch'] = 'pytorch',
30+
device: Optional[Union[Literal['cuda', 'gpu', 'cpu'], 'torch.device']] = None,
31+
) -> None:
32+
"""Principal Component Analysis (PCA) outlier detector.
33+
34+
The detector is based on the Principal Component Analysis (PCA) algorithm. There are two variants of PCA:
35+
linear PCA and kernel PCA. Linear PCA computes the eigenvectors of the covariance matrix of the data. Kernel
36+
PCA computes the eigenvectors of the kernel matrix of the data.
37+
38+
When scoring a test instance using the linear variant compute the distance to the principal subspace spanned
39+
by the first `n_components` eigenvectors.
40+
41+
When scoring a test instance using the kernel variant we project it onto the largest eigenvectors and
42+
compute its score using the L2 norm.
43+
44+
If a threshold is fitted we use this to determine whether the instance is an outlier or not.
45+
46+
Parameters
47+
----------
48+
n_components:
49+
The number of dimensions in the principal subspace. For linear pca should have
50+
``1 <= n_components < dim(data)``. For kernel pca should have ``1 <= n_components < len(data)``.
51+
kernel
52+
Kernel function to use for outlier detection. If ``None``, linear PCA is used instead of the
53+
kernel variant.
54+
backend
55+
Backend used for outlier detection. Defaults to ``'pytorch'``. Options are ``'pytorch'``.
56+
device
57+
Device type used. The default tries to use the GPU and falls back on CPU if needed. Can be specified by
58+
passing either ``'cuda'``, ``'gpu'``, ``'cpu'`` or an instance of ``torch.device``.
59+
60+
Raises
61+
------
62+
NotImplementedError
63+
If choice of `backend` is not implemented.
64+
ValueError
65+
If `n_components` is less than 1.
66+
"""
67+
super().__init__()
68+
69+
backend_str: str = backend.lower()
70+
BackendValidator(
71+
backend_options={'pytorch': ['pytorch']},
72+
construct_name=self.__class__.__name__
73+
).verify_backend(backend_str)
74+
75+
kernel_backend_cls, linear_backend_cls = backends[backend]
76+
77+
self.backend: Union[KernelPCATorch, LinearPCATorch]
78+
if kernel is not None:
79+
self.backend = kernel_backend_cls(
80+
n_components=n_components,
81+
device=device,
82+
kernel=kernel
83+
)
84+
else:
85+
self.backend = linear_backend_cls(
86+
n_components=n_components,
87+
device=device,
88+
)
89+
90+
def fit(self, x_ref: np.ndarray) -> None:
91+
"""Fit the detector on reference data.
92+
93+
In the linear case we compute the principal components of the reference data using the
94+
covariance matrix and then remove the largest `n_components` eigenvectors. The remaining
95+
eigenvectors correspond to the invariant dimensions of the data. Changes in these
96+
dimensions are used to compute the outlier score which is the distance to the principal
97+
subspace spanned by the first `n_components` eigenvectors.
98+
99+
In the kernel case we compute the principal components of the reference data using the
100+
kernel matrix and then return the largest `n_components` eigenvectors. These are then
101+
normalized to have length equal to `1/eigenvalue`. Note that this differs from the
102+
linear case where we remove the largest eigenvectors.
103+
104+
In both cases we then store the computed components to use later when we score test
105+
instances.
106+
107+
Parameters
108+
----------
109+
x_ref
110+
Reference data used to fit the detector.
111+
112+
Raises
113+
------
114+
ValueError
115+
If using linear pca variant and `n_components` is greater than or equal to number of
116+
features or if using kernel pca variant and `n_components` is greater than or equal
117+
to number of instances.
118+
"""
119+
self.backend.fit(self.backend._to_tensor(x_ref))
120+
121+
@catch_error('NotFittedError')
122+
def score(self, x: np.ndarray) -> np.ndarray:
123+
"""Score `x` instances using the detector.
124+
125+
Project `x` onto the eigenvectors and compute the score using the L2 norm.
126+
127+
Parameters
128+
----------
129+
x
130+
Data to score. The shape of `x` should be `(n_instances, n_features)`.
131+
132+
Returns
133+
-------
134+
Outlier scores. The shape of the scores is `(n_instances,)`. The higher the score, the more anomalous the \
135+
instance.
136+
137+
Raises
138+
------
139+
NotFittedError
140+
If called before detector has been fit.
141+
"""
142+
score = self.backend.score(self.backend._to_tensor(x))
143+
return self.backend._to_numpy(score)
144+
145+
@catch_error('NotFittedError')
146+
def infer_threshold(self, x: np.ndarray, fpr: float) -> None:
147+
"""Infer the threshold for the PCA detector.
148+
149+
The threshold is computed so that the outlier detector would incorrectly classify `fpr` proportion of the
150+
reference data as outliers.
151+
152+
Parameters
153+
----------
154+
x
155+
Reference data used to infer the threshold.
156+
fpr
157+
False positive rate used to infer the threshold. The false positive rate is the proportion of
158+
instances in `x` that are incorrectly classified as outliers. The false positive rate should
159+
be in the range ``(0, 1)``.
160+
161+
Raises
162+
------
163+
ValueError
164+
Raised if `fpr` is not in ``(0, 1)``.
165+
NotFittedError
166+
If called before detector has been fit.
167+
"""
168+
self.backend.infer_threshold(self.backend._to_tensor(x), fpr)
169+
170+
@catch_error('NotFittedError')
171+
def predict(self, x: np.ndarray) -> Dict[str, Any]:
172+
"""Predict whether the instances in `x` are outliers or not.
173+
174+
Scores the instances in `x` and if the threshold was inferred, returns the outlier labels and p-values as well.
175+
176+
Parameters
177+
----------
178+
x
179+
Data to predict. The shape of `x` should be `(n_instances, n_features)`.
180+
181+
Returns
182+
-------
183+
Dictionary with keys 'data' and 'meta'. 'data' contains the outlier scores. If threshold inference was \
184+
performed, 'data' also contains the threshold value, outlier labels and p-vals . The shape of the scores is \
185+
`(n_instances,)`. The higher the score, the more anomalous the instance. 'meta' contains information about \
186+
the detector.
187+
188+
Raises
189+
------
190+
NotFittedError
191+
If called before detector has been fit.
192+
"""
193+
outputs = self.backend.predict(self.backend._to_tensor(x))
194+
output = outlier_prediction_dict()
195+
output['data'] = {
196+
**output['data'],
197+
**self.backend._to_numpy(outputs)
198+
}
199+
output['meta'] = {
200+
**output['meta'],
201+
'name': self.__class__.__name__,
202+
'detector_type': 'outlier',
203+
'online': False,
204+
'version': __version__,
205+
}
206+
return output

alibi_detect/od/pytorch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22

33
KNNTorch = import_optional('alibi_detect.od.pytorch.knn', ['KNNTorch'])
44
MahalanobisTorch = import_optional('alibi_detect.od.pytorch.mahalanobis', ['MahalanobisTorch'])
5+
KernelPCATorch, LinearPCATorch = import_optional('alibi_detect.od.pytorch.pca', ['KernelPCATorch', 'LinearPCATorch'])
56
Ensembler = import_optional('alibi_detect.od.pytorch.ensemble', ['Ensembler'])

alibi_detect/od/pytorch/base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from typing import List, Union, Optional, Dict
2+
from typing_extensions import Literal
23
from dataclasses import dataclass, asdict
34
from abc import ABC, abstractmethod
45

@@ -60,7 +61,10 @@ class TorchOutlierDetector(torch.nn.Module, FitMixinTorch, ABC):
6061
threshold_inferred = False
6162
threshold = None
6263

63-
def __init__(self, device: Optional[Union[str, torch.device]] = None):
64+
def __init__(
65+
self,
66+
device: Optional[Union[Literal['cuda', 'gpu', 'cpu'], 'torch.device']] = None,
67+
):
6468
self.device = get_device(device)
6569
super().__init__()
6670

alibi_detect/od/pytorch/knn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from typing import Optional, Union, List, Tuple
2-
2+
from typing_extensions import Literal
33
import numpy as np
44
import torch
55

@@ -13,7 +13,7 @@ def __init__(
1313
k: Union[np.ndarray, List, Tuple, int],
1414
kernel: Optional[torch.nn.Module] = None,
1515
ensembler: Optional[Ensembler] = None,
16-
device: Optional[Union[str, torch.device]] = None
16+
device: Optional[Union[Literal['cuda', 'gpu', 'cpu'], 'torch.device']] = None,
1717
):
1818
"""PyTorch backend for KNN detector.
1919

alibi_detect/od/pytorch/mahalanobis.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from typing import Optional, Union
2-
2+
from typing_extensions import Literal
33
import torch
44

55
from alibi_detect.od.pytorch.base import TorchOutlierDetector
@@ -11,7 +11,7 @@ class MahalanobisTorch(TorchOutlierDetector):
1111
def __init__(
1212
self,
1313
min_eigenvalue: float = 1e-6,
14-
device: Optional[Union[str, torch.device]] = None
14+
device: Optional[Union[Literal['cuda', 'gpu', 'cpu'], 'torch.device']] = None,
1515
):
1616
"""PyTorch backend for Mahalanobis detector.
1717
@@ -20,8 +20,8 @@ def __init__(
2020
min_eigenvalue
2121
Eigenvectors with eigenvalues below this value will be discarded.
2222
device
23-
Device type used. The default None tries to use the GPU and falls back on CPU if needed.
24-
Can be specified by passing either ``'cuda'``, ``'gpu'`` or ``'cpu'``.
23+
Device type used. The default tries to use the GPU and falls back on CPU if needed. Can be specified by
24+
passing either ``'cuda'``, ``'gpu'``, ``'cpu'`` or an instance of ``torch.device``.
2525
"""
2626
super().__init__(device=device)
2727
self.min_eigenvalue = min_eigenvalue

0 commit comments

Comments
 (0)