You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I've tried several different versions of the following code, all of which work when running locally but hang forever in DataBricks
(single node, 13.3 LTS ML runtime):
import os
import pandas as pd
import numpy as np
from petastorm.spark import SparkDatasetConverter, make_spark_converter
from pyspark.sql import SparkSession
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import (
Dense,
Input,
Dropout,
concatenate,
)
from tensorflow.keras.models import Model
spark = SparkSession.builder.getOrCreate()
n_rows = 100
n_outcomes = 15
input_shapes = [10, 5, 2]
possible_outcomes = np.eye(n_outcomes).astype(int)
df = pd.DataFrame(
{
'c1': [[np.random.rand() for _ in range(input_shapes[0])] for __ in range(n_rows)],
'c2': [[np.random.rand() for _ in range(input_shapes[1])] for __ in range(n_rows)],
'c3': [[np.random.rand() for _ in range(input_shapes[2])] for __ in range(n_rows)],
'target': [[int(x) for x in possible_outcomes[np.random.choice(possible_outcomes.shape[0])]] for __ in range(n_rows)],
}
)
from pyspark.sql.types import IntegerType, ArrayType, StructType, StructField, FloatType
schema = StructType([
StructField("c1", ArrayType(FloatType())),
StructField("c2", ArrayType(FloatType())),
StructField("c3", ArrayType(FloatType())),
StructField("target", ArrayType(IntegerType()))
])
sdf = spark.createDataFrame(df, schema=schema)
sdf = sdf.repartition(5)
sdf.show(4)
inputs = []
outputs = []
for shape in input_shapes:
input_layer = Input(shape=(shape,))
x = Dense(24, activation='relu')(input_layer)
x = Dropout(0.2)(x)
output_layer = Dense(24, activation='relu')(x)
inputs.append(input_layer)
outputs.append(output_layer)
combined_input = concatenate(outputs)
z = Dense(256, activation='relu')(combined_input)
z = Dropout(0.2)(z)
z = Dense(256, activation='relu')(z)
final_output = Dense(n_outcomes, activation='softmax')(z)
model = Model(inputs=inputs, outputs=final_output)
model.compile(
loss='categorical_crossentropy',
optimizer=Adam(),
experimental_run_tf_function=False,
)
is_local_mode = True
if is_local_mode:
cwd = os.getcwd()
spark.conf.set(SparkDatasetConverter.PARENT_CACHE_DIR_URL_CONF, f"file://{cwd}/cache")
else:
spark.conf.set(SparkDatasetConverter.PARENT_CACHE_DIR_URL_CONF, "file:///dbfs/tmp/petastorm/cache")
converter_train = make_spark_converter(sdf)
BATCH_SIZE = 10
NUM_EPOCHS = 5
with converter_train.make_tf_dataset(batch_size=BATCH_SIZE) as train_dataset:
train_dataset = train_dataset.map(lambda x: (x[:-1], x[-1]))
steps_per_epoch = len(converter_train) // BATCH_SIZE
print([i for i in train_dataset.take(1)])
hist = model.fit(train_dataset,
steps_per_epoch=steps_per_epoch,
epochs=NUM_EPOCHS,
verbose=2)
The text was updated successfully, but these errors were encountered:
I had the same problem in Databricks. Is MLflow autologging on by any chance?
It seems that MLflow tries to load the dataset in memory for logging purposes, which is not possible for the endless stream that Petastorm generates when num_epochs is not specified in make_tf_dataset. Additionally, it can be very slow and prone to OOM errors even when num_epochs is defined.
Adding the following flag in the autologging call fixed it for me:
I've tried several different versions of the following code, all of which work when running locally but hang forever in DataBricks
(single node, 13.3 LTS ML runtime):
The text was updated successfully, but these errors were encountered: