Skip to content

Commit

Permalink
Allow to run.attach()
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergey Mezentsev committed May 9, 2024
1 parent ce8b5a5 commit 8dba74e
Show file tree
Hide file tree
Showing 20 changed files with 387 additions and 43 deletions.
2 changes: 1 addition & 1 deletion runner/internal/runner/api/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ func (ds DummyRunner) GetState() (shim.RunnerStatus, shim.ContainerStatus, strin
return ds.State, ds.ContainerStatus, "", ds.JobResult
}

func (ds DummyRunner) Run(context.Context, shim.DockerImageConfig) error {
func (ds DummyRunner) Run(context.Context, shim.TaskConfig) error {
return nil
}

Expand Down
8 changes: 4 additions & 4 deletions runner/internal/shim/api/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,18 @@ func (s *ShimServer) SubmitPostHandler(w http.ResponseWriter, r *http.Request) (
return nil, &api.Error{Status: http.StatusConflict}
}

var body DockerTaskBody
var body TaskConfigBody
if err := api.DecodeJSONBody(w, r, &body, true); err != nil {
log.Println("Failed to decode submit body", "err", err)
return nil, err
}

go func(taskParams shim.DockerImageConfig) {
err := s.runner.Run(context.Background(), taskParams)
go func(taskConfig shim.TaskConfig) {
err := s.runner.Run(context.Background(), taskConfig)
if err != nil {
fmt.Printf("failed Run %v\n", err)
}
}(body.GetDockerImageConfig())
}(body.GetTaskConfig())

return nil, nil
}
Expand Down
10 changes: 7 additions & 3 deletions runner/internal/shim/api/schemas.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@ package api

import "github.com/dstackai/dstack/runner/internal/shim"

type DockerTaskBody struct {
type TaskConfigBody struct {
Username string `json:"username"`
Password string `json:"password"`
ImageName string `json:"image_name"`
ContainerName string `json:"container_name"`
ShmSize int64 `json:"shm_size"`
PublicKeys []string `json:"public_keys"`
SshUser string `json:"ssh_user"`
SshKey string `json:"ssh_key"`
}

type StopBody struct {
Expand Down Expand Up @@ -37,14 +39,16 @@ type StopResponse struct {
State string `json:"state"`
}

func (ra DockerTaskBody) GetDockerImageConfig() shim.DockerImageConfig {
res := shim.DockerImageConfig{
func (ra TaskConfigBody) GetTaskConfig() shim.TaskConfig {
res := shim.TaskConfig{
ImageName: ra.ImageName,
Username: ra.Username,
Password: ra.Password,
ContainerName: ra.ContainerName,
ShmSize: ra.ShmSize,
PublicKeys: ra.PublicKeys,
SshUser: ra.SshUser,
SshKey: ra.SshKey,
}
return res
}
2 changes: 1 addition & 1 deletion runner/internal/shim/api/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
)

type TaskRunner interface {
Run(context.Context, shim.DockerImageConfig) error
Run(context.Context, shim.TaskConfig) error
GetState() (shim.RunnerStatus, shim.ContainerStatus, string, shim.JobResult)
Stop(bool)
}
Expand Down
141 changes: 141 additions & 0 deletions runner/internal/shim/authorized_keys.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package shim

import (
"bufio"
"fmt"
"io"
"os"
"slices"

"github.com/ztrue/tracerr"
"golang.org/x/crypto/ssh"
)

func PublicKeyFingerprint(key string) (string, error) {
pk, _, _, _, err := ssh.ParseAuthorizedKey([]byte(key))
if err != nil {
return "", tracerr.Wrap(err)
}
keyFingerprint := ssh.FingerprintSHA256(pk)
return keyFingerprint, nil
}

func IsPublicKeysEqual(left string, right string) bool {
leftFingerprint, err := PublicKeyFingerprint(left)
if err != nil {
return false
}

rightFingerprint, err := PublicKeyFingerprint(right)
if err != nil {
return false
}

return leftFingerprint == rightFingerprint
}

func RemovePublicKeys(fileKeys []string, keysToRemove []string) []string {
newKeys := slices.DeleteFunc(fileKeys, func(fileKey string) bool {
delete := slices.ContainsFunc(keysToRemove, func(removeKey string) bool {
return IsPublicKeysEqual(fileKey, removeKey)
})
return delete
})
return newKeys
}

func AppendPublicKeys(fileKeys []string, keysToAppend []string) []string {
newKeys := []string{}
newKeys = append(newKeys, fileKeys...)
newKeys = append(newKeys, keysToAppend...)
return newKeys
}

type AuthorizedKeys struct {
user string
rootPath string
}

func (ak AuthorizedKeys) AppendPublicKeys(publicKeys []string) error {
return ak.transformAuthorizedKeys(AppendPublicKeys, publicKeys)
}

func (ak AuthorizedKeys) RemovePublicKeys(publicKeys []string) error {
return ak.transformAuthorizedKeys(RemovePublicKeys, publicKeys)
}

func (ak AuthorizedKeys) read(r io.Reader) ([]string, error) {
lines := []string{}
scanner := bufio.NewScanner(r)
for scanner.Scan() {
text := scanner.Text()
lines = append(lines, text)
}
if err := scanner.Err(); err != nil {
return []string{}, tracerr.Wrap(err)
}
return lines, nil
}

func (ak AuthorizedKeys) write(w io.Writer, lines []string) error {
wr := bufio.NewWriter(w)
for _, line := range lines {
_, err := fmt.Fprintln(wr, line)
if err != nil {
return tracerr.Wrap(err)

}
}
return wr.Flush()
}

func (ak AuthorizedKeys) GetAuthorizedKeysPath() string {
return fmt.Sprintf("%s/home/%s/.ssh/authorized_keys", ak.rootPath, ak.user)
}

func (ak AuthorizedKeys) transformAuthorizedKeys(transform func([]string, []string) []string, publicKeys []string) error {
authorizedKeysPath := ak.GetAuthorizedKeysPath()
info, err := os.Stat(authorizedKeysPath)
if err != nil {
return tracerr.Wrap(err)
}
fileMode := info.Mode().Perm()

authorizedKeysFile, err := os.OpenFile(authorizedKeysPath, os.O_RDWR, fileMode)
if err != nil {
return tracerr.Wrap(err)
}
defer authorizedKeysFile.Close()

lines, err := ak.read(authorizedKeysFile)
if err != nil {
return tracerr.Wrap(err)
}

// write backup
authorizedKeysPathBackup := ak.GetAuthorizedKeysPath() + ".bak"
authorizedKeysBackup, err := os.OpenFile(authorizedKeysPathBackup, os.O_RDWR|os.O_CREATE|os.O_TRUNC, fileMode)
if err != nil {
return tracerr.Wrap(err)
}
defer authorizedKeysBackup.Close()
if err := ak.write(authorizedKeysBackup, lines); err != nil {
return tracerr.Wrap(err)
}

// transform lines
newLines := transform(lines, publicKeys)

// write authorized_keys
if err := authorizedKeysFile.Truncate(0); err != nil {
return tracerr.Wrap(err)
}
if _, err := authorizedKeysFile.Seek(0, 0); err != nil {
return tracerr.Wrap(err)
}
if err := ak.write(authorizedKeysFile, newLines); err != nil {
return tracerr.Wrap(err)
}

return nil
}
Loading

0 comments on commit 8dba74e

Please sign in to comment.