Skip to content
This repository was archived by the owner on Jan 8, 2025. It is now read-only.

Commit 3de834d

Browse files
authored
Merge branch 'ml4ai:main' into main
2 parents a8987b9 + abef48f commit 3de834d

28 files changed

+1874
-675
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
<math><mfrac><mrow><mi>d</mi><mi>E</mi></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mi>&#x03B2;</mi><mi>I</mi><mi>S</mi><mo>&#x2212;</mo><mi>&#x03B4;</mi><mi>E</mi></math>
2+
<math><mfrac><mrow><mi>d</mi><mi>R</mi></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mi>(1&#x2212;&#x03B1;)</mi><mi>&#x03B3;</mi><mi>I</mi></math>
3+
<math><mfrac><mrow><mi>d</mi><mi>I</mi></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mi>&#x03B4;</mi><mi>E</mi><mo>&#x2212;</mo><mi>(1&#x2212;&#x03B1;)</mi><mi>&#x03B3;</mi><mi>I</mi><mo>&#x2212;</mo><mi>&#x03B1;</mi><mi>&#x03C1;</mi><mi>I</mi></math>
4+
<math><mfrac><mrow><mi>d</mi><mi>D</mi></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mi>&#x03B1;</mi><mi>&#x03C1;</mi><mi>I</mi></math>
5+
<math><mfrac><mrow><mi>d</mi><mi>S</mi></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mo>&#x2212;</mo><mi>&#x03B2;</mi><mi>I</mi><mi>S</mi></math>
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
<math display="block"><mfrac><mrow><mi>d</mi><mi>S</mi><mo>(</mo><mi>t</mi><mo>)</mo></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mo>&#x2212;</mo><mi>&#x03B2;</mi><mi>I</mi><mo>(</mo><mi>t</mi><mo>)</mo><mi>S</mi><mo>(</mo><mi>t</mi><mo>)</mo></math>
2+
<math display="block"><mfrac><mrow><mi>d</mi><mi>E</mi><mo>(</mo><mi>t</mi><mo>)</mo></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mi>&#x03B2;</mi><mi>I</mi><mo>(</mo><mi>t</mi><mo>)</mo><mi>S</mi><mo>(</mo><mi>t</mi><mo>)</mo><mo>&#x2212;</mo><mi>&#x03B4;</mi><mi>E</mi><mo>(</mo><mi>t</mi><mo>)</mo></math>
3+
<math display="block"><mfrac><mrow><mi>d</mi><mi>D</mi><mo>(</mo><mi>t</mi><mo>)</mo></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mi>&#x03B1;</mi><mi>&#x03C1;</mi><mi>I</mi><mo>(</mo><mi>t</mi><mo>)</mo></math>
4+
<math display="block"><mfrac><mrow><mi>d</mi><mi>R</mi><mo>(</mo><mi>t</mi><mo>)</mo></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mi>(1&#x2212;&#x03B1;)</mi><mi>&#x03B3;</mi><mi>I</mi><mo>(</mo><mi>t</mi><mo>)</mo></math>
5+
<math display="block"><mfrac><mrow><mi>d</mi><mi>I</mi><mo>(</mo><mi>t</mi><mo>)</mo></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mi>&#x03B4;</mi><mi>E</mi><mo>(</mo><mi>t</mi><mo>)</mo><mo>&#x2212;</mo><mi>(1&#x2212;&#x03B1;)</mi><mi>&#x03B3;</mi><mi>I</mi><mo>(</mo><mi>t</mi><mo>)</mo><mo>&#x2212;</mo><mi>&#x03B1;</mi><mi>&#x03C1;</mi><mi>I</mi><mo>(</mo><mi>t</mi><mo>)</mo></math>
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
<math display="block"><mfrac><mrow><mi>d</mi><mi>S</mi><mo>(</mo><mi>t</mi><mo>)</mo></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mo>&#x2212;</mo><mi>&#x03B2;</mi><mi>I</mi><mo>(</mo><mi>t</mi><mo>)</mo><mfrac><mrow><mi>S</mi><mo>(</mo><mi>t</mi><mo>)</mo></mrow><mi>N</mi></mfrac></math>
2+
<math display="block"><mfrac><mrow><mi>d</mi><mi>E</mi><mo>(</mo><mi>t</mi><mo>)</mo></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mi>&#x03B2;</mi><mi>I</mi><mo>(</mo><mi>t</mi><mo>)</mo><mfrac><mrow><mi>S</mi><mo>(</mo><mi>t</mi><mo>)</mo></mrow><mi>N</mi></mfrac><mo>&#x2212;</mo><mi>&#x03B4;</mi><mi>E</mi><mo>(</mo><mi>t</mi><mo>)</mo></math>
3+
<math display="block"><mfrac><mrow><mi>d</mi><mi>I</mi><mo>(</mo><mi>t</mi><mo>)</mo></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mi>&#x03B4;</mi><mi>E</mi><mo>(</mo><mi>t</mi><mo>)</mo><mo>&#x2212;</mo><mo>(</mo><mn>1</mn><mo>&#x2212;</mo><mi>&#x03B1;</mi><mo>)</mo><mi>&#x03B3;</mi><mi>I</mi><mo>(</mo><mi>t</mi><mo>)</mo><mo>&#x2212;</mo><mi>&#x03B1;</mi><mi>&#x03C1;</mi><mi>I</mi><mo>(</mo><mi>t</mi><mo>)</mo></math>
4+
<math display="block"><mfrac><mrow><mi>d</mi><mi>R</mi><mo>(</mo><mi>t</mi><mo>)</mo></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mo>(</mo><mn>1</mn><mo>&#x2212;</mo><mi>&#x03B1;</mi><mo>)</mo><mi>&#x03B3;</mi><mi>I</mi><mo>(</mo><mi>t</mi><mo>)</mo></math>
5+
<math display="block"><mfrac><mrow><mi>d</mi><mi>D</mi><mo>(</mo><mi>t</mi><mo>)</mo></mrow><mrow><mi>d</mi><mi>t</mi></mrow></mfrac><mo>=</mo><mi>&#x03B1;</mi><mi>&#x03C1;</mi><mi>I</mi><mo>(</mo><mi>t</mi><mo>)</mo></math>

skema/img2mml/api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ def retrieve_model(model_path=None) -> str:
2525
cwd = Path(__file__).parents[0]
2626
MODEL_BASE_ADDRESS = "https://artifacts.askem.lum.ai/skema/img2mml/models"
2727
MODEL_NAME = "cnn_xfmer_arxiv_im2mml_with_fonts_boldface_best.pt"
28-
29-
if model_path is None:
28+
# If the model path is none or doesn't exist, the default model will be downloaded from server.
29+
if model_path is None or not os.path.exists(model_path):
3030
model_path = cwd / "trained_models" / MODEL_NAME
3131

3232
# Check if the model file already exists
Lines changed: 77 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,76 +1,53 @@
11
from typing import List, Dict
22
from skema.program_analysis.CAST2FN.model.cast import SourceRef
33

4-
5-
class NodeHelper(object):
6-
def __init__(self, source_file_name: str, source: str):
7-
self.source_file_name = source_file_name
4+
from tree_sitter import Node
5+
6+
CONTROL_CHARACTERS = [
7+
",",
8+
"=",
9+
"==",
10+
"(",
11+
")",
12+
"(/",
13+
"/)",
14+
":",
15+
"::",
16+
"+",
17+
"-",
18+
"*",
19+
"**",
20+
"/",
21+
">",
22+
"<",
23+
"<=",
24+
">=",
25+
"only",
26+
]
27+
28+
class NodeHelper():
29+
def __init__(self, source: str, source_file_name: str):
830
self.source = source
31+
self.source_file_name = source_file_name
932

10-
def parse_tree_to_dict(self, node) -> Dict:
11-
node_dict = {
12-
"type": self.get_node_type(node),
13-
"source_refs": [self.get_node_source_ref(node)],
14-
"identifier": self.get_node_identifier(node),
15-
"original_children_order": [],
16-
"children": [],
17-
"comments": [],
18-
"control": [],
19-
}
2033

21-
for child in node.children:
22-
child_dict = self.parse_tree_to_dict(child)
23-
node_dict["original_children_order"].append(child_dict)
24-
if self.is_comment_node(child):
25-
node_dict["comments"].append(child_dict)
26-
elif self.is_control_character_node(child):
27-
node_dict["control"].append(child_dict)
28-
else:
29-
node_dict["children"].append(child_dict)
30-
31-
return node_dict
32-
33-
def is_comment_node(self, node):
34-
if node.type == "comment":
35-
return True
36-
return False
37-
38-
def is_control_character_node(self, node):
39-
control_characters = [
40-
",",
41-
"=",
42-
"(",
43-
")",
44-
":",
45-
"::",
46-
"+",
47-
"-",
48-
"*",
49-
"**",
50-
"/",
51-
">",
52-
"<",
53-
"<=",
54-
">=",
55-
]
56-
return node.type in control_characters
57-
58-
def get_node_source_ref(self, node) -> SourceRef:
34+
def get_source_ref(self, node: Node) -> SourceRef:
35+
"""Given a node and file name, return a CAST SourceRef object."""
5936
row_start, col_start = node.start_point
6037
row_end, col_end = node.end_point
6138
return SourceRef(self.source_file_name, col_start, col_end, row_start, row_end)
6239

63-
def get_node_identifier(self, node) -> str:
64-
source_ref = self.get_node_source_ref(node)
6540

41+
def get_identifier(self, node: Node) -> str:
42+
"""Given a node, return the identifier it represents. ie. The code between node.start_point and node.end_point"""
6643
line_num = 0
6744
column_num = 0
6845
in_identifier = False
6946
identifier = ""
7047
for i, char in enumerate(self.source):
71-
if line_num == source_ref.row_start and column_num == source_ref.col_start:
48+
if line_num == node.start_point[0] and column_num == node.start_point[1]:
7249
in_identifier = True
73-
elif line_num == source_ref.row_end and column_num == source_ref.col_end:
50+
elif line_num == node.end_point[0] and column_num == node.end_point[1]:
7451
break
7552

7653
if char == "\n":
@@ -84,19 +61,51 @@ def get_node_identifier(self, node) -> str:
8461

8562
return identifier
8663

87-
def get_node_type(self, node) -> str:
88-
return node.type
64+
def get_first_child_by_type(node: Node, type: str, recurse=False):
65+
"""Takes in a node and a type string as inputs and returns the first child matching that type. Otherwise, return None
66+
When the recurse argument is set, it will also recursivly search children nodes as well.
67+
"""
68+
for child in node.children:
69+
if child.type == type:
70+
return child
71+
72+
if recurse:
73+
for child in node.children:
74+
out = get_first_child_by_type(child, type, True)
75+
if out:
76+
return out
77+
return None
78+
79+
80+
def get_children_by_types(node: Node, types: List):
81+
"""Takes in a node and a list of types as inputs and returns all children matching those types. Otherwise, return an empty list"""
82+
return [child for child in node.children if child.type in types]
83+
84+
85+
def get_first_child_index(node, type: str):
86+
"""Get the index of the first child of node with type type."""
87+
for i, child in enumerate(node.children):
88+
if child.type == type:
89+
return i
90+
91+
92+
def get_last_child_index(node, type: str):
93+
"""Get the index of the last child of node with type type."""
94+
last = None
95+
for i, child in enumerate(node.children):
96+
if child.type == type:
97+
last = child
98+
return last
99+
89100

90-
def get_first_child_by_type(self, node: Dict, node_type: str) -> Dict:
91-
children = self.get_children_by_type(node, node_type)
92-
if len(children) >= 1:
93-
return children[0]
101+
def get_control_children(node: Node):
102+
return get_children_by_types(node, CONTROL_CHARACTERS)
94103

95-
def get_children_by_type(self, node: Dict, node_type: str) -> List:
96-
children = []
97104

98-
for child in node["children"]:
99-
if child["type"] == node_type:
100-
children.append(child)
105+
def get_non_control_children(node: Node):
106+
children = []
107+
for child in node.children:
108+
if child.type not in CONTROL_CHARACTERS:
109+
children.append(child)
101110

102-
return children
111+
return children

0 commit comments

Comments
 (0)