Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Petastorm hangs forever in DataBricks #804

Open
juzzmac opened this issue Feb 20, 2024 · 1 comment
Open

Petastorm hangs forever in DataBricks #804

juzzmac opened this issue Feb 20, 2024 · 1 comment

Comments

@juzzmac
Copy link

juzzmac commented Feb 20, 2024

I've tried several different versions of the following code, all of which work when running locally but hang forever in DataBricks
(single node, 13.3 LTS ML runtime):

import os
import pandas as pd
import numpy as np
from petastorm.spark import SparkDatasetConverter, make_spark_converter
from pyspark.sql import SparkSession
from tensorflow.keras.optimizers import Adam

from tensorflow.keras.layers import (
    Dense,
    Input,
    Dropout,
    concatenate,
)

from tensorflow.keras.models import Model

spark = SparkSession.builder.getOrCreate()

n_rows = 100
n_outcomes = 15
input_shapes = [10, 5, 2]
possible_outcomes = np.eye(n_outcomes).astype(int)
df = pd.DataFrame(
    {
        'c1': [[np.random.rand() for _ in range(input_shapes[0])] for __ in range(n_rows)],
        'c2': [[np.random.rand() for _ in range(input_shapes[1])] for __ in range(n_rows)],
        'c3': [[np.random.rand() for _ in range(input_shapes[2])] for __ in range(n_rows)],
        'target': [[int(x) for x in possible_outcomes[np.random.choice(possible_outcomes.shape[0])]] for __ in range(n_rows)],
    }
)

from pyspark.sql.types import IntegerType, ArrayType, StructType, StructField, FloatType
schema = StructType([
    StructField("c1", ArrayType(FloatType())),
    StructField("c2", ArrayType(FloatType())),
    StructField("c3", ArrayType(FloatType())),
    StructField("target", ArrayType(IntegerType()))
])
sdf = spark.createDataFrame(df, schema=schema)
sdf = sdf.repartition(5)
sdf.show(4)

inputs = []
outputs = []
for shape in input_shapes:
    input_layer = Input(shape=(shape,))
    x = Dense(24, activation='relu')(input_layer)
    x = Dropout(0.2)(x)
    output_layer = Dense(24, activation='relu')(x)
    inputs.append(input_layer)
    outputs.append(output_layer)

combined_input = concatenate(outputs)
z = Dense(256, activation='relu')(combined_input)
z = Dropout(0.2)(z)
z = Dense(256, activation='relu')(z)
final_output = Dense(n_outcomes, activation='softmax')(z)
model = Model(inputs=inputs, outputs=final_output)

model.compile(
    loss='categorical_crossentropy',
    optimizer=Adam(),
    experimental_run_tf_function=False,
)

is_local_mode = True

if is_local_mode:
    cwd = os.getcwd()
    spark.conf.set(SparkDatasetConverter.PARENT_CACHE_DIR_URL_CONF, f"file://{cwd}/cache")
else:
    spark.conf.set(SparkDatasetConverter.PARENT_CACHE_DIR_URL_CONF, "file:///dbfs/tmp/petastorm/cache")

converter_train = make_spark_converter(sdf)
BATCH_SIZE = 10
NUM_EPOCHS = 5

with converter_train.make_tf_dataset(batch_size=BATCH_SIZE) as train_dataset:
    train_dataset = train_dataset.map(lambda x: (x[:-1], x[-1]))
    steps_per_epoch = len(converter_train) // BATCH_SIZE
    print([i for i in train_dataset.take(1)])

    hist = model.fit(train_dataset,
                     steps_per_epoch=steps_per_epoch,
                     epochs=NUM_EPOCHS,
                     verbose=2)

@joramdevreede2
Copy link

Hi @juzzmac,

I had the same problem in Databricks. Is MLflow autologging on by any chance?

It seems that MLflow tries to load the dataset in memory for logging purposes, which is not possible for the endless stream that Petastorm generates when num_epochs is not specified in make_tf_dataset. Additionally, it can be very slow and prone to OOM errors even when num_epochs is defined.

Adding the following flag in the autologging call fixed it for me:

mlflow.tensorflow.autolog(log_datasets=False)

Hope this solves it!

See also: mlflow/mlflow#9600

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants