Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a retryable mechanism for SWA features when files get deleted #1127

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ import org.apache.spark.sql.SparkSession
*/
object LocalFeatureJoinJob {

// This is a config for local test only to induce a FileNotFoundException.
var shouldRetryAddingSWAFeatures = false

// for user convenience, create spark session within this function, so user does not need to create one
// this also ensure it has same setting as the real feathr join job
val ss: SparkSession = createSparkSession(enableHiveSupport = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ private[offline] class DataFrameFeatureJoiner(logicalPlan: MultiStageJoinPlan, d
offline.FeatureDataFrame(obsToJoinWithFeatures, Map())
} else {
val swaJoiner = new SlidingWindowAggregationJoiner(featureGroups.allWindowAggFeatures, anchorToDataSourceMapper)
swaJoiner.joinWindowAggFeaturesAsDF(
val (featureDataFrame, retryableFeatureNames) = swaJoiner.joinWindowAggFeaturesAsDF(
ss,
obsToJoinWithFeatures,
joinConfig,
Expand All @@ -338,6 +338,28 @@ private[offline] class DataFrameFeatureJoiner(logicalPlan: MultiStageJoinPlan, d
swaObsTime,
failOnMissingPartition,
swaHandler)

// We will retry the SWA features which could not added because of changing data.
val retryableErasedEntityTaggedFeatures = requiredWindowAggFeatures.filter(x => retryableFeatureNames.contains(x.getFeatureName))

// Keep only the features which are to be retried.
val updatedWindowAggFeatureStages = windowAggFeatureStages.map(x => (x._1, x._2.intersect(retryableFeatureNames)))
if (retryableFeatureNames.nonEmpty) {
swaJoiner.joinWindowAggFeaturesAsDF(
ss,
featureDataFrame.df,
joinConfig,
keyTagIntsToStrings,
updatedWindowAggFeatureStages,
retryableErasedEntityTaggedFeatures,
bloomFilters,
swaObsTime,
failOnMissingPartition,
swaHandler,
false)._1
} else {
featureDataFrame
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import com.linkedin.feathr.offline.anchored.keyExtractor.{MVELSourceKeyExtractor
import com.linkedin.feathr.offline.client.DataFrameColName
import com.linkedin.feathr.offline.config.FeatureJoinConfig
import com.linkedin.feathr.offline.exception.FeathrIllegalStateException
import com.linkedin.feathr.offline.job.PreprocessedDataFrameManager
import com.linkedin.feathr.offline.job.{LocalFeatureJoinJob, PreprocessedDataFrameManager}
import com.linkedin.feathr.offline.join.DataFrameKeyCombiner
import com.linkedin.feathr.offline.source.DataSource
import com.linkedin.feathr.offline.source.accessor.DataSourceAccessor
Expand All @@ -22,11 +22,14 @@ import com.linkedin.feathr.offline.{FeatureDataFrame, JoinStage}
import com.linkedin.feathr.swj.{FactData, LabelData, SlidingWindowJoin}
import com.linkedin.feathr.{common, offline}
import org.apache.logging.log4j.LogManager
import org.apache.spark.SparkException
import org.apache.spark.sql.functions.{col, lit}
import org.apache.spark.sql.{DataFrame, SparkSession}
import org.apache.spark.util.sketch.BloomFilter

import java.io.FileNotFoundException
import scala.collection.mutable
import scala.collection.mutable.ArrayBuffer

/**
* Case class containing other SWA handler methods
Expand Down Expand Up @@ -65,6 +68,9 @@ private[offline] class SlidingWindowAggregationJoiner(
* @param obsDF Observation data
* @param swaObsTimeOpt start and end time of observation data
* @param failOnMissingPartition whether to fail the data loading if some of the date partitions are missing.
* @param swaHandler External SWA libraries if any should handle the SWA join
* @param shouldRetry If this is a retry attempt to retry adding features which were missed because of IOExceptions.
* Default is set to true.
* @return pair of :
* 1) dataframe with feature column appended to the obsData,
* it can be converted to a pair RDD of (observation data record, feature record),
Expand All @@ -81,7 +87,9 @@ private[offline] class SlidingWindowAggregationJoiner(
bloomFilters: Option[Map[Seq[Int], BloomFilter]],
swaObsTimeOpt: Option[DateTimeInterval],
failOnMissingPartition: Boolean,
swaHandler: Option[SWAHandler]): FeatureDataFrame = {
swaHandler: Option[SWAHandler],
shouldRetry: Boolean = true): (FeatureDataFrame, Seq[String]) = {
val retryableSwaFeatures = ArrayBuffer.empty[String]
val joinConfigSettings = joinConfig.settings
// extract time window settings
if (joinConfigSettings.isEmpty) {
Expand Down Expand Up @@ -255,11 +263,23 @@ private[offline] class SlidingWindowAggregationJoiner(
SlidingWindowFeatureUtils.getFactDataDef(filteredFactData, anchorWithSourceToDFMap.keySet.toSeq, featuresToDelayImmutableMap, selectedFeatures)
}
val origContextObsColumns = labelDataDef.dataSource.columns
val shouldRetryForMissingData = FeathrUtils.getFeathrJobParam(ss.sparkContext.getConf, FeathrUtils.RETRY_ADDING_MISSING_SWA_FEATURES).toBoolean
try {
// THIS IS FOR LOCAL TEST ONLY. It is to induce a spark exception with the root cause of FileNotFoundException.
if (shouldRetry && FeathrUtils.getFeathrJobParam(ss.sparkContext.getConf, FeathrUtils.LOCAL_RETRY_ADDING_MISSING_SWA_FEATURES).toBoolean)
throw new SparkException("file not found", new FileNotFoundException())
contextDF = if (swaHandler.isDefined) swaHandler.get.join(labelDataDef, factDataDefs.toList) else SlidingWindowJoin.join(labelDataDef, factDataDefs.toList)
} catch {
// Many times the files which are to be loaded gets deleted midway. We will retry all the features at this stage again by reloading the datasets.
case exception: SparkException => if (shouldRetry && shouldRetryForMissingData && exception.getCause != null && exception.getCause.isInstanceOf[FileNotFoundException]) {
val unjoinedFeatures = factDataDefs.flatMap(factData => factData.aggFeatures.map(_.name))
retryableSwaFeatures ++= unjoinedFeatures
}
}

contextDF = if (swaHandler.isDefined) swaHandler.get.join(labelDataDef, factDataDefs.toList) else SlidingWindowJoin.join(labelDataDef, factDataDefs.toList)

val finalJoinedFeatures = joinedFeatures diff retryableSwaFeatures
contextDF = if (shouldFilterNulls && !factDataRowsWithNulls.isEmpty) {
val nullDfWithFeatureCols = joinedFeatures.foldLeft(factDataRowsWithNulls)((s, x) => s.withColumn(x, lit(null)))
val nullDfWithFeatureCols = finalJoinedFeatures.foldLeft(factDataRowsWithNulls)((s, x) => s.withColumn(x, lit(null)))
contextDF.union(nullDfWithFeatureCols)
} else contextDF

Expand All @@ -272,13 +292,13 @@ private[offline] class SlidingWindowAggregationJoiner(
.asInstanceOf[TimeWindowConfigurableAnchorExtractor].features(nameToFeatureAnchor._1).columnFormat)

val FeatureDataFrame(withFDSFeatureDF, inferredTypes) =
SlidingWindowFeatureUtils.convertSWADFToFDS(contextDF, joinedFeatures.toSet, featureNameToColumnFormat, userSpecifiedTypesConfig)
SlidingWindowFeatureUtils.convertSWADFToFDS(contextDF, finalJoinedFeatures.toSet, featureNameToColumnFormat, userSpecifiedTypesConfig)
// apply default on FDS dataset
val withFeatureContextDF =
substituteDefaults(withFDSFeatureDF, defaults.keys.filter(joinedFeatures.contains).toSeq, defaults, userSpecifiedTypesConfig, ss)
substituteDefaults(withFDSFeatureDF, defaults.keys.filter(finalJoinedFeatures.contains).toSeq, defaults, userSpecifiedTypesConfig, ss)

allInferredFeatureTypes ++= inferredTypes
contextDF = standardizeFeatureColumnNames(origContextObsColumns, withFeatureContextDF, joinedFeatures, keyTags.map(keyTagList))
contextDF = standardizeFeatureColumnNames(origContextObsColumns, withFeatureContextDF, finalJoinedFeatures, keyTags.map(keyTagList))
if (shouldCheckPoint(ss)) {
// checkpoint complicated dataframe for each stage to avoid Spark failure
contextDF = contextDF.checkpoint(true)
Expand All @@ -292,7 +312,7 @@ private[offline] class SlidingWindowAggregationJoiner(
}
}
}})
offline.FeatureDataFrame(contextDF, allInferredFeatureTypes.toMap)
(offline.FeatureDataFrame(contextDF, allInferredFeatureTypes.toMap), retryableSwaFeatures)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ private[feathr] object FeathrUtils {
val ENABLE_SANITY_CHECK_MODE = "enable.sanity.check.mode"
val SANITY_CHECK_MODE_ROW_COUNT = "sanity.check.row.count"
val FILTER_NULLS = "filter.nulls"
val RETRY_ADDING_MISSING_SWA_FEATURES = "retry.adding.missing.swa.features"
// Retryer to be configured only for local tests
val LOCAL_RETRY_ADDING_MISSING_SWA_FEATURES = "local.retry.adding.missing.swa.features"
val STRING_PARAMETER_DELIMITER = ","

// Used to check if the current dataframe has satisfied the checkpoint frequency
Expand Down Expand Up @@ -86,7 +89,9 @@ private[feathr] object FeathrUtils {
SPARK_JOIN_MIN_PARALLELISM -> (SQLConf.buildConf(getFullConfigKeyName(SPARK_JOIN_MIN_PARALLELISM )).stringConf.createOptional, "10"),
ENABLE_SANITY_CHECK_MODE -> (SQLConf.buildConf(getFullConfigKeyName(ENABLE_SANITY_CHECK_MODE )).stringConf.createOptional, "false"),
SANITY_CHECK_MODE_ROW_COUNT -> (SQLConf.buildConf(getFullConfigKeyName(SANITY_CHECK_MODE_ROW_COUNT )).stringConf.createOptional, "10"),
FILTER_NULLS -> (SQLConf.buildConf(getFullConfigKeyName(FILTER_NULLS )).stringConf.createOptional, "false")
FILTER_NULLS -> (SQLConf.buildConf(getFullConfigKeyName(FILTER_NULLS)).stringConf.createOptional, "false"),
LOCAL_RETRY_ADDING_MISSING_SWA_FEATURES -> (SQLConf.buildConf(getFullConfigKeyName(LOCAL_RETRY_ADDING_MISSING_SWA_FEATURES)).stringConf.createOptional, "false"),
RETRY_ADDING_MISSING_SWA_FEATURES -> (SQLConf.buildConf(getFullConfigKeyName(RETRY_ADDING_MISSING_SWA_FEATURES)).stringConf.createOptional, "true")
)

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package com.linkedin.feathr.offline

import com.linkedin.feathr.offline.AssertFeatureUtils.{rowApproxEquals, validateRows}
import com.linkedin.feathr.offline.util.FeathrUtils
import com.linkedin.feathr.offline.util.FeathrUtils.{FILTER_NULLS, SKIP_MISSING_FEATURE, setFeathrJobParam}
import com.linkedin.feathr.offline.util.FeathrUtils.{FILTER_NULLS, LOCAL_RETRY_ADDING_MISSING_SWA_FEATURES, SKIP_MISSING_FEATURE, setFeathrJobParam}
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.types.{LongType, StructField, StructType}
Expand Down Expand Up @@ -328,6 +328,80 @@ class SlidingWindowAggIntegTest extends FeathrIntegTest {
setFeathrJobParam(FILTER_NULLS, "false")
}

/**
* test SWA with dense vector feature with retry.
* This should get handled by the SWA retry method.
*/
@Test
def testLocalAnchorSWAWithDenseVectorWithRetry(): Unit = {
setFeathrJobParam(LOCAL_RETRY_ADDING_MISSING_SWA_FEATURES, "true")
val res = runLocalFeatureJoinForTest(
"""
| settings: {
| joinTimeSettings: {
| timestampColumn: {
| def: "timestamp"
| format: "yyyy-MM-dd"
| }
| simulateTimeDelay: 1d
| }
|}
|
|features: [
| {
| key: [mId],
| featureList: ["aEmbedding", "memberEmbeddingAutoTZ"]
| }
|]
""".stripMargin,
"""
|sources: {
| swaSource: {
| location: { path: "generation/daily" }
| timePartitionPattern: "yyyy/MM/dd"
| timeWindowParameters: {
| timestampColumn: "timestamp"
| timestampColumnFormat: "yyyy-MM-dd"
| }
| }
|}
|
|anchors: {
| swaAnchor: {
| source: "swaSource"
| key: "x"
| features: {
| aEmbedding: {
| def: "embedding"
| aggregation: LATEST
| window: 3d
| }
| memberEmbeddingAutoTZ: {
| def: "embedding"
| aggregation: LATEST
| window: 3d
| type: {
| type: TENSOR
| tensorCategory: SPARSE
| dimensionType: [INT]
| valType: FLOAT
| }
| }
| }
| }
|}
""".stripMargin,
observationDataPath = "slidingWindowAgg/csvTypeTimeFile1.csv").data

val featureList = res.collect().sortBy(row => if (row.get(0) != null) row.getAs[String]("mId") else "null")

assertEquals(featureList.size, 2)
assertEquals(featureList(0).getAs[Row]("aEmbedding"), mutable.WrappedArray.make(Array(5.5f, 5.8f)))
assertEquals(featureList(0).getAs[Row]("memberEmbeddingAutoTZ"),
TestUtils.build1dSparseTensorFDSRow(Array(0, 1), Array(5.5f, 5.8f)))
setFeathrJobParam(LOCAL_RETRY_ADDING_MISSING_SWA_FEATURES, "true")
}

/**
* test SWA with dense vector feature
* The feature dataset generation/daily has different but compatible schema for different partitions,
Expand Down