Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(seed): generate random seed per-request if -1 is set #1952

Merged
merged 2 commits into from
Apr 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions .github/labeler.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
enhancements:
- head-branch: ['^feature', 'feature']

kind/documentation:
- any:
- changed-files:
- any-glob-to-any-file: 'docs/*'
- changed-files:
- any-glob-to-any-file: '*.md'

examples:
- any:
- changed-files:
- any-glob-to-any-file: 'examples/*'

ci:
- any:
- changed-files:
- any-glob-to-any-file: '.github/*'
12 changes: 12 additions & 0 deletions .github/workflows/labeler.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
name: "Pull Request Labeler"
on:
- pull_request_target

jobs:
labeler:
permissions:
contents: read
pull-requests: write
runs-on: ubuntu-latest
steps:
- uses: actions/labeler@v5
27 changes: 27 additions & 0 deletions .github/workflows/secscan.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: "Security Scan"

# Run workflow each time code is pushed to your repository and on a schedule.
# The scheduled workflow runs every at 00:00 on Sunday UTC time.
on:
push:
schedule:
- cron: '0 0 * * 0'

jobs:
tests:
runs-on: ubuntu-latest
env:
GO111MODULE: on
steps:
- name: Checkout Source
uses: actions/checkout@v3
- name: Run Gosec Security Scanner
uses: securego/gosec@master
with:
# we let the report trigger content trigger a failure using the GitHub Security features.
args: '-no-fail -fmt sarif -out results.sarif ./...'
- name: Upload SARIF file
uses: github/codeql-action/upload-sarif@v2
with:
# Path to SARIF file relative to the root of the repository
sarif_file: results.sarif
15 changes: 12 additions & 3 deletions core/backend/options.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package backend

import (
"math/rand"
"os"
"path/filepath"

Expand Down Expand Up @@ -33,12 +34,20 @@
return opts
}

func getSeed(c config.BackendConfig) int32 {
seed := int32(*c.Seed)
if seed == config.RAND_SEED {
seed = rand.Int31()
Dismissed Show dismissed Hide dismissed
}

return seed
}

func gRPCModelOpts(c config.BackendConfig) *pb.ModelOptions {
b := 512
if c.Batch != 0 {
b = c.Batch
}

return &pb.ModelOptions{
CUDA: c.CUDA || c.Diffusers.CUDA,
SchedulerType: c.Diffusers.SchedulerType,
Expand All @@ -54,7 +63,7 @@
CLIPSkip: int32(c.Diffusers.ClipSkip),
ControlNet: c.Diffusers.ControlNet,
ContextSize: int32(*c.ContextSize),
Seed: int32(*c.Seed),
Seed: getSeed(c),
NBatch: int32(b),
NoMulMatQ: c.NoMulMatQ,
DraftModel: c.DraftModel,
Expand Down Expand Up @@ -129,7 +138,7 @@
NKeep: int32(c.Keep),
Batch: int32(c.Batch),
IgnoreEOS: c.IgnoreEOS,
Seed: int32(*c.Seed),
Seed: getSeed(c),
FrequencyPenalty: float32(c.FrequencyPenalty),
MLock: *c.MMlock,
MMap: *c.MMap,
Expand Down
7 changes: 5 additions & 2 deletions core/config/backend_config.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import (
"errors"
"fmt"
"io/fs"
"math/rand"
"os"
"path/filepath"
"sort"
Expand All @@ -20,6 +19,10 @@ import (
"github.com/charmbracelet/glamour"
)

const (
RAND_SEED = -1
)

type BackendConfig struct {
schema.PredictionOptions `yaml:"parameters"`
Name string `yaml:"name"`
Expand Down Expand Up @@ -218,7 +221,7 @@ func (cfg *BackendConfig) SetDefaults(opts ...ConfigLoaderOption) {

if cfg.Seed == nil {
// random number generator seed
defaultSeed := int(rand.Int31())
defaultSeed := RAND_SEED
cfg.Seed = &defaultSeed
}

Expand Down
32 changes: 32 additions & 0 deletions docs/content/docs/features/text-generation.md
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ The backend will automatically download the required files in order to run the m
| Type | Description |
| --- | --- |
| `AutoModelForCausalLM` | `AutoModelForCausalLM` is a model that can be used to generate sequences. |
| `OVModelForCausalLM` | for OpenVINO models |
| N/A | Defaults to `AutoModel` |


Expand All @@ -324,4 +325,35 @@ curl http://localhost:8080/v1/completions -H "Content-Type: application/json" -d
"prompt": "Hello, my name is",
"temperature": 0.1, "top_p": 0.1
}'
```

#### Examples

##### OpenVINO

A model configuration file for openvion and starling model:

```yaml
name: starling-openvino
backend: transformers
parameters:
model: fakezeta/Starling-LM-7B-beta-openvino-int8
context_size: 8192
threads: 6
f16: true
type: OVModelForCausalLM
stopwords:
- <|end_of_turn|>
- <|endoftext|>
prompt_cache_path: "cache"
prompt_cache_all: true
template:
chat_message: |
{{if eq .RoleName "system"}}{{.Content}}<|end_of_turn|>{{end}}{{if eq .RoleName "assistant"}}<|end_of_turn|>GPT4 Correct Assistant: {{.Content}}<|end_of_turn|>{{end}}{{if eq .RoleName "user"}}GPT4 Correct User: {{.Content}}{{end}}

chat: |
{{.Input}}<|end_of_turn|>GPT4 Correct Assistant:

completion: |
{{.Input}}
```
5 changes: 3 additions & 2 deletions tests/e2e-aio/e2e_suite_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ var containerImageTag = os.Getenv("LOCALAI_IMAGE_TAG")
var modelsDir = os.Getenv("LOCALAI_MODELS_DIR")
var apiPort = os.Getenv("LOCALAI_API_PORT")
var apiEndpoint = os.Getenv("LOCALAI_API_ENDPOINT")
var apiKey = os.Getenv("LOCALAI_API_KEY")

func TestLocalAI(t *testing.T) {
RegisterFailHandler(Fail)
Expand All @@ -38,11 +39,11 @@ var _ = BeforeSuite(func() {
var defaultConfig openai.ClientConfig
if apiEndpoint == "" {
startDockerImage()
defaultConfig = openai.DefaultConfig("")
defaultConfig = openai.DefaultConfig(apiKey)
defaultConfig.BaseURL = "http://localhost:" + apiPort + "/v1"
} else {
fmt.Println("Default ", apiEndpoint)
defaultConfig = openai.DefaultConfig("")
defaultConfig = openai.DefaultConfig(apiKey)
defaultConfig.BaseURL = apiEndpoint
}

Expand Down
Loading