Skip to content

Commit

Permalink
Use proper_type plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
sobolevn committed Aug 11, 2024
1 parent 532a3c2 commit f85e605
Show file tree
Hide file tree
Showing 10 changed files with 98 additions and 83 deletions.
18 changes: 9 additions & 9 deletions returns/contrib/mypy/_features/curry.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from mypy.plugin import FunctionContext
from mypy.types import AnyType, CallableType, FunctionLike, Overloaded
from mypy.types import Type as MypyType
from mypy.types import TypeOfAny
from mypy.types import TypeOfAny, get_proper_type

from returns.contrib.mypy._structures.args import FuncArg
from returns.contrib.mypy._typeops.transform_callable import (
Expand All @@ -20,14 +20,14 @@

def analyze(ctx: FunctionContext) -> MypyType:
"""Returns proper type for curried functions."""
if not isinstance(ctx.arg_types[0][0], CallableType):
return ctx.default_return_type
if not isinstance(ctx.default_return_type, CallableType):
return ctx.default_return_type
default_return = get_proper_type(ctx.default_return_type)
arg_type = get_proper_type(ctx.arg_types[0][0])
if not isinstance(arg_type, CallableType):
return default_return
if not isinstance(default_return, CallableType):
return default_return

return _CurryFunctionOverloads(
ctx.arg_types[0][0], ctx,
).build_overloads()
return _CurryFunctionOverloads(arg_type, ctx).build_overloads()


@final
Expand Down Expand Up @@ -147,7 +147,7 @@ def _build_overloads_from_argtree(self, argtree: _ArgTree) -> None:
# Will take `2` and apply its type to the previous function `1`.
# Will result in `def x -> y -> A`
# We also overloadify existing return types.
ret_type = argtree.case.ret_type
ret_type = get_proper_type(argtree.case.ret_type)
temp_any = isinstance(
ret_type, AnyType,
) and ret_type.type_of_any == TypeOfAny.implementation_artifact
Expand Down
13 changes: 7 additions & 6 deletions returns/contrib/mypy/_features/do_notation.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,22 +36,23 @@ def analyze(ctx: MethodContext) -> MypyType:
if generator expression has ``if`` conditions inside.
"""
default_return = get_proper_type(ctx.default_return_type)
if not ctx.args or not ctx.args[0]:
return ctx.default_return_type
return default_return

expr = ctx.args[0][0]
if not isinstance(expr, GeneratorExpr):
ctx.api.fail(_LITERAL_GENERATOR_EXPR_REQUIRED, expr)
return ctx.default_return_type
return default_return
if not isinstance(ctx.type, CallableType):
return ctx.default_return_type
if not isinstance(ctx.default_return_type, Instance):
return ctx.default_return_type
return default_return
if not isinstance(default_return, Instance):
return default_return

return _do_notation(
expr=expr,
type_info=ctx.type.type_object(),
default_return_type=ctx.default_return_type,
default_return_type=default_return,
ctx=ctx,
)

Expand Down
3 changes: 2 additions & 1 deletion returns/contrib/mypy/_features/flow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from mypy.plugin import FunctionContext
from mypy.types import Type as MypyType
from mypy.types import get_proper_type

from returns.contrib.mypy._typeops.inference import PipelineInference

Expand Down Expand Up @@ -46,7 +47,7 @@ def analyze(ctx: FunctionContext) -> MypyType:
)

return PipelineInference(
ctx.arg_types[0][0],
get_proper_type(ctx.arg_types[0][0]),
).from_callable_sequence(
real_arg_types,
ctx.arg_kinds[1],
Expand Down
27 changes: 13 additions & 14 deletions returns/contrib/mypy/_features/kind.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def attribute_access(ctx: AttributeContext) -> MypyType:
"""
assert isinstance(ctx.type, Instance)
instance = ctx.type.args[0]
instance = get_proper_type(ctx.type.args[0])

if isinstance(instance, TypeVarType):
bound = get_proper_type(instance.upper_bound)
Expand Down Expand Up @@ -78,18 +78,15 @@ def dekind(ctx: FunctionContext) -> MypyType:
So, ``dekind(KindN[T, int])`` will fail.
"""
kind = get_proper_type(ctx.arg_types[0][0])
correct_args = (
isinstance(kind, Instance) and
isinstance(kind.args[0], Instance)
)
assert isinstance(kind, Instance) # mypy requires these lines

if not correct_args:
kind_inst = get_proper_type(kind.args[0])

if not isinstance(kind_inst, Instance):
ctx.api.fail(_KindErrors.dekind_not_instance, ctx.context)
return AnyType(TypeOfAny.from_error)

assert isinstance(kind, Instance) # mypy requires these lines
assert isinstance(kind.args[0], Instance)
return kind.args[0].copy_modified(args=_crop_kind_args(kind))
return kind_inst.copy_modified(args=_crop_kind_args(kind))


@asserts_fallback_to_any
Expand All @@ -101,9 +98,10 @@ def kinded_signature(ctx: MethodSigContext) -> CallableType:
See :class:`returns.primitives.hkt.Kinded` for more information.
"""
assert isinstance(ctx.type, Instance)
assert isinstance(ctx.type.args[0], FunctionLike)

wrapped_method = ctx.type.args[0]
wrapped_method = get_proper_type(ctx.type.args[0])
assert isinstance(wrapped_method, FunctionLike)

if isinstance(wrapped_method, Overloaded):
return ctx.default_signature

Expand Down Expand Up @@ -138,10 +136,11 @@ def kinded_get_descriptor(ctx: MethodContext) -> MypyType:
We do this due to ``__get__`` descriptor magic.
"""
assert isinstance(ctx.type, Instance)
assert isinstance(ctx.type.args[0], CallableType)

wrapped_method = ctx.type.args[0]
self_type = wrapped_method.arg_types[0]
wrapped_method = get_proper_type(ctx.type.args[0])
assert isinstance(wrapped_method, CallableType)

self_type = get_proper_type(wrapped_method.arg_types[0])
signature = bind_self(
wrapped_method,
is_classmethod=isinstance(self_type, TypeType),
Expand Down
37 changes: 22 additions & 15 deletions returns/contrib/mypy/_features/partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,15 @@

from mypy.nodes import ARG_STAR, ARG_STAR2
from mypy.plugin import FunctionContext
from mypy.types import CallableType, FunctionLike, Instance, Overloaded
from mypy.types import Type as MypyType
from mypy.types import TypeType
from mypy.types import (
CallableType,
FunctionLike,
Instance,
Overloaded,
ProperType,
TypeType,
get_proper_type,
)

from returns.contrib.mypy._structures.args import FuncArg
from returns.contrib.mypy._typeops.analtype import (
Expand All @@ -27,7 +33,7 @@
)


def analyze(ctx: FunctionContext) -> MypyType:
def analyze(ctx: FunctionContext) -> ProperType:
"""
This hook is used to make typed curring a thing in `returns` project.
Expand All @@ -40,26 +46,27 @@ def analyze(ctx: FunctionContext) -> MypyType:
Internally we just reduce the original function's argument count.
And drop some of them from function's signature.
"""
if not isinstance(ctx.default_return_type, CallableType):
return ctx.default_return_type
default_return = get_proper_type(ctx.default_return_type)
if not isinstance(default_return, CallableType):
return default_return

function_def = ctx.arg_types[0][0]
function_def = get_proper_type(ctx.arg_types[0][0])
func_args = _AppliedArgs(ctx)

if len(list(filter(len, ctx.arg_types))) == 1:
return function_def # this means, that `partial(func)` is called
elif not isinstance(function_def, _SUPPORTED_TYPES):
return ctx.default_return_type
return default_return
elif isinstance(function_def, (Instance, TypeType)):
# We force `Instance` and similar types to coercse to callable:
function_def = func_args.get_callable_from_context()

is_valid, applied_args = func_args.build_from_context()
if not isinstance(function_def, (CallableType, Overloaded)) or not is_valid:
return ctx.default_return_type
return default_return

return _PartialFunctionReducer(
ctx.default_return_type,
default_return,
function_def,
applied_args,
ctx,
Expand Down Expand Up @@ -118,7 +125,7 @@ def __init__(
self._case_functions: List[CallableType] = []
self._fallbacks: List[CallableType] = []

def new_partial(self) -> MypyType:
def new_partial(self) -> ProperType:
"""
Creates new partial functions.
Expand Down Expand Up @@ -182,7 +189,7 @@ def _create_partial_case(
return detach_callable(partial)
return partial.copy_modified(variables=[])

def _create_new_partial(self) -> MypyType:
def _create_new_partial(self) -> ProperType:
"""
Creates a new partial function-like from set of callables.
Expand Down Expand Up @@ -220,12 +227,12 @@ def __init__(self, function_ctx: FunctionContext) -> None:
self._function_ctx.arg_kinds[1:],
)

def get_callable_from_context(self) -> MypyType:
def get_callable_from_context(self) -> ProperType:
"""Returns callable type from the context."""
return safe_translate_to_function(
return get_proper_type(safe_translate_to_function(
self._function_ctx.arg_types[0][0],
self._function_ctx,
)
))

def build_from_context(self) -> Tuple[bool, List[FuncArg]]:
"""
Expand Down
27 changes: 14 additions & 13 deletions returns/contrib/mypy/_features/pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,9 @@

from mypy.nodes import ARG_POS
from mypy.plugin import FunctionContext, MethodContext, MethodSigContext
from mypy.types import AnyType, CallableType, FunctionLike, Instance
from mypy.types import AnyType, CallableType, FunctionLike, Instance, ProperType
from mypy.types import Type as MypyType
from mypy.types import TypeOfAny, UnionType, get_proper_type
from mypy.types import TypeOfAny, UnionType, get_proper_type, get_proper_types

from returns.contrib.mypy._typeops.analtype import translate_to_function
from returns.contrib.mypy._typeops.inference import PipelineInference
Expand All @@ -51,21 +51,22 @@

def analyze(ctx: FunctionContext) -> MypyType:
"""This hook helps when we create the pipeline from sequence of funcs."""
if not isinstance(ctx.default_return_type, Instance):
return ctx.default_return_type
default_return = get_proper_type(ctx.default_return_type)
if not isinstance(default_return, Instance):
return default_return

if not ctx.arg_types[0]: # We do require to pass `*functions` arg.
ctx.api.fail('Too few arguments for "pipe"', ctx.context)
return ctx.default_return_type
return default_return

arg_types = [arg_type[0] for arg_type in ctx.arg_types if arg_type]
first_step, last_step = _get_pipeline_def(arg_types, ctx)
if not isinstance(first_step, FunctionLike):
return ctx.default_return_type
return default_return
if not isinstance(last_step, FunctionLike):
return ctx.default_return_type
return default_return

return ctx.default_return_type.copy_modified(
return default_return.copy_modified(
args=[
# First type argument represents first function arguments type:
_unify_type(first_step, _get_first_arg_type),
Expand All @@ -82,9 +83,9 @@ def infer(ctx: MethodContext) -> MypyType:
if not isinstance(ctx.type, Instance):
return ctx.default_return_type

pipeline_functions = ctx.type.args[2:]
pipeline_functions = get_proper_types(ctx.type.args[2:])
return PipelineInference(
ctx.arg_types[0][0],
get_proper_type(ctx.arg_types[0][0]),
).from_callable_sequence(
pipeline_functions,
list((ARG_POS,) * len(pipeline_functions)),
Expand Down Expand Up @@ -117,12 +118,12 @@ def _unify_type(
def _get_pipeline_def(
arg_types: List[MypyType],
ctx: FunctionContext,
) -> Tuple[MypyType, MypyType]:
) -> Tuple[ProperType, ProperType]:
first_step = get_proper_type(arg_types[0])
last_step = get_proper_type(arg_types[-1])

if not isinstance(first_step, FunctionLike):
first_step = translate_to_function(first_step, ctx) # type: ignore
first_step = translate_to_function(first_step, ctx)
if not isinstance(last_step, FunctionLike):
last_step = translate_to_function(last_step, ctx) # type: ignore
last_step = translate_to_function(last_step, ctx)
return first_step, last_step
11 changes: 6 additions & 5 deletions returns/contrib/mypy/_typeops/analtype.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@

from mypy.checkmember import analyze_member_access
from mypy.nodes import ARG_NAMED, ARG_OPT
from mypy.types import CallableType, FunctionLike
from mypy.types import CallableType, FunctionLike, ProperType
from mypy.types import Type as MypyType
from mypy.types import get_proper_type
from typing_extensions import Literal

from returns.contrib.mypy._structures.args import FuncArg
Expand Down Expand Up @@ -99,17 +100,17 @@ def safe_translate_to_function(


def translate_to_function(
function_def: MypyType,
function_def: ProperType,
ctx: CallableContext,
) -> MypyType:
) -> ProperType:
"""
Tries to translate a type into callable by accessing ``__call__`` attr.
This might fail with ``mypy`` errors and that's how it must work.
This also preserves all type arguments as-is.
"""
checker = ctx.api.expr_checker # type: ignore
return analyze_member_access(
return get_proper_type(analyze_member_access(
'__call__',
function_def,
ctx.context,
Expand All @@ -120,4 +121,4 @@ def translate_to_function(
original_type=function_def,
chk=checker.chk,
in_literal_context=checker.is_literal_context(),
)
))
Loading

0 comments on commit f85e605

Please sign in to comment.