-
Notifications
You must be signed in to change notification settings - Fork 2
/
eval.py
39 lines (26 loc) · 1.11 KB
/
eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from resnetv2 import resnet18
from preprocess import create_dataset
from mindspore.compression.quant import QuantizationAwareTraining
from mindspore.nn.optim import Adam
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore import load_checkpoint
from mindspore import Model, context
context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU")
import argparse
ap = argparse.ArgumentParser()
ap.add_argument("-loc", "--check_point", default = '/',
help="Model checkpoint file")
args = vars(ap.parse_args())
num_classes = 10
lr = 0.01
weight_decay = 1e-4
resnet = resnet18(num_classes)
ls = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
opt = Adam(resnet.trainable_params(),lr, weight_decay)
quantizer = QuantizationAwareTraining(bn_fold=False)
quant = quantizer.quantize(resnet)
load_checkpoint(args['check_point'], net=quant) # loading the custom trained checkpoint
eval_data = create_dataset(training = False) # define the test dataset
model = Model(quant, loss_fn=ls, optimizer=opt, metrics={'acc'})
acc = model.eval(eval_data)
print('Accuracy of model is: ', acc['acc'] * 100)