-
Notifications
You must be signed in to change notification settings - Fork 46
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add correlation similarity metric #158
Merged
Merged
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,11 +1,13 @@ | ||
"""Statistical Metrics to compare column pairs.""" | ||
|
||
from sdmetrics.column_pairs.statistical.contingency_similarity import ContingencySimilarity | ||
from sdmetrics.column_pairs.statistical.correlation_similarity import CorrelationSimilarity | ||
from sdmetrics.column_pairs.statistical.kl_divergence import ( | ||
ContinuousKLDivergence, DiscreteKLDivergence) | ||
|
||
__all__ = [ | ||
'ContingencySimilarity', | ||
'ContinuousKLDivergence', | ||
'CorrelationSimilarity', | ||
'DiscreteKLDivergence', | ||
] |
97 changes: 97 additions & 0 deletions
97
sdmetrics/column_pairs/statistical/correlation_similarity.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
"""Correlation Similarity Metric.""" | ||
|
||
import pandas as pd | ||
from scipy.stats import pearsonr, spearmanr | ||
|
||
from sdmetrics.column_pairs.base import ColumnPairsMetric | ||
from sdmetrics.goal import Goal | ||
from sdmetrics.utils import is_datetime | ||
|
||
|
||
class CorrelationSimilarity(ColumnPairsMetric): | ||
"""Correlation similarity metric. | ||
|
||
Attributes: | ||
name (str): | ||
Name to use when reports about this metric are printed. | ||
goal (sdmetrics.goal.Goal): | ||
The goal of this metric. | ||
min_value (Union[float, tuple[float]]): | ||
Minimum value or values that this metric can take. | ||
max_value (Union[float, tuple[float]]): | ||
Maximum value or values that this metric can take. | ||
""" | ||
|
||
name = 'CorrelationSimilarity' | ||
goal = Goal.MAXIMIZE | ||
min_value = 0.0 | ||
max_value = 1.0 | ||
|
||
@classmethod | ||
def compute_breakdown(cls, real_data, synthetic_data, coefficient='Pearson'): | ||
"""Compare the breakdown of correlation similarity of two continuous columns. | ||
|
||
Args: | ||
real_data (Union[numpy.ndarray, pandas.Series]): | ||
The values from the real dataset. | ||
synthetic_data (Union[numpy.ndarray, pandas.Series]): | ||
The values from the synthetic dataset. | ||
|
||
Returns: | ||
dict: | ||
A dict containing the score, and the real and synthetic metric values. | ||
""" | ||
real_data[pd.isna(real_data)] = 0.0 | ||
synthetic_data[pd.isna(synthetic_data)] = 0.0 | ||
column1, column2 = real_data.columns[:2] | ||
|
||
if is_datetime(real_data): | ||
real_data = pd.to_numeric(real_data) | ||
synthetic_data = pd.to_numeric(synthetic_data) | ||
|
||
correlation_fn = None | ||
if coefficient == 'Pearson': | ||
correlation_fn = pearsonr | ||
elif coefficient == 'Spearman': | ||
correlation_fn = spearmanr | ||
else: | ||
raise ValueError(f'requested coefficient {coefficient} is not valid. ' | ||
'Please choose either Pearson or Spearman.') | ||
|
||
correlation_real = correlation_fn(real_data[column1], real_data[column2]) | ||
correlation_synthetic = correlation_fn(synthetic_data[column1], synthetic_data[column2]) | ||
return { | ||
'score': 1 - abs(correlation_real - correlation_synthetic) / 2, | ||
'real': correlation_real, | ||
'synthetic': correlation_synthetic, | ||
} | ||
|
||
@classmethod | ||
def compute(cls, real_data, synthetic_data, coefficient='Pearson'): | ||
"""Compare the correlation similarity of two continuous columns. | ||
|
||
Args: | ||
real_data (Union[numpy.ndarray, pandas.Series]): | ||
The values from the real dataset. | ||
synthetic_data (Union[numpy.ndarray, pandas.Series]): | ||
The values from the synthetic dataset. | ||
|
||
Returns: | ||
float: | ||
The correlation similarity of the two columns. | ||
""" | ||
return cls.compute_breakdown(real_data, synthetic_data, coefficient)['score'] | ||
|
||
@classmethod | ||
def normalize(cls, raw_score): | ||
"""Return the `raw_score` as is, since it is already normalized. | ||
|
||
Args: | ||
raw_score (float): | ||
The value of the metric from `compute`. | ||
|
||
Returns: | ||
float: | ||
The normalized value of the metric | ||
""" | ||
return super().normalize(raw_score) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Unit tests for the column pairs module.""" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Unit tests for the column pairs statistical metrics.""" |
150 changes: 150 additions & 0 deletions
150
tests/unit/column_pairs/statistical/test_correlation_similarity.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
from datetime import datetime | ||
from unittest.mock import Mock, call, patch | ||
|
||
import pandas as pd | ||
|
||
from sdmetrics.column_pairs.statistical import CorrelationSimilarity | ||
from tests.utils import SeriesMatcher | ||
|
||
|
||
class TestCorrelationSimilarity: | ||
|
||
@patch('sdmetrics.column_pairs.statistical.correlation_similarity.pearsonr') | ||
def test_compute_breakdown(self, pearson_mock): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can we do an example with datetime columns? |
||
"""Test the ``compute_breakdown`` method. | ||
|
||
Expect that the selected coefficient is used to compare the real and synthetic data. | ||
|
||
Setup: | ||
- Patch the ``scipy.stats.pearsonr`` method to return a test result. | ||
|
||
Input: | ||
- Mocked real data. | ||
- Mocked synthetic data. | ||
|
||
Output: | ||
- A mapping of the metric results, containing the score and the real and synthetic results. | ||
""" | ||
# Setup | ||
real_data = pd.DataFrame({'col1': [1.0, 2.4, 2.6, 0.8], 'col2': [1, 2, 3, 4]}) | ||
synthetic_data = pd.DataFrame({'col1': [0.9, 1.8, 3.1, 5.0], 'col2': [2, 3, 4, 1]}) | ||
score_real = -0.451 | ||
score_synthetic = -0.003 | ||
pearson_mock.side_effect = [score_real, score_synthetic] | ||
expected_score_breakdown = { | ||
'score': 1 - abs(score_real - score_synthetic) / 2, | ||
'real': score_real, | ||
'synthetic': score_synthetic, | ||
} | ||
|
||
# Run | ||
metric = CorrelationSimilarity() | ||
result = metric.compute_breakdown(real_data, synthetic_data, coefficient='Pearson') | ||
|
||
# Assert | ||
assert pearson_mock.has_calls( | ||
call(SeriesMatcher(real_data['col1']), SeriesMatcher(real_data['col2'])), | ||
call(SeriesMatcher(synthetic_data['col1']), SeriesMatcher(synthetic_data['col2'])), | ||
) | ||
assert result == expected_score_breakdown | ||
|
||
@patch('sdmetrics.column_pairs.statistical.correlation_similarity.pearsonr') | ||
def test_compute_breakdown_datetime(self, pearson_mock): | ||
"""Test the ``compute_breakdown`` method with datetime input. | ||
|
||
Expect that the selected coefficient is used to compare the real and synthetic data. | ||
|
||
Setup: | ||
- Patch the ``scipy.stats.pearsonr`` method to return a test result. | ||
|
||
Input: | ||
- Mocked real data. | ||
- Mocked synthetic data. | ||
|
||
Output: | ||
- A mapping of the metric results, containing the score and the real and synthetic results. | ||
""" | ||
# Setup | ||
real_data = pd.DataFrame({ | ||
'col1': [datetime(2020, 1, 3), datetime(2020, 10, 13), datetime(2021, 5, 3)], | ||
'col2': [datetime(2021, 7, 23), datetime(2021, 8, 3), datetime(2020, 9, 24)], | ||
}) | ||
synthetic_data = pd.DataFrame({ | ||
'col1': [datetime(2021, 9, 19), datetime(2021, 10, 1), datetime(2020, 3, 1)], | ||
'col2': [datetime(2022, 4, 28), datetime(2021, 7, 31), datetime(2020, 4, 2)], | ||
}) | ||
score_real = 0.2 | ||
score_synthetic = 0.1 | ||
pearson_mock.side_effect = [score_real, score_synthetic] | ||
expected_score_breakdown = { | ||
'score': 1 - abs(score_real - score_synthetic) / 2, | ||
'real': score_real, | ||
'synthetic': score_synthetic, | ||
} | ||
|
||
# Run | ||
metric = CorrelationSimilarity() | ||
result = metric.compute_breakdown(real_data, synthetic_data, coefficient='Pearson') | ||
|
||
# Assert | ||
assert pearson_mock.has_calls( | ||
call(SeriesMatcher(real_data['col1']), SeriesMatcher(real_data['col2'])), | ||
call(SeriesMatcher(synthetic_data['col1']), SeriesMatcher(synthetic_data['col2'])), | ||
) | ||
assert result == expected_score_breakdown | ||
|
||
def test_compute(self): | ||
"""Test the ``compute`` method. | ||
|
||
Expect that the selected coefficient is used to compare the real and synthetic data. | ||
|
||
Setup: | ||
- Mock the ``compute`` method to return a test score. | ||
|
||
Input: | ||
- Real data. | ||
- Synthetic data. | ||
|
||
Output: | ||
- The evaluated metric. | ||
""" | ||
# Setup | ||
test_score = 0.2 | ||
score_breakdown = {'score': test_score} | ||
metric = CorrelationSimilarity() | ||
|
||
# Run | ||
with patch.object( | ||
CorrelationSimilarity, | ||
'compute_breakdown', | ||
return_value=score_breakdown, | ||
): | ||
result = metric.compute(Mock(), Mock(), coefficient='Pearson') | ||
|
||
# Assert | ||
assert result == test_score | ||
|
||
@patch( | ||
'sdmetrics.column_pairs.statistical.correlation_similarity.ColumnPairsMetric.normalize' | ||
) | ||
def test_normalize(self, normalize_mock): | ||
"""Test the ``normalize`` method. | ||
|
||
Expect that the inherited ``normalize`` method is called. | ||
|
||
Input: | ||
- Raw score | ||
|
||
Output: | ||
- The output of the inherited ``normalize`` method. | ||
""" | ||
# Setup | ||
metric = CorrelationSimilarity() | ||
raw_score = 0.9 | ||
|
||
# Run | ||
result = metric.normalize(raw_score) | ||
|
||
# Assert | ||
normalize_mock.assert_called_once_with(raw_score) | ||
assert result == normalize_mock.return_value |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should we error if synthetic data isn't also datetime?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm so I have a couple thoughts -
column_pairs
andsingle_column
metrics. I do think it would be nice to add more comprehensive validation. I'm not sure if it's worth adding it in this PR, since that will make it less unified.I'm in favor of opening another issue around adding data type verification on the base classes of
ColumnPairMetric
andSingleColumnMetric
, and addressing it for all metrics. Let me know what you think.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense to me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, added an issue for tracking here: #168