Skip to content

[typing] nn.Parameter return type identified as Tensor by pyrightΒ #125105

Closed
@randolf-scholz

Description

@randolf-scholz

πŸ› Describe the bug

Type checking the following script with pyright fails (but not with mypy):

from typing import assert_type, reveal_type
import torch

w = torch.randn(3)
p = torch.nn.Parameter(w)
assert_type(p, torch.nn.Parameter)  # ❌ expected "Parameter" but received "Tensor"
reveal_type(torch.nn.Parameter.__new__)
# (self: type[Self@TensorBase], *args: Unknown, **kwargs: Unknown) -> Tensor

My guess what's happening:

  1. TensorBase defines: def __new__(self, *args, **kwargs) -> Tensor: ... in the stub torch/_C/__init__.pyi
  2. Parameter.__new__ is untyped. Both pyright and mypy show it's self: type[Self@TensorBase] -> Tensor, so reusing the superclasses type.
  3. mypy ignores the discrepancy and treats the result as an nn.Paramter anyway.

It seems the issue can be fixed by simply changing the signature in TensorBase.__new__ to

def __new__(cls, *args, **kwargs) -> Self

Versions

  • torch 2.3.0
  • pyright 1.1.360

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions