Skip to content

Commit

Permalink
Support TargetDestinationCount field in input Origins
Browse files Browse the repository at this point in the history
  • Loading branch information
mmorang committed Nov 20, 2024
1 parent 7eede79 commit a4f1851
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 2 deletions.
36 changes: 35 additions & 1 deletion parallel_odcm.py
Original file line number Diff line number Diff line change
Expand Up @@ -777,6 +777,7 @@ def __init__( # pylint: disable=too-many-locals, too-many-arguments
self.total_jobs = len(self.origin_ranges) * len(destination_ranges)

self.optimized_cost_field = None
self.df_dest_count = None

def _validate_od_settings(self):
"""Validate OD cost matrix settings before spinning up a bunch of parallel processes doomed to failure.
Expand Down Expand Up @@ -846,6 +847,7 @@ def solve_od_in_parallel(self):
# Post-process outputs
if self.od_line_files:
self.logger.info("Post-processing OD Cost Matrix results...")
self._check_per_origin_dest_counts()
self.od_line_files = sorted(self.od_line_files)
if self.output_format is helpers.OutputFormat.featureclass:
self._post_process_od_line_fcs()
Expand Down Expand Up @@ -1033,12 +1035,44 @@ def _post_process_od_line_arrow_files(self):
for arrow_file in files_for_origin_range:
os.remove(arrow_file)

def _check_per_origin_dest_counts(self):
"""Check if the input origins had per-origin TargetDestinationCount values and preserve them in a dataframe."""
# Check if the input Origins table has a per-origin TargetDestinationCount
desc = arcpy.Describe(self.origins)
if "targetdestinationcount" not in [f.name.lower() for f in desc.fields]:
# No additional processing needed
return

# Create a dataframe to hold the per-origin TargetDestinationCount values
# Use the OID field because the OriginOID field in the output OD lines corresponds to this
fields = [desc.oidFieldName, "TargetDestinationCount"]
columns = ["OriginOID", "TargetDestinationCount"]
with arcpy.da.SearchCursor(self.origins, fields) as cur2: # pylint: disable=no-member
self.df_dest_count = pd.DataFrame(cur2, columns=columns)
# Use the default number of destinations to find for any nulls
if self.num_destinations:
self.df_dest_count["TargetDestinationCount"].fillna(self.num_destinations, inplace=True)
self.df_dest_count.set_index("OriginOID", inplace=True)

def _update_df_for_k_nearest_and_destination_rank(self, df):
"""Drop all but the k nearest records for each Origin from the dataframe and calculate DestinationRank."""
# Sort according to OriginOID and cost field
df.sort_values(["OriginOID", self.optimized_cost_field], inplace=True)

# Keep only the first k records for each OriginOID
if self.num_destinations:
if self.df_dest_count is not None:
# Preserve the first k destinations for each origin using the per-origin TargetDestination count value
# preserved in the self.df_dest_count dataframe
def drop_rows(group):
"""Drop rows in group according to the number specified in TargetDestinationCount."""
dest_count = self.df_dest_count.loc[group.iloc[0]["OriginOID"]]["TargetDestinationCount"]
if not pd.isna(dest_count):
return group.head(int(dest_count))
else:
return group
df = df.groupby("OriginOID").apply(lambda g: drop_rows(g)).reset_index(drop=True)
elif self.num_destinations:
# Keep only the first k records for each OriginOID
df = df.groupby("OriginOID").head(self.num_destinations).reset_index(drop=True)
# Properly calculate the DestinationRank field
df["DestinationRank"] = df.groupby("OriginOID").cumcount() + 1
Expand Down
25 changes: 25 additions & 0 deletions unittests/input_data_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,31 @@ def get_tract_centroids_with_cutoff(sf_gdb):
return new_fc


def get_stores_with_dest_count(sf_gdb):
"""Create the Stores_DestCount feature class in the SanFrancisco.gdb/Analysis for use in unit tests."""
new_fc = os.path.join(sf_gdb, "Analysis", "Stores_DestCount")
if arcpy.Exists(new_fc):
# The feature class exists already, so there's no need to do anything.
return new_fc
# Copy the tutorial dataset's Stores feature class to the new feature class
print(f"Creating {new_fc} for test input...")
orig_fc = os.path.join(sf_gdb, "Analysis", "Stores")
if not arcpy.Exists(orig_fc):
raise ValueError(f"{orig_fc} is missing.")
arcpy.management.Copy(orig_fc, new_fc)
# Add and populate the Cutoff field
arcpy.management.AddField(new_fc, "TargetDestinationCount", "LONG")
with arcpy.da.UpdateCursor(new_fc, ["NAME", "TargetDestinationCount"]) as cur: # pylint: disable=no-member
# Give Store_1 a TargetDestinationCount of 3 and Store_2 a TargetDestinationCount of 2
# and leave the rest as null
for row in cur:
if row[0] == "Store_1":
cur.updateRow([row[0], 3])
if row[0] == "Store_2":
cur.updateRow([row[0], 2])
return new_fc


def get_od_pair_csv(input_data_folder):
"""Create the od_pairs.csv input file in the input data folder for use in unit testing."""
od_pair_file = os.path.join(input_data_folder, "od_pairs.csv")
Expand Down
54 changes: 53 additions & 1 deletion unittests/test_SolveLargeODCostMatrix_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def test_run_tool_per_origin_cutoff(self):
"Feature class",
out_od_lines,
None,
1, # cutoff - tiny cutoff that is overridden for one destination
1, # cutoff - tiny cutoff that is overridden for one origin
None, # number of destinations
None, # time of day
None, # barriers
Expand All @@ -237,6 +237,58 @@ def test_run_tool_per_origin_cutoff(self):
self.assertLessEqual(row[1], 1, "Travel time is out of bounds for origin with with default cutoff")
self.assertEqual(13, num_dests, "Incorrect number of destinations found for origin with its own cutoff.")

def test_run_tool_per_origin_dest_count(self):
"""Test that the tool correctly uses the TargetDestinationCount field in the input origins layer."""
# Run tool
origins = input_data_helper.get_stores_with_dest_count(self.sf_gdb)
out_od_lines = os.path.join(self.output_gdb, "PerOriginDestCount_ODLines")
out_origins = os.path.join(self.output_gdb, "PerOriginDestCount_Origins")
out_destinations = os.path.join(self.output_gdb, "PerOriginDestCount_Destinations")
arcpy.LargeNetworkAnalysisTools.SolveLargeODCostMatrix( # pylint: disable=no-member
origins,
self.destinations,
self.local_nd,
self.local_tm_time,
"Minutes",
"Miles",
10, # chunk size
4, # max processes
out_origins,
out_destinations,
"Feature class",
out_od_lines,
None,
None, # cutoff
1, # number of destinations - overridden for one origin
None, # time of day
None, # barriers
True, # precalculate network locations
True # Spatially sort inputs
)
# Check results
self.assertTrue(arcpy.Exists(out_od_lines))
self.assertTrue(arcpy.Exists(out_origins))
self.assertTrue(arcpy.Exists(out_destinations))
self.assertEqual(28, int(arcpy.management.GetCount(out_od_lines).getOutput(0)), "Incorrect number of OD lines")
# Check Store_1, which should have 3 destinations
num_rows = 0
prev_time = 0
for row in arcpy.da.SearchCursor(out_od_lines, [ "Total_Time", "DestinationRank"], "OriginName = 'Store_1'"):
num_rows += 1
self.assertEqual(num_rows, row[1], "Incorrect DestinationRank value for Store_1")
self.assertGreater(row[0], prev_time, "Total_Time value for Store_1 isn't increasing")
prev_time = row[0]
self.assertEqual(3, num_rows, "Incorrect number of destinations found for Store_1")
# Check Store_2, which should have 2 destinations
num_rows = 0
prev_time = 0
for row in arcpy.da.SearchCursor(out_od_lines, [ "Total_Time", "DestinationRank"], "OriginName = 'Store_2'"):
num_rows += 1
self.assertEqual(num_rows, row[1], "Incorrect DestinationRank value for Store_2")
self.assertGreater(row[0], prev_time, "Total_Time value for Store_2 isn't increasing")
prev_time = row[0]
self.assertEqual(2, num_rows, "Incorrect number of destinations found for Store_2")

def test_error_required_output_od_lines(self):
"""Test for correct error when output format is Feature class and output OD Lines not specified."""
with self.assertRaises(arcpy.ExecuteError) as ex:
Expand Down

0 comments on commit a4f1851

Please sign in to comment.