From 73318f39273df3080297eb402f77e66badf0c9f2 Mon Sep 17 00:00:00 2001 From: Martin Junghanns Date: Fri, 17 Nov 2023 12:53:43 +0100 Subject: [PATCH] Derive node count from node values if present --- crates/builder/src/graph/csr.rs | 55 ++++++++++++++++++++++++++++++--- 1 file changed, 50 insertions(+), 5 deletions(-) diff --git a/crates/builder/src/graph/csr.rs b/crates/builder/src/graph/csr.rs index 5b37270..346cb25 100644 --- a/crates/builder/src/graph/csr.rs +++ b/crates/builder/src/graph/csr.rs @@ -543,13 +543,14 @@ where { fn from((node_values, edge_list, csr_option): (NodeValues, E, CsrLayout)) -> Self { info!("Creating directed graph"); - let node_count = edge_list.max_node_id() + NI::new(1); + let node_count = NI::new(node_values.0.len()); + let node_count_from_edge_list = edge_list.max_node_id() + NI::new(1); assert!( - node_values.0.len() >= node_count.index(), + node_count >= node_count_from_edge_list, "number of node values ({}) does not match node count of edge list ({})", - node_values.0.len(), - node_count.index() + node_count.index(), + node_count_from_edge_list.index() ); let start = Instant::now(); @@ -743,7 +744,15 @@ where { fn from((node_values, edge_list, csr_option): (NodeValues, E, CsrLayout)) -> Self { info!("Creating undirected graph"); - let node_count = edge_list.max_node_id() + NI::new(1); + let node_count = NI::new(node_values.0.len()); + let node_count_from_edge_list = edge_list.max_node_id() + NI::new(1); + + assert!( + node_count >= node_count_from_edge_list, + "number of node values ({}) does not match node count of edge list ({})", + node_count.index(), + node_count_from_edge_list.index() + ); let start = Instant::now(); let csr = Csr::from((&edge_list, node_count, Direction::Undirected, csr_option)); @@ -1193,4 +1202,40 @@ mod tests { assert_eq!(ug.neighbors(0).as_slice(), &[1, 3, 7, 21, 42]); }); } + + #[test] + fn directed_from_node_values_exceeding_edge_list_max_id() { + let g0: DirectedCsrGraph = GraphBuilder::new() + .edges(vec![(0, 1), (1, 2)]) + .node_values(vec![0, 1, 2, 3]) + .build(); + + assert_eq!(g0.node_count(), 4); + for node in 0..4 { + assert_eq!(g0.node_value(node), &node); + } + + assert_eq!(g0.out_degree(0), 1); + assert_eq!(g0.out_degree(1), 1); + assert_eq!(g0.out_degree(2), 0); + assert_eq!(g0.out_degree(3), 0); + } + + #[test] + fn undirected_from_node_values_exceeding_edge_list_max_id() { + let g0: UndirectedCsrGraph = GraphBuilder::new() + .edges(vec![(0, 1), (1, 2)]) + .node_values(vec![0, 1, 2, 3]) + .build(); + + assert_eq!(g0.node_count(), 4); + for node in 0..4 { + assert_eq!(g0.node_value(node), &node); + } + + assert_eq!(g0.degree(0), 1); + assert_eq!(g0.degree(1), 2); + assert_eq!(g0.degree(2), 1); + assert_eq!(g0.degree(3), 0); + } }