From 12389a33a7eff8cabefe2db0627f62f902055b25 Mon Sep 17 00:00:00 2001 From: Zhiqiang Wang Date: Sun, 28 Nov 2021 03:54:55 +0800 Subject: [PATCH] Add YOLOInference for downstream inference frameworks (#238) * Fixing docstring * Fixing compatibility issue * Minor fix of the docstring * Separate out the concatenation part in PostProcess * Deprecated the ONNX exporting in ncnn * Fixing jit trace * Add YOLOInference for downstream inference frameworks * Add test_yolo_inference * Use attempt_download * Add test_attempt_download * Abstract the _decode_pred_logits * Abstract the LogitsDecoder --- deployment/ncnn/export_onnx.py | 64 ----------- deployment/ncnn/tools/__init__.py | 0 .../ncnn/tools/yolort_deploy_friendly.py | 76 ------------ test/test_models.py | 9 +- test/test_relaying.py | 37 +++++- test/test_utils.py | 16 +-- test/test_v5.py | 27 +++-- yolort/models/box_head.py | 108 ++++++++++++++---- yolort/models/yolo.py | 2 + yolort/relaying/__init__.py | 3 + yolort/relaying/yolo_inference.py | 42 +++++++ yolort/v5/utils/downloads.py | 77 ++++++++++--- 12 files changed, 251 insertions(+), 210 deletions(-) delete mode 100644 deployment/ncnn/export_onnx.py delete mode 100644 deployment/ncnn/tools/__init__.py delete mode 100644 deployment/ncnn/tools/yolort_deploy_friendly.py create mode 100644 yolort/relaying/yolo_inference.py diff --git a/deployment/ncnn/export_onnx.py b/deployment/ncnn/export_onnx.py deleted file mode 100644 index fe742ef7..00000000 --- a/deployment/ncnn/export_onnx.py +++ /dev/null @@ -1,64 +0,0 @@ -# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved. -import argparse - -import torch -from tools.yolort_deploy_friendly import yolov5s_r40_deploy_ncnn - - -def get_parser(): - parser = argparse.ArgumentParser() - parser.add_argument("--weights", type=str, default="./yolov5s.pt", help="weights path") - parser.add_argument( - "--output_path", - type=str, - default="./yolov5s.onnx", - help="path of exported onnx", - ) - parser.add_argument( - "--img_size", - nargs="+", - type=int, - default=[640, 640], - help="image (height, width)", - ) - parser.add_argument("--num_classes", type=int, default=80, help="number of classes") - parser.add_argument("--batch_size", type=int, default=1, help="batch size") - parser.add_argument("--device", default="cpu", help="cuda device, i.e. 0 or 0,1,2,3 or cpu") - parser.add_argument("--half", action="store_true", help="FP16 half-precision export") - parser.add_argument("--dynamic", action="store_true", help="ONNX: dynamic axes") - parser.add_argument("--simplify", action="store_true", help="ONNX: simplify model") - parser.add_argument("--opset", type=int, default=11, help="ONNX: opset version") - return parser - - -def cli_main(): - parser = get_parser() - args = parser.parse_args() - print(args) - export_onnx(args) - - -def export_onnx(args): - - model = yolov5s_r40_deploy_ncnn( - pretrained=True, - num_classes=args.num_classes, - ) - img = torch.rand(args.batch_size, 3, 640, 640) - outputs = model(img) - assert len(outputs) == 3 - - torch.onnx.export( - model, - img, - args.output_path, - verbose=False, - opset_version=args.opset, - do_constant_folding=True, - input_names=["images"], - output_names=["h1", "h2", "h3"], - ) - - -if __name__ == "__main__": - cli_main() diff --git a/deployment/ncnn/tools/__init__.py b/deployment/ncnn/tools/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/deployment/ncnn/tools/yolort_deploy_friendly.py b/deployment/ncnn/tools/yolort_deploy_friendly.py deleted file mode 100644 index 9c985116..00000000 --- a/deployment/ncnn/tools/yolort_deploy_friendly.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved. -from typing import Any, List, Optional - -from torch import nn, Tensor -from torchvision.models.utils import load_state_dict_from_url -from yolort.models.backbone_utils import darknet_pan_backbone -from yolort.models.yolo import YOLO, model_urls - - -class YOLODeployFriendly(YOLO): - """ - Deployment Friendly Wrapper of YOLO. - """ - - def __init__( - self, - backbone: nn.Module, - num_classes: int, - # Anchor parameters - anchor_grids: Optional[List[List[float]]] = None, - anchor_generator: Optional[nn.Module] = None, - head: Optional[nn.Module] = None, - ): - super().__init__( - backbone, - num_classes, - anchor_grids=anchor_grids, - anchor_generator=anchor_generator, - head=head, - ) - - def forward(self, samples: Tensor): - """ - Arguments: - samples (Tensor): batched images, of shape [batch_size x 3 x H x W] - """ - # get the features from the backbone - features = self.backbone(samples) - - # compute the yolo heads outputs using the features - head_outputs = self.head(features) - return head_outputs - - -def yolov5s_r40_deploy_ncnn( - pretrained: bool = False, - progress: bool = True, - num_classes: int = 80, - **kwargs: Any, -) -> YOLODeployFriendly: - """ - Deployment friendly Wrapper of yolov5s for ncnn. - - Args: - pretrained (bool): If True, returns a model pre-trained on ImageNet - progress (bool): If True, displays a progress bar of the download to stderr - """ - backbone_name = "darknet_s_r4_0" - weights_name = "yolov5_darknet_pan_s_r40_coco" - depth_multiple = 0.33 - width_multiple = 0.5 - version = "r4.0" - - backbone = darknet_pan_backbone(backbone_name, depth_multiple, width_multiple, version=version) - - model = YOLODeployFriendly(backbone, num_classes, **kwargs) - if pretrained: - if model_urls.get(weights_name, None) is None: - raise ValueError(f"No checkpoint is available for model {weights_name}") - state_dict = load_state_dict_from_url(model_urls[weights_name], progress=progress) - model.load_state_dict(state_dict) - - del model.anchor_generator - del model.post_process - - return model diff --git a/test/test_models.py b/test/test_models.py index 782d2fd5..3dc03415 100644 --- a/test/test_models.py +++ b/test/test_models.py @@ -12,7 +12,7 @@ from yolort.models.backbone_utils import darknet_pan_backbone from yolort.models.box_head import YOLOHead, PostProcess, SetCriterion from yolort.models.transformer import darknet_tan_backbone -from yolort.v5 import get_yolov5_size +from yolort.v5 import get_yolov5_size, attempt_download @contextlib.contextmanager @@ -351,16 +351,11 @@ def test_load_from_yolov5( hash_prefix: str, ): img_path = "test/assets/bus.jpg" - checkpoint_path = f"{arch}_{upstream_version}_{hash_prefix}" base_url = "https://github.com/ultralytics/yolov5/releases/download/" model_url = f"{base_url}/{upstream_version}/{arch}.pt" + checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix) - torch.hub.download_url_to_file( - model_url, - checkpoint_path, - hash_prefix=hash_prefix, - ) score_thresh = 0.25 model_yolov5 = YOLOv5.load_from_yolov5( diff --git a/test/test_relaying.py b/test/test_relaying.py index 93f4372f..3b8fb90a 100644 --- a/test/test_relaying.py +++ b/test/test_relaying.py @@ -1,6 +1,10 @@ +import pytest +import torch +from torch import Tensor from torch.jit._trace import TopLevelTracedModule from yolort.models import yolov5s -from yolort.relaying import get_trace_module +from yolort.relaying import get_trace_module, YOLOInference +from yolort.v5 import attempt_download def test_get_trace_module(): @@ -8,3 +12,34 @@ def test_get_trace_module(): script_module = get_trace_module(model_func, input_shape=(416, 320)) assert isinstance(script_module, TopLevelTracedModule) assert script_module.code is not None + + +@pytest.mark.parametrize( + "arch, version, upstream_version, hash_prefix", + [ + ("yolov5s", "r4.0", "v4.0", "9ca9a642"), + ("yolov5n", "r6.0", "v6.0", "649e089f"), + ("yolov5s", "r6.0", "v6.0", "c3b140f3"), + ("yolov5n6", "r6.0", "v6.0", "beecbbae"), + ], +) +def test_yolo_inference(arch, version, upstream_version, hash_prefix): + + base_url = "https://github.com/ultralytics/yolov5/releases/download/" + model_url = f"{base_url}/{upstream_version}/{arch}.pt" + checkpoint_path = attempt_download(model_url) + score_thresh = 0.25 + + model = YOLOInference( + checkpoint_path, + score_thresh=score_thresh, + version=version, + ) + model.eval() + samples = torch.rand(1, 3, 320, 320) + outs = model(samples) + + assert isinstance(outs[0], dict) + assert isinstance(outs[0]["boxes"], Tensor) + assert isinstance(outs[0]["labels"], Tensor) + assert isinstance(outs[0]["scores"], Tensor) diff --git a/test/test_utils.py b/test/test_utils.py index b5eb503e..76c6740b 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -18,6 +18,7 @@ load_yolov5_model, scale_coords, non_max_suppression, + attempt_download, ) @@ -36,15 +37,10 @@ def test_load_from_ultralytics( hash_prefix: str, use_p6: bool, ): - checkpoint_path = f"{arch}_{upstream_version}_{hash_prefix}" base_url = "https://github.com/ultralytics/yolov5/releases/download/" model_url = f"{base_url}/{upstream_version}/{arch}.pt" + checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix) - torch.hub.download_url_to_file( - model_url, - checkpoint_path, - hash_prefix=hash_prefix, - ) model_info = load_from_ultralytics(checkpoint_path, version=version) assert isinstance(model_info, dict) assert model_info["num_classes"] == 80 @@ -65,16 +61,10 @@ def test_load_from_ultralytics_voc( hash_prefix: str, ): img_path = "test/assets/bus.jpg" - checkpoint_path = f"{arch}_{upstream_version}_{hash_prefix}" base_url = "https://github.com/ultralytics/yolov5/releases/download/" model_url = f"{base_url}/{upstream_version}/{arch}.pt" - - torch.hub.download_url_to_file( - model_url, - checkpoint_path, - hash_prefix=hash_prefix, - ) + checkpoint_path = attempt_download(model_url, hash_prefix=hash_prefix) # Preprocess img_raw = cv2.imread(img_path) diff --git a/test/test_v5.py b/test/test_v5.py index df6a1135..da2aaa43 100644 --- a/test/test_v5.py +++ b/test/test_v5.py @@ -1,24 +1,27 @@ # Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved. -from pathlib import Path +import hashlib -import torch from torch import Tensor -from yolort.v5 import load_yolov5_model +from yolort.v5 import load_yolov5_model, attempt_download + + +def test_attempt_download(): + + model_url = "https://github.com/ultralytics/yolov5/releases/download/v4.0/yolov5s.pt" + checkpoint_path = attempt_download(model_url, hash_prefix="9ca9a642") + with open(checkpoint_path, "rb") as f: + bytes = f.read() # read entire file as bytes + readable_hash = hashlib.sha256(bytes).hexdigest() + assert readable_hash[:8] == "9ca9a642" def test_load_yolov5_model(): img_path = "test/assets/zidane.jpg" - yolov5s_r40_path = Path("yolov5s.pt") - - if not yolov5s_r40_path.exists(): - torch.hub.download_url_to_file( - "https://github.com/ultralytics/yolov5/releases/download/v4.0/yolov5s.pt", - yolov5s_r40_path, - hash_prefix="9ca9a642", - ) + model_url = "https://github.com/ultralytics/yolov5/releases/download/v4.0/yolov5s.pt" + checkpoint_path = attempt_download(model_url, hash_prefix="9ca9a642") - model = load_yolov5_model(str(yolov5s_r40_path), autoshape=True, verbose=False) + model = load_yolov5_model(checkpoint_path, autoshape=True, verbose=False) results = model(img_path) assert isinstance(results.pred, list) diff --git a/yolort/models/box_head.py b/yolort/models/box_head.py index 50442daa..5a9c4570 100644 --- a/yolort/models/box_head.py +++ b/yolort/models/box_head.py @@ -196,7 +196,8 @@ def __call__( # Classification if self.num_classes > 1: # cls loss (only if multiple classes) - t = torch.full_like(pred_logits_subset[:, 5:], self.smooth_neg, device=device) # targets + # targets + t = torch.full_like(pred_logits_subset[:, 5:], self.smooth_neg, device=device) t[torch.arange(num_targets), target_cls[i]] = self.smooth_pos loss_cls += F.binary_cross_entropy_with_logits( pred_logits_subset[:, 5:], t, pos_weight=pos_weight_cls @@ -316,7 +317,83 @@ def build_targets( return target_cls, target_box, indices, anch -class PostProcess(nn.Module): +def _concat_pred_logits(head_outputs: List[Tensor]) -> Tensor: + # Concat all pred logits + batch_size, _, _, _, K = head_outputs[0].shape + + all_pred_logits = [] + for pred_logits in head_outputs: + pred_logits = pred_logits.reshape(batch_size, -1, K) # Size=(N, HWA, K) + all_pred_logits.append(pred_logits) + + all_pred_logits = torch.cat(all_pred_logits, dim=1) + return all_pred_logits + + +class LogitsDecoder(nn.Module): + """ + This is a simplified version of PostProcess to remove the ``torchvision::nms`` module. + + Args: + score_thresh (float): Score threshold used for postprocessing the detections. + """ + + def __init__(self, score_thresh: float = 0.25) -> None: + super().__init__() + self.score_thresh = score_thresh + + def _decode_pred_logits( + self, + pred_logits: Tensor, + idx: int, + anchors_tuple: Tuple[Tensor, Tensor, Tensor], + ): + """ + Decode the prediction logit from the Post_precess + """ + pred_logits = torch.sigmoid(pred_logits[idx]) + + # Compute conf + # box_conf x class_conf, w/ shape: num_anchors x num_classes + scores = pred_logits[:, 5:] * pred_logits[:, 4:5] + + boxes = det_utils.decode_single(pred_logits[:, :4], anchors_tuple) + + # remove low scoring boxes + inds, labels = torch.where(scores > self.score_thresh) + boxes, scores = boxes[inds], scores[inds, labels] + + return scores, labels, boxes + + def forward( + self, + head_outputs: List[Tensor], + anchors_tuple: Tuple[Tensor, Tensor, Tensor], + ) -> List[Dict[str, Tensor]]: + """ + Just concat the predict logits, ignore the original ``torchvision::nms`` module + from original ``yolort.models.box_head.PostProcess``. + + Args: + head_outputs (List[Tensor]): The predicted locations and class/object confidence, + shape of the element is (N, A, H, W, K). + anchors_tuple (Tuple[Tensor, Tensor, Tensor]): + """ + batch_size = len(head_outputs[0]) + + all_pred_logits = _concat_pred_logits(head_outputs) + + detections: List[Dict[str, Tensor]] = [] + + for idx in range(batch_size): # image idx, image inference + scores, labels, boxes = self._decode_pred_logits(all_pred_logits, idx, anchors_tuple) + + detections.append({"scores": scores, "labels": labels, "boxes": boxes}) + + return detections + + +class PostProcess(LogitsDecoder): """ Performs Non-Maximum Suppression (NMS) on inference results """ @@ -333,8 +410,7 @@ def __init__( nms_thresh (float): NMS threshold used for postprocessing the detections. detections_per_img (int): Number of best detections to keep after NMS. """ - super().__init__() - self.score_thresh = score_thresh + super().__init__(score_thresh=score_thresh) self.nms_thresh = nms_thresh self.detections_per_img = detections_per_img @@ -354,33 +430,19 @@ def forward( shape of the element is (N, A, H, W, K). anchors_tuple (Tuple[Tensor, Tensor, Tensor]): """ - batch_size, _, _, _, K = head_outputs[0].shape - - all_pred_logits = [] - for pred_logits in head_outputs: - pred_logits = pred_logits.reshape(batch_size, -1, K) # Size=(N, HWA, K) - all_pred_logits.append(pred_logits) + batch_size = len(head_outputs[0]) - all_pred_logits = torch.cat(all_pred_logits, dim=1) + all_pred_logits = _concat_pred_logits(head_outputs) detections: List[Dict[str, Tensor]] = [] for idx in range(batch_size): # image idx, image inference - pred_logits = torch.sigmoid(all_pred_logits[idx]) - - # Compute conf - # box_conf x class_conf, w/ shape: num_anchors x num_classes - scores = pred_logits[:, 5:] * pred_logits[:, 4:5] - - boxes = det_utils.decode_single(pred_logits[:, :4], anchors_tuple) - - # remove low scoring boxes - inds, labels = torch.where(scores > self.score_thresh) - boxes, scores = boxes[inds], scores[inds, labels] + # Decode the predict logits + scores, labels, boxes = self._decode_pred_logits(all_pred_logits, idx, anchors_tuple) # non-maximum suppression, independently done per level keep = box_ops.batched_nms(boxes, scores, labels, self.nms_thresh) - # keep only topk scoring head_outputs + # Keep only topk scoring head_outputs keep = keep[: self.detections_per_img] boxes, scores, labels = boxes[keep], scores[keep], labels[keep] diff --git a/yolort/models/yolo.py b/yolort/models/yolo.py index 4fe6e357..34ba6ef8 100644 --- a/yolort/models/yolo.py +++ b/yolort/models/yolo.py @@ -190,6 +190,7 @@ def load_from_yolov5( score_thresh: float = 0.25, nms_thresh: float = 0.45, version: str = "r6.0", + post_process: Optional[nn.Module] = None, ): """ Load model state from the checkpoint trained by YOLOv5. @@ -220,6 +221,7 @@ def load_from_yolov5( anchor_grids=model_info["anchor_grids"], score_thresh=score_thresh, nms_thresh=nms_thresh, + post_process=post_process, ) model.load_state_dict(model_info["state_dict"]) diff --git a/yolort/relaying/__init__.py b/yolort/relaying/__init__.py index 82822225..e051b86b 100644 --- a/yolort/relaying/__init__.py +++ b/yolort/relaying/__init__.py @@ -1,2 +1,5 @@ # Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved. from .trace_wrapper import get_trace_module +from .yolo_inference import YOLOInference + +__all__ = ["get_trace_module", "YOLOInference"] diff --git a/yolort/relaying/yolo_inference.py b/yolort/relaying/yolo_inference.py new file mode 100644 index 00000000..8d3b411b --- /dev/null +++ b/yolort/relaying/yolo_inference.py @@ -0,0 +1,42 @@ +# Copyright (c) 2021, Zhiqiang Wang. All Rights Reserved. +import torch +from torch import nn, Tensor +from yolort.models import YOLO +from yolort.models.box_head import LogitsDecoder + +__all__ = ["YOLOInference"] + + +class YOLOInference(nn.Module): + """ + A deployment friendly wrapper of YOLO. + + Remove the ``torchvision::nms`` in this warpper, due to the fact that some third-party + inference frameworks currently do not support this operator very well. + """ + + def __init__( + self, + checkpoint_path: str, + score_thresh: float = 0.25, + version: str = "r6.0", + ): + super().__init__() + post_process = LogitsDecoder(score_thresh) + + self.model = YOLO.load_from_yolov5( + checkpoint_path, + version=version, + post_process=post_process, + ) + + @torch.no_grad() + def forward(self, inputs: Tensor): + """ + Args: + inputs (Tensor): batched images, of shape [batch_size x 3 x H x W] + """ + # Compute the detections + outputs = self.model(inputs) + + return outputs diff --git a/yolort/v5/utils/downloads.py b/yolort/v5/utils/downloads.py index 4cb6ab57..0cd04360 100644 --- a/yolort/v5/utils/downloads.py +++ b/yolort/v5/utils/downloads.py @@ -1,24 +1,36 @@ -# YOLOv5 by Ultralytics, GPL-3.0 license +# YOLOv5 🚀 by Ultralytics, GPL-3.0 license """ Download utils """ import os +import platform import subprocess +import time import urllib from pathlib import Path +from zipfile import ZipFile import requests -import torch +from torch.hub import download_url_to_file -def safe_download(file, url, url2=None, min_bytes=1e0, error_msg=""): - # Attempts to download file from url or url2, checks and removes incomplete downloads < min_bytes +def gsutil_getsize(url=""): + # gs://bucket/file size https://cloud.google.com/storage/docs/gsutil/commands/du + s = subprocess.check_output(f"gsutil du {url}", shell=True).decode("utf-8") + return eval(s.split(" ")[0]) if len(s) else 0 # bytes + + +def safe_download(file, url, url2=None, min_bytes=1e0, error_msg="", hash_prefix=None): + """ + Attempts to download file from url or url2, checks + and removes incomplete downloads < min_bytes + """ file = Path(file) assert_msg = f"Downloaded file '{file}' does not exist or size is < min_bytes={min_bytes}" try: # url1 print(f"Downloading {url} to {file}...") - torch.hub.download_url_to_file(url, str(file)) + download_url_to_file(url, str(file), hash_prefix=hash_prefix) assert file.exists() and file.stat().st_size > min_bytes, assert_msg # check except Exception as e: # url2 file.unlink(missing_ok=True) # remove partial downloads @@ -32,7 +44,7 @@ def safe_download(file, url, url2=None, min_bytes=1e0, error_msg=""): print("") -def attempt_download(file, repo="ultralytics/yolov5"): +def attempt_download(file, repo="ultralytics/yolov5", hash_prefix=None): # Attempt file download if does not exist file = Path(str(file).strip().replace("'", "")) @@ -42,24 +54,25 @@ def attempt_download(file, repo="ultralytics/yolov5"): if str(file).startswith(("http:/", "https:/")): # download url = str(file).replace(":/", "://") # Pathlib turns :// -> :/ name = name.split("?")[0] # parse authentication https://url.com/file.txt?auth... - safe_download(file=name, url=url, min_bytes=1e5) + safe_download(file=name, url=url, min_bytes=1e5, hash_prefix=hash_prefix) return name # GitHub assets - file.parent.mkdir(parents=True, exist_ok=True) # make parent dir (if required) + file.parent.mkdir(parents=True, exist_ok=True) try: # github api response = requests.get(f"https://api.github.com/repos/{repo}/releases/latest").json() - # release assets, i.e. ['yolov5s.pt', 'yolov5m.pt', ...] assets = [x["name"] for x in response["assets"]] tag = response["tag_name"] # i.e. 'v1.0' - except requests.exceptions.RequestException as e: # fallback plan - print(str(e)) + except Exception as e: # fallback plan + print(f"Wrong when calling GitHub API: {e}") assets = [ + "yolov5n.pt", "yolov5s.pt", "yolov5m.pt", "yolov5l.pt", "yolov5x.pt", + "yolov5n6.pt", "yolov5s6.pt", "yolov5m6.pt", "yolov5l6.pt", @@ -71,14 +84,14 @@ def attempt_download(file, repo="ultralytics/yolov5"): .decode() .split()[-1] ) - except subprocess.CalledProcessError: - tag = "v5.0" # current release + except Exception as e: + print(f"Wrong when getting GitHub tag: {e}") + tag = "v6.0" # current release if name in assets: safe_download( file, url=f"https://github.com/{repo}/releases/download/{tag}/{name}", - # url2=f'https://storage.googleapis.com/{repo}/ckpt/{name}', # backup url (optional) min_bytes=1e5, error_msg=f"{file} missing, try downloading from https://github.com/{repo}/releases/", ) @@ -86,6 +99,42 @@ def attempt_download(file, repo="ultralytics/yolov5"): return str(file) +def gdrive_download(id="16TiPfZj7htmTyhntwcZyEEAejOUxuT6m", file="tmp.zip"): + # Downloads a file from Google Drive. + t = time.time() + file = Path(file) + cookie = Path("cookie") # gdrive cookie + print(f"Downloading https://drive.google.com/uc?export=download&id={id} as {file}... ", end="") + file.unlink(missing_ok=True) # remove existing file + cookie.unlink(missing_ok=True) # remove existing cookie + + # Attempt file download + out = "NUL" if platform.system() == "Windows" else "/dev/null" + os.system(f'curl -c ./cookie -s -L "drive.google.com/uc?export=download&id={id}" > {out}') + if os.path.exists("cookie"): # large file + s = f"curl -Lb ./cookie 'drive.google.com/uc?export=download&confirm={get_token()}&id={id}'" + else: # small file + s = f'curl -s -L "drive.google.com/uc?export=download&id={id}"' + download_excute = f"{s} -o {file}" + r = os.system(download_excute) + cookie.unlink(missing_ok=True) # remove existing cookie + + # Error check + if r != 0: + file.unlink(missing_ok=True) # remove partial + print("Download error ") # raise Exception('Download error') + return r + + # Unzip if archive + if file.suffix == ".zip": + print("unzipping... ", end="") + ZipFile(file).extractall(path=file.parent) # unzip + file.unlink() # remove zip + + print(f"Done ({time.time() - t:.1f}s)") + return r + + def get_token(cookie="./cookie"): with open(cookie) as f: for line in f: