Skip to content

Initial commit - seperate retriever #830

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 41 additions & 17 deletions gtsfm/runner/gtsfm_runner_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,12 @@ def construct_argparser(self) -> argparse.ArgumentParser:
default=3,
help="Number of times to retry cluster connection if it fails.",
)
parser.add_argument(
"--num_clusters",
type=int,
default=2,
help="Number of images clusters to create for parallelizing reconstruction.",
)
return parser

@abstractmethod
Expand Down Expand Up @@ -271,12 +277,33 @@ def setup_ssh_cluster_with_retries(self) -> SSHCluster:
" attempts. Aborting..."
)
return cluster

def get_image_pair_indices(self, client):
retriever_start_time = time.time()
with performance_report(filename="retriever-dask-report.html"):
image_pair_indices = self.scene_optimizer.image_pairs_generator.generate_image_pairs(
client=client,
images=self.loader.get_all_images_as_futures(client),
image_fnames=self.loader.image_filenames(),
plots_output_dir=self.scene_optimizer._plot_base_path,
)

retriever_metrics = self.scene_optimizer.image_pairs_generator._retriever.evaluate(
len(self.loader), image_pair_indices
)
retriever_duration_sec = time.time() - retriever_start_time
retriever_metrics.add_metric(GtsfmMetric("retriever_duration_sec", retriever_duration_sec))
logger.info("Image pair retrieval took %.2f sec.", retriever_duration_sec)

print("Total number of image pairs are", len(image_pair_indices))

return retriever_metrics, image_pair_indices


def run(self) -> GtsfmData:
"""Run the SceneOptimizer."""
start_time = time.time()

# Create dask cluster.
if self.parsed_args.cluster_config:
cluster = self.setup_ssh_cluster_with_retries()
client = Client(cluster)
Expand All @@ -294,29 +321,26 @@ def run(self) -> GtsfmData:
local_cluster_kwargs["memory_limit"] = self.parsed_args.worker_memory_limit
cluster = LocalCluster(**local_cluster_kwargs)
client = Client(cluster)

retriever_metrics, image_pair_indices = self.get_image_pair_indices(client)

clusters = [ [tuple(pair) for pair in arr.tolist()] for arr in np.array_split(image_pair_indices, self.parsed_args.num_clusters)]

for graph_cluster in clusters:
self.optimize_scene(client, retriever_metrics, graph_cluster)

def optimize_scene(self, client, retriever_metrics,image_pair_indices)->GtsfmData:
print("Running scene optimizer with number of image pairs:", len(image_pair_indices))
start_time = time.time()

# Create process graph.
process_graph_generator = ProcessGraphGenerator()
if isinstance(self.scene_optimizer.correspondence_generator, ImageCorrespondenceGenerator):
process_graph_generator.is_image_correspondence = True
process_graph_generator.save_graph()

retriever_start_time = time.time()
with performance_report(filename="retriever-dask-report.html"):
image_pair_indices = self.scene_optimizer.image_pairs_generator.generate_image_pairs(
client=client,
images=self.loader.get_all_images_as_futures(client),
image_fnames=self.loader.image_filenames(),
plots_output_dir=self.scene_optimizer._plot_base_path,
)

retriever_metrics = self.scene_optimizer.image_pairs_generator._retriever.evaluate(
len(self.loader), image_pair_indices
)
retriever_duration_sec = time.time() - retriever_start_time
retriever_metrics.add_metric(GtsfmMetric("retriever_duration_sec", retriever_duration_sec))
logger.info("Image pair retrieval took %.2f sec.", retriever_duration_sec)

#split image_pair_indices into two and call this function for them.

intrinsics = self.loader.get_all_intrinsics()

with performance_report(filename="correspondence-generator-dask-report.html"):
Expand Down
Loading