Skip to content
This repository has been archived by the owner on Nov 9, 2019. It is now read-only.

Commit

Permalink
ImmutableComputableGraph code improvement + unit tests.
Browse files Browse the repository at this point in the history
- reviewed and restricted the access modifiers of all ICG-related classes
- got rid of the functionality to hold on to old caches: whenever a cache goes out of date, the reference is immediately made null in the new ICG
- completely rewrote ComputableGraphStructure in a functional style
- got rid of unused and unnecessary methods
- Created an ImmutableComputableGraphUtils and factored out the builder and other common methods
- math equality asserts for NDArray
- unit tests for ComputableGraphStructure
- unit tests for ImmutableComputableGraph
- unit tests for ImmutableComputableGraphUtils
  • Loading branch information
mbabadi committed May 16, 2017
1 parent 4da7b03 commit 65b124d
Show file tree
Hide file tree
Showing 16 changed files with 2,271 additions and 594 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,10 @@ public enum ModelInitializationStrategy {
public static final String NUM_LATENTS_SHORT_NAME = "NL";
public static final String NUM_LATENTS_LONG_NAME = "numLatents";

public static final int DEFAULT_NUMBER_OF_TARGET_SPACE_PARTITIONS = 1;
public static final String NUMBER_OF_TARGET_SPACE_PARTITIONS_SHORT_NAME = "NTSP";
public static final String NUMBER_OF_TARGET_SPACE_PARTITIONS_LONG_NAME = "numTargetSpacePartitions";

public static final int DEFAULT_MIN_LEARNING_READ_COUNT = 5;
public static final String MIN_LEARNING_READ_COUNT_SHORT_NAME = "MLRC";
public static final String MIN_LEARNING_READ_COUNT_LONG_NAME = "minimumLearningReadCount";
Expand Down Expand Up @@ -200,7 +204,6 @@ public enum ModelInitializationStrategy {
public static final String TARGET_SPECIFIC_VARIANCE_SOLVER_NUM_THREADS_SHORT_NAME = "TSVSNT";
public static final String TARGET_SPECIFIC_VARIANCE_SOLVER_NUM_THREADS_LONG_NAME = "targetSpecificVarianceSolverNumThreads";


/* bias covariates related */

public static final BiasCovariateSolverStrategy DEFAULT_BIAS_COVARIATES_SOLVER_TYPE = BiasCovariateSolverStrategy.SPARK;
Expand Down Expand Up @@ -276,6 +279,8 @@ public enum ModelInitializationStrategy {
public static final String FOURIER_REGULARIZATION_STRENGTH_LONG_NAME = "fourierRegularizationStrength";


/* copy ratio calling related */

public static final boolean DEFAULT_CR_UPDATE_ENABLED = true;
public static final String CR_UPDATE_ENABLED_SHORT_NAME = "CRU";
public static final String CR_UPDATE_ENABLED_LONG_NAME = "copyRatioUpdate";
Expand All @@ -284,10 +289,7 @@ public enum ModelInitializationStrategy {
public static final String CR_HMM_TYPE_SHORT_NAME = "CRHMM";
public static final String CR_HMM_TYPE_LONG_NAME = "copyRatioHMMType";


public static final int DEFAULT_NUMBER_OF_TARGET_SPACE_PARTITIONS = 1;
public static final String NUMBER_OF_TARGET_SPACE_PARTITIONS_SHORT_NAME = "NTSP";
public static final String NUMBER_OF_TARGET_SPACE_PARTITIONS_LONG_NAME = "numTargetSpacePartitions";
/* checkpointing related */

public static final int DEFAULT_RDD_CHECKPOINTING_INTERVAL = 10;
public static final String RDD_CHECKPOINTING_INTERVAL_SHORT_NAME = "RDDCPI";
Expand All @@ -301,7 +303,6 @@ public enum ModelInitializationStrategy {
public static final String RDD_CHECKPOINTING_PATH_SHORT_NAME = "RDDCPP";
public static final String RDD_CHECKPOINTING_PATH_LONG_NAME = "rddCheckpointingPath";


public static final int DEFAULT_RUN_CHECKPOINTING_INTERVAL = 1;
public static final String RUN_CHECKPOINTING_INTERVAL_SHORT_NAME = "RCPI";
public static final String RUN_CHECKPOINTING_INTERVAL_LONG_NAME = "runCheckpointingInterval";
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,20 @@
*
* @author Mehrtash Babadi <[email protected]>
*/
public abstract class CacheNode {

abstract class CacheNode {
/**
* A string identifier for the cache node
*/
private final String key;

/**
* The collection of string identifiers of the immediate parents of this node (can be empty)
*/
private final Collection<String> parents;

/**
* The collection of string identifiers of the tags associated to this node (can be empty)
*/
private final Collection<String> tags;

/**
Expand All @@ -28,65 +36,104 @@ public abstract class CacheNode {
* @param tags the tags associated to this cache node
* @param parents immediate parents of this cache node
*/
public CacheNode(@Nonnull final String key,
@Nonnull final Collection<String> tags,
@Nonnull final Collection<String> parents) {
CacheNode(@Nonnull final String key,
@Nonnull final Collection<String> tags,
@Nonnull final Collection<String> parents) {
this.key = Utils.nonNull(key, "The key of a cache node can not be null");
this.tags = Collections.unmodifiableCollection(Utils.nonNull(tags, "The tag collection of a cache node can not be null"));
this.parents = Collections.unmodifiableCollection(Utils.nonNull(parents, "The immediate parents of a cache node can not be null"));
}

public abstract Duplicable get(@Nonnull final Map<String, Duplicable> dict);
/**
* Get the value stored in the node
*
* @param parents parent values (as a map from their string identifiers to their values)
* @return a {@link Duplicable}; possibly by reference
*/
abstract Duplicable get(@Nonnull final Map<String, Duplicable> parents);

public abstract boolean isPrimitive();
/**
* Set the value of the node
*
* @param newValue new value; possibly stored by reference
* @throws UnsupportedOperationException if the node is automatically computable
*/
abstract void set(@Nullable final Duplicable newValue) throws UnsupportedOperationException;

public abstract boolean isStoredValueAvailable();
/**
* Is the node primitive?
*/
abstract boolean isPrimitive();

public abstract void set(@Nullable final Duplicable val);
/**
* Is the node initialized yet?
*/
abstract boolean hasValue();

public abstract boolean isExternallyComputable();
/**
* Is the node externally computed?
*/
abstract boolean isExternallyComputed();

public CacheNode duplicateWithUpdatedValue(final Duplicable newValue)
throws UnsupportedOperationException {
throw new UnsupportedOperationException();
}
/**
* Duplicate the node with updated value
*
* @param newValue new value; possibly stored by reference
* @return a new {@link CacheNode} with the same key, parents, and tags but with a new value
* @throws UnsupportedOperationException if the node is automatically computable
*/
abstract CacheNode duplicateWithUpdatedValue(final Duplicable newValue) throws UnsupportedOperationException;

public CacheNode duplicate()
throws UnsupportedOperationException {
throw new UnsupportedOperationException();
}
/**
* Make a deep copy of the node
*
* @return a deeply copied instance of {@link CacheNode}
*/
abstract CacheNode duplicate();

public String getKey() {
/**
* Get the string identifier of the node
* @return a non-null {@link String}
*/
final String getKey() {
return key;
}

public Collection<String> getParents() {
return parents;
/**
* Get the collection of string identifier of the parents of this node (can be empty)
*/
final Collection<String> getParents() {
return Collections.unmodifiableCollection(parents);
}

public Collection<String> getTags() {
return tags;
/**
* Get the collection of string identifier of the tags associated to this node (can be empty)
*/
final Collection<String> getTags() {
return Collections.unmodifiableCollection(tags);
}

@Override
public String toString() {
public final String toString() {
return key;
}

/**
* NOTE: equality comparison is done just based on the key
* @param other another object
*/
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;

CacheNode cacheNode = (CacheNode) o;

if (!key.equals(cacheNode.key)) return false;
if (!parents.equals(cacheNode.parents)) return false;
return tags.equals(cacheNode.tags);
public final boolean equals(Object other) {
if (this == other) return true;
if (other == null || getClass() != other.getClass()) return false;
return (key.equals(((CacheNode) other).key));
}

/**
* NOTE: hashcode is generated just based on the key
*/
@Override
public int hashCode() {
public final int hashCode() {
return key.hashCode();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
*
* @author Mehrtash Babadi &lt;[email protected]&gt;
*/
public final class ComputableCacheNode extends CacheNode {
final class ComputableCacheNode extends CacheNode {

private final boolean cacheEvals;
private final ComputableNodeFunction func;
private Duplicable cachedValue = null;
private final boolean isCaching;
private boolean isCacheCurrent;

/**
Expand All @@ -26,17 +26,17 @@ public final class ComputableCacheNode extends CacheNode {
* @param key the key of the node
* @param parents immediate parents of the node
* @param func a function from a map that (at least) contains parents data to the computed value of this node
* @param cacheEvals does it store the value or not
* @param isCaching does it store the value or not
*/
public ComputableCacheNode(@Nonnull final String key,
@Nonnull final Collection<String> tags,
@Nonnull final Collection<String> parents,
@Nullable final ComputableNodeFunction func,
final boolean cacheEvals) {
ComputableCacheNode(@Nonnull final String key,
@Nonnull final Collection<String> tags,
@Nonnull final Collection<String> parents,
@Nullable final ComputableNodeFunction func,
final boolean isCaching) {
super(key, tags, parents);
this.func = func;
this.cacheEvals = cacheEvals;
Utils.validateArg(func != null || cacheEvals, "A computable node with null evaluation function is externally" +
this.isCaching = isCaching;
Utils.validateArg(func != null || isCaching, "A computable node with null evaluation function is externally" +
" mutable and must cache its values");
isCacheCurrent = false;
}
Expand All @@ -45,54 +45,41 @@ private ComputableCacheNode(@Nonnull final String key,
@Nonnull final Collection<String> tags,
@Nonnull final Collection<String> parents,
@Nullable final ComputableNodeFunction func,
final boolean cacheEvals,
final boolean isCaching,
final Duplicable cachedValue,
final boolean isCacheCurrent) {
super(key, tags, parents);
this.func = func;
this.cacheEvals = cacheEvals;
this.isCaching = isCaching;
this.isCacheCurrent = isCacheCurrent;
this.cachedValue = cachedValue;
}

@Override
public boolean isPrimitive() { return false; }
boolean isPrimitive() { return false; }

@Override
public boolean isExternallyComputable() { return func == null; }
boolean isExternallyComputed() { return func == null; }

public boolean isCacheCurrent() { return isCacheCurrent; }

public boolean doesCacheEvaluations() { return cacheEvals; }
boolean isCaching() { return isCaching; }

/**
* Available means (1) the node caches its value, and (2) a value is already cached (though may
* not be up-to-date)
*
* @return a boolean
* @return true if the node is caching, has a non-null {@link Duplicable}, the cache is up to date, and the
* duplicable has a non-null value stored in it
*/
@Override
public boolean isStoredValueAvailable() {
return cacheEvals && cachedValue != null && !cachedValue.hasValue();
}

/**
* In addition to being available, this methods checks if the cached value is up-to-date
*
* @return a boolean
*/
public boolean isStoredValueAvailableAndCurrent() {
return isStoredValueAvailable() && isCacheCurrent();
boolean hasValue() {
return isCaching && isCacheCurrent && cachedValue != null && cachedValue.hasValue();
}

@Override
public void set(@Nullable final Duplicable val) {
if (isExternallyComputable()) {
void set(@Nullable final Duplicable val) {
if (isExternallyComputed()) {
cachedValue = val;
isCacheCurrent = true;
} else {
throw new UnsupportedOperationException("Can not explicitly set the value of a computable cache node with" +
" non-null function.");
" non-null function");
}
}

Expand All @@ -106,23 +93,23 @@ public void set(@Nullable final Duplicable val) {
* @throws ComputableNodeFunction.ParentValueNotFoundException if a required parent value is not given
*/
@Override
public Duplicable get(@Nonnull final Map<String, Duplicable> parentsValues)
Duplicable get(@Nonnull final Map<String, Duplicable> parentsValues)
throws ComputableNodeFunction.ParentValueNotFoundException, ExternallyComputableNodeValueUnavailableException {
if (isStoredValueAvailableAndCurrent()) {
if (hasValue()) {
return cachedValue;
} else if (!isExternallyComputable()) {
} else if (!isExternallyComputed()) {
return func.apply(parentsValues); /* may throw {@link ComputableNodeFunction.ParentValueNotFoundException} */
} else { /* externally computable node */
throw new ExternallyComputableNodeValueUnavailableException(getKey());
}
}

@Override
public ComputableCacheNode duplicate() {
if (isStoredValueAvailable()) {
ComputableCacheNode duplicate() {
if (hasValue()) {
return new ComputableCacheNode(getKey(), getTags(), getParents(), func, true, cachedValue.duplicate(), isCacheCurrent);
} else {
return new ComputableCacheNode(getKey(), getTags(), getParents(), func, cacheEvals, null, isCacheCurrent);
return new ComputableCacheNode(getKey(), getTags(), getParents(), func, isCaching, null, isCacheCurrent);
}
}

Expand All @@ -132,11 +119,11 @@ public ComputableCacheNode duplicate() {
* @param newValue the cache value to be replaced with the old value
* @return a new instance of {@link ComputableCacheNode}
*/
public ComputableCacheNode duplicateWithUpdatedValue(final Duplicable newValue) {
if (cacheEvals && newValue != null && !newValue.hasValue()) {
ComputableCacheNode duplicateWithUpdatedValue(final Duplicable newValue) {
if (isCaching && newValue != null && newValue.hasValue()) {
return new ComputableCacheNode(getKey(), getTags(), getParents(), func, true, newValue, true);
} else {
return new ComputableCacheNode(getKey(), getTags(), getParents(), func, cacheEvals, null, false);
return new ComputableCacheNode(getKey(), getTags(), getParents(), func, isCaching, null, false);
}
}

Expand All @@ -145,7 +132,7 @@ public ComputableCacheNode duplicateWithUpdatedValue(final Duplicable newValue)
*
* @return a function
*/
public ComputableNodeFunction getFunction() {
ComputableNodeFunction getFunction() {
return func;
}

Expand All @@ -155,8 +142,8 @@ public ComputableNodeFunction getFunction() {
*
* @return a new instance of {@link ComputableCacheNode}
*/
public ComputableCacheNode duplicateWithOutdatedCacheStatus() {
return new ComputableCacheNode(getKey(), getTags(), getParents(), func, cacheEvals, null, false);
ComputableCacheNode duplicateWithOutdatedCacheStatus() {
return new ComputableCacheNode(getKey(), getTags(), getParents(), func, isCaching, null, false);
}

/**
Expand Down
Loading

0 comments on commit 65b124d

Please sign in to comment.