diff --git a/src/main/java/org/broadinstitute/hellbender/tools/coveragemodel/CoverageModelEMComputeBlock.java b/src/main/java/org/broadinstitute/hellbender/tools/coveragemodel/CoverageModelEMComputeBlock.java index 22679b446..78dbdd0bb 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/coveragemodel/CoverageModelEMComputeBlock.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/coveragemodel/CoverageModelEMComputeBlock.java @@ -138,9 +138,11 @@ public enum CoverageModelICGCacheNode { loglike_reg("Contribution of this block of targets to the model log likelihood (w/ Fourier regularizer)"); public final String description; + public final CacheNode.NodeKey key; CoverageModelICGCacheNode(final String description) { this.description = description; + this.key = new CacheNode.NodeKey(name()); } } @@ -158,10 +160,13 @@ public enum CoverageModelICGCacheTag { M_STEP_PSI("Cache nodes to be updated for the M-step for target-specific unexplained variance"), LOGLIKE_UNREG("Cache nodes to be updated for log likelihood calculation (w/o regularization)"), LOGLIKE_REG("Cache nodes to be updated for log likelihood calculation (w/ regularization)"); + public final String description; + public final CacheNode.NodeTag tag; CoverageModelICGCacheTag(final String description) { this.description = description; + this.tag = new CacheNode.NodeTag(name()); } } @@ -190,26 +195,26 @@ private static ImmutableComputableGraph createEmptyCacheGraph(final boolean bias */ cgbuilder /* raw read counts */ - .primitiveNodeWithEmptyNDArray(CoverageModelICGCacheNode.n_st.name()) + .primitiveNodeWithEmptyNDArray(CoverageModelICGCacheNode.n_st.key) /* mask */ - .primitiveNodeWithEmptyNDArray(CoverageModelICGCacheNode.M_st.name()) + .primitiveNodeWithEmptyNDArray(CoverageModelICGCacheNode.M_st.key) /* mapping error probability */ - .primitiveNodeWithEmptyNDArray(CoverageModelICGCacheNode.err_st.name()); + .primitiveNodeWithEmptyNDArray(CoverageModelICGCacheNode.err_st.key); /* * Model parameters */ cgbuilder /* mean log bias */ - .primitiveNodeWithEmptyNDArray(CoverageModelICGCacheNode.m_t.name()) + .primitiveNodeWithEmptyNDArray(CoverageModelICGCacheNode.m_t.key) /* unexplained variance */ - .primitiveNodeWithEmptyNDArray(CoverageModelICGCacheNode.Psi_t.name()); + .primitiveNodeWithEmptyNDArray(CoverageModelICGCacheNode.Psi_t.key); /* if ARD is enabled, add a node for ARD coefficients */ if (ardEnabled) { cgbuilder /* precision of bias covariates */ - .primitiveNodeWithEmptyNDArray(CoverageModelICGCacheNode.alpha_l.name()); + .primitiveNodeWithEmptyNDArray(CoverageModelICGCacheNode.alpha_l.key); } /* @@ -217,24 +222,24 @@ private static ImmutableComputableGraph createEmptyCacheGraph(final boolean bias */ cgbuilder /* E[log(c_{st})] */ - .externallyComputableNode(CoverageModelICGCacheNode.log_c_st.name()) + .untrackedExternallyComputableNode(CoverageModelICGCacheNode.log_c_st.key) /* var[log(c_{st})] */ - .externallyComputableNode(CoverageModelICGCacheNode.var_log_c_st.name()) + .untrackedExternallyComputableNode(CoverageModelICGCacheNode.var_log_c_st.key) /* E[log(d_s)] */ - .externallyComputableNode(CoverageModelICGCacheNode.log_d_s.name()) + .untrackedExternallyComputableNode(CoverageModelICGCacheNode.log_d_s.key) /* var[log(d_s)] */ - .externallyComputableNode(CoverageModelICGCacheNode.var_log_d_s.name()) + .untrackedExternallyComputableNode(CoverageModelICGCacheNode.var_log_d_s.key) /* E[\gamma_s] */ - .externallyComputableNode(CoverageModelICGCacheNode.gamma_s.name()); + .untrackedExternallyComputableNode(CoverageModelICGCacheNode.gamma_s.key); if (biasCovariatesEnabled) { cgbuilder /* mean of bias covariates */ - .externallyComputableNode(CoverageModelICGCacheNode.W_tl.name()) + .untrackedExternallyComputableNode(CoverageModelICGCacheNode.W_tl.key) /* E[z_{sm}] */ - .externallyComputableNode(CoverageModelICGCacheNode.z_sl.name()) + .untrackedExternallyComputableNode(CoverageModelICGCacheNode.z_sl.key) /* E[z_{sm} z_{sn}] */ - .externallyComputableNode(CoverageModelICGCacheNode.zz_sll.name()); + .untrackedExternallyComputableNode(CoverageModelICGCacheNode.zz_sll.key); } /* @@ -242,178 +247,178 @@ private static ImmutableComputableGraph createEmptyCacheGraph(final boolean bias */ cgbuilder /* log read counts */ - .computableNode(CoverageModelICGCacheNode.log_n_st.name(), - new String[]{ - CoverageModelICGCacheTag.M_STEP_M.name(), - CoverageModelICGCacheTag.E_STEP_D.name()}, - new String[]{ - CoverageModelICGCacheNode.n_st.name(), - CoverageModelICGCacheNode.M_st.name()}, + .computableNode(CoverageModelICGCacheNode.log_n_st.key, + new CacheNode.NodeTag[]{ + CoverageModelICGCacheTag.M_STEP_M.tag, + CoverageModelICGCacheTag.E_STEP_D.tag}, + new CacheNode.NodeKey[] { + CoverageModelICGCacheNode.n_st.key, + CoverageModelICGCacheNode.M_st.key}, calculate_log_n_st, true) /* Poisson noise */ - .computableNode(CoverageModelICGCacheNode.Sigma_st.name(), - new String[]{}, - new String[]{ - CoverageModelICGCacheNode.n_st.name(), - CoverageModelICGCacheNode.M_st.name()}, + .computableNode(CoverageModelICGCacheNode.Sigma_st.key, + new CacheNode.NodeTag[]{}, + new CacheNode.NodeKey[] { + CoverageModelICGCacheNode.n_st.key, + CoverageModelICGCacheNode.M_st.key}, calculate_Sigma_st, true) /* \sum_s M_{st} */ - .computableNode(CoverageModelICGCacheNode.sum_M_t.name(), - new String[]{}, - new String[]{CoverageModelICGCacheNode.M_st.name()}, + .computableNode(CoverageModelICGCacheNode.sum_M_t.key, + new CacheNode.NodeTag[]{}, + new CacheNode.NodeKey[] {CoverageModelICGCacheNode.M_st.key}, calculate_sum_M_t, true) /* \sum_t M_{st} */ - .computableNode(CoverageModelICGCacheNode.sum_M_s.name(), - new String[]{ - CoverageModelICGCacheTag.LOGLIKE_UNREG.name(), - CoverageModelICGCacheTag.LOGLIKE_REG.name()}, - new String[]{CoverageModelICGCacheNode.M_st.name()}, + .computableNode(CoverageModelICGCacheNode.sum_M_s.key, + new CacheNode.NodeTag[]{ + CoverageModelICGCacheTag.LOGLIKE_UNREG.tag, + CoverageModelICGCacheTag.LOGLIKE_REG.tag}, + new CacheNode.NodeKey[] {CoverageModelICGCacheNode.M_st.key}, calculate_sum_M_s, true) /* \Psi_{st} = \Psi_t + \Sigma_{st} + E[\gamma_s] */ - .computableNode(CoverageModelICGCacheNode.tot_Psi_st.name(), - new String[]{}, - new String[]{ - CoverageModelICGCacheNode.Sigma_st.name(), - CoverageModelICGCacheNode.Psi_t.name(), - CoverageModelICGCacheNode.gamma_s.name()}, + .computableNode(CoverageModelICGCacheNode.tot_Psi_st.key, + new CacheNode.NodeTag[] {}, + new CacheNode.NodeKey[] { + CoverageModelICGCacheNode.Sigma_st.key, + CoverageModelICGCacheNode.Psi_t.key, + CoverageModelICGCacheNode.gamma_s.key}, calculate_tot_Psi_st, true) /* log(n_{st}) - E[log(c_{st})] - E[log(d_s)] - m_t */ - .computableNode(CoverageModelICGCacheNode.Delta_st.name(), - new String[]{CoverageModelICGCacheTag.E_STEP_Z.name()}, - new String[]{ - CoverageModelICGCacheNode.log_n_st.name(), - CoverageModelICGCacheNode.log_c_st.name(), - CoverageModelICGCacheNode.log_d_s.name(), - CoverageModelICGCacheNode.m_t.name()}, + .computableNode(CoverageModelICGCacheNode.Delta_st.key, + new CacheNode.NodeTag[] {CoverageModelICGCacheTag.E_STEP_Z.tag}, + new CacheNode.NodeKey[] { + CoverageModelICGCacheNode.log_n_st.key, + CoverageModelICGCacheNode.log_c_st.key, + CoverageModelICGCacheNode.log_d_s.key, + CoverageModelICGCacheNode.m_t.key}, calculate_Delta_st, true) /* \sum_{t} M_{st} \log(\Psi_{st}) */ - .computableNode(CoverageModelICGCacheNode.loglike_normalization_s.name(), - new String[]{ - CoverageModelICGCacheTag.LOGLIKE_REG.name(), - CoverageModelICGCacheTag.LOGLIKE_UNREG.name()}, - new String[]{ - CoverageModelICGCacheNode.M_st.name(), - CoverageModelICGCacheNode.tot_Psi_st.name(), - CoverageModelICGCacheNode.log_n_st.name()}, + .computableNode(CoverageModelICGCacheNode.loglike_normalization_s.key, + new CacheNode.NodeTag[] { + CoverageModelICGCacheTag.LOGLIKE_REG.tag, + CoverageModelICGCacheTag.LOGLIKE_UNREG.tag}, + new CacheNode.NodeKey[] { + CoverageModelICGCacheNode.M_st.key, + CoverageModelICGCacheNode.tot_Psi_st.key, + CoverageModelICGCacheNode.log_n_st.key}, calculate_loglike_normalization_s, true) /* M_{st} \Psi_{st}^{-1} */ - .computableNode(CoverageModelICGCacheNode.M_Psi_inv_st.name(), - new String[]{ - CoverageModelICGCacheTag.E_STEP_W_UNREG.name(), - CoverageModelICGCacheTag.E_STEP_W_REG.name(), - CoverageModelICGCacheTag.M_STEP_M.name(), - CoverageModelICGCacheTag.E_STEP_Z.name(), - CoverageModelICGCacheTag.E_STEP_D.name(), - CoverageModelICGCacheTag.E_STEP_C.name()}, - new String[]{ - CoverageModelICGCacheNode.M_st.name(), - CoverageModelICGCacheNode.tot_Psi_st.name()}, + .computableNode(CoverageModelICGCacheNode.M_Psi_inv_st.key, + new CacheNode.NodeTag[] { + CoverageModelICGCacheTag.E_STEP_W_UNREG.tag, + CoverageModelICGCacheTag.E_STEP_W_REG.tag, + CoverageModelICGCacheTag.M_STEP_M.tag, + CoverageModelICGCacheTag.E_STEP_Z.tag, + CoverageModelICGCacheTag.E_STEP_D.tag, + CoverageModelICGCacheTag.E_STEP_C.tag}, + new CacheNode.NodeKey[] { + CoverageModelICGCacheNode.M_st.key, + CoverageModelICGCacheNode.tot_Psi_st.key}, calculate_M_Psi_inv_st, true) /* log likelihood (w/o regularization) */ - .computableNode(CoverageModelICGCacheNode.loglike_unreg.name(), - new String[]{CoverageModelICGCacheTag.LOGLIKE_UNREG.name()}, - new String[]{ - CoverageModelICGCacheNode.B_st.name(), - CoverageModelICGCacheNode.M_Psi_inv_st.name(), - CoverageModelICGCacheNode.loglike_normalization_s.name()}, + .computableNode(CoverageModelICGCacheNode.loglike_unreg.key, + new CacheNode.NodeTag[] {CoverageModelICGCacheTag.LOGLIKE_UNREG.tag}, + new CacheNode.NodeKey[] { + CoverageModelICGCacheNode.B_st.key, + CoverageModelICGCacheNode.M_Psi_inv_st.key, + CoverageModelICGCacheNode.loglike_normalization_s.key}, calculate_loglike_unreg, true); /* nodes specific to bias covariates */ if (biasCovariatesEnabled) { cgbuilder /* E[W] E[z_s] */ - .computableNode(CoverageModelICGCacheNode.Wz_st.name(), - new String[]{CoverageModelICGCacheTag.M_STEP_M.name(), - CoverageModelICGCacheTag.E_STEP_D.name(), - CoverageModelICGCacheTag.E_STEP_C.name()}, - new String[]{ - CoverageModelICGCacheNode.W_tl.name(), - CoverageModelICGCacheNode.z_sl.name()}, + .computableNode(CoverageModelICGCacheNode.Wz_st.key, + new CacheNode.NodeTag[] {CoverageModelICGCacheTag.M_STEP_M.tag, + CoverageModelICGCacheTag.E_STEP_D.tag, + CoverageModelICGCacheTag.E_STEP_C.tag}, + new CacheNode.NodeKey[] { + CoverageModelICGCacheNode.W_tl.key, + CoverageModelICGCacheNode.z_sl.key}, calculate_Wz_st, true) /* (E[z_s z_s^T] E[W W^T])_{tt} */ - .computableNode(CoverageModelICGCacheNode.WzzWT_st.name(), - new String[]{}, - new String[]{ - CoverageModelICGCacheNode.W_tl.name(), - CoverageModelICGCacheNode.zz_sll.name()}, + .computableNode(CoverageModelICGCacheNode.WzzWT_st.key, + new CacheNode.NodeTag[] {}, + new CacheNode.NodeKey[] { + CoverageModelICGCacheNode.W_tl.key, + CoverageModelICGCacheNode.zz_sll.key}, calculate_WzzWT_st, true) /* v_{t\mu} */ - .computableNode(CoverageModelICGCacheNode.v_tl.name(), - new String[]{ - CoverageModelICGCacheTag.E_STEP_W_REG.name(), - CoverageModelICGCacheTag.E_STEP_W_UNREG.name()}, - new String[]{ - CoverageModelICGCacheNode.M_Psi_inv_st.name(), - CoverageModelICGCacheNode.Delta_st.name(), - CoverageModelICGCacheNode.z_sl.name()}, + .computableNode(CoverageModelICGCacheNode.v_tl.key, + new CacheNode.NodeTag[] { + CoverageModelICGCacheTag.E_STEP_W_REG.tag, + CoverageModelICGCacheTag.E_STEP_W_UNREG.tag}, + new CacheNode.NodeKey[] { + CoverageModelICGCacheNode.M_Psi_inv_st.key, + CoverageModelICGCacheNode.Delta_st.key, + CoverageModelICGCacheNode.z_sl.key}, calculate_v_tl, false) /* Q_{t\mu\nu} */ - .computableNode(CoverageModelICGCacheNode.Q_tll.name(), - new String[]{ - CoverageModelICGCacheTag.E_STEP_W_REG.name(), - CoverageModelICGCacheTag.E_STEP_W_UNREG.name()}, - new String[]{ - CoverageModelICGCacheNode.M_Psi_inv_st.name(), - CoverageModelICGCacheNode.zz_sll.name()}, + .computableNode(CoverageModelICGCacheNode.Q_tll.key, + new CacheNode.NodeTag[] { + CoverageModelICGCacheTag.E_STEP_W_REG.tag, + CoverageModelICGCacheTag.E_STEP_W_UNREG.tag}, + new CacheNode.NodeKey[] { + CoverageModelICGCacheNode.M_Psi_inv_st.key, + CoverageModelICGCacheNode.zz_sll.key}, calculate_Q_tll, false) /* B_{st} */ - .computableNode(CoverageModelICGCacheNode.B_st.name(), - new String[]{ - CoverageModelICGCacheTag.M_STEP_PSI.name(), - CoverageModelICGCacheTag.E_STEP_GAMMA.name()}, - new String[]{ - CoverageModelICGCacheNode.Delta_st.name(), - CoverageModelICGCacheNode.var_log_c_st.name(), - CoverageModelICGCacheNode.var_log_d_s.name(), - CoverageModelICGCacheNode.WzzWT_st.name(), - CoverageModelICGCacheNode.Wz_st.name()}, + .computableNode(CoverageModelICGCacheNode.B_st.key, + new CacheNode.NodeTag[] { + CoverageModelICGCacheTag.M_STEP_PSI.tag, + CoverageModelICGCacheTag.E_STEP_GAMMA.tag}, + new CacheNode.NodeKey[] { + CoverageModelICGCacheNode.Delta_st.key, + CoverageModelICGCacheNode.var_log_c_st.key, + CoverageModelICGCacheNode.var_log_d_s.key, + CoverageModelICGCacheNode.WzzWT_st.key, + CoverageModelICGCacheNode.Wz_st.key}, calculate_B_st_with_bias_covariates, true) /* M_{st} . (log(n_{st}) - E[log(c_{st})] - E[log(d_s)]) - mean --- externally computed */ - .computableNode(CoverageModelICGCacheNode.Delta_PCA_st.name(), - new String[]{}, - new String[]{ - CoverageModelICGCacheNode.log_n_st.name(), - CoverageModelICGCacheNode.log_c_st.name(), - CoverageModelICGCacheNode.log_d_s.name()}, + .computableNode(CoverageModelICGCacheNode.Delta_PCA_st.key, + new CacheNode.NodeTag[] {}, + new CacheNode.NodeKey[] { + CoverageModelICGCacheNode.log_n_st.key, + CoverageModelICGCacheNode.log_c_st.key, + CoverageModelICGCacheNode.log_d_s.key}, null, true); } else { /* no bias covariates */ cgbuilder /* B_{st} */ - .computableNode(CoverageModelICGCacheNode.B_st.name(), - new String[]{ - CoverageModelICGCacheTag.M_STEP_PSI.name(), - CoverageModelICGCacheTag.E_STEP_GAMMA.name()}, - new String[]{ - CoverageModelICGCacheNode.Delta_st.name(), - CoverageModelICGCacheNode.var_log_c_st.name(), - CoverageModelICGCacheNode.var_log_d_s.name()}, + .computableNode(CoverageModelICGCacheNode.B_st.key, + new CacheNode.NodeTag[] { + CoverageModelICGCacheTag.M_STEP_PSI.tag, + CoverageModelICGCacheTag.E_STEP_GAMMA.tag}, + new CacheNode.NodeKey[] { + CoverageModelICGCacheNode.Delta_st.key, + CoverageModelICGCacheNode.var_log_c_st.key, + CoverageModelICGCacheNode.var_log_d_s.key}, calculate_B_st_without_bias_covariates, true); } // TODO github/gatk-protected issue #701 -- this class is part of the upcoming CNV-avoiding regularizer // // /* FFT[W] */ - // .computableNode(CoverageModelICGCacheNode.F_W_tl.name(), + // .computableNode(CoverageModelICGCacheNode.F_W_tl.key, // new String[]{}, - // new String[]{CoverageModelICGCacheNode.W_tl.name()}, + // new String[]{CoverageModelICGCacheNode.W_tl.key}, // null, true); // // /* log likelihood (w/ regularization) */ - // .computableNode(CoverageModelICGCacheNode.loglike_reg.name(), - // new String[]{CoverageModelICGCacheTag.LOGLIKE_REG.name()}, + // .computableNode(CoverageModelICGCacheNode.loglike_reg.key, + // new String[]{CoverageModelICGCacheTag.LOGLIKE_REG.key}, // new String[]{ - // CoverageModelICGCacheNode.B_st.name(), - // CoverageModelICGCacheNode.M_Psi_inv_st.name(), - // CoverageModelICGCacheNode.loglike_normalization_s.name(), - // CoverageModelICGCacheNode.W_tl.name(), - // CoverageModelICGCacheNode.F_W_tl.name(), - // CoverageModelICGCacheNode.zz_sll.name()}, + // CoverageModelICGCacheNode.B_st.key, + // CoverageModelICGCacheNode.M_Psi_inv_st.key, + // CoverageModelICGCacheNode.loglike_normalization_s.key, + // CoverageModelICGCacheNode.W_tl.key, + // CoverageModelICGCacheNode.F_W_tl.key, + // CoverageModelICGCacheNode.zz_sll.key}, // calculate_loglike_reg, true); // /* \sum_t Q_{t\mu\nu} */ - // .computableNode(CoverageModelICGCacheNode.sum_Q_ll.name(), - // new String[]{CoverageModelICGCacheTag.E_STEP_W_REG.name()}, - // new String[]{CoverageModelICGCacheNode.Q_tll.name()}, + // .computableNode(CoverageModelICGCacheNode.sum_Q_ll.key, + // new String[]{CoverageModelICGCacheTag.E_STEP_W_REG.key}, + // new String[]{CoverageModelICGCacheNode.Q_tll.key}, // calculate_sum_Q_ll, true); return cgbuilder.build(); @@ -475,7 +480,7 @@ public LinearlySpacedIndexBlock getTargetSpaceBlock() { */ @QueriesICG public INDArray getINDArrayFromCache(final CoverageModelICGCacheNode key) { - return ((DuplicableNDArray)icg.fetchWithRequiredEvaluations(key.name())).value(); + return ((DuplicableNDArray)icg.fetchWithRequiredEvaluations(key.key)).value(); } /** @@ -488,7 +493,7 @@ public INDArray getINDArrayFromCache(final CoverageModelICGCacheNode key) { */ @QueriesICG public double getDoubleFromCache(final CoverageModelICGCacheNode key) { - return ((DuplicableNumber) icg.fetchWithRequiredEvaluations(key.name())).value().doubleValue(); + return ((DuplicableNumber) icg.fetchWithRequiredEvaluations(key.key)).value().doubleValue(); } private void assertBiasCovariatesEnabled() { @@ -650,9 +655,9 @@ public List> getSampleCopyRatioLatentPo final INDArray mu_st; if (biasCovariatesEnabled) { final INDArray Wz_st = getINDArrayFromCache(CoverageModelICGCacheNode.Wz_st); - mu_st = Wz_st.addRowVector(m_t); + mu_st = Wz_st.addRowVector(m_t); } else { - mu_st = Nd4j.vstack(Collections.nCopies(numSamples, m_t)); + mu_st = Nd4j.vstack(Collections.nCopies(numSamples, m_t)); } final double[] psiArray = Psi_t.dup().data().asDouble(); @@ -1163,7 +1168,7 @@ public CoverageModelEMComputeBlock cloneWithRemovedPCAInitializationData() { public CoverageModelEMComputeBlock cloneWithUpdatedPrimitive(@Nonnull final CoverageModelICGCacheNode key, @Nullable final INDArray value) { return new CoverageModelEMComputeBlock(targetBlock, numSamples, numLatents, ardEnabled, - icg.setValue(key.name(), new DuplicableNDArray(value)), latestMStepSignal); + icg.setValue(key.key, new DuplicableNDArray(value)), latestMStepSignal); } /** @@ -1180,10 +1185,10 @@ public CoverageModelEMComputeBlock cloneWithUpdatedPrimitiveAndSignal(@Nonnull f @Nonnull final SubroutineSignal latestMStepSignal) { if (value == null) { return new CoverageModelEMComputeBlock(targetBlock, numSamples, numLatents, ardEnabled, - icg.setValue(key.name(), new DuplicableNDArray()), latestMStepSignal); + icg.setValue(key.key, new DuplicableNDArray()), latestMStepSignal); } else { return new CoverageModelEMComputeBlock(targetBlock, numSamples, numLatents, ardEnabled, - icg.setValue(key.name(), new DuplicableNDArray(value)), latestMStepSignal); + icg.setValue(key.key, new DuplicableNDArray(value)), latestMStepSignal); } } @@ -1206,7 +1211,7 @@ public CoverageModelEMComputeBlock cloneWithUpdatedSignal(@Nonnull final Subrout */ public CoverageModelEMComputeBlock cloneWithUpdatedCachesByTag(final CoverageModelICGCacheTag tag) { return new CoverageModelEMComputeBlock(targetBlock, numSamples, numLatents, ardEnabled, - icg.updateCachesForTag(tag.name()), latestMStepSignal); + icg.updateCachesForTag(tag.tag), latestMStepSignal); } /** @@ -1231,25 +1236,25 @@ public void performGarbageCollection() { /* dependents: [M_st] */ private static final ComputableNodeFunction calculate_sum_M_t = new ComputableNodeFunction() { @Override - public Duplicable apply(final Map parents) { - return new DuplicableNDArray(fetchINDArray(CoverageModelICGCacheNode.M_st.name(), parents).sum(0)); + public Duplicable apply(final Map parents) { + return new DuplicableNDArray(fetchINDArray(CoverageModelICGCacheNode.M_st.key, parents).sum(0)); } }; /* dependents: [M_st] */ private static final ComputableNodeFunction calculate_sum_M_s = new ComputableNodeFunction() { @Override - public Duplicable apply(final Map parents) { - return new DuplicableNDArray(fetchINDArray(CoverageModelICGCacheNode.M_st.name(), parents).sum(1)); + public Duplicable apply(final Map parents) { + return new DuplicableNDArray(fetchINDArray(CoverageModelICGCacheNode.M_st.key, parents).sum(1)); } }; /* dependents: [n_st, M_st] */ private static final ComputableNodeFunction calculate_Sigma_st = new ComputableNodeFunction() { @Override - public Duplicable apply(final Map parents) { - final INDArray n_st = fetchINDArray(CoverageModelICGCacheNode.n_st.name(), parents); - final INDArray M_st = fetchINDArray(CoverageModelICGCacheNode.M_st.name(), parents); + public Duplicable apply(final Map parents) { + final INDArray n_st = fetchINDArray(CoverageModelICGCacheNode.n_st.key, parents); + final INDArray M_st = fetchINDArray(CoverageModelICGCacheNode.M_st.key, parents); final INDArray Sigma_st = replaceMaskedEntries( Nd4j.ones(n_st.shape()).divi(n_st), M_st, @@ -1261,9 +1266,9 @@ public Duplicable apply(final Map parents) { /* dependents: [n_st, M_st] */ private static final ComputableNodeFunction calculate_log_n_st = new ComputableNodeFunction() { @Override - public Duplicable apply(final Map parents) { - final INDArray n_st = fetchINDArray(CoverageModelICGCacheNode.n_st.name(), parents); - final INDArray M_st = fetchINDArray(CoverageModelICGCacheNode.M_st.name(), parents); + public Duplicable apply(final Map parents) { + final INDArray n_st = fetchINDArray(CoverageModelICGCacheNode.n_st.key, parents); + final INDArray M_st = fetchINDArray(CoverageModelICGCacheNode.M_st.key, parents); final INDArray log_n_st = replaceMaskedEntries( Transforms.log(n_st, true), M_st, @@ -1274,11 +1279,11 @@ public Duplicable apply(final Map parents) { private static final ComputableNodeFunction calculate_Delta_st = new ComputableNodeFunction() { @Override - public Duplicable apply(final Map parents) { - final INDArray log_n_st = fetchINDArray(CoverageModelICGCacheNode.log_n_st.name(), parents); - final INDArray log_c_st = fetchINDArray(CoverageModelICGCacheNode.log_c_st.name(), parents); - final INDArray log_d_s = fetchINDArray(CoverageModelICGCacheNode.log_d_s.name(), parents); - final INDArray m_t = fetchINDArray(CoverageModelICGCacheNode.m_t.name(), parents); + public Duplicable apply(final Map parents) { + final INDArray log_n_st = fetchINDArray(CoverageModelICGCacheNode.log_n_st.key, parents); + final INDArray log_c_st = fetchINDArray(CoverageModelICGCacheNode.log_c_st.key, parents); + final INDArray log_d_s = fetchINDArray(CoverageModelICGCacheNode.log_d_s.key, parents); + final INDArray m_t = fetchINDArray(CoverageModelICGCacheNode.m_t.key, parents); return new DuplicableNDArray(log_n_st.sub(log_c_st).subiColumnVector(log_d_s).subiRowVector(m_t)); } }; @@ -1286,9 +1291,9 @@ public Duplicable apply(final Map parents) { /* dependents: [W_tl, z_sl] */ private static final ComputableNodeFunction calculate_Wz_st = new ComputableNodeFunction() { @Override - public Duplicable apply(final Map parents) { - final INDArray W_tl = fetchINDArray(CoverageModelICGCacheNode.W_tl.name(), parents); - final INDArray z_sl = fetchINDArray(CoverageModelICGCacheNode.z_sl.name(), parents); + public Duplicable apply(final Map parents) { + final INDArray W_tl = fetchINDArray(CoverageModelICGCacheNode.W_tl.key, parents); + final INDArray z_sl = fetchINDArray(CoverageModelICGCacheNode.z_sl.key, parents); return new DuplicableNDArray(W_tl.mmul(z_sl.transpose()).transpose()); } }; @@ -1296,9 +1301,9 @@ public Duplicable apply(final Map parents) { /* dependents: [W_tl, zz_sll] */ private static final ComputableNodeFunction calculate_WzzWT_st = new ComputableNodeFunction() { @Override - public Duplicable apply(final Map parents) { - final INDArray W_tl = fetchINDArray(CoverageModelICGCacheNode.W_tl.name(), parents); - final INDArray zz_sll = fetchINDArray(CoverageModelICGCacheNode.zz_sll.name(), parents); + public Duplicable apply(final Map parents) { + final INDArray W_tl = fetchINDArray(CoverageModelICGCacheNode.W_tl.key, parents); + final INDArray zz_sll = fetchINDArray(CoverageModelICGCacheNode.zz_sll.key, parents); final int numSamples = zz_sll.shape()[0]; final int numTargets = W_tl.shape()[0]; final INDArray WzzWT_st = Nd4j.create(numSamples, numTargets); @@ -1315,10 +1320,10 @@ public Duplicable apply(final Map parents) { /* dependents: ["Sigma_st", "Psi_t"] */ private static final ComputableNodeFunction calculate_tot_Psi_st = new ComputableNodeFunction() { @Override - public Duplicable apply(final Map parents) { - final INDArray Sigma_st = fetchINDArray(CoverageModelICGCacheNode.Sigma_st.name(), parents); - final INDArray Psi_t = fetchINDArray(CoverageModelICGCacheNode.Psi_t.name(), parents); - final INDArray gamma_s = fetchINDArray(CoverageModelICGCacheNode.gamma_s.name(), parents); + public Duplicable apply(final Map parents) { + final INDArray Sigma_st = fetchINDArray(CoverageModelICGCacheNode.Sigma_st.key, parents); + final INDArray Psi_t = fetchINDArray(CoverageModelICGCacheNode.Psi_t.key, parents); + final INDArray gamma_s = fetchINDArray(CoverageModelICGCacheNode.gamma_s.key, parents); return new DuplicableNDArray(Sigma_st.addRowVector(Psi_t).addiColumnVector(gamma_s)); } }; @@ -1326,10 +1331,10 @@ public Duplicable apply(final Map parents) { /* dependents: ["M_st", "tot_Psi_st", "log_n_st"] */ private static final ComputableNodeFunction calculate_loglike_normalization_s = new ComputableNodeFunction() { @Override - public Duplicable apply(final Map parents) { - final INDArray M_st = fetchINDArray(CoverageModelICGCacheNode.M_st.name(), parents); - final INDArray tot_Psi_st = fetchINDArray(CoverageModelICGCacheNode.tot_Psi_st.name(), parents); - final INDArray log_n_st = fetchINDArray(CoverageModelICGCacheNode.log_n_st.name(), parents); + public Duplicable apply(final Map parents) { + final INDArray M_st = fetchINDArray(CoverageModelICGCacheNode.M_st.key, parents); + final INDArray tot_Psi_st = fetchINDArray(CoverageModelICGCacheNode.tot_Psi_st.key, parents); + final INDArray log_n_st = fetchINDArray(CoverageModelICGCacheNode.log_n_st.key, parents); final INDArray loglike_normalization_s = Transforms.log(tot_Psi_st, true) .addi(FastMath.log(2 * FastMath.PI)) .muli(-0.5) @@ -1343,19 +1348,19 @@ public Duplicable apply(final Map parents) { /* dependents: ["M_st", "tot_Psi_st"] */ private static final ComputableNodeFunction calculate_M_Psi_inv_st = new ComputableNodeFunction() { @Override - public Duplicable apply(final Map parents) { - return new DuplicableNDArray(fetchINDArray(CoverageModelICGCacheNode.M_st.name(), parents).div( - fetchINDArray(CoverageModelICGCacheNode.tot_Psi_st.name(), parents))); + public Duplicable apply(final Map parents) { + return new DuplicableNDArray(fetchINDArray(CoverageModelICGCacheNode.M_st.key, parents).div( + fetchINDArray(CoverageModelICGCacheNode.tot_Psi_st.key, parents))); } }; /* dependents: ["M_Psi_inv_st", "Delta_st", "z_sl"] */ private static final ComputableNodeFunction calculate_v_tl = new ComputableNodeFunction() { @Override - public Duplicable apply(final Map parents) { - final INDArray M_Psi_inv_st = fetchINDArray(CoverageModelICGCacheNode.M_Psi_inv_st.name(), parents); - final INDArray Delta_st = fetchINDArray(CoverageModelICGCacheNode.Delta_st.name(), parents); - final INDArray z_sl = fetchINDArray(CoverageModelICGCacheNode.z_sl.name(), parents); + public Duplicable apply(final Map parents) { + final INDArray M_Psi_inv_st = fetchINDArray(CoverageModelICGCacheNode.M_Psi_inv_st.key, parents); + final INDArray Delta_st = fetchINDArray(CoverageModelICGCacheNode.Delta_st.key, parents); + final INDArray z_sl = fetchINDArray(CoverageModelICGCacheNode.z_sl.key, parents); return new DuplicableNDArray(M_Psi_inv_st.mul(Delta_st).transpose().mmul(z_sl)); } }; @@ -1363,9 +1368,9 @@ public Duplicable apply(final Map parents) { /* dependents: ["M_Psi_inv_st", "zz_sll"] */ private static final ComputableNodeFunction calculate_Q_tll = new ComputableNodeFunction() { @Override - public Duplicable apply(final Map parents) { - final INDArray M_Psi_inv_st_trans = fetchINDArray(CoverageModelICGCacheNode.M_Psi_inv_st.name(), parents).transpose(); - final INDArray zz_sll = fetchINDArray(CoverageModelICGCacheNode.zz_sll.name(), parents); + public Duplicable apply(final Map parents) { + final INDArray M_Psi_inv_st_trans = fetchINDArray(CoverageModelICGCacheNode.M_Psi_inv_st.key, parents).transpose(); + final INDArray zz_sll = fetchINDArray(CoverageModelICGCacheNode.zz_sll.key, parents); final int numTargets = M_Psi_inv_st_trans.shape()[0]; final int numLatents = zz_sll.shape()[1]; final INDArray res = Nd4j.create(numTargets, numLatents, numLatents); @@ -1380,20 +1385,20 @@ public Duplicable apply(final Map parents) { /* dependents: ["Q_tll"] */ private static final ComputableNodeFunction calculate_sum_Q_ll = new ComputableNodeFunction() { @Override - public Duplicable apply(final Map parents) { - return new DuplicableNDArray(fetchINDArray(CoverageModelICGCacheNode.Q_tll.name(), parents).sum(0)); + public Duplicable apply(final Map parents) { + return new DuplicableNDArray(fetchINDArray(CoverageModelICGCacheNode.Q_tll.key, parents).sum(0)); } }; /* dependents: ["Delta_st", "var_log_c_st", "var_log_d_s", "WzzWT_st", "Wz_st"] */ private static final ComputableNodeFunction calculate_B_st_with_bias_covariates = new ComputableNodeFunction() { @Override - public Duplicable apply(final Map parents) { - final INDArray Delta_st = fetchINDArray(CoverageModelICGCacheNode.Delta_st.name(), parents); - final INDArray var_log_c_st = fetchINDArray(CoverageModelICGCacheNode.var_log_c_st.name(), parents); - final INDArray var_log_d_s = fetchINDArray(CoverageModelICGCacheNode.var_log_d_s.name(), parents); - final INDArray WzzWT_st = fetchINDArray(CoverageModelICGCacheNode.WzzWT_st.name(), parents); - final INDArray Wz_st = fetchINDArray(CoverageModelICGCacheNode.Wz_st.name(), parents); + public Duplicable apply(final Map parents) { + final INDArray Delta_st = fetchINDArray(CoverageModelICGCacheNode.Delta_st.key, parents); + final INDArray var_log_c_st = fetchINDArray(CoverageModelICGCacheNode.var_log_c_st.key, parents); + final INDArray var_log_d_s = fetchINDArray(CoverageModelICGCacheNode.var_log_d_s.key, parents); + final INDArray WzzWT_st = fetchINDArray(CoverageModelICGCacheNode.WzzWT_st.key, parents); + final INDArray Wz_st = fetchINDArray(CoverageModelICGCacheNode.Wz_st.key, parents); return new DuplicableNDArray( Delta_st.mul(Delta_st) .addi(var_log_c_st) @@ -1406,10 +1411,10 @@ public Duplicable apply(final Map parents) { /* dependents: ["Delta_st", "var_log_c_st", "var_log_d_s"] */ private static final ComputableNodeFunction calculate_B_st_without_bias_covariates = new ComputableNodeFunction() { @Override - public Duplicable apply(final Map parents) { - final INDArray Delta_st = fetchINDArray(CoverageModelICGCacheNode.Delta_st.name(), parents); - final INDArray var_log_c_st = fetchINDArray(CoverageModelICGCacheNode.var_log_c_st.name(), parents); - final INDArray var_log_d_s = fetchINDArray(CoverageModelICGCacheNode.var_log_d_s.name(), parents); + public Duplicable apply(final Map parents) { + final INDArray Delta_st = fetchINDArray(CoverageModelICGCacheNode.Delta_st.key, parents); + final INDArray var_log_c_st = fetchINDArray(CoverageModelICGCacheNode.var_log_c_st.key, parents); + final INDArray var_log_d_s = fetchINDArray(CoverageModelICGCacheNode.var_log_d_s.key, parents); return new DuplicableNDArray(Delta_st.mul(Delta_st).addi(var_log_c_st).addiColumnVector(var_log_d_s)); } }; @@ -1421,11 +1426,11 @@ public Duplicable apply(final Map parents) { /* dependents: ["B_st", "M_Psi_inv_st", "loglike_normalization_s"] */ private static final ComputableNodeFunction calculate_loglike_unreg = new ComputableNodeFunction() { @Override - public Duplicable apply(final Map parents) { + public Duplicable apply(final Map parents) { /* fetch */ - final INDArray B_st = fetchINDArray(CoverageModelICGCacheNode.B_st.name(), parents); - final INDArray M_Psi_inv_st = fetchINDArray(CoverageModelICGCacheNode.M_Psi_inv_st.name(), parents); - final INDArray loglike_normalization_s = fetchINDArray(CoverageModelICGCacheNode.loglike_normalization_s.name(), parents); + final INDArray B_st = fetchINDArray(CoverageModelICGCacheNode.B_st.key, parents); + final INDArray M_Psi_inv_st = fetchINDArray(CoverageModelICGCacheNode.M_Psi_inv_st.key, parents); + final INDArray loglike_normalization_s = fetchINDArray(CoverageModelICGCacheNode.loglike_normalization_s.key, parents); return new DuplicableNDArray(B_st.mul(M_Psi_inv_st).sum(1) .muli(-0.5) @@ -1437,17 +1442,17 @@ public Duplicable apply(final Map parents) { /* dependents: ["B_st", "M_Psi_inv_st", "loglike_normalization_s", "W_tl", "F_W_tl", "zz_sll"] */ private static final ComputableNodeFunction calculate_loglike_reg = new ComputableNodeFunction() { @Override - public Duplicable apply(final Map parents) { - final INDArray B_st = fetchINDArray(CoverageModelICGCacheNode.B_st.name(), parents); - final INDArray M_Psi_inv_st = fetchINDArray(CoverageModelICGCacheNode.M_Psi_inv_st.name(), parents); - final INDArray loglike_normalization_s = fetchINDArray(CoverageModelICGCacheNode.loglike_normalization_s.name(), parents); + public Duplicable apply(final Map parents) { + final INDArray B_st = fetchINDArray(CoverageModelICGCacheNode.B_st.key, parents); + final INDArray M_Psi_inv_st = fetchINDArray(CoverageModelICGCacheNode.M_Psi_inv_st.key, parents); + final INDArray loglike_normalization_s = fetchINDArray(CoverageModelICGCacheNode.loglike_normalization_s.key, parents); final INDArray regularPart = B_st.mul(M_Psi_inv_st).sum(1) .muli(-0.5) .addi(loglike_normalization_s); - final INDArray W_tl = fetchINDArray(CoverageModelICGCacheNode.W_tl.name(), parents); - final INDArray F_W_tl = fetchINDArray(CoverageModelICGCacheNode.F_W_tl.name(), parents); - final INDArray zz_sll = fetchINDArray(CoverageModelICGCacheNode.zz_sll.name(), parents); + final INDArray W_tl = fetchINDArray(CoverageModelICGCacheNode.W_tl.key, parents); + final INDArray F_W_tl = fetchINDArray(CoverageModelICGCacheNode.F_W_tl.key, parents); + final INDArray zz_sll = fetchINDArray(CoverageModelICGCacheNode.zz_sll.key, parents); final INDArray WFW_ll = W_tl.transpose().mmul(F_W_tl).muli(-0.5); final int numSamples = B_st.shape()[0]; final INDArray filterPart = Nd4j.create(new int[]{numSamples, 1}, diff --git a/src/main/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/CacheNode.java b/src/main/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/CacheNode.java index 2aca3c771..72369e36f 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/CacheNode.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/CacheNode.java @@ -13,32 +13,32 @@ * * @author Mehrtash Babadi <mehrtash@broadinstitute.org> */ -abstract class CacheNode { +public abstract class CacheNode { /** - * A string identifier for the cache node + * A key for identifying this cache node */ - private final String key; + private final NodeKey key; /** - * The collection of string identifiers of the immediate parents of this node (can be empty) + * The collection of lookup keys for parents of this node (can be empty) */ - private final Collection parents; + private final Collection parents; /** - * The collection of string identifiers of the tags associated to this node (can be empty) + * The collection of tags associated to this node (can be empty) */ - private final Collection tags; + private final Collection tags; /** * Public constructor * - * @param key string identifier of the cache node + * @param key lookup key of the cache node * @param tags the tags associated to this cache node - * @param parents immediate parents of this cache node + * @param parents lookup keys of the parents of this cache node */ - CacheNode(@Nonnull final String key, - @Nonnull final Collection tags, - @Nonnull final Collection parents) { + CacheNode(@Nonnull final NodeKey key, + @Nonnull final Collection tags, + @Nonnull final Collection parents) { this.key = Utils.nonNull(key, "The key of a cache node can not be null"); this.tags = Collections.unmodifiableCollection(Utils.nonNull(tags, "The tag collection of a cache node can not be null")); this.parents = Collections.unmodifiableCollection(Utils.nonNull(parents, "The immediate parents of a cache node can not be null")); @@ -47,10 +47,10 @@ abstract class CacheNode { /** * Get the value stored in the node * - * @param parents parent values (as a map from their string identifiers to their values) + * @param parents parent values (as a map from their keys to their values) * @return a {@link Duplicable}; possibly by reference */ - abstract Duplicable get(@Nonnull final Map parents); + abstract Duplicable get(@Nonnull final Map parents); /** * Set the value of the node @@ -95,27 +95,27 @@ abstract class CacheNode { * Get the string identifier of the node * @return a non-null {@link String} */ - final String getKey() { + final NodeKey getKey() { return key; } /** - * Get the collection of string identifier of the parents of this node (can be empty) + * Get the collection of keys of the parents of this node (can be empty) */ - final Collection getParents() { + final Collection getParents() { return Collections.unmodifiableCollection(parents); } /** - * Get the collection of string identifier of the tags associated to this node (can be empty) + * Get the collection of tags associated to this node (can be empty) */ - final Collection getTags() { + final Collection getTags() { return Collections.unmodifiableCollection(tags); } @Override public final String toString() { - return key; + return key.toString(); } /** @@ -136,4 +136,66 @@ public final boolean equals(Object other) { public final int hashCode() { return key.hashCode(); } + + /** + * This class represents a node key. It is a wrapper around a String. + */ + public static class NodeKey { + private final String key; + + public NodeKey(final String key) { + this.key = Utils.nonNull(key, "Node key must be non-null"); + } + + @Override + public String toString() { + return key; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + NodeKey nodeKey = (NodeKey) o; + + return key.equals(nodeKey.key); + } + + @Override + public int hashCode() { + return key.hashCode(); + } + } + + /** + * This class represents a node tag. It is a wrapper around a String. + */ + public static class NodeTag { + private final String tag; + + public NodeTag(final String tag) { + this.tag = Utils.nonNull(tag, "Node tag must be non-null"); + } + + @Override + public String toString() { + return tag; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + NodeTag nodeTag = (NodeTag) o; + + return tag.equals(nodeTag.tag); + } + + @Override + public int hashCode() { + return tag.hashCode(); + } + } } diff --git a/src/main/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/ComputableCacheNode.java b/src/main/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/ComputableCacheNode.java index 637c14430..3fb80bcfc 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/ComputableCacheNode.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/ComputableCacheNode.java @@ -24,13 +24,13 @@ final class ComputableCacheNode extends CacheNode { * Public constructor * * @param key the key of the node - * @param parents immediate parents of the node + * @param parents parents of the node * @param func a function from a map that (at least) contains parents data to the computed value of this node * @param isCaching does it store the value or not */ - ComputableCacheNode(@Nonnull final String key, - @Nonnull final Collection tags, - @Nonnull final Collection parents, + ComputableCacheNode(@Nonnull final NodeKey key, + @Nonnull final Collection tags, + @Nonnull final Collection parents, @Nullable final ComputableNodeFunction func, final boolean isCaching) { super(key, tags, parents); @@ -41,9 +41,9 @@ final class ComputableCacheNode extends CacheNode { isCacheCurrent = false; } - private ComputableCacheNode(@Nonnull final String key, - @Nonnull final Collection tags, - @Nonnull final Collection parents, + private ComputableCacheNode(@Nonnull final NodeKey key, + @Nonnull final Collection tags, + @Nonnull final Collection parents, @Nullable final ComputableNodeFunction func, final boolean isCaching, final Duplicable cachedValue, @@ -93,7 +93,7 @@ void set(@Nullable final Duplicable val) { * @throws ComputableNodeFunction.ParentValueNotFoundException if a required parent value is not given */ @Override - Duplicable get(@Nonnull final Map parentsValues) + Duplicable get(@Nonnull final Map parentsValues) throws ComputableNodeFunction.ParentValueNotFoundException, ExternallyComputableNodeValueUnavailableException { if (hasValue()) { return cachedValue; @@ -119,6 +119,7 @@ ComputableCacheNode duplicate() { * @param newValue the cache value to be replaced with the old value * @return a new instance of {@link ComputableCacheNode} */ + @Override ComputableCacheNode duplicateWithUpdatedValue(final Duplicable newValue) { if (isCaching && newValue != null && newValue.hasValue()) { return new ComputableCacheNode(getKey(), getTags(), getParents(), func, true, newValue, true); @@ -153,7 +154,7 @@ static final class ExternallyComputableNodeValueUnavailableException extends Run implements Serializable { private static final long serialVersionUID = 9056196660803073912L; - private ExternallyComputableNodeValueUnavailableException(final String nodeKey) { + private ExternallyComputableNodeValueUnavailableException(final NodeKey nodeKey) { super(String.format("Either the externally mutable node \"%s\" is not initialized or is outdated", nodeKey)); } diff --git a/src/main/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/ComputableGraphStructure.java b/src/main/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/ComputableGraphStructure.java index 5856bf1cc..e8b3016c0 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/ComputableGraphStructure.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/ComputableGraphStructure.java @@ -16,11 +16,16 @@ * * - assertion for existence of no cycles * - construction of maps of nodes to their descendants and ancestors - * - propagation of tags from descendants to ancestors + * - propagation of tags from descendants to ancestors (tags are instances of {@link CacheNode.NodeTag} that are used + * for marking one or more cache nodes that are required for a specific computation; for example, see the javadoc + * of {@link ImmutableComputableGraph} for a concrete use case) * - topological order for evaluating a computable node * - topological order for mutating a primitive/externally-computed node and updating the caches of all involved nodes * - topological order for evaluating all nodes associated to a tag (see {@link ImmutableComputableGraph}) - * - topological order for complete computation of the graph + * - topological order for evaluating all nodes in the graph + * + * @implNote the implementation of some of the algorithms in this class, while being fast for small graphs, are not + * optimized for large graphs and can run into {@link StackOverflowError}. * * @author Mehrtash Babadi <mehrtash@broadinstitute.org> */ @@ -28,18 +33,18 @@ final class ComputableGraphStructure implements Serializable { private static final long serialVersionUID = -3124293279477371159L; - private final Set nodeKeysSet; - private final Set nodeTagsSet; - private final Map> inducedTagsMap; - private final Map> descendantsMap; - private final Map> childrenMap; - private final Map> parentsMap; - private final Map> ancestorsMap; - private final Map topologicalOrderMap; - private final Map> topologicalOrderForNodeEvaluation; - private final Map> topologicalOrderForNodeMutation; - private final Map> topologicalOrderForTagEvaluation; - private final List topologicalOrderForCompleteEvaluation; + private final Set nodeKeysSet; + private final Set nodeTagsSet; + private final Map> inducedTagsMap; + private final Map> descendantsMap; + private final Map> childrenMap; + private final Map> parentsMap; + private final Map> ancestorsMap; + private final Map topologicalOrderMap; + private final Map> topologicalOrderForNodeEvaluation; + private final Map> topologicalOrderForNodeMutation; + private final Map> topologicalOrderForTagEvaluation; + private final List topologicalOrderForCompleteEvaluation; /** * An arbitrary negative number to denote the to-be-determined topological order a node @@ -53,114 +58,111 @@ final class ComputableGraphStructure implements Serializable { * @param nodeSet a set of {@link CacheNode}s */ ComputableGraphStructure(@Nonnull final Set nodeSet) { - - /* create the set of keys */ Utils.nonNull(nodeSet, "The given set of nodes must be non-null"); nodeKeysSet = extractKeys(nodeSet); nodeTagsSet = extractTags(nodeSet); assertParentKeysExist(nodeSet, nodeKeysSet); - /* nodeKey -> set of parents' keys */ parentsMap = getParentsMap(nodeSet); - - /* nodeKey -> set of children's keys */ childrenMap = getChildrenMap(nodeSet); - - /* nodeKey -> topological order */ topologicalOrderMap = getTopologicalOrderMap(parentsMap, nodeKeysSet); - - /* topological order -> set of node keys */ - final Map> nodesByTopologicalOrderMap = getNodesByTopologicalOrderMap(topologicalOrderMap, - nodeKeysSet); - - /* nodeKey -> set of descendants' keys */ + final Map> nodesByTopologicalOrderMap = + getNodesByTopologicalOrderMap(topologicalOrderMap, nodeKeysSet); descendantsMap = getDescendantsMap(childrenMap, nodesByTopologicalOrderMap); - - /* nodeKey -> set of ancestors' keys */ ancestorsMap = getAncestorsMap(parentsMap, nodesByTopologicalOrderMap); - - /* nodeKey -> set of upward-propagated tags */ inducedTagsMap = getInducedTagsMap(nodeSet, nodesByTopologicalOrderMap, ancestorsMap); - - /* tag -> set of tagged nodes, including upward-propagation */ - final Map> nodesByInducedTagMap = getNodesByInducedTagMap(inducedTagsMap, nodeKeysSet, - nodeTagsSet); - - /* topological order for evaluating a single node */ + final Map> nodesByInducedTagMap = + getNodesByInducedTagMap(inducedTagsMap, nodeKeysSet, nodeTagsSet); topologicalOrderForNodeEvaluation = getTopologicalOrderForNodeEvaluation(nodeKeysSet, topologicalOrderMap, ancestorsMap); - - /* topological order for evaluating all nodes associated to a tag */ topologicalOrderForTagEvaluation = getTopologicalOrderForTagEvaluation(nodeTagsSet, topologicalOrderMap, nodesByInducedTagMap, ancestorsMap); - - /* topological order for evaluating all nodes */ topologicalOrderForCompleteEvaluation = getTopologicalOrderForCompleteEvaluation(nodeKeysSet, topologicalOrderMap); - - /* topological order for updating the descendants of a mutated node */ topologicalOrderForNodeMutation = getTopologicalOrderForNodeMutation(nodeKeysSet, ancestorsMap, descendantsMap, topologicalOrderMap); } private static void assertParentKeysExist(@Nonnull final Set nodeSet, - @Nonnull final Set nodeKeysSet) { + @Nonnull final Set nodeKeysSet) { for (final CacheNode node : nodeSet) { if (!nodeKeysSet.containsAll(node.getParents())) { - final Set undefinedParents = Sets.difference(new HashSet<>(node.getParents()), nodeKeysSet); - throw new NonexistentParentNodeKey("Node " + ImmutableComputableGraphUtils.quote(node.getKey()) + + final Set undefinedParents = Sets.difference(new HashSet<>(node.getParents()), nodeKeysSet); + throw new NonexistentParentNodeKey("Node " + ImmutableComputableGraphUtils.quote(node.getKey().toString()) + " depends on undefined parent(s): " + undefinedParents.stream() + .map(CacheNode.NodeKey::toString) .map(ImmutableComputableGraphUtils::quote).collect(Collectors.joining(", "))); } } } - private static Set extractTags(@Nonnull Set nodeSet) { + private static Set extractTags(@Nonnull Set nodeSet) { return nodeSet.stream().map(CacheNode::getTags).flatMap(Collection::stream).collect(Collectors.toSet()); } - private static Set extractKeys(@Nonnull Set nodeSet) { + private static Set extractKeys(@Nonnull Set nodeSet) { return nodeSet.stream().map(CacheNode::getKey).collect(Collectors.toSet()); } - private static Map getTopologicalOrderMap(@Nonnull final Map> immediateParentsMap, - @Nonnull final Set nodeKeysSet) { - final Map topologicalOrderMap = new HashMap<>(); + /** + * Sorts nodes by topological order using depth-first search. The output is a map from node keys to + * their depth (root nodes have 0 depth). + */ + private static Map getTopologicalOrderMap( + @Nonnull final Map> parentsMap, + @Nonnull final Set nodeKeysSet) { + final Map topologicalOrderMap = new HashMap<>(); nodeKeysSet.forEach(key -> topologicalOrderMap.put(key, UNDEFINED_TOPOLOGICAL_ORDER)); - nodeKeysSet.forEach(nodeKey -> updateDepth(nodeKey, 0, nodeKeysSet, immediateParentsMap, topologicalOrderMap)); + nodeKeysSet.forEach(nodeKey -> updateDepth(nodeKey, 0, nodeKeysSet, parentsMap, topologicalOrderMap)); return topologicalOrderMap; } - private static Map> getNodesByTopologicalOrderMap(Map topologicalOrderMap, - Set nodeKeysSet) { + /** + * Generates a (topological order -> set of node keys) map. The map clusters the nodes at the + * same depth (root nodes have 0 depth). + */ + private static Map> getNodesByTopologicalOrderMap( + @Nonnull final Map topologicalOrderMap, + @Nonnull final Set nodeKeysSet) { final int maxDepth = Collections.max(topologicalOrderMap.values()); - final Map> nodesByTopologicalOrderMap = new HashMap<>(); - IntStream.range(0, maxDepth + 1).forEach(depth -> - nodesByTopologicalOrderMap.put(depth, - nodeKeysSet.stream().filter(node -> - topologicalOrderMap.get(node) == depth).collect(Collectors.toSet()))); + final Map> nodesByTopologicalOrderMap = new HashMap<>(); + IntStream.range(0, maxDepth + 1) + .forEach(depth -> nodesByTopologicalOrderMap.put(depth, + nodeKeysSet.stream() + .filter(node -> topologicalOrderMap.get(node) == depth) + .collect(Collectors.toSet()))); return nodesByTopologicalOrderMap; } - private static Map> getParentsMap(@Nonnull final Set nodeSet) { + /** + * Generates a (nodeKey -> set of parents' keys) map. + */ + private static Map> getParentsMap(@Nonnull final Set nodeSet) { return nodeSet.stream() .collect(Collectors.toMap(CacheNode::getKey, node -> new HashSet<>(node.getParents()))); } - private static Map> getChildrenMap(@Nonnull final Set nodeSet) { - final Map> childrenMap = nodeSet.stream() - .collect(Collectors.toMap(CacheNode::getKey, node -> new HashSet())); + /** + * Generates a (nodeKey -> set of children's keys) map. + */ + private static Map> getChildrenMap(@Nonnull final Set nodeSet) { + final Map> childrenMap = nodeSet.stream() + .collect(Collectors.toMap(CacheNode::getKey, node -> new HashSet())); nodeSet.forEach(node -> node.getParents().forEach(parentKey -> childrenMap.get(parentKey) .add(node.getKey()))); return childrenMap; } - private static Map> getInducedTagsMap(@Nonnull final Set nodeSet, - @Nonnull final Map> nodesByTopologicalOrderMap, - @Nonnull final Map> allParentsMap) { + /** + * Generates a (nodeKey -> set of upward-propagated tags) map. + */ + private static Map> getInducedTagsMap( + @Nonnull final Set nodeSet, + @Nonnull final Map> nodesByTopologicalOrderMap, + @Nonnull final Map> allParentsMap) { final int maxDepth = Collections.max(nodesByTopologicalOrderMap.keySet()); /* initialize with given tags */ - final Map> allTagsMap = nodeSet.stream() + final Map> allTagsMap = nodeSet.stream() .collect(Collectors.toMap(CacheNode::getKey, node -> new HashSet<>(node.getTags()))); /* propagate tags to all parents */ for (int depth = maxDepth; depth >= 0; depth--) { @@ -171,29 +173,37 @@ private static Map> getInducedTagsMap(@Nonnull final Set> getNodesByInducedTagMap(@Nonnull final Map> allTagsMap, - @Nonnull final Set nodeKeysSet, - @Nonnull final Set nodeTagsSet) { - final Map> nodesByTagMap = nodeTagsSet.stream() - .collect(Collectors.toMap(Function.identity(), tag -> new HashSet())); + /** + * Generates a (tag -> set of tagged nodes, including upward-propagation) map. + */ + private static Map> getNodesByInducedTagMap( + @Nonnull final Map> allTagsMap, + @Nonnull final Set nodeKeysSet, + @Nonnull final Set nodeTagsSet) { + final Map> nodesByTagMap = nodeTagsSet.stream() + .collect(Collectors.toMap(Function.identity(), tag -> new HashSet())); nodeKeysSet.forEach(nodeKey -> allTagsMap.get(nodeKey).forEach(tag -> nodesByTagMap.get(tag).add(nodeKey))); return nodesByTagMap; } - private static Map> getDescendantsMap(@Nonnull final Map> childrenMap, - @Nonnull final Map> nodesByTopologicalOrderMap) { - final Map> descendantsMap = new HashMap<>(); + /** + * Generates a (nodeKey -> set of descendants' keys) map. + */ + private static Map> getDescendantsMap( + @Nonnull final Map> childrenMap, + @Nonnull final Map> nodesByTopologicalOrderMap) { + final Map> descendantsMap = new HashMap<>(); final int maxDepth = Collections.max(nodesByTopologicalOrderMap.keySet()); /* deepest nodes have no descendants */ nodesByTopologicalOrderMap.get(maxDepth).forEach(node -> descendantsMap.put(node, new HashSet<>())); /* get all descendants by ascending the tree */ for (int depth = maxDepth - 1; depth >= 0; depth -= 1) { - for (final String node : nodesByTopologicalOrderMap.get(depth)) { - final Set nodeDescendants = new HashSet<>(); + for (final CacheNode.NodeKey node : nodesByTopologicalOrderMap.get(depth)) { + final Set nodeDescendants = new HashSet<>(); nodeDescendants.addAll(childrenMap.get(node)); - for (final String child : childrenMap.get(node)) { + for (final CacheNode.NodeKey child : childrenMap.get(node)) { nodeDescendants.addAll(descendantsMap.get(child)); } descendantsMap.put(node, nodeDescendants); @@ -202,16 +212,20 @@ private static Map> getDescendantsMap(@Nonnull final Map> getAncestorsMap(@Nonnull final Map> parentsMap, - @Nonnull final Map> nodesByTopologicalOrderMap) { - final Map> ancestorsMap = new HashMap<>(); + /** + * Generates a (nodeKey -> set of ancestors' keys) map. + */ + private static Map> getAncestorsMap( + @Nonnull final Map> parentsMap, + @Nonnull final Map> nodesByTopologicalOrderMap) { + final Map> ancestorsMap = new HashMap<>(); final int maxDepth = Collections.max(nodesByTopologicalOrderMap.keySet()); nodesByTopologicalOrderMap.get(0).forEach(node -> ancestorsMap.put(node, new HashSet<>())); for (int depth = 1; depth <= maxDepth; depth += 1) { - for (final String node : nodesByTopologicalOrderMap.get(depth)) { - final Set nodeAncestors = new HashSet<>(); + for (final CacheNode.NodeKey node : nodesByTopologicalOrderMap.get(depth)) { + final Set nodeAncestors = new HashSet<>(); nodeAncestors.addAll(parentsMap.get(node)); - for (final String parent : parentsMap.get(node)) { + for (final CacheNode.NodeKey parent : parentsMap.get(node)) { nodeAncestors.addAll(ancestorsMap.get(parent)); } ancestorsMap.put(node, nodeAncestors); @@ -220,12 +234,17 @@ private static Map> getAncestorsMap(@Nonnull final Map> getTopologicalOrderForNodeEvaluation(@Nonnull final Set nodeKeysSet, - @Nonnull final Map topologicalOrderMap, - @Nonnull final Map> ancestorsMap) { - final Map> topologicalOrderForNodeEvaluation = new HashMap<>(); - for (final String nodeKey : nodeKeysSet) { - final List allParentsIncludingTheNode = new ArrayList<>(); + /** + * Generates the topological order for evaluating a single node as a (nodeKey -> list of topologically-ordered + * node keys) map. + */ + private static Map> getTopologicalOrderForNodeEvaluation( + @Nonnull final Set nodeKeysSet, + @Nonnull final Map topologicalOrderMap, + @Nonnull final Map> ancestorsMap) { + final Map> topologicalOrderForNodeEvaluation = new HashMap<>(); + for (final CacheNode.NodeKey nodeKey : nodeKeysSet) { + final List allParentsIncludingTheNode = new ArrayList<>(); allParentsIncludingTheNode.addAll(ancestorsMap.get(nodeKey)); allParentsIncludingTheNode.add(nodeKey); /* sort by depth */ @@ -235,18 +254,23 @@ private static Map> getTopologicalOrderForNodeEvaluation(@N return topologicalOrderForNodeEvaluation; } - private static Map> getTopologicalOrderForTagEvaluation(@Nonnull final Set nodeTagsSet, - @Nonnull final Map topologicalOrderMap, - @Nonnull final Map> nodesByTagMap, - @Nonnull final Map> ancestorsMap) { - final Map> topologicalOrderForTagEvaluation = new HashMap<>(); - for (final String tag : nodeTagsSet) { - final Set ancestorsIncludingTheTaggedNodesSet = new HashSet<>(); - for (final String node : nodesByTagMap.get(tag)) { + /** + * Generates the topological order for evaluating all nodes associated to a tag as a (nodeKey -> list of + * topologically-ordered node keys) map. + */ + private static Map> getTopologicalOrderForTagEvaluation( + @Nonnull final Set nodeTagsSet, + @Nonnull final Map topologicalOrderMap, + @Nonnull final Map> nodesByTagMap, + @Nonnull final Map> ancestorsMap) { + final Map> topologicalOrderForTagEvaluation = new HashMap<>(); + for (final CacheNode.NodeTag tag : nodeTagsSet) { + final Set ancestorsIncludingTheTaggedNodesSet = new HashSet<>(); + for (final CacheNode.NodeKey node : nodesByTagMap.get(tag)) { ancestorsIncludingTheTaggedNodesSet.addAll(ancestorsMap.get(node)); ancestorsIncludingTheTaggedNodesSet.add(node); } - final List ancestorsIncludingTheTaggedNodesList = new ArrayList<>(); + final List ancestorsIncludingTheTaggedNodesList = new ArrayList<>(); ancestorsIncludingTheTaggedNodesList.addAll(ancestorsIncludingTheTaggedNodesSet); /* sort by depth */ ancestorsIncludingTheTaggedNodesList.sort(Comparator.comparingInt(topologicalOrderMap::get)); @@ -255,32 +279,38 @@ private static Map> getTopologicalOrderForTagEvaluation(@No return topologicalOrderForTagEvaluation; } - private static List getTopologicalOrderForCompleteEvaluation(@Nonnull final Set nodeKeysSet, - @Nonnull final Map topologicalOrderMap) { + /** + * Generates topological order for evaluating all nodes in the graph as a (nodeKey -> list of topologically-ordered + * node keys) map. + */ + private static List getTopologicalOrderForCompleteEvaluation( + @Nonnull final Set nodeKeysSet, + @Nonnull final Map topologicalOrderMap) { return new ArrayList<>(nodeKeysSet).stream() .sorted(Comparator.comparingInt(topologicalOrderMap::get)) .collect(Collectors.toList()); } /** - * Topological order for mutating a primitive/externally-computed node and updating the caches of the + * Generate topological order for mutating a primitive/externally-computed node and updating the caches of the * involved nodes. These include mutated node, its descendants, and the ancestors of all of its descendants. * The latter is required for updating the caches of the descendants. */ - private static Map> getTopologicalOrderForNodeMutation(@Nonnull final Set nodeKeysSet, - @Nonnull final Map> ancestorsMap, - @Nonnull final Map> descendantsMap, - @Nonnull final Map topologicalOrderMap) { - final Map> topologicalOrderForNodeMutation = new HashMap<>(); - for (final String nodeKey : nodeKeysSet) { - final Set allInvolvedNodesSet = new HashSet<>(); + private static Map> getTopologicalOrderForNodeMutation( + @Nonnull final Set nodeKeysSet, + @Nonnull final Map> ancestorsMap, + @Nonnull final Map> descendantsMap, + @Nonnull final Map topologicalOrderMap) { + final Map> topologicalOrderForNodeMutation = new HashMap<>(); + for (final CacheNode.NodeKey nodeKey : nodeKeysSet) { + final Set allInvolvedNodesSet = new HashSet<>(); allInvolvedNodesSet.add(nodeKey); allInvolvedNodesSet.addAll(descendantsMap.get(nodeKey)); - for (final String descendantNodeKey : descendantsMap.get(nodeKey)) { + for (final CacheNode.NodeKey descendantNodeKey : descendantsMap.get(nodeKey)) { allInvolvedNodesSet.add(descendantNodeKey); allInvolvedNodesSet.addAll(ancestorsMap.get(descendantNodeKey)); } - final List allInvolvedNodesList = new ArrayList<>(); + final List allInvolvedNodesList = new ArrayList<>(); allInvolvedNodesList.addAll(allInvolvedNodesSet); allInvolvedNodesList.sort(Comparator.comparingInt(topologicalOrderMap::get)); topologicalOrderForNodeMutation.put(nodeKey, allInvolvedNodesList); @@ -290,11 +320,12 @@ private static Map> getTopologicalOrderForNodeMutation(@Non /** * Updates the depth of a node recursively - * */ - private static void updateDepth(@Nonnull final String nodeKey, final int recursion, - @Nonnull Set nodeKeysSet, - @Nonnull Map> parentsMap, - @Nonnull Map topologicalOrderMap) { + */ + private static void updateDepth(@Nonnull final CacheNode.NodeKey nodeKey, + final int recursion, + @Nonnull Set nodeKeysSet, + @Nonnull Map> parentsMap, + @Nonnull Map topologicalOrderMap) { if (recursion > nodeKeysSet.size()) { throw new CyclicGraphException("The graph is not acyclic"); } @@ -312,47 +343,52 @@ private static void updateDepth(@Nonnull final String nodeKey, final int recursi /* do nothing otherwise -- we already have the order for this node */ } - Set getNodeKeysSet() { return nodeKeysSet; } + Set getNodeKeysSet() { return nodeKeysSet; } - Set getNodeTagsSet() { return nodeTagsSet; } + Set getNodeTagsSet() { return nodeTagsSet; } - Set getInducedTagsForNode(final String nodeKey) { + Set getInducedTagsForNode(final CacheNode.NodeKey nodeKey) { return inducedTagsMap.get(nodeKey); } - int getTopologicalOrder(@Nonnull final String nodeKey) { + /** + * Return the topological order for a node. Note: nodes at the same depth have the same topological order. + * @param nodeKey node key + * @return (integer) topological order + */ + int getTopologicalOrder(@Nonnull final CacheNode.NodeKey nodeKey) { return topologicalOrderMap.get(nodeKey); } - Set getChildren(@Nonnull final String nodeKey) { + Set getChildren(@Nonnull final CacheNode.NodeKey nodeKey) { return childrenMap.get(nodeKey); } - Set getParents(@Nonnull final String nodeKey) { + Set getParents(@Nonnull final CacheNode.NodeKey nodeKey) { return parentsMap.get(nodeKey); } - Set getAncestors(@Nonnull final String nodeKey) { + Set getAncestors(@Nonnull final CacheNode.NodeKey nodeKey) { return ancestorsMap.get(nodeKey); } - Set getDescendants(@Nonnull final String nodeKey) { + Set getDescendants(@Nonnull final CacheNode.NodeKey nodeKey) { return descendantsMap.get(nodeKey); } - List getTopologicalOrderForNodeEvaluation(final String nodeKey) { + List getTopologicalOrderForNodeEvaluation(final CacheNode.NodeKey nodeKey) { return topologicalOrderForNodeEvaluation.get(nodeKey); } - List getTopologicalOrderForNodeMutation(final String nodeKey) { + List getTopologicalOrderForNodeMutation(final CacheNode.NodeKey nodeKey) { return topologicalOrderForNodeMutation.get(nodeKey); } - List getTopologicalOrderForTagEvaluation(final String tagKey) { + List getTopologicalOrderForTagEvaluation(final CacheNode.NodeTag tagKey) { return topologicalOrderForTagEvaluation.get(tagKey); } - List getTopologicalOrderForCompleteEvaluation() { + List getTopologicalOrderForCompleteEvaluation() { return topologicalOrderForCompleteEvaluation; } diff --git a/src/main/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/ComputableNodeFunction.java b/src/main/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/ComputableNodeFunction.java index db4f25241..84317cf87 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/ComputableNodeFunction.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/ComputableNodeFunction.java @@ -13,7 +13,7 @@ @FunctionalInterface public interface ComputableNodeFunction { - Duplicable apply(final Map parents) throws ParentValueNotFoundException; + Duplicable apply(final Map parents) throws ParentValueNotFoundException; /** * Fetches a parent node value from a given map @@ -24,7 +24,8 @@ public interface ComputableNodeFunction { * @return an instance of {@link Duplicable} by reference * @throws ParentValueNotFoundException if the parent key is not in the map */ - default Duplicable fetch(final String key, final Map parents) throws ParentValueNotFoundException { + default Duplicable fetch(final CacheNode.NodeKey key, final Map parents) + throws ParentValueNotFoundException { if (!parents.containsKey(key)) { throw new ParentValueNotFoundException(key); } @@ -38,7 +39,7 @@ default Duplicable fetch(final String key, final Map parents * @param parents parent key-value map * @throws ParentValueNotFoundException if the parent key is not in the map */ - default INDArray fetchINDArray(final String key, final Map parents) + default INDArray fetchINDArray(final CacheNode.NodeKey key, final Map parents) throws ParentValueNotFoundException, ClassCastException { return ((DuplicableNDArray)fetch(key, parents)).value(); } @@ -51,7 +52,7 @@ default INDArray fetchINDArray(final String key, final Map p * @throws ParentValueNotFoundException if the parent key is not in the map */ @SuppressWarnings("unchecked") - default double fetchDouble(final String key, final Map parents) + default double fetchDouble(final CacheNode.NodeKey key, final Map parents) throws ParentValueNotFoundException, ClassCastException { return ((DuplicableNumber)fetch(key, parents)).value(); } @@ -63,7 +64,7 @@ default double fetchDouble(final String key, final Map paren final class ParentValueNotFoundException extends RuntimeException implements Serializable { private static final long serialVersionUID = -4557250891066141519L; - private ParentValueNotFoundException(final String nodeKey) { + private ParentValueNotFoundException(final CacheNode.NodeKey nodeKey) { super(String.format("The value of node \"%s\" is required for computation but it is not available", nodeKey)); } } diff --git a/src/main/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/ImmutableComputableGraph.java b/src/main/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/ImmutableComputableGraph.java index 21d8b86af..c74af1238 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/ImmutableComputableGraph.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/ImmutableComputableGraph.java @@ -53,13 +53,13 @@ * (1) Caching: these nodes may store the values they evaluate for future lookup * (2) Non-caching: these nodes are compute-on-demand * (3) Externally-computed: these nodes are constructed using a {@code null} evaluation function. The user is - * responsible for evaluating these nodes when required and using {@link #setValue(String, Duplicable)} + * responsible for evaluating these nodes when required and using {@link #setValue(CacheNode.NodeKey, Duplicable)} * to update them. This class performs the bookkeeping of the status of their values (up to date or * out of date) based on the provided parents list. * * One can require the caching computable nodes to be updated and cached automatically after each mutation of a * primitive or externally-computed node by invoking {@link ImmutableComputableGraphBuilder#withCacheAutoUpdate()}. - * This feature, however, is _not_ recommended as one may not need all the caches to be up-to-date at all times + * This feature, however, is not recommended as one may not need all the caches to be up-to-date at all times * (see below for updating caches selectively). This class throws exceptions if an out-of-date cached value is queried * in order to notify the user to update the cache manually. * @@ -67,9 +67,9 @@ * ================ * * If cache auto update is not enabled, the user is responsible for updating the caches by calling - * either {@link #updateAllCaches()}, {@link #updateCachesForNode(String)}, or {@link #updateCachesForTag(String)}. + * either {@link #updateAllCaches()}, {@link #updateCachesForNode(CacheNode.NodeKey)}, or {@link #updateCachesForTag(CacheNode.NodeTag)}. * These methods also come with counterparts {@link #updateAllCachesIfPossible()}, - * {@link #updateCachesForNodeIfPossible(String)}, and {@link #updateCachesForTagIfPossible(String)}. These methods + * {@link #updateCachesForNodeIfPossible(CacheNode.NodeKey)}, and {@link #updateCachesForTagIfPossible(CacheNode.NodeTag)}. These methods * are safeguarded against throwing exceptions (e.g. if a required primitive or externally-computed value is not * available) and try to update as many cache nodes as possible. * @@ -79,7 +79,7 @@ * Tags are arbitrary string identifiers used for grouping nodes that are semantically related. The user may want to update the * nodes associated to the same tag simultaneously, for example, when performing an operation that requires updated * values for all nodes that denote subexpressions of a larger expression. In this case, the user will tag all nodes - * that appear in the larger expression with a common name and calls {@link #updateCachesForTag(String)}. + * that appear in the larger expression with a common name and calls {@link #updateCachesForTag(CacheNode.NodeTag)}. * * Tags are inherited from descendents to ancestors. In the above example, if Q_2 is tagged with {"FOO"} and Q_1 is tagged * with {"BAR"}, then X and Y both inherit {"FOO", "BAR"} tags whereas Z only inherits {"FOO"}. @@ -87,10 +87,10 @@ * Querying: * ========= * - * The graph is queried either by calling {@link #fetchDirectly(String)} or - * {@link #fetchWithRequiredEvaluations(String)}. The former only fetches the values and throws an exception + * The graph is queried either by calling {@link #fetchDirectly(CacheNode.NodeKey)} or + * {@link #fetchWithRequiredEvaluations(CacheNode.NodeKey)}. The former only fetches the values and throws an exception * if the some of the required nodes are out of date, or a non-caching computable node is queried. The latter performs - * the required _necessary_ evaluations along the way. + * the required necessary evaluations along the way. * * IMPORTANT NOTE: the queried values are returned by reference. It is the user's responsibility not to mutate * them. Otherwise, immutability will be broken. @@ -98,13 +98,13 @@ * Mutation: * ========= * - * The primitive values can be mutated by calling {@link #setValue(String, Duplicable)}. Mutations do not occur in place; + * The primitive values can be mutated by calling {@link #setValue(CacheNode.NodeKey, Duplicable)}. Mutations do not occur in place; * rather, a new instance of {@link ImmutableComputableGraph} is created along with new instances for the updated nodes. * Immutability is desired, for instance, if this class is used as elements of a {@link org.apache.spark.api.java.JavaRDD}. * If cache auto update is enabled, the affected nodes will be automatically computed and cached. Otherwise, only the * cache status go out of date and the old stored values are {@code null}ed. * - * Note: the new {@link ImmutableComputableGraph} instance returned by {@link #setValue(String, Duplicable)} is _not_ a + * Note: the new {@link ImmutableComputableGraph} instance returned by {@link #setValue(CacheNode.NodeKey, Duplicable)} is not a * deep copy and may hold references to {@link CacheNode}s contained the previous instance(s). * * @author Mehrtash Babadi <mehrtash@broadinstitute.org> @@ -113,9 +113,9 @@ public final class ImmutableComputableGraph implements Serializable { private static final long serialVersionUID = -1162776031416105027L; - private static Map EMPTY_MAP = new HashMap<>(); + private static Map EMPTY_NODE_KEY_VALUE_MAP = new HashMap<>(); - private final Map nodesMap; + private final Map nodesMap; private final boolean cacheAutoUpdate; private final ComputableGraphStructure cgs; @@ -142,7 +142,7 @@ public static ImmutableComputableGraphBuilder builder() { * @param nodesMap a previously constructed key -> {@link CacheNode} map * @param cgs a previously constructed {@link ComputableGraphStructure} */ - private ImmutableComputableGraph(@Nonnull final Map nodesMap, + private ImmutableComputableGraph(@Nonnull final Map nodesMap, @Nonnull final ComputableGraphStructure cgs, final boolean cacheAutoUpdate) { this.nodesMap = nodesMap; @@ -159,20 +159,20 @@ private ImmutableComputableGraph(@Nonnull final Map nodesMap, * @throws IllegalArgumentException if the node does not exist * @throws UnsupportedOperationException if the node is non-primitive */ - public ImmutableComputableGraph setValue(@Nonnull final String nodeKey, + public ImmutableComputableGraph setValue(@Nonnull final CacheNode.NodeKey nodeKey, @Nonnull final Duplicable newValue) throws IllegalArgumentException, UnsupportedOperationException { CacheNode node = nodesMap.get(assertNodeExists(nodeKey)); if (!node.isExternallyComputed()) { throw new UnsupportedOperationException("Can not explicitly set the value of a non-primitive cache node."); } - final Map updatedNodesMap = new HashMap<>(); + final Map updatedNodesMap = new HashMap<>(); updatedNodesMap.put(nodeKey, node.duplicateWithUpdatedValue(newValue)); final ImmutableComputableGraph out = duplicateWithUpdatedNodes( addDuplicateOfOutdatedDescendants(nodeKey, updatedNodesMap)); if (cacheAutoUpdate) { try { /* try to update caches; it is not guaranteed if some of the nodes are not initialized */ - final Map accumulatedValues = out.evaluateInTopologicalOrder( + final Map accumulatedValues = out.evaluateInTopologicalOrder( cgs.getTopologicalOrderForNodeMutation(nodeKey)); return out.updateCachesFromAccumulatedValues(accumulatedValues); } catch (final PrimitiveCacheNode.PrimitiveValueNotInitializedException | @@ -190,9 +190,10 @@ public ImmutableComputableGraph setValue(@Nonnull final String nodeKey, * @param key key of the updated node * @param updatedNodesMap a key -> node map */ - private Map addDuplicateOfOutdatedDescendants(@Nonnull final String key, - @Nonnull final Map updatedNodesMap) { - for (final String descendant : cgs.getDescendants(key)) { + private Map addDuplicateOfOutdatedDescendants( + @Nonnull final CacheNode.NodeKey key, + @Nonnull final Map updatedNodesMap) { + for (final CacheNode.NodeKey descendant : cgs.getDescendants(key)) { CacheNode oldDescendant = nodesMap.get(descendant); /* all of the descendants are computable nodes and can be safely up-casted */ updatedNodesMap.put(descendant, ((ComputableCacheNode)oldDescendant).duplicateWithOutdatedCacheStatus()); @@ -203,7 +204,7 @@ private Map addDuplicateOfOutdatedDescendants(@Nonnull final /** * Returns a reference to the value of a given node * - * Note: this function is purposefully meant to be _light_ in the following sense: + * Note: this function is purposefully meant to be light in the following sense: * * (1) it does not update out-of-date caching computable nodes, and * (2) it does not evaluate non-caching computable nodes. @@ -214,28 +215,28 @@ private Map addDuplicateOfOutdatedDescendants(@Nonnull final * not initialized * @throws IllegalArgumentException if the node does not exist */ - public Duplicable fetchDirectly(@Nonnull final String nodeKey) throws IllegalArgumentException { - return nodesMap.get(assertNodeExists(nodeKey)).get(EMPTY_MAP); + public Duplicable fetchDirectly(@Nonnull final CacheNode.NodeKey nodeKey) throws IllegalArgumentException { + return nodesMap.get(assertNodeExists(nodeKey)).get(EMPTY_NODE_KEY_VALUE_MAP); } /** * Returns the value of a node. If the node is computable and its value is not available, all of the required * intermediate calculations will be done. * - * Note: the result of intermediate calculations are _not_ stored back into the (possibly out-of-date) ancestor + * Note: the result of intermediate calculations are not stored back into the (possibly out-of-date) ancestor * nodes. This may result in redundant calculations. In generic situations, the most efficient approach is to * update the required cache nodes first, and then fetch the up-to-date cached values using - * {@link #fetchDirectly(String)} instead. + * {@link #fetchDirectly(CacheNode.NodeKey)} instead. * * @param nodeKey key of the node * @return value of the node * @throws IllegalArgumentException if the node does not exist */ - public Duplicable fetchWithRequiredEvaluations(@Nonnull final String nodeKey) + public Duplicable fetchWithRequiredEvaluations(@Nonnull final CacheNode.NodeKey nodeKey) throws IllegalStateException, IllegalArgumentException { final CacheNode node = nodesMap.get(assertNodeExists(nodeKey)); if (node.hasValue()) { - return node.get(EMPTY_MAP); + return node.get(EMPTY_NODE_KEY_VALUE_MAP); } else { return evaluateInTopologicalOrder(cgs.getTopologicalOrderForNodeEvaluation(nodeKey)).get(nodeKey); } @@ -262,9 +263,10 @@ public Duplicable fetchWithRequiredEvaluations(@Nonnull final String nodeKey) * initialized * @return a map from node keys to their values accumulated during computation */ - private Map evaluateInTopologicalOrder(@Nonnull final List topologicallyOrderedNodeKeys) { - final Map accumulatedValues = new HashMap<>(); - for (final String nodeKey : topologicallyOrderedNodeKeys) { + private Map evaluateInTopologicalOrder( + @Nonnull final List topologicallyOrderedNodeKeys) { + final Map accumulatedValues = new HashMap<>(); + for (final CacheNode.NodeKey nodeKey : topologicallyOrderedNodeKeys) { accumulatedValues.put(nodeKey, nodesMap.get(nodeKey).get(accumulatedValues)); } return accumulatedValues; @@ -278,9 +280,10 @@ private Map evaluateInTopologicalOrder(@Nonnull final List evaluateInTopologicalOrderIfPossible(@Nonnull final List topologicallyOrderedNodeKeys) { - final Map accumulatedValues = new HashMap<>(); - for (final String nodeKey : topologicallyOrderedNodeKeys) { + private Map evaluateInTopologicalOrderIfPossible( + @Nonnull final List topologicallyOrderedNodeKeys) { + final Map accumulatedValues = new HashMap<>(); + for (final CacheNode.NodeKey nodeKey : topologicallyOrderedNodeKeys) { Duplicable value = null; try { value = nodesMap.get(nodeKey).get(accumulatedValues); @@ -301,7 +304,8 @@ private Map evaluateInTopologicalOrderIfPossible(@Nonnull fi * @param accumulatedValues a nodekey -> duplicable map for possibly affected nodes * @return a new instance of {@link ImmutableComputableGraph} with new instances of updated nodes */ - private ImmutableComputableGraph updateCachesFromAccumulatedValues(final Map accumulatedValues) { + private ImmutableComputableGraph updateCachesFromAccumulatedValues( + @Nonnull final Map accumulatedValues) { /* since accumulatedValues may contain unchanged values (by reference), we filter and only update * the affected nodes */ return duplicateWithUpdatedNodes( @@ -321,7 +325,7 @@ private ImmutableComputableGraph updateCachesFromAccumulatedValues(final Map, Map> topologicalEvaluator, - @Nonnull final List topologicallyOrderedNodeKeys) { + @Nonnull final Function, Map> topologicalEvaluator, + @Nonnull final List topologicallyOrderedNodeKeys) { return updateCachesFromAccumulatedValues(topologicalEvaluator.apply(topologicallyOrderedNodeKeys)); } @@ -386,9 +390,9 @@ private ImmutableComputableGraph evaluateAndUpdateCaches( * @param updatedNodesMap nodes to be replaced and their new values * @return a new instance of {@link ImmutableComputableGraph} with new instances of updated nodes */ - private ImmutableComputableGraph duplicateWithUpdatedNodes(final Map updatedNodesMap) { - final Map newNodesMap = new HashMap<>(); - final Set updatedNodeKeys = updatedNodesMap.keySet(); + private ImmutableComputableGraph duplicateWithUpdatedNodes(final Map updatedNodesMap) { + final Map newNodesMap = new HashMap<>(); + final Set updatedNodeKeys = updatedNodesMap.keySet(); /* intact nodes */ cgs.getNodeKeysSet().stream() .filter(node -> !updatedNodeKeys.contains(node)) @@ -398,19 +402,21 @@ private ImmutableComputableGraph duplicateWithUpdatedNodes(final Map nodes; - private final Set keys; + private final Set keys; private boolean cacheAutoUpdate; ImmutableComputableGraphBuilder() { @@ -35,8 +35,11 @@ public static class ImmutableComputableGraphBuilder { cacheAutoUpdate = false; } - public ImmutableComputableGraphBuilder primitiveNode(@Nonnull final String key, - @Nonnull final String[] tags, + /** + * Add a primitive node with specified value + */ + public ImmutableComputableGraphBuilder primitiveNode(@Nonnull final CacheNode.NodeKey key, + @Nonnull final CacheNode.NodeTag[] tags, @Nonnull Duplicable value) { Utils.nonNull(key); Utils.nonNull(tags); @@ -47,14 +50,19 @@ public ImmutableComputableGraphBuilder primitiveNode(@Nonnull final String key, return this; } - - public ImmutableComputableGraphBuilder primitiveNodeWithEmptyNDArray(@Nonnull final String key) { - return primitiveNode(key, new String[]{}, new DuplicableNDArray()); + /** + * Add an uninitialized primitive node holding an empty {@link DuplicableNDArray} and no tags + */ + public ImmutableComputableGraphBuilder primitiveNodeWithEmptyNDArray(@Nonnull final CacheNode.NodeKey key) { + return primitiveNode(key, new CacheNode.NodeTag[] {}, new DuplicableNDArray()); } - public ImmutableComputableGraphBuilder computableNode(@Nonnull final String key, - @Nonnull final String[] tags, - @Nonnull final String[] parents, + /** + * Add a computable node + */ + public ImmutableComputableGraphBuilder computableNode(@Nonnull final CacheNode.NodeKey key, + @Nonnull final CacheNode.NodeTag[] tags, + @Nonnull final CacheNode.NodeKey[] parents, @Nullable final ComputableNodeFunction func, final boolean cacheEvals) { Utils.nonNull(key); @@ -69,23 +77,34 @@ public ImmutableComputableGraphBuilder computableNode(@Nonnull final String key, return this; } - public ImmutableComputableGraphBuilder externallyComputableNode(@Nonnull final String key) { - return computableNode(key, new String[] {}, new String[] {}, null, true); + /** + * Add an "untracked" externally computable node (i.e. no parents and no tags). Note: it is the user's + * responsibility to keep track of the value of these nodes and updating them if necessary. Since these nodes + * have no parents, {@link ImmutableComputableGraph} will not perform any bookkeeping on them. + */ + public ImmutableComputableGraphBuilder untrackedExternallyComputableNode(@Nonnull final CacheNode.NodeKey key) { + return computableNode(key, new CacheNode.NodeTag[] {}, new CacheNode.NodeKey[] {}, null, true); } + /** + * Enable cache auto-update for all computable nodes in the graph + */ public ImmutableComputableGraphBuilder withCacheAutoUpdate() { cacheAutoUpdate = true; return this; } + /** + * Disable cache auto-update for all computable nodes in the graph + */ public ImmutableComputableGraphBuilder withoutCacheAutoUpdate() { cacheAutoUpdate = false; return this; } - private void assertKeyUniqueness(@Nonnull final String key) { + private void assertKeyUniqueness(@Nonnull final CacheNode.NodeKey key) { if (keys.contains(key)) { - throw new DuplicateNodeKeyException("A node with key " + quote(key) + " already exists"); + throw new DuplicateNodeKeyException("A node with key " + quote(key.toString()) + " already exists"); } } @@ -112,4 +131,4 @@ static final class DuplicateNodeKeyException extends RuntimeException { static String quote(final String str) { return "\"" + str + "\""; } -} +} \ No newline at end of file diff --git a/src/main/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/PrimitiveCacheNode.java b/src/main/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/PrimitiveCacheNode.java index 81cc6277e..07d976eac 100644 --- a/src/main/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/PrimitiveCacheNode.java +++ b/src/main/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/PrimitiveCacheNode.java @@ -27,8 +27,8 @@ void set(@Nullable final Duplicable val) { value = val; } - PrimitiveCacheNode(@Nonnull final String key, - @Nonnull final Collection tags, + PrimitiveCacheNode(@Nonnull final NodeKey key, + @Nonnull final Collection tags, @Nullable final Duplicable val) { super(key, tags, Collections.emptyList()); set(val); @@ -40,13 +40,12 @@ boolean hasValue() { } @Override - Duplicable get(@Nullable final Map parentsValues) + Duplicable get(@Nullable final Map parentsValues) throws PrimitiveValueNotInitializedException { if (hasValue()) { return value; } else { - throw new PrimitiveValueNotInitializedException(String.format( - "The primitive cache \"%s\" is not initialized yet", getKey())); + throw new PrimitiveValueNotInitializedException(getKey()); } } @@ -70,8 +69,9 @@ PrimitiveCacheNode duplicateWithUpdatedValue(final Duplicable newValue) { static final class PrimitiveValueNotInitializedException extends RuntimeException implements Serializable { private static final long serialVersionUID = 6036472510998845566L; - private PrimitiveValueNotInitializedException(String s) { - super(s); + PrimitiveValueNotInitializedException(final NodeKey nodeKey) { + super("The primitive cache " + ImmutableComputableGraphUtils.quote(nodeKey.toString()) + + " is not initialized yet"); } } } diff --git a/src/test/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/CacheNodeUnitTest.java b/src/test/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/CacheNodeUnitTest.java new file mode 100644 index 000000000..562a9853e --- /dev/null +++ b/src/test/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/CacheNodeUnitTest.java @@ -0,0 +1,149 @@ +package org.broadinstitute.hellbender.tools.coveragemodel.cachemanager; + +import org.broadinstitute.hellbender.utils.MathObjectAsserts; +import org.broadinstitute.hellbender.utils.test.BaseTest; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.testng.Assert; +import org.testng.annotations.Test; + +import java.util.*; + +/** + * Unit tests for {@link CacheNode}. + * + * @author Mehrtash Babadi <mehrtash@broadinstitute.org> + */ +public class CacheNodeUnitTest extends BaseTest { + private static final List EMPTY_NODE_TAG_LIST = Collections.emptyList(); + private static final List EMPTY_NODE_KEY_LIST = Collections.emptyList(); + private static final Map EMPTY_PARENTS = Collections.emptyMap(); + + /** + * Asserts that the equality comparison of two {@link CacheNode}s is done just based on their key + */ + @Test + public void testEquality() { + final List nodesWithOneKey = getRandomCollectionOfNodesWithTheSameKey(new CacheNode.NodeKey("ONE_KEY")); + final List nodesWithAnotherKey = getRandomCollectionOfNodesWithTheSameKey(new CacheNode.NodeKey("ANOTHER_KEY")); + + for (final CacheNode node_0 : nodesWithOneKey) { + for (final CacheNode node_1 : nodesWithOneKey) { + Assert.assertTrue(node_0.equals(node_1) || (node_0.getClass() != node_1.getClass())); + } + } + + for (final CacheNode node_0 : nodesWithAnotherKey) { + for (final CacheNode node_1 : nodesWithAnotherKey) { + Assert.assertTrue(node_0.equals(node_1) || (node_0.getClass() != node_1.getClass())); + } + } + + for (final CacheNode node_0 : nodesWithOneKey) { + for (final CacheNode node_1 : nodesWithAnotherKey) { + Assert.assertTrue(!node_0.equals(node_1)); + } + } + } + + @Test + public void testToString() { + final List nodesWithOneKey = getRandomCollectionOfNodesWithTheSameKey(new CacheNode.NodeKey("ONE_KEY")); + for (final CacheNode node : nodesWithOneKey) { + Assert.assertTrue(node.toString().equals("ONE_KEY")); + } + } + + private List getRandomCollectionOfNodesWithTheSameKey(final CacheNode.NodeKey key) { + final List collection = new ArrayList<>(); + collection.add(new PrimitiveCacheNode(key, EMPTY_NODE_TAG_LIST, null)); + collection.add(new PrimitiveCacheNode(key, + Arrays.asList(new CacheNode.NodeTag("a"), new CacheNode.NodeTag("b"), new CacheNode.NodeTag("c")), + new DuplicableNumber<>(1.0))); + collection.add(new ComputableCacheNode(key, EMPTY_NODE_TAG_LIST, + Arrays.asList(new CacheNode.NodeKey("d"), new CacheNode.NodeKey("e")), + ImmutableComputableGraphUnitTest.f_computation_function, false)); + collection.add(new ComputableCacheNode(key, + EMPTY_NODE_TAG_LIST, + EMPTY_NODE_KEY_LIST, + ImmutableComputableGraphUnitTest.f_computation_function, false)); + collection.add(new ComputableCacheNode(key, + Arrays.asList(new CacheNode.NodeTag("f")), + Arrays.asList(new CacheNode.NodeKey("g")), null, true)); + return collection; + } + + @Test(expectedExceptions = UnsupportedOperationException.class) + public void testSetValueOfAutomaticallyComputableNode() { + new ComputableCacheNode(new CacheNode.NodeKey("TEST"), + EMPTY_NODE_TAG_LIST, EMPTY_NODE_KEY_LIST, ImmutableComputableGraphUnitTest.h_computation_function, false) + .set(new DuplicableNDArray(ImmutableComputableGraphUnitTest.getRandomINDArray())); + } + + @Test + public void testSetValueOfExternallyComputableNode() { + final ComputableCacheNode node = new ComputableCacheNode(new CacheNode.NodeKey("TEST"), + EMPTY_NODE_TAG_LIST, EMPTY_NODE_KEY_LIST, null, true); + final INDArray arr = ImmutableComputableGraphUnitTest.getRandomINDArray(); + node.set(new DuplicableNDArray(arr)); + MathObjectAsserts.assertNDArrayEquals((INDArray)node.get(EMPTY_PARENTS).value(), arr); + } + + @Test + public void testSetValueOfPrimitiveNode() { + final PrimitiveCacheNode node = new PrimitiveCacheNode(new CacheNode.NodeKey("TEST"), EMPTY_NODE_TAG_LIST, null); + final INDArray arr = ImmutableComputableGraphUnitTest.getRandomINDArray(); + node.set(new DuplicableNDArray(arr)); + MathObjectAsserts.assertNDArrayEquals((INDArray)node.get(EMPTY_PARENTS).value(), arr); + } + + @Test + public void testPrimitiveNodeDuplication() { + final PrimitiveCacheNode node = new PrimitiveCacheNode(new CacheNode.NodeKey("TEST"), EMPTY_NODE_TAG_LIST, + new DuplicableNDArray(ImmutableComputableGraphUnitTest.getRandomINDArray())); + final PrimitiveCacheNode dupNode = node.duplicate(); + MathObjectAsserts.assertNDArrayEquals((INDArray)node.get(EMPTY_PARENTS).value(), + (INDArray)dupNode.get(EMPTY_PARENTS).value()); + Assert.assertTrue(dupNode.hasValue()); + Assert.assertTrue(dupNode.getKey().equals(new CacheNode.NodeKey("TEST"))); + } + + @Test + public void testCachingComputableNodeDuplication() { + final INDArray testArray = ImmutableComputableGraphUnitTest.getRandomINDArray(); + final Duplicable testDuplicable = new DuplicableNDArray(testArray); + final ComputableNodeFunction trivialFunction = parents -> testDuplicable; + + final ComputableCacheNode cachingAutoNodeUncached = new ComputableCacheNode(new CacheNode.NodeKey("TEST"), + EMPTY_NODE_TAG_LIST, EMPTY_NODE_KEY_LIST, trivialFunction, true); + final ComputableCacheNode cachingAutoNodeUncachedDup = cachingAutoNodeUncached.duplicate(); + + final ComputableCacheNode cachingAutoNodeCached = cachingAutoNodeUncached.duplicateWithUpdatedValue(testDuplicable); + final ComputableCacheNode cachingAutoNodeCachedDup = cachingAutoNodeCached.duplicate(); + + final ComputableCacheNode cachingAutoNodeCachedOutdated = cachingAutoNodeCached.duplicateWithOutdatedCacheStatus(); + final ComputableCacheNode cachingAutoNodeCachedOutdatedDup = cachingAutoNodeCachedOutdated.duplicate(); + + Assert.assertTrue(cachingAutoNodeUncached.isCaching()); + Assert.assertTrue(cachingAutoNodeUncachedDup.isCaching()); + Assert.assertTrue(!cachingAutoNodeUncached.isExternallyComputed()); + Assert.assertTrue(!cachingAutoNodeUncachedDup.isExternallyComputed()); + Assert.assertTrue(!cachingAutoNodeUncached.hasValue()); + Assert.assertTrue(!cachingAutoNodeUncachedDup.hasValue()); + + Assert.assertTrue(cachingAutoNodeCached.isCaching()); + Assert.assertTrue(cachingAutoNodeCachedDup.isCaching()); + Assert.assertTrue(!cachingAutoNodeCached.isExternallyComputed()); + Assert.assertTrue(!cachingAutoNodeCachedDup.isExternallyComputed()); + Assert.assertTrue(cachingAutoNodeCached.hasValue()); + Assert.assertTrue(cachingAutoNodeCachedDup.hasValue()); + MathObjectAsserts.assertNDArrayEquals((INDArray)cachingAutoNodeCached.get(EMPTY_PARENTS).value(), + (INDArray)cachingAutoNodeCachedDup.get(EMPTY_PARENTS).value()); + + Assert.assertTrue(cachingAutoNodeCachedOutdated.isCaching()); + Assert.assertTrue(cachingAutoNodeCachedOutdatedDup.isCaching()); + Assert.assertTrue(!cachingAutoNodeCachedOutdated.isExternallyComputed()); + Assert.assertTrue(!cachingAutoNodeCachedOutdatedDup.isExternallyComputed()); + Assert.assertTrue(!cachingAutoNodeCachedOutdated.hasValue()); /* outdated caches must drop out */ + Assert.assertTrue(!cachingAutoNodeCachedOutdatedDup.hasValue()); /* outdated caches must drop out */ + } +} diff --git a/src/test/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/ComputableGraphStructureUnitTest.java b/src/test/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/ComputableGraphStructureUnitTest.java index b80a92756..2264f4a7c 100644 --- a/src/test/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/ComputableGraphStructureUnitTest.java +++ b/src/test/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/ComputableGraphStructureUnitTest.java @@ -26,48 +26,57 @@ public class ComputableGraphStructureUnitTest extends BaseTest { private static final int MAX_PARENTS_PER_NODE = 10; private static final int NUM_TRIALS = 5; + private static final CacheNode.NodeKey X_KEY = new CacheNode.NodeKey("x"); + private static final CacheNode.NodeKey Y_KEY = new CacheNode.NodeKey("y"); + private static final CacheNode.NodeKey Z_KEY = new CacheNode.NodeKey("z"); + private static final CacheNode.NodeKey F_KEY = new CacheNode.NodeKey("f"); + private static final CacheNode.NodeKey G_KEY = new CacheNode.NodeKey("g"); + private static final CacheNode.NodeKey H_KEY = new CacheNode.NodeKey("h"); + @Test(expectedExceptions = ComputableGraphStructure.NonexistentParentNodeKey.class) public void testMissingParents() { + final CacheNode.NodeKey Q_KEY = new CacheNode.NodeKey("q"); ImmutableComputableGraph.builder() - .primitiveNode("x", new String[] {}, new DuplicableNDArray()) - .primitiveNode("y", new String[] {}, new DuplicableNumber()) - .primitiveNode("z", new String[] {}, new DuplicableNDArray()) - .computableNode("f", new String[] {}, new String[] {"x", "y", "q"}, null, true) /* q is undefined */ - .computableNode("g", new String[] {}, new String[] {"y", "z"}, null, true) - .computableNode("h", new String[] {}, new String[] {"f", "g"}, null, true) + .primitiveNode(X_KEY, new CacheNode.NodeTag[] {}, new DuplicableNDArray()) + .primitiveNode(Y_KEY, new CacheNode.NodeTag[] {}, new DuplicableNumber()) + .primitiveNode(Z_KEY, new CacheNode.NodeTag[] {}, new DuplicableNDArray()) + .computableNode(F_KEY, new CacheNode.NodeTag[] {}, new CacheNode.NodeKey[] {X_KEY, Y_KEY, Z_KEY, Q_KEY}, null, true) /* q is undefined */ + .computableNode(G_KEY, new CacheNode.NodeTag[] {}, new CacheNode.NodeKey[] {Y_KEY, Z_KEY}, null, true) + .computableNode(H_KEY, new CacheNode.NodeTag[] {}, new CacheNode.NodeKey[] {F_KEY, G_KEY}, null, true) .build(); } @Test(expectedExceptions = ComputableGraphStructure.CyclicGraphException.class) public void testCyclicGraphException_1() { + final CacheNode.NodeKey W_KEY = new CacheNode.NodeKey("w"); ImmutableComputableGraph.builder() - .primitiveNode("x", new String[] {}, new DuplicableNDArray()) - .computableNode("y", new String[] {}, new String[] {"x", "w"}, null, true) /* cycle */ - .computableNode("z", new String[] {}, new String[] {"y"}, null, true) - .computableNode("w", new String[] {}, new String[] {"z"}, null, true) + .primitiveNode(X_KEY, new CacheNode.NodeTag[] {}, new DuplicableNDArray()) + .computableNode(Y_KEY, new CacheNode.NodeTag[] {}, new CacheNode.NodeKey[] {X_KEY, W_KEY}, null, true) /* cycle */ + .computableNode(Z_KEY, new CacheNode.NodeTag[] {}, new CacheNode.NodeKey[] {Y_KEY}, null, true) + .computableNode(W_KEY, new CacheNode.NodeTag[] {}, new CacheNode.NodeKey[] {Z_KEY}, null, true) .build(); } @Test(expectedExceptions = ComputableGraphStructure.CyclicGraphException.class) public void testCyclicGraphException_2() { ImmutableComputableGraph.builder() - .primitiveNode("x", new String[] {}, new DuplicableNDArray()) - .primitiveNode("y", new String[] {}, new DuplicableNDArray()) - .primitiveNode("z", new String[] {}, new DuplicableNDArray()) - .computableNode("f", new String[] {}, new String[] {"x", "y", "h"}, null, true) /* cycle */ - .computableNode("g", new String[] {}, new String[] {"y", "z"}, null, true) - .computableNode("h", new String[] {}, new String[] {"f", "g"}, null, true) + .primitiveNode(X_KEY, new CacheNode.NodeTag[] {}, new DuplicableNDArray()) + .primitiveNode(Y_KEY, new CacheNode.NodeTag[] {}, new DuplicableNDArray()) + .primitiveNode(Z_KEY, new CacheNode.NodeTag[] {}, new DuplicableNDArray()) + .computableNode(F_KEY, new CacheNode.NodeTag[] {}, new CacheNode.NodeKey[] {X_KEY, Y_KEY, H_KEY}, null, true) /* cycle */ + .computableNode(G_KEY, new CacheNode.NodeTag[] {}, new CacheNode.NodeKey[] {Y_KEY, Z_KEY}, null, true) + .computableNode(H_KEY, new CacheNode.NodeTag[] {}, new CacheNode.NodeKey[] {F_KEY, G_KEY}, null, true) .build(); } @Test(invocationCount = NUM_TRIALS) - public void testNodeTagsAndKeys() { + public void testNodeTagsAndKeysInitializationAndAccessors() { final RandomDAG dag = RandomDAG.getRandomDAG(); final ComputableGraphStructure cgs = new ComputableGraphStructure(dag.getEquivalentCacheNodeSet()); - final Set cgsNodeTagsSet = cgs.getNodeTagsSet(); - final Set dagNodeTagsSet = dag.tagsSet; - final Set cgsNodeKeysSet = cgs.getNodeKeysSet(); - final Set dagNodeKeysSet = dag.nodeKeysSet; + final Set cgsNodeTagsSet = cgs.getNodeTagsSet(); + final Set dagNodeTagsSet = dag.tagsSet; + final Set cgsNodeKeysSet = cgs.getNodeKeysSet(); + final Set dagNodeKeysSet = dag.nodeKeysSet; Assert.assertTrue(cgsNodeKeysSet.equals(dagNodeKeysSet)); Assert.assertTrue(cgsNodeTagsSet.equals(dagNodeTagsSet)); } @@ -143,7 +152,7 @@ public void testTopologicalOrderForTagEvaluation() { public void testTopologicalOrderForCompleteEvaluation() { final RandomDAG dag = RandomDAG.getRandomDAG(); final ComputableGraphStructure cgs = new ComputableGraphStructure(dag.getEquivalentCacheNodeSet()); - final List orderedNodes = new ArrayList<>(dag.nodeKeysSet); + final List orderedNodes = new ArrayList<>(dag.nodeKeysSet); orderedNodes.sort(Comparator.comparingInt(dag.topologicalOrderMap::get)); Assert.assertTrue(isTopologicallyEquivalent(cgs.getTopologicalOrderForCompleteEvaluation(), orderedNodes, dag.topologicalOrderMap)); @@ -181,12 +190,12 @@ private static final class RandomDAG { private static final int TAG_LENGTH = 32; private static final int NODE_KEY_LENGTH = 32; - final Map> nodesByTopologicalOrder; - final Map topologicalOrderMap; - final Map> parentsMap; - final Map> tagsMap; - final Set nodeKeysSet; - final Set tagsSet; + final Map> nodesByTopologicalOrder; + final Map topologicalOrderMap; + final Map> parentsMap; + final Map> tagsMap; + final Set nodeKeysSet; + final Set tagsSet; private RandomDAG(final int depth, final int maxNodesPerLayer, final int maxParentsPerNode, final int maxTagsPerNode) { Utils.validateArg(depth >= 0, "DAG depth must be >= 0"); @@ -200,20 +209,20 @@ private RandomDAG(final int depth, final int maxNodesPerLayer, final int maxPare nodeKeysSet = new HashSet<>(); for (int d = 0; d <= depth; d++) { - final Set randomNodes = getRandomNodeKeys(maxNodesPerLayer); + final Set randomNodes = getUniqueRandomNodeKeys(maxNodesPerLayer, d); nodeKeysSet.addAll(randomNodes); nodesByTopologicalOrder.put(d, randomNodes); randomNodes.forEach(nodeKey -> tagsMap.put(nodeKey, getRandomTags(maxTagsPerNode))); randomNodes.forEach(nodeKey -> parentsMap.put(nodeKey, new HashSet<>())); if (d > 0) { - final Set possibleAncestors = IntStream.range(0, d) + final Set possibleAncestors = IntStream.range(0, d) .mapToObj(nodesByTopologicalOrder::get) .flatMap(Set::stream) .collect(Collectors.toSet()); - for (final String nodeKey : randomNodes) { - final String randomParent = getRandomElement(nodesByTopologicalOrder.get(d - 1)); + for (final CacheNode.NodeKey nodeKey : randomNodes) { + final CacheNode.NodeKey randomParent = getRandomElement(nodesByTopologicalOrder.get(d - 1)); final int numParents = rng.nextInt(maxParentsPerNode); - final Set randomAncestors = IntStream.range(0, numParents) + final Set randomAncestors = IntStream.range(0, numParents) .mapToObj(i -> getRandomElement(possibleAncestors)) .collect(Collectors.toSet()); parentsMap.put(nodeKey, new HashSet<>()); @@ -235,15 +244,15 @@ static RandomDAG getRandomDAG() { 1 + rng.nextInt(MAX_TAGS_PER_NODE)); } - Set getParents(final String nodeKey) { + Set getParents(final CacheNode.NodeKey nodeKey) { return parentsMap.get(nodeKey); } /** * Brute-force method */ - Set getAncestors(final String nodeKey) { - final Set parents = getParents(nodeKey); + Set getAncestors(final CacheNode.NodeKey nodeKey) { + final Set parents = getParents(nodeKey); return Sets.union(parents, parents.stream() .map(this::getAncestors) .flatMap(Set::stream) @@ -253,7 +262,7 @@ Set getAncestors(final String nodeKey) { /** * Brute-force method */ - Set getChildren(final String nodeKey) { + Set getChildren(final CacheNode.NodeKey nodeKey) { return nodeKeysSet.stream() .filter(key -> getParents(key).contains(nodeKey)) .collect(Collectors.toSet()); @@ -262,24 +271,24 @@ Set getChildren(final String nodeKey) { /** * Brute-force method */ - Set getDescendents(final String nodeKey) { - final Set children = getChildren(nodeKey); + Set getDescendents(final CacheNode.NodeKey nodeKey) { + final Set children = getChildren(nodeKey); return Sets.union(children, children.stream() .map(this::getDescendents) .flatMap(Set::stream) .collect(Collectors.toSet())); } - Set getInducedTags(final String nodeKey) { - final Set allNodes = Sets.union(getDescendents(nodeKey), Collections.singleton(nodeKey)); + Set getInducedTags(final CacheNode.NodeKey nodeKey) { + final Set allNodes = Sets.union(getDescendents(nodeKey), Collections.singleton(nodeKey)); return allNodes.stream() .map(tagsMap::get) .flatMap(Set::stream) .collect(Collectors.toSet()); } - List getTopologicalOrderForNodeEvaluation(final String nodeKey) { - final List sortedNodes = new ArrayList<>(Sets.union(getAncestors(nodeKey), + List getTopologicalOrderForNodeEvaluation(final CacheNode.NodeKey nodeKey) { + final List sortedNodes = new ArrayList<>(Sets.union(getAncestors(nodeKey), Collections.singleton(nodeKey))); sortedNodes.sort(Comparator.comparingInt(topologicalOrderMap::get)); return sortedNodes; @@ -288,16 +297,16 @@ List getTopologicalOrderForNodeEvaluation(final String nodeKey) { /** * Brute-force method */ - List getTopologicalOrderForNodeMutation(final String nodeKey) { - final Set mutatedNodeAndDescendants = Sets.union(getDescendents(nodeKey), + List getTopologicalOrderForNodeMutation(final CacheNode.NodeKey nodeKey) { + final Set mutatedNodeAndDescendants = Sets.union(getDescendents(nodeKey), Collections.singleton(nodeKey)); - final Set ancestorsOfDescendants = getDescendents(nodeKey) + final Set ancestorsOfDescendants = getDescendents(nodeKey) .stream() .map(this::getAncestors) .flatMap(Set::stream) .collect(Collectors.toSet()); - final Set involvedNodes = Sets.union(mutatedNodeAndDescendants, ancestorsOfDescendants); - final List topologicallySortedInvolvedNodes = new ArrayList<>(involvedNodes); + final Set involvedNodes = Sets.union(mutatedNodeAndDescendants, ancestorsOfDescendants); + final List topologicallySortedInvolvedNodes = new ArrayList<>(involvedNodes); topologicallySortedInvolvedNodes.sort(Comparator.comparingInt(topologicalOrderMap::get)); return topologicallySortedInvolvedNodes; } @@ -305,22 +314,23 @@ List getTopologicalOrderForNodeMutation(final String nodeKey) { /** * Brute-force method */ - List getTopologicalOrderForTagEvaluation(final String tag) { - final Set taggedNodes = nodeKeysSet.stream() + List getTopologicalOrderForTagEvaluation(final CacheNode.NodeTag tag) { + final Set taggedNodes = nodeKeysSet.stream() .filter(nodeKey -> getInducedTags(nodeKey).contains(tag)) .collect(Collectors.toSet()); - final Set taggedNodesAndTheirAncestors = Sets.union(taggedNodes, + final Set taggedNodesAndTheirAncestors = Sets.union(taggedNodes, taggedNodes.stream() .map(this::getAncestors) .flatMap(Set::stream) .collect(Collectors.toSet())); - final List topologicallySortedtaggedNodesAndTheirAncestors = new ArrayList<>(taggedNodesAndTheirAncestors); + final List topologicallySortedtaggedNodesAndTheirAncestors = + new ArrayList<>(taggedNodesAndTheirAncestors); topologicallySortedtaggedNodesAndTheirAncestors.sort(Comparator.comparingInt(topologicalOrderMap::get)); return topologicallySortedtaggedNodesAndTheirAncestors; } Set getEquivalentCacheNodeSet() { - final List shuffledNodeList = new ArrayList<>(nodeKeysSet); + final List shuffledNodeList = new ArrayList<>(nodeKeysSet); Collections.shuffle(shuffledNodeList, rng); return shuffledNodeList.stream() .map(nodeKey -> topologicalOrderMap.get(nodeKey) == 0 @@ -329,26 +339,31 @@ Set getEquivalentCacheNodeSet() { .collect(Collectors.toSet()); } - private static Set getRandomNodeKeys(final int maxNodes) { - return IntStream.range(0, rng.nextInt(maxNodes) + 1) - .mapToObj(i -> "NODE_KEY_" + RandomStringUtils.randomAlphanumeric(NODE_KEY_LENGTH)) - .collect(Collectors.toSet()); + private static Set getUniqueRandomNodeKeys(final int maxNodes, final int depth) { + Set randomNodeKeySet = new HashSet<>(); + final int count = rng.nextInt(maxNodes) + 1; + while (randomNodeKeySet.stream().distinct().count() != count) { + randomNodeKeySet = IntStream.range(0, count) + .mapToObj(i -> new CacheNode.NodeKey("NODE_KEY_" + depth + "_" + RandomStringUtils.randomAlphanumeric(NODE_KEY_LENGTH))) + .collect(Collectors.toSet()); + } + return randomNodeKeySet; } - private static Set getRandomTags(final int maxTags) { + private static Set getRandomTags(final int maxTags) { return IntStream.range(0, rng.nextInt(maxTags)) - .mapToObj(i -> "TAG_" + RandomStringUtils.randomAlphanumeric(TAG_LENGTH)) + .mapToObj(i -> new CacheNode.NodeTag("TAG_" + RandomStringUtils.randomAlphanumeric(TAG_LENGTH))) .collect(Collectors.toSet()); } - private static String getRandomElement(final Set set) { - final List list = new ArrayList<>(set); + private static T getRandomElement(final Set set) { + final List list = new ArrayList<>(set); return list.get(rng.nextInt(list.size())); } } - private boolean isTopologicallyEquivalent(final List actual, final List expected, - final Map topologicalOrderMap) { + private static boolean isTopologicallyEquivalent(final List actual, final List expected, + final Map topologicalOrderMap) { if (actual == null || expected == null || !(new HashSet<>(actual).equals(new HashSet<>(expected)))) { return false; } diff --git a/src/test/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/ImmutableComputableGraphUnitTest.java b/src/test/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/ImmutableComputableGraphUnitTest.java index 24a749c06..360bf4415 100644 --- a/src/test/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/ImmutableComputableGraphUnitTest.java +++ b/src/test/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/ImmutableComputableGraphUnitTest.java @@ -55,14 +55,21 @@ public class ImmutableComputableGraphUnitTest extends BaseTest { */ private static final int NUM_TRIALS = 5; - private static final Set ALL_NODES = new HashSet<>(Arrays.asList("x", "y", "z", "f", "g", "h")); - private static final Set ALL_PRIMITIVE_NODES = new HashSet<>(Arrays.asList("x", "y", "z")); - private static final Set ALL_COMPUTABLE_NODES = new HashSet<>(Arrays.asList("f", "g", "h")); + private static final CacheNode.NodeKey X_KEY = new CacheNode.NodeKey("x"); + private static final CacheNode.NodeKey Y_KEY = new CacheNode.NodeKey("y"); + private static final CacheNode.NodeKey Z_KEY = new CacheNode.NodeKey("z"); + private static final CacheNode.NodeKey F_KEY = new CacheNode.NodeKey("f"); + private static final CacheNode.NodeKey G_KEY = new CacheNode.NodeKey("g"); + private static final CacheNode.NodeKey H_KEY = new CacheNode.NodeKey("h"); + + private static final Set ALL_NODES = new HashSet<>(Arrays.asList(X_KEY, Y_KEY, Z_KEY, F_KEY, G_KEY, H_KEY)); + private static final Set ALL_PRIMITIVE_NODES = new HashSet<>(Arrays.asList(X_KEY, Y_KEY, Z_KEY)); + private static final Set ALL_COMPUTABLE_NODES = new HashSet<>(Arrays.asList(F_KEY, G_KEY, H_KEY)); /** * A static counter to keep track of the number of function evaluations */ - private static final Counter counter = new Counter("f", "g", "h"); + private static final Counter counter = new Counter(F_KEY, G_KEY, H_KEY); /** * Shape of NDArrays in the ICG @@ -86,11 +93,8 @@ public class ImmutableComputableGraphUnitTest extends BaseTest { "cos", UNARY_FUNCTION_COSINE, "id", UNARY_FUNCTION_IDENTITY); - private static final List EMPTY_STRING_LIST = new ArrayList<>(); - private static final Map EMPTY_PARENTS = new HashMap<>(); - /** - * Instructions for computing f(x, y) = F[2] ( F[0](x), F[1](x) ) + * Instructions for computing f(x, y) = F[2] ( F[0](x), F[1](y) ) * * The first two strings describe one of the unary functions in {@link #TEST_UNARY_FUNCTIONS} * The last string describes a binary function in {@link #TEST_BINARY_FUNCTIONS} @@ -136,7 +140,7 @@ private static void generateNewRandomFunctionalComposition() { H_COMPUTATION_INSTRUCTIONS.add(getRandomChoice(TEST_BINARY_FUNCTIONS.keySet())); } - private static INDArray getRandomINDArray() { + static INDArray getRandomINDArray() { return Nd4j.rand(TEST_NDARRAY_SHAPE); } @@ -152,11 +156,11 @@ private static T getRandomChoice(final List collection) { return collection.get(rng.nextInt(collection.size())); } - private Set getRandomSetOfTags() { + private Set getRandomSetOfTags() { final int MAX_NUM_TAGS = 5; final int TAG_LENGTH = 12; return IntStream.range(0, rng.nextInt(MAX_NUM_TAGS)) - .mapToObj(n -> RandomStringUtils.randomAlphanumeric(TAG_LENGTH)) + .mapToObj(n -> new CacheNode.NodeTag(RandomStringUtils.randomAlphanumeric(TAG_LENGTH))) .collect(Collectors.toSet()); } @@ -165,7 +169,7 @@ private static Counter getCounterInstance() { } /** - * Computes "f" from "x" and "y" according to the instructions in {@link #F_COMPUTATION_INSTRUCTIONS} + * Computes F_KEY from X_KEY and Y_KEY according to the instructions in {@link #F_COMPUTATION_INSTRUCTIONS} */ private static INDArray f_computer(final INDArray x, final INDArray y) { final INDArray xTrans = TEST_UNARY_FUNCTIONS.get(F_COMPUTATION_INSTRUCTIONS.get(0)).apply(x); @@ -174,7 +178,7 @@ private static INDArray f_computer(final INDArray x, final INDArray y) { } /** - * Computes "g" from "y" and "z" according to the instructions in {@link #G_COMPUTATION_INSTRUCTIONS} + * Computes G_KEY from Y_KEY and Z_KEY according to the instructions in {@link #G_COMPUTATION_INSTRUCTIONS} */ private static INDArray g_computer(final INDArray y, final INDArray z) { final INDArray yTrans = TEST_UNARY_FUNCTIONS.get(G_COMPUTATION_INSTRUCTIONS.get(0)).apply(y); @@ -183,7 +187,7 @@ private static INDArray g_computer(final INDArray y, final INDArray z) { } /** - * Computes "h" from "f", "g" and "x" according to the instructions in {@link #H_COMPUTATION_INSTRUCTIONS} + * Computes H_KEY from F_KEY, G_KEY and X_KEY according to the instructions in {@link #H_COMPUTATION_INSTRUCTIONS} */ private static INDArray h_computer(final INDArray f, final INDArray g, final INDArray x) { final INDArray fTrans = TEST_UNARY_FUNCTIONS.get(H_COMPUTATION_INSTRUCTIONS.get(0)).apply(f); @@ -195,52 +199,52 @@ private static INDArray h_computer(final INDArray f, final INDArray g, final IND /** * An instance of {@link ComputableNodeFunction} for calculating f(x, y) automatically in the - * {@link ImmutableComputableGraph} representation of the problem. It computes "f" by calling - * {@link #f_computer(INDArray, INDArray)} and increments the static "f"-function evaluation counter. + * {@link ImmutableComputableGraph} representation of the problem. It computes F_KEY by calling + * {@link #f_computer(INDArray, INDArray)} and increments the static F_KEY-function evaluation counter. */ - private static ComputableNodeFunction f_computation_function = new ComputableNodeFunction() { + static ComputableNodeFunction f_computation_function = new ComputableNodeFunction() { @Override - public Duplicable apply(Map parents) throws ParentValueNotFoundException { - final INDArray x = fetchINDArray("x", parents); - final INDArray y = Nd4j.zeros(x.shape()).add(fetchDouble("y", parents)); + public Duplicable apply(Map parents) throws ParentValueNotFoundException { + final INDArray x = fetchINDArray(X_KEY, parents); + final INDArray y = Nd4j.zeros(x.shape()).add(fetchDouble(Y_KEY, parents)); final INDArray result = f_computer(x, y); - counter.increment("f"); + counter.increment(F_KEY); return new DuplicableNDArray(result); } }; /** * An instance of {@link ComputableNodeFunction} for calculating g(y, z) automatically in the - * {@link ImmutableComputableGraph} representation of the problem. It computes "g" by calling - * {@link #g_computer(INDArray, INDArray)} and increments the static "g"-function evaluation counter. + * {@link ImmutableComputableGraph} representation of the problem. It computes G_KEY by calling + * {@link #g_computer(INDArray, INDArray)} and increments the static G_KEY-function evaluation counter. * - * Note: "y" will be casted into an {@link INDArray} + * Note: Y_KEY will be casted into an {@link INDArray} */ - private static ComputableNodeFunction g_computation_function = new ComputableNodeFunction() { + static ComputableNodeFunction g_computation_function = new ComputableNodeFunction() { @Override - public Duplicable apply(Map parents) throws ParentValueNotFoundException { - final INDArray z = fetchINDArray("z", parents); - final INDArray y = Nd4j.zeros(z.shape()).add(fetchDouble("y", parents)); + public Duplicable apply(Map parents) throws ParentValueNotFoundException { + final INDArray z = fetchINDArray(Z_KEY, parents); + final INDArray y = Nd4j.zeros(z.shape()).add(fetchDouble(Y_KEY, parents)); final INDArray result = g_computer(y, z); - counter.increment("g"); + counter.increment(G_KEY); return new DuplicableNDArray(result); } }; /** * An instance of {@link ComputableNodeFunction} for calculating h(f, g, x) automatically in the - * {@link ImmutableComputableGraph} representation of the problem. It computes "h" by calling - * {@link #h_computer(INDArray, INDArray, INDArray)} and increments the static "h"-function evaluation + * {@link ImmutableComputableGraph} representation of the problem. It computes H_KEY by calling + * {@link #h_computer(INDArray, INDArray, INDArray)} and increments the static H_KEY-function evaluation * counter. */ - public static ComputableNodeFunction h_computation_function = new ComputableNodeFunction() { + static ComputableNodeFunction h_computation_function = new ComputableNodeFunction() { @Override - public Duplicable apply(Map parents) throws ParentValueNotFoundException { - final INDArray f = fetchINDArray("f", parents); - final INDArray g = fetchINDArray("g", parents); - final INDArray x = fetchINDArray("x", parents); + public Duplicable apply(Map parents) throws ParentValueNotFoundException { + final INDArray f = fetchINDArray(F_KEY, parents); + final INDArray g = fetchINDArray(G_KEY, parents); + final INDArray x = fetchINDArray(X_KEY, parents); final INDArray result = h_computer(f, g, x); - counter.increment("h"); + counter.increment(H_KEY); return new DuplicableNDArray(result); } }; @@ -249,17 +253,17 @@ private static ImmutableComputableGraphUtils.ImmutableComputableGraphBuilder get final boolean f_caching, final boolean f_external, final boolean g_caching, final boolean g_external, final boolean h_caching, final boolean h_external, - final String[] x_tags, final String[] y_tags, final String[] z_tags, - final String[] f_tags, final String[] g_tags, final String[] h_tags) { + final CacheNode.NodeTag[] x_tags, final CacheNode.NodeTag[] y_tags, final CacheNode.NodeTag[] z_tags, + final CacheNode.NodeTag[] f_tags, final CacheNode.NodeTag[] g_tags, final CacheNode.NodeTag[] h_tags) { return ImmutableComputableGraph.builder() - .primitiveNode("x", x_tags, new DuplicableNDArray()) - .primitiveNode("y", y_tags, new DuplicableNumber()) - .primitiveNode("z", z_tags, new DuplicableNDArray()) - .computableNode("f", f_tags, new String[]{"x", "y"}, + .primitiveNode(X_KEY, x_tags, new DuplicableNDArray()) + .primitiveNode(Y_KEY, y_tags, new DuplicableNumber()) + .primitiveNode(Z_KEY, z_tags, new DuplicableNDArray()) + .computableNode(F_KEY, f_tags, new CacheNode.NodeKey[] {X_KEY, Y_KEY}, f_external ? null : f_computation_function, f_caching) - .computableNode("g", g_tags, new String[]{"y", "z"}, + .computableNode(G_KEY, g_tags, new CacheNode.NodeKey[] {Y_KEY, Z_KEY}, g_external ? null : g_computation_function, g_caching) - .computableNode("h", h_tags, new String[]{"f", "g", "x"}, + .computableNode(H_KEY, h_tags, new CacheNode.NodeKey[] {F_KEY, G_KEY, X_KEY}, h_external ? null : h_computation_function, h_caching); } @@ -268,8 +272,8 @@ private static ImmutableComputableGraphUtils.ImmutableComputableGraphBuilder get final boolean g_caching, final boolean g_external, final boolean h_caching, final boolean h_external) { return getTestICGBuilder(f_caching, f_external, g_caching, g_external, h_caching, h_external, - new String[] {}, new String[] {}, new String[] {}, - new String[] {}, new String[] {}, new String[] {}); + new CacheNode.NodeTag[] {}, new CacheNode.NodeTag[] {}, new CacheNode.NodeTag[] {}, + new CacheNode.NodeTag[] {}, new CacheNode.NodeTag[] {}, new CacheNode.NodeTag[] {}); } /** @@ -291,39 +295,39 @@ private static void assertCorrectness(final INDArray xExpected, final INDArray y private boolean assertIntactReferences(@Nonnull final ImmutableComputableGraph original, @Nonnull final ImmutableComputableGraph other, - @Nonnull final Set unaffectedNodeKeys) { - final Set affectedNodeKeys = unaffectedNodeKeys.stream() + @Nonnull final Set unaffectedNodeKeys) { + final Set affectedNodeKeys = unaffectedNodeKeys.stream() .filter(nodeKey -> original.getCacheNode(nodeKey) != other.getCacheNode(nodeKey)) .collect(Collectors.toSet()); if (!affectedNodeKeys.isEmpty()) { - throw new AssertionError("Some of the node references have changed but they were supposed to remain" + - " intact: " + affectedNodeKeys.stream().collect(Collectors.joining(", "))); + throw new AssertionError("Some of the node references have changed but they were supposed to remain intact: " + + affectedNodeKeys.stream().map(CacheNode.NodeKey::toString).collect(Collectors.joining(", "))); } return true; } private boolean assertChangedReferences(@Nonnull final ImmutableComputableGraph original, @Nonnull final ImmutableComputableGraph other, - @Nonnull final Set affectedNodeKeys) { - final Set unaffectedNodeKeys = affectedNodeKeys.stream() + @Nonnull final Set affectedNodeKeys) { + final Set unaffectedNodeKeys = affectedNodeKeys.stream() .filter(nodeKey -> original.getCacheNode(nodeKey) == other.getCacheNode(nodeKey)) .collect(Collectors.toSet()); if (!unaffectedNodeKeys.isEmpty()) { throw new AssertionError("Some of the node references have not changed but they were supposed to change: " + - unaffectedNodeKeys.stream().collect(Collectors.joining(", "))); + unaffectedNodeKeys.stream().map(CacheNode.NodeKey::toString).collect(Collectors.joining(", "))); } return true; } private boolean assertIntactReferences(@Nonnull final ImmutableComputableGraph original, @Nonnull final ImmutableComputableGraph other, - @Nonnull final String... unaffectedNodeKeys) { + @Nonnull final CacheNode.NodeKey... unaffectedNodeKeys) { return assertIntactReferences(original, other, Arrays.stream(unaffectedNodeKeys).collect(Collectors.toSet())); } private boolean assertChangedReferences(@Nonnull final ImmutableComputableGraph original, @Nonnull final ImmutableComputableGraph other, - @Nonnull final String... affectedNodeKeys) { + @Nonnull final CacheNode.NodeKey... affectedNodeKeys) { return assertChangedReferences(original, other, Arrays.stream(affectedNodeKeys).collect(Collectors.toSet())); } @@ -341,15 +345,15 @@ public void testAutoUpdateCache() { Counter startCounts = getCounterInstance(); ImmutableComputableGraph icg_1 = icg_0 - .setValue("x", new DuplicableNDArray(x)) - .setValue("y", new DuplicableNumber<>(y)) - .setValue("z", new DuplicableNDArray(z)); - final INDArray xICG = (INDArray)icg_1.fetchDirectly("x").value(); - final double yICG = (Double)icg_1.fetchDirectly("y").value(); - final INDArray zICG = (INDArray)icg_1.fetchDirectly("z").value(); - final INDArray fICG = (INDArray)icg_1.fetchDirectly("f").value(); - final INDArray gICG = (INDArray)icg_1.fetchDirectly("g").value(); - final INDArray hICG = (INDArray)icg_1.fetchDirectly("h").value(); + .setValue(X_KEY, new DuplicableNDArray(x)) + .setValue(Y_KEY, new DuplicableNumber<>(y)) + .setValue(Z_KEY, new DuplicableNDArray(z)); + final INDArray xICG = (INDArray)icg_1.fetchDirectly(X_KEY).value(); + final double yICG = (Double)icg_1.fetchDirectly(Y_KEY).value(); + final INDArray zICG = (INDArray)icg_1.fetchDirectly(Z_KEY).value(); + final INDArray fICG = (INDArray)icg_1.fetchDirectly(F_KEY).value(); + final INDArray gICG = (INDArray)icg_1.fetchDirectly(G_KEY).value(); + final INDArray hICG = (INDArray)icg_1.fetchDirectly(H_KEY).value(); Counter diffCounts = getCounterInstance().diff(startCounts); assertCorrectness(x, Nd4j.zeros(TEST_NDARRAY_SHAPE).add(y), z, @@ -357,9 +361,9 @@ public void testAutoUpdateCache() { fICG, gICG, hICG); /* each function must be calculated only once; otherwise, ICG is doing redundant computations */ - Assert.assertEquals(diffCounts.getCount("f"), 1); - Assert.assertEquals(diffCounts.getCount("g"), 1); - Assert.assertEquals(diffCounts.getCount("h"), 1); + Assert.assertEquals(diffCounts.getCount(F_KEY), 1); + Assert.assertEquals(diffCounts.getCount(G_KEY), 1); + Assert.assertEquals(diffCounts.getCount(H_KEY), 1); /* if we update all caches again, nothing should change */ startCounts = getCounterInstance(); @@ -393,42 +397,42 @@ public void testBookkeeping(final boolean f_caching, final boolean f_external, final ImmutableComputableGraph icg_0 = getTestICGBuilder(f_caching, f_external, g_caching, g_external, h_caching, h_external).build(); - Assert.assertTrue(!icg_0.isValueDirectlyAvailable("x")); - Assert.assertTrue(!icg_0.isValueDirectlyAvailable("y")); - Assert.assertTrue(!icg_0.isValueDirectlyAvailable("z")); - Assert.assertTrue(!icg_0.isValueDirectlyAvailable("f")); - Assert.assertTrue(!icg_0.isValueDirectlyAvailable("g")); - Assert.assertTrue(!icg_0.isValueDirectlyAvailable("h")); + Assert.assertTrue(!icg_0.isValueDirectlyAvailable(X_KEY)); + Assert.assertTrue(!icg_0.isValueDirectlyAvailable(Y_KEY)); + Assert.assertTrue(!icg_0.isValueDirectlyAvailable(Z_KEY)); + Assert.assertTrue(!icg_0.isValueDirectlyAvailable(F_KEY)); + Assert.assertTrue(!icg_0.isValueDirectlyAvailable(G_KEY)); + Assert.assertTrue(!icg_0.isValueDirectlyAvailable(H_KEY)); ImmutableComputableGraph icg_tmp = icg_0; - icg_tmp = icg_tmp.setValue("x", new DuplicableNDArray(getRandomINDArray())); - Assert.assertTrue(icg_tmp.isValueDirectlyAvailable("x")); - Assert.assertTrue(!icg_tmp.isValueDirectlyAvailable("y")); - Assert.assertTrue(!icg_tmp.isValueDirectlyAvailable("z")); - Assert.assertTrue(!icg_tmp.isValueDirectlyAvailable("f")); - Assert.assertTrue(!icg_tmp.isValueDirectlyAvailable("g")); - Assert.assertTrue(!icg_tmp.isValueDirectlyAvailable("h")); - assertIntactReferences(icg_0, icg_tmp, "y", "z", "g"); + icg_tmp = icg_tmp.setValue(X_KEY, new DuplicableNDArray(getRandomINDArray())); + Assert.assertTrue(icg_tmp.isValueDirectlyAvailable(X_KEY)); + Assert.assertTrue(!icg_tmp.isValueDirectlyAvailable(Y_KEY)); + Assert.assertTrue(!icg_tmp.isValueDirectlyAvailable(Z_KEY)); + Assert.assertTrue(!icg_tmp.isValueDirectlyAvailable(F_KEY)); + Assert.assertTrue(!icg_tmp.isValueDirectlyAvailable(G_KEY)); + Assert.assertTrue(!icg_tmp.isValueDirectlyAvailable(H_KEY)); + assertIntactReferences(icg_0, icg_tmp, Y_KEY, Z_KEY, G_KEY); ImmutableComputableGraph icg_tmp_old = icg_tmp; - icg_tmp = icg_tmp.setValue("y", new DuplicableNumber<>(getRandomDouble())); - Assert.assertTrue(icg_tmp.isValueDirectlyAvailable("x")); - Assert.assertTrue(icg_tmp.isValueDirectlyAvailable("y")); - Assert.assertTrue(!icg_tmp.isValueDirectlyAvailable("z")); - Assert.assertTrue(!icg_tmp.isValueDirectlyAvailable("f")); - Assert.assertTrue(!icg_tmp.isValueDirectlyAvailable("g")); - Assert.assertTrue(!icg_tmp.isValueDirectlyAvailable("h")); - assertIntactReferences(icg_tmp_old, icg_tmp, "x", "z"); + icg_tmp = icg_tmp.setValue(Y_KEY, new DuplicableNumber<>(getRandomDouble())); + Assert.assertTrue(icg_tmp.isValueDirectlyAvailable(X_KEY)); + Assert.assertTrue(icg_tmp.isValueDirectlyAvailable(Y_KEY)); + Assert.assertTrue(!icg_tmp.isValueDirectlyAvailable(Z_KEY)); + Assert.assertTrue(!icg_tmp.isValueDirectlyAvailable(F_KEY)); + Assert.assertTrue(!icg_tmp.isValueDirectlyAvailable(G_KEY)); + Assert.assertTrue(!icg_tmp.isValueDirectlyAvailable(H_KEY)); + assertIntactReferences(icg_tmp_old, icg_tmp, X_KEY, Z_KEY); icg_tmp_old = icg_tmp; - icg_tmp = icg_tmp.setValue("z", new DuplicableNDArray(getRandomINDArray())); - Assert.assertTrue(icg_tmp.isValueDirectlyAvailable("x")); - Assert.assertTrue(icg_tmp.isValueDirectlyAvailable("y")); - Assert.assertTrue(icg_tmp.isValueDirectlyAvailable("z")); - Assert.assertTrue(!icg_tmp.isValueDirectlyAvailable("f")); - Assert.assertTrue(!icg_tmp.isValueDirectlyAvailable("g")); - Assert.assertTrue(!icg_tmp.isValueDirectlyAvailable("h")); - assertIntactReferences(icg_tmp_old, icg_tmp, "x", "y", "f"); + icg_tmp = icg_tmp.setValue(Z_KEY, new DuplicableNDArray(getRandomINDArray())); + Assert.assertTrue(icg_tmp.isValueDirectlyAvailable(X_KEY)); + Assert.assertTrue(icg_tmp.isValueDirectlyAvailable(Y_KEY)); + Assert.assertTrue(icg_tmp.isValueDirectlyAvailable(Z_KEY)); + Assert.assertTrue(!icg_tmp.isValueDirectlyAvailable(F_KEY)); + Assert.assertTrue(!icg_tmp.isValueDirectlyAvailable(G_KEY)); + Assert.assertTrue(!icg_tmp.isValueDirectlyAvailable(H_KEY)); + assertIntactReferences(icg_tmp_old, icg_tmp, X_KEY, Y_KEY, F_KEY); icg_tmp_old = icg_tmp; try { @@ -440,130 +444,130 @@ public void testBookkeeping(final boolean f_caching, final boolean f_external, icg_tmp = icg_tmp.updateAllCachesIfPossible(); /* this will not throw exception by design */ } } - assertIntactReferences(icg_tmp_old, icg_tmp, "x", "y", "z"); + assertIntactReferences(icg_tmp_old, icg_tmp, X_KEY, Y_KEY, Z_KEY); - Assert.assertTrue((!f_caching && assertIntactReferences(icg_tmp_old, icg_tmp, "f")) || - (f_external && !icg_tmp.isValueDirectlyAvailable("f") && assertIntactReferences(icg_tmp_old, icg_tmp, "f")) || - (!f_external && icg_tmp.isValueDirectlyAvailable("f") && assertChangedReferences(icg_tmp_old, icg_tmp, "f"))); + Assert.assertTrue((!f_caching && assertIntactReferences(icg_tmp_old, icg_tmp, F_KEY)) || + (f_external && !icg_tmp.isValueDirectlyAvailable(F_KEY) && assertIntactReferences(icg_tmp_old, icg_tmp, F_KEY)) || + (!f_external && icg_tmp.isValueDirectlyAvailable(F_KEY) && assertChangedReferences(icg_tmp_old, icg_tmp, F_KEY))); - Assert.assertTrue((!g_caching && assertIntactReferences(icg_tmp_old, icg_tmp, "g")) || - (g_external && !icg_tmp.isValueDirectlyAvailable("g") && assertIntactReferences(icg_tmp_old, icg_tmp, "g")) || - (!g_external && icg_tmp.isValueDirectlyAvailable("g") && assertChangedReferences(icg_tmp_old, icg_tmp, "g"))); + Assert.assertTrue((!g_caching && assertIntactReferences(icg_tmp_old, icg_tmp, G_KEY)) || + (g_external && !icg_tmp.isValueDirectlyAvailable(G_KEY) && assertIntactReferences(icg_tmp_old, icg_tmp, G_KEY)) || + (!g_external && icg_tmp.isValueDirectlyAvailable(G_KEY) && assertChangedReferences(icg_tmp_old, icg_tmp, G_KEY))); if (!f_external && !g_external) { - Assert.assertTrue((!h_caching && assertIntactReferences(icg_tmp_old, icg_tmp, "h")) || - (h_external && !icg_tmp.isValueDirectlyAvailable("h") && assertIntactReferences(icg_tmp_old, icg_tmp, "h")) || - (!h_external && icg_tmp.isValueDirectlyAvailable("h") && assertChangedReferences(icg_tmp_old, icg_tmp, "h"))); + Assert.assertTrue((!h_caching && assertIntactReferences(icg_tmp_old, icg_tmp, H_KEY)) || + (h_external && !icg_tmp.isValueDirectlyAvailable(H_KEY) && assertIntactReferences(icg_tmp_old, icg_tmp, H_KEY)) || + (!h_external && icg_tmp.isValueDirectlyAvailable(H_KEY) && assertChangedReferences(icg_tmp_old, icg_tmp, H_KEY))); } else { - Assert.assertTrue(!icg_tmp.isValueDirectlyAvailable("h") && assertIntactReferences(icg_tmp_old, icg_tmp, "h")); + Assert.assertTrue(!icg_tmp.isValueDirectlyAvailable(H_KEY) && assertIntactReferences(icg_tmp_old, icg_tmp, H_KEY)); } /* fill in the external values */ if (f_external) { icg_tmp_old = icg_tmp; - icg_tmp = icg_tmp.setValue("f", f_computation_function.apply( - ImmutableMap.of("x", icg_tmp.fetchDirectly("x"), "y", icg_tmp.fetchDirectly("y")))); - Assert.assertTrue(icg_tmp.isValueDirectlyAvailable("f")); - assertIntactReferences(icg_tmp_old, icg_tmp, "x", "y", "z", "g"); - assertChangedReferences(icg_tmp_old, icg_tmp, "f", "h"); + icg_tmp = icg_tmp.setValue(F_KEY, f_computation_function.apply( + ImmutableMap.of(X_KEY, icg_tmp.fetchDirectly(X_KEY), Y_KEY, icg_tmp.fetchDirectly(Y_KEY)))); + Assert.assertTrue(icg_tmp.isValueDirectlyAvailable(F_KEY)); + assertIntactReferences(icg_tmp_old, icg_tmp, X_KEY, Y_KEY, Z_KEY, G_KEY); + assertChangedReferences(icg_tmp_old, icg_tmp, F_KEY, H_KEY); } if (g_external) { icg_tmp_old = icg_tmp; - icg_tmp = icg_tmp.setValue("g", g_computation_function.apply( - ImmutableMap.of("y", icg_tmp.fetchDirectly("y"), "z", icg_tmp.fetchDirectly("z")))); - Assert.assertTrue(icg_tmp.isValueDirectlyAvailable("g")); - assertIntactReferences(icg_tmp_old, icg_tmp, "x", "y", "z", "f"); - assertChangedReferences(icg_tmp_old, icg_tmp, "g", "h"); + icg_tmp = icg_tmp.setValue(G_KEY, g_computation_function.apply( + ImmutableMap.of(Y_KEY, icg_tmp.fetchDirectly(Y_KEY), Z_KEY, icg_tmp.fetchDirectly(Z_KEY)))); + Assert.assertTrue(icg_tmp.isValueDirectlyAvailable(G_KEY)); + assertIntactReferences(icg_tmp_old, icg_tmp, X_KEY, Y_KEY, Z_KEY, F_KEY); + assertChangedReferences(icg_tmp_old, icg_tmp, G_KEY, H_KEY); } if (h_external) { icg_tmp_old = icg_tmp; - icg_tmp = icg_tmp.setValue("h", h_computation_function.apply(ImmutableMap.of( - "f", icg_tmp.fetchWithRequiredEvaluations("f"), - "g", icg_tmp.fetchWithRequiredEvaluations("g"), - "x", icg_tmp.fetchDirectly("x")))); - Assert.assertTrue(icg_tmp.isValueDirectlyAvailable("h")); - assertIntactReferences(icg_tmp_old, icg_tmp, "x", "y", "z", "f", "g"); - assertChangedReferences(icg_tmp_old, icg_tmp, "h"); + icg_tmp = icg_tmp.setValue(H_KEY, h_computation_function.apply(ImmutableMap.of( + F_KEY, icg_tmp.fetchWithRequiredEvaluations(F_KEY), + G_KEY, icg_tmp.fetchWithRequiredEvaluations(G_KEY), + X_KEY, icg_tmp.fetchDirectly(X_KEY)))); + Assert.assertTrue(icg_tmp.isValueDirectlyAvailable(H_KEY)); + assertIntactReferences(icg_tmp_old, icg_tmp, X_KEY, Y_KEY, Z_KEY, F_KEY, G_KEY); + assertChangedReferences(icg_tmp_old, icg_tmp, H_KEY); } /* since all externally computed nodes are initialized, a call to updateAllCaches() must succeed */ icg_tmp = icg_tmp.updateAllCaches(); /* at this point, every caching node must be up-to-date */ - Assert.assertTrue(icg_tmp.isValueDirectlyAvailable("x")); - Assert.assertTrue(icg_tmp.isValueDirectlyAvailable("y")); - Assert.assertTrue(icg_tmp.isValueDirectlyAvailable("z")); - Assert.assertTrue(!f_caching || icg_tmp.isValueDirectlyAvailable("f")); - Assert.assertTrue(!g_caching || icg_tmp.isValueDirectlyAvailable("g")); - Assert.assertTrue(!h_caching || icg_tmp.isValueDirectlyAvailable("h")); + Assert.assertTrue(icg_tmp.isValueDirectlyAvailable(X_KEY)); + Assert.assertTrue(icg_tmp.isValueDirectlyAvailable(Y_KEY)); + Assert.assertTrue(icg_tmp.isValueDirectlyAvailable(Z_KEY)); + Assert.assertTrue(!f_caching || icg_tmp.isValueDirectlyAvailable(F_KEY)); + Assert.assertTrue(!g_caching || icg_tmp.isValueDirectlyAvailable(G_KEY)); + Assert.assertTrue(!h_caching || icg_tmp.isValueDirectlyAvailable(H_KEY)); /* update x -- f and h must go out of date */ - ImmutableComputableGraph icg_tmp_x = icg_tmp.setValue("x", new DuplicableNDArray(getRandomINDArray())); - Assert.assertTrue(!icg_tmp_x.isValueDirectlyAvailable("f")); - Assert.assertTrue(!g_caching || icg_tmp_x.isValueDirectlyAvailable("g")); - Assert.assertTrue(!icg_tmp_x.isValueDirectlyAvailable("h")); + ImmutableComputableGraph icg_tmp_x = icg_tmp.setValue(X_KEY, new DuplicableNDArray(getRandomINDArray())); + Assert.assertTrue(!icg_tmp_x.isValueDirectlyAvailable(F_KEY)); + Assert.assertTrue(!g_caching || icg_tmp_x.isValueDirectlyAvailable(G_KEY)); + Assert.assertTrue(!icg_tmp_x.isValueDirectlyAvailable(H_KEY)); /* update y -- f, g and h must go out of date */ - ImmutableComputableGraph icg_tmp_y = icg_tmp.setValue("y", new DuplicableNumber<>(getRandomDouble())); - Assert.assertTrue(!icg_tmp_y.isValueDirectlyAvailable("f")); - Assert.assertTrue(!icg_tmp_y.isValueDirectlyAvailable("g")); - Assert.assertTrue(!icg_tmp_y.isValueDirectlyAvailable("h")); + ImmutableComputableGraph icg_tmp_y = icg_tmp.setValue(Y_KEY, new DuplicableNumber<>(getRandomDouble())); + Assert.assertTrue(!icg_tmp_y.isValueDirectlyAvailable(F_KEY)); + Assert.assertTrue(!icg_tmp_y.isValueDirectlyAvailable(G_KEY)); + Assert.assertTrue(!icg_tmp_y.isValueDirectlyAvailable(H_KEY)); /* update z -- g and h must go out of date */ - ImmutableComputableGraph icg_tmp_z = icg_tmp.setValue("z", new DuplicableNDArray(getRandomINDArray())); - Assert.assertTrue(!f_caching || icg_tmp_z.isValueDirectlyAvailable("f")); - Assert.assertTrue(!icg_tmp_z.isValueDirectlyAvailable("g")); - Assert.assertTrue(!icg_tmp_z.isValueDirectlyAvailable("h")); + ImmutableComputableGraph icg_tmp_z = icg_tmp.setValue(Z_KEY, new DuplicableNDArray(getRandomINDArray())); + Assert.assertTrue(!f_caching || icg_tmp_z.isValueDirectlyAvailable(F_KEY)); + Assert.assertTrue(!icg_tmp_z.isValueDirectlyAvailable(G_KEY)); + Assert.assertTrue(!icg_tmp_z.isValueDirectlyAvailable(H_KEY)); /* update x and y -- f, g and h must go out of date */ ImmutableComputableGraph icg_tmp_xy = icg_tmp - .setValue("x", new DuplicableNDArray(getRandomINDArray())) - .setValue("y", new DuplicableNumber<>(getRandomDouble())); - Assert.assertTrue(!icg_tmp_xy.isValueDirectlyAvailable("f")); - Assert.assertTrue(!icg_tmp_xy.isValueDirectlyAvailable("g")); - Assert.assertTrue(!icg_tmp_xy.isValueDirectlyAvailable("h")); + .setValue(X_KEY, new DuplicableNDArray(getRandomINDArray())) + .setValue(Y_KEY, new DuplicableNumber<>(getRandomDouble())); + Assert.assertTrue(!icg_tmp_xy.isValueDirectlyAvailable(F_KEY)); + Assert.assertTrue(!icg_tmp_xy.isValueDirectlyAvailable(G_KEY)); + Assert.assertTrue(!icg_tmp_xy.isValueDirectlyAvailable(H_KEY)); /* update x and z -- f, g and h must go out of date */ ImmutableComputableGraph icg_tmp_xz = icg_tmp - .setValue("x", new DuplicableNDArray(getRandomINDArray())) - .setValue("z", new DuplicableNDArray(getRandomINDArray())); - Assert.assertTrue(!icg_tmp_xz.isValueDirectlyAvailable("f")); - Assert.assertTrue(!icg_tmp_xz.isValueDirectlyAvailable("g")); - Assert.assertTrue(!icg_tmp_xz.isValueDirectlyAvailable("h")); + .setValue(X_KEY, new DuplicableNDArray(getRandomINDArray())) + .setValue(Z_KEY, new DuplicableNDArray(getRandomINDArray())); + Assert.assertTrue(!icg_tmp_xz.isValueDirectlyAvailable(F_KEY)); + Assert.assertTrue(!icg_tmp_xz.isValueDirectlyAvailable(G_KEY)); + Assert.assertTrue(!icg_tmp_xz.isValueDirectlyAvailable(H_KEY)); /* update x and z -- f, g and h must go out of date */ ImmutableComputableGraph icg_tmp_xyz = icg_tmp - .setValue("x", new DuplicableNDArray(getRandomINDArray())) - .setValue("y", new DuplicableNumber<>(getRandomDouble())) - .setValue("z", new DuplicableNDArray(getRandomINDArray())); - Assert.assertTrue(!icg_tmp_xyz.isValueDirectlyAvailable("f")); - Assert.assertTrue(!icg_tmp_xyz.isValueDirectlyAvailable("g")); - Assert.assertTrue(!icg_tmp_xyz.isValueDirectlyAvailable("h")); + .setValue(X_KEY, new DuplicableNDArray(getRandomINDArray())) + .setValue(Y_KEY, new DuplicableNumber<>(getRandomDouble())) + .setValue(Z_KEY, new DuplicableNDArray(getRandomINDArray())); + Assert.assertTrue(!icg_tmp_xyz.isValueDirectlyAvailable(F_KEY)); + Assert.assertTrue(!icg_tmp_xyz.isValueDirectlyAvailable(G_KEY)); + Assert.assertTrue(!icg_tmp_xyz.isValueDirectlyAvailable(H_KEY)); if (f_external) { /* update f -- h must go out of date */ ImmutableComputableGraph icg_tmp_f = icg_tmp - .setValue("f", new DuplicableNDArray(getRandomINDArray())); - Assert.assertTrue(!g_caching || icg_tmp_f.isValueDirectlyAvailable("g")); - Assert.assertTrue(!icg_tmp_f.isValueDirectlyAvailable("h")); + .setValue(F_KEY, new DuplicableNDArray(getRandomINDArray())); + Assert.assertTrue(!g_caching || icg_tmp_f.isValueDirectlyAvailable(G_KEY)); + Assert.assertTrue(!icg_tmp_f.isValueDirectlyAvailable(H_KEY)); } if (g_external) { /* update g -- h must go out of date */ ImmutableComputableGraph icg_tmp_g = icg_tmp - .setValue("g", new DuplicableNDArray(getRandomINDArray())); - Assert.assertTrue(!f_caching || icg_tmp_g.isValueDirectlyAvailable("f")); - Assert.assertTrue(!icg_tmp_g.isValueDirectlyAvailable("h")); + .setValue(G_KEY, new DuplicableNDArray(getRandomINDArray())); + Assert.assertTrue(!f_caching || icg_tmp_g.isValueDirectlyAvailable(F_KEY)); + Assert.assertTrue(!icg_tmp_g.isValueDirectlyAvailable(H_KEY)); } if (f_external && g_external) { /* update f and g -- h must go out of date */ ImmutableComputableGraph icg_tmp_fg = icg_tmp - .setValue("f", new DuplicableNDArray(getRandomINDArray())) - .setValue("g", new DuplicableNDArray(getRandomINDArray())); - Assert.assertTrue(!icg_tmp_fg.isValueDirectlyAvailable("h")); + .setValue(F_KEY, new DuplicableNDArray(getRandomINDArray())) + .setValue(G_KEY, new DuplicableNDArray(getRandomINDArray())); + Assert.assertTrue(!icg_tmp_fg.isValueDirectlyAvailable(H_KEY)); } } @@ -574,31 +578,31 @@ public void testBookkeeping(final boolean f_caching, final boolean f_external, public void testTagPropagation(final boolean f_caching, final boolean f_external, final boolean g_caching, final boolean g_external, final boolean h_caching, final boolean h_external) { - final Set x_tags = getRandomSetOfTags(); - final Set y_tags = getRandomSetOfTags(); - final Set z_tags = getRandomSetOfTags(); - final Set f_tags = getRandomSetOfTags(); - final Set g_tags = getRandomSetOfTags(); - final Set h_tags = getRandomSetOfTags(); + final Set x_tags = getRandomSetOfTags(); + final Set y_tags = getRandomSetOfTags(); + final Set z_tags = getRandomSetOfTags(); + final Set f_tags = getRandomSetOfTags(); + final Set g_tags = getRandomSetOfTags(); + final Set h_tags = getRandomSetOfTags(); final ImmutableComputableGraph icg = getTestICGBuilder( f_caching, f_external, g_caching, g_external, h_caching, h_external, - x_tags.toArray(new String[0]), y_tags.toArray(new String[0]), - z_tags.toArray(new String[0]), f_tags.toArray(new String[0]), - g_tags.toArray(new String[0]), h_tags.toArray(new String[0])).build(); - - final Set all_x_tags = Sets.union(Sets.union(x_tags, f_tags), h_tags); - final Set all_y_tags = Sets.union(Sets.union(Sets.union(y_tags, f_tags), g_tags), h_tags); - final Set all_z_tags = Sets.union(Sets.union(z_tags, g_tags), h_tags); - final Set all_f_tags = Sets.union(f_tags, h_tags); - final Set all_g_tags = Sets.union(g_tags, h_tags); - final Set all_h_tags = h_tags; - - final Set all_x_tags_actual = icg.getComputableGraphStructure().getInducedTagsForNode("x"); - final Set all_y_tags_actual = icg.getComputableGraphStructure().getInducedTagsForNode("y"); - final Set all_z_tags_actual = icg.getComputableGraphStructure().getInducedTagsForNode("z"); - final Set all_f_tags_actual = icg.getComputableGraphStructure().getInducedTagsForNode("f"); - final Set all_g_tags_actual = icg.getComputableGraphStructure().getInducedTagsForNode("g"); - final Set all_h_tags_actual = icg.getComputableGraphStructure().getInducedTagsForNode("h"); + x_tags.toArray(new CacheNode.NodeTag[0]), y_tags.toArray(new CacheNode.NodeTag[0]), + z_tags.toArray(new CacheNode.NodeTag[0]), f_tags.toArray(new CacheNode.NodeTag[0]), + g_tags.toArray(new CacheNode.NodeTag[0]), h_tags.toArray(new CacheNode.NodeTag[0])).build(); + + final Set all_x_tags = Sets.union(Sets.union(x_tags, f_tags), h_tags); + final Set all_y_tags = Sets.union(Sets.union(Sets.union(y_tags, f_tags), g_tags), h_tags); + final Set all_z_tags = Sets.union(Sets.union(z_tags, g_tags), h_tags); + final Set all_f_tags = Sets.union(f_tags, h_tags); + final Set all_g_tags = Sets.union(g_tags, h_tags); + final Set all_h_tags = h_tags; + + final Set all_x_tags_actual = icg.getComputableGraphStructure().getInducedTagsForNode(X_KEY); + final Set all_y_tags_actual = icg.getComputableGraphStructure().getInducedTagsForNode(Y_KEY); + final Set all_z_tags_actual = icg.getComputableGraphStructure().getInducedTagsForNode(Z_KEY); + final Set all_f_tags_actual = icg.getComputableGraphStructure().getInducedTagsForNode(F_KEY); + final Set all_g_tags_actual = icg.getComputableGraphStructure().getInducedTagsForNode(G_KEY); + final Set all_h_tags_actual = icg.getComputableGraphStructure().getInducedTagsForNode(H_KEY); Assert.assertTrue(all_x_tags.equals(all_x_tags_actual)); Assert.assertTrue(all_y_tags.equals(all_y_tags_actual)); @@ -608,18 +612,18 @@ public void testTagPropagation(final boolean f_caching, final boolean f_external Assert.assertTrue(all_h_tags.equals(all_h_tags_actual)); } - private Map getExpectedComputableNodeValues(final Duplicable x, final Duplicable y, final Duplicable z) { + private Map getExpectedComputableNodeValues(final Duplicable x, final Duplicable y, final Duplicable z) { final INDArray xVal = (INDArray)x.value(); final INDArray yVal = Nd4j.zeros(TEST_NDARRAY_SHAPE).add((Double)y.value()); final INDArray zVal = (INDArray)z.value(); final INDArray fExpected = f_computer(xVal, yVal); final INDArray gExpected = g_computer(yVal, zVal); final INDArray hExpected = h_computer(fExpected, gExpected, xVal); - return ImmutableMap.of("f", fExpected, "g", gExpected, "h", hExpected); + return ImmutableMap.of(F_KEY, fExpected, G_KEY, gExpected, H_KEY, hExpected); } /** - * Tests {@link ImmutableComputableGraph#updateCachesForTag(String)}} + * Tests {@link ImmutableComputableGraph#updateCachesForTag(CacheNode.NodeTag)}} */ @Test(dataProvider = "allPossibleNodeFlags", invocationCount = NUM_TRIALS) public void testUpdateCachesByTag(final boolean f_caching, final boolean f_external, @@ -628,17 +632,17 @@ public void testUpdateCachesByTag(final boolean f_caching, final boolean f_exter generateNewRandomFunctionalComposition(); final ImmutableComputableGraph icg_empty = getTestICGBuilder( f_caching, f_external, g_caching, g_external, h_caching, h_external, - getRandomSetOfTags().toArray(new String[0]), getRandomSetOfTags().toArray(new String[0]), - getRandomSetOfTags().toArray(new String[0]), getRandomSetOfTags().toArray(new String[0]), - getRandomSetOfTags().toArray(new String[0]), getRandomSetOfTags().toArray(new String[0])).build(); - - final Set all_x_tags = icg_empty.getComputableGraphStructure().getInducedTagsForNode("x"); - final Set all_y_tags = icg_empty.getComputableGraphStructure().getInducedTagsForNode("y"); - final Set all_z_tags = icg_empty.getComputableGraphStructure().getInducedTagsForNode("z"); - final Set all_f_tags = icg_empty.getComputableGraphStructure().getInducedTagsForNode("f"); - final Set all_g_tags = icg_empty.getComputableGraphStructure().getInducedTagsForNode("g"); - final Set all_h_tags = icg_empty.getComputableGraphStructure().getInducedTagsForNode("h"); - final Set all_tags = new HashSet<>(); + getRandomSetOfTags().toArray(new CacheNode.NodeTag[0]), getRandomSetOfTags().toArray(new CacheNode.NodeTag[0]), + getRandomSetOfTags().toArray(new CacheNode.NodeTag[0]), getRandomSetOfTags().toArray(new CacheNode.NodeTag[0]), + getRandomSetOfTags().toArray(new CacheNode.NodeTag[0]), getRandomSetOfTags().toArray(new CacheNode.NodeTag[0])).build(); + + final Set all_x_tags = icg_empty.getComputableGraphStructure().getInducedTagsForNode(X_KEY); + final Set all_y_tags = icg_empty.getComputableGraphStructure().getInducedTagsForNode(Y_KEY); + final Set all_z_tags = icg_empty.getComputableGraphStructure().getInducedTagsForNode(Z_KEY); + final Set all_f_tags = icg_empty.getComputableGraphStructure().getInducedTagsForNode(F_KEY); + final Set all_g_tags = icg_empty.getComputableGraphStructure().getInducedTagsForNode(G_KEY); + final Set all_h_tags = icg_empty.getComputableGraphStructure().getInducedTagsForNode(H_KEY); + final Set all_tags = new HashSet<>(); all_tags.addAll(all_x_tags); all_tags.addAll(all_y_tags); all_tags.addAll(all_z_tags); all_tags.addAll(all_f_tags); all_tags.addAll(all_g_tags); all_tags.addAll(all_h_tags); @@ -646,13 +650,13 @@ public void testUpdateCachesByTag(final boolean f_caching, final boolean f_exter final double y = getRandomDouble(); final INDArray z = getRandomINDArray(); final ImmutableComputableGraph icg_0 = icg_empty - .setValue("x", new DuplicableNDArray(x)) - .setValue("y", new DuplicableNumber<>(y)) - .setValue("z", new DuplicableNDArray(z)); - final Map expectedComputableNodeValues = getExpectedComputableNodeValues( - icg_0.fetchDirectly("x"), icg_0.fetchDirectly("y"), icg_0.fetchDirectly("z")); + .setValue(X_KEY, new DuplicableNDArray(x)) + .setValue(Y_KEY, new DuplicableNumber<>(y)) + .setValue(Z_KEY, new DuplicableNDArray(z)); + final Map expectedComputableNodeValues = getExpectedComputableNodeValues( + icg_0.fetchDirectly(X_KEY), icg_0.fetchDirectly(Y_KEY), icg_0.fetchDirectly(Z_KEY)); - for (final String tag : all_tags) { + for (final CacheNode.NodeTag tag : all_tags) { ImmutableComputableGraph icg_1; Counter startCounter; try { @@ -669,51 +673,51 @@ public void testUpdateCachesByTag(final boolean f_caching, final boolean f_exter final Counter evalCounts = getCounterInstance().diff(startCounter); /* check updated caches */ - final Set updatedNodesExpected = new HashSet<>(); + final Set updatedNodesExpected = new HashSet<>(); if (!f_external && f_caching && all_f_tags.contains(tag)) { - updatedNodesExpected.add("f"); + updatedNodesExpected.add(F_KEY); } if (!g_external && g_caching && all_g_tags.contains(tag)) { - updatedNodesExpected.add("g"); + updatedNodesExpected.add(G_KEY); } if (!h_external && !f_external && !g_external && h_caching && all_h_tags.contains(tag)) { - updatedNodesExpected.add("h"); + updatedNodesExpected.add(H_KEY); } assertChangedReferences(icg_0, icg_1, updatedNodesExpected); assertIntactReferences(icg_0, icg_1, Sets.difference(ALL_NODES, updatedNodesExpected)); - for (final String nodeKey : updatedNodesExpected) { + for (final CacheNode.NodeKey nodeKey : updatedNodesExpected) { Assert.assertTrue(icg_1.isValueDirectlyAvailable(nodeKey)); MathObjectAsserts.assertNDArrayEquals((INDArray)icg_1.fetchDirectly(nodeKey).value(), expectedComputableNodeValues.get(nodeKey)); } - for (final String nodeKey : Sets.difference(ALL_COMPUTABLE_NODES, updatedNodesExpected)) { + for (final CacheNode.NodeKey nodeKey : Sets.difference(ALL_COMPUTABLE_NODES, updatedNodesExpected)) { Assert.assertTrue(!icg_1.isValueDirectlyAvailable(nodeKey)); } /* check function evaluation counts */ if ((!f_external && all_f_tags.contains(tag)) /* f is computable and caching */ || (all_h_tags.contains(tag) && !f_external && !g_external && !h_external) /* h, as a descendant, is computable */) { - Assert.assertEquals(evalCounts.getCount("f"), 1); + Assert.assertEquals(evalCounts.getCount(F_KEY), 1); } else { - Assert.assertEquals(evalCounts.getCount("f"), 0); + Assert.assertEquals(evalCounts.getCount(F_KEY), 0); } if ((!g_external && all_g_tags.contains(tag)) /* g is computable and caching */ || (all_h_tags.contains(tag) && !g_external && !f_external && !h_external) /* h, as a descendant, is computable */) { - Assert.assertEquals(evalCounts.getCount("g"), 1); + Assert.assertEquals(evalCounts.getCount(G_KEY), 1); } else { - Assert.assertEquals(evalCounts.getCount("g"), 0); + Assert.assertEquals(evalCounts.getCount(G_KEY), 0); } if (all_h_tags.contains(tag) && !f_external && !g_external && !h_external) { - Assert.assertEquals(evalCounts.getCount("h"), 1); + Assert.assertEquals(evalCounts.getCount(H_KEY), 1); } else { - Assert.assertEquals(evalCounts.getCount("h"), 0); + Assert.assertEquals(evalCounts.getCount(H_KEY), 0); } } } /** - * Tests {@link ImmutableComputableGraph#updateCachesForNode(String)}} + * Tests {@link ImmutableComputableGraph#updateCachesForNode(CacheNode.NodeKey)}} */ @Test(dataProvider = "allPossibleNodeFlags", invocationCount = NUM_TRIALS) public void testUpdateCacheByNode(final boolean f_caching, final boolean f_external, @@ -727,13 +731,13 @@ public void testUpdateCacheByNode(final boolean f_caching, final boolean f_exter final double y = getRandomDouble(); final INDArray z = getRandomINDArray(); final ImmutableComputableGraph icg_0 = icg_empty - .setValue("x", new DuplicableNDArray(x)) - .setValue("y", new DuplicableNumber<>(y)) - .setValue("z", new DuplicableNDArray(z)); - final Map expectedComputableNodeValues = getExpectedComputableNodeValues( - icg_0.fetchDirectly("x"), icg_0.fetchDirectly("y"), icg_0.fetchDirectly("z")); + .setValue(X_KEY, new DuplicableNDArray(x)) + .setValue(Y_KEY, new DuplicableNumber<>(y)) + .setValue(Z_KEY, new DuplicableNDArray(z)); + final Map expectedComputableNodeValues = getExpectedComputableNodeValues( + icg_0.fetchDirectly(X_KEY), icg_0.fetchDirectly(Y_KEY), icg_0.fetchDirectly(Z_KEY)); - for (final String nodeKey : ALL_PRIMITIVE_NODES) { + for (final CacheNode.NodeKey nodeKey : ALL_PRIMITIVE_NODES) { Counter startCounter = getCounterInstance(); ImmutableComputableGraph icg_1 = icg_0.updateCachesForNode(nodeKey); final Counter evalCounts = getCounterInstance().diff(startCounter); @@ -745,115 +749,115 @@ public void testUpdateCacheByNode(final boolean f_caching, final boolean f_exter Counter startCounter; Counter diff; - /* tests for "f" */ + /* tests for F_KEY */ try { startCounter = getCounterInstance(); - icg_1 = icg_0.updateCachesForNode("f"); + icg_1 = icg_0.updateCachesForNode(F_KEY); } catch (final Exception ex) { /* should fail only if some of the tagged nodes are external */ if (!f_external && !g_external && !h_external) { throw new AssertionError("Could not update tagged nodes but it should have been possible"); } startCounter = getCounterInstance(); - icg_1 = icg_0.updateCachesForNodeIfPossible("f"); + icg_1 = icg_0.updateCachesForNodeIfPossible(F_KEY); } diff = getCounterInstance().diff(startCounter); - assertIntactReferences(icg_0, icg_1, "x", "y", "z", "g"); + assertIntactReferences(icg_0, icg_1, X_KEY, Y_KEY, Z_KEY, G_KEY); if (f_external) { assertIntactReferences(icg_0, icg_1, ALL_NODES); diff.assertZero(); } else { - Assert.assertEquals(diff.getCount("f"), 1); - Assert.assertEquals(diff.getCount("g"), 0); - Assert.assertEquals(diff.getCount("h"), 0); + Assert.assertEquals(diff.getCount(F_KEY), 1); + Assert.assertEquals(diff.getCount(G_KEY), 0); + Assert.assertEquals(diff.getCount(H_KEY), 0); if (f_caching) { - Assert.assertTrue(icg_1.isValueDirectlyAvailable("f")); - MathObjectAsserts.assertNDArrayEquals((INDArray)icg_1.fetchDirectly("f").value(), - expectedComputableNodeValues.get("f")); + Assert.assertTrue(icg_1.isValueDirectlyAvailable(F_KEY)); + MathObjectAsserts.assertNDArrayEquals((INDArray)icg_1.fetchDirectly(F_KEY).value(), + expectedComputableNodeValues.get(F_KEY)); } else { final Counter before = getCounterInstance(); - Assert.assertTrue(!icg_1.isValueDirectlyAvailable("f")); - MathObjectAsserts.assertNDArrayEquals((INDArray)icg_1.fetchWithRequiredEvaluations("f").value(), - expectedComputableNodeValues.get("f")); + Assert.assertTrue(!icg_1.isValueDirectlyAvailable(F_KEY)); + MathObjectAsserts.assertNDArrayEquals((INDArray)icg_1.fetchWithRequiredEvaluations(F_KEY).value(), + expectedComputableNodeValues.get(F_KEY)); final Counter diff2 = getCounterInstance().diff(before); - Assert.assertEquals(diff2.getCount("f"), 1); - Assert.assertEquals(diff2.getCount("g"), 0); - Assert.assertEquals(diff2.getCount("h"), 0); + Assert.assertEquals(diff2.getCount(F_KEY), 1); + Assert.assertEquals(diff2.getCount(G_KEY), 0); + Assert.assertEquals(diff2.getCount(H_KEY), 0); } } - /* tests for "g" */ + /* tests for G_KEY */ try { startCounter = getCounterInstance(); - icg_1 = icg_0.updateCachesForNode("g"); + icg_1 = icg_0.updateCachesForNode(G_KEY); } catch (final Exception ex) { /* should fail only if some of the tagged nodes are external */ if (!f_external && !g_external && !h_external) { throw new AssertionError("Could not update tagged nodes but it should have been possible"); } startCounter = getCounterInstance(); - icg_1 = icg_0.updateCachesForNodeIfPossible("g"); + icg_1 = icg_0.updateCachesForNodeIfPossible(G_KEY); } diff = getCounterInstance().diff(startCounter); - assertIntactReferences(icg_0, icg_1, "x", "y", "z", "f"); + assertIntactReferences(icg_0, icg_1, X_KEY, Y_KEY, Z_KEY, F_KEY); if (g_external) { assertIntactReferences(icg_0, icg_1, ALL_NODES); diff.assertZero(); } else { - Assert.assertEquals(diff.getCount("f"), 0); - Assert.assertEquals(diff.getCount("g"), 1); - Assert.assertEquals(diff.getCount("h"), 0); + Assert.assertEquals(diff.getCount(F_KEY), 0); + Assert.assertEquals(diff.getCount(G_KEY), 1); + Assert.assertEquals(diff.getCount(H_KEY), 0); if (g_caching) { - Assert.assertTrue(icg_1.isValueDirectlyAvailable("g")); - MathObjectAsserts.assertNDArrayEquals((INDArray)icg_1.fetchDirectly("g").value(), - expectedComputableNodeValues.get("g")); + Assert.assertTrue(icg_1.isValueDirectlyAvailable(G_KEY)); + MathObjectAsserts.assertNDArrayEquals((INDArray)icg_1.fetchDirectly(G_KEY).value(), + expectedComputableNodeValues.get(G_KEY)); } else { - Assert.assertTrue(!icg_1.isValueDirectlyAvailable("g")); + Assert.assertTrue(!icg_1.isValueDirectlyAvailable(G_KEY)); final Counter before = getCounterInstance(); - MathObjectAsserts.assertNDArrayEquals((INDArray)icg_1.fetchWithRequiredEvaluations("g").value(), - expectedComputableNodeValues.get("g")); + MathObjectAsserts.assertNDArrayEquals((INDArray)icg_1.fetchWithRequiredEvaluations(G_KEY).value(), + expectedComputableNodeValues.get(G_KEY)); final Counter diff2 = getCounterInstance().diff(before); - Assert.assertEquals(diff2.getCount("f"), 0); - Assert.assertEquals(diff2.getCount("g"), 1); - Assert.assertEquals(diff2.getCount("h"), 0); + Assert.assertEquals(diff2.getCount(F_KEY), 0); + Assert.assertEquals(diff2.getCount(G_KEY), 1); + Assert.assertEquals(diff2.getCount(H_KEY), 0); } } - /* tests for "h" */ + /* tests for H_KEY */ try { startCounter = getCounterInstance(); - icg_1 = icg_0.updateCachesForNode("h"); + icg_1 = icg_0.updateCachesForNode(H_KEY); } catch (final Exception ex) { /* should fail only if some of the tagged nodes are external */ if (!f_external && !g_external && !h_external) { throw new AssertionError("Could not update tagged nodes but it should have been possible"); } startCounter = getCounterInstance(); - icg_1 = icg_0.updateCachesForNodeIfPossible("h"); + icg_1 = icg_0.updateCachesForNodeIfPossible(H_KEY); } diff = getCounterInstance().diff(startCounter); - assertIntactReferences(icg_0, icg_1, "x", "y", "z"); + assertIntactReferences(icg_0, icg_1, X_KEY, Y_KEY, Z_KEY); if (h_external && f_external && g_external) { assertIntactReferences(icg_0, icg_1, ALL_NODES); diff.assertZero(); } else if (!h_external && !f_external && !g_external) { - Assert.assertEquals(diff.getCount("f"), 1); - Assert.assertEquals(diff.getCount("g"), 1); - Assert.assertEquals(diff.getCount("h"), 1); + Assert.assertEquals(diff.getCount(F_KEY), 1); + Assert.assertEquals(diff.getCount(G_KEY), 1); + Assert.assertEquals(diff.getCount(H_KEY), 1); if (h_caching) { - Assert.assertTrue(icg_1.isValueDirectlyAvailable("h")); + Assert.assertTrue(icg_1.isValueDirectlyAvailable(H_KEY)); MathObjectAsserts.assertNDArrayEquals( - (INDArray)icg_1.fetchDirectly("h").value(), - expectedComputableNodeValues.get("h")); + (INDArray)icg_1.fetchDirectly(H_KEY).value(), + expectedComputableNodeValues.get(H_KEY)); } else { - Assert.assertTrue(!icg_1.isValueDirectlyAvailable("h")); + Assert.assertTrue(!icg_1.isValueDirectlyAvailable(H_KEY)); final Counter before = getCounterInstance(); MathObjectAsserts.assertNDArrayEquals( - (INDArray)icg_1.fetchWithRequiredEvaluations("h").value(), - expectedComputableNodeValues.get("h")); + (INDArray)icg_1.fetchWithRequiredEvaluations(H_KEY).value(), + expectedComputableNodeValues.get(H_KEY)); final Counter diff2 = getCounterInstance().diff(before); - Assert.assertEquals(diff2.getCount("f"), f_caching ? 0 : 1); - Assert.assertEquals(diff2.getCount("g"), g_caching ? 0 : 1); - Assert.assertEquals(diff2.getCount("h"), 1); + Assert.assertEquals(diff2.getCount(F_KEY), f_caching ? 0 : 1); + Assert.assertEquals(diff2.getCount(G_KEY), g_caching ? 0 : 1); + Assert.assertEquals(diff2.getCount(H_KEY), 1); } } } @@ -861,8 +865,8 @@ public void testUpdateCacheByNode(final boolean f_caching, final boolean f_exter @Test public void testUninitializedPrimitiveNode() { final ImmutableComputableGraph icg = getTestICGBuilder(true, false, true, false, true, false).build() - .setValue("x", new DuplicableNDArray(getRandomINDArray())) - .setValue("y", new DuplicableNumber<>(getRandomDouble())); + .setValue(X_KEY, new DuplicableNDArray(getRandomINDArray())) + .setValue(Y_KEY, new DuplicableNumber<>(getRandomDouble())); boolean failed = false; try { icg.updateAllCaches(); @@ -873,11 +877,11 @@ public void testUninitializedPrimitiveNode() { throw new AssertionError("Expected PrimitiveValueNotInitializedException but it was not thrown"); } - icg.updateCachesForNode("f"); /* should not fail */ + icg.updateCachesForNode(F_KEY); /* should not fail */ failed = false; try { - icg.updateCachesForNode("g"); + icg.updateCachesForNode(G_KEY); } catch (final PrimitiveCacheNode.PrimitiveValueNotInitializedException ex) { failed = true; } @@ -887,7 +891,7 @@ public void testUninitializedPrimitiveNode() { failed = false; try { - icg.updateCachesForNode("h"); + icg.updateCachesForNode(H_KEY); } catch (final PrimitiveCacheNode.PrimitiveValueNotInitializedException ex) { failed = true; } @@ -899,9 +903,9 @@ public void testUninitializedPrimitiveNode() { @Test public void testExternallyComputedNode() { final ImmutableComputableGraph icg = getTestICGBuilder(true, true, true, false, true, false).build() - .setValue("x", new DuplicableNDArray(getRandomINDArray())) - .setValue("y", new DuplicableNumber<>(getRandomDouble())) - .setValue("z", new DuplicableNDArray(getRandomINDArray())); + .setValue(X_KEY, new DuplicableNDArray(getRandomINDArray())) + .setValue(Y_KEY, new DuplicableNumber<>(getRandomDouble())) + .setValue(Z_KEY, new DuplicableNDArray(getRandomINDArray())); boolean failed = false; try { icg.updateAllCaches(); @@ -912,11 +916,11 @@ public void testExternallyComputedNode() { throw new AssertionError("Expected ExternallyComputableNodeValueUnavailableException but it was not thrown"); } - icg.updateCachesForNode("g"); /* should not fail */ + icg.updateCachesForNode(G_KEY); /* should not fail */ failed = false; try { - icg.updateCachesForNode("f"); + icg.updateCachesForNode(F_KEY); } catch (final ComputableCacheNode.ExternallyComputableNodeValueUnavailableException ex) { failed = true; } @@ -926,7 +930,7 @@ public void testExternallyComputedNode() { failed = false; try { - icg.updateCachesForNode("h"); + icg.updateCachesForNode(H_KEY); } catch (final ComputableCacheNode.ExternallyComputableNodeValueUnavailableException ex) { failed = true; } @@ -935,187 +939,67 @@ public void testExternallyComputedNode() { } /* supply f */ - ImmutableComputableGraph icg_1 = icg.setValue("f", f_computation_function.apply( - ImmutableMap.of("x", icg.fetchDirectly("x"), "y", icg.fetchDirectly("y")))); - Assert.assertTrue(icg_1.isValueDirectlyAvailable("f")); + ImmutableComputableGraph icg_1 = icg.setValue(F_KEY, f_computation_function.apply( + ImmutableMap.of(X_KEY, icg.fetchDirectly(X_KEY), Y_KEY, icg.fetchDirectly(Y_KEY)))); + Assert.assertTrue(icg_1.isValueDirectlyAvailable(F_KEY)); /* cache g */ - Assert.assertTrue(!icg_1.isValueDirectlyAvailable("g")); + Assert.assertTrue(!icg_1.isValueDirectlyAvailable(G_KEY)); Counter before = getCounterInstance(); - icg_1 = icg_1.updateCachesForNode("g"); - Assert.assertTrue(icg_1.isValueDirectlyAvailable("g")); + icg_1 = icg_1.updateCachesForNode(G_KEY); + Assert.assertTrue(icg_1.isValueDirectlyAvailable(G_KEY)); Counter diff = getCounterInstance().diff(before); - Assert.assertEquals(diff.getCount("f"), 0); - Assert.assertEquals(diff.getCount("g"), 1); - Assert.assertEquals(diff.getCount("h"), 0); + Assert.assertEquals(diff.getCount(F_KEY), 0); + Assert.assertEquals(diff.getCount(G_KEY), 1); + Assert.assertEquals(diff.getCount(H_KEY), 0); /* cache h -- now, it is computable */ - Assert.assertTrue(!icg_1.isValueDirectlyAvailable("h")); + Assert.assertTrue(!icg_1.isValueDirectlyAvailable(H_KEY)); before = getCounterInstance(); - icg_1 = icg_1.updateCachesForNode("h"); - Assert.assertTrue(icg_1.isValueDirectlyAvailable("h")); + icg_1 = icg_1.updateCachesForNode(H_KEY); + Assert.assertTrue(icg_1.isValueDirectlyAvailable(H_KEY)); diff = getCounterInstance().diff(before); - Assert.assertEquals(diff.getCount("f"), 0); - Assert.assertEquals(diff.getCount("g"), 0); - Assert.assertEquals(diff.getCount("h"), 1); + Assert.assertEquals(diff.getCount(F_KEY), 0); + Assert.assertEquals(diff.getCount(G_KEY), 0); + Assert.assertEquals(diff.getCount(H_KEY), 1); /* updating all caches must have no effect */ before = getCounterInstance(); ImmutableComputableGraph icg_2 = icg_1.updateAllCaches(); getCounterInstance().diff(before).assertZero(); - Assert.assertTrue(icg_2.isValueDirectlyAvailable("f")); - Assert.assertTrue(icg_2.isValueDirectlyAvailable("g")); - Assert.assertTrue(icg_2.isValueDirectlyAvailable("h")); + Assert.assertTrue(icg_2.isValueDirectlyAvailable(F_KEY)); + Assert.assertTrue(icg_2.isValueDirectlyAvailable(G_KEY)); + Assert.assertTrue(icg_2.isValueDirectlyAvailable(H_KEY)); assertIntactReferences(icg_1, icg_2, ALL_NODES); } - /** - * Asserts that the equality comparison of two {@link CacheNode}s is done just based on their key - */ - @Test - public void testEquality() { - final List nodesWithOneKey = getRandomCollectionOfNodesWithTheSameKey("ONE_KEY"); - final List nodesWithAnotherKey = getRandomCollectionOfNodesWithTheSameKey("ANOTHER_KEY"); - - for (final CacheNode node_0 : nodesWithOneKey) { - for (final CacheNode node_1 : nodesWithOneKey) { - Assert.assertTrue(node_0.equals(node_1) || (node_0.getClass() != node_1.getClass())); - } - } - - for (final CacheNode node_0 : nodesWithAnotherKey) { - for (final CacheNode node_1 : nodesWithAnotherKey) { - Assert.assertTrue(node_0.equals(node_1) || (node_0.getClass() != node_1.getClass())); - } - } - - for (final CacheNode node_0 : nodesWithOneKey) { - for (final CacheNode node_1 : nodesWithAnotherKey) { - Assert.assertTrue(!node_0.equals(node_1)); - } - } - } - - @Test - public void testToString() { - final List nodesWithOneKey = getRandomCollectionOfNodesWithTheSameKey("ONE_KEY"); - for (final CacheNode node : nodesWithOneKey) { - Assert.assertTrue(node.toString().equals("ONE_KEY")); - } - } - - private List getRandomCollectionOfNodesWithTheSameKey(final String key) { - final List collection = new ArrayList<>(); - collection.add(new PrimitiveCacheNode(key, EMPTY_STRING_LIST, null)); - collection.add(new PrimitiveCacheNode(key, Arrays.asList("a", "b", "c"), new DuplicableNumber<>(1.0))); - collection.add(new ComputableCacheNode(key, EMPTY_STRING_LIST, Arrays.asList("d", "e"), - f_computation_function, false)); - collection.add(new ComputableCacheNode(key, EMPTY_STRING_LIST, EMPTY_STRING_LIST, - f_computation_function, false)); - collection.add(new ComputableCacheNode(key, Arrays.asList("f"), Arrays.asList("g"), null, true)); - return collection; - } - - @Test(expectedExceptions = UnsupportedOperationException.class) - public void testSetValueOfAutomaticallyComputableNode() { - new ComputableCacheNode("TEST", EMPTY_STRING_LIST, EMPTY_STRING_LIST, h_computation_function, false) - .set(new DuplicableNDArray(getRandomINDArray())); - } - - @Test - public void testSetValueOfExternallyComputableNode() { - final ComputableCacheNode node = new ComputableCacheNode("TEST", EMPTY_STRING_LIST, EMPTY_STRING_LIST, null, true); - final INDArray arr = getRandomINDArray(); - node.set(new DuplicableNDArray(arr)); - MathObjectAsserts.assertNDArrayEquals((INDArray)node.get(EMPTY_PARENTS).value(), arr); - } - - @Test - public void testSetValueOfPrimitiveNode() { - final PrimitiveCacheNode node = new PrimitiveCacheNode("TEST", EMPTY_STRING_LIST, null); - final INDArray arr = getRandomINDArray(); - node.set(new DuplicableNDArray(arr)); - MathObjectAsserts.assertNDArrayEquals((INDArray)node.get(EMPTY_PARENTS).value(), arr); - } - - @Test - public void testPrimitiveNodeDuplication() { - final PrimitiveCacheNode node = new PrimitiveCacheNode("TEST", EMPTY_STRING_LIST, - new DuplicableNDArray(getRandomINDArray())); - final PrimitiveCacheNode dupNode = node.duplicate(); - MathObjectAsserts.assertNDArrayEquals((INDArray)node.get(EMPTY_PARENTS).value(), - (INDArray)dupNode.get(EMPTY_PARENTS).value()); - Assert.assertTrue(dupNode.hasValue()); - Assert.assertTrue(dupNode.getKey().equals("TEST")); - } - - @Test - public void testCachingComputableNodeDuplication() { - final INDArray testArray = getRandomINDArray(); - final Duplicable testDuplicable = new DuplicableNDArray(testArray); - final ComputableNodeFunction trivialFunction = parents -> testDuplicable; - - final ComputableCacheNode cachingAutoNodeUncached = new ComputableCacheNode("TEST", EMPTY_STRING_LIST, - EMPTY_STRING_LIST, trivialFunction, true); - final ComputableCacheNode cachingAutoNodeUncachedDup = cachingAutoNodeUncached.duplicate(); - - final ComputableCacheNode cachingAutoNodeCached = cachingAutoNodeUncached.duplicateWithUpdatedValue(testDuplicable); - final ComputableCacheNode cachingAutoNodeCachedDup = cachingAutoNodeCached.duplicate(); - - final ComputableCacheNode cachingAutoNodeCachedOutdated = cachingAutoNodeCached.duplicateWithOutdatedCacheStatus(); - final ComputableCacheNode cachingAutoNodeCachedOutdatedDup = cachingAutoNodeCachedOutdated.duplicate(); - - Assert.assertTrue(cachingAutoNodeUncached.isCaching()); - Assert.assertTrue(cachingAutoNodeUncachedDup.isCaching()); - Assert.assertTrue(!cachingAutoNodeUncached.isExternallyComputed()); - Assert.assertTrue(!cachingAutoNodeUncachedDup.isExternallyComputed()); - Assert.assertTrue(!cachingAutoNodeUncached.hasValue()); - Assert.assertTrue(!cachingAutoNodeUncachedDup.hasValue()); - - Assert.assertTrue(cachingAutoNodeCached.isCaching()); - Assert.assertTrue(cachingAutoNodeCachedDup.isCaching()); - Assert.assertTrue(!cachingAutoNodeCached.isExternallyComputed()); - Assert.assertTrue(!cachingAutoNodeCachedDup.isExternallyComputed()); - Assert.assertTrue(cachingAutoNodeCached.hasValue()); - Assert.assertTrue(cachingAutoNodeCachedDup.hasValue()); - MathObjectAsserts.assertNDArrayEquals((INDArray)cachingAutoNodeCached.get(EMPTY_PARENTS).value(), - (INDArray)cachingAutoNodeCachedDup.get(EMPTY_PARENTS).value()); - - Assert.assertTrue(cachingAutoNodeCachedOutdated.isCaching()); - Assert.assertTrue(cachingAutoNodeCachedOutdatedDup.isCaching()); - Assert.assertTrue(!cachingAutoNodeCachedOutdated.isExternallyComputed()); - Assert.assertTrue(!cachingAutoNodeCachedOutdatedDup.isExternallyComputed()); - Assert.assertTrue(!cachingAutoNodeCachedOutdated.hasValue()); /* outdated caches must drop out */ - Assert.assertTrue(!cachingAutoNodeCachedOutdatedDup.hasValue()); /* outdated caches must drop out */ - } - /** * A simple helper class for keeping track of ICG function evaluations */ private static class Counter { - final Map counts; + final Map counts; - Counter(String ... keys) { + Counter(CacheNode.NodeKey ... keys) { counts = new HashMap<>(); - for (final String key : keys) { + for (final CacheNode.NodeKey key : keys) { counts.put(key, 0); } } - private Counter(final Map otherCounts) { + private Counter(final Map otherCounts) { counts = new HashMap<>(otherCounts.size()); counts.putAll(otherCounts); } - void increment(final String key) { + void increment(final CacheNode.NodeKey key) { counts.put(key, getCount(key) + 1); } - int getCount(final String key) { + int getCount(final CacheNode.NodeKey key) { return counts.get(key); } - Set getKeys() { + Set getKeys() { return counts.keySet(); } @@ -1126,7 +1010,7 @@ Counter copy() { Counter diff(final Counter oldCounter) { Utils.validateArg(Sets.symmetricDifference(oldCounter.getKeys(), getKeys()).isEmpty(), "the counters must have the same keys"); - final Map diffMap = new HashMap<>(getKeys().size()); + final Map diffMap = new HashMap<>(getKeys().size()); getKeys().forEach(key -> diffMap.put(key, getCount(key) - oldCounter.getCount(key))); return new Counter(diffMap); } diff --git a/src/test/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/ImmutableComputableGraphUtilsUnitTest.java b/src/test/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/ImmutableComputableGraphUtilsUnitTest.java index 6122b0804..b4bd8c385 100644 --- a/src/test/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/ImmutableComputableGraphUtilsUnitTest.java +++ b/src/test/java/org/broadinstitute/hellbender/tools/coveragemodel/cachemanager/ImmutableComputableGraphUtilsUnitTest.java @@ -10,19 +10,21 @@ */ public class ImmutableComputableGraphUtilsUnitTest extends BaseTest { + private static final CacheNode.NodeKey X_KEY = new CacheNode.NodeKey("x"); + @Test(expectedExceptions = ImmutableComputableGraphUtils.ImmutableComputableGraphBuilder.DuplicateNodeKeyException.class) public void testDuplicatePrimitiveNode() { ImmutableComputableGraph.builder() - .primitiveNode("x", new String[] {}, new DuplicableNDArray()) - .primitiveNode("x", new String[] {}, new DuplicableNDArray()) + .primitiveNode(X_KEY, new CacheNode.NodeTag[] {}, new DuplicableNDArray()) + .primitiveNode(X_KEY, new CacheNode.NodeTag[] {}, new DuplicableNDArray()) .build(); } @Test(expectedExceptions = ImmutableComputableGraphUtils.ImmutableComputableGraphBuilder.DuplicateNodeKeyException.class) public void testDuplicateComputableNode() { ImmutableComputableGraph.builder() - .computableNode("x", new String[] {}, new String[] {}, null, true) - .computableNode("x", new String[] {}, new String[] {}, null, true) + .computableNode(X_KEY, new CacheNode.NodeTag[] {}, new CacheNode.NodeKey[] {}, null, true) + .computableNode(X_KEY, new CacheNode.NodeTag[] {}, new CacheNode.NodeKey[] {}, null, true) .build(); } }