Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initializing joint marginals is pain #393

Open
bvdmitri opened this issue Dec 19, 2024 · 0 comments
Open

Initializing joint marginals is pain #393

bvdmitri opened this issue Dec 19, 2024 · 0 comments
Assignees

Comments

@bvdmitri
Copy link
Member

We need an easy function to do that based on the following snippet I have for @ismailsenoz

function before_inference(model)
    graph = RxInfer.getmodel(model)
    gcv_labels = collect(filter(RxInfer.GraphPPL.as_node(GCV), graph))
    @info "Number of GCV nodes is $(length(gcv_labels))" 
    for label in gcv_labels
        factor_node = RxInfer.GraphPPL.getextra(graph[label], RxInfer.ReactiveMPExtraFactorNodeKey)
        for marginal in factor_node.localclusters.marginals
            _name = RxInfer.ReactiveMP.name(marginal)
            stream = RxInfer.ReactiveMP.getmarginal(marginal)
            if _name === :y_x
                RxInfer.ReactiveMP.setmarginal!(stream, MvNormalMeanCovariance([ 0.0, 0.0 ], [ 1.0 0.0; 0.0 1.0 ]))
            elseif _name === :z_κ
                RxInfer.ReactiveMP.setmarginal!(stream, MvNormalMeanCovariance([ 0.0, 0.0 ], [ 1.0 0.0; 0.0 1.0 ]))
            elseif _name === 
                RxInfer.ReactiveMP.setmarginal!(stream, NormalMeanVariance(0.0, 1.0))
            end
        end
    end
end

Also GraphPPL exposes NodeId plugin, which is not used in RxInfer but might be very handy here

@bvdmitri bvdmitri self-assigned this Dec 19, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant