Skip to content

Commit d29b712

Browse files
committed
Merge commit for internal changes
2 parents 26d4765 + 6ed75e6 commit d29b712

File tree

127 files changed

+3799
-1565
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

127 files changed

+3799
-1565
lines changed

configure.py

Lines changed: 19 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1122,12 +1122,16 @@ def toolkit_exists(toolkit_path):
11221122
write_action_env_to_bazelrc('COMPUTECPP_TOOLKIT_PATH',
11231123
computecpp_toolkit_path)
11241124

1125+
11251126
def set_trisycl_include_dir(environ_cp):
11261127
"""Set TRISYCL_INCLUDE_DIR."""
1127-
ask_trisycl_include_dir = (
1128-
'Please specify the location of the triSYCL include directory. (Use '
1129-
'--config=sycl_trisycl when building with Bazel) '
1130-
'[Default is %s]: ') % _DEFAULT_TRISYCL_INCLUDE_DIR
1128+
1129+
ask_trisycl_include_dir = ('Please specify the location of the triSYCL '
1130+
'include directory. (Use --config=sycl_trisycl '
1131+
'when building with Bazel) '
1132+
'[Default is %s]: '
1133+
) % (_DEFAULT_TRISYCL_INCLUDE_DIR)
1134+
11311135
while True:
11321136
trisycl_include_dir = get_from_env_or_user_or_default(
11331137
environ_cp, 'TRISYCL_INCLUDE_DIR', ask_trisycl_include_dir,
@@ -1201,46 +1205,10 @@ def set_other_mpi_vars(environ_cp):
12011205
raise ValueError('Cannot find the MPI library file in %s/lib' % mpi_home)
12021206

12031207

1204-
def set_mkl():
1205-
write_to_bazelrc('build:mkl --define using_mkl=true')
1206-
write_to_bazelrc('build:mkl -c opt')
1207-
print(
1208-
'Add "--config=mkl" to your bazel command to build with MKL '
1209-
'support.\nPlease note that MKL on MacOS or windows is still not '
1210-
'supported.\nIf you would like to use a local MKL instead of '
1211-
'downloading, please set the environment variable \"TF_MKL_ROOT\" every '
1212-
'time before build.\n')
1213-
1214-
1215-
def set_monolithic():
1216-
# Add --config=monolithic to your bazel command to use a mostly-static
1217-
# build and disable modular op registration support (this will revert to
1218-
# loading TensorFlow with RTLD_GLOBAL in Python). By default (without
1219-
# --config=monolithic), TensorFlow will build with a dependence on
1220-
# //tensorflow:libtensorflow_framework.so.
1221-
write_to_bazelrc('build:monolithic --define framework_shared_object=false')
1222-
# For projects which use TensorFlow as part of a Bazel build process, putting
1223-
# nothing in a bazelrc will default to a monolithic build. The following line
1224-
# opts in to modular op registration support by default:
1225-
write_to_bazelrc('build --define framework_shared_object=true')
1226-
1227-
1228-
def create_android_bazelrc_configs():
1229-
# Flags for --config=android
1230-
write_to_bazelrc('build:android --crosstool_top=//external:android/crosstool')
1231-
write_to_bazelrc(
1232-
'build:android --host_crosstool_top=@bazel_tools//tools/cpp:toolchain')
1233-
# Flags for --config=android_arm
1234-
write_to_bazelrc('build:android_arm --config=android')
1235-
write_to_bazelrc('build:android_arm --cpu=armeabi-v7a')
1236-
# Flags for --config=android_arm64
1237-
write_to_bazelrc('build:android_arm64 --config=android')
1238-
write_to_bazelrc('build:android_arm64 --cpu=arm64-v8a')
1239-
1240-
12411208
def set_grpc_build_flags():
12421209
write_to_bazelrc('build --define grpc_no_ares=true')
12431210

1211+
12441212
def set_windows_build_flags():
12451213
if is_windows():
12461214
# The non-monolithic build is not supported yet
@@ -1251,6 +1219,11 @@ def set_windows_build_flags():
12511219
write_to_bazelrc('build --verbose_failures')
12521220

12531221

1222+
def config_info_line(name, help_text):
1223+
"""Helper function to print formatted help text for Bazel config options."""
1224+
print('\t--config=%-12s\t# %s' % (name, help_text))
1225+
1226+
12541227
def main():
12551228
# Make a copy of os.environ to be clear when functions and getting and setting
12561229
# environment variables.
@@ -1336,10 +1309,7 @@ def main():
13361309

13371310
set_grpc_build_flags()
13381311
set_cc_opt_flags(environ_cp)
1339-
set_mkl()
1340-
set_monolithic()
13411312
set_windows_build_flags()
1342-
create_android_bazelrc_configs()
13431313

13441314
if workspace_has_any_android_rule():
13451315
print('The WORKSPACE file has at least one of ["android_sdk_repository", '
@@ -1357,6 +1327,11 @@ def main():
13571327
create_android_ndk_rule(environ_cp)
13581328
create_android_sdk_rule(environ_cp)
13591329

1330+
print('Preconfigured Bazel build configs. You can use any of the below by '
1331+
'adding "--config=<>" to your build command. See tools/bazel.rc for '
1332+
'more details.')
1333+
config_info_line('mkl', 'Build with MKL support.')
1334+
config_info_line('monolithic', 'Config for mostly static monolithic build.')
13601335

13611336
if __name__ == '__main__':
13621337
main()

tensorflow/c/eager/BUILD

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,3 +117,9 @@ cc_library(
117117
"//tensorflow/core:lib",
118118
],
119119
)
120+
121+
filegroup(
122+
name = "headers",
123+
srcs = ["c_api.h"],
124+
visibility = ["//tensorflow:__subpackages__"],
125+
)

tensorflow/c/eager/c_api.cc

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,10 @@ TFE_Context* TFE_NewContext(const TFE_ContextOptions* opts, TF_Status* status) {
9898

9999
void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status) {
100100
status->status = tensorflow::Status::OK();
101-
tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache);
101+
{
102+
tensorflow::mutex_lock ml(ctx->cache_mu);
103+
tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache);
104+
}
102105
TF_Graph* graph = ctx->session->graph;
103106
TF_DeleteSession(ctx->session, status);
104107
TF_DeleteGraph(graph);
@@ -110,6 +113,11 @@ TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx, TF_Status* status) {
110113
return TF_SessionListDevices(ctx->session, status);
111114
}
112115

