Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 127 additions & 0 deletions ParquetSharp.Dataset.Test/Filter/TestIntFilter.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Apache.Arrow;
using NUnit.Framework;
using ParquetSharp.Dataset.Filter;

namespace ParquetSharp.Dataset.Test.Filter;

Expand Down Expand Up @@ -128,6 +130,33 @@ public void TestComputeIntRangeMask((long, long) filterRange)
TestComputeIntRangeMask<ulong, UInt64Array, UInt64Array.Builder>(rangeStart, rangeEnd, ULongValues, val => checked((long)val));
}

[Theory]
public void TestIntEqualityIncludeRowGroup(long filterValue)
{
TestIntEqualityIncludeRowGroup(filterValue, SByteValues, val => val);
TestIntEqualityIncludeRowGroup(filterValue, ShortValues, val => val);
TestIntEqualityIncludeRowGroup(filterValue, IntValues, val => val);
TestIntEqualityIncludeRowGroup(filterValue, LongValues, val => val);
TestIntEqualityIncludeRowGroup(filterValue, ByteValues, val => val);
TestIntEqualityIncludeRowGroup(filterValue, UShortValues, val => val);
TestIntEqualityIncludeRowGroup(filterValue, UIntValues, val => val);
TestIntEqualityIncludeRowGroup(filterValue, ULongValues, val => checked((long)val));
}

[Theory]
public void TestIntRangeIncludeRowGroup((long, long) filterRange)
{
var (rangeStart, rangeEnd) = filterRange;
TestIntRangeIncludeRowGroup(rangeStart, rangeEnd, SByteValues, val => val);
TestIntRangeIncludeRowGroup(rangeStart, rangeEnd, ShortValues, val => val);
TestIntRangeIncludeRowGroup(rangeStart, rangeEnd, IntValues, val => val);
TestIntRangeIncludeRowGroup(rangeStart, rangeEnd, LongValues, val => val);
TestIntRangeIncludeRowGroup(rangeStart, rangeEnd, ByteValues, val => val);
TestIntRangeIncludeRowGroup(rangeStart, rangeEnd, UShortValues, val => val);
TestIntRangeIncludeRowGroup(rangeStart, rangeEnd, UIntValues, val => val);
TestIntRangeIncludeRowGroup(rangeStart, rangeEnd, ULongValues, val => checked((long)val));
}

private static void TestComputeIntEqualityMask<T, TArray, TBuilder>(long filterValue, T[] values, Func<T, long> checkedCast)
where T : struct
where TArray : PrimitiveArray<T>
Expand Down Expand Up @@ -203,6 +232,104 @@ private static void TestComputeIntRangeMask<T, TArray, TBuilder>(long rangeStart
}
}

private static void TestIntEqualityIncludeRowGroup<T>(long filterValue, T[] values, Func<T, long> checkedCast)
where T : IComparable<T>
{
var filter = Col.Named("x").IsEqualTo(filterValue);

var statsRanges = values
.SelectMany(min => values.Select(max => (min, max)))
.Where(range => range.max.CompareTo(range.min) >= 0)
.ToArray();
foreach (var statsRange in statsRanges)
{
var rowGroupStats = new Dictionary<string, LogicalStatistics>
{
{ "x", new LogicalStatistics<T>(statsRange.min, statsRange.max) }
};

var filterValueInRange = true;
try
{
var longMin = checkedCast(statsRange.min);
if (filterValue < longMin)
{
filterValueInRange = false;
}
}
catch (OverflowException)
{
filterValueInRange = false;
}

try
{
var longMax = checkedCast(statsRange.max);
if (filterValue > longMax)
{
filterValueInRange = false;
}
}
catch (OverflowException)
{
}

var includeRowGroup = filter.IncludeRowGroup(rowGroupStats);
Assert.That(
includeRowGroup, Is.EqualTo(filterValueInRange),
$"Expected {typeof(T)} stats range [{statsRange.min}, {statsRange.max}] inclusion to be {filterValueInRange}");
}
}

private static void TestIntRangeIncludeRowGroup<T>(long rangeStart, long rangeEnd, T[] values, Func<T, long> checkedCast)
where T : IComparable<T>
{
var filter = Col.Named("x").IsInRange(rangeStart, rangeEnd);

var statsRanges = values
.SelectMany(min => values.Select(max => (min, max)))
.Where(range => range.max.CompareTo(range.min) >= 0)
.ToArray();
foreach (var statsRange in statsRanges)
{
var rowGroupStats = new Dictionary<string, LogicalStatistics>
{
{ "x", new LogicalStatistics<T>(statsRange.min, statsRange.max) }
};

var rangesOverlap = true;
try
{
var longMin = checkedCast(statsRange.min);
if (longMin > rangeEnd)
{
rangesOverlap = false;
}
}
catch (OverflowException)
{
rangesOverlap = false;
}

try
{
var longMax = checkedCast(statsRange.max);
if (longMax < rangeStart)
{
rangesOverlap = false;
}
}
catch (OverflowException)
{
}

var includeRowGroup = filter.IncludeRowGroup(rowGroupStats);
Assert.That(
includeRowGroup, Is.EqualTo(rangesOverlap),
$"Expected {typeof(T)} stats range [{statsRange.min}, {statsRange.max}] inclusion to be {rangesOverlap}");
}
}

private static TArray BuildArray<T, TArray, TBuilder>(T[] values)
where T : struct
where TArray : IArrowArray
Expand Down
Loading