Skip to content

Commit

Permalink
zarr dataloader mitocondria
Browse files Browse the repository at this point in the history
  • Loading branch information
Alecampoy committed Sep 3, 2024
1 parent 6f05288 commit a24c559
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 12 deletions.
6 changes: 3 additions & 3 deletions scripts/training_loop_VaeResnet18_ac.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,10 @@
import yaml

# Parameters
model_name = "test_linear_ac_latent_128"
run_name= "Linear_dataset_split_17_latent_128"
model_name = "test_linear_ac_latent_128_b5e-6"
run_name= "Linear_dataset_split_17_latent_128_b5e-6"
latent_space_dim = 128
beta = 1e-4
beta = 5e-6
n_epochs = 15
find_port = False #change to false if you already have tensorboard running

Expand Down
25 changes: 25 additions & 0 deletions src/embed_time/launch_tensorboard_ac.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import os
import subprocess
import pandas as pd
import numpy as np
from torch.utils.tensorboard import SummaryWriter

def find_free_port():
import socket

with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("", 0))
return s.getsockname()[1]

# Launch TensorBoard on the browser
def launch_tensorboard(log_dir):
port = find_free_port()
tensorboard_cmd = f"tensorboard --logdir={log_dir} --port={port}"
process = subprocess.Popen(tensorboard_cmd, shell=True)
print(
f"TensorBoard started at http://localhost:{port}. \n"
"If you are using VSCode remote session, forward the port using the PORTS tab next to TERMINAL."
)
return process

tensorboard_process = launch_tensorboard("embed_time_static_runs")
18 changes: 9 additions & 9 deletions src/embed_time/zarr_dataloader_ac.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,16 @@ def __getitem__(self, idx):

# transformation
transform_array = np.max(array, axis=2).squeeze(0)
print(transform_array.shape)
transform_array = scipy.ndimage.zoom(transform_array, zoom=(1, 0.5, 0.5))
print(transform_array.shape)
flip_prob = np.random.rand(1)
# print(transform_array.shape)
# transform_array = scipy.ndimage.zoom(transform_array, zoom=(1, 0.5, 0.5))
# print(transform_array.shape)
# flip_prob = np.random.rand(1)

if flip_prob >0.5:
transform_array = np.flip(transform_array, axis=(1,2)) #(2,2024,2024)
print(transform_array.shape)
transform_array = np.rot90(transform_array, k=np.random.randint(4), axes=(1,2))
print(transform_array.shape)
# if flip_prob >0.5:
# transform_array = np.flip(transform_array, axis=(1,2)) #(2,2024,2024)
# print(transform_array.shape)
# transform_array = np.rot90(transform_array, k=np.random.randint(4), axes=(1,2))
# print(transform_array.shape)
return transform_array

dataset = ZarrDataset("/home/S-ac/embed_time/zarrdata/mitochondria.zarr")
Expand Down

0 comments on commit a24c559

Please sign in to comment.