|
16 | 16 | from aiida.cmdline.utils.template_config import ( |
17 | 17 | _process_content, |
18 | 18 | load_and_process_template, |
| 19 | + parse_template_vars, |
19 | 20 | ) |
20 | 21 |
|
21 | 22 |
|
| 23 | +def _fake_response(text, *, ok=True): |
| 24 | + """Build a fake ``requests.Response`` for monkeypatching ``requests.get``.""" |
| 25 | + import requests |
| 26 | + |
| 27 | + class _Response: |
| 28 | + status_code = 200 if ok else 500 |
| 29 | + |
| 30 | + def raise_for_status(self): |
| 31 | + if not ok: |
| 32 | + raise requests.ConnectionError(text) |
| 33 | + |
| 34 | + _Response.text = text |
| 35 | + return _Response() |
| 36 | + |
| 37 | + |
22 | 38 | class TestProcessContent: |
23 | 39 | """Tests for :func:`_process_content` — the core processing pipeline.""" |
24 | 40 |
|
@@ -58,17 +74,24 @@ def test_template_vars_without_metadata(self): |
58 | 74 | result = _process_content(content, interactive=False, template_vars={'label': 'test'}) |
59 | 75 | assert result == {'label': 'test', 'hostname': 'localhost'} |
60 | 76 |
|
| 77 | + def test_template_vars_without_metadata_warns_interactive(self, capsys): |
| 78 | + """In interactive mode, template variables without metadata emit a warning but don't fail.""" |
| 79 | + content = "label: '{{ label }}'\nhostname: localhost\n" |
| 80 | + result = _process_content(content, interactive=True) |
| 81 | + assert result == {'label': '{{ label }}', 'hostname': 'localhost'} |
| 82 | + assert 'no metadata found' in capsys.readouterr().out |
| 83 | + |
61 | 84 | @pytest.mark.parametrize( |
62 | 85 | ('content', 'match'), |
63 | 86 | [ |
64 | | - (':\ninvalid: [yaml', 'Invalid YAML'), |
65 | | - ('- item1\n- item2\n', 'Expected a YAML mapping'), |
66 | | - ( |
| 87 | + pytest.param(':\ninvalid: [yaml', 'Invalid YAML', id='invalid_yaml'), |
| 88 | + pytest.param('- item1\n- item2\n', 'Expected a YAML mapping', id='non_dict_yaml'), |
| 89 | + pytest.param( |
67 | 90 | "label: '{{ label }}'\nmetadata:\n template_variables:\n label:\n description: x\n", |
68 | 91 | r'Template variables detected.*but no values provided', |
| 92 | + id='missing_vars_non_interactive', |
69 | 93 | ), |
70 | 94 | ], |
71 | | - ids=['invalid_yaml', 'non_dict_yaml', 'missing_vars_non_interactive'], |
72 | 95 | ) |
73 | 96 | def test_invalid_input_raises(self, content, match): |
74 | 97 | with pytest.raises(click.BadParameter, match=match): |
@@ -101,12 +124,15 @@ def test_registry_computer_format(self): |
101 | 124 | interactive=False, |
102 | 125 | template_vars={'label': 'eiger-mc', 'slurm_partition': 'normal', 'slurm_account': 'my_project'}, |
103 | 126 | ) |
104 | | - assert result['label'] == 'eiger-mc' |
105 | | - assert '#SBATCH --partition=normal' in result['prepend_text'] |
106 | | - assert '#SBATCH --account=my_project' in result['prepend_text'] |
107 | | - assert '{username}' in result['work_dir'] |
108 | | - assert '{tot_num_mpiprocs}' in result['mpirun_command'] |
109 | | - assert 'metadata' not in result |
| 127 | + assert result == { |
| 128 | + 'label': 'eiger-mc', |
| 129 | + 'hostname': 'eiger.cscs.ch', |
| 130 | + 'transport': 'core.ssh', |
| 131 | + 'scheduler': 'core.slurm', |
| 132 | + 'work_dir': '/scratch/{username}/aiida_run/', |
| 133 | + 'mpirun_command': 'srun -n {tot_num_mpiprocs}', |
| 134 | + 'prepend_text': '#SBATCH --partition=normal\n#SBATCH --account=my_project', |
| 135 | + } |
110 | 136 |
|
111 | 137 | def test_registry_code_format_with_multiline_expression(self): |
112 | 138 | """Realistic code YAML with a multi-line Jinja2 expression (``{{ }}`` spanning two lines).""" |
@@ -141,6 +167,70 @@ def test_from_file(self, tmp_path): |
141 | 167 | result = load_and_process_template(str(filepath), interactive=False) |
142 | 168 | assert result == {'label': 'my-computer', 'hostname': 'localhost'} |
143 | 169 |
|
| 170 | + def test_from_url(self, monkeypatch): |
| 171 | + """Loading a template from a URL fetches and processes the content.""" |
| 172 | + import requests |
| 173 | + |
| 174 | + monkeypatch.setattr( |
| 175 | + requests, 'get', lambda *a, **kw: _fake_response('label: url-computer\nhostname: remote-host\n') |
| 176 | + ) |
| 177 | + result = load_and_process_template('https://example.com/config.yaml', interactive=False) |
| 178 | + assert result == {'label': 'url-computer', 'hostname': 'remote-host'} |
| 179 | + |
144 | 180 | def test_file_not_found(self): |
145 | 181 | with pytest.raises(click.BadParameter, match='Failed to read file'): |
146 | 182 | load_and_process_template('/nonexistent/path.yaml', interactive=False) |
| 183 | + |
| 184 | + def test_url_failure(self, monkeypatch): |
| 185 | + """A failing URL request raises ``click.BadParameter``.""" |
| 186 | + import requests |
| 187 | + |
| 188 | + monkeypatch.setattr(requests, 'get', lambda *a, **kw: (_ for _ in ()).throw(requests.ConnectionError('fail'))) |
| 189 | + with pytest.raises(click.BadParameter, match='Failed to fetch URL'): |
| 190 | + load_and_process_template('https://example.com/config.yaml', interactive=False) |
| 191 | + |
| 192 | + |
| 193 | +class TestParseTemplateVars: |
| 194 | + """Tests for :func:`parse_template_vars` — file / URL / JSON resolution chain.""" |
| 195 | + |
| 196 | + def test_from_json_string(self): |
| 197 | + result = parse_template_vars('{"key": "value", "num": "42"}') |
| 198 | + assert result == {'key': 'value', 'num': '42'} |
| 199 | + |
| 200 | + def test_from_yaml_file(self, tmp_path): |
| 201 | + filepath = tmp_path / 'vars.yaml' |
| 202 | + filepath.write_text('account: my_project\npartition: normal\n') |
| 203 | + result = parse_template_vars(str(filepath)) |
| 204 | + assert result == {'account': 'my_project', 'partition': 'normal'} |
| 205 | + |
| 206 | + def test_from_url(self, monkeypatch): |
| 207 | + import requests |
| 208 | + |
| 209 | + monkeypatch.setattr(requests, 'get', lambda *a, **kw: _fake_response('account: remote_project\n')) |
| 210 | + result = parse_template_vars('https://example.com/vars.yaml') |
| 211 | + assert result == {'account': 'remote_project'} |
| 212 | + |
| 213 | + @pytest.mark.parametrize( |
| 214 | + ('value', 'match'), |
| 215 | + [ |
| 216 | + pytest.param('not valid json', 'Invalid JSON', id='invalid_json'), |
| 217 | + pytest.param('["a", "b"]', 'must contain a YAML/JSON mapping', id='json_array'), |
| 218 | + pytest.param('"just a string"', 'must contain a YAML/JSON mapping', id='json_string'), |
| 219 | + ], |
| 220 | + ) |
| 221 | + def test_invalid_input_raises(self, value, match): |
| 222 | + with pytest.raises(click.BadParameter, match=match): |
| 223 | + parse_template_vars(value) |
| 224 | + |
| 225 | + @pytest.mark.parametrize( |
| 226 | + ('file_content', 'match'), |
| 227 | + [ |
| 228 | + pytest.param(':\n [invalid yaml', 'Invalid YAML', id='invalid_yaml'), |
| 229 | + pytest.param('- item1\n- item2\n', 'must contain a YAML/JSON mapping', id='non_dict_yaml'), |
| 230 | + ], |
| 231 | + ) |
| 232 | + def test_file_with_bad_content_raises(self, tmp_path, file_content, match): |
| 233 | + filepath = tmp_path / 'bad.yaml' |
| 234 | + filepath.write_text(file_content) |
| 235 | + with pytest.raises(click.BadParameter, match=match): |
| 236 | + parse_template_vars(str(filepath)) |
0 commit comments