Skip to content

Commit 50b2d68

Browse files
committed
Final edits to figures and figure scripts
1 parent c9a619c commit 50b2d68

23 files changed

+372
-92
lines changed

.gitignore

+2-1
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ dev_files
99

1010
# These might be added later when the full repo is complete
1111
mile_marker.txt
12-
checker.ipynb
12+
checker.ipynb
13+
*.pkl

figure_and_table_generation/figure_scripts/box_share_VA.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,9 @@
168168
ax.legend(
169169
handles=handles,
170170
labels=[
171-
"RRC CD 12 (5B proposed)",
172-
"RRC CD 16 (5B proposed)",
173-
"RRC Rand Plan (5B proposed)",
171+
"RevReCom Seed 1 (5B proposed)",
172+
"RevReCom Seed 2 (5B proposed)",
173+
"RevReCom Seed 3 (5B proposed)",
174174
"Forest (10M proposed)",
175175
],
176176
loc="upper left",

figure_and_table_generation/figure_scripts/helper_files/wasserstein_trace_tally.py

+117
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,120 @@ def wasserstein_trace_v_full(
169169
xticks.append(step)
170170
trace.append(distance)
171171
return xticks, trace
172+
173+
174+
def wasserstein_trace_shares(shares1_df, shares2_df, weights1, weights2, resolution):
175+
"""
176+
Computes the Wasserstein trace between a full ensemble and an ongoing ensemble.
177+
That is, given a full dataframe of counts and weights which generates some distribution,
178+
and some ongoing array of counts and weights, the Wasserstein trace contains the Wasserstein
179+
between the ongoing distribution and the totality of the full distribution at each step
180+
equal to the resolution.
181+
182+
Parameters
183+
----------
184+
shares_df : pandas.DataFrame
185+
The dataframe of shares for the ongoing ensemble.
186+
full_df : pandas.DataFrame
187+
The dataframe of shares for the full ensemble.
188+
weights : pandas.Series
189+
The weights for the ongoing ensemble.
190+
weights_full : pandas.Series
191+
The weights for the full ensemble.
192+
resolution : int
193+
The resolution of the trace.
194+
195+
Returns
196+
-------
197+
(array-like, array-like):
198+
The xticks for use in plotting and the trace of the Wasserstein distances.
199+
"""
200+
assert all(shares1_df.columns == shares2_df.columns)
201+
202+
shares1 = shares1_df.sort_index(axis=1).to_numpy()
203+
shares2 = shares2_df.sort_index(axis=1).to_numpy()
204+
205+
n_districts = len(shares1[0])
206+
207+
assert shares1_df.shape == shares2_df.shape
208+
209+
state1 = np.zeros(n_districts)
210+
state2 = np.zeros(n_districts)
211+
xticks = []
212+
trace = []
213+
hist1 = [Counter() for _ in range(n_districts)]
214+
hist2 = [Counter() for _ in range(n_districts)]
215+
216+
for step, (s1, w1, s2, w2) in enumerate(
217+
tqdm(zip(shares1, weights1, shares2, weights2), total=shares1.shape[0])
218+
):
219+
# We assume 1-indexed districts.
220+
for dist, v in enumerate(s1):
221+
state1[dist] = v
222+
for k, v in enumerate(sorted(state1)):
223+
hist1[k][v] += w1
224+
for dist, v in enumerate(s2):
225+
state2[dist] = v
226+
for k, v in enumerate(sorted(state2)):
227+
hist2[k][v] += w2
228+
if step > 0 and step % resolution == 0:
229+
distance = 0
230+
for dist1, dist2 in zip(hist1, hist2):
231+
distance += wasserstein_distance(
232+
list(dist1.keys()),
233+
list(dist2.keys()),
234+
list(dist1.values()),
235+
list(dist2.values()),
236+
)
237+
xticks.append(step)
238+
trace.append(distance)
239+
return xticks, trace
240+
241+
242+
def wasserstein_trace_shares(shares1_df, shares2_df, weights1, weights2, resolution):
243+
"""
244+
Computes the Wasserstein trace between a full ensemble and an ongoing ensemble.
245+
"""
246+
# Ensure that the dataframes have the same columns
247+
assert all(shares1_df.columns == shares2_df.columns)
248+
249+
# Convert dataframes to numpy arrays (columns sorted)
250+
shares1 = shares1_df.sort_index(axis=1).to_numpy()
251+
shares2 = shares2_df.sort_index(axis=1).to_numpy()
252+
253+
n_districts = shares1.shape[1]
254+
assert shares1_df.shape == shares2_df.shape
255+
256+
xticks = []
257+
trace = []
258+
# Initialize a counter per district for each ensemble
259+
hist1 = [Counter() for _ in range(n_districts)]
260+
hist2 = [Counter() for _ in range(n_districts)]
261+
262+
for step, (s1, w1, s2, w2) in enumerate(
263+
tqdm(zip(shares1, weights1, shares2, weights2), total=shares1.shape[0])
264+
):
265+
# Directly sort the current row using NumPy
266+
sorted_s1 = np.sort(s1)
267+
for k, v in enumerate(sorted_s1):
268+
hist1[k][v] += w1
269+
270+
sorted_s2 = np.sort(s2)
271+
for k, v in enumerate(sorted_s2):
272+
hist2[k][v] += w2
273+
274+
# Compute the Wasserstein trace at the specified resolution
275+
if step > 0 and step % resolution == 0:
276+
distance = sum(
277+
wasserstein_distance(
278+
list(dist1.keys()),
279+
list(dist2.keys()),
280+
list(dist1.values()),
281+
list(dist2.values()),
282+
)
283+
for dist1, dist2 in zip(hist1, hist2)
284+
)
285+
xticks.append(step)
286+
trace.append(distance)
287+
288+
return xticks, trace

figure_and_table_generation/figure_scripts/histogram_comparison_50x50.py

+20-19
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@
2525
"#80b1d3",
2626
]
2727

28+
script_dir = Path(__file__).resolve().parent
29+
top_dir = script_dir.parents[1]
30+
2831

2932
def make_method_plot(rrc_forest, lower, upper, n_dists, methods="forest_rrc"):
3033
"""
@@ -50,7 +53,7 @@ def make_method_plot(rrc_forest, lower, upper, n_dists, methods="forest_rrc"):
5053
"""
5154

5255
methods = methods.replace(" ", "_")
53-
out_path = Path("../figures")
56+
out_path = Path(f"{script_dir}/../figures")
5457

5558
_, ax = plt.subplots(figsize=(25, 10), dpi=400)
5659

@@ -73,7 +76,7 @@ def make_method_plot(rrc_forest, lower, upper, n_dists, methods="forest_rrc"):
7376
ax.set_xticks(list(range(lower, upper, 50)))
7477
ax.set_xticklabels([str(i) for i in range(lower, upper, 50)], fontsize=16)
7578
ax.set_yticks([])
76-
ax.legend(loc="right", bbox_to_anchor=(1.18, 0.5), prop={"size": 16})
79+
ax.legend(loc="right", bbox_to_anchor=(1.21, 0.5), prop={"size": 16})
7780
plt.savefig(
7881
out_path.joinpath(f"50x50_{n_dists}_dist_{methods}_comparison.png"),
7982
bbox_inches="tight",
@@ -100,7 +103,7 @@ def make_recom_plot(lower, upper, n_dists, glob_expr):
100103
-------
101104
None
102105
"""
103-
out_path = Path("../figures")
106+
out_path = Path(f"{script_dir}/../figures")
104107
data_path = Path(f"{top_dir}/hpc_files/hpc_processed_data/50x50")
105108
_, ax = plt.subplots(figsize=(25, 10), dpi=500)
106109

@@ -114,8 +117,6 @@ def make_recom_plot(lower, upper, n_dists, glob_expr):
114117
lst = file.name.split("_")
115118
all_recom_files[f"{lst[1]} (1B proposed)"] = file
116119

117-
all_recom_files
118-
119120
for i, (n, f) in enumerate(all_recom_files.items()):
120121
df = pd.read_parquet(f)
121122
prob_df = df.groupby("cut_edges").sum().reset_index()
@@ -128,7 +129,7 @@ def make_recom_plot(lower, upper, n_dists, glob_expr):
128129
edgecolor=None,
129130
color=colors[i + 3],
130131
alpha=0.8,
131-
label=n.replace("Recom", "ReCom "),
132+
label=n.replace("ReCom", "ReCom-"),
132133
)
133134

134135
ax.set_xlim(lower - 20, upper + 20)
@@ -147,43 +148,43 @@ def make_recom_plot(lower, upper, n_dists, glob_expr):
147148
top_dir = script_dir.parents[1]
148149

149150
rrc_forest_10 = {
150-
"RRC (10B proposed)": f"{top_dir}/hpc_files/hpc_processed_data/50x50/50x50_RevReCom_steps_10000000000_plan_50x5_strip_20240618_174413_cut_edges.parquet",
151+
"RevReCom (10B proposed)": f"{top_dir}/hpc_files/hpc_processed_data/50x50/50x50_RevReCom_steps_10000000000_plan_50x5_strip_20240618_174413_cut_edges.parquet",
151152
"Forest (10M proposed)": f"{top_dir}/hpc_files/hpc_processed_data/50x50/50x50_Forest_steps_10000000_rng_seed_278986_gamma_0.0_alpha_1.0_ndists_10_20240830_142334_cut_edges.parquet",
152153
}
153154

154155
rrc_forest_smc_10 = rrc_forest_10.copy()
155-
rrc_forest_smc_10["SMC (100k Samples)"] = (
156+
rrc_forest_smc_10["SMC (100K)"] = (
156157
f"{top_dir}/hpc_files/hpc_processed_data/50x50/50x50_SMC_batch_size_100000_rng_seed_278986_dists_10_20250129_150813_cut_edges.parquet"
157158
)
158159

159-
make_method_plot(rrc_forest_10, 350, 601, 10, "forest_rrc")
160-
make_method_plot(rrc_forest_smc_10, 350, 601, 10, "forest_rrc_smc")
160+
# make_method_plot(rrc_forest_10, 350, 601, 10, "forest_rrc")
161+
# make_method_plot(rrc_forest_smc_10, 350, 601, 10, "forest_rrc_smc")
161162
make_recom_plot(350, 601, 10, "*_ReCom*50x5_*")
162163

163164
rrc_forest_25 = {
164-
"RRC (10B proposed)": f"{top_dir}/hpc_files/hpc_processed_data/50x50/50x50_RevReCom_steps_10000000000_plan_10x10_square_20240618_174413_cut_edges.parquet",
165+
"RevReCom (10B proposed)": f"{top_dir}/hpc_files/hpc_processed_data/50x50/50x50_RevReCom_steps_10000000000_plan_10x10_square_20240618_174413_cut_edges.parquet",
165166
"Forest (10M proposed)": f"{top_dir}/hpc_files/hpc_processed_data/50x50/50x50_Forest_steps_10000000_rng_seed_278986_gamma_0.0_alpha_1.0_ndists_25_20240830_142334_cut_edges.parquet",
166167
}
167168

168169
rrc_forest_smc_25 = rrc_forest_25.copy()
169-
rrc_forest_smc_25["SMC (100k Samples)"] = (
170+
rrc_forest_smc_25["SMC (100K)"] = (
170171
f"{top_dir}/hpc_files/hpc_processed_data/50x50/50x50_SMC_batch_size_100000_rng_seed_278986_dists_25_20250129_150813_cut_edges.parquet"
171172
)
172173

173-
make_method_plot(rrc_forest_25, 650, 880, 25, "forest_rrc")
174-
make_method_plot(rrc_forest_smc_25, 650, 880, 25, "forest_rrc_smc")
175-
make_recom_plot(650, 880, 25, "*_ReCom*10x10_*")
174+
# make_method_plot(rrc_forest_25, 650, 851, 25, "forest_rrc")
175+
# make_method_plot(rrc_forest_smc_25, 650, 851, 25, "forest_rrc_smc")
176+
make_recom_plot(650, 851, 25, "*_ReCom*10x10_*")
176177

177178
rrc_forest_50 = {
178-
"RRC (10B proposed)": f"{top_dir}/hpc_files/hpc_processed_data/50x50/50x50_RevReCom_steps_10000000000_plan_50x1_strip_20240618_174413_cut_edges.parquet",
179+
"RevReCom (10B proposed)": f"{top_dir}/hpc_files/hpc_processed_data/50x50/50x50_RevReCom_steps_10000000000_plan_50x1_strip_20240618_174413_cut_edges.parquet",
179180
"Forest (10M proposed)": f"{top_dir}/hpc_files/hpc_processed_data/50x50/50x50_Forest_steps_10000000_rng_seed_278986_gamma_0.0_alpha_1.0_ndists_50_20240830_142334_cut_edges.parquet",
180181
}
181182

182183
rrc_forest_smc_50 = rrc_forest_50.copy()
183-
rrc_forest_smc_50["SMC (100k Samples)"] = (
184+
rrc_forest_smc_50["SMC (100K)"] = (
184185
f"{top_dir}/hpc_files/hpc_processed_data/50x50/50x50_SMC_batch_size_100000_rng_seed_278986_dists_50_20250129_150813_cut_edges.parquet"
185186
)
186187

187-
make_method_plot(rrc_forest_50, 900, 1180, 50, "forest_rrc")
188-
make_method_plot(rrc_forest_smc_50, 900, 1180, 50, "forest_rrc_smc")
188+
# make_method_plot(rrc_forest_50, 900, 1151, 50, "forest_rrc")
189+
# make_method_plot(rrc_forest_smc_50, 900, 1151, 50, "forest_rrc_smc")
189190
make_recom_plot(900, 1180, 50, "*_ReCom*50x1_*")

0 commit comments

Comments
 (0)