1212"""
1313# pylint: disable=duplicate-code
1414
15- from typing import Optional , Union
15+ from typing import Optional , Protocol , Union
1616
1717import jax .numpy as jnp
1818import numpy as np
3131from axlearn .common .bert import bert_embedding_config , bert_model_config , bert_transformer_config
3232from axlearn .common .config import (
3333 REQUIRED ,
34+ ConfigOr ,
3435 FunctionConfigBase ,
3536 InstantiableConfig ,
3637 Required ,
3738 config_class ,
39+ maybe_instantiate ,
3840)
3941from axlearn .common .embedding import TransformerTextEmbeddings
4042from axlearn .common .layers import (
5153from axlearn .common .param_init import ConstantInitializer
5254from axlearn .common .poolings import BasePoolingLayer , FirstNTokenPooling , LastNTokenPooling
5355from axlearn .common .text_encoder import TEXT_EMBEDDINGS , TextEmbeddingEncoder
54- from axlearn .common .utils import NestedTensor
56+ from axlearn .common .utils import NestedTensor , Tensor
5557from axlearn .common .vision_transformer import VisionTransformer , layer_norm_config
5658from axlearn .vision .mobilenets import MobileNets
5759
@@ -452,6 +454,13 @@ def set_bert_text_encoder_config(
452454 return text_encoder_cfg
453455
454456
457+ class _ContrastiveLossFn (Protocol ):
458+ def __call__ (
459+ self , x_y_logits : Tensor , y_x_logits : Tensor , * , temperature : Union [Tensor , float ]
460+ ) -> Tensor :
461+ ...
462+
463+
455464class CLIPFusionNetwork (FusionNetwork ):
456465 """CLIP fusion network. See also CLIPModel."""
457466
@@ -465,11 +474,13 @@ class Config(BaseLayer.Config):
465474 value = np .log (1 / 0.07 )
466475 )
467476 temperature_max_cap : float = 100
477+ contrastive_loss_fn : ConfigOr [_ContrastiveLossFn ] = symmetric_contrastive_loss_from_logits
468478
469479 def __init__ (self , cfg : Config , * , parent : Optional [Module ]):
470480 super ().__init__ (cfg , parent = parent )
471481 cfg = self .config
472482 self ._log_logit_scale_init = cfg .log_logit_scale_init .instantiate ()
483+ self ._contrastive_loss_fn = maybe_instantiate (cfg .contrastive_loss_fn )
473484
474485 def _create_layer_parameter_specs (self ) -> dict [str , ParameterSpec ]:
475486 param_specs = {}
@@ -515,9 +526,7 @@ def forward(self, input_batch: NestedTensor) -> NestedTensor:
515526 log_logit_scale = jnp .clip (log_logit_scale , a_max = jnp .log (cfg .temperature_max_cap ))
516527 temperature = 1 / jnp .exp (log_logit_scale )
517528 similarity = contrastive_logits (x , y )
518- loss = symmetric_contrastive_loss_from_logits (
519- similarity , similarity .T , temperature = temperature
520- )
529+ loss = self ._contrastive_loss_fn (similarity , similarity .T , temperature = temperature )
521530 self .add_summary ("temperature" , temperature )
522531 # Show the first 2048 samples. As the data is randomly sampled, this is
523532 # an approximation of the whole datapoints.
0 commit comments