18
18
package ai .h2o .sparkling .examples
19
19
20
20
import java .io .File
21
-
22
21
import ai .h2o .sparkling .H2OContext
23
22
import ai .h2o .sparkling .ml .algos .{H2ODeepLearning , H2OGBM }
24
23
import ai .h2o .sparkling .ml .models .H2OMOJOModel
@@ -27,57 +26,115 @@ import org.apache.spark.sql.{DataFrame, SparkSession}
27
26
28
27
object ChicagoCrimeApp {
29
28
30
- private val seasonUdf = udf(getSeason _)
31
- private val weekendUdf = udf(isWeekend _)
29
+ private val seasonUdf = udf(monthToSeason _)
30
+ private val weekendUdf = udf(( isWeekend _).andThen(boolToInt) )
32
31
private val dayOfWeekUdf = udf(dayOfWeek _)
33
32
34
33
def main (args : Array [String ]) {
35
- val spark = SparkSession
34
+ implicit val spark = SparkSession
36
35
.builder()
37
36
.appName(" Chicago Crime App" )
38
37
.getOrCreate()
38
+ import spark .implicits ._
39
39
40
- val weatherDataPath = " ./examples/smalldata/chicago/chicagoAllWeather.csv"
41
- val weatherDataFile = s " file:// ${new File (weatherDataPath).getAbsolutePath}"
42
- val weatherTable = createWeatherTable(spark, weatherDataFile)
43
- weatherTable.createOrReplaceTempView(" chicagoWeather" )
44
-
45
- val censusDataPath = " ./examples/smalldata/chicago/chicagoCensus.csv"
46
- val censusDataFile = s " file:// ${new File (censusDataPath).getAbsolutePath}"
47
- val censusTable = createCensusTable(spark, censusDataFile)
48
- censusTable.createOrReplaceTempView(" chicagoCensus" )
49
-
50
- val crimesDataPath = " ./examples/smalldata/chicago/chicagoCrimes10k.csv"
51
- val crimesDataFile = s " file:// ${new File (crimesDataPath).getAbsolutePath}"
52
- val crimesTable = createCrimeTable(spark, crimesDataFile)
53
- crimesTable.createOrReplaceTempView(" chicagoCrime" )
54
-
55
- // Join crimes and weather tables
56
- val crimeWeather = spark.sql(""" SELECT
57
- |a.Year, a.Month, a.Day, a.WeekNum, a.HourOfDay, a.Weekend, a.Season, a.WeekDay,
58
- |a.IUCR, a.Primary_Type, a.Location_Description, a.Community_Area, a.District,
59
- |a.Arrest, a.Domestic, a.Beat, a.Ward, a.FBI_Code,
60
- |b.minTemp, b.maxTemp, b.meanTemp,
61
- |c.PERCENT_AGED_UNDER_18_OR_OVER_64, c.PER_CAPITA_INCOME, c.HARDSHIP_INDEX,
62
- |c.PERCENT_OF_HOUSING_CROWDED, c.PERCENT_HOUSEHOLDS_BELOW_POVERTY,
63
- |c.PERCENT_AGED_16__UNEMPLOYED, c.PERCENT_AGED_25__WITHOUT_HIGH_SCHOOL_DIPLOMA
64
- |FROM chicagoCrime a
65
- |JOIN chicagoWeather b
66
- |ON a.Year = b.year AND a.Month = b.month AND a.Day = b.day
67
- |JOIN chicagoCensus c
68
- |ON a.Community_Area = c.Community_Area_Number""" .stripMargin)
69
-
70
- val gbmModel = trainGBM(crimeWeather)
71
- val dlModel = trainDeepLearning(crimeWeather)
72
-
73
- val crimes = Seq (
74
- Crime (" 02/08/2015 11:43:58 PM" , 1811 , " NARCOTICS" , " STREET" , Domestic = false , 422 , 4 , 7 , 46 , 18 ),
75
- Crime (" 02/08/2015 11:00:39 PM" , 1150 , " DECEPTIVE PRACTICE" , " RESIDENCE" , Domestic = false , 923 , 9 , 14 , 63 , 11 ))
76
- score(spark, crimes, gbmModel, dlModel, censusTable)
77
- }
40
+ val weatherTable = loadCsv(" ./examples/smalldata/chicago/chicagoAllWeather.csv" ).drop(" date" )
41
+ val chicagoWeatherTableName = " chicagoWeather"
42
+ weatherTable.createOrReplaceTempView(chicagoWeatherTableName)
43
+
44
+ val censusTable = loadCsv(" ./examples/smalldata/chicago/chicagoCensus.csv" )
45
+ val chicagoCensusTableName = " chicagoCensus"
46
+ censusTable.createOrReplaceTempView(chicagoCensusTableName)
47
+
48
+ val crimesTable = addAdditionalDateColumns(loadCsv(" ./examples/smalldata/chicago/chicagoCrimes10k.csv" ))
49
+ val chicagoCrimeTableName = " chicagoCrime"
50
+ crimesTable.createOrReplaceTempView(chicagoCrimeTableName)
51
+
52
+ val crimeDataColumnsForTraining = Seq (
53
+ $" cr.Year" ,
54
+ $" cr.Month" ,
55
+ $" cr.Day" ,
56
+ $" WeekNum" ,
57
+ $" HourOfDay" ,
58
+ $" Weekend" ,
59
+ $" Season" ,
60
+ $" WeekDay" ,
61
+ $" IUCR" ,
62
+ $" Primary_Type" ,
63
+ $" Location_Description" ,
64
+ $" Community_Area" ,
65
+ $" District" ,
66
+ $" Arrest" ,
67
+ $" Domestic" ,
68
+ $" Beat" ,
69
+ $" Ward" ,
70
+ $" FBI_Code" )
71
+
72
+ val censusDataColumnsForTraining = Seq (
73
+ $" PERCENT_AGED_UNDER_18_OR_OVER_64" ,
74
+ $" PER_CAPITA_INCOME" ,
75
+ $" HARDSHIP_INDEX" ,
76
+ $" PERCENT_OF_HOUSING_CROWDED" ,
77
+ $" PERCENT_HOUSEHOLDS_BELOW_POVERTY" ,
78
+ $" PERCENT_AGED_16_UNEMPLOYED" ,
79
+ $" PERCENT_AGED_25_WITHOUT_HIGH_SCHOOL_DIPLOMA" )
80
+
81
+ val weatherDataColumnsForTraining = Seq ($" minTemp" , $" maxTemp" , $" meanTemp" )
82
+
83
+ val joinedDataForTraining = spark
84
+ .table(chicagoCrimeTableName)
85
+ .as(" cr" )
86
+ .join(
87
+ spark.table(chicagoWeatherTableName).as(" we" ),
88
+ $" cr.Year" === $" we.year" and $" cr.Month" === $" we.month" and $" cr.Day" === $" we.day" )
89
+ .join(spark.table(chicagoCensusTableName).as(" ce" ), $" cr.Community_Area" === $" ce.Community_Area_Number" )
90
+ .select(crimeDataColumnsForTraining ++ censusDataColumnsForTraining ++ weatherDataColumnsForTraining : _* )
78
91
79
- def trainGBM (train : DataFrame ): H2OMOJOModel = {
80
92
H2OContext .getOrCreate()
93
+ val gbmModel = trainGBM(joinedDataForTraining)
94
+ val dlModel = trainDeepLearning(joinedDataForTraining)
95
+
96
+ val crimesToScore = Seq (
97
+ CrimeWithCensusData (
98
+ date = " 02/08/2015 11:43:58 PM" ,
99
+ IUCR = 1811 ,
100
+ Primary_Type = " NARCOTICS" ,
101
+ Location_Description = " STREET" ,
102
+ Domestic = false ,
103
+ Beat = 422 ,
104
+ District = 4 ,
105
+ Ward = 7 ,
106
+ Community_Area = 46 ,
107
+ FBI_Code = 18 ,
108
+ PERCENT_AGED_UNDER_18_OR_OVER_64 = 41.1 ,
109
+ PER_CAPITA_INCOME = 16579 ,
110
+ HARDSHIP_INDEX = 75 ,
111
+ PERCENT_OF_HOUSING_CROWDED = 4.7 ,
112
+ PERCENT_HOUSEHOLDS_BELOW_POVERTY = 29.8 ,
113
+ PERCENT_AGED_16_UNEMPLOYED = 19.7 ,
114
+ PERCENT_AGED_25_WITHOUT_HIGH_SCHOOL_DIPLOMA = 26.6 ),
115
+ CrimeWithCensusData (
116
+ date = " 02/08/2015 11:00:39 PM" ,
117
+ IUCR = 1150 ,
118
+ Primary_Type = " DECEPTIVE PRACTICE" ,
119
+ Location_Description = " RESIDENCE" ,
120
+ Domestic = false ,
121
+ Beat = 923 ,
122
+ District = 9 ,
123
+ Ward = 14 ,
124
+ Community_Area = 63 ,
125
+ FBI_Code = 11 ,
126
+ PERCENT_AGED_UNDER_18_OR_OVER_64 = 38.8 ,
127
+ PER_CAPITA_INCOME = 12171 ,
128
+ HARDSHIP_INDEX = 93 ,
129
+ PERCENT_OF_HOUSING_CROWDED = 15.8 ,
130
+ PERCENT_HOUSEHOLDS_BELOW_POVERTY = 23.4 ,
131
+ PERCENT_AGED_16_UNEMPLOYED = 18.2 ,
132
+ PERCENT_AGED_25_WITHOUT_HIGH_SCHOOL_DIPLOMA = 51.5 )).toDF
133
+
134
+ score(addAdditionalDateColumns(crimesToScore), gbmModel, dlModel)
135
+ }
136
+
137
+ private def trainGBM (train : DataFrame ): H2OMOJOModel = {
81
138
val gbm = new H2OGBM ()
82
139
.setSplitRatio(0.8 )
83
140
.setLabelCol(" Arrest" )
@@ -88,8 +145,7 @@ object ChicagoCrimeApp {
88
145
gbm.fit(train)
89
146
}
90
147
91
- def trainDeepLearning (train : DataFrame ): H2OMOJOModel = {
92
- H2OContext .getOrCreate()
148
+ private def trainDeepLearning (train : DataFrame ): H2OMOJOModel = {
93
149
val dl = new H2ODeepLearning ()
94
150
.setSplitRatio(0.8 )
95
151
.setLabelCol(" Arrest" )
@@ -102,58 +158,33 @@ object ChicagoCrimeApp {
102
158
dl.fit(train)
103
159
}
104
160
105
- def score (
106
- spark : SparkSession ,
107
- crimes : Seq [Crime ],
108
- gbmModel : H2OMOJOModel ,
109
- dlModel : H2OMOJOModel ,
110
- censusTable : DataFrame ): Unit = {
111
- crimes.foreach { crime =>
112
- val arrestGBM = scoreEvent(spark, crime, gbmModel, censusTable)
113
- val arrestDL = scoreEvent(spark, crime, dlModel, censusTable)
114
- println(s """
115
- |Crime: $crime
116
- | Will be arrested based on DeepLearning: $arrestDL
117
- | Will be arrested based on GBM: $arrestGBM
161
+ private def score (crimes : DataFrame , gbmModel : H2OMOJOModel , dlModel : H2OMOJOModel )(
162
+ implicit spark : SparkSession ): Unit = {
163
+ import spark .implicits ._
164
+ val arrestGBM = gbmModel.transform(crimes)
165
+ val arrestDL = dlModel.transform(crimes)
166
+ val willBeArrestedPrediction = $" prediction" === " 1"
167
+ println(s """
168
+ | Will be arrested based on DeepLearning: ${arrestDL.where(willBeArrestedPrediction).count()}
169
+ | Will be arrested based on GBM: ${arrestGBM.where(willBeArrestedPrediction).count()}
118
170
|
119
171
""" .stripMargin)
120
- }
121
- }
122
-
123
- def scoreEvent (spark : SparkSession , crime : Crime , model : H2OMOJOModel , censusTable : DataFrame ): Boolean = {
124
- // Create Spark DataFrame from a single row
125
- import spark .implicits ._
126
- val df = addAdditionalDateColumns(spark, spark.sparkContext.parallelize(Seq (crime)).toDF)
127
- // Join table with census data
128
- val row = censusTable.join(df).where(' Community_Area === ' Community_Area_Number )
129
- val predictTable = model.transform(row)
130
- predictTable.collect().head.getAs[String ](" prediction" ) == " 1"
131
172
}
132
173
133
- def createWeatherTable (spark : SparkSession , datafile : String ): DataFrame = {
134
- val df = spark.read.option(" header" , " true" ).option(" inferSchema" , " true" ).csv(datafile)
135
- df.drop(df.columns(0 ))
136
- }
137
-
138
- def createCensusTable (spark : SparkSession , datafile : String ): DataFrame = {
174
+ private def loadCsv (dataPath : String )(implicit spark : SparkSession ): DataFrame = {
175
+ val datafile = s " file:// ${new File (dataPath).getAbsolutePath}"
139
176
val df = spark.read.option(" header" , " true" ).option(" inferSchema" , " true" ).csv(datafile)
140
177
val renamedColumns = df.columns.map { col =>
141
- val name = col.trim.replace(' ' , '_' ).replace('+' , '_' )
178
+ val name = col.trim
179
+ .replace(' ' , '_' )
180
+ .replace('+' , '_' )
181
+ .replace(" __" , " _" )
142
182
df(col).as(name)
143
183
}
144
184
df.select(renamedColumns : _* )
145
185
}
146
186
147
- def createCrimeTable (spark : SparkSession , datafile : String ): DataFrame = {
148
- val df = spark.read.option(" header" , " true" ).option(" inferSchema" , " true" ).csv(datafile)
149
- val renamedColumns = df.columns.map { col =>
150
- val name = col.trim.replace(' ' , '_' ).replace('+' , '_' )
151
- df(col).as(name)
152
- }
153
- addAdditionalDateColumns(spark, df.select(renamedColumns : _* ))
154
- }
155
-
156
- def addAdditionalDateColumns (spark : SparkSession , df : DataFrame ): DataFrame = {
187
+ private def addAdditionalDateColumns (df : DataFrame )(implicit spark : SparkSession ): DataFrame = {
157
188
import org .apache .spark .sql .functions ._
158
189
import spark .implicits ._
159
190
df.withColumn(" DateTmp" , from_unixtime(unix_timestamp(' Date , " MM/dd/yyyy hh:mm:ss a" )))
@@ -168,18 +199,16 @@ object ChicagoCrimeApp {
168
199
.drop(' DateTmp )
169
200
}
170
201
171
- private def getSeason (month : Int ): String = {
172
- val seasonNum =
173
- if (month >= 3 && month <= 5 ) 0 // Spring
174
- else if (month >= 6 && month <= 8 ) 1 // Summer
175
- else if (month >= 9 && month <= 10 ) 2 // Autumn
176
- else 3 // Winter
177
- SEASONS (seasonNum)
202
+ private def monthToSeason (month : Int ): String = {
203
+ if (month >= 3 && month <= 5 ) " Spring"
204
+ else if (month >= 6 && month <= 8 ) " Summer"
205
+ else if (month >= 9 && month <= 10 ) " Autumn"
206
+ else " Winter"
178
207
}
179
208
180
- private def SEASONS : Array [ String ] = Array [ String ]( " Spring " , " Summer " , " Autumn " , " Winter " )
209
+ private def isWeekend ( dayOfWeek : Int ) : Boolean = dayOfWeek == 7 || dayOfWeek == 6
181
210
182
- private def isWeekend ( dayOfWeek : Int ): Int = if (dayOfWeek == 7 || dayOfWeek == 6 ) 1 else 0
211
+ private def boolToInt ( bool : Boolean ): Int = if (bool ) 1 else 0
183
212
184
213
private def dayOfWeek (day : String ): Int = {
185
214
day match {
@@ -194,7 +223,7 @@ object ChicagoCrimeApp {
194
223
}
195
224
}
196
225
197
- case class Crime (
226
+ case class CrimeWithCensusData (
198
227
date : String ,
199
228
IUCR : Short ,
200
229
Primary_Type : String ,
@@ -204,6 +233,13 @@ object ChicagoCrimeApp {
204
233
District : Byte ,
205
234
Ward : Byte ,
206
235
Community_Area : Byte ,
207
- FBI_Code : Byte )
236
+ FBI_Code : Byte ,
237
+ PERCENT_AGED_UNDER_18_OR_OVER_64 : Double ,
238
+ PER_CAPITA_INCOME : Int ,
239
+ HARDSHIP_INDEX : Short ,
240
+ PERCENT_OF_HOUSING_CROWDED : Double ,
241
+ PERCENT_HOUSEHOLDS_BELOW_POVERTY : Double ,
242
+ PERCENT_AGED_16_UNEMPLOYED : Double ,
243
+ PERCENT_AGED_25_WITHOUT_HIGH_SCHOOL_DIPLOMA : Double )
208
244
209
245
}
0 commit comments