-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.go
676 lines (556 loc) · 18.1 KB
/
main.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
package main
import (
"bufio"
"context"
_ "embed"
"flag"
"fmt"
"io"
"log"
"net"
"os"
"os/exec"
"os/signal"
"phasing/scripts"
"regexp"
"runtime"
"strconv"
"strings"
"syscall"
"golang.org/x/crypto/ssh"
"path/filepath"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
"k8s.io/client-go/tools/clientcmd"
"k8s.io/client-go/util/retry"
"github.com/manifoldco/promptui"
)
type Phasing struct {
ServiceName string
Namespace string
Port int
LocalPort int
AgentLocalPort int
}
func (phasing *Phasing) RemoteEndpoint() string {
return fmt.Sprintf("localhost:%d", phasing.Port)
}
func (phasing *Phasing) LocalEndpoint() string {
return fmt.Sprintf("localhost:%d", phasing.LocalPort)
}
func (phasing *Phasing) AgentEndpoint() string {
return fmt.Sprintf("localhost:%d", phasing.AgentLocalPort)
}
const (
hostsFilePath = "/etc/hosts"
marker = "# added by phasing"
)
var phasing Phasing
var oldSelector map[string]string
// AddHostEntries adds multiple entries to /etc/hosts, supporting multiple hostnames per IP
func AddHostEntries(entries map[string][]string) error {
log.Println("Adding host entries to /etc/hosts...")
// Read the current contents of /etc/hosts
input, err := os.ReadFile(hostsFilePath)
if err != nil {
return fmt.Errorf("error reading /etc/hosts: %w", err)
}
existingContent := string(input)
// Open /etc/hosts in append mode
file, err := os.OpenFile(hostsFilePath, os.O_APPEND|os.O_WRONLY, 0600)
if err != nil {
return fmt.Errorf("error opening /etc/hosts: %w", err)
}
defer file.Close()
// Track added entries
// var addedEntries []string
// Loop through entries to add
for ip, hostnames := range entries {
for _, hostname := range hostnames {
// Check if the entry already exists
entry := fmt.Sprintf("%s %s", ip, hostname)
if strings.Contains(existingContent, entry) {
// fmt.Printf("Entry %s already exists, skipping\n", entry)
continue
}
// Add the entry to the file
line := fmt.Sprintf("%s %s %s\n", ip, hostname, marker)
if _, err := file.WriteString(line); err != nil {
return fmt.Errorf("error writing to /etc/hosts: %w", err)
}
// addedEntries = append(addedEntries, entry)
}
}
return nil
}
// RemoveHostEntries removes entries from /etc/hosts by IP or by marker if no IP is provided
func RemoveHostEntries(ips []string) error {
fmt.Println("Removing host entries from /etc/hosts...")
// Read the entire /etc/hosts file
input, err := os.ReadFile(hostsFilePath)
if err != nil {
return fmt.Errorf("error reading /etc/hosts: %w", err)
}
// Create a scanner to process the file line by line
lines := strings.Split(string(input), "\n")
// Determine if we're removing by specific IPs or by marker
removeAllMarked := len(ips) == 0
ipSet := make(map[string]struct{}, len(ips))
for _, ip := range ips {
ipSet[ip] = struct{}{}
}
// Filter out lines that match the IPs or the marker
var output []string
for _, line := range lines {
if removeAllMarked {
if strings.Contains(line, marker) {
continue // Remove all lines marked by the marker
}
} else {
shouldRemove := false
for ip := range ipSet {
if strings.Contains(line, ip) && strings.Contains(line, marker) {
shouldRemove = true
break
}
}
if shouldRemove {
continue // Remove lines that match the IPs provided
}
}
output = append(output, line)
}
// Write the filtered lines back to /etc/hosts
tempFilePath := hostsFilePath + ".tmp"
if err := os.WriteFile(tempFilePath, []byte(strings.Join(output, "\n")), 0644); err != nil {
return fmt.Errorf("error writing to temporary file: %w", err)
}
// Atomically replace the original file with the temporary file
if err := os.Rename(tempFilePath, hostsFilePath); err != nil {
return fmt.Errorf("error renaming temporary file: %w", err)
}
fmt.Println("Successfully removed the specified host entries.")
return nil
}
func updateService(namespace, serviceName, kubeconfigPath string) (err error) {
// Load the kubeconfig from the specified file
config, err := clientcmd.BuildConfigFromFlags("", kubeconfigPath)
if err != nil {
log.Fatalf("Error loading kubeconfig: %v", err)
}
// Create a Kubernetes clientset
clientset, err := kubernetes.NewForConfig(config)
if err != nil {
log.Printf("Error creating Kubernetes clientset: %v", err)
return err
}
// New selector and labels to set for the Service
newSelector := map[string]string{"app": "phasing"}
// Fetch the existing Service
service, err := clientset.CoreV1().Services(namespace).Get(context.TODO(), serviceName, metav1.GetOptions{})
if err != nil {
log.Printf("Error fetching Service: %v\r\n", err)
return err
}
// Print the current Service's port and target port
// fmt.Printf("Service Port: %d\n", service.Spec.Ports[0].Port)
phasing.Port = int(service.Spec.Ports[0].Port)
// fmt.Printf("Current Target Port: %d\n", service.Spec.Ports[0].TargetPort.IntVal)
// Patch the Service's selector
retryErr := retry.RetryOnConflict(retry.DefaultRetry, func() error {
service, err = clientset.CoreV1().Services(namespace).Get(context.TODO(), serviceName, metav1.GetOptions{})
if err != nil {
return err
}
// this will magically restore old service selector so we can defer the whole function from main
//if old service selector not set
if oldSelector == nil {
//backup original service selector
oldSelector = service.Spec.Selector
// Update the selector
service.Spec.Selector = newSelector
} else {
service.Spec.Selector = oldSelector
}
// Patch the Service
_, updateErr := clientset.CoreV1().Services(namespace).Update(context.TODO(), service, metav1.UpdateOptions{})
return updateErr
})
if retryErr != nil {
log.Printf("Error patching Service: %v", retryErr)
return retryErr
}
fmt.Printf("Service %s in namespace %s patched successfully.\r\n", serviceName, namespace)
return nil
}
func forwardServices(namespace, kubeconfigPath string) (ips []string, err error) {
// Load the kubeconfig from the specified file
config, err := clientcmd.BuildConfigFromFlags("", kubeconfigPath)
if err != nil {
log.Fatalf("Error loading kubeconfig: %v", err)
}
// Create a Kubernetes clientset
clientset, err := kubernetes.NewForConfig(config)
if err != nil {
log.Fatalf("Error creating Kubernetes clientset: %v", err)
}
services, err := clientset.CoreV1().Services(namespace).List(context.TODO(), metav1.ListOptions{})
if err != nil {
panic(err.Error())
}
hosts := make(map[string][]string)
for i, svc := range services.Items {
if len(svc.Spec.Ports) == 0 {
// log.Println("Service has no ports! ", svc.Name)
continue
}
what := "svc/" + svc.Name
// use service port by default
port := svc.Spec.Ports[0].Port
// sometimes we cannot forward ports to service, because service is missing a selector
// and example of this kind of service is zalando postgres operator master service
// we need to get the service endpoint ip and find the matching pod and forward to pod directly while faking dns for the service
if svc.Spec.Selector == nil {
// log.Println("WARNING: Cannot forward ports to services without selector for now :-( bummer! Service: ", svc.Name)
// Get the endpoints associated with the service
endpoints, err := clientset.CoreV1().Endpoints(namespace).Get(context.TODO(), svc.Name, metav1.GetOptions{})
if err != nil {
log.Fatalf("Error getting endpoints: %v", err)
}
if len(endpoints.Subsets) == 0 {
log.Fatalf("No endpoints nor selectors found for service: %s", svc.Name)
continue
}
// Extract the pods from the endpoints
for _, subset := range endpoints.Subsets {
for _, address := range subset.Addresses {
// Get the pod name from the endpoint addresses
if address.TargetRef != nil && address.TargetRef.Kind == "Pod" {
podName := address.TargetRef.Name
// log.Printf("Found pod: %s\n", podName)
what = "pod/" + podName
port = svc.Spec.Ports[0].TargetPort.IntVal
}
}
}
}
go RawPortForward(namespace, what, strconv.Itoa(int(port)), "127.1.1."+strconv.Itoa(i))
hosts["127.1.1."+strconv.Itoa(i)] = []string{svc.Name, svc.Name + "." + namespace,
svc.Name + "." + namespace + ".svc.cluster.local",
svc.Name + "." + namespace + ".svc",
svc.Name + "." + namespace + ".svc.cluster",
svc.Name + ".svc"}
log.Printf("Service %s will be available locally as %s:%s\r\n", svc.Name, svc.Name, strconv.Itoa(int(port)))
}
AddHostEntries(hosts)
// Get the IPs (keys) from the entries map
// is thre really no way to get keys from a map in go?
for ip := range hosts {
ips = append(ips, ip)
}
return ips, nil
}
func selectService(namespace, kubeconfigPath string) (serviceName string, err error) {
// Load the kubeconfig from the specified file
config, err := clientcmd.BuildConfigFromFlags("", kubeconfigPath)
if err != nil {
log.Fatalf("Error loading kubeconfig: %v", err)
}
// Create a Kubernetes clientset
clientset, err := kubernetes.NewForConfig(config)
if err != nil {
log.Fatalf("Error creating Kubernetes clientset: %v", err)
}
services, err := clientset.CoreV1().Services(namespace).List(context.TODO(), metav1.ListOptions{})
if err != nil {
panic(err.Error())
}
var serviceNames []string
for _, svc := range services.Items {
serviceNames = append(serviceNames, svc.Name)
}
prompt := promptui.Select{
Label: "Select service to forward",
Items: serviceNames,
Size: 20,
}
idx, _, err := prompt.Run()
if err != nil {
return "", err
}
return serviceNames[idx], nil
}
func getCurrentNamespace(kubeconfigPath string) (string, error) {
// Load the kubeconfig from the specified file
config, err := clientcmd.LoadFromFile(kubeconfigPath)
if err != nil {
return "default", err
}
// Get the current context name from the kubeconfig
currentContextName := config.CurrentContext
// Get the context for the current context name
context := config.Contexts[currentContextName]
// Retrieve the namespace from the context
return context.Namespace, nil
}
func tunnelTraffic(remote, local net.Conn) {
if remote == nil || local == nil {
return
}
defer remote.Close()
defer local.Close()
chDone := make(chan bool)
// Start local -> remote data transfer
go func() {
_, err := io.Copy(remote, local)
if err != nil {
log.Println(fmt.Sprintf("error while copy local->remote: %s", err))
}
chDone <- true
}()
// Start remote -> local data transfer
go func() {
_, err := io.Copy(local, remote)
if err != nil {
log.Println(fmt.Sprintf("error while copy remote->local: %s", err))
}
chDone <- true
}()
<-chDone
}
func sshKeyFile(file string) (ssh.AuthMethod, error) {
buffer, err := os.ReadFile(file)
if err != nil {
log.Println(fmt.Sprintf("Cannot read SSH key file ", file))
return nil, err
}
key, err := ssh.ParsePrivateKey(buffer)
if err != nil {
log.Println(fmt.Sprintf("Cannot parse SSH key file ", file))
return nil, err
}
return ssh.PublicKeys(key), nil
}
func RawPortForward(namespace string, what string, port string, address string) (err error) {
if address == "" {
address = "0.0.0.0"
}
if namespace != "" {
namespace = "--namespace=" + namespace
}
cmd := exec.Command("kubectl", "port-forward", namespace, "--address="+address, what, port)
// log.Println("Running command:", cmd.String())
// cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err = cmd.Run()
defer cmd.Process.Kill()
cmd.Wait()
if err != nil {
log.Println("Error forwarding ports for:", what, err)
return err
}
return nil
}
// use kubectl to do port forwarding to remote phasing agent
// local port is selected by os (0)
// we need to use pipe for stdout parsing, because it's a long running - blocking - command
// TODO: rewrite this to use k8s API instead of kubectl binary
func PortForward() (err error) {
cmd := exec.Command("kubectl", "port-forward", "--address=0.0.0.0", "deployment/phasing", strconv.Itoa(phasing.AgentLocalPort)+":22")
stdout, err := cmd.StdoutPipe()
if err != nil {
return err
}
//this is super ugly, but hey! It works!
go func() {
reader := bufio.NewReader(stdout)
for {
line, err := reader.ReadString('\n')
re := regexp.MustCompile(`Forwarding from 0\.0\.0\.0:(\d+) -> 22`)
found := re.FindSubmatch([]byte(line))
if len(found) > 0 {
phasing.AgentLocalPort, err = strconv.Atoi(string(found[1]))
return
}
if err != nil {
log.Fatalln("Error regexping:", err)
if err == io.EOF {
break
}
}
}
}()
err = cmd.Run()
defer cmd.Process.Kill()
cmd.Wait()
if err != nil {
return fmt.Errorf("Error connecting to agent:", err)
}
return nil
}
// TODO: rewrite this to use k8s API instead of kubectl binary
func Init() (err error) {
f, err := os.CreateTemp("", "phasing")
if err != nil {
fmt.Println("Error creating temporary file:", err)
return fmt.Errorf("Error Initializing Phasing:", err)
}
defer os.Remove(f.Name())
fyaml, err := os.CreateTemp("", "phasing")
if err != nil {
fmt.Println("Error creating temporary file:", err)
return fmt.Errorf("Error Initializing Phasing:", err)
}
defer os.Remove(fyaml.Name())
//write bash script
err = os.WriteFile(f.Name(), []byte(scripts.InitScript), 0755)
if err != nil {
fmt.Println("Error writing to temporary file:", err)
return fmt.Errorf("Error Initializing Phasing:", err)
}
// and now write a yaml file, omg, not very proud of this
err = os.WriteFile(fyaml.Name(), []byte(scripts.PhasingYAML), 0644)
if err != nil {
fmt.Println("Error writing to temporary file:", err)
return fmt.Errorf("Error Initializing Phasing:", err)
}
// Execute the Bash script using the "sh" command
cmd := exec.Command("sh", f.Name(), fyaml.Name())
cmd.Stdout = os.Stdout
cmd.Stderr = os.Stderr
err = cmd.Run()
if err != nil {
return fmt.Errorf("Error Initializing Phasing:", err)
}
for i := 0; i < 255; i++ {
if runtime.GOOS == "darwin" {
cmd = exec.Command("ifconfig", "lo0", "alias", "127.1.1."+strconv.Itoa(i)+"/8", "up")
} else if runtime.GOOS == "linux" {
cmd = exec.Command("ip", "a", "a", "127.1.1."+strconv.Itoa(i)+"/8", "dev", "lo")
}
err = cmd.Run()
if err != nil {
return fmt.Errorf("Error adding required ip addresses:", err)
}
}
return nil
}
func checkRoot() {
if os.Geteuid() != 0 {
fmt.Println("Please run this program as root. It manipulates network interfaces and dns settings.")
os.Exit(1)
}
}
func main() {
checkRoot()
var kubeconfigPath string
var init bool
kubeconfigPath = filepath.Join(os.Getenv("HOME"), ".kube", "config")
currentNamespace, err := getCurrentNamespace(kubeconfigPath)
if err != nil {
currentNamespace = "default"
}
flag.StringVar(&phasing.ServiceName, "service", "phasing", "Service name")
flag.IntVar(&phasing.LocalPort, "port", 7777, "Local port to forward remote service to")
flag.StringVar(&phasing.Namespace, "namespace", currentNamespace, "Namespace name")
flag.StringVar(&kubeconfigPath, "kubeconfig", kubeconfigPath, "Path to kube .config file")
flag.BoolVar(&init, "init", false, "Run Phasing initialization")
flag.Parse()
args := flag.Args()
if len(args) > 0 {
phasing.ServiceName = args[0]
}
if len(args) > 1 {
phasing.LocalPort, err = strconv.Atoi(args[1])
if err != nil {
log.Println("Please provide valid local port number", err)
return
}
}
if init {
err := Init()
if err != nil {
log.Println("Could not initialize phasing", err)
}
return
}
go PortForward() // start port forwarding to agent in the background
for phasing.AgentLocalPort == 0 {
//if port is still 0, just sit and wait
}
if phasing.ServiceName == "phasing" {
phasing.ServiceName, err = selectService(phasing.Namespace, kubeconfigPath)
}
for phasing.AgentLocalPort == 0 { // loop until port is set
}
// forward ports to all services in the namespace
ips, err := forwardServices(phasing.Namespace, kubeconfigPath)
// defer removal of host entries
defer RemoveHostEntries(ips)
// Create a channel to receive signals
signalCh := make(chan os.Signal, 1)
// Register for interrupt (Ctrl+C) and termination (SIGTERM) signals
signal.Notify(signalCh, syscall.SIGINT, syscall.SIGTERM)
go func() {
defer close(signalCh)
select {
case <-signalCh:
log.Println("Signal received. Exiting...")
updateService(phasing.Namespace, phasing.ServiceName, kubeconfigPath) // to restore old service config
RemoveHostEntries(ips)
os.Exit(0)
}
}()
err = updateService(phasing.Namespace, phasing.ServiceName, kubeconfigPath)
if err != nil {
log.Printf("Could not update service %s in %s\r\n", phasing.ServiceName, phasing.Namespace)
return
}
defer updateService(phasing.Namespace, phasing.ServiceName, kubeconfigPath) // to restore old service config
fmt.Printf("Starting phasing\r\nRemote endpoint is %s.%s:%d\r\nLocal endpoint for intercepting %s is localhost:%d\r\n",
phasing.ServiceName, phasing.Namespace, phasing.Port, phasing.ServiceName, phasing.LocalPort)
sshKey, err := sshKeyFile(filepath.Join(os.Getenv("HOME"), ".ssh", "phasing_key"))
if err != nil {
log.Fatalf("Cannot load SSH key: %s", err)
return
}
sshConfig := &ssh.ClientConfig{
User: "root",
Auth: []ssh.AuthMethod{sshKey},
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}
// Connect to remote agent server using agentEndpoint
serverConn, err := ssh.Dial("tcp", phasing.AgentEndpoint(), sshConfig)
if err != nil {
log.Printf("Cannot connect to the remote agent. Make sure that Phasing has been initialized properly. %s \r\n", err)
return
}
// Listen on remote server port
listener, err := serverConn.Listen("tcp", phasing.RemoteEndpoint())
if err != nil {
log.Fatalln(fmt.Printf("Listen open port ON remote server error: %s", err))
}
defer listener.Close()
// handle incoming connections
for {
// open a remote port
remote, err := listener.Accept()
if err != nil {
if err == io.EOF {
return
}
//log.Fatalln(err)
log.Printf("Cannot setup remote endpoint. Make sure that Phasing has been initialized properly. %s \r\n", err)
break
}
// Open a local connection
local, err := net.Dial("tcp", phasing.LocalEndpoint())
if err != nil {
log.Println("Dial INTO local service error: ", err)
//close remote connection given we couldn't fully establish the tunnel
remote.Close()
}
// tunnel stuff between remote and local
go tunnelTraffic(remote, local)
}
}