Skip to content

Simplify the Function class #955

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 7 additions & 30 deletions pytensor/compile/function/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -792,18 +792,14 @@ def __call__(self, *args, **kwargs):
The function inputs can be passed as keyword argument. For this, use
the name of the input or the input instance as the key.

Keyword argument ``output_subset`` is a list of either indices of the
function's outputs or the keys belonging to the `output_keys` dict
and represent outputs that are requested to be calculated. Regardless
of the presence of ``output_subset``, the updates are always calculated
and processed. To disable the updates, you should use the ``copy``
The updates are always calculated and processed.
To disable the updates, you should use the ``copy``
method with ``delete_updates=True``.

Returns
-------
list
List of outputs on indices/keys from ``output_subset`` or all of them,
if ``output_subset`` is not passed.
List of outputs.
"""

def restore_defaults():
Expand All @@ -816,10 +812,6 @@ def restore_defaults():
profile = self.profile
t0 = time.perf_counter()

output_subset = kwargs.pop("output_subset", None)
if output_subset is not None and self.output_keys is not None:
output_subset = [self.output_keys.index(key) for key in output_subset]

# Reinitialize each container's 'provided' counter
if self.trust_input:
i = 0
Expand Down Expand Up @@ -955,11 +947,7 @@ def restore_defaults():
# Do the actual work
t0_fn = time.perf_counter()
try:
outputs = (
self.vm()
if output_subset is None
else self.vm(output_subset=output_subset)
)
outputs = self.vm()
except Exception:
restore_defaults()
if hasattr(self.vm, "position_of_error"):
Expand Down Expand Up @@ -1040,24 +1028,13 @@ def restore_defaults():
profile.ignore_first_call = False
if self.return_none:
return None
elif self.unpack_single and len(outputs) == 1 and output_subset is None:
elif self.unpack_single and len(outputs) == 1:
return outputs[0]
else:
if self.output_keys is not None:
assert len(self.output_keys) == len(outputs)

if output_subset is None:
return dict(zip(self.output_keys, outputs))
else:
return {
self.output_keys[index]: outputs[index]
for index in output_subset
}

if output_subset is None:
return outputs
else:
return [outputs[i] for i in output_subset]
return dict(zip(self.output_keys, outputs))
return outputs

value = property(
lambda self: self._value,
Expand Down
Loading