diff --git a/edgraph/namespace.go b/edgraph/namespace.go index 4531ea05469..e7373b109b7 100644 --- a/edgraph/namespace.go +++ b/edgraph/namespace.go @@ -310,7 +310,7 @@ func getNamespaceIDFromName(ctx context.Context, nsName string) (uint64, error) return 0, err } if len(data.Namespaces) == 0 { - return 0, errors.Errorf("namespace %q not found", nsName) + return 0, errors.Wrapf(x.ErrNamespaceNotFound, "namespace %q not found", nsName) } glog.Infof("Found namespace [%v] with id [%d]", nsName, data.Namespaces[0].ID) diff --git a/edgraph/server.go b/edgraph/server.go index da496fd8637..1f4cf4242cf 100644 --- a/edgraph/server.go +++ b/edgraph/server.go @@ -1194,7 +1194,26 @@ func (s *Server) Query(ctx context.Context, req *api.Request) (*api.Response, er // Query handles queries or mutations func (s *Server) QueryNoGrpc(ctx context.Context, req *api.Request) (*api.Response, error) { - ctx = x.AttachJWTNamespace(ctx) + // If the `namespace-str` is present in the metadata, use it to attach the namespace + // otherwise, use the namespace from the JWT. + var attached bool + nsStr, _ := x.ExtractNamespaceStr(ctx) + if nsStr != "" { + ns, err := getNamespaceIDFromName(x.AttachNamespace(ctx, x.RootNamespace), nsStr) + if err == nil { + ctx = x.AttachNamespace(ctx, ns) + attached = true + } else { + if !errors.Is(err, x.ErrNamespaceNotFound) { + glog.Warningf("Error getting namespace ID from name: %v. Defaulting to default or JWT namespace", err) + } + } + } + if !attached { + ctx = x.AttachJWTNamespace(ctx) + } + // If acl is enabled, then the namespace from the JWT will be applied in the test below + // overriding any namespace from the metadata obtained from the `namespace-str` above. if x.WorkerConfig.AclEnabled && req.GetStartTs() != 0 { // A fresh StartTs is assigned if it is 0. ns, err := x.ExtractNamespace(ctx) diff --git a/x/x.go b/x/x.go index db2053f0706..2e5c2cc4bb5 100644 --- a/x/x.go +++ b/x/x.go @@ -65,6 +65,8 @@ var ( ErrConflict = errors.New("Transaction conflict") // ErrHashMismatch is returned when the hash does not matches the startTs ErrHashMismatch = errors.New("hash mismatch the claimed startTs|namespace") + // ErrNamespaceNotFound is returned when a namespace is not found. + ErrNamespaceNotFound = errors.New("namespace not found") ) const ( @@ -268,6 +270,19 @@ func ExtractNamespace(ctx context.Context) (uint64, error) { return namespace, nil } +// ExtractNamespaceStr parses the namespace string value from the incoming gRPC context. +func ExtractNamespaceStr(ctx context.Context) (string, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return "", errors.New("No metadata in the context") + } + ns := md.Get("namespace-str") + if len(ns) == 0 { + return "", errors.New("No namespace-str in the metadata of context") + } + return ns[0], nil +} + func IsRootNsOperation(ctx context.Context) bool { md, ok := metadata.FromIncomingContext(ctx) if !ok {