diff --git a/parallel_odcm.py b/parallel_odcm.py index 884088d..dc22c15 100644 --- a/parallel_odcm.py +++ b/parallel_odcm.py @@ -777,6 +777,7 @@ def __init__( # pylint: disable=too-many-locals, too-many-arguments self.total_jobs = len(self.origin_ranges) * len(destination_ranges) self.optimized_cost_field = None + self.df_dest_count = None def _validate_od_settings(self): """Validate OD cost matrix settings before spinning up a bunch of parallel processes doomed to failure. @@ -846,6 +847,7 @@ def solve_od_in_parallel(self): # Post-process outputs if self.od_line_files: self.logger.info("Post-processing OD Cost Matrix results...") + self._check_per_origin_dest_counts() self.od_line_files = sorted(self.od_line_files) if self.output_format is helpers.OutputFormat.featureclass: self._post_process_od_line_fcs() @@ -1033,12 +1035,44 @@ def _post_process_od_line_arrow_files(self): for arrow_file in files_for_origin_range: os.remove(arrow_file) + def _check_per_origin_dest_counts(self): + """Check if the input origins had per-origin TargetDestinationCount values and preserve them in a dataframe.""" + # Check if the input Origins table has a per-origin TargetDestinationCount + desc = arcpy.Describe(self.origins) + if "targetdestinationcount" not in [f.name.lower() for f in desc.fields]: + # No additional processing needed + return + + # Create a dataframe to hold the per-origin TargetDestinationCount values + # Use the OID field because the OriginOID field in the output OD lines corresponds to this + fields = [desc.oidFieldName, "TargetDestinationCount"] + columns = ["OriginOID", "TargetDestinationCount"] + with arcpy.da.SearchCursor(self.origins, fields) as cur2: # pylint: disable=no-member + self.df_dest_count = pd.DataFrame(cur2, columns=columns) + # Use the default number of destinations to find for any nulls + if self.num_destinations: + self.df_dest_count["TargetDestinationCount"].fillna(self.num_destinations, inplace=True) + self.df_dest_count.set_index("OriginOID", inplace=True) + def _update_df_for_k_nearest_and_destination_rank(self, df): """Drop all but the k nearest records for each Origin from the dataframe and calculate DestinationRank.""" # Sort according to OriginOID and cost field df.sort_values(["OriginOID", self.optimized_cost_field], inplace=True) + # Keep only the first k records for each OriginOID - if self.num_destinations: + if self.df_dest_count is not None: + # Preserve the first k destinations for each origin using the per-origin TargetDestination count value + # preserved in the self.df_dest_count dataframe + def drop_rows(group): + """Drop rows in group according to the number specified in TargetDestinationCount.""" + dest_count = self.df_dest_count.loc[group.iloc[0]["OriginOID"]]["TargetDestinationCount"] + if not pd.isna(dest_count): + return group.head(int(dest_count)) + else: + return group + df = df.groupby("OriginOID").apply(lambda g: drop_rows(g)).reset_index(drop=True) + elif self.num_destinations: + # Keep only the first k records for each OriginOID df = df.groupby("OriginOID").head(self.num_destinations).reset_index(drop=True) # Properly calculate the DestinationRank field df["DestinationRank"] = df.groupby("OriginOID").cumcount() + 1 diff --git a/unittests/input_data_helper.py b/unittests/input_data_helper.py index 1ddab0a..055e3fb 100644 --- a/unittests/input_data_helper.py +++ b/unittests/input_data_helper.py @@ -86,6 +86,31 @@ def get_tract_centroids_with_cutoff(sf_gdb): return new_fc +def get_stores_with_dest_count(sf_gdb): + """Create the Stores_DestCount feature class in the SanFrancisco.gdb/Analysis for use in unit tests.""" + new_fc = os.path.join(sf_gdb, "Analysis", "Stores_DestCount") + if arcpy.Exists(new_fc): + # The feature class exists already, so there's no need to do anything. + return new_fc + # Copy the tutorial dataset's Stores feature class to the new feature class + print(f"Creating {new_fc} for test input...") + orig_fc = os.path.join(sf_gdb, "Analysis", "Stores") + if not arcpy.Exists(orig_fc): + raise ValueError(f"{orig_fc} is missing.") + arcpy.management.Copy(orig_fc, new_fc) + # Add and populate the Cutoff field + arcpy.management.AddField(new_fc, "TargetDestinationCount", "LONG") + with arcpy.da.UpdateCursor(new_fc, ["NAME", "TargetDestinationCount"]) as cur: # pylint: disable=no-member + # Give Store_1 a TargetDestinationCount of 3 and Store_2 a TargetDestinationCount of 2 + # and leave the rest as null + for row in cur: + if row[0] == "Store_1": + cur.updateRow([row[0], 3]) + if row[0] == "Store_2": + cur.updateRow([row[0], 2]) + return new_fc + + def get_od_pair_csv(input_data_folder): """Create the od_pairs.csv input file in the input data folder for use in unit testing.""" od_pair_file = os.path.join(input_data_folder, "od_pairs.csv") diff --git a/unittests/test_SolveLargeODCostMatrix_tool.py b/unittests/test_SolveLargeODCostMatrix_tool.py index 48b3e7f..318b5f6 100644 --- a/unittests/test_SolveLargeODCostMatrix_tool.py +++ b/unittests/test_SolveLargeODCostMatrix_tool.py @@ -216,7 +216,7 @@ def test_run_tool_per_origin_cutoff(self): "Feature class", out_od_lines, None, - 1, # cutoff - tiny cutoff that is overridden for one destination + 1, # cutoff - tiny cutoff that is overridden for one origin None, # number of destinations None, # time of day None, # barriers @@ -237,6 +237,58 @@ def test_run_tool_per_origin_cutoff(self): self.assertLessEqual(row[1], 1, "Travel time is out of bounds for origin with with default cutoff") self.assertEqual(13, num_dests, "Incorrect number of destinations found for origin with its own cutoff.") + def test_run_tool_per_origin_dest_count(self): + """Test that the tool correctly uses the TargetDestinationCount field in the input origins layer.""" + # Run tool + origins = input_data_helper.get_stores_with_dest_count(self.sf_gdb) + out_od_lines = os.path.join(self.output_gdb, "PerOriginDestCount_ODLines") + out_origins = os.path.join(self.output_gdb, "PerOriginDestCount_Origins") + out_destinations = os.path.join(self.output_gdb, "PerOriginDestCount_Destinations") + arcpy.LargeNetworkAnalysisTools.SolveLargeODCostMatrix( # pylint: disable=no-member + origins, + self.destinations, + self.local_nd, + self.local_tm_time, + "Minutes", + "Miles", + 10, # chunk size + 4, # max processes + out_origins, + out_destinations, + "Feature class", + out_od_lines, + None, + None, # cutoff + 1, # number of destinations - overridden for one origin + None, # time of day + None, # barriers + True, # precalculate network locations + True # Spatially sort inputs + ) + # Check results + self.assertTrue(arcpy.Exists(out_od_lines)) + self.assertTrue(arcpy.Exists(out_origins)) + self.assertTrue(arcpy.Exists(out_destinations)) + self.assertEqual(28, int(arcpy.management.GetCount(out_od_lines).getOutput(0)), "Incorrect number of OD lines") + # Check Store_1, which should have 3 destinations + num_rows = 0 + prev_time = 0 + for row in arcpy.da.SearchCursor(out_od_lines, [ "Total_Time", "DestinationRank"], "OriginName = 'Store_1'"): + num_rows += 1 + self.assertEqual(num_rows, row[1], "Incorrect DestinationRank value for Store_1") + self.assertGreater(row[0], prev_time, "Total_Time value for Store_1 isn't increasing") + prev_time = row[0] + self.assertEqual(3, num_rows, "Incorrect number of destinations found for Store_1") + # Check Store_2, which should have 2 destinations + num_rows = 0 + prev_time = 0 + for row in arcpy.da.SearchCursor(out_od_lines, [ "Total_Time", "DestinationRank"], "OriginName = 'Store_2'"): + num_rows += 1 + self.assertEqual(num_rows, row[1], "Incorrect DestinationRank value for Store_2") + self.assertGreater(row[0], prev_time, "Total_Time value for Store_2 isn't increasing") + prev_time = row[0] + self.assertEqual(2, num_rows, "Incorrect number of destinations found for Store_2") + def test_error_required_output_od_lines(self): """Test for correct error when output format is Feature class and output OD Lines not specified.""" with self.assertRaises(arcpy.ExecuteError) as ex: