Skip to content
Merged
127 changes: 127 additions & 0 deletions ParquetSharp.Dataset.Test/Filter/TestIntFilter.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
using System;
using System.Collections.Generic;
using System.Linq;
using Apache.Arrow;
using NUnit.Framework;
using ParquetSharp.Dataset.Filter;

namespace ParquetSharp.Dataset.Test.Filter;

Expand Down Expand Up @@ -128,6 +130,33 @@ public void TestComputeIntRangeMask((long, long) filterRange)
TestComputeIntRangeMask<ulong, UInt64Array, UInt64Array.Builder>(rangeStart, rangeEnd, ULongValues, val => checked((long)val));
}

[Theory]
public void TestIntEqualityIncludeRowGroup(long filterValue)
{
TestIntEqualityIncludeRowGroup(filterValue, SByteValues, val => val);
TestIntEqualityIncludeRowGroup(filterValue, ShortValues, val => val);
TestIntEqualityIncludeRowGroup(filterValue, IntValues, val => val);
TestIntEqualityIncludeRowGroup(filterValue, LongValues, val => val);
TestIntEqualityIncludeRowGroup(filterValue, ByteValues, val => val);
TestIntEqualityIncludeRowGroup(filterValue, UShortValues, val => val);
TestIntEqualityIncludeRowGroup(filterValue, UIntValues, val => val);
TestIntEqualityIncludeRowGroup(filterValue, ULongValues, val => checked((long)val));
}

[Theory]
public void TestIntRangeIncludeRowGroup((long, long) filterRange)
{
var (rangeStart, rangeEnd) = filterRange;
TestIntRangeIncludeRowGroup(rangeStart, rangeEnd, SByteValues, val => val);
TestIntRangeIncludeRowGroup(rangeStart, rangeEnd, ShortValues, val => val);
TestIntRangeIncludeRowGroup(rangeStart, rangeEnd, IntValues, val => val);
TestIntRangeIncludeRowGroup(rangeStart, rangeEnd, LongValues, val => val);
TestIntRangeIncludeRowGroup(rangeStart, rangeEnd, ByteValues, val => val);
TestIntRangeIncludeRowGroup(rangeStart, rangeEnd, UShortValues, val => val);
TestIntRangeIncludeRowGroup(rangeStart, rangeEnd, UIntValues, val => val);
TestIntRangeIncludeRowGroup(rangeStart, rangeEnd, ULongValues, val => checked((long)val));
}

private static void TestComputeIntEqualityMask<T, TArray, TBuilder>(long filterValue, T[] values, Func<T, long> checkedCast)
where T : struct
where TArray : PrimitiveArray<T>
Expand Down Expand Up @@ -203,6 +232,104 @@ private static void TestComputeIntRangeMask<T, TArray, TBuilder>(long rangeStart
}
}

private static void TestIntEqualityIncludeRowGroup<T>(long filterValue, T[] values, Func<T, long> checkedCast)
where T : IComparable<T>
{
var filter = Col.Named("x").IsEqualTo(filterValue);

var statsRanges = values
.SelectMany(min => values.Select(max => (min, max)))
.Where(range => range.max.CompareTo(range.min) >= 0)
.ToArray();
foreach (var statsRange in statsRanges)
{
var rowGroupStats = new Dictionary<string, LogicalStatistics>
{
{ "x", new LogicalStatistics<T>(statsRange.min, statsRange.max) }
};

var filterValueInRange = true;
try
{
var longMin = checkedCast(statsRange.min);
if (filterValue < longMin)
{
filterValueInRange = false;
}
}
catch (OverflowException)
{
filterValueInRange = false;
}

try
{
var longMax = checkedCast(statsRange.max);
if (filterValue > longMax)
{
filterValueInRange = false;
}
}
catch (OverflowException)
{
}

var includeRowGroup = filter.IncludeRowGroup(rowGroupStats);
Assert.That(
includeRowGroup, Is.EqualTo(filterValueInRange),
$"Expected {typeof(T)} stats range [{statsRange.min}, {statsRange.max}] inclusion to be {filterValueInRange}");
}
}

private static void TestIntRangeIncludeRowGroup<T>(long rangeStart, long rangeEnd, T[] values, Func<T, long> checkedCast)
where T : IComparable<T>
{
var filter = Col.Named("x").IsInRange(rangeStart, rangeEnd);

var statsRanges = values
.SelectMany(min => values.Select(max => (min, max)))
.Where(range => range.max.CompareTo(range.min) >= 0)
.ToArray();
foreach (var statsRange in statsRanges)
{
var rowGroupStats = new Dictionary<string, LogicalStatistics>
{
{ "x", new LogicalStatistics<T>(statsRange.min, statsRange.max) }
};

var rangesOverlap = true;
try
{
var longMin = checkedCast(statsRange.min);
if (longMin > rangeEnd)
{
rangesOverlap = false;
}
}
catch (OverflowException)
{
rangesOverlap = false;
}

try
{
var longMax = checkedCast(statsRange.max);
if (longMax < rangeStart)
{
rangesOverlap = false;
}
}
catch (OverflowException)
{
}

var includeRowGroup = filter.IncludeRowGroup(rowGroupStats);
Assert.That(
includeRowGroup, Is.EqualTo(rangesOverlap),
$"Expected {typeof(T)} stats range [{statsRange.min}, {statsRange.max}] inclusion to be {rangesOverlap}");
}
}

private static TArray BuildArray<T, TArray, TBuilder>(T[] values)
where T : struct
where TArray : IArrowArray
Expand Down
Loading