Skip to content

Commit

Permalink
Plugins: Common configuration and http utilities
Browse files Browse the repository at this point in the history
A couple of common patterns around configuration and calling external http
services have been identified, so in order to reduce code duplication, a couple
of utility interfaces and methods have been introduced.

Additionally the go version in go.mod has been bumped to 1.18, which is required
for generics.
  • Loading branch information
kvalev committed Jun 2, 2023
1 parent c3836e6 commit 4ed7d35
Show file tree
Hide file tree
Showing 5 changed files with 125 additions and 84 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,4 @@ require (
golang.org/x/arch v0.2.0 // indirect
)

go 1.17
go 1.18
25 changes: 24 additions & 1 deletion internal/plugin/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ package plugin

import (
"fmt"
"strings"
"os"
"strconv"
"strings"

"github.com/photoprism/photoprism/pkg/list"
)
Expand All @@ -25,6 +26,28 @@ func (c PluginConfig) Enabled() bool {
return false
}

// MandatoryStringParameter reads a mandatory string plugin parameter.
func (c PluginConfig) MandatoryStringParameter(name string) (string, error) {
if value, ok := c[name]; !ok {
return "", fmt.Errorf("%s parameter is mandatory", name)
} else {
return value, nil
}
}

// OptionalFloatParameter reads an optional float64 plugin parameter.
func (c PluginConfig) OptionalFloatParameter(name string, defaultValue float64) (float64, error) {
if value, ok := c[name]; ok {
if fValue, err := strconv.ParseFloat(value, 64); err != nil {
return 0, err
} else {
return fValue, nil
}
} else {
return defaultValue, nil
}
}

func loadConfig(p Plugin) PluginConfig {
var config = make(PluginConfig)

Expand Down
7 changes: 7 additions & 0 deletions internal/plugin/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@ type Plugin interface {
OnIndex(*entity.File, *entity.Photo) error
}

// HttpPlugin provides an interface for plugins calling external http services.
type HttpPlugin interface {
Plugin
Hostname() string
Port() string
}

// OnIndex calls the [OnIndex] hook method for all enabled plugins.
func OnIndex(file *entity.File, photo *entity.Photo) (changed bool) {
for _, p := range getPlugins() {
Expand Down
69 changes: 69 additions & 0 deletions internal/plugin/utils.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package plugin

import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"net/http"
"path/filepath"
"time"

"github.com/disintegration/imaging"
"github.com/photoprism/photoprism/pkg/fs"
)

// ReadImageAsBase64 reads an image, rotates it if needed and returns it as base64-encoded string.
func ReadImageAsBase64(filePath string) (string, error) {
if !fs.FileExists(filePath) {
return "", fmt.Errorf("file %s is missing", filepath.Base(filePath))
}

img, err := imaging.Open(filePath, imaging.AutoOrientation(true))
if err != nil {
return "", err
}

buffer := &bytes.Buffer{}
err = imaging.Encode(buffer, img, imaging.JPEG)
if err != nil {
return "", err
}

encoded := base64.StdEncoding.EncodeToString(buffer.Bytes())

return encoded, nil
}

// PostJson sends a post request with a json payload to a plugin endpoint and returns a deserialized json output.
func PostJson[T any](p HttpPlugin, endpoint string, payload map[string]interface{}) (T, error) {
client := &http.Client{Timeout: 60 * time.Second}
url := fmt.Sprintf("http://%s:%s/%s", p.Hostname(), p.Port(), endpoint)

var empty T

var req *http.Request
var output *T

if j, err := json.Marshal(payload); err != nil {
return empty, err
} else if req, err = http.NewRequest(http.MethodPost, url, bytes.NewReader(j)); err != nil {
return empty, err
}

// Add Content-Type header.
req.Header.Add("Content-Type", "application/json")

if resp, err := client.Do(req); err != nil {
return empty, err
} else if resp.StatusCode != 200 {
return empty, fmt.Errorf("%s server running at %s:%s, bad status %d", p.Name(), p.Hostname(), p.Port(), resp.StatusCode)
} else if body, err := io.ReadAll(resp.Body); err != nil {
return empty, err
} else if err := json.Unmarshal(body, &output); err != nil {
return empty, err
} else {
return *output, nil
}
}
106 changes: 24 additions & 82 deletions internal/plugin/yolo8/yolo8.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,12 @@
package main

import (
"bytes"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"io/ioutil"
"net/http"
"strconv"
"time"

"github.com/photoprism/photoprism/internal/classify"
"github.com/photoprism/photoprism/internal/entity"
"github.com/photoprism/photoprism/internal/photoprism"
"github.com/photoprism/photoprism/internal/plugin"
"github.com/photoprism/photoprism/pkg/clean"
"github.com/photoprism/photoprism/pkg/fs"
)

type ClassifyResults map[string]float64
Expand All @@ -38,24 +28,28 @@ func (p Yolo8Plugin) Name() string {
return "yolo8"
}

func (p Yolo8Plugin) Hostname() string {
return p.hostname
}

func (p Yolo8Plugin) Port() string {
return p.port
}

func (p *Yolo8Plugin) Configure(config plugin.PluginConfig) error {
hostname, ok := config["hostname"]
if !ok {
return fmt.Errorf("hostname parameter is mandatory")
hostname, err := config.MandatoryStringParameter("hostname")
if err != nil {
return err
}

port, ok := config["port"]
if !ok {
return fmt.Errorf("port parameter is mandatory")
port, err := config.MandatoryStringParameter("port")
if err != nil {
return err
}

threshold := 0.5
var err error

if t, ok := config["confidence_threshold"]; ok {
if threshold, err = strconv.ParseFloat(t, 64); err != nil {
return err
}
threshold, err := config.OptionalFloatParameter("confidence_threshold", 0.5)
if err != nil {
return err
}

p.hostname = hostname
Expand Down Expand Up @@ -85,78 +79,26 @@ func (p *Yolo8Plugin) OnIndex(file *entity.File, photo *entity.Photo) error {
func (p *Yolo8Plugin) image(f *entity.File) (string, error) {
filePath := photoprism.FileName(f.FileRoot, f.FileName)

if !fs.FileExists(filePath) {
return "", fmt.Errorf("file %s is missing", clean.Log(f.FileName))
}

data, err := ioutil.ReadFile(filePath)

if err != nil {
return "", err
}

encoded := base64.StdEncoding.EncodeToString(data)

return encoded, nil
return plugin.ReadImageAsBase64(filePath)
}

func (p *Yolo8Plugin) detect(image string) (classify.Labels, error) {
client := &http.Client{Timeout: 60 * time.Second}
url := fmt.Sprintf("http://%s:%s/detect", p.hostname, p.port)
payload := map[string]string{"image": image}

var req *http.Request
var output *DetectResults
payload := map[string]interface{}{"image": image}

if j, err := json.Marshal(payload); err != nil {
return nil, err
} else if req, err = http.NewRequest(http.MethodPost, url, bytes.NewReader(j)); err != nil {
return nil, err
}

// Add Content-Type header.
req.Header.Add("Content-Type", "application/json")

if resp, err := client.Do(req); err != nil {
return nil, err
} else if resp.StatusCode != 200 {
return nil, fmt.Errorf("yolo8 server running at %s:%s, bad status %d\n", p.hostname, p.port, resp.StatusCode)
} else if body, err := io.ReadAll(resp.Body); err != nil {
return nil, err
} else if err := json.Unmarshal(body, &output); err != nil {
if output, err := plugin.PostJson[DetectResults](p, "detect", payload); err != nil {
return nil, err
} else {
return (*output).toLabels(p.confThreshold), nil
return output.toLabels(p.confThreshold), nil
}
}

func (p *Yolo8Plugin) classify(image string) (classify.Labels, error) {
client := &http.Client{Timeout: 60 * time.Second}
url := fmt.Sprintf("http://%s:%s/classify", p.hostname, p.port)
payload := map[string]string{"image": image}
payload := map[string]interface{}{"image": image}

var req *http.Request
var output *ClassifyResults

if j, err := json.Marshal(payload); err != nil {
return nil, err
} else if req, err = http.NewRequest(http.MethodPost, url, bytes.NewReader(j)); err != nil {
return nil, err
}

// Add Content-Type header.
req.Header.Add("Content-Type", "application/json")

if resp, err := client.Do(req); err != nil {
return nil, err
} else if resp.StatusCode != 200 {
return nil, fmt.Errorf("yolo8 server running at %s:%s, bad status %d\n", p.hostname, p.port, resp.StatusCode)
} else if body, err := io.ReadAll(resp.Body); err != nil {
return nil, err
} else if err := json.Unmarshal(body, &output); err != nil {
if output, err := plugin.PostJson[ClassifyResults](p, "classify", payload); err != nil {
return nil, err
} else {
return (*output).toLabels(p.confThreshold), nil
return output.toLabels(p.confThreshold), nil
}
}

Expand Down

0 comments on commit 4ed7d35

Please sign in to comment.