1
1
from typing import List , Dict
2
2
from skema .program_analysis .CAST2FN .model .cast import SourceRef
3
3
4
-
5
- class NodeHelper (object ):
6
- def __init__ (self , source_file_name : str , source : str ):
7
- self .source_file_name = source_file_name
4
+ from tree_sitter import Node
5
+
6
+ CONTROL_CHARACTERS = [
7
+ "," ,
8
+ "=" ,
9
+ "==" ,
10
+ "(" ,
11
+ ")" ,
12
+ "(/" ,
13
+ "/)" ,
14
+ ":" ,
15
+ "::" ,
16
+ "+" ,
17
+ "-" ,
18
+ "*" ,
19
+ "**" ,
20
+ "/" ,
21
+ ">" ,
22
+ "<" ,
23
+ "<=" ,
24
+ ">=" ,
25
+ "only" ,
26
+ ]
27
+
28
+ class NodeHelper ():
29
+ def __init__ (self , source : str , source_file_name : str ):
8
30
self .source = source
31
+ self .source_file_name = source_file_name
9
32
10
- def parse_tree_to_dict (self , node ) -> Dict :
11
- node_dict = {
12
- "type" : self .get_node_type (node ),
13
- "source_refs" : [self .get_node_source_ref (node )],
14
- "identifier" : self .get_node_identifier (node ),
15
- "original_children_order" : [],
16
- "children" : [],
17
- "comments" : [],
18
- "control" : [],
19
- }
20
33
21
- for child in node .children :
22
- child_dict = self .parse_tree_to_dict (child )
23
- node_dict ["original_children_order" ].append (child_dict )
24
- if self .is_comment_node (child ):
25
- node_dict ["comments" ].append (child_dict )
26
- elif self .is_control_character_node (child ):
27
- node_dict ["control" ].append (child_dict )
28
- else :
29
- node_dict ["children" ].append (child_dict )
30
-
31
- return node_dict
32
-
33
- def is_comment_node (self , node ):
34
- if node .type == "comment" :
35
- return True
36
- return False
37
-
38
- def is_control_character_node (self , node ):
39
- control_characters = [
40
- "," ,
41
- "=" ,
42
- "(" ,
43
- ")" ,
44
- ":" ,
45
- "::" ,
46
- "+" ,
47
- "-" ,
48
- "*" ,
49
- "**" ,
50
- "/" ,
51
- ">" ,
52
- "<" ,
53
- "<=" ,
54
- ">=" ,
55
- ]
56
- return node .type in control_characters
57
-
58
- def get_node_source_ref (self , node ) -> SourceRef :
34
+ def get_source_ref (self , node : Node ) -> SourceRef :
35
+ """Given a node and file name, return a CAST SourceRef object."""
59
36
row_start , col_start = node .start_point
60
37
row_end , col_end = node .end_point
61
38
return SourceRef (self .source_file_name , col_start , col_end , row_start , row_end )
62
39
63
- def get_node_identifier (self , node ) -> str :
64
- source_ref = self .get_node_source_ref (node )
65
40
41
+ def get_identifier (self , node : Node ) -> str :
42
+ """Given a node, return the identifier it represents. ie. The code between node.start_point and node.end_point"""
66
43
line_num = 0
67
44
column_num = 0
68
45
in_identifier = False
69
46
identifier = ""
70
47
for i , char in enumerate (self .source ):
71
- if line_num == source_ref . row_start and column_num == source_ref . col_start :
48
+ if line_num == node . start_point [ 0 ] and column_num == node . start_point [ 1 ] :
72
49
in_identifier = True
73
- elif line_num == source_ref . row_end and column_num == source_ref . col_end :
50
+ elif line_num == node . end_point [ 0 ] and column_num == node . end_point [ 1 ] :
74
51
break
75
52
76
53
if char == "\n " :
@@ -84,19 +61,51 @@ def get_node_identifier(self, node) -> str:
84
61
85
62
return identifier
86
63
87
- def get_node_type (self , node ) -> str :
88
- return node .type
64
+ def get_first_child_by_type (node : Node , type : str , recurse = False ):
65
+ """Takes in a node and a type string as inputs and returns the first child matching that type. Otherwise, return None
66
+ When the recurse argument is set, it will also recursivly search children nodes as well.
67
+ """
68
+ for child in node .children :
69
+ if child .type == type :
70
+ return child
71
+
72
+ if recurse :
73
+ for child in node .children :
74
+ out = get_first_child_by_type (child , type , True )
75
+ if out :
76
+ return out
77
+ return None
78
+
79
+
80
+ def get_children_by_types (node : Node , types : List ):
81
+ """Takes in a node and a list of types as inputs and returns all children matching those types. Otherwise, return an empty list"""
82
+ return [child for child in node .children if child .type in types ]
83
+
84
+
85
+ def get_first_child_index (node , type : str ):
86
+ """Get the index of the first child of node with type type."""
87
+ for i , child in enumerate (node .children ):
88
+ if child .type == type :
89
+ return i
90
+
91
+
92
+ def get_last_child_index (node , type : str ):
93
+ """Get the index of the last child of node with type type."""
94
+ last = None
95
+ for i , child in enumerate (node .children ):
96
+ if child .type == type :
97
+ last = child
98
+ return last
99
+
89
100
90
- def get_first_child_by_type (self , node : Dict , node_type : str ) -> Dict :
91
- children = self .get_children_by_type (node , node_type )
92
- if len (children ) >= 1 :
93
- return children [0 ]
101
+ def get_control_children (node : Node ):
102
+ return get_children_by_types (node , CONTROL_CHARACTERS )
94
103
95
- def get_children_by_type (self , node : Dict , node_type : str ) -> List :
96
- children = []
97
104
98
- for child in node ["children" ]:
99
- if child ["type" ] == node_type :
100
- children .append (child )
105
+ def get_non_control_children (node : Node ):
106
+ children = []
107
+ for child in node .children :
108
+ if child .type not in CONTROL_CHARACTERS :
109
+ children .append (child )
101
110
102
- return children
111
+ return children
0 commit comments