Skip to content

Remove all references to pylint #293

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 0 additions & 12 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,6 @@ repos:
args: [ --fix ]
- id: ruff-format

- repo: https://github.com/asottile/pyupgrade
rev: v3.15.0
hooks:
- id: pyupgrade
args: [--py38-plus]

- repo: https://github.com/PyCQA/pylint
rev: v3.0.3
hooks:
- id: pylint
args: ["--disable=import-error"]

- repo: local
hooks:
- id: mypy
Expand Down
6 changes: 3 additions & 3 deletions profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import pstats
import random

import torchvision # type: ignore[import-untyped] # pylint: disable=unused-import # noqa: F401
from tqdm import trange # pylint: disable=unused-import # noqa: F401
import torchvision # type: ignore[import-untyped] # noqa: F401
from tqdm import trange # noqa: F401

from torchinfo import summary # pylint: disable=unused-import # noqa: F401
from torchinfo import summary # noqa: F401


def profile() -> None:
Expand Down
5 changes: 0 additions & 5 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
black
codecov
flake8
isort
mypy
pycln
pylint
pytest
pytest-cov
pre-commit
Expand Down
7 changes: 4 additions & 3 deletions ruff.toml
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
target-version = "py37"
select = ["ALL"]
ignore = [
"ANN101", # Missing type annotation for `self` in method
Expand All @@ -13,20 +14,21 @@ ignore = [
"FBT002", # Boolean default value in function definition
"FBT003", # Boolean positional value in function call
"FIX002", # Line contains TODO
"ISC001", # single-line-implicit-string-concatenation
"ISC001", # Isort
"PLR0911", # Too many return statements (11 > 6)
"PLR2004", # Magic value used in comparison, consider replacing 2 with a constant variable
"PLR0912", # Too many branches
"PLR0913", # Too many arguments to function call
"PLR0915", # Too many statements
"PTH123", # `open()` should be replaced by `Path.open()`
"S101", # Use of `assert` detected
"S311", # Standard pseudo-random generators are not suitable for cryptographic purposes
"T201", # print() found
"T203", # pprint() found
"TCH001", # Move application import into a type-checking block
"TCH003", # Move standard library import into a type-checking block
"TD002", # Missing author in TODO; try: `# TODO(<author_name>): ...`
"TD003", # Missing issue link on the line following this TODO
"TD005", # Missing issue description after `TODO`
"TRY003", # Avoid specifying long messages outside the exception class

# torchinfo-specific ignores
Expand All @@ -42,7 +44,6 @@ ignore = [
"TRY004", # Prefer `TypeError` exception for invalid type
"TRY301", # Abstract `raise` to an inner function
]
target-version = "py37"
exclude = ["tests"] # TODO: check tests too

[flake8-pytest-style]
Expand Down
24 changes: 3 additions & 21 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -33,25 +33,7 @@ torchinfo = py.typed

[mypy]
strict = True
implicit_reexport = True
warn_unreachable = True
disallow_any_unimported = True
extra_checks = True
enable_error_code = ignore-without-code

[pylint.MESSAGES CONTROL]
extension-pkg-whitelist = torch
enable =
useless-suppression,
deprecated-pragma,
use-symbolic-message-instead,
disable =
missing-module-docstring,
missing-function-docstring,
too-many-instance-attributes,
too-many-arguments,
too-many-branches,
too-many-locals,
invalid-name,
line-too-long, # Covered by flake8
no-member,
fixme,
duplicate-code,
fail-on = useless-suppression
3 changes: 2 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import pytest

from torchinfo import ModelStatistics
from torchinfo.formatting import HEADER_TITLES, ColumnSettings
from torchinfo.enums import ColumnSettings
from torchinfo.formatting import HEADER_TITLES
from torchinfo.torchinfo import clear_cached_forward_pass


Expand Down
1 change: 0 additions & 1 deletion tests/fixtures/genotype.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# type: ignore
# pylint: skip-file
from collections import namedtuple

import torch
Expand Down
9 changes: 4 additions & 5 deletions tests/fixtures/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# pylint: disable=too-few-public-methods
from __future__ import annotations

import math
Expand Down Expand Up @@ -303,7 +302,7 @@ def __init__(self) -> None:
self.constant = 5

def forward(self, x: dict[int, torch.Tensor], scale_factor: int) -> torch.Tensor:
return scale_factor * (x[256] + x[512][0]) * self.constant
return cast(torch.Tensor, scale_factor * (x[256] + x[512][0]) * self.constant)


class ModuleDictModel(nn.Module):
Expand Down Expand Up @@ -359,7 +358,7 @@ def __int__(self) -> IntWithGetitem:
return self

def __getitem__(self, val: int) -> torch.Tensor:
return self.tensor * val
return cast(torch.Tensor, self.tensor * val)


class EdgecaseInputOutputModel(nn.Module):
Expand Down Expand Up @@ -576,7 +575,7 @@ def __init__(self) -> None:
self.b = nn.Parameter(torch.empty(10), requires_grad=False)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.w * x + self.b
return cast(torch.Tensor, self.w * x + self.b)


class MixedTrainable(nn.Module):
Expand Down Expand Up @@ -718,7 +717,7 @@ def __init__(
def forward(self, x: torch.Tensor) -> torch.Tensor:
h = torch.mm(x, self.a) + self.b
if self.output_dim is None:
return h
return cast(torch.Tensor, h)
return cast(torch.Tensor, self.fc2(h))


Expand Down
1 change: 0 additions & 1 deletion tests/fixtures/tmva_net.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# type: ignore
# pylint: skip-file
import torch
import torch.nn.functional as F
from torch import nn
Expand Down
4 changes: 2 additions & 2 deletions tests/torchinfo_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,14 +541,14 @@ def test_empty_module_list() -> None:


def test_single_parameter_model() -> None:
class ParameterA(nn.Module): # pylint: disable=too-few-public-methods
class ParameterA(nn.Module):
"""A model with one parameter."""

def __init__(self) -> None:
super().__init__()
self.w = nn.Parameter(torch.zeros(1024))

class ParameterB(nn.Module): # pylint: disable=too-few-public-methods
class ParameterB(nn.Module):
"""A model with one parameter and one Conv2d layer."""

def __init__(self) -> None:
Expand Down
2 changes: 1 addition & 1 deletion torchinfo/layer_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def trainable(self) -> str:

@staticmethod
def calculate_size(
inputs: DETECTED_INPUT_OUTPUT_TYPES, batch_dim: int | None
inputs: DETECTED_INPUT_OUTPUT_TYPES | None, batch_dim: int | None
) -> tuple[list[int], int]:
"""
Set input_size or output_size using the model's inputs.
Expand Down
8 changes: 3 additions & 5 deletions torchinfo/torchinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,6 @@ class name as the key. If the forward pass is an expensive operation,
model_mode = Mode(mode)

if verbose is None:
# pylint: disable=no-member
verbose = 0 if hasattr(sys, "ps1") and sys.ps1 else 1

if cache_forward_pass is None:
Expand Down Expand Up @@ -269,7 +268,7 @@ def forward_pass(
**kwargs: Any,
) -> list[LayerInfo]:
"""Perform a forward pass on the model using forward hooks."""
global _cached_forward_pass # pylint: disable=global-variable-not-assigned
global _cached_forward_pass
model_name = model.__class__.__name__
if cache_forward_pass and model_name in _cached_forward_pass:
return _cached_forward_pass[model_name]
Expand Down Expand Up @@ -485,7 +484,7 @@ def get_device(
model_parameter = None

if model_parameter is not None and model_parameter.is_cuda:
return model_parameter.device
return model_parameter.device # type: ignore[no-any-return]
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
return None

Expand Down Expand Up @@ -664,7 +663,6 @@ def apply_hooks(
# some unknown reason (infinite recursion)
stack += [
(name, mod, curr_depth + 1, global_layer_info[module_id])
# pylint: disable=protected-access
for name, mod in reversed(module._modules.items())
if mod is not None
]
Expand All @@ -673,5 +671,5 @@ def apply_hooks(

def clear_cached_forward_pass() -> None:
"""Clear the forward pass cache."""
global _cached_forward_pass # pylint: disable=global-statement
global _cached_forward_pass
_cached_forward_pass = {}