Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

vmap does not support Tensor.clone() #1143

Open
JustinS6626 opened this issue Apr 10, 2024 · 0 comments
Open

vmap does not support Tensor.clone() #1143

JustinS6626 opened this issue Apr 10, 2024 · 0 comments

Comments

@JustinS6626
Copy link

Hi All,

First, just a quick note, you will need pennylane installed in order to run this toy example. Here is the code that I am trying to execute:

import time
import os
import copy
from copy import deepcopy
# PyTorch
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision
from torchvision import datasets, transforms
from torch.func import functional_call
from torch.func import stack_module_state
import copy
from torch import vmap

# Pennylane
import pennylane as qml
from pennylane import numpy as np

torch.manual_seed(42)
#np.random.seed(42)
rng = np.random.default_rng(12345)

# Plotting
import matplotlib.pyplot as plt

n_qubits = 4                # Number of qubits
step = 0.0004               # Learning rate
batch_size = 4              # Number of samples for each training step
num_epochs = 3              # Number of training epochs
q_depth = 6                 # Depth of the quantum circuit (number of variational layers)
gamma_lr_scheduler = 0.1    # Learning rate reduction applied every 10 epochs.
q_delta = 0.01              # Initial spread of random quantum weights
start_time = time.time()    # Start of the computation timer
num_models = 10

#dev = qml.device("default.mixed", wires=n_qubits)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

data_transforms = {
    "train": transforms.Compose(
        [
            # transforms.RandomResizedCrop(224),     # uncomment for data augmentation
            # transforms.RandomHorizontalFlip(),     # uncomment for data augmentation
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            # Normalize input channels using mean values and standard deviations of ImageNet.
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
    "val": transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    ),
}

data_dir = "hymenoptera_data"
image_datasets = {
    x if x == "train" else "validation": datasets.ImageFolder(
        os.path.join(data_dir, x), data_transforms[x]
    )
    for x in ["train", "val"]
}
dataset_sizes = {x: len(image_datasets[x]) for x in ["train", "validation"]}
class_names = image_datasets["train"].classes

# Initialize dataloader
dataloaders = {
    x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True)
    for x in ["train", "validation"]
}

inputs, classes = next(iter(dataloaders["validation"]))

def H_layer(nqubits):
    """Layer of single-qubit Hadamard gates.
    """
    for idx in range(nqubits):
        qml.Hadamard(wires=idx)


def RY_layer(w):
    """Layer of parametrized qubit rotations around the y axis.
    """
    for idx, element in enumerate(w):
        qml.RY(element, wires=idx)


def entangling_layer(nqubits):
    """Layer of CNOTs followed by another shifted layer of CNOT.
    """
    # In other words it should apply something like :
    # CNOT  CNOT  CNOT  CNOT...  CNOT
    #   CNOT  CNOT  CNOT...  CNOT
    for i in range(0, nqubits - 1, 2):  # Loop over even indices: i=0,2,...N-2
        qml.CNOT(wires=[i, i + 1])
    for i in range(1, nqubits - 1, 2):  # Loop over odd indices:  i=1,3,...N-3
        qml.CNOT(wires=[i, i + 1])


#@qml.qnode(dev)
def quantum_net(q_input_features, q_weights_flat, noise_probs):
    """
    The variational quantum circuit.
    """
    #noise_probs = np.random.uniform(size=n_qubits)
    #noise_probs = rng.integers(2, size=n_qubits)
    print("Noise probabilities:")
    print(noise_probs)
    # Reshape weights
    q_weights = q_weights_flat.reshape(q_depth, n_qubits)

    # Start from state |+> , unbiased w.r.t. |0> and |1>
    H_layer(n_qubits)

    # Embed features in the quantum node
    RY_layer(q_input_features)
    for i in range(n_qubits):
        qml.BitFlip(noise_probs[i], wires=i)
        print("bit flip done")
    # Sequence of trainable variational layers
    for k in range(q_depth):
        entangling_layer(n_qubits)
        RY_layer(q_weights[k])

    # Expectation values in the Z basis
    exp_vals = [qml.expval(qml.PauliZ(position)) for position in range(n_qubits)]
    print("Exp vals calculated")
    return tuple(exp_vals)


