Skip to content

Commit

Permalink
Add types to the graph.
Browse files Browse the repository at this point in the history
This adds a fourth node type, and a fourth edge flow, both called
"type". The idea is to represent types as first-class elements in the
graph representation. This allows greater compositionality by breaking
up composite types into subcomponents, and decreases the required
vocabulary size required to achieve a given coverage.

Background
----------

Currently, type information is stored in the "text" field of nodes for
constants and variables, e.g.:

    node {
      type: VARIABLE
      text: "i8"
    }

There are two issues with this:

 * Composite types end up with long textual representations,
   e.g. "struct foo { i32 a; i32 b; ... }". Since there is an
   unbounded number of possible structs, this prevents 100% vocabulary
   coverage on any IR with structs (or other composite types).

 * In the future, we will want to encode different information on data
   nodes, such as embedding literal values. Moving the type information
   out of the data node "frees up" space for something else.

Overview
--------

This changes the representation to represent types as first-class
elements in the graph. A "type" node represents a type using its
"text" field, and a new "type" edge connects this type to variables or
constants of that type, e.g. a variable "int x" could be represented as:

    node {
      type: VARIABLE
      text: "var"
    }
    node {
      type: TYPE
      text: "i32"
    }
    edge {
      flow: TYPE
      source: 1
    }

Composite types
---------------

Types may be composed by connecting multiple type nodes using type
edges. This allows you to break down complex types into a graph of
primitive parts. The meaning of composite types will depend on the
IR being targetted, the remainder describes the process for
LLVM-IR.

Pointer types
-------------

A pointer is a composite of two types:

    [variable] <- [pointer] <- [pointed-type]

For example:

    int32_t* instance;

Would be represented as:

    node {
      type: TYPE
      text: "i32"
    }
    node {
      type: TYPE
      text: "*"
    }
    node {
      type: VARIABLE
      text: "var"
    }
    edge {
      text: TYPE
      target: 1
    }
    edge {
      text: TYPE
      source: 1
      target: 2
    }

Where variables/constants of this type receive an incoming type edge
from the [pointer] node, which in turn receives an incoming type edge
from the [pointed-type] node.

One [pointer] node is generated for each unique pointer type. If a
graph contains multiple pointer types, there will be multiple
[pointer] nodes, one for each pointed type.

Struct types
------------

A struct is a compsite type where each member is a node type which
points to the parent node. Variable/constant instances of a struct
receive an incoming type edge from the root struct node. Note that
the graph of type nodes representing a composite struct type may be
cyclical, since a struct can contain a pointer of the same type (think
of a binary tree implementation). For all other member types, a new
type node is produced. For example, a struct with two integer members
will produce two integer type nodes, they are not shared.

The type edges from member nodes to the parent struct are
positional. The position indicates the element number. E.g. for a
struct with three elements, the incoming type edges to the struct node
will have positions 0, 1, and 2.

This example struct:

    struct s {
      int8_t a;
      int8_t b;
      struct s* c;
    }

    struct s instance;

Would be represented as:

    node {
      type: TYPE
      text: "struct"
    }
    node {
      type: TYPE
      text: "i8"
    }
    node {
      type: TYPE
      text: "i8"
    }
    node {
      type: TYPE
      text: "*"
    }
    node {
      type: VARIABLE
      text: "var"
    }
    edge {
      flow: TYPE
      target: 1
    }
    edge {
      flow: TYPE
      target: 2
      position: 1
    }
    edge {
      flow: TYPE
      target: 3
      position: 2
    }
    edge {
      flow: TYPE
      source: 3
    }
    edge {
      flow: TYPE
      target: 4
    }

Array Types
-----------

An array is a composite type [variable] <- [array] <- [element-type].
For example, the array:

    int a[10];

Would be represented as:

    node {
      type: TYPE
      text: "i32"
    }
    node {
      type: TYPE
      text: "[]"
    }
    node {
      type: VARIABLE
      text: "var"
    }
    edge {
      flow: TYPE
      target: 1
    }
    edge {
      flow: TYPE
      source: 1
      target: 2
    }

Function Pointers
-----------------

