diff --git a/CPC/extract_embedding_and_plot.py b/CPC/extract_embedding_and_plot.py index 149e410..f492ff6 100644 --- a/CPC/extract_embedding_and_plot.py +++ b/CPC/extract_embedding_and_plot.py @@ -22,8 +22,8 @@ table_location = data_location / "tabular_data" table_location.mkdir(exist_ok=True) -version = 1 -model_name = "ben_model_04_masked_BF" +version = 0 +model_name = "ben_model_04_masked_BRA" out_tabular_data = table_location / model_name out_tabular_data.mkdir(exist_ok=True) out_tabular_data = out_tabular_data / f"version_{str(version)}" @@ -63,7 +63,7 @@ loading_transforms = trans.Compose([ CropAndReshapeTL(1,0,598,0), - SelectChannel(0,0), + SelectChannel(1,0), CustomToTensor(dtype=torch.float), v2.Resize((576,576)), ])