Skip to content

Commit

Permalink
Incorporate module initialization code (#114)
Browse files Browse the repository at this point in the history
* Add tests for wala#202.

* Remove whitespace.

* Add log.

* Fix wala#202.
  • Loading branch information
khatchad authored Jul 8, 2024
1 parent 4e44687 commit 8f850b8
Show file tree
Hide file tree
Showing 10 changed files with 137 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2606,6 +2606,34 @@ public void testModule53()
new int[] {2});
}

/** Test https://github.com/wala/ML/issues/202. */
@Test
public void testModule54()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
test(
new String[] {"proj51/src/__init__.py", "proj51/src/module.py", "proj51/client.py"},
"src/module.py",
"f",
"proj51",
1,
1,
new int[] {2});
}

/** Test https://github.com/wala/ML/issues/202. */
@Test
public void testModule55()
throws ClassHierarchyException, IllegalArgumentException, CancelException, IOException {
test(
new String[] {"proj52/src/__init__.py", "proj52/src/module.py", "proj52/client.py"},
"src/module.py",
"f",
"proj52",
1,
1,
new int[] {2});
}

@Test
public void testStaticMethod() throws ClassHierarchyException, CancelException, IOException {
test("tf2_test_static_method.py", "MyClass.the_static_method", 1, 1, 2);
Expand Down
2 changes: 2 additions & 0 deletions com.ibm.wala.cast.python.test/.pydevproject
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,7 @@
<path>/${PROJECT_DIR_NAME}/data/proj45</path>
<path>/${PROJECT_DIR_NAME}/data/proj49</path>
<path>/${PROJECT_DIR_NAME}/data/proj50</path>
<path>/${PROJECT_DIR_NAME}/data/proj51</path>
<path>/${PROJECT_DIR_NAME}/data/proj52</path>
</pydev_pathproperty>
</pydev_project>
6 changes: 6 additions & 0 deletions com.ibm.wala.cast.python.test/data/proj51/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Test https://github.com/wala/ML/issues/163.

from tensorflow import ones
from src import f

f(ones([1, 2]))
1 change: 1 addition & 0 deletions com.ibm.wala.cast.python.test/data/proj51/src/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .module import f
7 changes: 7 additions & 0 deletions com.ibm.wala.cast.python.test/data/proj51/src/module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Test https://github.com/wala/ML/issues/163.

from tensorflow import Tensor


def f(a):
assert isinstance(a, Tensor)
6 changes: 6 additions & 0 deletions com.ibm.wala.cast.python.test/data/proj52/client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Test https://github.com/wala/ML/issues/163.

from tensorflow import ones
from src import f

f(ones([1, 2]))
1 change: 1 addition & 0 deletions com.ibm.wala.cast.python.test/data/proj52/src/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .module import *
7 changes: 7 additions & 0 deletions com.ibm.wala.cast.python.test/data/proj52/src/module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
# Test https://github.com/wala/ML/issues/163.

from tensorflow import Tensor


def f(a):
assert isinstance(a, Tensor)
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,19 @@
*****************************************************************************/
package com.ibm.wala.cast.python.ipa.callgraph;

import static com.ibm.wala.cast.python.ir.PythonLanguage.MODULE_INITIALIZATION_FILENAME;

import com.google.common.collect.Maps;
import com.ibm.wala.cast.ipa.callgraph.AstSSAPropagationCallGraphBuilder;
import com.ibm.wala.cast.ipa.callgraph.GlobalObjectKey;
import com.ibm.wala.cast.ipa.callgraph.ScopeMappingInstanceKeys.ScopeMappingInstanceKey;
import com.ibm.wala.cast.ir.ssa.AstGlobalRead;
import com.ibm.wala.cast.ir.ssa.AstPropertyRead;
import com.ibm.wala.cast.python.ir.PythonLanguage;
import com.ibm.wala.cast.python.ssa.PythonInstructionVisitor;
import com.ibm.wala.cast.python.ssa.PythonInvokeInstruction;
import com.ibm.wala.cast.python.types.PythonTypes;
import com.ibm.wala.cast.python.util.Util;
import com.ibm.wala.classLoader.IClass;
import com.ibm.wala.classLoader.IField;
import com.ibm.wala.classLoader.NewSiteReference;
Expand All @@ -29,6 +33,7 @@
import com.ibm.wala.ipa.callgraph.CallGraph;
import com.ibm.wala.ipa.callgraph.IAnalysisCacheView;
import com.ibm.wala.ipa.callgraph.propagation.AbstractFieldPointerKey;
import com.ibm.wala.ipa.callgraph.propagation.AllocationSiteInNode;
import com.ibm.wala.ipa.callgraph.propagation.InstanceKey;
import com.ibm.wala.ipa.callgraph.propagation.PointerAnalysis;
import com.ibm.wala.ipa.callgraph.propagation.PointerKey;
Expand Down Expand Up @@ -313,6 +318,36 @@ public void visitPropertyRead(AstPropertyRead instruction) {
});
}
}

// check if we are reading from an module initialization script.
PointerKey objRefPK = getPointerKeyForLocal(instruction.getObjectRef());
OrdinalSet<InstanceKey> objRefPointsToSet =
getBuilder().getPointerAnalysis().getPointsToSet(objRefPK);

