Skip to content

Commit 78c2a8c

Browse files
committed
Make root requirement check a decorator
1 parent f3e05e1 commit 78c2a8c

File tree

4 files changed

+67
-49
lines changed

4 files changed

+67
-49
lines changed

src/manage_iocs/commands.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import inspect
2-
import os
32
import socket
43
import sys
54
from subprocess import PIPE, Popen
@@ -75,6 +74,7 @@ def report():
7574
)
7675

7776

77+
@utils.requires_root
7878
def disable(ioc: str):
7979
"""Disable autostart for the given IOC."""
8080

@@ -86,6 +86,7 @@ def disable(ioc: str):
8686
return ret
8787

8888

89+
@utils.requires_root
8990
def enable(ioc: str):
9091
"""Enable autostart for the given IOC."""
9192

@@ -139,6 +140,7 @@ def stopall():
139140
return ret
140141

141142

143+
@utils.requires_root
142144
def enableall():
143145
"""Enable autostart for all IOCs on this host."""
144146

@@ -149,6 +151,7 @@ def enableall():
149151
return ret
150152

151153

154+
@utils.requires_root
152155
def disableall():
153156
"""Disable autostart for all IOCs on this host."""
154157

@@ -170,12 +173,10 @@ def restart(ioc: str):
170173
return ret
171174

172175

176+
@utils.requires_root
173177
def uninstall(ioc: str):
174178
"""Uninstall the given IOC."""
175179

176-
if not os.geteuid() == 0:
177-
raise RuntimeError("You must be root to uninstall an IOC!")
178-
179180
_, _, ret = utils.systemctl_passthrough("stop", ioc)
180181
if ret != 0:
181182
raise RuntimeError(f"Failed to stop IOC '{ioc}' before uninstalling!")
@@ -190,12 +191,10 @@ def uninstall(ioc: str):
190191
return ret
191192

192193

194+
@utils.requires_root
193195
def install(ioc: str):
194196
"""Install the given IOC."""
195197

196-
if not os.geteuid() == 0:
197-
raise RuntimeError("You must be root to install an IOC!")
198-
199198
service_file = utils.SYSTEMD_SERVICE_PATH / f"softioc-{ioc}.service"
200199
ioc_config = utils.find_iocs()[ioc]
201200

src/manage_iocs/utils.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import socket
3+
from collections.abc import Callable
34
from dataclasses import dataclass
45
from pathlib import Path
56
from subprocess import PIPE, Popen
@@ -107,3 +108,16 @@ def get_ioc_statuses(ioc_name: str) -> tuple[int, tuple[str, str]]:
107108
ret = ret + ret_enable
108109

109110
return ret, (status, enabled.capitalize())
111+
112+
113+
def requires_root(func: Callable):
114+
def wrapper(*args, **kwargs):
115+
if os.geteuid() != 0:
116+
raise PermissionError(f"Command {func.__name__} requires root privileges.")
117+
return func(*args, **kwargs)
118+
119+
# Preserve the original function's name and docstring
120+
wrapper.__doc__ = func.__doc__
121+
wrapper.__name__ = func.__name__
122+
123+
return wrapper

tests/conftest.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ def all_manage_iocs_commands():
1414
return [obj for name, obj in inspect.getmembers(cmds) if inspect.isfunction(obj)]
1515

1616