A function pointer is represented by a type node that uniquely identifies the
*signature* of a function, i.e. its return type and parameter types. The caveat
of this is that pointers to different functions which have the same signature
will resolve to the same type node. Additionally, there is no edge connecting a
function pointer type and the instructions which belong to this function.

github.com//issues/82
  • Loading branch information
ChrisCummins committed Jul 16, 2022
1 parent d8758c6 commit 8318ab9
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 12 deletions.
26 changes: 18 additions & 8 deletions programl/graph/format/graphviz_converter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,27 +170,31 @@ class GraphVizSerializer {
template <typename T>
void SetVertexAttributes(const Node& node, T& attributes) {
attributes["label"] = GetNodeLabel(node);
attributes["style"] = "filled";
switch (node.type()) {
case Node::INSTRUCTION:
attributes["shape"] = "box";
attributes["style"] = "filled";
attributes["fillcolor"] = "#3c78d8";
attributes["fontcolor"] = "#ffffff";
break;
case Node::VARIABLE:
attributes["shape"] = "ellipse";
attributes["style"] = "filled";
attributes["fillcolor"] = "#f4cccc";
attributes["color"] = "#990000";
attributes["fontcolor"] = "#990000";
break;
case Node::CONSTANT:
attributes["shape"] = "diamond";
attributes["style"] = "filled";
attributes["shape"] = "octagon";
attributes["fillcolor"] = "#e99c9c";
attributes["color"] = "#990000";
attributes["fontcolor"] = "#990000";
break;
case Node::TYPE:
attributes["shape"] = "diamond";
attributes["fillcolor"] = "#cccccc";
attributes["color"] = "#cccccc";
attributes["fontcolor"] = "#222222";
break;
default:
LOG(FATAL) << "unreachable";
}
Expand All @@ -204,7 +208,7 @@ class GraphVizSerializer {
const Node& node = graph_.node(i);
// Determine the subgraph to add this node to.
boost::subgraph<GraphvizGraph>* dst = defaultGraph;
if (i && node.type() != Node::CONSTANT) {
if (i && (node.type() == Node::INSTRUCTION || node.type() == Node::VARIABLE)) {
dst = &(*functionGraphs)[node.function()].get();
}
auto vertex = add_vertex(i, *dst);
Expand All @@ -229,16 +233,22 @@ class GraphVizSerializer {
attributes["color"] = "#65ae4d";
attributes["weight"] = "1";
break;
case Edge::TYPE:
attributes["color"] = "#aaaaaa";
attributes["weight"] = "1";
attributes["penwidth"] = "1.5";
break;
default:
LOG(FATAL) << "unreachable";
}

// Set the edge label.
if (edge.position()) {
// Position labels for control edge are drawn close to the originating
// instruction. For data edges, they are drawn closer to the consuming
// instruction.
const string label = edge.flow() == Edge::DATA ? "headlabel" : "taillabel";
// instruction. For control edges, they are drawn close to the branching
// instruction. For data and type edges, they are drawn close to the
// consuming node.
const string label = edge.flow() == Edge::CONTROL ? "taillabel" : "headlabel";
attributes[label] = std::to_string(edge.position());
attributes["labelfontcolor"] = attributes["color"];
}
Expand Down
21 changes: 21 additions & 0 deletions programl/graph/program_graph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ Node* ProgramGraphBuilder::AddVariable(const string& text, const Function* funct

Node* ProgramGraphBuilder::AddConstant(const string& text) { return AddNode(Node::CONSTANT, text); }

Node* ProgramGraphBuilder::AddType(const string& text) { return AddNode(Node::TYPE, text); }

labm8::StatusOr<Edge*> ProgramGraphBuilder::AddControlEdge(int32_t position, const Node* source,
const Node* target) {
DCHECK(source) << "nullptr argument";
Expand Down Expand Up @@ -131,6 +133,25 @@ labm8::StatusOr<Edge*> ProgramGraphBuilder::AddCallEdge(const Node* source, cons
return AddEdge(Edge::CALL, /*position=*/0, source, target);
}

labm8::StatusOr<Edge*> ProgramGraphBuilder::AddTypeEdge(int32_t position, const Node* source,
const Node* target) {
DCHECK(source) << "nullptr argument";
DCHECK(target) << "nullptr argument";

if (source->type() != Node::TYPE) {
return Status(labm8::error::Code::INVALID_ARGUMENT,
"Invalid source type ({}) for type edge. Expected type",
Node::Type_Name(source->type()));
}
if (target->type() == Node::INSTRUCTION) {
return Status(labm8::error::Code::INVALID_ARGUMENT,
"Invalid destination type (instruction) for type edge. "
"Expected {variable,constant,type}");
}

return AddEdge(Edge::TYPE, position, source, target);
}

labm8::StatusOr<ProgramGraph> ProgramGraphBuilder::Build() {
if (options().strict()) {
RETURN_IF_ERROR(ValidateGraph());
Expand Down
7 changes: 6 additions & 1 deletion programl/graph/program_graph_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ class ProgramGraphBuilder {

Node* AddConstant(const string& text);

Node* AddType(const string& text);

// Edge factories.
[[nodiscard]] labm8::StatusOr<Edge*> AddControlEdge(int32_t position, const Node* source,
const Node* target);
Expand All @@ -73,6 +75,9 @@ class ProgramGraphBuilder {

[[nodiscard]] labm8::StatusOr<Edge*> AddCallEdge(const Node* source, const Node* target);

[[nodiscard]] labm8::StatusOr<Edge*> AddTypeEdge(int32_t position, const Node* source,
const Node* target);

const Node* GetRootNode() const { return &graph_.node(0); }

// Return the graph protocol buffer.
Expand Down Expand Up @@ -116,7 +121,7 @@ class ProgramGraphBuilder {
int32_t GetIndex(const Function* function);
int32_t GetIndex(const Node* node);

// Maps which covert store the index of objects in repeated field lists.
// Maps that store the index of objects in repeated field lists.
absl::flat_hash_map<Module*, int32_t> moduleIndices_;
absl::flat_hash_map<Function*, int32_t> functionIndices_;
absl::flat_hash_map<Node*, int32_t> nodeIndices_;
Expand Down
7 changes: 7 additions & 0 deletions programl/ir/llvm/inst2vec_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def Encode(self, proto: ProgramGraph, ir: Optional[str] = None) -> ProgramGraph:
# Add the node features.
var_embedding = self.dictionary["!IDENTIFIER"]
const_embedding = self.dictionary["!IMMEDIATE"]
type_embedding = self.dictionary["!IMMEDIATE"] # Types are immediates

text_index = 0
for node in proto.node:
Expand All @@ -113,6 +114,12 @@ def Encode(self, proto: ProgramGraph, ir: Optional[str] = None) -> ProgramGraph:
node.features.feature["inst2vec_embedding"].int64_list.value.append(
const_embedding
)
elif node.type == node_pb2.Node.TYPE:
node.features.feature["inst2vec_embedding"].int64_list.value.append(
type_embedding
)
else:
raise TypeError(f"Unknown node type {node}")

proto.features.feature["inst2vec_annotated"].int64_list.value.append(1)
return proto
Expand Down
119 changes: 116 additions & 3 deletions programl/ir/llvm/internal/program_graph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "labm8/cpp/logging.h"
#include "labm8/cpp/status_macros.h"
#include "labm8/cpp/string.h"
#include "llvm/IR/BasicBlock.h"
Expand Down Expand Up @@ -323,29 +324,131 @@ Node* ProgramGraphBuilder::AddLlvmInstruction(const ::llvm::Instruction* instruc
Node* ProgramGraphBuilder::AddLlvmVariable(const ::llvm::Instruction* operand,
const programl::Function* function) {
const LlvmTextComponents text = textEncoder_.Encode(operand);
Node* node = AddVariable(text.lhs_type, function);
Node* node = AddVariable("var", function);
node->set_block(blockCount_);
graph::AddScalarFeature(node, "full_text", text.lhs);

compositeTypeParts_.clear(); // Reset after previous call.
Node* type = GetOrCreateType(operand->getType());
CHECK(AddTypeEdge(/*position=*/0, type, node).ok());

return node;
}

Node* ProgramGraphBuilder::AddLlvmVariable(const ::llvm::Argument* argument,
const programl::Function* function) {
const LlvmTextComponents text = textEncoder_.Encode(argument);
Node* node = AddVariable(text.lhs_type, function);
Node* node = AddVariable("var", function);
node->set_block(blockCount_);
graph::AddScalarFeature(node, "full_text", text.lhs);

compositeTypeParts_.clear(); // Reset after previous call.
Node* type = GetOrCreateType(argument->getType());
CHECK(AddTypeEdge(/*position=*/0, type, node).ok());

return node;
}

Node* ProgramGraphBuilder::AddLlvmConstant(const ::llvm::Constant* constant) {
const LlvmTextComponents text = textEncoder_.Encode(constant);
Node* node = AddConstant(text.lhs_type);
Node* node = AddConstant("val");
node->set_block(blockCount_);
graph::AddScalarFeature(node, "full_text", text.text);

compositeTypeParts_.clear(); // Reset after previous call.
Node* type = GetOrCreateType(constant->getType());
CHECK(AddTypeEdge(/*position=*/0, type, node).ok());

return node;
}

Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::Type* type) {
// Dispatch to the type-specific handlers.
if (::llvm::dyn_cast<::llvm::StructType>(type)) {
return AddLlvmType(::llvm::dyn_cast<::llvm::StructType>(type));
} else if (::llvm::dyn_cast<::llvm::PointerType>(type)) {
return AddLlvmType(::llvm::dyn_cast<::llvm::PointerType>(type));
} else if (::llvm::dyn_cast<::llvm::FunctionType>(type)) {
return AddLlvmType(::llvm::dyn_cast<::llvm::FunctionType>(type));
} else if (::llvm::dyn_cast<::llvm::ArrayType>(type)) {
return AddLlvmType(::llvm::dyn_cast<::llvm::ArrayType>(type));
} else if (::llvm::dyn_cast<::llvm::VectorType>(type)) {
return AddLlvmType(::llvm::dyn_cast<::llvm::VectorType>(type));
} else {
const LlvmTextComponents text = textEncoder_.Encode(type);
Node* node = AddType(text.text);
graph::AddScalarFeature(node, "llvm_string", text.text);
return node;
}
}

Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::StructType* type) {
Node* node = AddType("struct");
compositeTypeParts_[type] = node;
graph::AddScalarFeature(node, "llvm_string", textEncoder_.Encode(type).text);

// Add types for the struct elements, and add type edges.
for (int i = 0; i < type->getNumElements(); ++i) {
const auto& member = type->elements()[i];
// Don't re-use member types in structs, always create a new type. For
// example, the code:
//
// struct S {
// int a;
// int b;
// };
// int c;
// int d;
//
// would produce four type nodes: one for S.a, one for S.b, and one which
// is shared by c and d.
Node* memberNode = AddLlvmType(member);
CHECK(AddTypeEdge(/*position=*/i, memberNode, node).ok());
}

return node;
}

Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::PointerType* type) {
Node* node = AddType("*");
graph::AddScalarFeature(node, "llvm_string", textEncoder_.Encode(type).text);

auto elementType = type->getElementType();
auto parent = compositeTypeParts_.find(elementType);
if (parent == compositeTypeParts_.end()) {
// Re-use the type if it already exists to prevent duplication.
auto elementNode = GetOrCreateType(type->getElementType());
CHECK(AddTypeEdge(/*position=*/0, elementNode, node).ok());
} else {
// Bottom-out for self-referencing types.
CHECK(AddTypeEdge(/*position=*/0, parent->second, node).ok());
}

return node;
}

Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::FunctionType* type) {
const std::string signature = textEncoder_.Encode(type).text;
Node* node = AddType(signature);
graph::AddScalarFeature(node, "llvm_string", signature);
return node;
}

Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::ArrayType* type) {
Node* node = AddType("[]");
graph::AddScalarFeature(node, "llvm_string", textEncoder_.Encode(type).text);
// Re-use the type if it already exists to prevent duplication.
auto elementType = GetOrCreateType(type->getElementType());
CHECK(AddTypeEdge(/*position=*/0, elementType, node).ok());
return node;
}

Node* ProgramGraphBuilder::AddLlvmType(const ::llvm::VectorType* type) {
Node* node = AddType("vector");
graph::AddScalarFeature(node, "llvm_string", textEncoder_.Encode(type).text);
// Re-use the type if it already exists to prevent duplication.
auto elementType = GetOrCreateType(type->getElementType());
CHECK(AddTypeEdge(/*position=*/0, elementType, node).ok());
return node;
}

Expand Down Expand Up @@ -461,6 +564,16 @@ void ProgramGraphBuilder::Clear() {
programl::graph::ProgramGraphBuilder::Clear();
}

Node* ProgramGraphBuilder::GetOrCreateType(const ::llvm::Type* type) {
auto it = types_.find(type);
if (it == types_.end()) {
Node* node = AddLlvmType(type);
types_[type] = node;
return node;
}
return it->second;
}

} // namespace internal
} // namespace llvm
} // namespace ir
Expand Down
32 changes: 32 additions & 0 deletions programl/ir/llvm/internal/program_graph_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ class ProgramGraphBuilder : public programl::graph::ProgramGraphBuilder {

void Clear();

// Return the node representing a type. If no node already exists
// for this type, a new node is created and added to the graph. In
// the case of composite types, multiple new nodes may be added by
// this call, and the root type returned.
Node* GetOrCreateType(const ::llvm::Type* type);

protected:
[[nodiscard]] labm8::StatusOr<FunctionEntryExits> VisitFunction(const ::llvm::Function& function,
const Function* functionMessage);
Expand All @@ -85,6 +91,12 @@ class ProgramGraphBuilder : public programl::graph::ProgramGraphBuilder {
Node* AddLlvmVariable(const ::llvm::Instruction* operand, const Function* function);
Node* AddLlvmVariable(const ::llvm::Argument* argument, const Function* function);
Node* AddLlvmConstant(const ::llvm::Constant* constant);
Node* AddLlvmType(const ::llvm::Type* type);
Node* AddLlvmType(const ::llvm::StructType* type);
Node* AddLlvmType(const ::llvm::PointerType* type);
Node* AddLlvmType(const ::llvm::FunctionType* type);
Node* AddLlvmType(const ::llvm::ArrayType* type);
Node* AddLlvmType(const ::llvm::VectorType* type);

private:
TextEncoder textEncoder_;
Expand All @@ -99,6 +111,26 @@ class ProgramGraphBuilder : public programl::graph::ProgramGraphBuilder {
// populated by VisitBasicBlock() and consumed once all functions have been
// visited.
absl::flat_hash_map<const ::llvm::Constant*, std::vector<PositionalNode>> constants_;

// A map from an LLVM type to the node message that represents it.
absl::flat_hash_map<const ::llvm::Type*, Node*> types_;

// When adding a new type to the graph we need to know whether the type that
// we are adding is part of a composite type that references itself. For
// example:
//
// struct BinaryTree {
// int data;
// struct BinaryTree* left;
// struct BinaryTree* right;
// }
//
// When the recursive GetOrCreateType() resolves the "left" member, it needs
// to know that the parent BinaryTree type has already been processed. This
// map stores the Nodes corresponding to any parent structs that have been
// already added in a call to GetOrCreateType(). It must be cleared between
// calls.
absl::flat_hash_map<const ::llvm::Type*, Node*> compositeTypeParts_;
};

} // namespace internal
Expand Down
4 changes: 4 additions & 0 deletions programl/proto/program_graph.proto
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ message Node {
VARIABLE = 1;
// A constant.
CONSTANT = 2;
// A type.
TYPE = 3;
}
// The type of the node.
Type type = 1;
Expand Down Expand Up @@ -92,6 +94,8 @@ message Edge {
DATA = 1;
// A call relation.
CALL = 2;
// A type relation.
TYPE = 3;
}
// The type of relation of this edge.
Flow flow = 1;
Expand Down

0 comments on commit 8318ab9

Please sign in to comment.