@@ -577,6 +577,25 @@ def dbpedia_entities_openai_1M(out_fn, n = None):
577
577
578
578
write_output (X_train , X_test , out_fn , "angular" )
579
579
580
+ def coco (out_fn : str , kind : str ):
581
+ assert kind in ('t2i' , 'i2i' )
582
+
583
+ local_fn = "coco-clip-b16-512-features.hdf5"
584
+ url = "https://github.com/fabiocarrara/str-encoders/releases/download/v0.1.3/%s" % local_fn
585
+ download (url , local_fn )
586
+
587
+ with h5py .File (local_fn , "r" ) as f :
588
+ img_X = f ['img_feats' ][:]
589
+
590
+ X_train , X_test = train_test_split (img_X , test_size = 10_000 )
591
+
592
+ if kind == 't2i' :
593
+ # there are 5 captions per image, take the first one
594
+ txt_X = f ['txt_feats' ][::5 ]
595
+ _ , X_test = train_test_split (txt_X , test_size = 10_000 )
596
+
597
+ write_output (X_train , X_test , out_fn , "angular" )
598
+
580
599
581
600
DATASETS : Dict [str , Callable [[str ], None ]] = {
582
601
"deep-image-96-angular" : deep_image ,
@@ -606,6 +625,8 @@ def dbpedia_entities_openai_1M(out_fn, n = None):
606
625
"movielens1m-jaccard" : movielens1m ,
607
626
"movielens10m-jaccard" : movielens10m ,
608
627
"movielens20m-jaccard" : movielens20m ,
628
+ "coco-i2i-512-angular" : lambda out_fn : coco (out_fn , "i2i" ),
629
+ "coco-t2i-512-angular" : lambda out_fn : coco (out_fn , "t2i" ),
609
630
}
610
631
611
632
DATASETS .update ({
0 commit comments