17+
@pytest.fixture(autouse=True)
18+
def sim_running_as_root(monkeypatch):
19+
monkeypatch.setattr(os, "geteuid", lambda: 0) # Mock as root user
20+
21+
1722
@pytest.fixture
1823
def sample_config_file_factory(tmp_path):
1924
def _simple_config_file(

tests/test_commands.py

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -45,26 +45,37 @@ def normalize_whitespace(s: str) -> str:
4545

4646

4747
@pytest.mark.parametrize(
48-
"ioc_name, command, before_state, before_enabled, after_state, after_enabled",
48+
"ioc_name, command, before_state, before_enabled, after_state, after_enabled, as_root",
4949
[
50-
("ioc1", cmds.stop, "Running", "Enabled", "Stopped", "Enabled"),
51-
("ioc3", cmds.stop, "Running", "Disabled", "Stopped", "Disabled"),
52-
("ioc4", cmds.stop, "Stopped", "Disabled", "Stopped", "Disabled"),
53-
("ioc1", cmds.start, "Running", "Enabled", "Running", "Enabled"),
54-
("ioc3", cmds.start, "Running", "Disabled", "Running", "Disabled"),
55-
("ioc4", cmds.start, "Stopped", "Disabled", "Running", "Disabled"),
56-
("ioc3", cmds.enable, "Running", "Disabled", "Running", "Enabled"),
57-
("ioc4", cmds.enable, "Stopped", "Disabled", "Stopped", "Enabled"),
58-
("ioc1", cmds.disable, "Running", "Enabled", "Running", "Disabled"),
59-
("ioc3", cmds.disable, "Running", "Disabled", "Running", "Disabled"),
60-
("ioc1", cmds.restart, "Running", "Enabled", "Running", "Enabled"),
61-
("ioc4", cmds.restart, "Stopped", "Disabled", "Running", "Disabled"),
62-
("ioc3", cmds.restart, "Running", "Disabled", "Running", "Disabled"),
50+
("ioc1", cmds.stop, "Running", "Enabled", "Stopped", "Enabled", False),
51+
("ioc3", cmds.stop, "Running", "Disabled", "Stopped", "Disabled", False),
52+
("ioc4", cmds.stop, "Stopped", "Disabled", "Stopped", "Disabled", False),
53+
("ioc1", cmds.start, "Running", "Enabled", "Running", "Enabled", False),
54+
("ioc3", cmds.start, "Running", "Disabled", "Running", "Disabled", False),
55+
("ioc4", cmds.start, "Stopped", "Disabled", "Running", "Disabled", False),
56+
("ioc3", cmds.enable, "Running", "Disabled", "Running", "Enabled", True),
57+
("ioc4", cmds.enable, "Stopped", "Disabled", "Stopped", "Enabled", True),
58+
("ioc1", cmds.disable, "Running", "Enabled", "Running", "Disabled", True),
59+
("ioc3", cmds.disable, "Running", "Disabled", "Running", "Disabled", True),
60+
("ioc1", cmds.restart, "Running", "Enabled", "Running", "Enabled", False),
61+
("ioc4", cmds.restart, "Stopped", "Disabled", "Running", "Disabled", False),
62+
("ioc3", cmds.restart, "Running", "Disabled", "Running", "Disabled", False),
6363
],
6464
)
6565
def test_state_change_commands(
66-
sample_iocs, ioc_name, command, before_state, before_enabled, after_state, after_enabled
66+
sample_iocs,
67+
ioc_name,
68+
command,
69+
before_state,
70+
before_enabled,
71+
after_state,
72+
after_enabled,
73+
as_root,
74+
monkeypatch,
6775
):
76+
if not as_root:
77+
monkeypatch.setattr(os, "geteuid", lambda: 1000) # Mock as non-root user
78+
6879
_, status = get_ioc_statuses(ioc_name)
6980
assert status == (before_state, before_enabled)
7081

@@ -76,8 +87,6 @@ def test_state_change_commands(
7687

7788

7889
def test_install_new_ioc(sample_iocs, monkeypatch):
79-
monkeypatch.setattr(os, "geteuid", lambda: 0) # Mock as root user
80-
8190
assert "ioc2" not in find_installed_iocs()
8291

8392
rc = cmds.install("ioc2")
@@ -89,32 +98,19 @@ def test_install_new_ioc(sample_iocs, monkeypatch):
8998
assert "ioc2" in find_installed_iocs()
9099

91100

92-
def test_install_ioc_not_root(sample_iocs, monkeypatch):
93-
monkeypatch.setattr(os, "geteuid", lambda: 1000) # Mock as non-root user
94-
95-
with pytest.raises(RuntimeError, match="You must be root to install an IOC!"):
96-
cmds.install("ioc3")
97-
98-
99101
def test_install_ioc_wrong_host(sample_iocs, monkeypatch):
100-
monkeypatch.setattr(os, "geteuid", lambda: 0) # Mock as root user
101-
102102
with pytest.raises(RuntimeError, match="Cannot install IOC 'ioc1' on this host"):
103103
cmds.install("ioc1")
104104

105105

106106
def test_install_already_installed_ioc(sample_iocs, monkeypatch):
107-
monkeypatch.setattr(os, "geteuid", lambda: 0) # Mock as root user
108-
109107
assert "ioc3" in find_installed_iocs()
110108

111109
with pytest.raises(RuntimeError, match="Failed to install IOC 'ioc3'!"):
112110
cmds.install("ioc3")
113111

114112

115113
def test_uninstall_ioc(sample_iocs, monkeypatch):
116-
monkeypatch.setattr(os, "geteuid", lambda: 0) # Mock as root user
117-
118114
assert "ioc5" in find_installed_iocs()
119115

120116
rc = cmds.uninstall("ioc5")
@@ -123,11 +119,15 @@ def test_uninstall_ioc(sample_iocs, monkeypatch):
123119
assert "ioc5" not in find_installed_iocs()
124120

125121

126-
def test_uninstall_ioc_not_root(sample_iocs, monkeypatch):
122+
@pytest.mark.parametrize(
123+
"command",
124+
[cmds.enable, cmds.disable, cmds.enableall, cmds.disableall, cmds.install, cmds.uninstall],
125+
)
126+
def test_requires_root(sample_iocs, monkeypatch, command):
127127
monkeypatch.setattr(os, "geteuid", lambda: 1000) # Mock as non-root user
128128

129-
with pytest.raises(RuntimeError, match="You must be root to uninstall an IOC!"):
130-
cmds.uninstall("ioc1")
129+
with pytest.raises(PermissionError, match="requires root privileges."):
130+
command("ioc1")
131131

132132

133133
@pytest.mark.parametrize(
@@ -155,15 +155,21 @@ def test_state_change_all(sample_iocs, cmd, expected_state, expected_enabled):
155155
assert get_ioc_statuses(ioc.name)[1][1] == expected_enabled
156156

157157

158-
def test_attach(sample_iocs, monkeypatch, dummy_popen):
159-
monkeypatch.setattr(os, "geteuid", lambda: 1000) # Mock as non-root user
158+
@pytest.mark.parametrize("as_root", [True, False])
159+
def test_attach(sample_iocs, monkeypatch, dummy_popen, as_root):
160+
if not as_root:
161+
monkeypatch.setattr(os, "geteuid", lambda: 1000) # Mock as non-root user
160162

161163
ret = cmds.attach("ioc3")
162164

163165
assert ret == ["telnet", "localhost", "3456"]
164166

165167

166-
def test_status(sample_iocs, capsys):
168+
@pytest.mark.parametrize("as_root", [True, False])
169+
def test_status(sample_iocs, capsys, monkeypatch, as_root):
170+
if not as_root:
171+
monkeypatch.setattr(os, "geteuid", lambda: 1000) # Mock as non-root user
172+
167173
rc = cmds.status()
168174
captured = capsys.readouterr()
169175
expected_output = """IOC Status Auto-Start
@@ -194,8 +200,6 @@ def normalize_whitespace(s: str) -> str:
194200
],
195201
)
196202
def test_command_failures(sample_iocs, monkeypatch, cmd, expected_message):
197-
monkeypatch.setattr(os, "geteuid", lambda: 0) # Mock as root user
198-
199203
def failing_systemctl_passthrough(action: str, ioc: str) -> tuple[str, str, int]:
200204
return ("", "Simulated failure", 1)
201205

@@ -214,8 +218,6 @@ def failing_systemctl_passthrough(action: str, ioc: str) -> tuple[str, str, int]
214218
],
215219
)
216220
def test_uninstall_failures(sample_iocs, monkeypatch, failed_action, expected_message):
217-
monkeypatch.setattr(os, "geteuid", lambda: 0) # Mock as root user
218-
219221
def failing_systemctl_passthrough(action: str, ioc: str) -> tuple[str, str, int]:
220222
if action == failed_action:
221223
return ("", "Simulated failure", 1)
@@ -228,8 +230,6 @@ def failing_systemctl_passthrough(action: str, ioc: str) -> tuple[str, str, int]
228230

229231

230232
def test_fail_to_install_ioc_to_run_as_root(sample_iocs, monkeypatch, sample_config_file_factory):
231-
monkeypatch.setattr(os, "geteuid", lambda: 0) # Mock as root user
232-
233233
sample_config_file_factory(name="ioc1", user="root")
234234
with pytest.raises(RuntimeError, match="Refusing to install IOC 'ioc1' to run as user 'root'!"):
235235
cmds.install("ioc1")

0 commit comments

Comments
 (0)