Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing tensor.numpy on wrapped tensors #627

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

vfdev-5
Copy link
Contributor

@vfdev-5 vfdev-5 commented Mar 29, 2022

Fixes #626

Description:

  • Fixing tensor.numpy on wrapped tensors
  • Added a test

Fixes pytorch#626

Description:
- Fixing tensor.numpy on wrapped tensors
Comment on lines +109 to +128
level = _C.maybe_get_level(tensor)
if level == -1:
return _old_numpy(tensor)

if _C.is_functionaltensor(tensor):
# Since we're unwrapping the FunctionalTensorWrapper, we need to make sure
# that it's up to date first
torch._sync(tensor)

value = _C.get_unwrapped(tensor)
dl_enabled = _C.tls_set_is_included()
try:
# Disable temporarily kDynamicLayerFrontModeKey/kDynamicLayerBackModeKey as included dispatch keys
if (dl_enabled):
_C._set_dynamic_layer_keys_included(False)
return value.numpy()
finally:
# Reenable kDynamicLayerFrontModeKey/kDynamicLayerBackModeKey as included dispatch keys
if (dl_enabled):
_C._set_dynamic_layer_keys_included(True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, so this is a little more complicated than this I think.

When someone calls .numpy() under vmap, we probably want to error out. Otherwise some weird things might happen:

def f(x):
  return torch.tensor(x.numpy())

x = torch.randn(B)
vmap(f)(x) # returns a Tensor of size B, B -- is that what we want?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When someone calls .numpy() under the grad transform then we should support this (as long as there are no vmaps involved). I'm not sure what the best way to support this is... one thing we can do is keep unwrapping the Tensor and seeing that no BatchedTensors are involved.

In the long-term we want a better fix for this that perhaps involves making the pytorch dispatcher recognize .numpy() as an operation

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Issue with tensor.numpy() for wrapped tensors
3 participants