Skip to content

Latest commit

 

History

History
186 lines (137 loc) · 5.16 KB

functional_programming7.md

File metadata and controls

186 lines (137 loc) · 5.16 KB

Multimethods

本文为大家介绍Guido于2005年对于泛型函数的一些构想。文章原文链接参见文末参考文献。

构想与实现

正如@singledispatch一样,我们可以利用装饰器来定义多参数的泛型函数:

from mm import multimethod

@multimethod(int, int)
def foo(a, b):
    '''code for two ints'''

@multimethod(float, float):
def foo(a, b):
    '''code for two floats'''

@multimethod(str, str):
def foo(a, b):
    '''code for two strings'''

上述multimethod装饰器可以这样来简单实现。首先我们定义一个类来存储函数的映射关系:

registry = {}

class MultiMethod:
    def __init__(self, name):
        self.name = name
        self.typemap = {}
    
    def __call__(self, *args):
        types = tuple(arg.__class__ for arg in args)
        function = self.typemap.get(types)
        if function is None:
            raise TypeError("no match")
        return function(*args)
    
    def register(self, types, function):
        if types in self.typemap:
            raise TypeError("duplicate registration")
        self.typemap[types] = function

通过register方法可以注册泛型函数和对应的参数类型,而采用特殊方法__call__的原因后续会看到。这里,只有一个MultiMethod类是不够的,我们需要的是装饰器@multimethod。装饰器的作用是将类型和函数对应起来,因而我们的@multimethod只需要返回一个MultiMethod对象即可:

def multimethod(*args):
    def wrapper(function):
        name = function.__name__
        mm = registry.get(name)
        if mm is None:
            mm = registry[name] = MultiMethod(name)
        mm.register(args, function)
        return mm
    return wrapper

这里我们通过全局变量registry记录了函数名与MultiMethod对象的对应关系。我们再回过头看看当我们写下@multimethod定义函数时发生了什么。wrapper首先查找registry是否定义了这个函数的MultiMethod对象mm。之后,调用mmregister方法来记录args和函数,其中args就是@multimethod括号中的参数。register将会记录下args对应的函数。也就是说,一个泛型函数对应一个MultiMethod对象,存储于全局字典registry中;一个对象内存储着参数列表和对应函数的映射。最后,当函数调用时(实际是MultiMethod对象进行调用),执行的是__call__方法。我们来看看效果:

@multimethod(int, float)
def add(a, b):
    return a + b

@multimethod(int, list)
def add(a, b):
    return [x + a for x in b]

@multimethod(int, float, complex)
def add(a, b, c):
    return a + b + c

print(add(1, 1.0))
2.0

print(add(2, [1, 2, 3]))
[3, 4, 5]

print(add(1, 1.0, 1+2j))
(3+2j)

这里,@multimethod仅仅支持位置参数。如果要支持关键字参数则比较复杂,因为关键字参数并不要求参数的顺序,而泛型函数需要明确顺序来获得类型组合。我们尝试给__call__增加关键字参数:

# class MultiMethod
def __call__(self, *args, **kwargs):
    types = tuple(arg.__class__ for arg in args)\
    		+ tuple(kwargs[key].__class__ for key in kwargs)
    ...
    return function(*args, **kwargs)

试着调用一下:****

print(add(a=1, b=2.0))
3.0

print(add(b=2.0, a=1))
#TypeError: no match

相同的参数列表却得到了不同的结果。

默认参数

具有默认参数的函数与泛型函数有一丝冲突,因为默认参数在调用时可以给出也可以不必给出,而泛型函数则需要获得所有参数的类型。

@multimethod(int, int)
def add(a, b=1):
    return a * b

上述add等价于下面两个函数的结合体。

@multimethod(int, int)
def add(a, b):
    return a * b

@multimethod(int)
def add(a):
    return add(a, b=1)

这两个函数的函数体是一样的(但是上面的定义是无法使用的)。一个比较优雅的书写方式是装饰器的嵌套:

@multimethod(int, int)
@multimethod(int)
def add(a, b=1):
    return a * b

怎么实现呢?由于经过一次装饰后的函数获得的是一个MultiMethod对象,这个对象无法第二次再进行装饰(因为不存在name属性),因而我们只需要修改multimethod函数,利用一个属性将原始函数记录下来即可:

def multimethod(*types):
    def register(function):
        function = getattr(function, "__lastreg__", function)
        name = function.__name__
        mm = registry.get(name)
        if mm is None:
            mm = registry[name] = MultiMethod(name)
        mm.register(types, function)
        mm.__lastreg__ = function
        return mm
    return register

之后我们可以嵌套@multimethod,并且能够支持默认参数:

@multimethod(float, int)
@multimethod(int, int)
@multimethod(int)
@multimethod(float)
def add(a, b=1):
    return a * b

print(add(1))
1

print(add(2.0))
2.0

print(add(1, 2))
2

print(add(1.0, 3))
3.0

https://www.artima.com/weblogs/viewpost.jsp?thread=101605

Source : http://inst.eecs.berkeley.edu/~cs61A/book/chapters/objects.html#generic-functions