11
11
# ANY KIND, either express or implied. See the License for the specific language
12
12
# governing permissions and limitations under the License.
13
13
14
- from types import ModuleType
15
- from typing import Any , Union
14
+ import inspect
15
+ from collections import defaultdict
16
+ from functools import wraps
17
+ from typing import Any , Callable , Optional
16
18
17
19
import modin .pandas as pd
20
+ from modin .config import Backend
21
+
22
+ # This type describes a defaultdict that maps backend name (or `None` for
23
+ # method implementation and not bound to any one extension) to the dictionary of
24
+ # extensions for that backend. The keys of the inner dictionary are the names of
25
+ # the extensions, and the values are the extensions themselves.
26
+ EXTENSION_DICT_TYPE = defaultdict [Optional [str ], dict [str , Any ]]
27
+
28
+ _attrs_to_delete_on_test = defaultdict (list )
29
+
30
+ _NON_EXTENDABLE_ATTRIBUTES = (
31
+ # we use these attributes to implement the extension system, so we can't
32
+ # allow extensions to override them.
33
+ "__getattribute__" ,
34
+ "__setattr__" ,
35
+ "__delattr__" ,
36
+ "get_backend" ,
37
+ "set_backend" ,
38
+ "__getattr__" ,
39
+ "_get_extension" ,
40
+ "_getattribute__from_extension_impl" ,
41
+ "_getattr__from_extension_impl" ,
42
+ "_query_compiler" ,
43
+ )
18
44
19
45
20
46
def _set_attribute_on_obj (
21
- name : str , extensions_dict : dict , obj : Union [ pd . DataFrame , pd . Series , ModuleType ]
47
+ name : str , extensions : dict , backend : Optional [ str ], obj : type
22
48
):
23
49
"""
24
50
Create a new or override existing attribute on obj.
@@ -27,8 +53,11 @@ def _set_attribute_on_obj(
27
53
----------
28
54
name : str
29
55
The name of the attribute to assign to `obj`.
30
- extensions_dict : dict
56
+ extensions : dict
31
57
The dictionary mapping extension name to `new_attr` (assigned below).
58
+ backend : Optional[str]
59
+ The backend to which the accessor applies. If `None`, this accessor
60
+ will become the default for all backends.
32
61
obj : DataFrame, Series, or modin.pandas
33
62
The object we are assigning the new attribute to.
34
63
@@ -37,10 +66,12 @@ def _set_attribute_on_obj(
37
66
decorator
38
67
Returns the decorator function.
39
68
"""
69
+ if name in _NON_EXTENDABLE_ATTRIBUTES :
70
+ raise ValueError (f"Cannot register an extension with the reserved name { name } ." )
40
71
41
72
def decorator (new_attr : Any ):
42
73
"""
43
- The decorator for a function or class to be assigned to name
74
+ Decorate a function or class to be assigned to the given name.
44
75
45
76
Parameters
46
77
----------
@@ -52,14 +83,24 @@ def decorator(new_attr: Any):
52
83
new_attr
53
84
Unmodified new_attr is return from the decorator.
54
85
"""
55
- extensions_dict [name ] = new_attr
56
- setattr (obj , name , new_attr )
86
+ extensions [None if backend is None else Backend .normalize (backend )][
87
+ name
88
+ ] = new_attr
89
+ if callable (new_attr ) and name not in dir (obj ):
90
+ # For callable extensions, we add a method to the class that
91
+ # dispatches to the correct implementation.
92
+ setattr (
93
+ obj ,
94
+ name ,
95
+ wrap_method_in_backend_dispatcher (name , new_attr , extensions ),
96
+ )
97
+ _attrs_to_delete_on_test [obj ].append (name )
57
98
return new_attr
58
99
59
100
return decorator
60
101
61
102
62
- def register_dataframe_accessor (name : str ):
103
+ def register_dataframe_accessor (name : str , * , backend : Optional [ str ] = None ):
63
104
"""
64
105
Registers a dataframe attribute with the name provided.
65
106
@@ -88,13 +129,16 @@ def my_new_dataframe_method(*args, **kwargs):
88
129
-------
89
130
decorator
90
131
Returns the decorator function.
132
+ backend : Optional[str]
133
+ The backend to which the accessor applies. If ``None``, this accessor
134
+ will become the default for all backends.
91
135
"""
92
136
return _set_attribute_on_obj (
93
- name , pd .dataframe ._DATAFRAME_EXTENSIONS_ , pd .DataFrame
137
+ name , pd .dataframe ._DATAFRAME_EXTENSIONS_ , backend , pd . dataframe .DataFrame
94
138
)
95
139
96
140
97
- def register_series_accessor (name : str ):
141
+ def register_series_accessor (name : str , * , backend : Optional [ str ] = None ):
98
142
"""
99
143
Registers a series attribute with the name provided.
100
144
@@ -118,13 +162,61 @@ def my_new_series_method(*args, **kwargs):
118
162
----------
119
163
name : str
120
164
The name of the attribute to assign to Series.
165
+ backend : Optional[str]
166
+ The backend to which the accessor applies. If ``None``, this accessor
167
+ will become the default for all backends.
168
+
169
+ Returns
170
+ -------
171
+ decorator
172
+ Returns the decorator function.
173
+ """
174
+ return _set_attribute_on_obj (
175
+ name , pd .series ._SERIES_EXTENSIONS_ , backend = backend , obj = pd .series .Series
176
+ )
177
+
178
+
179
+ def register_base_accessor (name : str , * , backend : Optional [str ] = None ):
180
+ """
181
+ Register a base attribute with the name provided.
182
+
183
+ This is a decorator that assigns a new attribute to BasePandasDataset. It can be used
184
+ with the following syntax:
185
+
186
+ ```
187
+ @register_base_accessor("new_method")
188
+ def register_base_accessor(*args, **kwargs):
189
+ # logic goes here
190
+ return
191
+ ```
192
+
193
+ The new attribute can then be accessed with the name provided:
194
+
195
+ ```
196
+ s.new_method(*my_args, **my_kwargs)
197
+ ```
198
+
199
+ Parameters
200
+ ----------
201
+ name : str
202
+ The name of the attribute to assign to BasePandasDataset.
203
+ backend : Optional[str]
204
+ The backend to which the accessor applies. If ``None``, this accessor
205
+ will become the default for all backends.
121
206
122
207
Returns
123
208
-------
124
209
decorator
125
210
Returns the decorator function.
126
211
"""
127
- return _set_attribute_on_obj (name , pd .series ._SERIES_EXTENSIONS_ , pd .Series )
212
+ import modin .pandas .base
213
+
214
+ return _set_attribute_on_obj (
215
+ name ,
216
+ modin .pandas .base ._BASE_EXTENSIONS ,
217
+ backend = backend ,
218
+ obj = modin .pandas .base .BasePandasDataset ,
219
+ )
128
220
129
221
130
222
def register_pd_accessor (name : str ):
@@ -160,4 +252,128 @@ def my_new_pd_function(*args, **kwargs):
160
252
decorator
161
253
Returns the decorator function.
162
254
"""
163
- return _set_attribute_on_obj (name , pd ._PD_EXTENSIONS_ , pd )
255
+
256
+ def decorator (new_attr : Any ):
257
+ """
258
+ The decorator for a function or class to be assigned to name
259
+
260
+ Parameters
261
+ ----------
262
+ new_attr : Any
263
+ The new attribute to assign to name.
264
+
265
+ Returns
266
+ -------
267
+ new_attr
268
+ Unmodified new_attr is return from the decorator.
269
+ """
270
+ pd ._PD_EXTENSIONS_ [name ] = new_attr
271
+ setattr (pd , name , new_attr )
272
+ return new_attr
273
+
274
+ return decorator
275
+
276
+
277
+ def wrap_method_in_backend_dispatcher (
278
+ name : str , method : Callable , extensions : EXTENSION_DICT_TYPE
279
+ ) -> Callable :
280
+ """
281
+ Wraps a method to dispatch to the correct backend implementation.
282
+
283
+ This function is a wrapper that is used to dispatch to the correct backend
284
+ implementation of a method.
285
+
286
+ Parameters
287
+ ----------
288
+ name : str
289
+ The name of the method being wrapped.
290
+ method : Callable
291
+ The method being wrapped.
292
+ extensions : EXTENSION_DICT_TYPE
293
+ The extensions dictionary for the class this method is defined on.
294
+
295
+ Returns
296
+ -------
297
+ Callable
298
+ Returns the wrapped function.
299
+ """
300
+
301
+ @wraps (method )
302
+ def method_dispatcher (* args , ** kwargs ):
303
+ if len (args ) == 0 and len (kwargs ) == 0 :
304
+ # Handle some cases like __init__()
305
+ return method (* args , ** kwargs )
306
+ # TODO(https://github.com/modin-project/modin/issues/7470): this
307
+ # method may take dataframes and series backed by different backends
308
+ # as input, e.g. if we are here because of a call like
309
+ # `pd.DataFrame().set_backend('python_test').merge(pd.DataFrame().set_backend('pandas'))`.
310
+ # In that case, we should determine which backend to cast to, cast all
311
+ # arguments to that backend, and then choose the appropriate extension
312
+ # method, if it exists.
313
+
314
+ # Assume that `self` is the first argument.
315
+ self = args [0 ]
316
+ remaining_args = args [1 :]
317
+ if (
318
+ hasattr (self , "_query_compiler" )
319
+ and self .get_backend () in extensions
320
+ and name in extensions [self .get_backend ()]
321
+ ):
322
+ # If `self` is using a query compiler whose backend has an
323
+ # extension for this method, use that extension.
324
+ return extensions [self .get_backend ()][name ](self , * remaining_args , ** kwargs )
325
+ else :
326
+ # Otherwise, use the default implementation.
327
+ if name not in extensions [None ]:
328
+ raise AttributeError (
329
+ f"{ type (self ).__name__ } object has no attribute { name } "
330
+ )
331
+ return extensions [None ][name ](self , * remaining_args , ** kwargs )
332
+
333
+ return method_dispatcher
334
+
335
+
336
+ def wrap_class_methods_in_backend_dispatcher (
337
+ extensions : EXTENSION_DICT_TYPE ,
338
+ ) -> Callable :
339
+ """
340
+ Get a function that can wrap a class's instance methods so that they dispatch to the correct backend.
341
+
342
+ Parameters
343
+ ----------
344
+ extensions : EXTENSION_DICT_TYPE
345
+ The extension dictionary for the class.
346
+
347
+ Returns
348
+ -------
349
+ Callable
350
+ The class wrapper.
351
+ """
352
+
353
+ def wrap_methods (cls : type ):
354
+ # We want to avoid wrapping synonyms like __add__() and add() with
355
+ # different wrappers, so keep a dict mapping methods we've wrapped
356
+ # to their wrapped versions.
357
+ already_seen_to_wrapped : dict [Callable , Callable ] = {}
358
+ for method_name , method_value in inspect .getmembers (
359
+ cls , predicate = inspect .isfunction
360
+ ):
361
+ if method_value in already_seen_to_wrapped :
362
+ setattr (cls , method_name , already_seen_to_wrapped [method_value ])
363
+ continue
364
+ elif method_name not in _NON_EXTENDABLE_ATTRIBUTES :
365
+ extensions [None ][method_name ] = method_value
366
+ wrapped = wrap_method_in_backend_dispatcher (
367
+ method_name , method_value , extensions
368
+ )
369
+ if method_name not in cls .__dict__ :
370
+ # If this class's method comes from a superclass (i.e.
371
+ # it's not in cls.__dict__), mark it so that
372
+ # modin.utils._inherit_docstrings knows that the method
373
+ # must get its docstrings from its superclass.
374
+ wrapped ._wrapped_superclass_method = method_value
375
+ setattr (cls , method_name , wrapped )
376
+ already_seen_to_wrapped [method_value ] = getattr (cls , method_name )
377
+ return cls
378
+
379
+ return wrap_methods
0 commit comments