116+
void TFE_ContextClearCaches(TFE_Context* ctx) {
117+
tensorflow::mutex_lock ml(ctx->cache_mu);
118+
tensorflow::gtl::STLDeleteValues(&ctx->kernel_cache);
119+
}
120+
113121
TFE_TensorHandle* TFE_NewTensorHandle(TF_Tensor* t, TF_Status* status) {
114122
tensorflow::Tensor tensor;
115123
status->status = tensorflow::TF_TensorToTensor(t, &tensor);
@@ -489,8 +497,11 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
489497
std::vector<tensorflow::Tensor> outputs(1);
490498
const tensorflow::MemoryTypeVector* output_memory_types = nullptr;
491499
tensorflow::Fprint128 cache_key = op->attrs.CacheKey(device->name());
492-
tensorflow::KernelAndDevice* kernel =
493-
tensorflow::gtl::FindPtrOrNull(ctx->kernel_cache, cache_key);
500+
tensorflow::KernelAndDevice* kernel;
501+
{
502+
tensorflow::tf_shared_lock l(ctx->cache_mu);
503+
kernel = tensorflow::gtl::FindPtrOrNull(ctx->kernel_cache, cache_key);
504+
}
494505
if (kernel == nullptr) {
495506
const tensorflow::NodeDef& ndef = op->attrs.BuildNodeDef();
496507
kernel = new tensorflow::KernelAndDevice(ctx->rendezvous);
@@ -506,6 +517,7 @@ void TFE_Execute(TFE_Op* op, TFE_TensorHandle** retvals, int* num_retvals,
506517
delete kernel;
507518
return;
508519
}
520+
tensorflow::mutex_lock ml(ctx->cache_mu);
509521
tensorflow::gtl::InsertOrUpdate(&(ctx->kernel_cache), cache_key, kernel);
510522
}
511523
std::vector<TFE_TensorHandle*> copied_tensors;

tensorflow/c/eager/c_api.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ limitations under the License.
1717
#define TENSORFLOW_C_EAGER_C_API_H_
1818

1919
// C API extensions to experiment with eager execution of kernels.
20+
// WARNING: Unlike tensorflow/c/c_api.h, the API here is not guaranteed to be
21+
// stable and can change without notice.
2022

2123
#include "tensorflow/c/c_api.h"
2224

@@ -87,6 +89,10 @@ TF_CAPI_EXPORT extern void TFE_DeleteContext(TFE_Context* ctx, TF_Status* status
8789
TF_CAPI_EXPORT extern TF_DeviceList* TFE_ContextListDevices(TFE_Context* ctx,
8890
TF_Status* status);
8991

92+
// Clears the internal caches in the TFE context. Useful when reseeding random
93+
// ops.
94+
TF_CAPI_EXPORT extern void TFE_ContextClearCaches(TFE_Context* ctx);
95+
9096
// A handle to a tensor on a device.
9197
//
9298
// Like a TF_Tensor, a TFE_TensorHandle refers to a tensor with a value, shape,

tensorflow/c/eager/c_api_internal.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,10 @@ struct TFE_Context {
5858
// session->devices[i].
5959
std::unique_ptr<tensorflow::ProcessFunctionLibraryRuntime> pflr;
6060

61+
tensorflow::mutex cache_mu;
6162
std::unordered_map<tensorflow::Fprint128, tensorflow::KernelAndDevice*,
6263
tensorflow::Fprint128Hasher>
63-
kernel_cache;
64+
kernel_cache GUARDED_BY(cache_mu);
6465

6566
tensorflow::FunctionLibraryRuntime* func_lib(tensorflow::Device* d) {
6667
return pflr->GetFLR(d->name());

tensorflow/compiler/tests/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,9 @@ tf_xla_py_test(
248248
tags = ["optonly"],
249249
deps = [
250250
":xla_test",
251+
"//tensorflow/contrib/signal:signal_py",
251252
"//tensorflow/python:array_ops",
253+
"//tensorflow/python:extra_py_tests_deps",
252254
"//tensorflow/python:framework_for_generated_wrappers",
253255
"//tensorflow/python:platform_test",
254256
"//tensorflow/python:spectral_ops",

tensorflow/compiler/tests/fft_test.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,10 @@
2121
import itertools
2222

2323
import numpy as np
24+
import scipy.signal as sps
2425

2526
from tensorflow.compiler.tests.xla_test import XLATestCase
27+
from tensorflow.contrib.signal.python.ops import spectral_ops as signal
2628
from tensorflow.python.framework import dtypes
2729
from tensorflow.python.ops import array_ops
2830
from tensorflow.python.ops import spectral_ops
@@ -76,6 +78,29 @@ def _VerifyFftMethod(self, inner_dims, complex_to_input, input_to_expected,
7678
value = sess.run(out, {ph: data})
7779
self.assertAllClose(expected, value, rtol=RTOL, atol=ATOL)
7880

81+
def testContribSignalSTFT(self):
82+
ws = 512
83+
hs = 128
84+
dims = (ws * 20,)
85+
shape = BATCH_DIMS + dims
86+
data = np.arange(np.prod(shape)) / np.prod(dims)
87+
np.random.seed(123)
88+
np.random.shuffle(data)
89+
data = np.reshape(data.astype(np.float32), shape)
90+
window = sps.get_window("hann", ws)
91+
expected = sps.stft(
92+
data, nperseg=ws, noverlap=ws - hs, boundary=None, window=window)[2]
93+
expected = np.swapaxes(expected, -1, -2)
94+
expected *= window.sum() # scipy divides by window sum
95+
with self.test_session() as sess:
96+
with self.test_scope():
97+
ph = array_ops.placeholder(
98+
dtypes.as_dtype(data.dtype), shape=data.shape)
99+
out = signal.stft(ph, ws, hs)
100+
101+
value = sess.run(out, {ph: data})
102+
self.assertAllClose(expected, value, rtol=RTOL, atol=ATOL)
103+
79104
def testFFT(self):
80105
self._VerifyFftMethod(INNER_DIMS_1D, lambda x: x, np.fft.fft,
81106
spectral_ops.fft)

0 commit comments

Comments
 (0)