Skip to content

Commit

Permalink
Merge branch 'patch-1.1.6' into v1.2-dev
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Jul 20, 2020
1 parent c073eca commit ca778da
Show file tree
Hide file tree
Showing 42 changed files with 856 additions and 542 deletions.
6 changes: 5 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ python:
jobs:
allow_failures:
- python: "3.8"
include:
- stage: lint
python: "3.8"
script: flake8 tests gunpowder

before_install:
- sudo apt-get update
Expand All @@ -23,7 +27,7 @@ before_install:
- source activate gp-env

install:
- make install-full
- make install-dev

script:
- make test
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ install-full:

.PHONY: install-dev
install-dev:
pip install -r requirements-dev.txt
pip install -e .[full]

.PHONY: test
test:
python -m tests -v
pytest -v --cov gunpowder

.PHONY: publish
publish:
Expand Down
4 changes: 2 additions & 2 deletions gunpowder/batch_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def __init__(self, *args, random_seed=None, **kwargs):
)
super().__init__(*args, **kwargs)

def add(self, key, shape, voxel_size=None, placeholder=False):
def add(self, key, shape, voxel_size=None, directed=None, placeholder=False):
"""Convenience method to add an array or graph spec by providing only
the shape of a ROI (in world units).
Expand All @@ -60,7 +60,7 @@ def add(self, key, shape, voxel_size=None, placeholder=False):
if isinstance(key, ArrayKey):
spec = ArraySpec(placeholder=placeholder)
elif isinstance(key, GraphKey):
spec = GraphSpec(placeholder=placeholder)
spec = GraphSpec(placeholder=placeholder, directed=directed)
else:
raise RuntimeError("Only ArrayKey or GraphKey can be added.")

Expand Down
274 changes: 179 additions & 95 deletions gunpowder/contrib/nodes/add_vector_map.py

Large diffs are not rendered by default.

30 changes: 19 additions & 11 deletions gunpowder/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,19 @@ def spec(self, new_spec):

@property
def directed(self):
return self.spec.directed
return (
self.spec.directed
if self.spec.directed is not None
else self.__graph.is_directed()
)

def create_graph(self, nodes: Iterator[Node], edges: Iterator[Edge]):
if self.directed:
if self.__spec.directed is None:
logger.warning(
"Trying to create a Graph without specifying directionality. Using default Directed!"
)
graph = nx.DiGraph()
elif self.__spec.directed:
graph = nx.DiGraph()
else:
graph = nx.Graph()
Expand All @@ -238,6 +247,9 @@ def nodes(self):
def num_vertices(self):
return self.__graph.number_of_nodes()

def num_edges(self):
return self.__graph.number_of_edges()

@property
def edges(self):
for (u, v), attrs in self.__graph.edges.items():
Expand Down Expand Up @@ -297,7 +309,6 @@ def remove_node(self, node: Node, retain_connectivity=False):
self.add_edge(Edge(pred_id, succ_id))
self.__graph.remove_node(node.id)


def add_node(self, node: Node):
"""
Adds a node to the graph.
Expand Down Expand Up @@ -562,14 +573,11 @@ def from_nx_graph(cls, graph, spec):
"""
Create a gunpowder graph from a networkx graph
"""
nodes = [
Node(id=node, location=attrs["location"], attrs=attrs)
for node, attrs in graph.nodes().items()
]
edges = [Edge(u, v) for u, v in graph.edges]
directed = graph.is_directed()
spec.directed = directed
return cls(nodes, edges, spec)
if spec.directed is None:
spec.directed = graph.is_directed()
g = cls([], [], spec)
g.__graph = graph
return g

def relabel_connected_components(self):
"""
Expand Down
9 changes: 7 additions & 2 deletions gunpowder/graph_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,17 @@ class GraphSpec(Freezable):
The region of interested represented by this graph.
directed (``bool``):
directed (``bool``, optional):
Whether the graph is directed or not.
dtype (``dtype``, optional):
The data type of the "location" attribute.
Currently only supports np.float32.
"""

def __init__(self, roi=None, directed=True, dtype=np.float32, placeholder=False):
def __init__(self, roi=None, directed=None, dtype=np.float32, placeholder=False):

