Skip to content

Commit

Permalink
$PAGER environment variable sets custom pager
Browse files Browse the repository at this point in the history
  • Loading branch information
konradreiche committed May 17, 2024
1 parent 66901a4 commit 0bec65f
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 0 deletions.
27 changes: 27 additions & 0 deletions command.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"bytes"
"flag"
"io"
"os"
"os/exec"
"strings"

Expand Down Expand Up @@ -37,9 +38,35 @@ func command(stdin io.Reader, stdout io.Writer) error {
if err != nil {
return err
}

if pager := os.ExpandEnv("$PAGER"); pager != "" {
return usePager(pager, stdout, func(w io.Writer) error {
return printDiff(w, profiles, moduleInfo)
})
}
return printDiff(stdout, profiles, moduleInfo)
}

func usePager(
name string,
stdout io.Writer,
printDiff func(io.Writer) error,
) error {
cmd := exec.Command(name)
cmd.Stdout = stdout
w, err := cmd.StdinPipe()
if err != nil {
return err
}

go func() {
defer w.Close()
printDiff(w)
}()

return cmd.Run()
}

func runGoTests() (*bytes.Buffer, error) {
cmd := exec.Command(
"go",
Expand Down
27 changes: 27 additions & 0 deletions command_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@ import (
)

func TestCommand(t *testing.T) {
// ensure only specific pager tests execute with custom pager
t.Setenv("PAGER", "")

t.Run("from-stdin-pipe", func(t *testing.T) {
args := os.Args
t.Cleanup(func() { os.Args = args })
Expand Down Expand Up @@ -119,6 +122,30 @@ func TestCommand(t *testing.T) {
t.Errorf("got len=%d, want len=%d", len(got), len(want))
}
})

t.Run("print-with-pager", func(t *testing.T) {
args := os.Args
t.Setenv("PAGER", "less")
t.Cleanup(func() { os.Args = args })

// override os.Args which contains flags from the test binary
os.Args = []string{
"coverdiff",
}
flag.Parse()

var stdout bytes.Buffer
stdin := bytes.NewBufferString(readFile(t, "testdata/coverage.out"))
if err := command(stdin, &stdout); err != nil {
t.Fatal(err)
}

got := stdout.String()
want := readFile(t, "testdata/coverdiff.out")
if got != want {
t.Errorf("got len=%d, want len=%d", len(got), len(want))
}
})
}

func readFile(tb testing.TB, name string) string {
Expand Down

0 comments on commit 0bec65f

Please sign in to comment.