Skip to content

Commit

Permalink
fix: add origin checking
Browse files Browse the repository at this point in the history
By design, the Same-Origin Policy is not applied to WebSockets, so the browser allows requests from any origin to rtty's WebSockets.
Fixed a vulnerability where arbitrary commands could be executed if a malicious website was opened on the browser while rtty was running.
  • Loading branch information
skanehira committed Aug 15, 2024
1 parent 9b8dedf commit 2ce6901
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 20 deletions.
20 changes: 11 additions & 9 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ Terminal on browser via websocket
```sh
# Run server
$ rtty run zsh -p 8080 -v --font "Cica Regular" --font-size 20
2021/05/10 14:08:11 running command: zsh
2021/05/10 14:08:11 running http://localhost:8080
2024/08/15 23:39:37 allowed origins [localhost:8080]
2024/08/15 23:39:37 running command: zsh
2024/08/15 23:39:37 running http://localhost:8080
# Help
$ rtty run -h
Expand All @@ -36,15 +37,16 @@ Usage:
rtty run [command] [flags]
Command
Execute specified command (default $SHELL)
Execute specified command (default "/bin/zsh")
Flags:
-a, --addr string server address
--font string font
--font-size string font size
-h, --help help for run
-p, --port int server port (default 9999)
-v, --view open browser
-a, --addr string server address (default "localhost")
--allow-origin stringArray allow origin (default ["localhost:9999"])
--font string font
--font-size string font size
-h, --help help for run
-p, --port int server port (default 9999)
-v, --view open browser
```

## Author
Expand Down
50 changes: 39 additions & 11 deletions cmd/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,28 @@ func (ws *wsConn) Write(b []byte) (i int, err error) {
return n, e
}

func checkOrigin(allowOrigins []string) func(config *websocket.Config, req *http.Request) error {
return func(config *websocket.Config, req *http.Request) (err error) {
config.Origin, err = websocket.Origin(config, req)
if err == nil && config.Origin == nil {
return fmt.Errorf("null origin")
}
if err != nil {
return err
}

for _, allowOrigin := range allowOrigins {
if config.Origin.Host == allowOrigin {
return nil
}
}

msg := fmt.Sprintf("not allowed origin: %s", config.Origin.Host)
log.Printf(msg)
return fmt.Errorf(msg)
}
}

var runCmd = &cobra.Command{
Use: "run",
Short: "Run command",
Expand Down Expand Up @@ -210,9 +232,6 @@ var runCmd = &cobra.Command{
log.Println(err)
return
}
if addr == "" {
addr = "localhost"
}

indexJS = strings.Replace(indexJS, "{addr}", template.JSEscapeString(addr), 1)
indexJS = strings.Replace(indexJS, "{port}", port, 1)
Expand All @@ -234,14 +253,21 @@ var runCmd = &cobra.Command{
mux.HandleFunc("/index.js", func(w http.ResponseWriter, r *http.Request) {
_, _ = w.Write([]byte(indexJS))
})
mux.Handle("/ws", websocket.Handler(run))

allowedOrigins, err := cmd.PersistentFlags().GetStringArray("allow-origin")
allowedOrigins = append(allowedOrigins, "localhost:"+port)
if err != nil {
log.Println(err)
}
mux.Handle("/ws", websocket.Server{Handler: run, Handshake: checkOrigin(allowedOrigins)})

server := &http.Server{
Addr: addr + ":" + port,
Handler: mux,
}

go func() {
log.Println("allowed origins", allowedOrigins)
log.Println("running command: " + command)
log.Printf("running http://%s:%s\n", addr, port)

Expand Down Expand Up @@ -301,10 +327,11 @@ var runCmd = &cobra.Command{

func init() {
runCmd.PersistentFlags().IntP("port", "p", 9999, "server port")
runCmd.PersistentFlags().StringP("addr", "a", "", "server address")
runCmd.PersistentFlags().StringP("addr", "a", "localhost", "server address")
runCmd.PersistentFlags().String("font", "", "font")
runCmd.PersistentFlags().String("font-size", "", "font size")
runCmd.PersistentFlags().BoolP("view", "v", false, "open browser")
runCmd.PersistentFlags().StringArray("allow-origin", []string{}, "allow origin")
runCmd.SetHelpFunc(func(cmd *cobra.Command, args []string) {
fmt.Printf(`Run command
Expand All @@ -315,12 +342,13 @@ Command
Execute specified command (default "%s")
Flags:
-a, --addr string server address
--font string font
--font-size string font size
-h, --help help for run
-p, --port int server port (default 9999)
-v, --view open browser
-a, --addr string server address (default "localhost")
--allow-origin stringArray allow origin (default ["localhost:9999"])
--font string font
--font-size string font size
-h, --help help for run
-p, --port int server port (default 9999)
-v, --view open browser
`, command)
})
rootCmd.AddCommand(runCmd)
Expand Down

0 comments on commit 2ce6901

Please sign in to comment.