diff --git a/plugin/yaml/yaml.go b/plugin/yaml/yaml.go index 83b093a8c..fc21141d8 100644 --- a/plugin/yaml/yaml.go +++ b/plugin/yaml/yaml.go @@ -10,6 +10,8 @@ import ( "path/filepath" "regexp" "time" + "slices" + "os/user" "github.com/patrickmn/go-cache" "github.com/tg123/sshpiper/libplugin" @@ -18,7 +20,8 @@ import ( ) type pipeConfigFrom struct { - Username string `yaml:"username"` + Username string `yaml:"username,omitempty"` + Groupname string `yaml:"groupname,omitempty"` UsernameRegexMatch bool `yaml:"username_regex_match,omitempty"` AuthorizedKeys string `yaml:"authorized_keys,omitempty"` AuthorizedKeysData string `yaml:"authorized_keys_data,omitempty"` @@ -220,20 +223,30 @@ func (p *plugin) createUpstream(conn libplugin.ConnMetadata, to pipeConfigTo, or func (p *plugin) findAndCreateUpstream(conn libplugin.ConnMetadata, password string, publicKey []byte) (*libplugin.Upstream, error) { user := conn.User() + userGroups, err := getUserGroups(user) + + if err != nil { + return nil, err + } config, err := p.loadConfig() + if err != nil { return nil, err } for _, pipe := range config.Pipes { for _, from := range pipe.From { - matched := from.Username == user - - if from.UsernameRegexMatch { - matched, _ = regexp.MatchString(from.Username, user) + var matched bool + if from.Username != "" { + matched = from.Username == user + if from.UsernameRegexMatch { + matched, _ = regexp.MatchString(from.Username, user) + } + } else { + fromPipeGroup := from.Groupname + matched = slices.Contains(userGroups, fromPipeGroup) } - if !matched { continue } @@ -265,3 +278,27 @@ func (p *plugin) findAndCreateUpstream(conn libplugin.ConnMetadata, password str return nil, fmt.Errorf("no matching pipe for username [%v] found", user) } + +func getUserGroups(userName string) ([]string, error) { + usr, err := user.Lookup(userName) + if err != nil { + return nil, err + } + + groupIds, err := usr.GroupIds() + if err != nil { + return nil, err + } + + var groups []string + for _, groupId := range groupIds { + grp, err := user.LookupGroupId(groupId) + if err != nil { + return nil, err + } + groups = append(groups, grp.Name) + } + + return groups, nil +} +