@@ -96,32 +96,34 @@ abstract class RDD[T: ClassManifest](@transient sc: SparkContext) extends Serial
96
96
var total = 0
97
97
var multiplier = 3.0
98
98
var initialCount = count()
99
-
99
+ var maxSelected = 0
100
+
101
+ if (initialCount > Integer .MAX_VALUE ) {
102
+ maxSelected = Integer .MAX_VALUE
103
+ }
104
+ else {
105
+ maxSelected = initialCount.toInt
106
+ }
107
+
100
108
if (num > initialCount) {
101
- total = Math .min(initialCount, Integer .MAX_VALUE )
102
- total = total.toInt
103
- fraction = 1.0
109
+ total = maxSelected
110
+ fraction = Math .min(multiplier* (maxSelected+ 1 )/ initialCount, 1.0 )
104
111
}
105
112
else if (num < 0 ) {
106
- throw (new IllegalArgumentException ())
113
+ throw (new IllegalArgumentException (" Negative number of elements requested " ))
107
114
}
108
115
else {
109
- fraction = Math .min(multiplier* (num+ 1 )/ count() , 1.0 )
116
+ fraction = Math .min(multiplier* (num+ 1 )/ initialCount , 1.0 )
110
117
total = num.toInt
111
118
}
112
119
113
- var r = new SampledRDD (this , withReplacement, fraction, seed)
114
- var samples = r.collect()
120
+ var samples = this .sample(withReplacement, fraction, seed).collect()
115
121
116
122
while (samples.length < total) {
117
- r = new SampledRDD ( this , withReplacement, fraction, seed)
123
+ samples = this .sample( withReplacement, fraction, seed).collect( )
118
124
}
119
125
120
- var arr = new Array [T ](total)
121
-
122
- for (i <- 0 to total - 1 ) {
123
- arr(i) = samples(i)
124
- }
126
+ val arr = samples.take(total)
125
127
126
128
return arr
127
129
}
0 commit comments