Skip to content

Commit 6135482

Browse files
Fix complex number interaction with scalars
1 parent 53492cc commit 6135482

File tree

3 files changed

+110
-80
lines changed

3 files changed

+110
-80
lines changed

pythran/pythonic/include/types/complex.hpp

Lines changed: 48 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -5,42 +5,54 @@
55

66
namespace std
77
{
8-
template <class T>
9-
std::complex<T> operator+(std::complex<T> self, long other);
10-
template <class T>
11-
std::complex<T> operator+(long self, std::complex<T> other);
12-
template <class T>
13-
std::complex<T> operator-(std::complex<T> self, long other);
14-
template <class T>
15-
std::complex<T> operator-(long self, std::complex<T> other);
16-
template <class T>
17-
std::complex<T> operator*(std::complex<T> self, long other);
18-
template <class T>
19-
std::complex<T> operator*(long self, std::complex<T> other);
20-
template <class T>
21-
std::complex<T> operator/(std::complex<T> self, long other);
22-
template <class T>
23-
std::complex<T> operator/(long self, std::complex<T> other);
24-
template <class T>
25-
bool operator==(std::complex<T> self, long other);
26-
template <class T>
27-
bool operator==(long self, std::complex<T> other);
28-
template <class T>
29-
bool operator!=(std::complex<T> self, long other);
30-
template <class T>
31-
bool operator!=(long self, std::complex<T> other);
32-
template <class T>
33-
bool operator<(std::complex<T> self, std::complex<T> other);
34-
template <class T>
35-
bool operator<=(std::complex<T> self, std::complex<T> other);
36-
template <class T>
37-
bool operator>(std::complex<T> self, std::complex<T> other);
38-
template <class T>
39-
bool operator>=(std::complex<T> self, std::complex<T> other);
40-
template <class T>
41-
bool operator&&(std::complex<T> self, std::complex<T> other);
42-
template <class T>
43-
bool operator||(std::complex<T> self, std::complex<T> other);
8+
9+
template <class T, class S>
10+
using complex_broadcast_t = typename std::enable_if<
11+
std::is_scalar<S>::value && !std::is_same<T, S>::value,
12+
std::complex<typename std::common_type<T, S>::type>>::type;
13+
template <class T, class S>
14+
using complex_bool_t = typename std::enable_if<
15+
std::is_scalar<S>::value && !std::is_same<T, S>::value, bool>::type;
16+
17+
template <class T, class S>
18+
complex_broadcast_t<T, S> operator+(std::complex<T> self, S other);
19+
template <class T, class S>
20+
complex_broadcast_t<T, S> operator+(S self, std::complex<T> other);
21+
template <class T, class S>
22+
complex_broadcast_t<T, S> operator-(std::complex<T> self, S other);
23+
template <class T, class S>
24+
complex_broadcast_t<T, S> operator-(S self, std::complex<T> other);
25+
template <class T, class S>
26+
complex_broadcast_t<T, S> operator*(std::complex<T> self, S other);
27+
template <class T, class S>
28+
complex_broadcast_t<T, S> operator*(S self, std::complex<T> other);
29+
template <class T, class S>
30+
complex_broadcast_t<T, S> operator/(std::complex<T> self, S other);
31+
template <class T, class S>
32+
complex_broadcast_t<T, S> operator/(S self, std::complex<T> other);
33+
34+
template <class T, class S>
35+
complex_bool_t<T, S> operator==(std::complex<T> self, S other);
36+
template <class T, class S>
37+
complex_bool_t<T, S> operator==(S self, std::complex<T> other);
38+
template <class T, class S>
39+
complex_bool_t<T, S> operator!=(std::complex<T> self, S other);
40+
template <class T, class S>
41+
complex_bool_t<T, S> operator!=(S self, std::complex<T> other);
42+
43+
template <class T, class S>
44+
bool operator<(std::complex<T> self, std::complex<S> other);
45+
template <class T, class S>
46+
bool operator<=(std::complex<T> self, std::complex<S> other);
47+
template <class T, class S>
48+
bool operator>(std::complex<T> self, std::complex<S> other);
49+
template <class T, class S>
50+
bool operator>=(std::complex<T> self, std::complex<S> other);
51+
template <class T, class S>
52+
bool operator&&(std::complex<T> self, std::complex<S> other);
53+
template <class T, class S>
54+
bool operator||(std::complex<T> self, std::complex<S> other);
55+
4456
template <class T>
4557
bool operator!(std::complex<T> self);
4658

pythran/pythonic/types/complex.hpp

Lines changed: 52 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -7,114 +7,122 @@
77

88
namespace std
99
{
10-
template <class T>
11-
std::complex<T> operator+(std::complex<T> self, long other)
10+
template <class T, class S>
11+
complex_broadcast_t<T, S> operator+(std::complex<T> self, S other)
1212
{
13-
return self + T(other);
13+
return (complex_broadcast_t<T, S>)self +
14+
(typename std::common_type<T, S>::type)(other);
1415
}
1516

16-
template <class T>
17-
std::complex<T> operator+(long self, std::complex<T> other)
17+
template <class T, class S>
18+
complex_broadcast_t<T, S> operator+(S self, std::complex<T> other)
1819
{
19-
return T(self) + other;
20+
return (typename std::common_type<T, S>::type)(self) +
21+
(complex_broadcast_t<T, S>)other;
2022
}
2123

22-
template <class T>
23-
std::complex<T> operator-(std::complex<T> self, long other)
24+
template <class T, class S>
25+
complex_broadcast_t<T, S> operator-(std::complex<T> self, S other)
2426
{
25-
return self - T(other);
27+
return (complex_broadcast_t<T, S>)self -
28+
(typename std::common_type<T, S>::type)(other);
2629
}
2730

28-
template <class T>
29-
std::complex<T> operator-(long self, std::complex<T> other)
31+
template <class T, class S>
32+
complex_broadcast_t<T, S> operator-(S self, std::complex<T> other)
3033
{
31-
return T(self) - other;
34+
return (typename std::common_type<T, S>::type)(self) -
35+
(complex_broadcast_t<T, S>)other;
3236
}
3337

34-
template <class T>
35-
std::complex<T> operator*(std::complex<T> self, long other)
38+
template <class T, class S>
39+
complex_broadcast_t<T, S> operator*(std::complex<T> self, S other)
3640
{
37-
return self * T(other);
41+
return (complex_broadcast_t<T, S>)self *
42+
(typename std::common_type<T, S>::type)(other);
3843
}
3944

40-
template <class T>
41-
std::complex<T> operator*(long self, std::complex<T> other)
45+
template <class T, class S>
46+
complex_broadcast_t<T, S> operator*(S self, std::complex<T> other)
4247
{
43-
return T(self) * other;
48+
return (typename std::common_type<T, S>::type)(self) *
49+
(complex_broadcast_t<T, S>)other;
4450
}
4551

46-
template <class T>
47-
std::complex<T> operator/(std::complex<T> self, long other)
52+
template <class T, class S>
53+
complex_broadcast_t<T, S> operator/(std::complex<T> self, S other)
4854
{
49-
return self / T(other);
55+
return (complex_broadcast_t<T, S>)self /
56+
(typename std::common_type<T, S>::type)(other);
5057
}
5158

52-
template <class T>
53-
std::complex<T> operator/(long self, std::complex<T> other)
59+
template <class T, class S>
60+
complex_broadcast_t<T, S> operator/(S self, std::complex<T> other)
5461
{
55-
return T(self) / other;
62+
return (typename std::common_type<T, S>::type)(self) /
63+
(complex_broadcast_t<T, S>)other;
5664
}
5765

58-
template <class T>
59-
bool operator==(std::complex<T> self, long other)
66+
template <class T, class S>
67+
complex_bool_t<T, S> operator==(std::complex<T> self, S other)
6068
{
6169
return self == T(other);
6270
}
6371

64-
template <class T>
65-
bool operator==(long self, std::complex<T> other)
72+
template <class T, class S>
73+
complex_bool_t<T, S> operator==(S self, std::complex<T> other)
6674
{
6775
return T(self) == other;
6876
}
6977

70-
template <class T>
71-
bool operator!=(std::complex<T> self, long other)
78+
template <class T, class S>
79+
complex_bool_t<T, S> operator!=(std::complex<T> self, S other)
7280
{
7381
return self != T(other);
7482
}
7583

76-
template <class T>
77-
bool operator!=(long self, std::complex<T> other)
84+
template <class T, class S>
85+
complex_bool_t<T, S> operator!=(S self, std::complex<T> other)
7886
{
7987
return T(self) != other;
8088
}
8189

82-
template <class T>
83-
bool operator<(std::complex<T> self, std::complex<T> other)
90+
template <class T, class S>
91+
bool operator<(std::complex<T> self, std::complex<S> other)
8492
{
8593
return self.real() == other.real() ? self.imag() < other.imag()
8694
: self.real() < other.real();
8795
}
8896

89-
template <class T>
90-
bool operator<=(std::complex<T> self, std::complex<T> other)
97+
template <class T, class S>
98+
bool operator<=(std::complex<T> self, std::complex<S> other)
9199
{
92100
return self.real() == other.real() ? self.imag() <= other.imag()
93101
: self.real() <= other.real();
94102
}
95103

96-
template <class T>
97-
bool operator>(std::complex<T> self, std::complex<T> other)
104+
template <class T, class S>
105+
bool operator>(std::complex<T> self, std::complex<S> other)
98106
{
99107
return self.real() == other.real() ? self.imag() > other.imag()
100108
: self.real() > other.real();
101109
}
102110

103-
template <class T>
104-
bool operator>=(std::complex<T> self, std::complex<T> other)
111+
template <class T, class S>
112+
bool operator>=(std::complex<T> self, std::complex<S> other)
105113
{
106114
return self.real() == other.real() ? self.imag() >= other.imag()
107115
: self.real() >= other.real();
108116
}
109117

110-
template <class T>
111-
bool operator&&(std::complex<T> self, std::complex<T> other)
118+
template <class T, class S>
119+
bool operator&&(std::complex<T> self, std::complex<S> other)
112120
{
113121
return (self.real() || self.imag()) && (other.real() || other.imag());
114122
}
115123

116-
template <class T>
117-
bool operator||(std::complex<T> self, std::complex<T> other)
124+
template <class T, class S>
125+
bool operator||(std::complex<T> self, std::complex<S> other)
118126
{
119127
return (self.real() || self.imag()) || (other.real() || other.imag());
120128
}

pythran/tests/test_complex.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,16 @@ def test_complex_array_iexpr_real_assign(self):
102102
np.array([[3 + 2j, 2, 1, 0]] * 3,dtype=np.complex64),
103103
test_complex_array_iexpr_real_assign=[NDArray[np.complex64, :, :]])
104104

105+
def test_complex_broadcast_scalar0(self):
106+
self.run_test('def complex_broadcast_scalar0(x): return x + 1.5, 1.3 +x, 3.1 - x, x - 3.7, x * 5.4, 7.6 * x',
107+
5.1 + 3j,
108+
complex_broadcast_scalar0=[complex])
109+
110+
def test_complex_broadcast_scalar1(self):
111+
self.run_test('def complex_broadcast_scalar1(x): return x + 1.5, 1.3 +x, 3.1 - x, x - 3.7, x * 5.4, 7.6 * x',
112+
np.complex64(5.1 + 3j),
113+
complex_broadcast_scalar1=[np.complex64])
114+
105115
def test_complex_array_imag_assign(self):
106116
self.run_test('def test_complex_array_imag_assign(a): a.imag = 1; return a',
107117
np.array([[3 + 2j, 2, 1, 0]] * 3,dtype=np.complex64),

0 commit comments

Comments
 (0)