Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

local checkpoint error #963

Open
Lee-ray-a opened this issue Dec 27, 2023 · 0 comments
Open

local checkpoint error #963

Lee-ray-a opened this issue Dec 27, 2023 · 0 comments

Comments

@Lee-ray-a
Copy link

trying to do inference

  • this is my infer_Test.py
import os

import jax
from matplotlib import pyplot as plt
import numpy as np
from scenic.projects.owl_vit import configs
from scenic.projects.owl_vit import models
from scipy.special import expit as sigmoid
import skimage
from skimage import io as skimage_io
from skimage import transform as skimage_transform
import tensorflow as tf

devices = jax.devices('gpu')[0]

'''Choose config''' 
config = configs.owl_v2_clip_b16.get_config(init_mode='canonical_checkpoint')

'''Load the model and variables'''
module = models.TextZeroShotDetectionModule(
    body_configs=config.model.body,
    objectness_head_configs=config.model.objectness_head,
    normalize=config.model.normalize,
    box_bias=config.model.box_bias)
variables = module.load_variables(config.init_from.checkpoint_path)

'''Prepare image'''
# Load example image:

image_uint8 = skimage_io.imread('/projects/TianchiCup/test_images/picture/ele_0cfcef679f3c1880f69ab9bdf596ee5f.jpg')
image = image_uint8.astype(np.float32) / 255.0

# Pad to square with gray pixels on bottom and right:
h, w, _ = image.shape
size = max(h, w)
image_padded = np.pad(
    image, ((0, size - h), (0, size - w), (0, 0)), constant_values=0.5)

# Resize to model input size:
input_image = skimage.transform.resize(
    image_padded,
    (config.dataset_configs.input_size, config.dataset_configs.input_size),
    anti_aliasing=True)

'''Prepare text queries'''
text_queries = ['people smoking','shirtless','mouse','cat','dog']
tokenized_queries = np.array([
    module.tokenize(q, config.dataset_configs.max_query_length)
    for q in text_queries
])

# Pad tokenized queries to avoid recompilation if number of queries changes:
tokenized_queries = np.pad(
    tokenized_queries,
    pad_width=((0, 100 - len(text_queries)), (0, 0)),
    constant_values=0)

'''Get predictions'''
jitted = jax.jit(module.apply, static_argnames=('train',))
# Note: The model expects a batch dimension.
predictions = jitted(
    variables,
    input_image[None, ...],
    tokenized_queries[None, ...],
    train=False)

# Remove batch dimension and convert to numpy:
predictions = jax.tree_util.tree_map(lambda x: np.array(x[0]), predictions )

'''Plot predictions'''
score_threshold = 0.2

logits = predictions['pred_logits'][..., :len(text_queries)]  # Remove padding.
scores = sigmoid(np.max(logits, axis=-1))
labels = np.argmax(predictions['pred_logits'], axis=-1)
boxes = predictions['pred_boxes']

# fig, ax = plt.subplots(1, 1, figsize=(8, 8))
# ax.imshow(input_image, extent=(0, 1, 1, 0))
# ax.set_axis_off()

for score, box, label in zip(scores, boxes, labels):
  if score < score_threshold:
    continue
  cx, cy, w, h = box
  rr, cc = skimage.draw.polygon_perimeter([cy - h / 2, cy + h / 2, cy + h / 2, cy - h / 2, cy - h / 2],
                                 [cx - w / 2, cx - w / 2, cx + w / 2, cx + w / 2, cx - w / 2],
                                 shape=image.shape, clip=True)
  input_image[rr, cc] = [255, 0, 0]  
  skimage_io.imsave('output.jpg', input_image)

this is the local jax checkpoint I want to load

  • owl_v2_clip_b16.py
# pylint: disable=line-too-long
r"""OWL v2 CLIP B/16 config."""
import ml_collections


CHECKPOINTS = {
    # https://arxiv.org/abs/2306.09683 Table 1 row 11:
    'owl2-b16-960-st-ngrams': 'gs://scenic-bucket/owl_vit/checkpoints/owl2-b16-960-st-ngrams_c7e1b9a',
    # https://arxiv.org/abs/2306.09683 Table 1 row 14:
    'owl2-b16-960-st-ngrams-ft-lvisbase': 'gs://scenic-bucket/owl_vit/checkpoints/owl2-b16-960-st-ngrams-ft-lvisbase_d368398',
    # https://arxiv.org/abs/2306.09683 Figure 5 weight ensemble:

      '''I add local path to this place'''
    'owl2-b16-960-st-ngrams-curated-ft-lvisbase-ens-cold-weight-05': '/projects/TianchiCup/scenic/owl2-b16-960-st/owl2-b16-960-st-ngrams-curated-ft-lvisbase-ens-cold-weight-05',
}

CHECKPOINTS['canonical_checkpoint'] = CHECKPOINTS[
    'owl2-b16-960-st-ngrams-curated-ft-lvisbase-ens-cold-weight-05'
]


def get_config(init_mode='canonical_checkpoint'):
  """Returns the configuration for text-query-based detection using OWL-ViT."""
  config = ml_collections.ConfigDict()
  config.experiment_name = 'owl_vit_detection'

  # Dataset.
  config.dataset_name = 'owl_vit'
  config.dataset_configs = ml_collections.ConfigDict()
  config.dataset_configs.input_size = 960
  config.dataset_configs.input_range = None
  config.dataset_configs.max_query_length = 16

  # Model.
  config.model_name = 'text_zero_shot_detection'

  config.model = ml_collections.ConfigDict()
  config.model.normalize = True

  config.model.body = ml_collections.ConfigDict()
  config.model.body.type = 'clip'
  config.model.body.variant = 'vit_b16'
  config.model.body.merge_class_token = 'mul-ln'
  config.model.box_bias = 'both'

  # Objectness head.
  config.model.objectness_head = ml_collections.ConfigDict()
  config.model.objectness_head.stop_gradient = True

  # Init.
  config.init_from = ml_collections.ConfigDict()
  checkpoint_path = CHECKPOINTS.get(init_mode, None)
  print('checkpoint_path: ',checkpoint_path)
  if checkpoint_path is None:
    raise ValueError('Unknown init_mode: {}'.format(init_mode))
  config.init_from.checkpoint_path = checkpoint_path

  return config

the output error is

image

Where did I do wrong? please give some insight. Thank you

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant