Skip to content

Commit

Permalink
Update evaluate code
Browse files Browse the repository at this point in the history
  • Loading branch information
jiaowoguanren0615 committed Jun 14, 2024
1 parent 6d6a85b commit 27ab4a4
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
5 changes: 1 addition & 4 deletions estimate_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def predictor(model, img, mask, device):
return prediction


def run_pred(args, weights_path, img_path, roi_mask_path):
def run_pred(args, model, weights_path, img_path, roi_mask_path):
assert os.path.exists(weights_path), f"weights {weights_path} not found."
assert os.path.exists(img_path), f"image {img_path} not found."
assert os.path.exists(roi_mask_path), f"image {roi_mask_path} not found."
Expand All @@ -38,9 +38,6 @@ def run_pred(args, weights_path, img_path, roi_mask_path):
device = args.device
print("using {} device.".format(device))

# create model
model = UKAN_large(num_classes=args.nb_classes)

# load weights
model.load_state_dict(torch.load(weights_path, map_location='cpu')['model_state'])
model.to(device)
Expand Down
7 changes: 6 additions & 1 deletion train_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,11 +345,16 @@ def main(args):
print('*********No improving mIOU, No saving checkpoint*********')

if args.predict and utils.is_main_process():
model_pred = create_model(
args.model,
num_classes=args.nb_classes,
args=args
)
print('*******************STARTING PREDICT*******************')
weights_path = f'./{args.save_weights_dir}/{args.model}_best_model.pth'
img_path = "/mnt/d/MedicalSeg/CVC-ClinicDB/Original/1.png"
roi_mask_path = "/mnt/d/MedicalSeg/CVC-ClinicDB/Ground Truth/1.png"
run_pred(args, weights_path, img_path, roi_mask_path)
run_pred(args, model_pred, weights_path, img_path, roi_mask_path)



Expand Down

0 comments on commit 27ab4a4

Please sign in to comment.