diff --git a/src/main/java/org/apache/commons/lang3/ArrayUtils.java b/src/main/java/org/apache/commons/lang3/ArrayUtils.java index 56b5df638c2..6f2a34d01f2 100644 --- a/src/main/java/org/apache/commons/lang3/ArrayUtils.java +++ b/src/main/java/org/apache/commons/lang3/ArrayUtils.java @@ -1433,7 +1433,8 @@ public static T arraycopy(final T source, final int sourcePos, final T dest, } /** - * Searches element in array sorted by key. + * Searches element in array sorted by key. If there are multiple elements matching, it returns first occurrence. + * If the array is not sorted, the result is undefined. * * @param array * array sorted by key field @@ -1445,25 +1446,26 @@ public static T arraycopy(final T source, final int sourcePos, final T dest, * comparator for keys * * @return - * index of the search key, if it is contained in the array; otherwise, (-first_greater - 1). - * The first_greater is the index of lowest greater element in the list - if all elements are lower, the - * first_greater is defined as array.length. + * index of the first occurrence of search key, if it is contained in the array; otherwise, + * (-first_greater - 1). The first_greater is the index of lowest greater element in the list - if all elements + * are lower, the first_greater is defined as array.length. * * @param * type of array element * @param * type of key */ - public static int binarySearch( + public static int binarySearchFirst( T[] array, K key, Function keyExtractor, Comparator comparator ) { - return binarySearch0(array, 0, array.length, key, keyExtractor, comparator); + return binarySearchFirst0(array, 0, array.length, key, keyExtractor, comparator); } /** - * Searches element in array sorted by key, within range fromIndex (inclusive) - toIndex (exclusive). + * Searches element in array sorted by key, within range fromIndex (inclusive) - toIndex (exclusive). If there are + * multiple elements matching, it returns first occurrence. If the array is not sorted, the result is undefined. * * @param array * array sorted by key field @@ -1479,9 +1481,9 @@ public static int binarySearch( * comparator for keys * * @return - * index of the search key, if it is contained in the array within specified range; otherwise, - * (-first_greater - 1). The first_greater is the index of lowest greater element in the list - if all elements - * are lower, the first_greater is defined as toIndex. + * index of the first occurrence of search key, if it is contained in the array within specified range; + * otherwise, (-first_greater - 1). The first_greater is the index of lowest greater element in the list - if + * all elements are lower, the first_greater is defined as toIndex. * * @throws ArrayIndexOutOfBoundsException * when fromIndex or toIndex is out of array range @@ -1493,28 +1495,124 @@ public static int binarySearch( * @param * type of key */ - public static int binarySearch( + public static int binarySearchFirst( T[] array, int fromIndex, int toIndex, K key, Function keyExtractor, Comparator comparator ) { - if (fromIndex > toIndex) { - throw new IllegalArgumentException( - "fromIndex(" + fromIndex + ") > toIndex(" + toIndex + ")"); - } - if (fromIndex < 0) { - throw new ArrayIndexOutOfBoundsException(fromIndex); - } - if (toIndex > array.length) { - throw new ArrayIndexOutOfBoundsException(toIndex); + checkRange(array.length, fromIndex, toIndex); + + return binarySearchFirst0(array, fromIndex, toIndex, key, keyExtractor, comparator); + } + + // common implementation for binarySearch methods, with same semantics: + private static int binarySearchFirst0( + T[] array, + int fromIndex, int toIndex, + K key, + Function keyExtractor, Comparator comparator + ) { + int l = fromIndex; + int h = toIndex - 1; + + while (l <= h) { + final int m = (l + h) >>> 1; // unsigned shift to avoid overflow + final K value = keyExtractor.apply(array[m]); + final int c = comparator.compare(value, key); + if (c < 0) { + l = m + 1; + } else if (c > 0) { + h = m - 1; + } else if (l < h) { + // possibly multiple matching items remaining: + h = m; + } else { + // single matching item remaining: + return m; + } } - return binarySearch0(array, fromIndex, toIndex, key, keyExtractor, comparator); + // not found, the l points to the lowest higher match: + return -l - 1; + } + + /** + * Searches element in array sorted by key. If there are multiple elements matching, it returns last occurrence. + * If the array is not sorted, the result is undefined. + * + * @param array + * array sorted by key field + * @param key + * key to search for + * @param keyExtractor + * function to extract key from element + * @param comparator + * comparator for keys + * + * @return + * index of the last occurrence of search key, if it is contained in the array; otherwise, + * (-first_greater - 1). The first_greater is the index of lowest greater element in the list - if all elements + * are lower, the first_greater is defined as array.length. + * + * @param + * type of array element + * @param + * type of key + */ + public static int binarySearchLast( + T[] array, + K key, + Function keyExtractor, Comparator comparator + ) { + return binarySearchLast0(array, 0, array.length, key, keyExtractor, comparator); + } + + /** + * Searches element in array sorted by key, within range fromIndex (inclusive) - toIndex (exclusive). If there are + * multiple elements matching, it returns last occurrence. If the array is not sorted, the result is undefined. + * + * @param array + * array sorted by key field + * @param fromIndex + * start index (inclusive) + * @param toIndex + * end index (exclusive) + * @param key + * key to search for + * @param keyExtractor + * function to extract key from element + * @param comparator + * comparator for keys + * + * @return + * index of the last occurrence of search key, if it is contained in the array within specified range; + * otherwise, (-first_greater - 1). The first_greater is the index of lowest greater element in the list - if + * all elements are lower, the first_greater is defined as toIndex. + * + * @throws ArrayIndexOutOfBoundsException + * when fromIndex or toIndex is out of array range + * @throws IllegalArgumentException + * when fromIndex is greater than toIndex + * + * @param + * type of array element + * @param + * type of key + */ + public static int binarySearchLast( + T[] array, + int fromIndex, int toIndex, + K key, + Function keyExtractor, Comparator comparator + ) { + checkRange(array.length, fromIndex, toIndex); + + return binarySearchLast0(array, fromIndex, toIndex, key, keyExtractor, comparator); } // common implementation for binarySearch methods, with same semantics: - private static int binarySearch0( + private static int binarySearchLast0( T[] array, int fromIndex, int toIndex, K key, @@ -1531,8 +1629,16 @@ private static int binarySearch0( l = m + 1; } else if (c > 0) { h = m - 1; + } else if (m + 1 < h) { + // matching, more than two items remaining: + l = m; + } else if (m + 1 == h) { + // two items remaining, next loops would result in unchanged l and h, we have to choose m or h: + final K valueH = keyExtractor.apply(array[h]); + final int cH = comparator.compare(valueH, key); + return cH == 0 ? h : m; } else { - // 0, found + // one item remaining, single match: return m; } } @@ -9573,4 +9679,18 @@ public static String[] toStringArray(final Object[] array, final String valueFor public ArrayUtils() { // empty } + + static void checkRange(int length, int fromIndex, int toIndex) { + if (fromIndex > toIndex) { + throw new IllegalArgumentException( + "fromIndex(" + fromIndex + ") > toIndex(" + toIndex + ")"); + } + if (fromIndex < 0) { + throw new ArrayIndexOutOfBoundsException(fromIndex); + } + if (toIndex > length) { + throw new ArrayIndexOutOfBoundsException(toIndex); + } + + } } diff --git a/src/test/java/org/apache/commons/lang3/ArrayUtilsBinarySearchTest.java b/src/test/java/org/apache/commons/lang3/ArrayUtilsBinarySearchTest.java index f04dba91721..62219ae829c 100644 --- a/src/test/java/org/apache/commons/lang3/ArrayUtilsBinarySearchTest.java +++ b/src/test/java/org/apache/commons/lang3/ArrayUtilsBinarySearchTest.java @@ -30,63 +30,99 @@ public class ArrayUtilsBinarySearchTest extends AbstractLangTest { @Test - public void binarySearch_whenLowHigherThanEnd_throw() { + public void binarySearchFirst_whenLowHigherThanEnd_throw() { final Data[] list = createList(0, 1); - assertThrowsExactly(IllegalArgumentException.class, () -> ArrayUtils.binarySearch(list, 1, 0, 0, Data::getValue, Integer::compare)); + assertThrowsExactly(IllegalArgumentException.class, () -> + ArrayUtils.binarySearchFirst(list, 1, 0, 0, Data::getValue, Integer::compare)); } @Test - public void binarySearch_whenLowNegative_throw() { + public void binarySearchFirst_whenLowNegative_throw() { final Data[] list = createList(0, 1); - assertThrowsExactly(ArrayIndexOutOfBoundsException.class, () -> ArrayUtils.binarySearch(list, -1, 0, 0, Data::getValue, Integer::compare)); + assertThrowsExactly(ArrayIndexOutOfBoundsException.class, () -> + ArrayUtils.binarySearchFirst(list, -1, 0, 0, Data::getValue, Integer::compare)); } @Test - public void binarySearch_whenEndBeyondLength_throw() { + public void binarySearchFirst_whenEndBeyondLength_throw() { final Data[] list = createList(0, 1); - assertThrowsExactly(ArrayIndexOutOfBoundsException.class, () -> ArrayUtils.binarySearch(list, 0, 3, 0, Data::getValue, Integer::compare)); + assertThrowsExactly(ArrayIndexOutOfBoundsException.class, () -> + ArrayUtils.binarySearchFirst(list, 0, 3, 0, Data::getValue, Integer::compare)); } @Test - public void binarySearch_whenEmpty_returnM1() { + public void binarySearchLast_whenLowHigherThanEnd_throw() { + final Data[] list = createList(0, 1); + assertThrowsExactly(IllegalArgumentException.class, () -> + ArrayUtils.binarySearchLast(list, 1, 0, 0, Data::getValue, Integer::compare)); + } + + @Test + public void binarySearchFirst_whenEmpty_returnM1() { final Data[] list = createList(); - final int found = ArrayUtils.binarySearch(list, 0, Data::getValue, Integer::compare); + final int found = ArrayUtils.binarySearchFirst(list, 0, Data::getValue, Integer::compare); assertEquals(-1, found); } @Test - public void binarySearch_whenExists_returnIndex() { + public void binarySearchFirst_whenExists_returnIndex() { final Data[] list = createList(0, 1, 2, 4, 7, 9, 12, 15, 17, 19, 25); - final int found = ArrayUtils.binarySearch(list, 9, Data::getValue, Integer::compare); + final int found = ArrayUtils.binarySearchFirst(list, 9, Data::getValue, Integer::compare); assertEquals(5, found); } @Test - public void binarySearch_whenNotExistsMiddle_returnMinusInsertion() { + @Timeout(10) + public void binarySearchFirst_whenMultiple_returnFirst() { + final Data[] list = createList(3, 4, 6, 6, 6, 7, 7, 8, 8, 9, 9, 9); + for (int i = 0; i < list.length; ++i) { + if (i > 0 && list[i].value == list[i - 1].value) { + continue; + } + final int found = ArrayUtils.binarySearchFirst(list, list[i].value, Data::getValue, Integer::compare); + assertEquals(i, found); + } + } + + @Test + @Timeout(10) + public void binarySearchLast_whenMultiple_returnFirst() { + final Data[] list = createList(3, 4, 6, 6, 6, 7, 7, 8, 8, 9, 9, 9); + for (int i = 0; i < list.length; ++i) { + if (i < list.length - 1 && list[i].value == list[i + 1].value) { + continue; + } + final int found = ArrayUtils.binarySearchLast(list, list[i].value, Data::getValue, Integer::compare); + assertEquals(i, found); + } + } + + @Test + public void binarySearchFirst_whenNotExistsMiddle_returnMinusInsertion() { final Data[] list = createList(0, 1, 2, 4, 7, 9, 12, 15, 17, 19, 25); - final int found = ArrayUtils.binarySearch(list, 8, Data::getValue, Integer::compare); + final int found = ArrayUtils.binarySearchFirst(list, 8, Data::getValue, Integer::compare); assertEquals(-6, found); } @Test - public void binarySearch_whenNotExistsBeginning_returnMinus1() { + public void binarySearchFirst_whenNotExistsBeginning_returnMinus1() { final Data[] list = createList(0, 1, 2, 4, 7, 9, 12, 15, 17, 19, 25); - final int found = ArrayUtils.binarySearch(list, -3, Data::getValue, Integer::compare); + final int found = ArrayUtils.binarySearchFirst(list, -3, Data::getValue, Integer::compare); assertEquals(-1, found); } @Test - public void binarySearch_whenNotExistsEnd_returnMinusLength() { + public void binarySearchFirst_whenNotExistsEnd_returnMinusLength() { final Data[] list = createList(0, 1, 2, 4, 7, 9, 12, 15, 17, 19, 25); - final int found = ArrayUtils.binarySearch(list, 29, Data::getValue, Integer::compare); + final int found = ArrayUtils.binarySearchFirst(list, 29, Data::getValue, Integer::compare); assertEquals(-(list.length + 1), found); } @Test @Timeout(10) - public void binarySearch_whenUnsorted_dontInfiniteLoop() { + public void binarySearchFirst_whenUnsorted_dontInfiniteLoop() { final Data[] list = createList(7, 1, 4, 9, 11, 8); - final int found = ArrayUtils.binarySearch(list, 10, Data::getValue, Integer::compare); + final int found = ArrayUtils.binarySearchFirst(list, 10, Data::getValue, Integer::compare); } private Data[] createList(int... values) {