Skip to content

Commit 89fd2d5

Browse files
authored
[SW-2674] ChicagoCrimeApp refactor (#2708)
* [SW-2673] ChicagoCrimeApp refactor * avoid joining single row
1 parent 0a04576 commit 89fd2d5

File tree

1 file changed

+133
-97
lines changed

1 file changed

+133
-97
lines changed

examples/src/main/scala/ai/h2o/sparkling/examples/ChicagoCrimeApp.scala

Lines changed: 133 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
package ai.h2o.sparkling.examples
1919

2020
import java.io.File
21-
2221
import ai.h2o.sparkling.H2OContext
2322
import ai.h2o.sparkling.ml.algos.{H2ODeepLearning, H2OGBM}
2423
import ai.h2o.sparkling.ml.models.H2OMOJOModel
@@ -27,57 +26,115 @@ import org.apache.spark.sql.{DataFrame, SparkSession}
2726

2827
object ChicagoCrimeApp {
2928

30-
private val seasonUdf = udf(getSeason _)
31-
private val weekendUdf = udf(isWeekend _)
29+
private val seasonUdf = udf(monthToSeason _)
30+
private val weekendUdf = udf((isWeekend _).andThen(boolToInt))
3231
private val dayOfWeekUdf = udf(dayOfWeek _)
3332

3433
def main(args: Array[String]) {
35-
val spark = SparkSession
34+
implicit val spark = SparkSession
3635
.builder()
3736
.appName("Chicago Crime App")
3837
.getOrCreate()
38+
import spark.implicits._
3939

40-
val weatherDataPath = "./examples/smalldata/chicago/chicagoAllWeather.csv"
41-
val weatherDataFile = s"file://${new File(weatherDataPath).getAbsolutePath}"
42-
val weatherTable = createWeatherTable(spark, weatherDataFile)
43-
weatherTable.createOrReplaceTempView("chicagoWeather")
44-
45-
val censusDataPath = "./examples/smalldata/chicago/chicagoCensus.csv"
46-
val censusDataFile = s"file://${new File(censusDataPath).getAbsolutePath}"
47-
val censusTable = createCensusTable(spark, censusDataFile)
48-
censusTable.createOrReplaceTempView("chicagoCensus")
49-
50-
val crimesDataPath = "./examples/smalldata/chicago/chicagoCrimes10k.csv"
51-
val crimesDataFile = s"file://${new File(crimesDataPath).getAbsolutePath}"
52-
val crimesTable = createCrimeTable(spark, crimesDataFile)
53-
crimesTable.createOrReplaceTempView("chicagoCrime")
54-
55-
// Join crimes and weather tables
56-
val crimeWeather = spark.sql("""SELECT
57-
|a.Year, a.Month, a.Day, a.WeekNum, a.HourOfDay, a.Weekend, a.Season, a.WeekDay,
58-
|a.IUCR, a.Primary_Type, a.Location_Description, a.Community_Area, a.District,
59-
|a.Arrest, a.Domestic, a.Beat, a.Ward, a.FBI_Code,
60-
|b.minTemp, b.maxTemp, b.meanTemp,
61-
|c.PERCENT_AGED_UNDER_18_OR_OVER_64, c.PER_CAPITA_INCOME, c.HARDSHIP_INDEX,
62-
|c.PERCENT_OF_HOUSING_CROWDED, c.PERCENT_HOUSEHOLDS_BELOW_POVERTY,
63-
|c.PERCENT_AGED_16__UNEMPLOYED, c.PERCENT_AGED_25__WITHOUT_HIGH_SCHOOL_DIPLOMA
64-
|FROM chicagoCrime a
65-
|JOIN chicagoWeather b
66-
|ON a.Year = b.year AND a.Month = b.month AND a.Day = b.day
67-
|JOIN chicagoCensus c
68-
|ON a.Community_Area = c.Community_Area_Number""".stripMargin)
69-
70-
val gbmModel = trainGBM(crimeWeather)
71-
val dlModel = trainDeepLearning(crimeWeather)
72-
73-
val crimes = Seq(
74-
Crime("02/08/2015 11:43:58 PM", 1811, "NARCOTICS", "STREET", Domestic = false, 422, 4, 7, 46, 18),
75-
Crime("02/08/2015 11:00:39 PM", 1150, "DECEPTIVE PRACTICE", "RESIDENCE", Domestic = false, 923, 9, 14, 63, 11))
76-
score(spark, crimes, gbmModel, dlModel, censusTable)
77-
}
40+
val weatherTable = loadCsv("./examples/smalldata/chicago/chicagoAllWeather.csv").drop("date")
41+
val chicagoWeatherTableName = "chicagoWeather"
42+
weatherTable.createOrReplaceTempView(chicagoWeatherTableName)
43+
44+
val censusTable = loadCsv("./examples/smalldata/chicago/chicagoCensus.csv")
45+
val chicagoCensusTableName = "chicagoCensus"
46+
censusTable.createOrReplaceTempView(chicagoCensusTableName)
47+
48+
val crimesTable = addAdditionalDateColumns(loadCsv("./examples/smalldata/chicago/chicagoCrimes10k.csv"))
49+
val chicagoCrimeTableName = "chicagoCrime"
50+
crimesTable.createOrReplaceTempView(chicagoCrimeTableName)
51+
52+
val crimeDataColumnsForTraining = Seq(
53+
$"cr.Year",
54+
$"cr.Month",
55+
$"cr.Day",
56+
$"WeekNum",
57+
$"HourOfDay",
58+
$"Weekend",
59+
$"Season",
60+
$"WeekDay",
61+
$"IUCR",
62+
$"Primary_Type",
63+
$"Location_Description",
64+
$"Community_Area",
65+
$"District",
66+
$"Arrest",
67+
$"Domestic",
68+
$"Beat",
69+
$"Ward",
70+
$"FBI_Code")
71+
72+
val censusDataColumnsForTraining = Seq(
73+
$"PERCENT_AGED_UNDER_18_OR_OVER_64",
74+
$"PER_CAPITA_INCOME",
75+
$"HARDSHIP_INDEX",
76+
$"PERCENT_OF_HOUSING_CROWDED",
77+
$"PERCENT_HOUSEHOLDS_BELOW_POVERTY",
78+
$"PERCENT_AGED_16_UNEMPLOYED",
79+
$"PERCENT_AGED_25_WITHOUT_HIGH_SCHOOL_DIPLOMA")
80+
81+
val weatherDataColumnsForTraining = Seq($"minTemp", $"maxTemp", $"meanTemp")
82+
83+
val joinedDataForTraining = spark
84+
.table(chicagoCrimeTableName)
85+
.as("cr")
86+
.join(
87+
spark.table(chicagoWeatherTableName).as("we"),
88+
$"cr.Year" === $"we.year" and $"cr.Month" === $"we.month" and $"cr.Day" === $"we.day")
89+
.join(spark.table(chicagoCensusTableName).as("ce"), $"cr.Community_Area" === $"ce.Community_Area_Number")
90+
.select(crimeDataColumnsForTraining ++ censusDataColumnsForTraining ++ weatherDataColumnsForTraining: _*)
7891

79-
def trainGBM(train: DataFrame): H2OMOJOModel = {
8092
H2OContext.getOrCreate()
93+
val gbmModel = trainGBM(joinedDataForTraining)
94+
val dlModel = trainDeepLearning(joinedDataForTraining)
95+
96+
val crimesToScore = Seq(
97+
CrimeWithCensusData(
98+
date = "02/08/2015 11:43:58 PM",
99+
IUCR = 1811,
100+
Primary_Type = "NARCOTICS",
101+
Location_Description = "STREET",
102+
Domestic = false,
103+
Beat = 422,
104+
District = 4,
105+
Ward = 7,
106+
Community_Area = 46,
107+
FBI_Code = 18,
108+
PERCENT_AGED_UNDER_18_OR_OVER_64 = 41.1,
109+
PER_CAPITA_INCOME = 16579,
110+
HARDSHIP_INDEX = 75,
111+
PERCENT_OF_HOUSING_CROWDED = 4.7,
112+
PERCENT_HOUSEHOLDS_BELOW_POVERTY = 29.8,
113+
PERCENT_AGED_16_UNEMPLOYED = 19.7,
114+
PERCENT_AGED_25_WITHOUT_HIGH_SCHOOL_DIPLOMA = 26.6),
115+
CrimeWithCensusData(
116+
date = "02/08/2015 11:00:39 PM",
117+
IUCR = 1150,
118+
Primary_Type = "DECEPTIVE PRACTICE",
119+
Location_Description = "RESIDENCE",
120+
Domestic = false,
121+
Beat = 923,
122+
District = 9,
123+
Ward = 14,
124+
Community_Area = 63,
125+
FBI_Code = 11,
126+
PERCENT_AGED_UNDER_18_OR_OVER_64 = 38.8,
127+
PER_CAPITA_INCOME = 12171,
128+
HARDSHIP_INDEX = 93,
129+
PERCENT_OF_HOUSING_CROWDED = 15.8,
130+
PERCENT_HOUSEHOLDS_BELOW_POVERTY = 23.4,
131+
PERCENT_AGED_16_UNEMPLOYED = 18.2,
132+
PERCENT_AGED_25_WITHOUT_HIGH_SCHOOL_DIPLOMA = 51.5)).toDF
133+
134+
score(addAdditionalDateColumns(crimesToScore), gbmModel, dlModel)
135+
}
136+
137+
private def trainGBM(train: DataFrame): H2OMOJOModel = {
81138
val gbm = new H2OGBM()
82139
.setSplitRatio(0.8)
83140
.setLabelCol("Arrest")
@@ -88,8 +145,7 @@ object ChicagoCrimeApp {
88145
gbm.fit(train)
89146
}
90147

91-
def trainDeepLearning(train: DataFrame): H2OMOJOModel = {
92-
H2OContext.getOrCreate()
148+
private def trainDeepLearning(train: DataFrame): H2OMOJOModel = {
93149
val dl = new H2ODeepLearning()
94150
.setSplitRatio(0.8)
95151
.setLabelCol("Arrest")
@@ -102,58 +158,33 @@ object ChicagoCrimeApp {
102158
dl.fit(train)
103159
}
104160

105-
def score(
106-
spark: SparkSession,
107-
crimes: Seq[Crime],
108-
gbmModel: H2OMOJOModel,
109-
dlModel: H2OMOJOModel,
110-
censusTable: DataFrame): Unit = {
111-
crimes.foreach { crime =>
112-
val arrestGBM = scoreEvent(spark, crime, gbmModel, censusTable)
113-
val arrestDL = scoreEvent(spark, crime, dlModel, censusTable)
114-
println(s"""
115-
|Crime: $crime
116-
| Will be arrested based on DeepLearning: $arrestDL
117-
| Will be arrested based on GBM: $arrestGBM
161+
private def score(crimes: DataFrame, gbmModel: H2OMOJOModel, dlModel: H2OMOJOModel)(
162+
implicit spark: SparkSession): Unit = {
163+
import spark.implicits._
164+
val arrestGBM = gbmModel.transform(crimes)
165+
val arrestDL = dlModel.transform(crimes)
166+
val willBeArrestedPrediction = $"prediction" === "1"
167+
println(s"""
168+
| Will be arrested based on DeepLearning: ${arrestDL.where(willBeArrestedPrediction).count()}
169+
| Will be arrested based on GBM: ${arrestGBM.where(willBeArrestedPrediction).count()}
118170
|
119171
""".stripMargin)
120-
}
121-
}
122-
123-
def scoreEvent(spark: SparkSession, crime: Crime, model: H2OMOJOModel, censusTable: DataFrame): Boolean = {
124-
// Create Spark DataFrame from a single row
125-
import spark.implicits._
126-
val df = addAdditionalDateColumns(spark, spark.sparkContext.parallelize(Seq(crime)).toDF)
127-
// Join table with census data
128-
val row = censusTable.join(df).where('Community_Area === 'Community_Area_Number)
129-
val predictTable = model.transform(row)
130-
predictTable.collect().head.getAs[String]("prediction") == "1"
131172
}
132173

133-
def createWeatherTable(spark: SparkSession, datafile: String): DataFrame = {
134-
val df = spark.read.option("header", "true").option("inferSchema", "true").csv(datafile)
135-
df.drop(df.columns(0))
136-
}
137-
138-
def createCensusTable(spark: SparkSession, datafile: String): DataFrame = {
174+
private def loadCsv(dataPath: String)(implicit spark: SparkSession): DataFrame = {
175+
val datafile = s"file://${new File(dataPath).getAbsolutePath}"
139176
val df = spark.read.option("header", "true").option("inferSchema", "true").csv(datafile)
140177
val renamedColumns = df.columns.map { col =>
141-
val name = col.trim.replace(' ', '_').replace('+', '_')
178+
val name = col.trim
179+
.replace(' ', '_')
180+
.replace('+', '_')
181+
.replace("__", "_")
142182
df(col).as(name)
143183
}
144184
df.select(renamedColumns: _*)
145185
}
146186

147-
def createCrimeTable(spark: SparkSession, datafile: String): DataFrame = {
148-
val df = spark.read.option("header", "true").option("inferSchema", "true").csv(datafile)
149-
val renamedColumns = df.columns.map { col =>
150-
val name = col.trim.replace(' ', '_').replace('+', '_')
151-
df(col).as(name)
152-
}
153-
addAdditionalDateColumns(spark, df.select(renamedColumns: _*))
154-
}
155-
156-
def addAdditionalDateColumns(spark: SparkSession, df: DataFrame): DataFrame = {
187+
private def addAdditionalDateColumns(df: DataFrame)(implicit spark: SparkSession): DataFrame = {
157188
import org.apache.spark.sql.functions._
158189
import spark.implicits._
159190
df.withColumn("DateTmp", from_unixtime(unix_timestamp('Date, "MM/dd/yyyy hh:mm:ss a")))
@@ -168,18 +199,16 @@ object ChicagoCrimeApp {
168199
.drop('DateTmp)
169200
}
170201

171-
private def getSeason(month: Int): String = {
172-
val seasonNum =
173-
if (month >= 3 && month <= 5) 0 // Spring
174-
else if (month >= 6 && month <= 8) 1 // Summer
175-
else if (month >= 9 && month <= 10) 2 // Autumn
176-
else 3 // Winter
177-
SEASONS(seasonNum)
202+
private def monthToSeason(month: Int): String = {
203+
if (month >= 3 && month <= 5) "Spring"
204+
else if (month >= 6 && month <= 8) "Summer"
205+
else if (month >= 9 && month <= 10) "Autumn"
206+
else "Winter"
178207
}
179208

180-
private def SEASONS: Array[String] = Array[String]("Spring", "Summer", "Autumn", "Winter")
209+
private def isWeekend(dayOfWeek: Int): Boolean = dayOfWeek == 7 || dayOfWeek == 6
181210

182-
private def isWeekend(dayOfWeek: Int): Int = if (dayOfWeek == 7 || dayOfWeek == 6) 1 else 0
211+
private def boolToInt(bool: Boolean): Int = if (bool) 1 else 0
183212

184213
private def dayOfWeek(day: String): Int = {
185214
day match {
@@ -194,7 +223,7 @@ object ChicagoCrimeApp {
194223
}
195224
}
196225

197-
case class Crime(
226+
case class CrimeWithCensusData(
198227
date: String,
199228
IUCR: Short,
200229
Primary_Type: String,
@@ -204,6 +233,13 @@ object ChicagoCrimeApp {
204233
District: Byte,
205234
Ward: Byte,
206235
Community_Area: Byte,
207-
FBI_Code: Byte)
236+
FBI_Code: Byte,
237+
PERCENT_AGED_UNDER_18_OR_OVER_64: Double,
238+
PER_CAPITA_INCOME: Int,
239+
HARDSHIP_INDEX: Short,
240+
PERCENT_OF_HOUSING_CROWDED: Double,
241+
PERCENT_HOUSEHOLDS_BELOW_POVERTY: Double,
242+
PERCENT_AGED_16_UNEMPLOYED: Double,
243+
PERCENT_AGED_25_WITHOUT_HIGH_SCHOOL_DIPLOMA: Double)
208244

209245
}

0 commit comments

Comments
 (0)