class DressedQuantumNet(nn.Module):
    """
    Torch module implementing the *dressed* quantum net.
    """

    def __init__(self):
        """
        Definition of the *dressed* layout.
        """

        super().__init__()
        weights = torchvision.models.ResNet18_Weights.IMAGENET1K_V1
        self.premodel = torchvision.models.resnet18(weights=weights)
        self.premodel.fc = nn.Linear(512, n_qubits)
        self.q_params = nn.Parameter(q_delta * torch.randn(q_depth * n_qubits))
        self.post_net = nn.Linear(n_qubits, 2)


    def forward(self, input_features, bitflips):
        """
        Defining how tensors are supposed to move through the *dressed* quantum
        net.
        """

        # obtain the input features for the quantum circuit
        # by reducing the feature dimension from 512 to
        dev = qml.device("default.mixed", wires=n_qubits)
        node = qml.QNode(quantum_net, dev, interface="torch", diff_method="backprop")
        flips = bitflips.clone().detach().cpu().numpy()
        pre_out = self.premodel(input_features)
        q_in = torch.tanh(pre_out) * np.pi / 2.0

        # Apply the quantum circuit to each element of the batch and append to q_out
        q_out = torch.Tensor(0, n_qubits)
        q_out = q_out.to(device)
        for elem in q_in:
            q_out_elem = torch.hstack(node(elem, self.q_params, flips)).float().unsqueeze(0)
            q_out = torch.cat((q_out, q_out_elem))
            print("Q out calculated")

        # return the two-dimensional prediction from the postprocessing layer
        return self.post_net(q_out)


##weights = torchvision.models.ResNet18_Weights.IMAGENET1K_V1
##model_hybrid = torchvision.models.resnet18(weights=weights)



# Notice that model_hybrid.fc is the last layer of ResNet18
model_hybrid = DressedQuantumNet()

for param in model_hybrid.parameters():
    param.requires_grad = False

# Use CUDA or CPU according to the "device" object.
model_hybrid = model_hybrid.to(device)

criterion = nn.CrossEntropyLoss()
optimizer_hybrid = optim.Adam(model_hybrid.parameters(), lr=step)

exp_lr_scheduler = lr_scheduler.StepLR(
    optimizer_hybrid, step_size=10, gamma=gamma_lr_scheduler
)