for (InstanceKey objIK : objRefPointsToSet) {
if (objIK instanceof AllocationSiteInNode || objIK instanceof ScopeMappingInstanceKey) {
AllocationSiteInNode asin = Util.getAllocationSiteInNode(objIK);
NewSiteReference site = asin.getSite();
TypeReference declaredType = site.getDeclaredType();
TypeName scriptTypeName = declaredType.getName();

if (scriptTypeName.toString().endsWith("/" + MODULE_INITIALIZATION_FILENAME)) {
// the "receiver" is a module initialization script.
Atom scriptPackage = scriptTypeName.getPackage();

String scriptName =
(scriptPackage == null
? scriptTypeName.getClassName()
: scriptPackage.toString() + "/" + scriptTypeName.getClassName())
.toString();
logger.finer("Script name is: " + scriptName + ".");

// check if the constant refers to a field that is being imported by a wildcard in the
// corresponding module's initialization script.
processWildcardImports(instruction, scriptName, constantValue.toString());
}
}
}
}
}

Expand All @@ -329,9 +364,25 @@ public void visitAstGlobalRead(AstGlobalRead globalRead) {
.toString();
logger.finer("Script name is: " + scriptName + ".");

String fieldName = getStrippedDeclaredFieldName(globalRead);

processWildcardImports(globalRead, scriptName, fieldName);
}

/**
* Processes the given {@link SSAInstruction} for any potential wildcard imports being utilized
* by the instruction.
*
* @param instruction The {@link SSAInstruction} whose definition may depend on a wildcard
* import.
* @param scriptName The name of the script to check for wildcard imports.
* @param fieldName The name of the field that may be imported using a wildcard.
*/
private void processWildcardImports(
SSAInstruction instruction, String scriptName, String fieldName) {
// Are there any wildcard imports for this script?
if (scriptToWildcardImports.containsKey(scriptName)) {
logger.info("Found wildcard imports in " + scriptName + " for " + globalRead + ".");
logger.info("Found wildcard imports in " + scriptName + " for " + instruction + ".");

Deque<MethodReference> deque = scriptToWildcardImports.get(scriptName);

Expand All @@ -341,14 +392,13 @@ public void visitAstGlobalRead(AstGlobalRead globalRead) {
+ importMethodReference.getDeclaringClass().getName().getClassName()
+ ".");

String globalFieldName = getStrippedDeclaredFieldName(globalRead);
logger.fine("Examining global: " + globalFieldName + " for wildcard import.");
logger.fine("Examining global: " + fieldName + " for wildcard import.");

CallGraph callGraph = this.getBuilder().getCallGraph();
Set<CGNode> nodes = callGraph.getNodes(importMethodReference);

PointerKey globalDefPK = this.getPointerKeyForLocal(globalRead.getDef());
assert globalDefPK != null;
PointerKey defPK = this.getPointerKeyForLocal(instruction.getDef());
assert defPK != null;

for (CGNode n : nodes) {
for (Iterator<NewSiteReference> nit = n.iterateNewSites(); nit.hasNext(); ) {
Expand All @@ -357,15 +407,14 @@ public void visitAstGlobalRead(AstGlobalRead globalRead) {
String name = newSiteReference.getDeclaredType().getName().getClassName().toString();
logger.finest("Examining: " + name + ".");

if (name.equals(globalFieldName)) {
if (name.equals(fieldName)) {
logger.info("Found wildcard import for: " + name + ".");

InstanceKey instanceKey =
this.getBuilder().getInstanceKeyForAllocation(n, newSiteReference);

if (this.system.newConstraint(globalDefPK, instanceKey)) {
logger.fine(
"Added constraint that: " + globalDefPK + " gets: " + instanceKey + ".");
if (this.system.newConstraint(defPK, instanceKey)) {
logger.fine("Added constraint that: " + defPK + " gets: " + instanceKey + ".");
return;
}
}
Expand All @@ -381,24 +430,20 @@ public void visitAstGlobalRead(AstGlobalRead globalRead) {
public void visitPut(SSAPutInstruction putInstruction) {
FieldReference putField = putInstruction.getDeclaredField();

if (globalFieldName.equals(putField.getName().toString())) {
if (fieldName.equals(putField.getName().toString())) {
// Found it.
int putVal = putInstruction.getVal();

// Make the global def point to the put instruction value.
// Make the def point to the put instruction value.
PointerKey putValPK =
PythonSSAPropagationCallGraphBuilder.PythonConstraintVisitor.this
.getBuilder()
.getPointerKeyForLocal(n, putVal);

if (PythonSSAPropagationCallGraphBuilder.PythonConstraintVisitor.this
.system.newConstraint(globalDefPK, assignOperator, putValPK))
.system.newConstraint(defPK, assignOperator, putValPK))
logger.fine(
"Added constraint that: "
+ globalDefPK
+ " gets: "
+ putValPK
+ ".");
"Added constraint that: " + defPK + " gets: " + putValPK + ".");
}
}
});
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1001,6 +1001,23 @@ protected void doPrimitive(int resultVal, WalkContext context, CAstNode primitiv
((AstInstructionFactory) insts)
.PropertyRead(
idx, resultVal, resultVal, context.currentScope().getConstantValue(eltName)));

// if the module is the special initialization module.
if (context.getName().endsWith("/" + MODULE_INITIALIZATION_FILENAME)) {
// add the imported name to the module so that other files can use it.
FieldReference eltField =
FieldReference.findOrCreate(
PythonTypes.Root, Atom.findOrCreateUnicodeAtom(eltName), PythonTypes.Root);

LOGGER.info("Adding write of field: " + eltField + " to initialization script.");

// The script should be in v1.
idx = context.cfg().getCurrentInstruction();
context
.cfg()
.addInstruction(
((AstInstructionFactory) insts).PutInstruction(idx, 1, resultVal, eltField));
}
}
}

Expand Down

0 comments on commit 8f850b8

Please sign in to comment.