-
Notifications
You must be signed in to change notification settings - Fork 1.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Exporting MultiInputActorCriticPolicy as ONNX #1873
Comments
Hello, Please provide a minimal and working code example (see link in issue template for what that means). |
Hello, thanks for your response. I have tried a couple of things so far. First I tried converting my model into an onnxable policy using the method shown in the documentation. My code is as follows: class OnnxablePolicy(th.nn.Module):
def __init__(self, policy):
super(OnnxablePolicy2, self).__init__()
self.policy = policy
def forward(self, input):
return self.policy(input)
model = PPO.load("Models/ppo.zip")
onnx_policy = OnnxablePolicy(model.policy)
th.onnx.export(
onnx_policy,
obs_dict,
"ONNX/ppo_model.onnx",
opset_version=17,
input_names=["input"],
) To get the dummy input which I am here calling obs_dict, I used the following code snippet: obs = env.reset()
obs_dict = {}
for key in obs.keys():
obs_dict[key] = th.from_numpy(np.array([obs[key]])).float() This creates an input with the same structure as the observation space after I also tried the approach seen here, and created the following code: class OnnxablePolicy(th.nn.Module):
def __init__(self, extractor, action_net, value_net):
super(OnnxablePolicy, self).__init__()
self.extractor = extractor
self.action_net = action_net
self.value_net = value_net
def forward(self, input):
action_hidden = value_hidden = self.extractor(input)
return self.action_net(action_hidden), self.value_net(value_hidden)
onnx_policy = OnnxablePolicy(model.policy.features_extractor, model.policy.action_net, model.policy.value_net)
th.onnx.export(
onnx_policy,
obs_dict,
"ONNX/ppo_model.onnx",
opset_version=17,
input_names=["input"],
) Which resulted in the same error as before. Finally I tried using the policy as is: model = PPO.load("Models/ppo.zip")
obs = env.reset()
th.onnx.export(
model.policy,
obs,
"ONNX/ppo_model.onnx",
opset_version=17,
input_names=["input"],
) This seemingly got me the furthest, producing the new error:
|
I gave it a try but this one seems to be a bit hard, you probably need to use the experimental onnx export from pytorch (using dynamo). my current attempt (the export seems to work but the loading doesn't :/) import torch as th
from typing import Tuple
import gymnasium as gym
from stable_baselines3 import PPO
from stable_baselines3.common.policies import BasePolicy
import onnx
import onnxruntime as ort
import numpy as np
class OnnxableSB3Policy(th.nn.Module):
def __init__(self, policy: BasePolicy):
super().__init__()
self.policy = policy
def forward(self, observation):
print(observation)
return observation["a"]
# NOTE: Preprocessing is included, but postprocessing
# (clipping/inscaling actions) is not,
# If needed, you also need to transpose the images so that they are channel first
# use deterministic=False if you want to export the stochastic policy
return self.policy._predict(observation, deterministic=True)
class Custom(gym.Env):
def __init__(self):
super().__init__()
self.observation_space = gym.spaces.Dict(
{
"a": gym.spaces.Box(low=-1, high=1, shape=(3,), dtype=np.float32),
# "b": gym.spaces.Discrete(5),
}
)
self.action_space = gym.spaces.Discrete(2)
def reset(self, seed=None):
return self.observation_space.sample(), {}
def step(self, action):
return self.observation_space.sample(), 0.0, False, False, {}
env = Custom()
obs, _ = env.reset()
# Example: model = PPO("MlpPolicy", "Pendulum-v1")
PPO("MultiInputPolicy", env).save("PathToTrainedModel")
model = PPO.load("PathToTrainedModel.zip", device="cpu")
onnx_policy = OnnxableSB3Policy(model.policy)
observation_size = model.observation_space.shape
# Add batch dimension
dummy_input = {
# "a": np.array(obs["a"])[np.newaxis, ...],
"a": np.array(obs["a"]),
# "b": np.array(obs["b"])[np.newaxis, ...],
}
dummy_input_tensor = {
"a": th.as_tensor(dummy_input["a"]),
# "b": th.as_tensor(dummy_input["b"]),
}
print(model.predict(dummy_input, deterministic=True))
th.onnx.export(
onnx_policy,
args=(dummy_input_tensor, {}),
f="my_ppo_model.onnx",
opset_version=17,
input_names=["input"],
)
##### Load and test with onnx
onnx_path = "my_ppo_model.onnx"
onnx_model = onnx.load(onnx_path)
onnx.checker.check_model(onnx_model)
observation = dummy_input.copy()
ort_sess = ort.InferenceSession(onnx_path)
# print(ort_sess.get_inputs()[0].name)
# print(ort_sess.get_inputs())
output = ort_sess.run(None, {"input": observation})
print(output)
# Check that the predictions are the same
# with th.no_grad():
# print(model.policy(th.as_tensor(observation), deterministic=True)) |
" from https://pytorch.org/docs/stable/onnx_dynamo.html#torch.onnx.ONNXProgram.adapt_torch_inputs_to_onnx |
HI all, I wouldn't really export the sampling procedure to onnx here (''self.policy._predict(observation, deterministic=True) |
Just wondering if there has been any progress here? I've got the export to work, but when I try to predict, it requires a bunch of "_obs.17", ""_obs.23" ... etc observations which the original model doesn't require. |
❓ Question
Hi,
I am looking into the use of ONNX with SB3. I have tested 2 models (A2C and PPO) on a custom environment using a MultiInputActorCriticPolicy. The observation space of the environment is of type dict. So far I have not been able to produce an onnaxable policy.
In the documentation the words
The following examples are for MlpPolicy only, and are general examples
can be found. Is it possible to export a model of my type to ONNX? and if so would it be possible to provide an example?Thanks
Checklist
The text was updated successfully, but these errors were encountered: