Skip to content

Commit

Permalink
black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
jdebacker committed Sep 1, 2023
1 parent 5413f0e commit 9db3f30
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 13 deletions.
8 changes: 4 additions & 4 deletions ogcore/firm.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,7 @@ def solve_L(Y, K, K_g, p, method, m=-1):


def adj_cost(K, Kp1, p, method):
r'''
r"""
Firm capital adjstment costs
..math::
Expand All @@ -697,13 +697,13 @@ def adj_cost(K, Kp1, p, method):
Returns
Psi (array-like): Capital adjustment costs per unit of investment
'''
if method == 'SS':
"""
if method == "SS":
ac_method = "total_ss"
else:
ac_method = "total_tpi"
Inv = aggr.get_I(None, Kp1, K, p, ac_method)

Psi = ((p.psi / 2) * (Inv / K - p.mu) ** 2) / (Inv / K)

return Psi
return Psi
21 changes: 13 additions & 8 deletions tests/test_firm.py
Original file line number Diff line number Diff line change
Expand Up @@ -939,6 +939,7 @@ def test_solve_L(Y, K, Kg, p, method, expected):
L = firm.solve_L(Y, K, Kg, p, method)
assert np.allclose(L, expected, atol=1e-6)


p1 = Specifications()
p1.psi = 4.0
p1.g_n_ss = 0.01
Expand Down Expand Up @@ -979,14 +980,18 @@ def test_solve_L(Y, K, Kg, p, method, expected):
expected_dPsidKp1_3 = np.array([0.479061039, 0.43588367, -62.31580895])


@pytest.mark.parametrize('K,Kp1,p,method,expected',
[(K_1, Kp1_1, p1, 'SS', expected_Psi_1),
(K_2, Kp1_2, p2, 'SS', expected_Psi_2),
(K_3, Kp1_3, p3, 'TPI', expected_Psi_3)],
ids=['Zero cost', 'Non-zero cost', 'TPI'])
@pytest.mark.parametrize(
"K,Kp1,p,method,expected",
[
(K_1, Kp1_1, p1, "SS", expected_Psi_1),
(K_2, Kp1_2, p2, "SS", expected_Psi_2),
(K_3, Kp1_3, p3, "TPI", expected_Psi_3),
],
ids=["Zero cost", "Non-zero cost", "TPI"],
)
def test_adj_cost(K, Kp1, p, method, expected):
'''
"""
Test of the firm capital adjustment cost function.
'''
"""
test_val = firm.adj_cost(K, Kp1, p, method)
assert np.allclose(test_val, expected)
assert np.allclose(test_val, expected)
2 changes: 1 addition & 1 deletion tests/test_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,4 +194,4 @@ def test_expand_taxfunc_params():
specs.update_specifications(new_specs)
assert len(specs.etr_params) == specs.T + specs.S
assert len(specs.etr_params[0]) == specs.S
assert specs.etr_params[0][0][0] == 0.35
assert specs.etr_params[0][0][0] == 0.35

0 comments on commit 9db3f30

Please sign in to comment.