def train_model(model, criterion, optimizer, scheduler, num_epochs):
    
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    best_loss = 10000.0  # Large arbitrary number
    best_acc_train = 0.0
    best_loss_train = 10000.0  # Large arbitrary number
    print("Training started:")

    for epoch in range(num_epochs):
        models = [deepcopy(model) for _ in range(num_models)]
        params, buffers = stack_module_state(models)
        base_model = models[0]
        base_model = base_model.to('meta')
        def fmodel(params, buffers, x, flips):
            return functional_call(base_model, (params, buffers), (x, flips))
        # Each epoch has a training and validation phase
        for phase in ["train", "validation"]:
            if phase == "train":
                # Set model to training mode
                model.train()
            else:
                # Set model to evaluate mode
                model.eval()
            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            n_batches = dataset_sizes[phase] // batch_size
            it = 0
            for inputs, labels in dataloaders[phase]:
                since_batch = time.time()
                batch_size_ = len(inputs)
                inputs = inputs.to(device)
                labels = labels.to(device)
                all_inputs = torch.stack([inputs.clone() for i in range(num_models)])
                all_labels = torch.stack([labels.clone() for i in range(num_models)])
                bitflips = torch.randint(2, (num_models, n_qubits)).to(device)
                optimizer.zero_grad()

                # Track/compute gradient and make an optimization step only when training
                with torch.set_grad_enabled(phase == "train"):
                    #outputs = model(inputs)
                    outputs = vmap(fmodel)(params, buffers, all_inputs, bitflips)
                    print(outputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, all_labels)
                    if phase == "train":
                        loss.backward()
                        optimizer.step()

                # Print iteration results
                running_loss += loss.item() * batch_size_
                batch_corrects = torch.sum(preds == labels.data).item()
                running_corrects += batch_corrects
                print(
                    "Phase: {} Epoch: {}/{} Iter: {}/{} Batch time: {:.4f}".format(
                        phase,
                        epoch + 1,
                        num_epochs,
                        it + 1,
                        n_batches + 1,
                        time.time() - since_batch,
                    ),
                    end="\r",
                    flush=True,
                )
                it += 1

            # Print epoch results
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects / dataset_sizes[phase]
            print(
                "Phase: {} Epoch: {}/{} Loss: {:.4f} Acc: {:.4f}        ".format(
                    "train" if phase == "train" else "validation  ",
                    epoch + 1,
                    num_epochs,
                    epoch_loss,
                    epoch_acc,
                )
            )

            # Check if this is the best model wrt previous epochs
            if phase == "validation" and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == "validation" and epoch_loss < best_loss:
                best_loss = epoch_loss
            if phase == "train" and epoch_acc > best_acc_train:
                best_acc_train = epoch_acc
            if phase == "train" and epoch_loss < best_loss_train:
                best_loss_train = epoch_loss

            # Update learning rate
            if phase == "train":
                scheduler.step()

    # Print final results
    model.load_state_dict(best_model_wts)
    time_elapsed = time.time() - since
    print(
        "Training completed in {:.0f}m {:.0f}s".format(time_elapsed // 60, time_elapsed % 60)
    )
    print("Best test loss: {:.4f} | Best test accuracy: {:.4f}".format(best_loss, best_acc))
    return model


model_hybrid = train_model(
    model_hybrid, criterion, optimizer_hybrid, exp_lr_scheduler, num_epochs=num_epochs
)

When I run the code, I get the following error:

Traceback (most recent call last):
  File "/usr/lib/python3.9/idlelib/run.py", line 559, in runcode
    exec(code, self.locals)
  File "/home/justin/RandomEnsembleToyExample.py", line 311, in <module>
    model_hybrid = train_model(
  File "/home/justin/RandomEnsembleToyExample.py", line 247, in train_model
    outputs = vmap(fmodel)(params, buffers, all_inputs, bitflips)
  File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/apis.py", line 188, in wrapped
    return vmap_impl(func, in_dims, out_dims, randomness, chunk_size, *args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/vmap.py", line 278, in vmap_impl
    return _flat_vmap(
  File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/vmap.py", line 44, in fn
    return f(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/vmap.py", line 391, in _flat_vmap
    batched_outputs = func(*batched_inputs, **kwargs)
  File "/home/justin/RandomEnsembleToyExample.py", line 219, in fmodel
    return functional_call(base_model, (params, buffers), (x, flips))
  File "/usr/local/lib/python3.9/dist-packages/torch/_functorch/functional_call.py", line 143, in functional_call
    return nn.utils.stateless._functional_call(
  File "/usr/local/lib/python3.9/dist-packages/torch/nn/utils/stateless.py", line 263, in _functional_call
    return module(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/justin/RandomEnsembleToyExample.py", line 166, in forward
    flips = bitflips.clone().detach().cpu().numpy()
RuntimeError: Cannot access data pointer of Tensor that doesn't have storage

Is this because the Tensor.clone() method is simply not supported for vmap, or is there a special technique that I can use in this case to make it work properly? If it's the latter, any help would be much appreciated. In the case of the former, is there something analogous that I can do that would give me the same result?

@JustinS6626 JustinS6626 changed the title vmap does not support Tensor.clone() vmap does not support Tensor.clone()[high priority] Apr 10, 2024
@JustinS6626 JustinS6626 changed the title vmap does not support Tensor.clone()[high priority] vmap does not support Tensor.clone() Apr 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant