|
| 1 | +from typing import List |
| 2 | +import unittest |
| 3 | + |
| 4 | +class Solution: |
| 5 | + def findMedianSortedArrays(self, nums1: List[int], nums2: List[int]) -> float: |
| 6 | + l = len(nums1) + len(nums2) |
| 7 | + if l % 2 == 1: |
| 8 | + return self.getKth(nums1, nums2, l // 2) |
| 9 | + else: |
| 10 | + return (self.getKth(nums1, nums2, l // 2) + self.getKth(nums1, nums2, l // 2 - 1)) / 2 |
| 11 | + |
| 12 | + def getKth(self, nums1: List[int], nums2: List[int], k: int) -> float: |
| 13 | + if not nums1: |
| 14 | + return nums2[k] |
| 15 | + if not nums2: |
| 16 | + return nums1[k] |
| 17 | + mid1, mid2 = len(nums1) // 2, len(nums2) // 2 |
| 18 | + |
| 19 | + if mid1 + mid2 < k: |
| 20 | + # if a's median is bigger than b's, b's first half doesn't include k |
| 21 | + if nums1[mid1] > nums2[mid2]: |
| 22 | + return self.getKth(nums1, nums2[mid2+1:], k - mid2 - 1) |
| 23 | + else: |
| 24 | + return self.getKth(nums1[mid1+1:], nums2, k - mid1 - 1) |
| 25 | + else: |
| 26 | + if nums1[mid1] > nums2[mid2]: |
| 27 | + return self.getKth(nums1[:mid1], nums2, k) |
| 28 | + else: |
| 29 | + return self.getKth(nums1, nums2[:mid2], k) |
| 30 | + |
| 31 | +class TestSolution(unittest.TestCase): |
| 32 | + def testFindMedianSortedArrays(self): |
| 33 | + sol = Solution() |
| 34 | + self.assertEqual(sol.findMedianSortedArrays([1,2], [3]), 2.0) |
| 35 | + self.assertEqual(sol.findMedianSortedArrays([1,2], [3,4]), 2.5) |
| 36 | + |
| 37 | +if __name__ == '__main__': |
| 38 | + unittest.main() |
0 commit comments