Skip to content

Commit 712a3e4

Browse files
committed
refactor imports and fix light mode inputs in MatterViz widgets in dark VSCode Jupyter notebooks
- fix widget asset paths in `matterviz.py`
1 parent a15ae29 commit 712a3e4

File tree

15 files changed

+154
-187
lines changed

15 files changed

+154
-187
lines changed

assets/scripts/cluster/composition/cluster_compositions_matbench.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,7 @@
2424
from pymatgen.util.string import htmlify
2525

2626
import pymatviz as pmv
27-
from pymatviz.cluster.composition import (
28-
EmbeddingMethod,
29-
matminer_featurize,
30-
one_hot_encode,
31-
)
27+
import pymatviz.cluster.composition as pcc
3228
from pymatviz.enums import Key
3329

3430

@@ -64,7 +60,7 @@ def process_dataset(
6460
target_key: str,
6561
target_label: str,
6662
target_symbol: str,
67-
embed_method: EmbeddingMethod,
63+
embed_method: pcc.EmbeddingMethod,
6864
projection: ProjectionMethod,
6965
n_components: int,
7066
**kwargs: Any,
@@ -113,9 +109,9 @@ def process_dataset(
113109
if embeddings_dict is None:
114110
# Create embeddings
115111
if embed_method == "one-hot":
116-
embeddings = one_hot_encode(compositions)
112+
embeddings = pcc.one_hot_encode(compositions)
117113
elif embed_method in ["magpie", "matscholar_el"]:
118-
embeddings = matminer_featurize(compositions, preset=embed_method)
114+
embeddings = pcc.matminer_featurize(compositions, preset=embed_method)
119115
else:
120116
raise ValueError(f"Unknown {embed_method=}")
121117

@@ -226,7 +222,9 @@ def annotate_top_points(row: pd.Series) -> dict[str, Any] | None:
226222
"K<sub>VRH</sub>",
227223
)
228224
plot_combinations: list[
229-
tuple[str, str, str, str, EmbeddingMethod, ProjectionMethod, int, dict[str, Any]]
225+
tuple[
226+
str, str, str, str, pcc.EmbeddingMethod, ProjectionMethod, int, dict[str, Any]
227+
]
230228
] = [
231229
# 1. Steels with PCA (2D) - shows clear linear trends
232230
(*mb_steels, "magpie", "pca", 2, dict(x=0.01, xanchor="left")),

examples/matbench/dielectric/explore_dielectric.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,7 @@
1414

1515
# %%
1616
import plotly.express as px
17-
from matbench_discovery.structure.prototype import (
18-
count_wyckoff_positions,
19-
get_protostructure_label,
20-
)
17+
from matbench_discovery.structure import prototype
2118
from matminer.datasets import load_dataset
2219
from tqdm import tqdm
2320

@@ -36,10 +33,12 @@
3633
proto_label_key = f"{Key.protostructure}_moyo"
3734
n_wyckoff_pos_key = f"{Key.n_wyckoff_pos}_moyo"
3835
df_diel[proto_label_key] = [
39-
get_protostructure_label(struct)
36+
prototype.get_protostructure_label(struct)
4037
for struct in tqdm(df_diel[Key.structure], desc="Getting Wyckoff strings")
4138
]
42-
df_diel[n_wyckoff_pos_key] = df_diel[proto_label_key].map(count_wyckoff_positions)
39+
df_diel[n_wyckoff_pos_key] = df_diel[proto_label_key].map(
40+
prototype.count_wyckoff_positions
41+
)
4342

4443
df_diel[Key.crystal_system] = df_diel[Key.spg_num].map(pmv.utils.spg_to_crystal_sys)
4544

examples/matbench/log_g+kvrh/explore_log_g+krvh.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,7 @@
1515
import numpy as np
1616
import pandas as pd
1717
import plotly.express as px
18-
from matbench_discovery.structure.prototype import (
19-
count_wyckoff_positions,
20-
get_protostructure_label,
21-
)
18+
from matbench_discovery.structure import prototype
2219
from matminer.datasets import load_dataset
2320
from pymatgen.core import Structure
2421
from tqdm import tqdm
@@ -37,7 +34,7 @@
3734
)
3835
df_sym[Key.crystal_system] = df_sym["number"].map(pmv.utils.spg_to_crystal_sys)
3936
df_grvh[Key.protostructure] = [
40-
get_protostructure_label(struct)
37+
prototype.get_protostructure_label(struct)
4138
for struct in tqdm(df_grvh[Key.structure], desc="matbench_log_gvrh Wyckoff strings")
4239
]
4340
df_kvrh[Key.protostructure] = df_grvh[Key.protostructure]
@@ -46,7 +43,9 @@
4643
df[[Key.spg_num, Key.wyckoff_symbols]] = df_sym[["number", "wyckoffs"]]
4744
df[Key.crystal_system] = df_sym[Key.crystal_system]
4845

49-
df[Key.n_wyckoff_pos] = df[Key.protostructure].map(count_wyckoff_positions)
46+
df[Key.n_wyckoff_pos] = df[Key.protostructure].map(
47+
prototype.count_wyckoff_positions
48+
)
5049
df[Key.formula] = [x.formula for x in df[Key.structure]]
5150

5251

pymatviz/structure/helpers.py

Lines changed: 16 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,7 @@
1616
from pymatgen.core.periodic_table import Element
1717
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
1818

19-
from pymatviz.colors import (
20-
ELEM_COLORS_ALLOY,
21-
ELEM_COLORS_JMOL,
22-
ELEM_COLORS_PASTEL,
23-
ELEM_COLORS_VESTA,
24-
)
19+
from pymatviz import colors
2520
from pymatviz.enums import ElemColorScheme, Key, SiteCoords
2621
from pymatviz.typing import Xyz
2722
from pymatviz.utils import df_ptable, pick_max_contrast_color
@@ -103,7 +98,7 @@ def get_struct_prop(
10398
)
10499

105100

106-
def _get_site_symbol(site: PeriodicSite) -> str:
101+
def get_site_symbol(site: PeriodicSite) -> str:
107102
"""Get a single element symbol for a site.
108103
109104
Handles disordered sites by picking the element with the highest fraction.
@@ -268,14 +263,8 @@ def get_elem_colors(
268263
"""Get element colors based on the provided scheme or custom dictionary."""
269264
if isinstance(elem_colors, dict):
270265
return elem_colors
271-
if str(elem_colors) == str(ElemColorScheme.jmol):
272-
return ELEM_COLORS_JMOL # type: ignore[return-value]
273-
if str(elem_colors) == str(ElemColorScheme.vesta):
274-
return ELEM_COLORS_VESTA # type: ignore[return-value]
275-
if str(elem_colors) == str(ElemColorScheme.alloy):
276-
return ELEM_COLORS_ALLOY # type: ignore[return-value]
277-
if str(elem_colors) == str(ElemColorScheme.pastel):
278-
return ELEM_COLORS_PASTEL # type: ignore[return-value]
266+
if color_dict := getattr(colors, f"ELEM_COLORS_{str(elem_colors).upper()}", None):
267+
return color_dict
279268
raise ValueError(
280269
f"colors must be a dict or one of ('{', '.join(ElemColorScheme)}')"
281270
)
@@ -313,12 +302,12 @@ def generate_site_label(
313302
return None
314303

315304
if site_labels == "symbol":
316-
return _get_site_symbol(site)
305+
return get_site_symbol(site)
317306
if site_labels == "species":
318307
return site.species_string # Use full species string for disordered
319308

320309
label_text = ""
321-
symbol = _get_site_symbol(site) # Majority element symbol of site
310+
symbol = get_site_symbol(site) # Majority element symbol of site
322311

323312
if isinstance(site_labels, dict):
324313
# Use provided label for symbol, else symbol itself, or empty if not found &
@@ -423,7 +412,7 @@ def format_coord(coord_val: float) -> str:
423412
return out_text
424413

425414

426-
def _process_element_color(raw_color_from_map: ColorType) -> str:
415+
def normalize_elem_color(raw_color_from_map: ColorType) -> str:
427416
"""Process a color from the element color map into a consistent RGB string format.
428417
429418
Args:
@@ -538,7 +527,7 @@ def draw_site(
538527
raw_color_from_map = elem_colors.get(majority_species.symbol, "gray")
539528

540529
# Process the color from the map into a string format
541-
atom_color = _process_element_color(raw_color_from_map)
530+
atom_color = normalize_elem_color(raw_color_from_map)
542531

543532
site_hover_text = get_site_hover_text(site, hover_text, majority_species, float_fmt)
544533

@@ -582,7 +571,7 @@ def draw_site(
582571
fig.add_scatter(**scatter_kwargs, row=row, col=col)
583572

584573

585-
def _create_disordered_site_legend_name(
574+
def get_disordered_site_legend_name(
586575
sorted_species: list[tuple[Species, float]], *, is_image: bool = False
587576
) -> str:
588577
"""Create a legend name for a disordered site showing all elements with occupancies.
@@ -699,7 +688,7 @@ def draw_disordered_site(
699688
sorted_species = sorted(species.items(), key=lambda x: x[1], reverse=True)
700689

701690
# Create a combined legend name showing all elements with occupancies
702-
legend_name = _create_disordered_site_legend_name(sorted_species, is_image=is_image)
691+
legend_name = get_disordered_site_legend_name(sorted_species, is_image=is_image)
703692

704693
# Set up legendgroup - use site_idx if not provided to group all parts together
705694
if legendgroup is None:
@@ -724,7 +713,7 @@ def draw_disordered_site(
724713
elem_symbol = element_species.symbol
725714
raw_color_from_map = elem_colors.get(elem_symbol, "gray")
726715

727-
atom_color = _process_element_color(raw_color_from_map)
716+
atom_color = normalize_elem_color(raw_color_from_map)
728717

729718
# Calculate the angular span for this species based on occupancy
730719
angle_span = 2 * np.pi * occupancy
@@ -735,7 +724,7 @@ def draw_disordered_site(
735724

736725
# Generate the spherical wedge mesh
737726
x_coords, y_coords, z_coords, i_indices, j_indices, k_indices = (
738-
_generate_spherical_wedge_mesh(
727+
get_spherical_wedge_mesh(
739728
center=coords,
740729
radius=wedge_radius,
741730
start_angle=current_angle,
@@ -847,7 +836,7 @@ def draw_disordered_site(
847836
for species_idx, (element_species, occupancy) in enumerate(sorted_species):
848837
elem_symbol = element_species.symbol
849838
raw_color_from_map = elem_colors.get(elem_symbol, "gray")
850-
atom_color = _process_element_color(raw_color_from_map)
839+
atom_color = normalize_elem_color(raw_color_from_map)
851840

852841
# Calculate angular width for this species based on occupancy
853842
angular_width = 2 * math.pi * occupancy
@@ -984,7 +973,7 @@ def draw_disordered_site(
984973
MAX_PIE_SLICE_POINTS = 20 # Base points for 2D pie slices
985974

986975

987-
def _generate_spherical_wedge_mesh(
976+
def get_spherical_wedge_mesh(
988977
center: np.ndarray,
989978
radius: float,
990979
start_angle: float,
@@ -1496,8 +1485,8 @@ def parse_color(color_val: Any) -> str:
14961485

14971486
if current_bond_color_setting is True:
14981487
# Default gradient: use element colors of bonded sites
1499-
elem1_symbol = _get_site_symbol(site1)
1500-
elem2_symbol = _get_site_symbol(site2)
1488+
elem1_symbol = get_site_symbol(site1)
1489+
elem2_symbol = get_site_symbol(site2)
15011490
color1_rgb_str = parse_color(_elem_colors.get(elem1_symbol, "gray"))
15021491
color2_rgb_str = parse_color(_elem_colors.get(elem2_symbol, "gray"))
15031492
color_for_segment_calc = (color1_rgb_str, color2_rgb_str)

pymatviz/structure/plotly.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def structure_2d(
242242
zip(struct_i, rotated_coords_all_sites, strict=False)
243243
):
244244
# Determine legend parameters for primary sites
245-
symbol = helpers._get_site_symbol(site)
245+
symbol = helpers.get_site_symbol(site)
246246
legendgroup = None
247247
showlegend = False
248248
if site_labels == "legend":
@@ -621,7 +621,7 @@ def structure_3d(
621621
else site_idx_loop % len(struct_i)
622622
)
623623

624-
symbol = helpers._get_site_symbol(site)
624+
symbol = helpers.get_site_symbol(site)
625625

626626
# Determine legend parameters for primary sites only
627627
legendgroup = None

pymatviz/widgets/matterviz.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ def fetch_widget_asset(filename: str, version_override: str | None = None) -> st
3434
asset_version = version_override or PKG_VERSION
3535

3636
# Paths
37-
local_path = f"{os.path.dirname(__file__)}/build/{filename}"
37+
local_path = f"{os.path.dirname(__file__)}/web/build/{filename}"
3838
cache_dir = f"{os.path.expanduser('~/.cache/pymatviz/build')}/{asset_version}"
3939
os.makedirs(cache_dir, exist_ok=True)
4040
cache_path = f"{cache_dir}/{filename}"

pymatviz/widgets/mime.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,9 @@ def _register_renderers() -> None:
8989

9090
def register_matterviz_widgets() -> None:
9191
"""Register widgets for automatic display."""
92-
from pymatviz.process_data import (
93-
is_composition_like,
94-
is_structure_like,
95-
is_trajectory_like,
96-
)
97-
98-
WIDGET_REGISTRY[is_trajectory_like] = "trajectory"
99-
WIDGET_REGISTRY[is_composition_like] = "composition"
100-
WIDGET_REGISTRY[is_structure_like] = "structure"
92+
from pymatviz import process_data as pd
93+
94+
WIDGET_REGISTRY[pd.is_trajectory_like] = StructureType
95+
WIDGET_REGISTRY[pd.is_composition_like] = CompositionType
96+
WIDGET_REGISTRY[pd.is_structure_like] = StructureType
10197
_register_renderers()

pymatviz/widgets/web/anywidget.ts

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,23 @@ function inject_app_css(theme_type?: ThemeType): void {
3232
background-color: transparent !important;
3333
}
3434
35+
/* Dark mode input styling for Jupyter notebooks and interactive windows in VSCode */
36+
:is(.vscode-dark, .dark-theme, [data-jp-theme-light="false"]) :is(input, textarea, select) {
37+
background-color: #2d2d2d;
38+
color: #ffffff;
39+
border: 1px solid #555555;
40+
border-radius: 4px;
41+
padding: 6px 8px;
42+
}
43+
:is(.vscode-dark, .dark-theme, [data-jp-theme-light="false"]) :is(input, textarea, select):focus {
44+
outline: none;
45+
border-color: #007acc;
46+
box-shadow: 0 0 0 2px rgba(0, 122, 204, 0.2);
47+
}
48+
:is(.vscode-dark, .dark-theme, [data-jp-theme-light="false"]) :is(input, textarea)::placeholder {
49+
color: #888888;
50+
}
51+
3552
/* scope global styles to matterviz widgets to prevent site styles leaking into notebook styles */
3653
/* this is brittle, will break should component CSS classes in matterviz change, try to find better solution */
3754
div:is(.structure, .trajectory, .composition) {

pymatviz/widgets/web/deno.jsonc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
{
22
"tasks": {
33
"dev": "deno run -A --node-modules-dir npm:vite dev",
4-
"build": "rm -rf build && deno run -A --node-modules-dir npm:vite build",
5-
"ext-install": "deno task build && find . -name 'matterviz-v*.vsix' -type f | head -1 | xargs -I {} code --install-extension {}"
4+
"build": "rm -rf build && deno run -A --node-modules-dir npm:vite build"
65
},
76
"lock": false,
87
"lint": {

tests/files/.pytest-split-durations

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@
489489
"tests/structure/test_structure_helpers.py::test_angles_to_rotation_matrix[30x,45y,60z-expected_shape1]": 0.00016524898819625378,
490490
"tests/structure/test_structure_helpers.py::test_angles_to_rotation_matrix_invalid_input": 0.00013866700464859605,
491491
"tests/structure/test_structure_helpers.py::test_constants": 9.791701450012624e-05,
492-
"tests/structure/test_structure_helpers.py::test_create_disordered_site_legend_name": 0.00012099897139705718,
492+
"tests/structure/test_structure_helpers.py::test_get_disordered_site_legend_name": 0.00012099897139705718,
493493
"tests/structure/test_structure_helpers.py::test_draw_bonds[False-bond_kwargs2-plotted_sites_coords_param2-0]": 0.0062469589756801724,
494494
"tests/structure/test_structure_helpers.py::test_draw_bonds[False-bond_kwargs4-None-16]": 0.008371708041522652,
495495
"tests/structure/test_structure_helpers.py::test_draw_bonds[True-None-None-160]": 0.026459541986696422,

0 commit comments

Comments
 (0)