Skip to content

Commit effeb72

Browse files
committed
Support Cors handling for HTTP endpoint
Signed-off-by: Hiroshi Hatake <[email protected]>
1 parent 3f23f30 commit effeb72

File tree

5 files changed

+228
-1
lines changed

5 files changed

+228
-1
lines changed

config.go

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,22 @@ func loadTLSConfig(cfg TLSSettings) (*tls.Config, error) {
3333
return tlsConfig, nil
3434
}
3535

36+
func parseCorsConfig(cfg *Config, conf plugin.ConfigLoader) (*Config, error) {
37+
if cfg == nil {
38+
return nil, errors.New("cfg must not nil")
39+
}
40+
if corsStr := conf.String("server.cors.allowed_origins"); corsStr != "" {
41+
for _, origin := range strings.Split(corsStr, ",") {
42+
trimmedOrigin := strings.TrimSpace(origin)
43+
if trimmedOrigin != "" { // Only append non-empty origins
44+
cfg.ServerCors.AllowedOrigins = append(cfg.ServerCors.AllowedOrigins, trimmedOrigin)
45+
}
46+
}
47+
}
48+
49+
return cfg, nil
50+
}
51+
3652
func loadConfig(fbit *plugin.Fluentbit) (*Config, error) {
3753
log := fbit.Logger
3854
cfg := &Config{
@@ -45,6 +61,12 @@ func loadConfig(fbit *plugin.Fluentbit) (*Config, error) {
4561
ServerGrpcListenAddr: fbit.Conf.String("server.grpc.listen_addr"),
4662
ServerHeaders: parseHeaders(fbit.Conf.String("server.headers")),
4763
}
64+
65+
cfg, err := parseCorsConfig(cfg, fbit.Conf)
66+
if err != nil {
67+
return nil, err
68+
}
69+
4870
if cfg.Mode == "" {
4971
cfg.Mode = "all"
5072
}

config_test.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,69 @@ func Test_loadConfig_ModeServer(t *testing.T) {
129129
})
130130
}
131131

132+
func Test_loadConfig_CORS(t *testing.T) {
133+
testCases := []struct {
134+
name string
135+
inputConf map[string]string
136+
expectedOrigins []string
137+
}{
138+
{
139+
name: "multiple origins with spaces",
140+
inputConf: map[string]string{
141+
"server.cors.allowed_origins": "http://localhost:3000, https://my-app.com",
142+
},
143+
expectedOrigins: []string{"http://localhost:3000", "https://my-app.com"},
144+
},
145+
{
146+
name: "single origin",
147+
inputConf: map[string]string{
148+
"server.cors.allowed_origins": "https://my-app.com",
149+
},
150+
expectedOrigins: []string{"https://my-app.com"},
151+
},
152+
{
153+
name: "wildcard origin",
154+
inputConf: map[string]string{
155+
"server.cors.allowed_origins": "*",
156+
},
157+
expectedOrigins: []string{"*"},
158+
},
159+
{
160+
name: "config key not present",
161+
inputConf: map[string]string{}, // The key is missing entirely
162+
expectedOrigins: nil, // Expect a nil slice, not an empty one
163+
},
164+
{
165+
name: "config key is an empty string",
166+
inputConf: map[string]string{
167+
"server.cors.allowed_origins": "",
168+
},
169+
expectedOrigins: nil,
170+
},
171+
{
172+
name: "malformed string with extra spaces and commas",
173+
inputConf: map[string]string{
174+
"server.cors.allowed_origins": " http://a.com, ,https://b.com ,",
175+
},
176+
expectedOrigins: []string{"http://a.com", "https://b.com"},
177+
},
178+
}
179+
180+
for _, tc := range testCases {
181+
t.Run(tc.name, func(t *testing.T) {
182+
// Arrange: Create a fresh config and mock loader for each test.
183+
cfg := &Config{}
184+
conf := mapConfigLoader(tc.inputConf)
185+
186+
// Act: Call the function we are testing.
187+
parseCorsConfig(cfg, conf)
188+
189+
// Assert: Check if the result matches the expectation.
190+
assert.Equal(t, tc.expectedOrigins, cfg.ServerCors.AllowedOrigins)
191+
})
192+
}
193+
}
194+
132195
func Test_loadConfig_Server_Defaults(t *testing.T) {
133196
t.Run("server mode applies defaults for retry and keepalive", func(t *testing.T) {
134197
fbit := &plugin.Fluentbit{

custom_jaeger_remote.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ type clientComponent struct {
5050
tracerProvider *sdktrace.TracerProvider
5151
}
5252

53+
type CorsSettings struct {
54+
AllowedOrigins []string
55+
}
56+
5357
type Config struct {
5458
Mode string // "client", "server", or "all"
5559
Headers map[string]string
@@ -61,6 +65,7 @@ type Config struct {
6165
ServerStrategyFile string
6266
ServerHttpListenAddr string
6367
ServerGrpcListenAddr string
68+
ServerCors CorsSettings
6469
ServerServiceNames []string
6570
ServerTLS TLSSettings
6671
ServerHeaders map[string]string

jaeger_services.go

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,11 +212,50 @@ func (plug *jaegerRemotePlugin) loadStrategiesFromFile() error {
212212
return nil
213213
}
214214

215+
func (plug *jaegerRemotePlugin) corsMiddleware(next http.Handler) http.Handler {
216+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
217+
if len(plug.config.ServerCors.AllowedOrigins) == 0 {
218+
next.ServeHTTP(w, r)
219+
return
220+
}
221+
222+
origin := r.Header.Get("Origin")
223+
isAllowed := false
224+
for _, allowed := range plug.config.ServerCors.AllowedOrigins {
225+
if allowed == "*" || allowed == origin {
226+
isAllowed = true
227+
break
228+
}
229+
}
230+
231+
if isAllowed {
232+
w.Header().Set("Access-Control-Allow-Origin", origin)
233+
w.Header().Set("Access-Control-Allow-Methods", "GET, OPTIONS")
234+
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
235+
}
236+
237+
// Handle pre-flight OPTIONS request
238+
if r.Method == "OPTIONS" {
239+
if isAllowed {
240+
w.WriteHeader(http.StatusNoContent)
241+
return
242+
}
243+
// If not an allowed origin, forbid the request.
244+
w.WriteHeader(http.StatusForbidden)
245+
return
246+
}
247+
248+
// Serve the actual request for GET, etc.
249+
next.ServeHTTP(w, r)
250+
})
251+
}
252+
215253
func (plug *jaegerRemotePlugin) startHttpServer() *http.Server {
216254
mux := http.NewServeMux()
217255
mux.HandleFunc("/sampling", plug.handleSampling)
218256
mux.HandleFunc("/strategies", plug.handleGetStrategies)
219-
server := &http.Server{Addr: plug.config.ServerHttpListenAddr, Handler: mux}
257+
258+
server := &http.Server{Addr: plug.config.ServerHttpListenAddr, Handler: plug.corsMiddleware(mux)} //
220259
go func() {
221260
if err := server.ListenAndServe(); err != http.ErrServerClosed {
222261
plug.log.Error("HTTP server error: %v", err)

jaeger_services_test.go

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,6 +280,104 @@ func Test_InitServer_Failure(t *testing.T) {
280280
})
281281
}
282282

283+
func Test_corsMiddleware(t *testing.T) {
284+
dummyHandler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
285+
w.WriteHeader(http.StatusOK)
286+
w.Write([]byte("OK"))
287+
})
288+
289+
plug := &jaegerRemotePlugin{
290+
log: newTestLogger(t),
291+
config: &Config{
292+
ServerCors: CorsSettings{
293+
AllowedOrigins: []string{"http://localhost:3000"},
294+
},
295+
},
296+
}
297+
testHandler := plug.corsMiddleware(dummyHandler) //
298+
299+
t.Run("GET request from allowed origin", func(t *testing.T) {
300+
req := httptest.NewRequest(http.MethodGet, "/sampling", nil)
301+
req.Header.Set("Origin", "http://localhost:3000")
302+
rr := httptest.NewRecorder()
303+
304+
testHandler.ServeHTTP(rr, req)
305+
306+
assert.Equal(t, http.StatusOK, rr.Code)
307+
assert.Equal(t, "http://localhost:3000", rr.Header().Get("Access-Control-Allow-Origin"))
308+
})
309+
310+
t.Run("pre-flight OPTIONS request from allowed origin", func(t *testing.T) {
311+
req := httptest.NewRequest(http.MethodOptions, "/sampling", nil)
312+
req.Header.Set("Origin", "http://localhost:3000")
313+
rr := httptest.NewRecorder()
314+
315+
testHandler.ServeHTTP(rr, req)
316+
317+
assert.Equal(t, http.StatusNoContent, rr.Code)
318+
assert.Equal(t, "http://localhost:3000", rr.Header().Get("Access-Control-Allow-Origin"))
319+
})
320+
321+
t.Run("GET request from disallowed origin", func(t *testing.T) {
322+
req := httptest.NewRequest(http.MethodGet, "/sampling", nil)
323+
req.Header.Set("Origin", "https://evil-site.com")
324+
rr := httptest.NewRecorder()
325+
326+
testHandler.ServeHTTP(rr, req)
327+
328+
assert.Equal(t, http.StatusOK, rr.Code)
329+
assert.Equal(t, "", rr.Header().Get("Access-Control-Allow-Origin"))
330+
})
331+
332+
t.Run("pre-flight OPTIONS request from disallowed origin", func(t *testing.T) {
333+
req := httptest.NewRequest(http.MethodOptions, "/sampling", nil)
334+
req.Header.Set("Origin", "https://evil-site.com")
335+
rr := httptest.NewRecorder()
336+
337+
testHandler.ServeHTTP(rr, req)
338+
339+
assert.Equal(t, http.StatusForbidden, rr.Code)
340+
})
341+
342+
t.Run("request without CORS config should pass through", func(t *testing.T) {
343+
plugWithoutCors := &jaegerRemotePlugin{
344+
log: newTestLogger(t),
345+
config: &Config{},
346+
}
347+
handlerWithoutCors := plugWithoutCors.corsMiddleware(dummyHandler)
348+
349+
req := httptest.NewRequest(http.MethodGet, "/sampling", nil)
350+
req.Header.Set("Origin", "http://localhost:3000")
351+
rr := httptest.NewRecorder()
352+
353+
handlerWithoutCors.ServeHTTP(rr, req)
354+
355+
assert.Equal(t, http.StatusOK, rr.Code)
356+
assert.Equal(t, "", rr.Header().Get("Access-Control-Allow-Origin"))
357+
})
358+
359+
t.Run("wildcard origin allows any origin", func(t *testing.T) {
360+
plugWithWildcard := &jaegerRemotePlugin{
361+
log: newTestLogger(t),
362+
config: &Config{
363+
ServerCors: CorsSettings{
364+
AllowedOrigins: []string{"*"},
365+
},
366+
},
367+
}
368+
handlerWithWildcard := plugWithWildcard.corsMiddleware(dummyHandler)
369+
370+
req := httptest.NewRequest(http.MethodGet, "/sampling", nil)
371+
req.Header.Set("Origin", "https://any-site.com")
372+
rr := httptest.NewRecorder()
373+
374+
handlerWithWildcard.ServeHTTP(rr, req)
375+
376+
assert.Equal(t, http.StatusOK, rr.Code)
377+
assert.Equal(t, "https://any-site.com", rr.Header().Get("Access-Control-Allow-Origin"))
378+
})
379+
}
380+
283381
func Test_ServerHandlers(t *testing.T) {
284382
mockJaeger := &samplingServer{
285383
err: status.Error(codes.NotFound, "strategy not found for service"),

0 commit comments

Comments
 (0)