diff --git a/cmd/sshproxyctl/sshproxyctl.go b/cmd/sshproxyctl/sshproxyctl.go index 95343ad8..09f4125f 100644 --- a/cmd/sshproxyctl/sshproxyctl.go +++ b/cmd/sshproxyctl/sshproxyctl.go @@ -309,11 +309,15 @@ func (fu flatUsers) getAllUsers(allFlag bool, passthrough bool) ([][]string, map totalOut := 0 for i, v := range fu { + groups := v.Groups + if !passthrough && groups == "" { + groups = "\u274C" + } if allFlag { rows[i] = []string{ v.User, v.Service, - v.Groups, + groups, fmt.Sprintf("%d", v.N), byteToHuman(v.BwIn, passthrough), byteToHuman(v.BwOut, passthrough), @@ -323,7 +327,7 @@ func (fu flatUsers) getAllUsers(allFlag bool, passthrough bool) ([][]string, map } else { rows[i] = []string{ v.User, - v.Groups, + groups, fmt.Sprintf("%d", v.N), byteToHuman(v.BwIn, passthrough), byteToHuman(v.BwOut, passthrough), @@ -422,9 +426,13 @@ type flatGroups []*utils.FlatGroup func (fg flatGroups) getAllGroups(allFlag bool, passthrough bool) [][]string { rows := make([][]string, len(fg)) for i, v := range fg { + group := v.Group + if !passthrough && group == "" { + group = "\u274C" + } if allFlag { rows[i] = []string{ - v.Group, + group, v.Service, v.Users, fmt.Sprintf("%d", v.N), @@ -433,7 +441,7 @@ func (fg flatGroups) getAllGroups(allFlag bool, passthrough bool) [][]string { } } else { rows[i] = []string{ - v.Group, + group, v.Users, fmt.Sprintf("%d", v.N), byteToHuman(v.BwIn, passthrough), diff --git a/pkg/utils/etcd.go b/pkg/utils/etcd.go index 9229dda1..3edce185 100644 --- a/pkg/utils/etcd.go +++ b/pkg/utils/etcd.go @@ -744,16 +744,7 @@ func (c *Client) GetAllUsers(allFlag bool) ([]*FlatUser, error) { } if users[key] == nil { v := &FlatUser{} - groups, err := GetGroupList(connection.User) - if err != nil { - return nil, err - } - g := make([]string, 0, len(groups)) - for group := range groups { - g = append(g, group) - } - sort.Strings(g) - v.Groups = strings.Join(g, " ") + v.Groups = GetSortedGroups(connection.User) v.N = 1 v.BwIn = connection.BwIn v.BwOut = connection.BwOut @@ -774,16 +765,7 @@ func (c *Client) GetAllUsers(allFlag bool) ([]*FlatUser, error) { key := hist.User if users[key] == nil { v := &FlatUser{} - groups, err := GetGroupList(strings.Split(hist.User, "@")[0]) - if err != nil { - return nil, err - } - g := make([]string, 0, len(groups)) - for group := range groups { - g = append(g, group) - } - sort.Strings(g) - v.Groups = strings.Join(g, " ") + v.Groups = GetSortedGroups(strings.Split(hist.User, "@")[0]) v.Dest = hist.Dest v.TTL = hist.TTL users[key] = v diff --git a/pkg/utils/utils.go b/pkg/utils/utils.go index a8da2199..a91de072 100644 --- a/pkg/utils/utils.go +++ b/pkg/utils/utils.go @@ -14,8 +14,11 @@ import ( "crypto/sha1" "fmt" "net" + "os" "os/user" + "sort" "strconv" + "strings" "time" ) @@ -111,6 +114,26 @@ func GetGroupList(username string) (map[string]bool, error) { return groups, nil } +// GetSortedGroups returns a string of sorted space-separated groups for the +// specified user. +// +// It displays a warning when a user has no group (happens when a user has been +// deleted, but still has an open connection +func GetSortedGroups(username string) string { + groups, err := GetGroupList(username) + if err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + return "" + } else { + g := make([]string, 0, len(groups)) + for group := range groups { + g = append(g, group) + } + sort.Strings(g) + return strings.Join(g, " ") + } +} + // Mocking net.LookupHost for testing. var netLookupHost = net.LookupHost diff --git a/pkg/utils/utils_test.go b/pkg/utils/utils_test.go index 59578b85..865585a8 100644 --- a/pkg/utils/utils_test.go +++ b/pkg/utils/utils_test.go @@ -11,8 +11,11 @@ package utils import ( + "bytes" "errors" "fmt" + "io" + "os" "os/user" "reflect" "sort" @@ -174,7 +177,7 @@ var getGroupListTests = []struct { {"testuser", "testgroup", ""}, {"userwithnogroupid", "", "user: list groups for userwithnogroupid: invalid gid \"\""}, {"userwithinvalidgroup", "nonexistentgroup", "group: unknown group ID 1002"}, - {"nonexistentuser", "nonexistentgroup", "user: unknown user nonexistentuser"}, + {"nonexistentuser", "", "user: unknown user nonexistentuser"}, } func TestGetGroupList(t *testing.T) { @@ -211,6 +214,57 @@ func BenchmarkGetGroupList(b *testing.B) { } } +var getSortedGroupsTests = []struct { + user, groups, err string +}{ + {"root", "root", ""}, + {"testuser", "testgroup", ""}, + {"userwithnogroupid", "", "user: list groups for userwithnogroupid: invalid gid \"\"\n"}, + {"userwithinvalidgroup", "", "group: unknown group ID 1002\n"}, + {"nonexistentuser", "", "user: unknown user nonexistentuser\n"}, +} + +func TestGetSortedGroups(t *testing.T) { + userLookup = mockUserLookup + userLookupGroupId = mockUserLookupGroupId + for _, tt := range getSortedGroupsTests { + // Save the original stderr + originalStderr := os.Stderr + + // Create a new buffer and redirect stderr + var buf bytes.Buffer + r, w, err := os.Pipe() + if err != nil { + t.Fatalf("Failed to create pipe: %v", err) + } + os.Stderr = w + + groups := GetSortedGroups(tt.user) + + // Stop writing and restore stderr + w.Close() + os.Stderr = originalStderr + io.Copy(&buf, r) + + // Verify the output + errStr := buf.String() + if errStr != tt.err { + t.Errorf("GetSortedGroups err = %s, want %s", errStr, tt.err) + } + if groups != tt.groups { + t.Errorf("GetSortedGroups groups = %s, want %s", groups, tt.groups) + } + } +} + +func BenchmarkGetSortedGroups(b *testing.B) { + b.Run("root", func(b *testing.B) { + for i := 0; i < b.N; i++ { + GetSortedGroups("root") + } + }) +} + func mockNetLookupHost(host string) ([]string, error) { if host == "err" { return nil, errors.New("LookupHost error") diff --git a/test/fedora-image/Dockerfile b/test/fedora-image/Dockerfile index 0a5437ac..0264ed3a 100644 --- a/test/fedora-image/Dockerfile +++ b/test/fedora-image/Dockerfile @@ -5,14 +5,16 @@ RUN set -ex \ && yum -y update \ && yum -y install asciidoc etcd git golang hostname iproute make openssh-server rpm-build procps -# Create fedora, user1 and user2 users ; fedora and user1 groups +# Create fedora, user1, user2 and tmpuser users ; fedora and user1 groups RUN set -ex \ && useradd fedora \ && install -d -m0755 -o fedora -g fedora /home/fedora/.ssh \ && useradd user1 \ && install -d -m0755 -o user1 -g user1 /home/user1/.ssh \ && useradd -g user1 user2 \ - && install -d -m0755 -o user2 -g user1 /home/user2/.ssh + && install -d -m0755 -o user2 -g user1 /home/user2/.ssh \ + && useradd -g user1 tmpuser \ + && install -d -m0755 -o tmpuser -g user1 /home/tmpuser/.ssh # Copy fedora public key to root authorized_keys RUN set -ex && install -d -m0700 /root/.ssh @@ -40,6 +42,10 @@ COPY --chown=user2:user1 ./ssh/id_ed25519.pub /home/user2/.ssh/authorized_keys COPY --chown=user2:user1 ./ssh/id_ed25519* ./ssh/known_hosts /home/user2/.ssh/ RUN chmod 0600 /home/user2/.ssh/id_ed25519 /home/user2/.ssh/authorized_keys +# Copy tmpuser ssh keys +COPY --chown=tmpuser:user1 ./ssh/id_ed25519.pub /home/tmpuser/.ssh/authorized_keys +COPY --chown=tmpuser:user1 ./ssh/id_ed25519* ./ssh/known_hosts /home/tmpuser/.ssh/ +RUN chmod 0600 /home/tmpuser/.ssh/id_ed25519 /home/tmpuser/.ssh/authorized_keys # Copy etcd certificates and keys COPY ./etcd/*.pem /etc/etcd/ diff --git a/test/fedora-image/sshproxy_test.go b/test/fedora-image/sshproxy_test.go index 0b81c2ec..a6da6d68 100644 --- a/test/fedora-image/sshproxy_test.go +++ b/test/fedora-image/sshproxy_test.go @@ -42,31 +42,23 @@ var ( ) func addLineSSHProxyConf(line string) { - ctx := context.Background() - for _, gateway := range gateways { - _, _, _, err := runCommand(ctx, "ssh", []string{fmt.Sprintf("root@%s", gateway), "--", fmt.Sprintf("echo \"%s\" >> %s", line, SSHPROXYCONFIG)}, nil, nil) - if err != nil { - log.Fatal(err) - } - } + runRootCommand(fmt.Sprintf("echo \"%s\" >> %s", line, SSHPROXYCONFIG)) } func removeLineSSHProxyConf(line string) { - ctx := context.Background() line = strings.ReplaceAll(line, "/", "\\/") - for _, gateway := range gateways { - _, _, _, err := runCommand(ctx, "ssh", []string{fmt.Sprintf("root@%s", gateway), "--", fmt.Sprintf("sed -i 's/^%s$//' %s", line, SSHPROXYCONFIG)}, nil, nil) - if err != nil { - log.Fatal(err) - } - } + runRootCommand(fmt.Sprintf("sed -i 's/^%s$//' %s", line, SSHPROXYCONFIG)) } func updateLineSSHProxyConf(key string, value string) { - ctx := context.Background() value = strings.ReplaceAll(value, "/", "\\/") + runRootCommand(fmt.Sprintf("sed -i '/%s:/s/: .*$/: %s/' %s", key, value, SSHPROXYCONFIG)) +} + +func runRootCommand(cmd string) { + ctx := context.Background() for _, gateway := range gateways { - _, _, _, err := runCommand(ctx, "ssh", []string{fmt.Sprintf("root@%s", gateway), "--", fmt.Sprintf("sed -i '/%s:/s/: .*$/: %s/' %s", key, value, SSHPROXYCONFIG)}, nil, nil) + _, _, _, err := runCommand(ctx, "ssh", []string{fmt.Sprintf("root@%s", gateway), "--", cmd}, nil, nil) if err != nil { log.Fatal(err) } @@ -662,6 +654,36 @@ func TestForgetPersist(t *testing.T) { } } +func TestMissingUser(t *testing.T) { + // remove old connections stored in etcd + time.Sleep(4 * time.Second) + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + args, _ := prepareCommand("tmpuser@gateway1", 2022, "sleep 20") + ch := make(chan *os.Process) + go func() { + runCommand(ctx, "ssh", args, nil, ch) + }() + process1 := <-ch + + time.Sleep(time.Second) + users, _ := getEtcdAllUsers() + if len(users) != 1 { + t.Errorf("Want 1 user, got %d", len(users)) + } else if users[0].Groups != "user1" { + t.Errorf("Want Groups=\"user1\", got \"%s\"", users[0].Groups) + } + runRootCommand("userdel -f tmpuser") + users, _ = getEtcdAllUsers() + process1.Kill() + if len(users) != 1 { + t.Errorf("Want 1 user, got %d", len(users)) + } else if users[0].Groups != "" { + t.Errorf("Want Groups=\"\", got \"%s\"", users[0].Groups) + } +} + func TestBalancedConnections(t *testing.T) { // remove old connections stored in etcd time.Sleep(4 * time.Second)