Skip to content

Commit 7dbb175

Browse files
authored
Derive node count from node values if present (#117)
1 parent 77e74bf commit 7dbb175

File tree

1 file changed

+50
-5
lines changed
  • crates/builder/src/graph

1 file changed

+50
-5
lines changed

crates/builder/src/graph/csr.rs

Lines changed: 50 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -543,13 +543,14 @@ where
543543
{
544544
fn from((node_values, edge_list, csr_option): (NodeValues<NV>, E, CsrLayout)) -> Self {
545545
info!("Creating directed graph");
546-
let node_count = edge_list.max_node_id() + NI::new(1);
546+
let node_count = NI::new(node_values.0.len());
547+
let node_count_from_edge_list = edge_list.max_node_id() + NI::new(1);
547548

548549
assert!(
549-
node_values.0.len() >= node_count.index(),
550+
node_count >= node_count_from_edge_list,
550551
"number of node values ({}) does not match node count of edge list ({})",
551-
node_values.0.len(),
552-
node_count.index()
552+
node_count.index(),
553+
node_count_from_edge_list.index()
553554
);
554555

555556
let start = Instant::now();
@@ -743,7 +744,15 @@ where
743744
{
744745
fn from((node_values, edge_list, csr_option): (NodeValues<NV>, E, CsrLayout)) -> Self {
745746
info!("Creating undirected graph");
746-
let node_count = edge_list.max_node_id() + NI::new(1);
747+
let node_count = NI::new(node_values.0.len());
748+
let node_count_from_edge_list = edge_list.max_node_id() + NI::new(1);
749+
750+
assert!(
751+
node_count >= node_count_from_edge_list,
752+
"number of node values ({}) does not match node count of edge list ({})",
753+
node_count.index(),
754+
node_count_from_edge_list.index()
755+
);
747756

748757
let start = Instant::now();
749758
let csr = Csr::from((&edge_list, node_count, Direction::Undirected, csr_option));
@@ -1193,4 +1202,40 @@ mod tests {
11931202
assert_eq!(ug.neighbors(0).as_slice(), &[1, 3, 7, 21, 42]);
11941203
});
11951204
}
1205+
1206+
#[test]
1207+
fn directed_from_node_values_exceeding_edge_list_max_id() {
1208+
let g0: DirectedCsrGraph<u32, u32> = GraphBuilder::new()
1209+
.edges(vec![(0, 1), (1, 2)])
1210+
.node_values(vec![0, 1, 2, 3])
1211+
.build();
1212+
1213+
assert_eq!(g0.node_count(), 4);
1214+
for node in 0..4 {
1215+
assert_eq!(g0.node_value(node), &node);
1216+
}
1217+
1218+
assert_eq!(g0.out_degree(0), 1);
1219+
assert_eq!(g0.out_degree(1), 1);
1220+
assert_eq!(g0.out_degree(2), 0);
1221+
assert_eq!(g0.out_degree(3), 0);
1222+
}
1223+
1224+
#[test]
1225+
fn undirected_from_node_values_exceeding_edge_list_max_id() {
1226+
let g0: UndirectedCsrGraph<u32, u32> = GraphBuilder::new()
1227+
.edges(vec![(0, 1), (1, 2)])
1228+
.node_values(vec![0, 1, 2, 3])
1229+
.build();
1230+
1231+
assert_eq!(g0.node_count(), 4);
1232+
for node in 0..4 {
1233+
assert_eq!(g0.node_value(node), &node);
1234+
}
1235+
1236+
assert_eq!(g0.degree(0), 1);
1237+
assert_eq!(g0.degree(1), 2);
1238+
assert_eq!(g0.degree(2), 1);
1239+
assert_eq!(g0.degree(3), 0);
1240+
}
11961241
}

0 commit comments

Comments
 (0)