self.roi = roi
self.directed = directed
Expand Down
4 changes: 2 additions & 2 deletions gunpowder/nodes/balance_labels.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .batch_filter import BatchFilter
from gunpowder.array import Array
from gunpowder.batch_request import BatchRequest
import collections
from collections.abc import Iterable
import itertools
import logging
import numpy as np
Expand Down Expand Up @@ -69,7 +69,7 @@ def __init__(self, labels, scales, mask=None, slab=None, num_classes=2,
self.scales = scales
if mask is None:
self.masks = []
elif not isinstance(mask, collections.Iterable):
elif not isinstance(mask, Iterable):
self.masks = [mask]
else:
self.masks = mask
Expand Down
46 changes: 31 additions & 15 deletions gunpowder/nodes/batch_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,17 @@
logger = logging.getLogger(__name__)


class BatchFilterError(Exception):

def __init__(self, batch_filter, msg):
self.batch_filter = batch_filter
self.msg = msg

def __str__(self):

return f"Error in {self.batch_filter.name()}: {self.msg}"


class BatchFilter(BatchProvider):
"""Convenience wrapper for :class:`BatchProviders<BatchProvider>` with
exactly one input provider.
Expand Down Expand Up @@ -41,9 +52,12 @@ def remove_placeholders(self):
return self._remove_placeholders

def get_upstream_provider(self):
assert (
len(self.get_upstream_providers()) == 1
), "BatchFilters need to have exactly one upstream provider"
if len(self.get_upstream_providers()) != 1:
raise BatchFilterError(
self,
"BatchFilters need to have exactly one upstream provider, "
f"this one has {len(self.get_upstream_providers())}: "
f"({[b.name() for b in self.get_upstream_providers()]}")
return self.get_upstream_providers()[0]

def updates(self, key, spec):
Expand All @@ -63,10 +77,12 @@ def updates(self, key, spec):
The updated spec of the array or point set.
"""

assert key in self.spec, (
"Node %s is trying to change the spec for %s, but is not provided upstream."
% (type(self).__name__, key)
)
if key not in self.spec:
raise BatchFilterError(
self,
f"BatchFilter {self} is trying to change the spec for {key}, "
f"but {key} is not provided upstream. Upstream offers: "
f"{self.get_upstream_provider().spec}")
self.spec[key] = copy.deepcopy(spec)
self.updated_items.append(key)

Expand Down Expand Up @@ -139,11 +155,11 @@ def provide(self, request):
elif dependencies is None:
upstream_request = request.copy()
else:
raise Exception(
f"{self.__class__} returned a {type(dependencies)}! "
f"Supported return types are: `BatchRequest` containing your exact "
f"dependencies or `None`, indicating a dependency on the full request."
)
raise BatchFilterError(
self,
f"This BatchFilter returned a {type(dependencies)}! "
"Supported return types are: `BatchRequest` containing your exact "
"dependencies or `None`, indicating a dependency on the full request.")
self.remove_provided(upstream_request)
else:
upstream_request = request.copy()
Expand Down Expand Up @@ -233,6 +249,6 @@ def process(self, batch, request):
The request this node received. The updated batch should meet
this request.
"""
raise RuntimeError(
"Class %s does not implement 'process'" % type(self).__name__
)
raise BatchFilterError(
self,
"does not implement 'process'")
63 changes: 50 additions & 13 deletions gunpowder/nodes/batch_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,27 @@
from gunpowder.provider_spec import ProviderSpec
from gunpowder.array import ArrayKey
from gunpowder.array_spec import ArraySpec
from gunpowder.graph import GraphKey
from gunpowder.graph_spec import GraphSpec

logger = logging.getLogger(__name__)


class BatchRequestError(Exception):

def __init__(self, provider, request, batch):
self.provider = provider
self.request = request
self.batch = batch

def __str__(self):

return \
f"Exception in {self.provider.name()} while processing request" \
f"{self.request} \n" \
"Batch returned so far:\n" \
f"{self.batch}"

class BatchProvider(object):
'''Superclass for all nodes in a `gunpowder` graph.
Expand Down Expand Up @@ -154,26 +170,34 @@ def request_batch(self, request):
:class:`GraphSpecs<GraphSpec>`.
'''

logger.debug("%s got request %s", self.name(), request)
batch = None

try:

request._update_random_seed()

request._update_random_seed()
self.set_seeds(request)

self.set_seeds(request)
logger.debug("%s got request %s", self.name(), request)

self.check_request_consistency(request)
self.check_request_consistency(request)

upstream_request = request.copy()
if self.remove_placeholders:
upstream_request.remove_placeholders()
batch = self.provide(upstream_request)
upstream_request = request.copy()
if self.remove_placeholders:
upstream_request.remove_placeholders()
batch = self.provide(upstream_request)

request.remove_placeholders()
request.remove_placeholders()

self.check_batch_consistency(batch, request)
self.check_batch_consistency(batch, request)

self.remove_unneeded(batch, request)
self.remove_unneeded(batch, request)

logger.debug("%s provides %s", self.name(), batch)
logger.debug("%s provides %s", self.name(), batch)

except Exception as e:

raise BatchRequestError(self, request, batch) from e

return batch

Expand Down Expand Up @@ -225,6 +249,13 @@ def check_request_consistency(self, request):
key,
provided_spec.voxel_size[d])

if isinstance(key, GraphKey):

if request_spec.directed is not None:
assert request_spec.directed == provided_spec.directed, (
f"asked for {key}: directed={request_spec.directed} but "
f"{self.name()} provides directed={provided_spec.directed}"
)
def check_batch_consistency(self, batch, request):

for (array_key, request_spec) in request.array_specs.items():
Expand Down Expand Up @@ -274,9 +305,15 @@ def check_batch_consistency(self, batch, request):
graph.spec.roi,
self.name())

if request_spec.directed is not None:
assert request_spec.directed == graph.directed, (
f"Recieved {graph_key}: directed={graph.directed} but "
f"{self.name()} should provide directed={request_spec.directed}"
)

for node in graph.nodes:
contained = graph.spec.roi.contains(node.location)
dangling = not contained or all(
dangling = not contained and all(
[
graph.spec.roi.contains(v.location)
for v in graph.neighbors(node)
Expand Down
2 changes: 1 addition & 1 deletion gunpowder/nodes/pad.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def process(self, batch, request):

else:

points = batch.points[self.key]
points = batch.graphs[self.key]
points.spec.roi = request[self.key].roi

def __expand(self, a, from_roi, to_roi, value):
Expand Down
13 changes: 6 additions & 7 deletions gunpowder/nodes/scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from gunpowder.array import Array
from gunpowder.batch import Batch
from gunpowder.coordinate import Coordinate
from gunpowder.points import Points
from gunpowder.graph import Graph
from gunpowder.producer_pool import ProducerPool
from gunpowder.roi import Roi
from .batch_filter import BatchFilter
Expand Down Expand Up @@ -175,12 +175,11 @@ def __get_shift_roi(self, spec):

logger.debug("upstream ROI is %s", spec[key].roi)

for r, s in zip(
reference_spec.roi.get_shape(),
spec[key].roi.get_shape()):
assert r <= s, (
for r, s in zip(reference_spec.roi.get_shape(), spec[key].roi.get_shape()):
assert s is None or r <= s, (
"reference %s with ROI %s does not fit into provided "
"upstream %s"%(key, reference_spec.roi, spec[key].roi))
"upstream %s" % (key, reference_spec.roi, spec[key].roi)
)

# we have a reference ROI
#
Expand Down Expand Up @@ -349,7 +348,7 @@ def __setup_batch(self, batch_spec, chunk):
roi = spec.roi
spec = self.spec[graph_key].copy()
spec.roi = roi
batch.graphs[graph_key] = Points(data={}, spec=spec)
batch.graphs[graph_key] = Graph(nodes=[], edges=[], spec=spec)

logger.debug("setup batch to fill %s", batch)

Expand Down
2 changes: 1 addition & 1 deletion gunpowder/nodes/simple_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def process(self, batch, request):

# mirror and transpose ROIs of arrays & points in batch
total_roi = batch.get_total_roi().copy()
for collection_type in [batch.arrays, batch.points]:
for collection_type in [batch.arrays, batch.graphs]:
for (key, collector) in collection_type.items():
if key not in request:
continue
Expand Down
Loading

0 comments on commit ca778da

Please sign in to comment.