Skip to content

Commit

Permalink
fix: tf backend get_item and add set_item (#28823)
Browse files Browse the repository at this point in the history
Co-authored-by: ivy-dev-bot <[email protected]>
  • Loading branch information
Sam-Armstrong and ivy-dev-bot committed Sep 17, 2024
1 parent 4a7b58b commit f5cc3e8
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 45 deletions.
12 changes: 7 additions & 5 deletions binaries.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,9 @@
{
"utils": [
"C",
"CC",
"CD",
"CI",
"CL",
"CM",
"CV",
"CX",
"D",
Expand Down Expand Up @@ -43,7 +41,6 @@
"III",
"IIL",
"IIM",
"IIV",
"IIX",
"IL",
"ILC",
Expand All @@ -62,7 +59,6 @@
"IMV",
"IMX",
"IV",
"IVC",
"IVD",
"IVI",
"IVL",
Expand Down Expand Up @@ -103,7 +99,13 @@
"VCV",
"VCX",
"VD",
"VDC",
"VDD",
"VDI",
"VDL",
"VDM",
"VDV",
"VDX",
"VI",
"VIC",
"VID",
Expand Down Expand Up @@ -149,4 +151,4 @@
}
]
}
}
}
31 changes: 13 additions & 18 deletions ivy/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@


def clear_graph_cache():
"""Clears the graph cache which gets populated if `graph_caching` is set to
`True` in `ivy.trace_graph`, `ivy.transpile` or `ivy.unify`. Use this to
"""Clears the graph cache which gets populated if `graph_caching` is set
to `True` in `ivy.trace_graph`, `ivy.transpile` or `ivy.unify`. Use this to
reset or clear the graph cache if needed.
Examples
--------
>>> import ivy
>>> ivy.clear_graph_cache()
"""
>>> ivy.clear_graph_cache()"""

from ._compiler import clear_graph_cache as _clear_graph_cache

return _clear_graph_cache()
Expand Down Expand Up @@ -55,8 +55,8 @@ def graph_transpile(
Returns
-------
Either a transpiled Graph or a non-initialized LazyGraph.
"""
Either a transpiled Graph or a non-initialized LazyGraph."""

from ._compiler import graph_transpile as _graph_transpile

return _graph_transpile(
Expand Down Expand Up @@ -96,7 +96,6 @@ def source_to_source(
e.g. (source="torch_frontend", target="ivy") or (source="torch_frontend", target="tensorflow") etc.
Args:
----
object: The object (class/function) to be translated.
source (str, optional): The source framework. Defaults to 'torch'.
target (str, optional): The target framework. Defaults to 'tensorflow'.
Expand All @@ -107,9 +106,8 @@ def source_to_source(
the old implementation. Defaults to 'True'.
Returns:
-------
The translated object.
"""
The translated object."""

from ._compiler import source_to_source as _source_to_source

return _source_to_source(
Expand Down Expand Up @@ -140,8 +138,7 @@ def trace_graph(
params_v=None,
v=None
):
"""Takes `fn` and traces it into a more efficient composition of backend
operations.
"""Takes `fn` and traces it into a more efficient composition of backend operations.
Parameters
----------
Expand Down Expand Up @@ -211,8 +208,8 @@ def trace_graph(
>>> start = time.time()
>>> graph(x)
>>> print(time.time() - start)
0.0001785755157470703
"""
0.0001785755157470703"""

from ._compiler import trace_graph as _trace_graph

return _trace_graph(
Expand Down Expand Up @@ -252,7 +249,6 @@ def transpile(
e.g. (source="torch_frontend", target="ivy") or (source="torch_frontend", target="tensorflow") etc.
Args:
----
object: The object (class/function) to be translated.
source (str, optional): The source framework. Defaults to 'torch'.
target (str, optional): The target framework. Defaults to 'tensorflow'.
Expand All @@ -263,9 +259,8 @@ def transpile(
the old implementation. Defaults to 'True'.
Returns:
-------
The translated object.
"""
The translated object."""

from ._compiler import transpile as _transpile

return _transpile(
Expand Down
46 changes: 45 additions & 1 deletion ivy/functional/backends/tensorflow/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,51 @@ def get_item(
) -> Union[tf.Tensor, tf.Variable]:
if ivy.is_array(query) and ivy.is_bool_dtype(query) and not len(query.shape):
return tf.expand_dims(x, 0)
return x[query]
if isinstance(query, (tf.Tensor, tf.Variable)):
if query.dtype == tf.bool:
return tf.boolean_mask(x, query, axis=0)
else:
query = tf.cast(query, tf.int64)
return tf.gather(x, query, axis=0)
else:
if any([isinstance(q, slice) for q in query]):
# convert any lists/tuples within the query to slices
query = tuple([slice(*q) if isinstance(q, (list, tuple)) else q for q in query])
# for slices and other basic indexing, use __getitem__
return x[query]


def set_item(
x: Union[tf.Tensor, tf.Variable],
query: Union[tf.Tensor, tf.Variable, Tuple],
val: Union[tf.Tensor, tf.Variable],
/,
*,
copy: Optional[bool] = False,
) -> Union[tf.Tensor, tf.Variable]:
# TODO: we should re-write this at some point so it's compatible with tf.function (don't use numpy as an intermediary)
# when doing this, be sure to check the performance of the function on large tensors, compared to this implementation

if tf.is_tensor(x):
x = x.numpy()
if tf.is_tensor(val):
val = val.numpy()

if isinstance(query, (tf.Tensor, tf.Variable)):
query = query.numpy()
elif isinstance(query, tuple):
query = tuple(
q.numpy() if isinstance(q, (tf.Tensor, tf.Variable)) else q
for q in query
)

x[query] = val

if isinstance(x, tf.Variable) and not copy:
x.assign(x)
return x
else:
return tf.Variable(x) if isinstance(x, tf.Variable) else tf.convert_to_tensor(x)


def to_numpy(x: Union[tf.Tensor, tf.Variable], /, *, copy: bool = True) -> np.ndarray:
Expand Down
20 changes: 1 addition & 19 deletions ivy/functional/ivy/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -2879,25 +2879,7 @@ def set_item(
ivy.array([[ 0, -1, 20],
[10, 10, 10]])
"""
if copy:
x = ivy.copy_array(x)
if not ivy.is_array(val):
val = ivy.array(val)
if 0 in x.shape or 0 in val.shape:
return x
if ivy.is_array(query) and ivy.is_bool_dtype(query):
if not len(query.shape):
query = ivy.tile(query, (x.shape[0],))
indices = ivy.nonzero(query, as_tuple=False)
else:
indices, target_shape, _ = _parse_query(
query, ivy.shape(x, as_array=True), scatter=True
)
if indices is None:
return x
val = val.astype(x.dtype)
ret = ivy.scatter_nd(indices, val, reduction="replace", out=x)
return ret
return current_backend(x).set_item(x, query, val, copy=copy)


set_item.mixed_backend_wrappers = {
Expand Down
2 changes: 0 additions & 2 deletions ivy/utils/decorator_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,6 @@ def handle_get_item(fn):
def wrapper(inp, query, **kwargs):
try:
res = inp.__getitem__(query)
except IndexError:
raise
except Exception:
res = fn(inp, query, **kwargs)
return res
Expand Down

0 comments on commit f5cc3e8

Please sign in to comment.