Skip to content

Commit c3d656d

Browse files
Refactorization. (#963)
1 parent 1c883d8 commit c3d656d

File tree

2 files changed

+19
-10
lines changed

2 files changed

+19
-10
lines changed

axlearn/common/loss.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
"""Loss functions."""
3030
# pylint: disable=too-many-lines
3131
import enum
32-
from typing import Optional
32+
from typing import Optional, Union
3333

3434
import jax
3535
import jax.numpy as jnp
@@ -454,7 +454,7 @@ def asymmetric_contrastive_loss_from_logits(
454454
logits: Tensor,
455455
*,
456456
key_paddings: Tensor = None,
457-
temperature: float = 1.0,
457+
temperature: Union[Tensor, float] = 1.0,
458458
soft_labels: Optional[Tensor] = None,
459459
) -> Tensor:
460460
"""Asymmetric contrastive loss from logits.
@@ -522,7 +522,7 @@ def asymmetric_contrastive_loss_from_features(
522522
*,
523523
negative_keys: Tensor = None,
524524
negative_key_paddings: Tensor = None,
525-
temperature: float = 1.0,
525+
temperature: Union[Tensor, float] = 1.0,
526526
soft_labels: Optional[Tensor] = None,
527527
):
528528
"""Asymmetric contrastive loss from features.
@@ -578,7 +578,7 @@ def symmetric_contrastive_loss_from_logits( # pylint: disable=missing-param-doc
578578
*,
579579
y_as_key_paddings: Tensor = None,
580580
x_as_key_paddings: Tensor = None,
581-
temperature: float = 1.0,
581+
temperature: Union[float, Tensor] = 1.0,
582582
y_as_key_soft_labels: Optional[Tensor] = None,
583583
x_as_key_soft_labels: Optional[Tensor] = None,
584584
):
@@ -628,7 +628,7 @@ def symmetric_contrastive_loss_from_features(
628628
y_negatives: Tensor = None,
629629
x_negative_paddings: Tensor = None,
630630
y_negative_paddings: Tensor = None,
631-
temperature: float = 1.0,
631+
temperature: Union[Tensor, float] = 1.0,
632632
y_as_key_soft_labels: Optional[Tensor] = None,
633633
x_as_key_soft_labels: Optional[Tensor] = None,
634634
):

axlearn/vision/clip.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
"""
1313
# pylint: disable=duplicate-code
1414

15-
from typing import Optional, Union
15+
from typing import Optional, Protocol, Union
1616

1717
import jax.numpy as jnp
1818
import numpy as np
@@ -31,10 +31,12 @@
3131
from axlearn.common.bert import bert_embedding_config, bert_model_config, bert_transformer_config
3232
from axlearn.common.config import (
3333
REQUIRED,
34+
ConfigOr,
3435
FunctionConfigBase,
3536
InstantiableConfig,
3637
Required,
3738
config_class,
39+
maybe_instantiate,
3840
)
3941
from axlearn.common.embedding import TransformerTextEmbeddings
4042
from axlearn.common.layers import (
@@ -51,7 +53,7 @@
5153
from axlearn.common.param_init import ConstantInitializer
5254
from axlearn.common.poolings import BasePoolingLayer, FirstNTokenPooling, LastNTokenPooling
5355
from axlearn.common.text_encoder import TEXT_EMBEDDINGS, TextEmbeddingEncoder
54-
from axlearn.common.utils import NestedTensor
56+
from axlearn.common.utils import NestedTensor, Tensor
5557
from axlearn.common.vision_transformer import VisionTransformer, layer_norm_config
5658
from axlearn.vision.mobilenets import MobileNets
5759

@@ -452,6 +454,13 @@ def set_bert_text_encoder_config(
452454
return text_encoder_cfg
453455

454456

457+
class _ContrastiveLossFn(Protocol):
458+
def __call__(
459+
self, x_y_logits: Tensor, y_x_logits: Tensor, *, temperature: Union[Tensor, float]
460+
) -> Tensor:
461+
...
462+
463+
455464
class CLIPFusionNetwork(FusionNetwork):
456465
"""CLIP fusion network. See also CLIPModel."""
457466

@@ -465,11 +474,13 @@ class Config(BaseLayer.Config):
465474
value=np.log(1 / 0.07)
466475
)
467476
temperature_max_cap: float = 100
477+
contrastive_loss_fn: ConfigOr[_ContrastiveLossFn] = symmetric_contrastive_loss_from_logits
468478

469479
def __init__(self, cfg: Config, *, parent: Optional[Module]):
470480
super().__init__(cfg, parent=parent)
471481
cfg = self.config
472482
self._log_logit_scale_init = cfg.log_logit_scale_init.instantiate()
483+
self._contrastive_loss_fn = maybe_instantiate(cfg.contrastive_loss_fn)
473484

474485
def _create_layer_parameter_specs(self) -> dict[str, ParameterSpec]:
475486
param_specs = {}
@@ -515,9 +526,7 @@ def forward(self, input_batch: NestedTensor) -> NestedTensor:
515526
log_logit_scale = jnp.clip(log_logit_scale, a_max=jnp.log(cfg.temperature_max_cap))
516527
temperature = 1 / jnp.exp(log_logit_scale)
517528
similarity = contrastive_logits(x, y)
518-
loss = symmetric_contrastive_loss_from_logits(
519-
similarity, similarity.T, temperature=temperature
520-
)
529+
loss = self._contrastive_loss_fn(similarity, similarity.T, temperature=temperature)
521530
self.add_summary("temperature", temperature)
522531
# Show the first 2048 samples. As the data is randomly sampled, this is
523532
# an approximation of the whole datapoints.

0 commit comments

Comments
 (0)