Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactorizes clip.py to make contrastive loss configurable. #963

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions axlearn/common/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
"""Loss functions."""
# pylint: disable=too-many-lines
import enum
from typing import Optional
from typing import Optional, Union

import jax
import jax.numpy as jnp
Expand Down Expand Up @@ -454,7 +454,7 @@ def asymmetric_contrastive_loss_from_logits(
logits: Tensor,
*,
key_paddings: Tensor = None,
temperature: float = 1.0,
temperature: Union[Tensor, float] = 1.0,
soft_labels: Optional[Tensor] = None,
) -> Tensor:
"""Asymmetric contrastive loss from logits.
Expand Down Expand Up @@ -522,7 +522,7 @@ def asymmetric_contrastive_loss_from_features(
*,
negative_keys: Tensor = None,
negative_key_paddings: Tensor = None,
temperature: float = 1.0,
temperature: Union[Tensor, float] = 1.0,
soft_labels: Optional[Tensor] = None,
):
"""Asymmetric contrastive loss from features.
Expand Down Expand Up @@ -578,7 +578,7 @@ def symmetric_contrastive_loss_from_logits( # pylint: disable=missing-param-doc
*,
y_as_key_paddings: Tensor = None,
x_as_key_paddings: Tensor = None,
temperature: float = 1.0,
temperature: Union[float, Tensor] = 1.0,
y_as_key_soft_labels: Optional[Tensor] = None,
x_as_key_soft_labels: Optional[Tensor] = None,
):
Expand Down Expand Up @@ -628,7 +628,7 @@ def symmetric_contrastive_loss_from_features(
y_negatives: Tensor = None,
x_negative_paddings: Tensor = None,
y_negative_paddings: Tensor = None,
temperature: float = 1.0,
temperature: Union[Tensor, float] = 1.0,
y_as_key_soft_labels: Optional[Tensor] = None,
x_as_key_soft_labels: Optional[Tensor] = None,
):
Expand Down
19 changes: 14 additions & 5 deletions axlearn/vision/clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
"""
# pylint: disable=duplicate-code

from typing import Optional, Union
from typing import Optional, Protocol, Union

import jax.numpy as jnp
import numpy as np
Expand All @@ -31,10 +31,12 @@
from axlearn.common.bert import bert_embedding_config, bert_model_config, bert_transformer_config
from axlearn.common.config import (
REQUIRED,
ConfigOr,
FunctionConfigBase,
InstantiableConfig,
Required,
config_class,
maybe_instantiate,
)
from axlearn.common.embedding import TransformerTextEmbeddings
from axlearn.common.layers import (
Expand All @@ -51,7 +53,7 @@
from axlearn.common.param_init import ConstantInitializer
from axlearn.common.poolings import BasePoolingLayer, FirstNTokenPooling, LastNTokenPooling
from axlearn.common.text_encoder import TEXT_EMBEDDINGS, TextEmbeddingEncoder
from axlearn.common.utils import NestedTensor
from axlearn.common.utils import NestedTensor, Tensor
from axlearn.common.vision_transformer import VisionTransformer, layer_norm_config
from axlearn.vision.mobilenets import MobileNets

Expand Down Expand Up @@ -452,6 +454,13 @@ def set_bert_text_encoder_config(
return text_encoder_cfg


class _ContrastiveLossFn(Protocol):
def __call__(
self, x_y_logits: Tensor, y_x_logits: Tensor, *, temperature: Union[Tensor, float]
) -> Tensor:
...


class CLIPFusionNetwork(FusionNetwork):
"""CLIP fusion network. See also CLIPModel."""

Expand All @@ -465,11 +474,13 @@ class Config(BaseLayer.Config):
value=np.log(1 / 0.07)
)
temperature_max_cap: float = 100
contrastive_loss_fn: ConfigOr[_ContrastiveLossFn] = symmetric_contrastive_loss_from_logits

def __init__(self, cfg: Config, *, parent: Optional[Module]):
super().__init__(cfg, parent=parent)
cfg = self.config
self._log_logit_scale_init = cfg.log_logit_scale_init.instantiate()
self._contrastive_loss_fn = maybe_instantiate(cfg.contrastive_loss_fn)

def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]:
param_specs = {}
Expand Down Expand Up @@ -515,9 +526,7 @@ def forward(self, input_batch: NestedTensor) -> NestedTensor:
log_logit_scale = jnp.clip(log_logit_scale, a_max=jnp.log(cfg.temperature_max_cap))
temperature = 1 / jnp.exp(log_logit_scale)
similarity = contrastive_logits(x, y)
loss = symmetric_contrastive_loss_from_logits(
similarity, similarity.T, temperature=temperature
)
loss = self._contrastive_loss_fn(similarity, similarity.T, temperature=temperature)
self.add_summary("temperature", temperature)
# Show the first 2048 samples. As the data is randomly sampled, this is
# an approximation of the whole datapoints.
Expand Down