-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_advanced.py
executable file
·83 lines (71 loc) · 2.82 KB
/
train_advanced.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
# USAGE
# python train_advanced.py --samples output/examples --char-classifier output/adv_char.cpickle --digit-classifier output/adv_digit.cpickle
# import the necessary packages
from __future__ import print_function
from pyimagesearch.license_plate import LicensePlateDetector
from pyimagesearch.descriptors import BlockBinaryPixelSum
from sklearn.svm import LinearSVC
from imutils import paths
import argparse
import pickle
import random
import glob
import cv2
# construct the argument parser and parse the arguments
ap = argparse.ArgumentParser()
ap.add_argument("-s", "--samples", required=True,
help="path to the training samples directory")
ap.add_argument("-c", "--char-classifier", required=True,
help="path to the output character classifier")
ap.add_argument("-d", "--digit-classifier", required=True,
help="path to the output digit classifier")
ap.add_argument("-m", "--min-samples", type=int, default=20,
help="minimum # of samples per character")
args = vars(ap.parse_args())
# initialize the descriptor
blockSizes = ((5, 5), (5, 10), (10, 5), (10, 10))
desc = BlockBinaryPixelSum(targetSize=(30, 15), blockSizes=blockSizes)
# initialize the data and labels for the alphabet and digits
alphabetData = []
digitsData = []
alphabetLabels = []
digitsLabels = []
# loop over the sample character paths
for samplePath in sorted(glob.glob(args["samples"] + "/*")):
# extract the sample name, grab all images in the sample path, and sample them
sampleName = samplePath[samplePath.rfind("/") + 1:]
imagePaths = list(paths.list_images(samplePath))
imagePaths = random.sample(imagePaths, min(len(imagePaths), args["min_samples"]))
# loop over all images in the sample path
for imagePath in imagePaths:
# load the character, convert it to grayscale, preprocess it, and describe it
char = cv2.imread(imagePath)
char = cv2.cvtColor(char, cv2.COLOR_BGR2GRAY)
char = LicensePlateDetector.preprocessChar(char)
features = desc.describe(char)
# check to see if we are examining a digit
if sampleName.isdigit():
digitsData.append(features)
digitsLabels.append(sampleName)
# otherwise, we are examining an alphabetical character
else:
alphabetData.append(features)
alphabetLabels.append(sampleName)
# train the character classifier
print("[INFO] fitting character model...")
charModel = LinearSVC(C=1.0, random_state=42)
charModel.fit(alphabetData, alphabetLabels)
# train the digit classifier
print("[INFO] fitting digit model...")
digitModel = LinearSVC(C=1.0, random_state=42)
digitModel.fit(digitsData, digitsLabels)
# dump the character classifier to file
print("[INFO] dumping character model...")
f = open(args["char_classifier"], "wb")
f.write(pickle.dumps(charModel))
f.close()
# dump the digit classifier to file
print("[INFO] dumping digit model...")
f = open(args["digit_classifier"], "wb")
f.write(pickle.dumps(digitModel))
f.close()