From 2658ada8e1e8f0d765df91b300a1bd82ef743ed3 Mon Sep 17 00:00:00 2001 From: John DeNero Date: Tue, 8 Mar 2016 11:05:21 -0800 Subject: [PATCH] tests with start values --- datascience/util.py | 2 ++ tests/test_util.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/datascience/util.py b/datascience/util.py index 7b1d3b6d4..fa8459133 100644 --- a/datascience/util.py +++ b/datascience/util.py @@ -130,6 +130,8 @@ def minimize(f, start=None, **vargs): arg_count = f.__code__.co_argcount assert arg_count > 0, "Please pass starting values explicitly" start = [0] * arg_count + if not hasattr(start, '__len__'): + start = [start] @functools.wraps(f) def wrapper(args): diff --git a/tests/test_util.py b/tests/test_util.py index 5488477bd..291dc1f2c 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -44,3 +44,5 @@ def test_table_apply(): def test_minimize(): assert (2 == ds.minimize(lambda x: (x-2)**2)) == True assert [2, 1] == list(ds.minimize(lambda x, y: (x-2)**2 + (y-1)**2)) + assert (2 == ds.minimize(lambda x: (x-2)**2, 1)) == True + assert [2, 1] == list(ds.minimize(lambda x, y: (x-2)**2 + (y-1)**2, [1, 1]))