diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml deleted file mode 100644 index 105b27d37..000000000 --- a/.github/ISSUE_TEMPLATE/bug_report.yml +++ /dev/null @@ -1,62 +0,0 @@ -name: Bug report -description: File a bug report -title: "[BUG]: " -labels: ["bug"] -assignees: - - MilesCranmer - -body: - - type: markdown - attributes: - value: | - Thanks for taking the time to fill out this bug report! - - type: textarea - id: what-happened - attributes: - label: What happened? - description: Also tell us, what did you expect to happen? - placeholder: Tell us what you see! - value: "A bug happened!" - validations: - required: true - - type: input - id: version - attributes: - label: Version - description: What version of SymbolicRegression.jl are you running? `julia -e 'import SymbolicRegression; println(SymbolicRegression.PACKAGE_VERSION)'`. - validations: - required: true - - type: dropdown - id: os - attributes: - label: Operating System - description: What operating system do you see the issue on? - multiple: true - options: - - Windows - - macOS - - Linux - - type: dropdown - id: display - attributes: - label: Interface - description: How are you running SymbolicRegression? - multiple: true - options: - - Script (i.e., `python my_script.py`) - - Jupyter Notebook - - Julia REPL - - Other (specify below) - validations: - required: true - - type: textarea - id: logs - attributes: - label: Relevant log output - description: Please copy and paste any log output or error messages. This will be automatically formatted into code, so no need for backticks. - render: shell - - type: textarea - id: extra - attributes: - label: Extra Info - description: Please tell us any other information that you think might help. For example, what are your search settings? What dataset are you running on? If possible, please share a minimal code example that produces the error. diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml deleted file mode 100644 index 77efaf370..000000000 --- a/.github/ISSUE_TEMPLATE/config.yml +++ /dev/null @@ -1,5 +0,0 @@ -blank_issues_enabled: true -contact_links: - - name: Forums - url: https://github.com/MilesCranmer/PySR/discussions - about: Please ask and answer questions about how to use SymbolicRegression.jl and PySR here. diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml deleted file mode 100644 index 4636eb3c5..000000000 --- a/.github/ISSUE_TEMPLATE/feature_request.yml +++ /dev/null @@ -1,21 +0,0 @@ -name: Feature request -description: Suggest an idea for this project -title: "[Feature]: " -labels: ["enhancement"] - -body: - - type: markdown - attributes: - value: | - Thanks for taking the time to suggest a feature for SymbolicRegression.jl! Your interest in the project helps improve the software for everyone ๐Ÿš€ - - type: textarea - id: feature - attributes: - label: Feature Request - description: Describe your desired feature request here! The more details the better. - validations: - required: true - - type: markdown - attributes: - value: | - Be sure to check out the [forums](https://github.com/MilesCranmer/PySR/discussions) to chat with other users about use-cases! diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 713a0b168..cb1c9ea86 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -26,8 +26,9 @@ jobs: fail-fast: false matrix: test: - - "unit" - - "integration" + - "part1" + - "part2" + - "part3" julia-version: - "1.6" - "1.8" @@ -37,22 +38,31 @@ jobs: include: - os: windows-latest julia-version: "1" - test: "unit" + test: "part1" - os: windows-latest julia-version: "1" - test: "integration" + test: "part2" + - os: windows-latest + julia-version: "1" + test: "part3" + - os: macOS-latest + julia-version: "1" + test: "part1" - os: macOS-latest julia-version: "1" - test: "unit" + test: "part2" - os: macOS-latest julia-version: "1" - test: "integration" + test: "part3" - os: ubuntu-latest julia-version: "~1.11.0-0" - test: "unit" + test: "part1" - os: ubuntu-latest julia-version: "~1.11.0-0" - test: "integration" + test: "part2" + - os: ubuntu-latest + julia-version: "~1.11.0-0" + test: "part3" steps: - uses: actions/checkout@v4 @@ -62,6 +72,8 @@ jobs: version: ${{ matrix.julia-version }} - name: "Cache dependencies" uses: julia-actions/cache@v2 + with: + cache-name: julia-cache;workflow=${{ github.workflow }};job=${{ github.job }};os=${{ matrix.os }};julia=${{ matrix.julia-version }};project=${{ hashFiles('**/Project.toml') }} - name: "Build package" uses: julia-actions/julia-buildpkg@v1 - name: "Run tests" diff --git a/.gitignore b/.gitignore index d362e9fe5..2cb9c5d85 100644 --- a/.gitignore +++ b/.gitignore @@ -20,3 +20,4 @@ docs/src/index.md *.code-workspace .vscode **/*.json +LocalPreferences.toml diff --git a/Project.toml b/Project.toml index 4927ffae0..9e3f01ea4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,15 +1,19 @@ -name = "SymbolicRegression" -uuid = "8254be44-1295-4e6a-a16d-46603ac705cb" -authors = ["MilesCranmer "] -version = "0.24.5" +name = "LibraryAugmentedSymbolicRegression" +uuid = "158930c3-947c-4174-974b-74b39e64a28f" +authors = ["AryaGrayeli ", "AtharvaSehgal ", "MilesCranmer "] +version = "0.0.1" [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" +ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b" DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821" +JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" LossFunctions = "30fc2ffe-d236-52d8-8643-a9d8f7c094a7" MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea" @@ -20,6 +24,7 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" ProgressBars = "49802e3a-d2f1-5c88-81d8-b72133a6f568" +PromptingTools = "670122d1-24a8-4d70-bfce-740807c42192" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" @@ -27,20 +32,26 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" TOML = "fa267f1f-6049-4f14-aa54-33bafae1ed76" [weakdeps] +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" [extensions] -SymbolicRegressionJSON3Ext = "JSON3" -SymbolicRegressionSymbolicUtilsExt = "SymbolicUtils" +LaSREnzymeExt = "Enzyme" +LaSRJSON3Ext = "JSON3" +LaSRSymbolicUtilsExt = "SymbolicUtils" [compat] +ADTypes = "^1.4.0" Compat = "^4.2" +ConstructionBase = "<1.5.7" Dates = "1" +DifferentiationInterface = "0.5" DispatchDoctor = "0.4" -Distributed = "1" -DynamicExpressions = "0.16" -DynamicQuantities = "0.10, 0.11, 0.12, 0.13, 0.14" +Distributed = "<0.0.1, 1" +DynamicExpressions = "0.19.3" +DynamicQuantities = "0.10, 0.11, 0.12, 0.13, 0.14, 1" +Enzyme = "0.12" JSON3 = "1" LineSearches = "7" LossFunctions = "0.10, 0.11" @@ -48,18 +59,19 @@ MLJModelInterface = "~1.5, ~1.6, ~1.7, ~1.8, ~1.9, ~1.10, ~1.11" MacroTools = "0.4, 0.5" Optim = "~1.8, ~1.9" PackageExtensionCompat = "1" -Pkg = "1" +Pkg = "<0.0.1, 1" PrecompileTools = "1" -Printf = "1" +Printf = "<0.0.1, 1" ProgressBars = "~1.4, ~1.5" -Random = "1" +Random = "<0.0.1, 1" Reexport = "1" SpecialFunctions = "0.10.1, 1, 2" StatsBase = "0.33, 0.34" -SymbolicUtils = "0.19, ^1.0.5" -TOML = "1" +SymbolicUtils = "0.19, ^1.0.5, 2, 3" +TOML = "<0.0.1, 1" julia = "1.6" [extras] +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" diff --git a/README.md b/README.md index 8c5361e77..62d46a4d4 100644 --- a/README.md +++ b/README.md @@ -1,21 +1,20 @@
-SymbolicRegression.jl searches for symbolic expressions which optimize a particular objective. +LaSR.jl accelerates the search for symbolic expressions using language guidance. -https://github.com/MilesCranmer/SymbolicRegression.jl/assets/7593028/f5b68f1f-9830-497f-a197-6ae332c94ee0 - -| Latest release | Documentation | Forums | Paper | +| Latest release | Website | Forums | Paper | | :---: | :---: | :---: | :---: | -| [![version](https://juliahub.com/docs/SymbolicRegression/version.svg)](https://juliahub.com/ui/Packages/SymbolicRegression/X2eIS) | [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://astroautomata.com/SymbolicRegression.jl/dev/) | [![Discussions](https://img.shields.io/badge/discussions-github-informational)](https://github.com/MilesCranmer/PySR/discussions) | [![Paper](https://img.shields.io/badge/arXiv-2305.01582-b31b1b)](https://arxiv.org/abs/2305.01582) | +| [![version](https://juliahub.com/docs/LaSR/version.svg)](https://juliahub.com/ui/Packages/LaSR/X2eIS) | [![Dev](https://img.shields.io/badge/docs-dev-blue.svg)](https://trishullab.github.io/lasr-web/) | [![Discussions](https://img.shields.io/badge/discussions-github-informational)](https://github.com/trishullab/LaSR.jl/discussions) | [![Paper](https://img.shields.io/badge/arXiv-????.?????-b31b1b)](https://atharvas.net/static/lasr.pdf) | | Build status | Coverage | | :---: | :---: | -| [![CI](https://github.com/MilesCranmer/SymbolicRegression.jl/workflows/CI/badge.svg)](.github/workflows/CI.yml) | [![Coverage Status](https://coveralls.io/repos/github/MilesCranmer/SymbolicRegression.jl/badge.svg?branch=master)](https://coveralls.io/github/MilesCranmer/SymbolicRegression.jl?branch=master) | +| [![CI](https://github.com/trishullab/LaSR.jl/workflows/CI/badge.svg)](.github/workflows/CI.yml) | [![Coverage Status](https://coveralls.io/repos/github/trishullab/LaSR.jl/badge.svg?branch=master)](https://coveralls.io/github/trishullab/LaSR.jl?branch=master) | -Check out [PySR](https://github.com/MilesCranmer/PySR) for +LaSR is integrated with [SymbolicRegression.jl](https://github.com/MilesCranmer/SymbolicRegression.jl). Check out [PySR](https://github.com/MilesCranmer/PySR) for a Python frontend. -[Cite this software](https://arxiv.org/abs/2305.01582) + +[Cite this software](https://arxiv.org/abs/????.?????)
@@ -23,13 +22,8 @@ a Python frontend. **Contents**: - [Quickstart](#quickstart) - - [MLJ Interface](#mlj-interface) - - [Low-Level Interface](#low-level-interface) -- [Constructing expressions](#constructing-expressions) -- [Exporting to SymbolicUtils.jl](#exporting-to-symbolicutilsjl) -- [Contributors โœจ](#contributors-) -- [Code structure](#code-structure) -- [Search options](#search-options) +- [Organization](#organization) +- [LLM Utilities](#llm-utilities) ## Quickstart @@ -37,17 +31,15 @@ Install in Julia with: ```julia using Pkg -Pkg.add("SymbolicRegression") +Pkg.add("LaSR") ``` -### MLJ Interface - -The easiest way to use SymbolicRegression.jl +LaSR uses the same interface as [SymbolicRegression.jl](https://github.com/MilesCranmer/SymbolicRegression.jl). The easiest way to use LaSR.jl is with [MLJ](https://github.com/alan-turing-institute/MLJ.jl). Let's see an example: ```julia -import SymbolicRegression: SRRegressor +import LaSR: LaSRRegressor, LLMOptions import MLJ: machine, fit!, predict, report # Dataset with two named features: @@ -59,10 +51,13 @@ y = @. 2 * cos(X.a * 23.5) - X.b ^ 2 # with some noise: y = y .+ randn(500) .* 1e-3 -model = SRRegressor( +model = LaSRRegressor( niterations=50, binary_operators=[+, -, *], unary_operators=[cos], + llm_options=LLMOptions( + ... + ) ) ``` @@ -108,196 +103,29 @@ predict(mach, (data=X, idx=2)) where here we choose to evaluate the second equation. -For fitting multiple outputs, one can use `MultitargetSRRegressor` +For fitting multiple outputs, one can use `MultitargetLaSRRegressor` (and pass an array of indices to `idx` in `predict` for selecting specific equations). -For a full list of options available to each regressor, see the [API page](https://astroautomata.com/SymbolicRegression.jl/dev/api/). - -### Low-Level Interface - -The heart of SymbolicRegression.jl is the -`equation_search` function. -This takes a 2D array and attempts -to model a 1D array using analytic functional forms. -**Note:** unlike the MLJ interface, -this assumes column-major input of shape [features, rows]. - -```julia -import SymbolicRegression: Options, equation_search - -X = randn(2, 100) -y = 2 * cos.(X[2, :]) + X[1, :] .^ 2 .- 2 - -options = Options( - binary_operators=[+, *, /, -], - unary_operators=[cos, exp], - populations=20 -) - -hall_of_fame = equation_search( - X, y, niterations=40, options=options, - parallelism=:multithreading -) -``` - -You can view the resultant equations in the dominating Pareto front (best expression -seen at each complexity) with: - -```julia -import SymbolicRegression: calculate_pareto_frontier - -dominating = calculate_pareto_frontier(hall_of_fame) -``` - -This is a vector of `PopMember` type - which contains the expression along with the score. -We can get the expressions with: - -```julia -trees = [member.tree for member in dominating] -``` - -Each of these equations is a `Node{T}` type for some constant type `T` (like `Float32`). - -You can evaluate a given tree with: - -```julia -import SymbolicRegression: eval_tree_array - -tree = trees[end] -output, did_succeed = eval_tree_array(tree, X, options) -``` - -The `output` array will contain the result of the tree at each of the 100 rows. -This `did_succeed` flag detects whether an evaluation was successful, or whether -encountered any NaNs or Infs during calculation (such as, e.g., `sqrt(-1)`). - -## Constructing expressions +For a full list of options available to each regressor, see the [API page](https://astroautomata.com/LaSR.jl/dev/api/). -Expressions are represented as the `Node` type which is developed -in the [DynamicExpressions.jl](https://github.com/SymbolicML/DynamicExpressions.jl/) package. - -You can manipulate and construct expressions directly. For example: +### LLM Options +LaSR uses PromptingTools.jl for zero shot prompting. If you wish to make changes to the prompting options, you can pass an `LLMOptions` object to the `LaSRRegressor` constructor. The options available are: ```julia -import SymbolicRegression: Options, Node, eval_tree_array - -options = Options(; - binary_operators=[+, -, *, ^, /], unary_operators=[cos, exp, sin] +llm_options = LLMOptions( + ... ) -x1, x2, x3 = [Node(; feature=i) for i=1:3] -tree = cos(x1 - 3.2 * x2) - x1^3.2 -``` - -This tree has `Float64` constants, so the type of the entire tree -will be promoted to `Node{Float64}`. - -We can convert all constants (recursively) to `Float32`: - -```julia -float32_tree = convert(Node{Float32}, tree) -``` - -We can then evaluate this tree on a dataset: - -```julia -X = rand(Float32, 3, 100) -output, did_succeed = eval_tree_array(tree, X, options) -``` - -## Exporting to SymbolicUtils.jl - -We can view the equations in the dominating -Pareto frontier with: - -```julia -dominating = calculate_pareto_frontier(hall_of_fame) -``` - -We can convert the best equation -to [SymbolicUtils.jl](https://github.com/JuliaSymbolics/SymbolicUtils.jl) -with the following function: - -```julia -import SymbolicRegression: node_to_symbolic - -eqn = node_to_symbolic(dominating[end].tree, options) -println(simplify(eqn*5 + 3)) ``` -We can also print out the full pareto frontier like so: - -```julia -import SymbolicRegression: compute_complexity, string_tree -println("Complexity\tMSE\tEquation") +## Organization -for member in dominating - complexity = compute_complexity(member, options) - loss = member.loss - string = string_tree(member.tree, options) - - println("$(complexity)\t$(loss)\t$(string)") -end -``` - -## Contributors โœจ - -We are eager to welcome new contributors! -If you have an idea for a new feature, don't hesitate to share it on the [issues](https://github.com/MilesCranmer/SymbolicRegression.jl/issues) page or [forums](https://github.com/MilesCranmer/PySR/discussions). - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
Mark Kittisopikul
Mark Kittisopikul

๐Ÿ’ป ๐Ÿ’ก ๐Ÿš‡ ๐Ÿ“ฆ ๐Ÿ“ฃ ๐Ÿ‘€ ๐Ÿ”ง โš ๏ธ
T Coxon
T Coxon

๐Ÿ› ๐Ÿ’ป ๐Ÿ”Œ ๐Ÿ’ก ๐Ÿš‡ ๐Ÿšง ๐Ÿ‘€ ๐Ÿ”ง โš ๏ธ ๐Ÿ““
Dhananjay Ashok
Dhananjay Ashok

๐Ÿ’ป ๐ŸŒ ๐Ÿ’ก ๐Ÿšง โš ๏ธ
Johan Blรฅbรคck
Johan Blรฅbรคck

๐Ÿ› ๐Ÿ’ป ๐Ÿ’ก ๐Ÿšง ๐Ÿ“ฃ ๐Ÿ‘€ โš ๏ธ ๐Ÿ““
JuliusMartensen
JuliusMartensen

๐Ÿ› ๐Ÿ’ป ๐Ÿ“– ๐Ÿ”Œ ๐Ÿ’ก ๐Ÿš‡ ๐Ÿšง ๐Ÿ“ฆ ๐Ÿ“ฃ ๐Ÿ‘€ ๐Ÿ”ง ๐Ÿ““
ngam
ngam

๐Ÿ’ป ๐Ÿš‡ ๐Ÿ“ฆ ๐Ÿ‘€ ๐Ÿ”ง โš ๏ธ
Kaze Wong
Kaze Wong

๐Ÿ› ๐Ÿ’ป ๐Ÿ’ก ๐Ÿš‡ ๐Ÿšง ๐Ÿ“ฃ ๐Ÿ‘€ ๐Ÿ”ฌ ๐Ÿ““
Christopher Rackauckas
Christopher Rackauckas

๐Ÿ› ๐Ÿ’ป ๐Ÿ”Œ ๐Ÿ’ก ๐Ÿš‡ ๐Ÿ“ฃ ๐Ÿ‘€ ๐Ÿ”ฌ ๐Ÿ”ง โš ๏ธ ๐Ÿ““
Patrick Kidger
Patrick Kidger

๐Ÿ› ๐Ÿ’ป ๐Ÿ“– ๐Ÿ”Œ ๐Ÿ’ก ๐Ÿšง ๐Ÿ“ฃ ๐Ÿ‘€ ๐Ÿ”ฌ ๐Ÿ”ง โš ๏ธ ๐Ÿ““
Okon Samuel
Okon Samuel

๐Ÿ› ๐Ÿ’ป ๐Ÿ“– ๐Ÿšง ๐Ÿ’ก ๐Ÿš‡ ๐Ÿ‘€ โš ๏ธ ๐Ÿ““
William Booth-Clibborn
William Booth-Clibborn

๐Ÿ’ป ๐ŸŒ ๐Ÿ“– ๐Ÿ““ ๐Ÿšง ๐Ÿ‘€ ๐Ÿ”ง โš ๏ธ
Pablo Lemos
Pablo Lemos

๐Ÿ› ๐Ÿ’ก ๐Ÿ“ฃ ๐Ÿ‘€ ๐Ÿ”ฌ ๐Ÿ““
Jerry Ling
Jerry Ling

๐Ÿ› ๐Ÿ’ป ๐Ÿ“– ๐ŸŒ ๐Ÿ’ก ๐Ÿ“ฃ ๐Ÿ‘€ ๐Ÿ““
Charles Fox
Charles Fox

๐Ÿ› ๐Ÿ’ป ๐Ÿ’ก ๐Ÿšง ๐Ÿ“ฃ ๐Ÿ‘€ ๐Ÿ”ฌ ๐Ÿ““
Johann Brehmer
Johann Brehmer

๐Ÿ’ป ๐Ÿ“– ๐Ÿ’ก ๐Ÿ“ฃ ๐Ÿ‘€ ๐Ÿ”ฌ โš ๏ธ ๐Ÿ““
Marius Millea
Marius Millea

๐Ÿ’ป ๐Ÿ’ก ๐Ÿ“ฃ ๐Ÿ‘€ ๐Ÿ““
Coba
Coba

๐Ÿ› ๐Ÿ’ป ๐Ÿ’ก ๐Ÿ‘€ ๐Ÿ““
Pietro Monticone
Pietro Monticone

๐Ÿ› ๐Ÿ“– ๐Ÿ’ก
Mateusz Kubica
Mateusz Kubica

๐Ÿ“– ๐Ÿ’ก
Jay Wadekar
Jay Wadekar

๐Ÿ› ๐Ÿ’ก ๐Ÿ“ฃ ๐Ÿ”ฌ
Anthony Blaom, PhD
Anthony Blaom, PhD

๐Ÿš‡ ๐Ÿ’ก ๐Ÿ‘€
Jgmedina95
Jgmedina95

๐Ÿ› ๐Ÿ’ก ๐Ÿ‘€
Michael Abbott
Michael Abbott

๐Ÿ’ป ๐Ÿ’ก ๐Ÿ‘€ ๐Ÿ”ง
Oscar Smith
Oscar Smith

๐Ÿ’ป ๐Ÿ’ก
Eric Hanson
Eric Hanson

๐Ÿ’ก ๐Ÿ“ฃ ๐Ÿ““
Henrique Becker
Henrique Becker

๐Ÿ’ป ๐Ÿ’ก ๐Ÿ‘€
qwertyjl
qwertyjl

๐Ÿ› ๐Ÿ“– ๐Ÿ’ก ๐Ÿ““
Rik Huijzer
Rik Huijzer

๐Ÿ’ก ๐Ÿš‡
Hongyu Wang
Hongyu Wang

๐Ÿ’ก ๐Ÿ“ฃ ๐Ÿ”ฌ
Saurav Maheshkar
Saurav Maheshkar

๐Ÿ”ง
- - - +LaSR.jl development is kept independent from the main codebase. However, to ensure LaSR can be used easily, it is integrated into SymbolicRegression.jl via the `ext/SymbolicRegressionLaSRExt` extension module. This, in turn, is loaded into PySR. This cartoon summarizes the interaction between the different packages: - +![LaSR.jl organization](https://raw.githubusercontent.com/trishullab/lasr-web/main/static/lasr-code-interactions.svg) ## Code structure -SymbolicRegression.jl is organized roughly as follows. +LaSR.jl is organized roughly as follows. Rounded rectangles indicate objects, and rectangles indicate functions. > (if you can't see this diagram being rendered, try pasting it into [mermaid-js.github.io/mermaid-live-editor](https://mermaid-js.github.io/mermaid-live-editor)) @@ -348,9 +176,9 @@ stateDiagram-v2 AdaptiveParsimony --> Population AdaptiveParsimony --> RegularizedEvolution AdaptiveParsimony --> SingleIteration - AdaptiveParsimony --> SymbolicRegression + AdaptiveParsimony --> LaSR CheckConstraints --> Mutate - CheckConstraints --> SymbolicRegression + CheckConstraints --> LaSR Complexity --> CheckConstraints Complexity --> HallOfFame Complexity --> LossFunctions @@ -358,7 +186,7 @@ stateDiagram-v2 Complexity --> Population Complexity --> SearchUtils Complexity --> SingleIteration - Complexity --> SymbolicRegression + Complexity --> LaSR ConstantOptimization --> Mutate ConstantOptimization --> SingleIteration Core --> AdaptiveParsimony @@ -377,24 +205,24 @@ stateDiagram-v2 Core --> RegularizedEvolution Core --> SearchUtils Core --> SingleIteration - Core --> SymbolicRegression + Core --> LaSR Dataset --> Core HallOfFame --> SearchUtils HallOfFame --> SingleIteration - HallOfFame --> SymbolicRegression + HallOfFame --> LaSR InterfaceDynamicExpressions --> LossFunctions - InterfaceDynamicExpressions --> SymbolicRegression + InterfaceDynamicExpressions --> LaSR LossFunctions --> ConstantOptimization LossFunctions --> HallOfFame LossFunctions --> Mutate LossFunctions --> PopMember LossFunctions --> Population - LossFunctions --> SymbolicRegression - Migration --> SymbolicRegression + LossFunctions --> LaSR + Migration --> LaSR Mutate --> RegularizedEvolution MutationFunctions --> Mutate MutationFunctions --> Population - MutationFunctions --> SymbolicRegression + MutationFunctions --> LaSR Operators --> Core Operators --> Options Options --> Core @@ -407,29 +235,29 @@ stateDiagram-v2 PopMember --> Population PopMember --> RegularizedEvolution PopMember --> SingleIteration - PopMember --> SymbolicRegression + PopMember --> LaSR Population --> Migration Population --> RegularizedEvolution Population --> SearchUtils Population --> SingleIteration - Population --> SymbolicRegression + Population --> LaSR ProgramConstants --> Core ProgramConstants --> Dataset ProgressBars --> SearchUtils - ProgressBars --> SymbolicRegression + ProgressBars --> LaSR Recorder --> Mutate Recorder --> RegularizedEvolution Recorder --> SingleIteration - Recorder --> SymbolicRegression + Recorder --> LaSR RegularizedEvolution --> SingleIteration - SearchUtils --> SymbolicRegression - SingleIteration --> SymbolicRegression + SearchUtils --> LaSR + SingleIteration --> LaSR Utils --> CheckConstraints Utils --> ConstantOptimization Utils --> Options Utils --> PopMember Utils --> SingleIteration - Utils --> SymbolicRegression + Utils --> LaSR ``` Bash command to generate dependency structure from `src` directory (requires `vim-stream`): @@ -446,4 +274,4 @@ done | vims -l 'f a--> ' | sort ## Search options -See https://astroautomata.com/SymbolicRegression.jl/stable/api/#Options +See https://astroautomata.com/LaSR.jl/stable/api/#Options diff --git a/benchmark/Project.toml b/benchmark/Project.toml index d6b434665..6e47cf260 100644 --- a/benchmark/Project.toml +++ b/benchmark/Project.toml @@ -3,5 +3,12 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b" LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[preferences.DynamicExpressions] +instability_check = "disable" + +[preferences.LibraryAugmentedSymbolicRegression] +instability_check = "disable" diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index 4f961302a..71c1404af 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -1,9 +1,9 @@ using BenchmarkTools -using SymbolicRegression, BenchmarkTools, Random -using SymbolicRegression.AdaptiveParsimonyModule: RunningSearchStatistics -using SymbolicRegression.PopulationModule: best_of_sample -using SymbolicRegression.ConstantOptimizationModule: optimize_constants -using SymbolicRegression.CheckConstraintsModule: check_constraints +using LibraryAugmentedSymbolicRegression, BenchmarkTools, Random +using LibraryAugmentedSymbolicRegression.AdaptiveParsimonyModule: RunningSearchStatistics +using LibraryAugmentedSymbolicRegression.PopulationModule: best_of_sample +using LibraryAugmentedSymbolicRegression.ConstantOptimizationModule: optimize_constants +using LibraryAugmentedSymbolicRegression.CheckConstraintsModule: check_constraints using Bumper, LoopVectorization function create_search_benchmark() diff --git a/benchmark/single_eval.jl b/benchmark/single_eval.jl index 36a0de116..d20db7310 100644 --- a/benchmark/single_eval.jl +++ b/benchmark/single_eval.jl @@ -1,5 +1,5 @@ using BenchmarkTools -using SymbolicRegression +using LibraryAugmentedSymbolicRegression nfeatures = 3 X = randn(nfeatures, 200) diff --git a/docs/make.jl b/docs/make.jl index f3ad5f756..ea1ad654c 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,13 +1,13 @@ using Documenter using SymbolicUtils -using SymbolicRegression -using SymbolicRegression: Dataset, update_baseline_loss! +using LibraryAugmentedSymbolicRegression +using LibraryAugmentedSymbolicRegression: Dataset, update_baseline_loss! DocMeta.setdocmeta!( - SymbolicRegression, :DocTestSetup, :(using LossFunctions); recursive=true + LibraryAugmentedSymbolicRegression, :DocTestSetup, :(using LossFunctions); recursive=true ) DocMeta.setdocmeta!( - SymbolicRegression, :DocTestSetup, :(using DynamicExpressions); recursive=true + LibraryAugmentedSymbolicRegression, :DocTestSetup, :(using DynamicExpressions); recursive=true ) readme = open(dirname(@__FILE__) * "/../README.md") do io @@ -20,17 +20,17 @@ readme = replace(readme, r"" => s"") # Then, we remove any line with " s"") -# We delete the https://github.com/MilesCranmer/SymbolicRegression.jl/assets/7593028/f5b68f1f-9830-497f-a197-6ae332c94ee0, +# We delete the https://github.com/MilesCranmer/LibraryAugmentedSymbolicRegression.jl/assets/7593028/f5b68f1f-9830-497f-a197-6ae332c94ee0, # and replace it with a video: readme = replace( readme, - r"https://github.com/MilesCranmer/SymbolicRegression.jl/assets/7593028/f5b68f1f-9830-497f-a197-6ae332c94ee0" => + r"https://github.com/MilesCranmer/LibraryAugmentedSymbolicRegression.jl/assets/7593028/f5b68f1f-9830-497f-a197-6ae332c94ee0" => ( """ ```@raw html
``` @@ -81,13 +81,13 @@ open(dirname(@__FILE__) * "/src/index.md", "w") do io end makedocs(; - sitename="SymbolicRegression.jl", + sitename="LibraryAugmentedSymbolicRegression.jl", authors="Miles Cranmer", doctest=true, strict=:doctest, clean=true, format=Documenter.HTML(; - canonical="https://astroautomata.com/SymbolicRegression.jl/stable" + canonical="https://astroautomata.com/LibraryAugmentedSymbolicRegression.jl/stable" ), pages=[ "Contents" => "index_base.md", @@ -129,4 +129,4 @@ open("docs/build/losses/index.html", "w") do io write(io, string(html)) end -deploydocs(; repo="github.com/MilesCranmer/SymbolicRegression.jl.git") +deploydocs(; repo="github.com/MilesCranmer/LibraryAugmentedSymbolicRegression.jl.git") diff --git a/example.jl b/example.jl index ef70096e5..1ee299f73 100644 --- a/example.jl +++ b/example.jl @@ -1,9 +1,9 @@ -using SymbolicRegression +using LibraryAugmentedSymbolicRegression X = randn(Float32, 5, 100) y = 2 * cos.(X[4, :]) + X[1, :] .^ 2 .- 2 -options = SymbolicRegression.Options(; +options = LibraryAugmentedSymbolicRegression.Options(; binary_operators=[+, *, /, -], unary_operators=[cos, exp], populations=20 ) diff --git a/examples/parameterized_function.jl b/examples/parameterized_function.jl new file mode 100644 index 000000000..e2f1d26ac --- /dev/null +++ b/examples/parameterized_function.jl @@ -0,0 +1,47 @@ +using LibraryAugmentedSymbolicRegression +using Random: MersenneTwister +using Zygote +using MLJBase: machine, fit!, predict, report +using Test + +rng = MersenneTwister(0) +X = NamedTuple{(:x1, :x2, :x3, :x4, :x5)}(ntuple(_ -> randn(rng, Float32, 30), Val(5))) +X = (; X..., classes=rand(rng, 1:2, 30)) +p1 = [0.0f0, 3.2f0] +p2 = [1.5f0, 0.5f0] + +y = [ + 2 * cos(X.x4[i] + p1[X.classes[i]]) + X.x1[i]^2 - p2[X.classes[i]] for + i in eachindex(X.classes) +] + +stop_at = Ref(1e-4) + +model = LaSRRegressor(; + niterations=100, + binary_operators=[+, *, /, -], + unary_operators=[cos, exp], + populations=30, + expression_type=ParametricExpression, + expression_options=(; max_parameters=2), + autodiff_backend=:Zygote, + parallelism=:multithreading, + early_stop_condition=(loss, _) -> loss < stop_at[], +) + +mach = machine(model, X, y) + +fit!(mach) +idx1 = lastindex(report(mach).equations) +ypred1 = predict(mach, (data=X, idx=idx1)) +loss1 = sum(i -> abs(ypred1[i] - y[i]), eachindex(y)) + +# Should keep all parameters +stop_at[] = 1e-5 +fit!(mach) +idx2 = lastindex(report(mach).equations) +ypred2 = predict(mach, (data=X, idx=idx2)) +loss2 = sum(i -> abs(ypred2[i] - y[i]), eachindex(y)) + +# Should get better: +@test loss1 >= loss2 diff --git a/ext/SymbolicRegressionEnzymeExt.jl b/ext/SymbolicRegressionEnzymeExt.jl new file mode 100644 index 000000000..1bba1a80f --- /dev/null +++ b/ext/SymbolicRegressionEnzymeExt.jl @@ -0,0 +1,61 @@ +module LaSREnzymeExt + +using LibraryAugmentedSymbolicRegression.LossFunctionsModule: eval_loss +using DynamicExpressions: + AbstractExpression, + AbstractExpressionNode, + get_scalar_constants, + set_scalar_constants!, + extract_gradient, + with_contents, + get_contents +using ADTypes: AutoEnzyme +using Enzyme: autodiff, Reverse, Active, Const, Duplicated + +import LibraryAugmentedSymbolicRegression.ConstantOptimizationModule: GradEvaluator + +# We prepare a copy of the tree and all arrays +function GradEvaluator(f::F, backend::AE) where {F,AE<:AutoEnzyme} + storage_tree = copy(f.tree) + _, storage_refs = get_scalar_constants(storage_tree) + storage_dataset = deepcopy(f.dataset) + # TODO: It is super inefficient to deepcopy; how can we skip this + return GradEvaluator(f, backend, (; storage_tree, storage_refs, storage_dataset)) +end + +function evaluator(tree, dataset, options, idx, output) + output[] = eval_loss(tree, dataset, options; regularization=false, idx=idx) + return nothing +end + +with_stacksize(f::F, n) where {F} = fetch(schedule(Task(f, n))) + +function (g::GradEvaluator{<:Any,<:AutoEnzyme})(_, G, x::AbstractVector{T}) where {T} + set_scalar_constants!(g.f.tree, x, g.f.refs) + set_scalar_constants!(g.extra.storage_tree, zero(x), g.extra.storage_refs) + fill!(g.extra.storage_dataset, 0) + + output = [zero(T)] + doutput = [one(T)] + + with_stacksize(32 * 1024 * 1024) do + autodiff( + Reverse, + evaluator, + Duplicated(g.f.tree, g.extra.storage_tree), + Duplicated(g.f.dataset, g.extra.storage_dataset), + Const(g.f.options), + Const(g.f.idx), + Duplicated(output, doutput), + ) + end + + if G !== nothing + # TODO: This is redundant since we already have the references. + # Should just be able to extract from the references directly. + G .= first(get_scalar_constants(g.extra.storage_tree)) + end + return output[] +end + +end diff --git a/ext/SymbolicRegressionJSON3Ext.jl b/ext/SymbolicRegressionJSON3Ext.jl index 1384f690f..b44ceb21e 100644 --- a/ext/SymbolicRegressionJSON3Ext.jl +++ b/ext/SymbolicRegressionJSON3Ext.jl @@ -1,7 +1,7 @@ -module SymbolicRegressionJSON3Ext +module LaSRJSON3Ext using JSON3: JSON3 -import SymbolicRegression.UtilsModule: json3_write +import LibraryAugmentedSymbolicRegression.UtilsModule: json3_write function json3_write(record, recorder_file) open(recorder_file, "w") do io diff --git a/ext/SymbolicRegressionSymbolicUtilsExt.jl b/ext/SymbolicRegressionSymbolicUtilsExt.jl index 2b0096ee0..5fbefeb9a 100644 --- a/ext/SymbolicRegressionSymbolicUtilsExt.jl +++ b/ext/SymbolicRegressionSymbolicUtilsExt.jl @@ -1,27 +1,32 @@ -module SymbolicRegressionSymbolicUtilsExt +module LaSRSymbolicUtilsExt using SymbolicUtils: Symbolic -using SymbolicRegression: AbstractExpressionNode, Node, Options -using SymbolicRegression.MLJInterfaceModule: AbstractSRRegressor, get_options +using LibraryAugmentedSymbolicRegression: AbstractExpressionNode, AbstractExpression, Node, Options +using LibraryAugmentedSymbolicRegression.MLJInterfaceModule: AbstractSRRegressor, get_options +using DynamicExpressions: get_tree, get_operators -import SymbolicRegression: node_to_symbolic, symbolic_to_node +import LibraryAugmentedSymbolicRegression: node_to_symbolic, symbolic_to_node """ node_to_symbolic(tree::AbstractExpressionNode, options::Options; kws...) Convert an expression to SymbolicUtils.jl form. """ -function node_to_symbolic(tree::AbstractExpressionNode, options::Options; kws...) - return node_to_symbolic(tree, options.operators; kws...) +function node_to_symbolic( + tree::Union{AbstractExpressionNode,AbstractExpression}, options::Options; kws... +) + return node_to_symbolic(get_tree(tree), get_operators(tree, options); kws...) end -function node_to_symbolic(tree::AbstractExpressionNode, m::AbstractSRRegressor; kws...) +function node_to_symbolic( + tree::Union{AbstractExpressionNode,AbstractExpression}, m::AbstractSRRegressor; kws... +) return node_to_symbolic(tree, get_options(m); kws...) end """ symbolic_to_node(eqn::Symbolic, options::Options; kws...) -Convert a SymbolicUtils.jl expression to SymbolicRegression.jl's `Node` type. +Convert a SymbolicUtils.jl expression to LibraryAugmentedSymbolicRegression.jl's `Node` type. """ function symbolic_to_node(eqn::Symbolic, options::Options; kws...) return symbolic_to_node(eqn, options.operators; kws...) @@ -31,24 +36,30 @@ function symbolic_to_node(eqn::Symbolic, m::AbstractSRRegressor; kws...) end function Base.convert( - ::Type{Symbolic}, tree::AbstractExpressionNode, options::Options; kws... + ::Type{Symbolic}, + tree::Union{AbstractExpressionNode,AbstractExpression}, + options::Union{Options,Nothing}=nothing; + kws..., ) - return convert(Symbolic, tree, options.operators; kws...) + return convert(Symbolic, get_tree(tree), get_operators(tree, options); kws...) end function Base.convert( - ::Type{Symbolic}, tree::AbstractExpressionNode, m::AbstractSRRegressor; kws... + ::Type{Symbolic}, + tree::Union{AbstractExpressionNode,AbstractExpression}, + m::AbstractSRRegressor; + kws..., ) return convert(Symbolic, tree, get_options(m); kws...) end function Base.convert( ::Type{N}, x::Union{Number,Symbolic}, options::Options; kws... -) where {N<:AbstractExpressionNode} +) where {N<:Union{AbstractExpressionNode,AbstractExpression}} return convert(N, x, options.operators; kws...) end function Base.convert( ::Type{N}, x::Union{Number,Symbolic}, m::AbstractSRRegressor; kws... -) where {N<:AbstractExpressionNode} +) where {N<:Union{AbstractExpressionNode,AbstractExpression}} return convert(N, x, get_options(m); kws...) end diff --git a/prompts/crossover_system.txt b/prompts/crossover_system.txt new file mode 100644 index 000000000..43f425aa7 --- /dev/null +++ b/prompts/crossover_system.txt @@ -0,0 +1,2 @@ +You are a helpful assistant that recombines two mathematical expressions by following a few provided suggestions. You will be given three suggestions and two reference expressions to recombine. +An expression must consist of the following variables: {{variables}}. All constants will be represented with the symbol C. Each expression will only use these operators: {{operators}}. \ No newline at end of file diff --git a/prompts/crossover_user.txt b/prompts/crossover_user.txt new file mode 100644 index 000000000..16cbf3806 --- /dev/null +++ b/prompts/crossover_user.txt @@ -0,0 +1,14 @@ +Suggestion 1: {{assump1}} +Suggestion 2: {{assump2}} +Suggestion 3: {{assump3}} +Reference Expression 1: {{expr1}} +Reference Expression 2: {{expr2}} + +Propose {{N}} expressions that would be appropriate given the suggestions and references. Provide short commentary for each of your decisions. End with a JSON list that enumerates the proposed expressions following this format: +```json +["expr1", + "expr2", + ... + "expr{{N}}" +] +``` \ No newline at end of file diff --git a/prompts/extract_idea_system.txt b/prompts/extract_idea_system.txt new file mode 100644 index 000000000..75745f7fa --- /dev/null +++ b/prompts/extract_idea_system.txt @@ -0,0 +1 @@ +You are a helpful assistant that hypothesizes about the underlying assumptions that generated a list of good and bad mathematical expressions in detailed ways. My ultimate goal is to discover what assumptions generated the observed good mathematical expressions and excludes the bad mathematical expressions. Focus more on the good expressions, their mathematical structure, and any relation to physical concepts. Note that capital C represents an arbitrary constant. \ No newline at end of file diff --git a/prompts/extract_idea_user.txt b/prompts/extract_idea_user.txt new file mode 100644 index 000000000..d2591a6dc --- /dev/null +++ b/prompts/extract_idea_user.txt @@ -0,0 +1,20 @@ +Good Expression 1: {{gexpr1}} +Good Expression 2: {{gexpr2}} +Good Expression 3: {{gexpr3}} +Good Expression 4: {{gexpr4}} +Good Expression 5: {{gexpr5}} + +Bad Expression 1: {{bexpr1}} +Bad Expression 2: {{bexpr2}} +Bad Expression 3: {{bexpr3}} +Bad Expression 4: {{bexpr4}} +Bad Expression 5: {{bexpr5}} + +Propose {{N}} hypotheses that would be appropriate given the expressions. Provide short commentary for each of your decisions. Do not talk about topics related to the simplicity or complexity of the expressions. I want ideas that are unique and interesting enough to amaze the world's best mathematicians. End with a JSON list that enumerates the proposed hypotheses following this format: +```json +["hyp1", + "hyp2", + ... + "hyp{{N}}" +] +``` \ No newline at end of file diff --git a/prompts/gen_random_system.txt b/prompts/gen_random_system.txt new file mode 100644 index 000000000..8a7f89aa7 --- /dev/null +++ b/prompts/gen_random_system.txt @@ -0,0 +1,2 @@ +You are a helpful assistant that proposes a mathematical expression by following three provided suggestions. +An expression must consist of the following variables: {{variables}}. All constants will be represented with the symbol C. Each expression will only use these operators: {{operators}}. \ No newline at end of file diff --git a/prompts/gen_random_user.txt b/prompts/gen_random_user.txt new file mode 100644 index 000000000..abd6540a5 --- /dev/null +++ b/prompts/gen_random_user.txt @@ -0,0 +1,12 @@ +Suggestion 1: {{assump1}} +Suggestion 2: {{assump2}} +Suggestion 3: {{assump3}} + +Propose {{N}} expressions that would be appropriate given the suggestions. Provide short commentary for each of your decisions. End with a JSON list that enumerates the proposed expressions following this format: +```json +["expr1", + "expr2", + ... + "expr{{N}}" +] +``` \ No newline at end of file diff --git a/prompts/mutate_system.txt b/prompts/mutate_system.txt new file mode 100644 index 000000000..1def8eef2 --- /dev/null +++ b/prompts/mutate_system.txt @@ -0,0 +1,2 @@ +You are a helpful assistant that mutates a mathematical expression by following a few provided suggestions. You will be given three suggestions and a single reference expression to mutate. +An expression must consist of the following variables: {{variables}}. All constants will be represented with the symbol C. Each expression will only use these operators: {{operators}}. \ No newline at end of file diff --git a/prompts/mutate_user.txt b/prompts/mutate_user.txt new file mode 100644 index 000000000..e2ae7c360 --- /dev/null +++ b/prompts/mutate_user.txt @@ -0,0 +1,13 @@ +Suggestion 1: {{assump1}} +Suggestion 2: {{assump2}} +Suggestion 3: {{assump3}} +Reference Expression: {{expr}} + +Propose {{N}} expressions that would be appropriate given the suggestions and references. Provide short commentary for each of your decisions. End with a JSON list that enumerates the proposed expressions following this format: +```json +["expr1", + "expr2", + ... + "expr{{N}}" +] +``` \ No newline at end of file diff --git a/prompts/new_extract_idea_0_system.txt b/prompts/new_extract_idea_0_system.txt new file mode 100644 index 000000000..1217f202b --- /dev/null +++ b/prompts/new_extract_idea_0_system.txt @@ -0,0 +1,29 @@ +You are a helpful assistant that identifies underlying functional relationships in a set of mathematical expressions. Your goal is to discover and describe the reductions, trends, and patterns that characterize the given good expressions, while distinguishing them from the bad expressions. The functional relationships you should consider include but are not limited to: +- Power Law: [y = k * x^a] +- Linear: [y = mx + b] +- Exponential: [y = ke^(ax)] +- Logarithmic: [y = k * log(x) + b] +- Quadratic: [y = ax^2 + bx + c] +- Polynomial: [y = a_n * x^n + a_(n-1) * x^(n-1) + ... + a_1 * x + a_0] +- Inverse: [y = k / x] +- Rational: [y = P(x) / Q(x)] +- Trigonometric: [y = A * sin(Bx + C) + D] +- Logistic: [y = L / (1 + e^(-k(x-x_0)))] +- Hyperbolic: [y = k / x] +- Piecewise Functions: [f(x) = {f_1(x) if x in A, f_2(x) if x in B, ..., f_n(x) if x in Z}] +- Implicit Functions: [F(x, y) = 0] +- Parametric Equations: [x = f(t), y = g(t)] +- Complex Functions: [f(z) = u(x, y) + iv(x, y)] +- Elliptic Functions: [y^2 = x^3 + ax + b] +- Integral Equations: [โˆซ_a^b f(x) dx = g(x)] +- Differential Equations: [d^n y/dx^n + a_(n-1) d^(n-1) y/dx^(n-1) + ... + a_1 dy/dx + a_0 y = h(x)] +- Generating Functions: [G(x) = โˆ‘_(n=0)^โˆž a_n x^n] +- Special Functions: [Bessel Functions, Gamma Function, Legendre Polynomials] +- Symmetric Functions: [f(x_1, x_2, ..., x_n) = f(x_ฯƒ(1), x_ฯƒ(2), ..., x_ฯƒ(n)) for any permutation ฯƒ] +- Alternating Functions: [f(x_1, x_2, ..., x_n) = sgn(ฯƒ) f(x_ฯƒ(1), x_ฯƒ(2), ..., x_ฯƒ(n))] +- Fractional Power Series: [f(x) = โˆ‘_(n=0)^โˆž a_n x^(n/m)] +- Modular Forms: [f(z) = โˆ‘_(n=0)^โˆž a_n e^(2ฯ€i n z)] +- Lagrange Multipliers: [โˆ‡f = ฮปโˆ‡g] +Your task is to: +1. Identify and describe the functional relationships in the good expressions. +2. Highlight any distinguishing features that separate the good expressions from the bad expressions. \ No newline at end of file diff --git a/prompts/new_extract_idea_1_user.txt b/prompts/new_extract_idea_1_user.txt new file mode 100644 index 000000000..77cfd13df --- /dev/null +++ b/prompts/new_extract_idea_1_user.txt @@ -0,0 +1,19 @@ +Good Expression 1: x0 * x2^(3/2) +Good Expression 2: x0 * (x8 * x10)^(-1) +Good Expression 3: x2^2 * (x0 / x3) +Good Expression 4: (x2 * x0) * x8^(1/2) +Good Expression 5: x2 * (x0 * x10^3) +Bad Expression 1: x2 * (x0 - (log(x8 * x10) / x3)) +Bad Expression 2: x2 * (log(x3) + (x0 - sqrt(x8))) +Bad Expression 3: (x0 - (sqrt(x8) / x3)) * x2 +Bad Expression 4: x2 * (x0 - sin(log(x8))) +Bad Expression 5: x2 * (x0 - log(x8)) + +Comment on the patterns and trends of each equations and end with a JSON list that enumerates the ideas following this format: +```json +["Idea1", + "Idea2", + ... + "IdeaN" +] +``` \ No newline at end of file diff --git a/prompts/new_extract_idea_2_system.txt b/prompts/new_extract_idea_2_system.txt new file mode 100644 index 000000000..cd7a62de8 --- /dev/null +++ b/prompts/new_extract_idea_2_system.txt @@ -0,0 +1,53 @@ +### Analysis of Good Expressions + +#### Good Expression 1: \( x0 \cdot x2^{3/2} \) +- **Functional Relationship**: This is a product of a polynomial term \( x0 \) and a power law term \( x2^{3/2} \). +- **Pattern**: There is no subtraction, logarithm, or transcendental functions involved. Itโ€™s a straightforward multiplicative combination of variables with integer or fractional exponents. + +#### Good Expression 2: \( x0 \cdot (x8 \cdot x10)^{-1} \) +- **Functional Relationship**: This can be rewritten as \( x0 / (x8 \cdot x10) \), which is a rational function and also represents multiplicative inverse relationships. +- **Pattern**: It avoids addition or subtraction within the expression. It maintains a purely multiplicative or divisional structure. + +#### Good Expression 3: \( x2^2 \cdot (x0 / x3) \) +- **Functional Relationship**: This is a product of a polynomial term \( x2^2 \) and a rational function \( x0 / x3 \). +- **Pattern**: Similar to other good expressions, it maintains the pattern of multiplicative combinations without any additive, subtractive, logarithmic, or transcendental functions. + +#### Good Expression 4: \( (x2 \cdot x0) \cdot x8^{1/2} \) +- **Functional Relationship**: This is a product of polynomial terms \( x2 \cdot x0 \) and a power law term \( x8^{1/2} \). +- **Pattern**: Continues the trend of multiplicative relationships, only involving simple powers or polynomials. + +#### Good Expression 5: \( x2 \cdot (x0 \cdot x10^3) \) +- **Functional Relationship**: This is a product of polynomial terms \( x2 \) and \( x0 \), and a term in the form \( x10^3 \). +- **Pattern**: The expression maintains a purely multiplicative structure with integer powers without introducing any other operations. + +### Analysis of Bad Expressions + +#### Bad Expression 1: \( x2 \cdot (x0 - (\log(x8 \cdot x10) / x3)) \) +- **Functional Relationship**: Involves a logarithmic function nested inside a rational function, and then combined with subtraction. +- **Pattern**: Use of subtraction and logarithms, adding complexity and non-linearity that breaks the patterns seen in good expressions. + +#### Bad Expression 2: \( x2 \cdot (\log(x3) + (x0 - \sqrt{x8})) \) +- **Functional Relationship**: Combination of logarithmic and square root functions alongside addition and subtraction. +- **Pattern**: Introduction of logarithms and square roots, and the use of addition/subtraction makes it distinct from good expressions. + +#### Bad Expression 3: \( (x0 - (\sqrt{x8} / x3)) \cdot x2 \) +- **Functional Relationship**: Contains a square root function combined with divisions and subtraction. +- **Pattern**: Presence of non-multiplicative operations like subtraction and square roots marks it as different. + +#### Bad Expression 4: \( x2 \cdot (x0 - \sin(\log(x8))) \) +- **Functional Relationship**: Combination of a logarithmic function inside a sine function, followed by subtraction. +- **Pattern**: Use of transcendental functions like sine and logarithm combined with subtraction is a key differentiator from good expressions. + +#### Bad Expression 5: \( x2 \cdot (x0 - \log(x8)) \) +- **Functional Relationship**: Involves subtraction of a logarithmic function. +- **Pattern**: Introducing subtraction and logs shows a clear departure from the purely multiplicative nature of the good expressions. + +### JSON List of Identified Ideas +```json +[ + "Good expressions are characterized by purely multiplicative forms without addition, subtraction, or transcendental functions.", + "Good expressions include combinations of polynomial, power law, or rational forms.", + "Bad expressions introduce complexity with logarithms, square roots, or trigonometric functions.", + "Bad expressions often involve addition or subtraction, breaking the purely multiplicative pattern." +] +``` \ No newline at end of file diff --git a/prompts/new_extract_idea_3_user.txt b/prompts/new_extract_idea_3_user.txt new file mode 100644 index 000000000..a298d279b --- /dev/null +++ b/prompts/new_extract_idea_3_user.txt @@ -0,0 +1,21 @@ +Input: +Good Expression 1: {{gexpr1}} +Good Expression 2: {{gexpr2}} +Good Expression 3: {{gexpr3}} +Good Expression 4: {{gexpr4}} +Good Expression 5: {{gexpr5}} + +Bad Expression 1: {{bexpr1}} +Bad Expression 2: {{bexpr2}} +Bad Expression 3: {{bexpr3}} +Bad Expression 4: {{bexpr4}} +Bad Expression 5: {{bexpr5}} + +Comment on the patterns and trends of each equations and end with a JSON list that enumerates the ideas following this format: +```json +["Idea1", + "Idea2", + ... + "IdeaN" +] +``` \ No newline at end of file diff --git a/prompts/prompt_evol_system.txt b/prompts/prompt_evol_system.txt new file mode 100644 index 000000000..b43f0d4ab --- /dev/null +++ b/prompts/prompt_evol_system.txt @@ -0,0 +1 @@ +You are a helpful assistant that merges and refines ideas about a set of hidden mathematical expression in new, interesting, and diverse ways. My ultimate goal is to discover the underlying properties of these hidden expressions. The resulting ideas should be a nontrivial conclusion given the previous ideas. \ No newline at end of file diff --git a/prompts/prompt_evol_user.txt b/prompts/prompt_evol_user.txt new file mode 100644 index 000000000..48821163b --- /dev/null +++ b/prompts/prompt_evol_user.txt @@ -0,0 +1,14 @@ +Idea 1: {{idea1}} +Idea 2: {{idea2}} +Idea 3: {{idea3}} +Idea 4: {{idea4}} +Idea 5: {{idea5}} + +Propose {{N}} hypotheses that would be appropriate given the ideas. Provide short commentary for each of your decisions. Do not talk about topics related to the simplicity or complexity of the expressions. I want ideas that are unique and interesting enough to amaze the world's best mathematicians. End with a JSON list that enumerates the proposed hypotheses following this format: +```json +["hyp1", + "hyp2", + ... + "hyp{{N}}" +] +``` \ No newline at end of file diff --git a/src/CheckConstraints.jl b/src/CheckConstraints.jl index 6b60ff840..7f6093631 100644 --- a/src/CheckConstraints.jl +++ b/src/CheckConstraints.jl @@ -1,7 +1,7 @@ module CheckConstraintsModule -using DynamicExpressions: AbstractExpressionNode, count_depth, tree_mapreduce -using ..UtilsModule: vals +using DynamicExpressions: + AbstractExpressionNode, AbstractExpression, get_tree, count_depth, tree_mapreduce using ..CoreModule: Options using ..ComplexityModule: compute_complexity, past_complexity_limit @@ -70,6 +70,15 @@ function flag_illegal_nests(tree::AbstractExpressionNode, options::Options)::Boo end """Check if user-passed constraints are violated or not""" +function check_constraints( + ex::AbstractExpression, + options::Options, + maxsize::Int, + cursize::Union{Int,Nothing}=nothing, +)::Bool + tree = get_tree(ex) + return check_constraints(tree, options, maxsize, cursize) +end function check_constraints( tree::AbstractExpressionNode, options::Options, @@ -93,7 +102,8 @@ function check_constraints( return true end -check_constraints(tree::AbstractExpressionNode, options::Options)::Bool = - check_constraints(tree, options, options.maxsize) +check_constraints( + ex::Union{AbstractExpression,AbstractExpressionNode}, options::Options +)::Bool = check_constraints(ex, options, options.maxsize) end diff --git a/src/Complexity.jl b/src/Complexity.jl index f3b0bba91..dccb05bd3 100644 --- a/src/Complexity.jl +++ b/src/Complexity.jl @@ -1,9 +1,12 @@ module ComplexityModule -using DynamicExpressions: AbstractExpressionNode, count_nodes, tree_mapreduce +using DynamicExpressions: + AbstractExpression, AbstractExpressionNode, get_tree, count_nodes, tree_mapreduce using ..CoreModule: Options, ComplexityMapping -function past_complexity_limit(tree::AbstractExpressionNode, options::Options, limit)::Bool +function past_complexity_limit( + tree::Union{AbstractExpression,AbstractExpressionNode}, options::Options, limit +)::Bool return compute_complexity(tree, options) > limit end @@ -14,6 +17,11 @@ By default, this is the number of nodes in a tree. However, it could use the custom settings in options.complexity_mapping if these are defined. """ +function compute_complexity( + tree::AbstractExpression, options::Options; break_sharing=Val(false) +) + return compute_complexity(get_tree(tree), options; break_sharing) +end function compute_complexity( tree::AbstractExpressionNode, options::Options; break_sharing=Val(false) )::Int diff --git a/src/Configure.jl b/src/Configure.jl index 5ccc08100..a256a1ee0 100644 --- a/src/Configure.jl +++ b/src/Configure.jl @@ -64,7 +64,7 @@ function test_option_configuration( if is_anonymous_function(op) throw( AssertionError( - "Anonymous functions can't be used as operators for SymbolicRegression.jl", + "Anonymous functions can't be used as operators for LibraryAugmentedSymbolicRegression.jl", ), ) end @@ -206,15 +206,15 @@ end function import_module_on_workers(procs, filename::String, options::Options, verbosity) loaded_modules_head_worker = [k.name for (k, _) in Base.loaded_modules] - included_as_local = "SymbolicRegression" โˆ‰ loaded_modules_head_worker + included_as_local = "LibraryAugmentedSymbolicRegression" โˆ‰ loaded_modules_head_worker expr = if included_as_local quote include($filename) - using .SymbolicRegression + using .LibraryAugmentedSymbolicRegression end else quote - using SymbolicRegression + using LibraryAugmentedSymbolicRegression end end @@ -242,9 +242,9 @@ function import_module_on_workers(procs, filename::String, options::Options, ver end verbosity > 0 && if isempty(relevant_extensions) - @info "Importing SymbolicRegression on workers." + @info "Importing LibraryAugmentedSymbolicRegression on workers." else - @info "Importing SymbolicRegression on workers as well as extensions $(join(relevant_extensions, ',' * ' '))." + @info "Importing LibraryAugmentedSymbolicRegression on workers as well as extensions $(join(relevant_extensions, ',' * ' '))." end @everywhere procs Core.eval(Core.Main, $expr) verbosity > 0 && @info "Finished!" @@ -257,7 +257,7 @@ function test_module_on_workers(procs, options::Options, verbosity) for proc in procs push!( futures, - @spawnat proc SymbolicRegression.gen_random_tree(3, options, 5, TEST_TYPE) + @spawnat proc LibraryAugmentedSymbolicRegression.gen_random_tree(3, options, 5, TEST_TYPE) ) end for future in futures diff --git a/src/ConstantOptimization.jl b/src/ConstantOptimization.jl index 7dc8471a3..fe66b4f5d 100644 --- a/src/ConstantOptimization.jl +++ b/src/ConstantOptimization.jl @@ -2,8 +2,16 @@ module ConstantOptimizationModule using LineSearches: LineSearches using Optim: Optim -using DynamicExpressions: Node, count_constants -using ..CoreModule: Options, Dataset, DATA_TYPE, LOSS_TYPE +using ADTypes: AbstractADType, AutoEnzyme +using DifferentiationInterface: value_and_gradient +using DynamicExpressions: + AbstractExpression, + Expression, + count_scalar_constants, + get_scalar_constants, + set_scalar_constants!, + extract_gradient +using ..CoreModule: Options, Dataset, DATA_TYPE, LOSS_TYPE, specialized_options using ..UtilsModule: get_birth_order using ..LossFunctionsModule: eval_loss, loss_to_score, batch_sample using ..PopMemberModule: PopMember @@ -22,45 +30,58 @@ end function dispatch_optimize_constants( dataset::Dataset{T,L}, member::P, options::Options, idx ) where {T<:DATA_TYPE,L<:LOSS_TYPE,P<:PopMember{T,L}} - nconst = count_constants(member.tree) + nconst = count_constants_for_optimization(member.tree) nconst == 0 && return (member, 0.0) if nconst == 1 && !(T <: Complex) algorithm = Optim.Newton(; linesearch=LineSearches.BackTracking()) return _optimize_constants( - dataset, member, options, algorithm, options.optimizer_options, idx + dataset, + member, + specialized_options(options), + algorithm, + options.optimizer_options, + idx, ) end return _optimize_constants( dataset, member, - options, + specialized_options(options), + # We use specialized options here due to Enzyme being + # more particular about dynamic dispatch options.optimizer_algorithm, options.optimizer_options, idx, ) end +"""How many constants will be optimized.""" +count_constants_for_optimization(ex::Expression) = count_scalar_constants(ex) + function _optimize_constants( dataset, member::P, options, algorithm, optimizer_options, idx )::Tuple{P,Float64} where {T,L,P<:PopMember{T,L}} tree = member.tree eval_fraction = options.batching ? (options.batch_size / dataset.n) : 1.0 - f(t) = eval_loss(t, dataset, options; regularization=false, idx=idx)::L - baseline = f(tree) - result = Optim.optimize(f, tree, algorithm, optimizer_options) + x0, refs = get_scalar_constants(tree) + @assert count_constants_for_optimization(tree) == length(x0) + f = Evaluator(tree, refs, dataset, options, idx) + fg! = GradEvaluator(f, options.autodiff_backend) + obj = if algorithm isa Optim.Newton || options.autodiff_backend === nothing + f + else + Optim.only_fg!(fg!) + end + baseline = f(x0) + result = Optim.optimize(obj, x0, algorithm, optimizer_options) num_evals = result.f_calls * eval_fraction # Try other initial conditions: for _ in 1:(options.optimizer_nrestarts) - tmptree = copy(tree) - foreach(tmptree) do node - if node.degree == 0 && node.constant - node.val = (node.val) * (T(1) + T(1//2) * randn(T)) - end - end - tmpresult = Optim.optimize( - f, tmptree, algorithm, optimizer_options; make_copy=false - ) + eps = randn(T, size(x0)...) + xt = @. x0 * (T(1) + T(1//2) * eps) + tmpresult = Optim.optimize(obj, xt, algorithm, optimizer_options) num_evals += tmpresult.f_calls * eval_fraction + # TODO: Does this need to take into account h_calls? if tmpresult.minimum < result.minimum result = tmpresult @@ -68,16 +89,48 @@ function _optimize_constants( end if result.minimum < baseline - member.tree = result.minimizer - member.loss = eval_loss(member.tree, dataset, options; regularization=true, idx=idx) + member.tree = tree + member.loss = f(result.minimizer; regularization=true) member.score = loss_to_score( member.loss, dataset.use_baseline, dataset.baseline_loss, member, options ) member.birth = get_birth_order(; deterministic=options.deterministic) num_evals += eval_fraction + else + set_scalar_constants!(member.tree, x0, refs) end return member, num_evals end +struct Evaluator{N<:AbstractExpression,R,D<:Dataset,O<:Options,I} <: Function + tree::N + refs::R + dataset::D + options::O + idx::I +end +function (e::Evaluator)(x::AbstractVector; regularization=false) + set_scalar_constants!(e.tree, x, e.refs) + return eval_loss(e.tree, e.dataset, e.options; regularization, e.idx) +end +struct GradEvaluator{F<:Evaluator,AD<:Union{Nothing,AbstractADType},EX} <: Function + f::F + backend::AD + extra::EX +end +GradEvaluator(f::F, backend::AD) where {F,AD} = GradEvaluator(f, backend, nothing) + +function (g::GradEvaluator{<:Any,AD})(_, G, x::AbstractVector) where {AD} + AD isa AutoEnzyme && error("Please load the `Enzyme.jl` package.") + set_scalar_constants!(g.f.tree, x, g.f.refs) + (val, grad) = value_and_gradient(g.backend, g.f.tree) do tree + eval_loss(tree, g.f.dataset, g.f.options; regularization=false, idx=g.f.idx) + end + if G !== nothing && grad !== nothing + G .= extract_gradient(grad, g.f.tree) + end + return val +end + end diff --git a/src/Core.jl b/src/Core.jl index 7a04b975f..63860b0f7 100644 --- a/src/Core.jl +++ b/src/Core.jl @@ -1,9 +1,12 @@ module CoreModule +function create_expression end + include("Utils.jl") include("ProgramConstants.jl") include("Dataset.jl") include("MutationWeights.jl") +include("LLMOptions.jl") include("OptionsStruct.jl") include("Operators.jl") include("Options.jl") @@ -12,8 +15,9 @@ using .ProgramConstantsModule: MAX_DEGREE, BATCH_DIM, FEATURE_DIM, RecordType, DATA_TYPE, LOSS_TYPE using .DatasetModule: Dataset using .MutationWeightsModule: MutationWeights, sample_mutation -using .OptionsStructModule: Options, ComplexityMapping -using .OptionsModule: Options +using .LLMOptionsModule: LLMOptions, LLMWeights +using .OptionsStructModule: Options, ComplexityMapping, specialized_options +using .OptionsModule: Options, binopmap, unaopmap using .OperatorsModule: plus, sub, diff --git a/src/Dataset.jl b/src/Dataset.jl index b2a7b42d8..99c31ee3d 100644 --- a/src/Dataset.jl +++ b/src/Dataset.jl @@ -1,13 +1,6 @@ module DatasetModule -using DynamicQuantities: - AbstractDimensions, - Dimensions, - SymbolicDimensions, - Quantity, - uparse, - sym_uparse, - DEFAULT_DIM_BASE_TYPE +using DynamicQuantities: Quantity using ..UtilsModule: subscriptify, get_base_type, @constfield using ..ProgramConstantsModule: BATCH_DIM, FEATURE_DIM, DATA_TYPE, LOSS_TYPE @@ -22,6 +15,8 @@ import ...deprecate_varmap - `X::AbstractMatrix{T}`: The input features, with shape `(nfeatures, n)`. - `y::AbstractVector{T}`: The desired output values, with shape `(n,)`. +- `index::Int`: The index of the output feature corresponding to this + dataset, if any. - `n::Int`: The number of samples. - `nfeatures::Int`: The number of features. - `weighted::Bool`: Whether the dataset is non-uniformly weighted. @@ -64,6 +59,7 @@ mutable struct Dataset{ } @constfield X::AX @constfield y::AY + @constfield index::Int @constfield n::Int @constfield nfeatures::Int @constfield weighted::Bool @@ -99,6 +95,7 @@ function Dataset( X::AbstractMatrix{T}, y::Union{AbstractVector{T},Nothing}=nothing, loss_type::Type{L}=Nothing; + index::Int=1, weights::Union{AbstractVector{T},Nothing}=nothing, variable_names::Union{Array{String,1},Nothing}=nothing, display_variable_names=variable_names, @@ -123,6 +120,7 @@ function Dataset( X, y, kws[:loss_type]; + index, weights, variable_names, display_variable_names, @@ -206,6 +204,7 @@ function Dataset( }( X, y, + index, n, nfeatures, weighted, @@ -260,4 +259,17 @@ function has_units(dataset::Dataset) return dataset.X_units !== nothing || dataset.y_units !== nothing end +# Used for Enzyme +function Base.fill!(d::Dataset, val) + _fill!(d.X, val) + _fill!(d.y, val) + _fill!(d.weights, val) + _fill!(d.extra, val) + return d +end +_fill!(x::AbstractArray, val) = fill!(x, val) +_fill!(x::NamedTuple, val) = foreach(v -> _fill!(v, val), values(x)) +_fill!(::Nothing, val) = nothing +_fill!(x, val) = x + end diff --git a/src/DimensionalAnalysis.jl b/src/DimensionalAnalysis.jl index d0e975fbf..cc9440db1 100644 --- a/src/DimensionalAnalysis.jl +++ b/src/DimensionalAnalysis.jl @@ -1,7 +1,7 @@ module DimensionalAnalysisModule -using DynamicExpressions: AbstractExpressionNode -using DynamicQuantities: Quantity, DimensionError, AbstractQuantity, uparse, constructorof +using DynamicExpressions: AbstractExpression, AbstractExpressionNode, get_tree +using DynamicQuantities: Quantity, DimensionError, AbstractQuantity, constructorof using ..CoreModule: Options, Dataset using ..UtilsModule: safe_call @@ -192,6 +192,11 @@ function violates_dimensional_constraints( tree, dataset.X_units, dataset.y_units, (@view X[:, 1]), options ) end +function violates_dimensional_constraints( + tree::AbstractExpression, dataset::Dataset, options::Options +) + return violates_dimensional_constraints(get_tree(tree), dataset, options) +end function violates_dimensional_constraints( tree::AbstractExpressionNode{T}, X_units::AbstractVector{<:Quantity}, diff --git a/src/ExpressionBuilder.jl b/src/ExpressionBuilder.jl new file mode 100644 index 000000000..bcad22575 --- /dev/null +++ b/src/ExpressionBuilder.jl @@ -0,0 +1,290 @@ +module ExpressionBuilderModule + +using DispatchDoctor: @unstable +using DynamicExpressions: + AbstractExpressionNode, + AbstractExpression, + Expression, + ParametricExpression, + ParametricNode, + constructorof, + get_tree, + get_contents, + get_metadata, + with_contents, + with_metadata, + count_scalar_constants, + eval_tree_array +using Random: default_rng, AbstractRNG +using StatsBase: StatsBase +using ..CoreModule: Options, Dataset, DATA_TYPE +using ..HallOfFameModule: HallOfFame +using ..LossFunctionsModule: maybe_getindex +using ..InterfaceDynamicExpressionsModule: expected_array_type +using ..PopulationModule: Population +using ..PopMemberModule: PopMember + +import DynamicExpressions: get_operators +import ..CoreModule: create_expression +import ..MutationFunctionsModule: + make_random_leaf, crossover_trees, mutate_constant, mutate_factor +import ..LossFunctionsModule: eval_tree_dispatch +import ..ConstantOptimizationModule: count_constants_for_optimization + +@unstable function create_expression( + t::T, options::Options, dataset::Dataset{T,L}, ::Val{embed}=Val(false) +) where {T,L,embed} + return create_expression( + constructorof(options.node_type)(; val=t), options, dataset, Val(embed) + ) +end +@unstable function create_expression( + t::AbstractExpressionNode{T}, + options::Options, + dataset::Dataset{T,L}, + ::Val{embed}=Val(false), +) where {T,L,embed} + return constructorof(options.expression_type)( + t; init_params(options, dataset, nothing, Val(embed))... + ) +end +function create_expression( + ex::AbstractExpression{T}, ::Options, ::Dataset{T,L}, ::Val{embed}=Val(false) +) where {T,L,embed} + return ex +end +@unstable function init_params( + options::Options, + dataset::Dataset{T,L}, + prototype::Union{Nothing,AbstractExpression}, + ::Val{embed}, +) where {T,L,embed} + consistency_checks(options, prototype) + return (; + operators=embed ? options.operators : nothing, + variable_names=embed ? dataset.variable_names : nothing, + extra_init_params( + options.expression_type, prototype, options, dataset, Val(embed) + )..., + ) +end +function extra_init_params( + ::Type{E}, + prototype::Union{Nothing,AbstractExpression}, + options::Options, + dataset::Dataset{T}, + ::Val{embed}, +) where {T,embed,E<:AbstractExpression} + return (;) +end +function extra_init_params( + ::Type{E}, + prototype::Union{Nothing,ParametricExpression}, + options::Options, + dataset::Dataset{T}, + ::Val{embed}, +) where {T,embed,E<:ParametricExpression} + num_params = options.expression_options.max_parameters + num_classes = length(unique(dataset.extra.classes)) + parameter_names = embed ? ["p$i" for i in 1:num_params] : nothing + _parameters = if prototype === nothing + randn(T, (num_params, num_classes)) + else + copy(get_metadata(prototype).parameters) + end + return (; parameters=_parameters, parameter_names) +end + +consistency_checks(::Options, prototype::Nothing) = nothing +function consistency_checks(options::Options, prototype) + if prototype === nothing + return nothing + end + @assert( + prototype isa options.expression_type, + "Need prototype to be of type $(options.expression_type), but got $(prototype)::$(typeof(prototype))" + ) + if prototype isa ParametricExpression + if prototype.metadata.parameter_names !== nothing + @assert( + length(prototype.metadata.parameter_names) == + options.expression_options.max_parameters, + "Mismatch between options.expression_options.max_parameters=$(options.expression_options.max_parameters) and prototype.metadata.parameter_names=$(prototype.metadata.parameter_names)" + ) + end + @assert size(prototype.metadata.parameters, 1) == + options.expression_options.max_parameters + end + return nothing +end + +@unstable begin + function embed_metadata( + ex::AbstractExpression, options::Options, dataset::Dataset{T,L} + ) where {T,L} + return with_metadata(ex; init_params(options, dataset, ex, Val(true))...) + end + function embed_metadata( + member::PopMember, options::Options, dataset::Dataset{T,L} + ) where {T,L} + return PopMember( + embed_metadata(member.tree, options, dataset), + member.score, + member.loss, + nothing; + member.ref, + member.parent, + deterministic=options.deterministic, + ) + end + function embed_metadata( + pop::Population, options::Options, dataset::Dataset{T,L} + ) where {T,L} + return Population( + map(member -> embed_metadata(member, options, dataset), pop.members) + ) + end + function embed_metadata( + hof::HallOfFame, options::Options, dataset::Dataset{T,L} + ) where {T,L} + return HallOfFame( + map(member -> embed_metadata(member, options, dataset), hof.members), hof.exists + ) + end + function embed_metadata( + vec::Vector{H}, options::Options, dataset::Dataset{T,L} + ) where {T,L,H<:Union{HallOfFame,Population,PopMember}} + return map(elem -> embed_metadata(elem, options, dataset), vec) + end +end + +"""Strips all metadata except for top-level information""" +function strip_metadata(ex::Expression, options::Options, dataset::Dataset{T,L}) where {T,L} + return with_metadata(ex; init_params(options, dataset, ex, Val(false))...) +end +function strip_metadata( + ex::ParametricExpression, options::Options, dataset::Dataset{T,L} +) where {T,L} + return with_metadata(ex; init_params(options, dataset, ex, Val(false))...) +end +function strip_metadata( + member::PopMember, options::Options, dataset::Dataset{T,L} +) where {T,L} + return PopMember( + strip_metadata(member.tree, options, dataset), + member.score, + member.loss, + nothing; + member.ref, + member.parent, + deterministic=options.deterministic, + ) +end +function strip_metadata( + pop::Population, options::Options, dataset::Dataset{T,L} +) where {T,L} + return Population(map(member -> strip_metadata(member, options, dataset), pop.members)) +end +function strip_metadata( + hof::HallOfFame, options::Options, dataset::Dataset{T,L} +) where {T,L} + return HallOfFame( + map(member -> strip_metadata(member, options, dataset), hof.members), hof.exists + ) +end + +function eval_tree_dispatch( + tree::ParametricExpression{T}, dataset::Dataset{T}, options::Options, idx +) where {T<:DATA_TYPE} + A = expected_array_type(dataset.X) + return eval_tree_array( + tree, + maybe_getindex(dataset.X, :, idx), + maybe_getindex(dataset.extra.classes, idx), + options.operators, + )::Tuple{A,Bool} +end + +function make_random_leaf( + nfeatures::Int, + ::Type{T}, + ::Type{N}, + rng::AbstractRNG=default_rng(), + options::Union{Options,Nothing}=nothing, +) where {T<:DATA_TYPE,N<:ParametricNode} + choice = rand(rng, 1:3) + if choice == 1 + return ParametricNode(; val=randn(rng, T)) + elseif choice == 2 + return ParametricNode(T; feature=rand(rng, 1:nfeatures)) + else + tree = ParametricNode{T}() + tree.val = zero(T) + tree.degree = 0 + tree.feature = 0 + tree.constant = false + tree.is_parameter = true + tree.parameter = rand( + rng, UInt16(1):UInt16(options.expression_options.max_parameters) + ) + return tree + end +end + +function crossover_trees( + ex1::ParametricExpression{T}, ex2::AbstractExpression{T}, rng::AbstractRNG=default_rng() +) where {T} + tree1 = get_contents(ex1) + tree2 = get_contents(ex2) + out1, out2 = crossover_trees(tree1, tree2, rng) + ex1 = with_contents(ex1, out1) + ex2 = with_contents(ex2, out2) + + # We also randomly share parameters + nparams1 = size(ex1.metadata.parameters, 1) + nparams2 = size(ex2.metadata.parameters, 1) + num_params_switch = min(nparams1, nparams2) + idx_to_switch = StatsBase.sample( + rng, 1:num_params_switch, num_params_switch; replace=false + ) + for param_idx in idx_to_switch + ex2_params = ex2.metadata.parameters[param_idx, :] + ex2.metadata.parameters[param_idx, :] .= ex1.metadata.parameters[param_idx, :] + ex1.metadata.parameters[param_idx, :] .= ex2_params + end + + return ex1, ex2 +end + +function count_constants_for_optimization(ex::ParametricExpression) + return count_scalar_constants(get_tree(ex)) + length(ex.metadata.parameters) +end + +function mutate_constant( + ex::ParametricExpression{T}, + temperature, + options::Options, + rng::AbstractRNG=default_rng(), +) where {T<:DATA_TYPE} + if rand(rng, Bool) + # Normal mutation of inner constant + tree = get_contents(ex) + return with_contents(ex, mutate_constant(tree, temperature, options, rng)) + else + # Mutate parameters + parameter_index = rand(rng, 1:(options.expression_options.max_parameters)) + # We mutate all the parameters at once + factor = mutate_factor(T, temperature, options, rng) + ex.metadata.parameters[parameter_index, :] .*= factor + return ex + end +end + +@unstable function get_operators(ex::AbstractExpression, options::Options) + return get_operators(ex, options.operators) +end +@unstable function get_operators(ex::AbstractExpressionNode, options::Options) + return get_operators(ex, options.operators) +end + +end diff --git a/src/HallOfFame.jl b/src/HallOfFame.jl index 8224fe480..19c52f933 100644 --- a/src/HallOfFame.jl +++ b/src/HallOfFame.jl @@ -1,12 +1,11 @@ module HallOfFameModule -using DynamicExpressions: AbstractExpressionNode, Node, constructorof, string_tree -using DynamicExpressions.EquationModule: with_type_parameters +using DynamicExpressions: AbstractExpression, string_tree using ..UtilsModule: split_string -using ..CoreModule: MAX_DEGREE, Options, Dataset, DATA_TYPE, LOSS_TYPE, relu +using ..CoreModule: + MAX_DEGREE, Options, Dataset, DATA_TYPE, LOSS_TYPE, relu, create_expression using ..ComplexityModule: compute_complexity using ..PopMemberModule: PopMember -using ..LossFunctionsModule: eval_loss using ..InterfaceDynamicExpressionsModule: format_dimensions, WILDCARD_UNIT_STRING using Printf: @sprintf @@ -23,13 +22,33 @@ have been set, you can run `.members[exists]`. These are ordered by complexity, with `.members[1]` the member with complexity 1. - `exists::Array{Bool,1}`: Whether the member at the given complexity has been set. """ -struct HallOfFame{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpressionNode{T}} +struct HallOfFame{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} members::Array{PopMember{T,L,N},1} exists::Array{Bool,1} #Whether it has been set end +function Base.show(io::IO, mime::MIME"text/plain", hof::HallOfFame{T,L,N}) where {T,L,N} + println(io, "HallOfFame{...}:") + for i in eachindex(hof.members, hof.exists) + s_member, s_exists = if hof.exists[i] + sprint((io, m) -> show(io, mime, m), hof.members[i]), "true" + else + "undef", "false" + end + println(io, " "^4 * ".exists[$i] = $s_exists") + print(io, " "^4 * ".members[$i] =") + splitted = split(strip(s_member), '\n') + if length(splitted) == 1 + println(io, " " * s_member) + else + println(io) + foreach(line -> println(io, " "^8 * line), splitted) + end + end + return nothing +end """ - HallOfFame(options::Options, ::Type{T}, ::Type{L}) where {T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpressionNode} + HallOfFame(options::Options, dataset::Dataset{T,L}) where {T<:DATA_TYPE,L<:LOSS_TYPE} Create empty HallOfFame. The HallOfFame stores a list of `PopMember` objects in `.members`, which is enumerated @@ -39,18 +58,18 @@ has been instantiated or not. Arguments: - `options`: Options containing specification about deterministic. -- `T`: Type of Nodes to use in the population. e.g., `Float64`. -- `L`: Type of loss to use in the population. e.g., `Float64`. +- `dataset`: Dataset containing the input data. """ function HallOfFame( - options::Options, ::Type{T}, ::Type{L} + options::Options, dataset::Dataset{T,L} ) where {T<:DATA_TYPE,L<:LOSS_TYPE} actualMaxsize = options.maxsize + MAX_DEGREE - NT = with_type_parameters(options.node_type, T) - return HallOfFame{T,L,NT}( + base_tree = create_expression(zero(T), options, dataset) + + return HallOfFame{T,L,typeof(base_tree)}( [ PopMember( - constructorof(options.node_type)(T; val=convert(T, 1)), + copy(base_tree), L(0), L(Inf), options; diff --git a/src/InterfaceDynamicExpressions.jl b/src/InterfaceDynamicExpressions.jl index ebd88c7f2..d5cf52300 100644 --- a/src/InterfaceDynamicExpressions.jl +++ b/src/InterfaceDynamicExpressions.jl @@ -1,27 +1,24 @@ module InterfaceDynamicExpressionsModule using Printf: @sprintf -using DynamicExpressions: DynamicExpressions using DynamicExpressions: - OperatorEnum, GenericOperatorEnum, AbstractExpressionNode, Node, GraphNode -using DynamicExpressions.StringsModule: needs_brackets + DynamicExpressions as DE, + OperatorEnum, + GenericOperatorEnum, + AbstractExpression, + AbstractExpressionNode, + ParametricExpression, + Node, + GraphNode using DynamicQuantities: dimension, ustrip using ..CoreModule: Options using ..CoreModule.OptionsModule: inverse_binopmap, inverse_unaopmap using ..UtilsModule: subscriptify -import DynamicExpressions: - eval_tree_array, - eval_diff_tree_array, - eval_grad_tree_array, - print_tree, - string_tree, - differentiable_eval_tree_array - import ..deprecate_varmap """ - eval_tree_array(tree::AbstractExpressionNode, X::AbstractArray, options::Options; kws...) + eval_tree_array(tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, options::Options; kws...) Evaluate a binary tree (equation) over a given input data matrix. The operators contain all of the operators used. This function fuses doublets @@ -42,7 +39,7 @@ The bulk of the code is for optimizations and pre-emptive NaN/Inf checks, which speed up evaluation significantly. # Arguments -- `tree::AbstractExpressionNode`: The root node of the tree to evaluate. +- `tree::Union{AbstractExpression,AbstractExpressionNode}`: The root node of the tree to evaluate. - `X::AbstractArray`: The input data to evaluate the tree on. - `options::Options`: Options used to define the operators used in the tree. @@ -53,12 +50,38 @@ which speed up evaluation significantly. or nan was encountered, and a large loss should be assigned to the equation. """ -function eval_tree_array( - tree::AbstractExpressionNode, X::AbstractArray, options::Options; kws... +function DE.eval_tree_array( + tree::Union{AbstractExpressionNode,AbstractExpression}, + X::AbstractMatrix, + options::Options; + kws..., ) A = expected_array_type(X) - return eval_tree_array( - tree, X, options.operators; turbo=options.turbo, bumper=options.bumper, kws... + return DE.eval_tree_array( + tree, + X, + DE.get_operators(tree, options); + turbo=options.turbo, + bumper=options.bumper, + kws..., + )::Tuple{A,Bool} +end +function DE.eval_tree_array( + tree::ParametricExpression, + X::AbstractMatrix, + classes::AbstractVector{<:Integer}, + options::Options; + kws..., +) + A = expected_array_type(X) + return DE.eval_tree_array( + tree, + X, + classes, + DE.get_operators(tree, options); + turbo=options.turbo, + bumper=options.bumper, + kws..., )::Tuple{A,Bool} end @@ -68,7 +91,7 @@ function expected_array_type(X::AbstractArray) end """ - eval_diff_tree_array(tree::AbstractExpressionNode, X::AbstractArray, options::Options, direction::Int) + eval_diff_tree_array(tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, options::Options, direction::Int) Compute the forward derivative of an expression, using a similar structure and optimization to eval_tree_array. `direction` is the index of a particular @@ -77,7 +100,7 @@ respect to `x1`. # Arguments -- `tree::AbstractExpressionNode`: The expression tree to evaluate. +- `tree::Union{AbstractExpression,AbstractExpressionNode}`: The expression tree to evaluate. - `X::AbstractArray`: The data matrix, with each column being a data point. - `options::Options`: The options containing the operators used to create the `tree`. - `direction::Int`: The index of the variable to take the derivative with respect to. @@ -87,15 +110,21 @@ respect to `x1`. - `(evaluation, derivative, complete)::Tuple{AbstractVector, AbstractVector, Bool}`: the normal evaluation, the derivative, and whether the evaluation completed as normal (or encountered a nan or inf). """ -function eval_diff_tree_array( - tree::AbstractExpressionNode, X::AbstractArray, options::Options, direction::Int +function DE.eval_diff_tree_array( + tree::Union{AbstractExpression,AbstractExpressionNode}, + X::AbstractArray, + options::Options, + direction::Int, ) A = expected_array_type(X) - return eval_diff_tree_array(tree, X, options.operators, direction)::Tuple{A,A,Bool} + # TODO: Add `AbstractExpression` implementation in `Expression.jl` + return DE.eval_diff_tree_array( + DE.get_tree(tree), X, DE.get_operators(tree, options), direction + )::Tuple{A,A,Bool} end """ - eval_grad_tree_array(tree::AbstractExpressionNode, X::AbstractArray, options::Options; variable::Bool=false) + eval_grad_tree_array(tree::Union{AbstractExpression,AbstractExpressionNode}, X::AbstractArray, options::Options; variable::Bool=false) Compute the forward-mode derivative of an expression, using a similar structure and optimization to eval_tree_array. `variable` specifies whether @@ -104,7 +133,7 @@ to every constant in the expression. # Arguments -- `tree::AbstractExpressionNode`: The expression tree to evaluate. +- `tree::Union{AbstractExpression,AbstractExpressionNode}`: The expression tree to evaluate. - `X::AbstractArray`: The data matrix, with each column being a data point. - `options::Options`: The options containing the operators used to create the `tree`. - `variable::Bool`: Whether to take derivatives with respect to features (i.e., `X` - with `variable=true`), @@ -115,12 +144,17 @@ to every constant in the expression. - `(evaluation, gradient, complete)::Tuple{AbstractVector, AbstractArray, Bool}`: the normal evaluation, the gradient, and whether the evaluation completed as normal (or encountered a nan or inf). """ -function eval_grad_tree_array( - tree::AbstractExpressionNode, X::AbstractArray, options::Options; kws... +function DE.eval_grad_tree_array( + tree::Union{AbstractExpression,AbstractExpressionNode}, + X::AbstractArray, + options::Options; + kws..., ) A = expected_array_type(X) M = typeof(X) # TODO: This won't work with StaticArrays! - return eval_grad_tree_array(tree, X, options.operators; kws...)::Tuple{A,M,Bool} + return DE.eval_grad_tree_array( + tree, X, DE.get_operators(tree, options); kws... + )::Tuple{A,M,Bool} end """ @@ -128,11 +162,16 @@ end Evaluate an expression tree in a way that can be auto-differentiated. """ -function differentiable_eval_tree_array( - tree::AbstractExpressionNode, X::AbstractArray, options::Options +function DE.differentiable_eval_tree_array( + tree::Union{AbstractExpression,AbstractExpressionNode}, + X::AbstractArray, + options::Options, ) A = expected_array_type(X) - return differentiable_eval_tree_array(tree, X, options.operators)::Tuple{A,Bool} + # TODO: Add `AbstractExpression` implementation in `Expression.jl` + return DE.differentiable_eval_tree_array( + DE.get_tree(tree), X, DE.get_operators(tree, options) + )::Tuple{A,Bool} end const WILDCARD_UNIT_STRING = "[?]" @@ -149,8 +188,8 @@ Convert an equation to a string. - `variable_names::Union{Array{String, 1}, Nothing}=nothing`: what variables to print for each feature. """ -@inline function string_tree( - tree::AbstractExpressionNode, +@inline function DE.string_tree( + tree::Union{AbstractExpression,AbstractExpressionNode}, options::Options; raw::Bool=true, X_sym_units=nothing, @@ -164,16 +203,19 @@ Convert an equation to a string. if raw tree = tree isa GraphNode ? convert(Node, tree) : tree - return string_tree( - tree, options.operators; f_variable=string_variable_raw, variable_names + return DE.string_tree( + tree, + DE.get_operators(tree, options); + f_variable=string_variable_raw, + variable_names, ) end vprecision = vals[options.print_precision] if X_sym_units !== nothing || y_sym_units !== nothing - return string_tree( + return DE.string_tree( tree, - options.operators; + DE.get_operators(tree, options); f_variable=(feature, vname) -> string_variable(feature, vname, X_sym_units), f_constant=let unit_placeholder = @@ -184,9 +226,9 @@ Convert an equation to a string. kws..., ) else - return string_tree( + return DE.string_tree( tree, - options.operators; + DE.get_operators(tree, options); f_variable=string_variable, f_constant=(val,) -> string_constant(val, vprecision, ""), variable_names=display_variable_names, @@ -252,22 +294,15 @@ Print an equation - `variable_names::Union{Array{String, 1}, Nothing}=nothing`: what variables to print for each feature. """ -function print_tree(tree::AbstractExpressionNode, options::Options; kws...) - return print_tree(tree, options.operators; kws...) -end -function print_tree(io::IO, tree::AbstractExpressionNode, options::Options; kws...) - return print_tree(io, tree, options.operators; kws...) +function DE.print_tree( + tree::Union{AbstractExpression,AbstractExpressionNode}, options::Options; kws... +) + return DE.print_tree(tree, DE.get_operators(tree, options); kws...) end - -""" - convert(::Type{<:AbstractExpressionNode{T}}, tree::AbstractExpressionNode, options::Options; kws...) where {T} - -Convert an equation to a different base type `T`. -""" -function Base.convert( - ::Type{N}, tree::AbstractExpressionNode, options::Options -) where {T,N<:AbstractExpressionNode{T}} - return convert(N, tree, options.operators) +function DE.print_tree( + io::IO, tree::Union{AbstractExpression,AbstractExpressionNode}, options::Options; kws... +) + return DE.print_tree(io, tree, DE.get_operators(tree, options); kws...) end """ @@ -283,13 +318,13 @@ defined. macro extend_operators(options) operators = :($(options).operators) type_requirements = Options - @gensym alias_operators + alias_operators = gensym("alias_operators") return quote if !isa($(options), $type_requirements) error("You must pass an options type to `@extend_operators`.") end $alias_operators = $define_alias_operators($operators) - $(DynamicExpressions).@extend_operators $alias_operators + $(DE).@extend_operators $alias_operators end |> esc end function define_alias_operators(operators) @@ -304,14 +339,22 @@ function define_alias_operators(operators) ) end -function (tree::AbstractExpressionNode)(X, options::Options; kws...) - return tree(X, options.operators; turbo=options.turbo, bumper=options.bumper, kws...) +function (tree::Union{AbstractExpression,AbstractExpressionNode})( + X, options::Options; kws... +) + return tree( + X, + DE.get_operators(tree, options); + turbo=options.turbo, + bumper=options.bumper, + kws..., + ) end -function DynamicExpressions.EvaluationHelpersModule._grad_evaluator( - tree::AbstractExpressionNode, X, options::Options; kws... +function DE.EvaluationHelpersModule._grad_evaluator( + tree::Union{AbstractExpression,AbstractExpressionNode}, X, options::Options; kws... ) - return DynamicExpressions.EvaluationHelpersModule._grad_evaluator( - tree, X, options.operators; turbo=options.turbo, kws... + return DE.EvaluationHelpersModule._grad_evaluator( + tree, X, DE.get_operators(tree, options); turbo=options.turbo, kws... ) end diff --git a/src/LLMFunctions.jl b/src/LLMFunctions.jl new file mode 100644 index 000000000..67ead9a8d --- /dev/null +++ b/src/LLMFunctions.jl @@ -0,0 +1,943 @@ +module LLMFunctionsModule + +using Random: default_rng, AbstractRNG, rand, randperm +using DynamicExpressions: + Node, + AbstractExpressionNode, + AbstractExpression, + ParametricExpression, + ParametricNode, + AbstractNode, + NodeSampler, + get_contents, + with_contents, + constructorof, + copy_node, + set_node!, + count_nodes, + has_constants, + has_operators, + string_tree, + AbstractOperatorEnum +using Compat: Returns, @inline +using ..CoreModule: Options, DATA_TYPE, binopmap, unaopmap, LLMOptions +using ..MutationFunctionsModule: gen_random_tree_fixed_size + +using PromptingTools: + SystemMessage, + UserMessage, + AIMessage, + aigenerate, + CustomOpenAISchema, + OllamaSchema, + OpenAISchema +using JSON: parse + +"""LLM Recoder records the LLM calls for debugging purposes.""" +function llm_recorder(options::LLMOptions, expr::String, mode::String="debug") + if options.active + if !isdir(options.llm_recorder_dir) + mkdir(options.llm_recorder_dir) + end + recorder = open(joinpath(options.llm_recorder_dir, "llm_calls.txt"), "a") + write(recorder, string("[", mode, "] ", expr, "\n[/", mode, "]\n")) + close(recorder) + end +end + +function load_prompt(path::String)::String + # load prompt file + f = open(path, "r") + s = read(f, String) + close(f) + return s +end + +function convertDict(d)::NamedTuple + return (; Dict(Symbol(k) => v for (k, v) in d)...) +end + +function get_vars(options::Options)::String + variable_names = ["x", "y", "z", "k", "j", "l", "m", "n", "p", "a", "b"] + if !isnothing(options.llm_options.var_order) + variable_names = [ + options.llm_options.var_order[key] for + key in sort(collect(keys(options.llm_options.var_order))) + ] + end + return join(variable_names, ", ") +end + +function get_ops(options::Options)::String + binary_operators = map(v -> string(v), map(binopmap, options.operators.binops)) + unary_operators = map(v -> string(v), map(unaopmap, options.operators.unaops)) + # Binary Ops: +, *, -, /, safe_pow (^) + # Unary Ops: exp, safe_log, safe_sqrt, sin, cos + return replace( + replace( + "binary operators: " * + join(binary_operators, ", ") * + ", and unary operators: " * + join(unary_operators, ", "), + "safe_" => "", + ), + "pow" => "^", + ) +end + +""" +Constructs a prompt by replacing the element_id_tag with the corresponding element in the element_list. +If the element_list is longer than the number of occurrences of the element_id_tag, the missing elements are added after the last occurrence. +If the element_list is shorter than the number of occurrences of the element_id_tag, the extra ids are removed. +""" +function construct_prompt( + user_prompt::String, element_list::Vector, element_id_tag::String +)::String + # Split the user prompt into lines + lines = split(user_prompt, "\n") + + # Filter lines that match the pattern "... : {{element_id_tag[1-9]}} + pattern = r"^.*: \{\{" * element_id_tag * r"\d+\}\}$" + + # find all occurrences of the element_id_tag + n_occurrences = count(x -> occursin(pattern, x), lines) + + # if n_occurrences is less than |element_list|, add the missing elements after the last occurrence + if n_occurrences < length(element_list) + last_occurrence = findlast(x -> occursin(pattern, x), lines) + for i in reverse((n_occurrences + 1):length(element_list)) + new_line = replace(lines[last_occurrence], string(n_occurrences) => string(i)) + insert!(lines, last_occurrence + 1, new_line) + end + end + + new_prompt = "" + idx = 1 + for line in lines + # if the line matches the pattern + if occursin(pattern, line) + if idx > length(element_list) + continue + end + # replace the element_id_tag with the corresponding element + new_prompt *= + replace(line, r"\{\{" * element_id_tag * r"\d+\}\}" => element_list[idx]) * + "\n" + idx += 1 + else + new_prompt *= line * "\n" + end + end + return new_prompt +end + +function gen_llm_random_tree( + node_count::Int, + options::Options, + nfeatures::Int, + ::Type{T}, + idea_database::Union{Vector{String},Nothing}, +)::AbstractExpressionNode{T} where {T<:DATA_TYPE} + # Note that this base tree is just a placeholder; it will be replaced. + N = 5 + # LLM prompt + # conversation = [ + # SystemMessage(load_prompt(options.llm_options.prompts_dir * "gen_random_system.txt")), + # UserMessage(load_prompt(options.llm_options.prompts_dir * "gen_random_user.txt"))] + assumptions = sample_context( + idea_database, + options.llm_options.num_pareto_context, + options.llm_options.idea_threshold, + ) + + if !options.llm_options.prompt_concepts + assumptions = [] + end + + conversation = [ + UserMessage( + load_prompt(options.llm_options.prompts_dir * "gen_random_system.txt") * + "\n" * + construct_prompt( + load_prompt(options.llm_options.prompts_dir * "gen_random_user.txt"), + assumptions, + "assump", + ), + ), + ] + llm_recorder(options.llm_options, conversation[1].content, "llm_input|gen_random") + + if options.llm_options.llm_context != "" + pushfirst!(assumptions, options.llm_options.llm_context) + end + + msg = nothing + try + msg = aigenerate( + CustomOpenAISchema(), + conversation; #OllamaSchema(), conversation; + variables=get_vars(options), + operators=get_ops(options), + N=N, + api_key=options.llm_options.api_key, + model=options.llm_options.model, + api_kwargs=convertDict(options.llm_options.api_kwargs), + http_kwargs=convertDict(options.llm_options.http_kwargs), + ) + catch e + llm_recorder(options.llm_options, "None", "gen_random|failed") + return gen_random_tree_fixed_size(node_count, options, nfeatures, T) + end + llm_recorder(options.llm_options, string(msg.content), "llm_output|gen_random") + + gen_tree_options = parse_msg_content(msg.content) + + N = min(size(gen_tree_options)[1], N) + + if N == 0 + llm_recorder(options.llm_options, "None", "gen_random|failed") + return gen_random_tree_fixed_size(node_count, options, nfeatures, T) + end + + for i in 1:N + l = rand(1:N) + t = expr_to_tree( + T, + String(strip(gen_tree_options[l], [' ', '\n', '"', ',', '.', '[', ']'])), + options, + ) + if t.val == 1 && t.constant + continue + end + llm_recorder(options.llm_options, tree_to_expr(t, options), "gen_random") + + return t + end + + out = expr_to_tree( + T, String(strip(gen_tree_options[1], [' ', '\n', '"', ',', '.', '[', ']'])), options + ) + + llm_recorder(options.llm_options, tree_to_expr(out, options), "gen_random") + + if out.val == 1 && out.constant + return gen_random_tree_fixed_size(node_count, options, nfeatures, T) + end + + return out +end + +"""Crossover between two expressions""" +function crossover_trees( + tree1::AbstractExpressionNode{T}, tree2::AbstractExpressionNode{T} +)::Tuple{AbstractExpressionNode{T},AbstractExpressionNode{T}} where {T<:DATA_TYPE} + tree1 = copy_node(tree1) + tree2 = copy_node(tree2) + + node1, parent1, side1 = random_node_and_parent(tree1) + node2, parent2, side2 = random_node_and_parent(tree2) + + node1 = copy_node(node1) + + if side1 == 'l' + parent1.l = copy_node(node2) + # tree1 now contains this. + elseif side1 == 'r' + parent1.r = copy_node(node2) + # tree1 now contains this. + else # 'n' + # This means that there is no parent2. + tree1 = copy_node(node2) + end + + if side2 == 'l' + parent2.l = node1 + elseif side2 == 'r' + parent2.r = node1 + else # 'n' + tree2 = node1 + end + return tree1, tree2 +end + +function sketch_const(val) + does_not_need_brackets = (typeof(val) <: Union{Real,AbstractArray}) + + if does_not_need_brackets + if isinteger(val) && (abs(val) < 5) # don't abstract integer constants from -4 to 4, useful for exponents + string(val) + else + "C" + end + else + if isinteger(val) && (abs(val) < 5) # don't abstract integer constants from -4 to 4, useful for exponents + "(" * string(val) * ")" + else + "(C)" + end + end +end + +function tree_to_expr( + ex::AbstractExpression{T}, options::Options +)::String where {T<:DATA_TYPE} + return tree_to_expr(get_contents(ex), options) +end + +function tree_to_expr(tree::AbstractExpressionNode{T}, options)::String where {T<:DATA_TYPE} + variable_names = ["x", "y", "z", "k", "j", "l", "m", "n", "p", "a", "b"] + if !isnothing(options.llm_options.var_order) + variable_names = [ + options.llm_options.var_order[key] for + key in sort(collect(keys(options.llm_options.var_order))) + ] + end + return string_tree( + tree, options.operators; f_constant=sketch_const, variable_names=variable_names + ) +end + +function handle_not_expr(::Type{T}, x, var_names)::Node{T} where {T<:DATA_TYPE} + if x isa Real + Node{T}(; val=convert(T, x)) # old: Node(T, 0, true, convert(T,x)) + elseif x isa Symbol + if x === :C # constant that got abstracted + Node{T}(; val=convert(T, 1)) # old: Node(T, 0, true, convert(T,1)) + else + feature = findfirst(isequal(string(x)), var_names) + if isnothing(feature) # invalid var name, just assume its x0 + feature = 1 + end + Node{T}(; feature=feature) # old: Node(T, 0, false, nothing, feature) + end + else + Node{T}(; val=convert(T, 1)) # old: Node(T, 0, true, convert(T,1)) # return a constant being 0 + end +end + +function expr_to_tree_recurse( + ::Type{T}, node::Expr, op::AbstractOperatorEnum, var_names +)::Node{T} where {T<:DATA_TYPE} + args = node.args + x = args[1] + degree = length(args) + + if degree == 1 + handle_not_expr(T, x, var_names) + elseif degree == 2 + unary_operators = map(v -> string(v), map(unaopmap, op.unaops)) + idx = findfirst(isequal(string(x)), unary_operators) + if isnothing(idx) # if not used operator, make it the first one + idx = findfirst(isequal("safe_" * string(x)), unary_operators) + if isnothing(idx) + idx = 1 + end + end + + left = if (args[2] isa Expr) + expr_to_tree_recurse(T, args[2], op, var_names) + else + handle_not_expr(T, args[2], var_names) + end + + Node(; op=idx, l=left) # old: Node(1, false, nothing, 0, idx, left) + elseif degree == 3 + if x === :^ + x = :pow + end + binary_operators = map(v -> string(v), map(binopmap, op.binops)) + idx = findfirst(isequal(string(x)), binary_operators) + if isnothing(idx) # if not used operator, make it the first one + idx = findfirst(isequal("safe_" * string(x)), binary_operators) + if isnothing(idx) + idx = 1 + end + end + + left = if (args[2] isa Expr) + expr_to_tree_recurse(T, args[2], op, var_names) + else + handle_not_expr(T, args[2], var_names) + end + right = if (args[3] isa Expr) + expr_to_tree_recurse(T, args[3], op, var_names) + else + handle_not_expr(T, args[3], var_names) + end + + Node(; op=idx, l=left, r=right) # old: Node(2, false, nothing, 0, idx, left, right) + else + Node{T}(; val=convert(T, 1)) # old: Node(T, 0, true, convert(T,1)) # return a constant being 1 + end +end + +function expr_to_tree_run(::Type{T}, x::String, options)::Node{T} where {T<:DATA_TYPE} + try + expr = Meta.parse(x) + variable_names = ["x", "y", "z", "k", "j", "l", "m", "n", "p", "a", "b"] + if !isnothing(options.llm_options.var_order) + variable_names = [ + options.llm_options.var_order[key] for + key in sort(collect(keys(options.llm_options.var_order))) + ] + end + if expr isa Expr + expr_to_tree_recurse(T, expr, options.operators, variable_names) + else + handle_not_expr(T, expr, variable_names) + end + catch + Node{T}(; val=convert(T, 1)) # old: Node(T, 0, true, convert(T,1)) # return a constant being 1 + end +end + +function expr_to_tree(::Type{T}, x::String, options) where {T<:DATA_TYPE} + if options.llm_options.is_parametric + out = ParametricNode{T}(expr_to_tree_run(T, x, options)) + else + out = Node{T}(expr_to_tree_run(T, x, options)) + end + return out +end + +function format_pareto(dominating, options, num_pareto_context::Int)::Vector{String} + pareto = Vector{String}() + if !isnothing(dominating) && size(dominating)[1] > 0 + idx = randperm(size(dominating)[1]) + for i in 1:min(size(dominating)[1], num_pareto_context) + push!(pareto, tree_to_expr(dominating[idx[i]].tree, options)) + end + end + while size(pareto)[1] < num_pareto_context + push!(pareto, "None") + end + return pareto +end + +function sample_one_context(idea_database, idea_threshold)::String + if isnothing(idea_database) + return "None" + end + + N = size(idea_database)[1] + if N == 0 + return "None" + end + + try + idea_database[rand(1:min(idea_threshold, N))] + catch e + "None" + end +end + +function sample_context(idea_database, N, idea_threshold)::Vector{String} + assumptions = Vector{String}() + if isnothing(idea_database) + for _ in 1:N + push!(assumptions, "None") + end + return assumptions + end + + if size(idea_database)[1] < N + for i in 1:(size(idea_database)[1]) + push!(assumptions, idea_database[i]) + end + for i in (size(idea_database)[1] + 1):N + push!(assumptions, "None") + end + return assumptions + end + + while size(assumptions)[1] < N + chosen_idea = sample_one_context(idea_database, idea_threshold) + if chosen_idea in assumptions + continue + end + push!(assumptions, chosen_idea) + end + return assumptions +end + +function prompt_evol(idea_database, options::Options) + num_ideas = size(idea_database)[1] + if num_ideas <= options.llm_options.idea_threshold + return nothing + end + + idea1 = idea_database[rand((options.llm_options.idea_threshold + 1):num_ideas)] + idea2 = idea_database[rand((options.llm_options.idea_threshold + 1):num_ideas)] # they could be same (should be allowed) + idea3 = idea_database[rand((options.llm_options.idea_threshold + 1):num_ideas)] # they could be same (should be allowed) + idea4 = idea_database[rand((options.llm_options.idea_threshold + 1):num_ideas)] # they could be same (should be allowed) + idea5 = idea_database[rand((options.llm_options.idea_threshold + 1):num_ideas)] # they could be same (should be allowed) + + N = 5 + + # conversation = [ + # SystemMessage(load_prompt(options.llm_options.prompts_dir * "prompt_evol_system.txt")), + # UserMessage(load_prompt(options.llm_options.prompts_dir * "prompt_evol_user.txt"))] + conversation = [ + UserMessage( + load_prompt(options.llm_options.prompts_dir * "prompt_evol_system.txt") * + "\n" * + construct_prompt( + load_prompt(options.llm_options.prompts_dir * "prompt_evol_user.txt"), + [idea1, idea2, idea3, idea4, idea5], + "idea", + ), + ), + ] + llm_recorder(options.llm_options, conversation[1].content, "llm_input|ideas") + + msg = nothing + try + msg = aigenerate( + CustomOpenAISchema(), + conversation; #OllamaSchema(), conversation; + N=N, + api_key=options.llm_options.api_key, + model=options.llm_options.model, + api_kwargs=convertDict(options.llm_options.api_kwargs), + http_kwargs=convertDict(options.llm_options.http_kwargs), + ) + catch e + llm_recorder(options.llm_options, "None", "ideas|failed") + return nothing + end + llm_recorder(options.llm_options, string(msg.content), "llm_output|ideas") + + idea_options = parse_msg_content(msg.content) + + N = min(size(idea_options)[1], N) + + if N == 0 + llm_recorder(options.llm_options, "None", "ideas|failed") + return nothing + end + + # only choose one, merging ideas not really crossover + chosen_idea = String( + strip(idea_options[rand(1:N)], [' ', '\n', '"', ',', '.', '[', ']']) + ) + + llm_recorder(options.llm_options, chosen_idea, "ideas") + + return chosen_idea +end + +function parse_msg_content(msg_content) + content = msg_content + try + content = match(r"```json(.*?)```"s, msg_content).captures[1] + catch e + try + content = match(r"```(.*?)```"s, msg_content).captures[1] + catch e2 + try + content = match(r"\[(.*?)\]"s, msg_content).match + catch e3 + content = msg_content + end + end + end + + try + out = parse(content) # json parse + if out isa Dict + return [out[key] for key in keys(out)] + end + + if out isa Vector && all(x -> isa(x, String), out) + return out + end + catch e + try + content = strip(content, [' ', '\n', '"', ',', '.', '[', ']']) + content = replace(content, "\n" => " ") + out_list = split(content, "\", \"") + return out_list + catch e2 + return [] + end + end + + try + content = strip(content, [' ', '\n', '"', ',', '.', '[', ']']) + content = replace(content, "\n" => " ") + out_list = split(content, "\", \"") + return out_list + catch e3 + return [] + end + # old method: + # find first JSON list + # first_idx = findfirst('[', content) + # last_idx = findfirst(']', content) + # content = chop(content, head=first_idx, tail=length(content) - last_idx + 1) + + # out_list = split(content, ",") + # for i in 1:length(out_list) + # out_list[i] = replace(out_list[i], "//.*" => "") # filter comments + # end + + # new method (for Llama since it follows directions better): +end + +function update_idea_database(idea_database, dominating, worst_members, options::Options) + # turn dominating pareto curve into ideas as strings + if isnothing(dominating) + return nothing + end + + op = options.operators + num_pareto_context = 5 # options.mutation_weights.num_pareto_context # must be 5 right now for prompts + + gexpr = format_pareto(dominating, options, num_pareto_context) + bexpr = format_pareto(worst_members, options, num_pareto_context) + + N = 5 + + # conversation = [ + # SystemMessage(load_prompt(options.llm_options.prompts_dir * "extract_idea_system.txt")), + # UserMessage(load_prompt(options.llm_options.prompts_dir * "extract_idea_user.txt"))] + conversation = [ + UserMessage( + load_prompt(options.llm_options.prompts_dir * "extract_idea_system.txt") * + "\n" * + construct_prompt( + construct_prompt( + load_prompt(options.llm_options.prompts_dir * "extract_idea_user.txt"), + gexpr, + "gexpr", + ), + bexpr, + "bexpr", + ), + ), + ] + llm_recorder(options.llm_options, conversation[1].content, "llm_input|gen_random") + + msg = nothing + try + # msg = aigenerate(OpenAISchema(), conversation; #OllamaSchema(), conversation; + # variables=get_vars(options), + # operators=get_ops(options), + # N=N, + # gexpr1=gexpr[1], + # gexpr2=gexpr[2], + # gexpr3=gexpr[3], + # gexpr4=gexpr[4], + # gexpr5=gexpr[5], + # bexpr1=bexpr[1], + # bexpr2=bexpr[2], + # bexpr3=bexpr[3], + # bexpr4=bexpr[4], + # bexpr5=bexpr[5], + # model="gpt-3.5-turbo-0125" + # ) + msg = aigenerate( + CustomOpenAISchema(), + conversation; #OllamaSchema(), conversation; + variables=get_vars(options), + operators=get_ops(options), + N=N, + api_key=options.llm_options.api_key, + model=options.llm_options.model, + api_kwargs=convertDict(options.llm_options.api_kwargs), + http_kwargs=convertDict(options.llm_options.http_kwargs), + ) + catch e + llm_recorder(options.llm_options, "None", "ideas|failed") + return nothing + end + + llm_recorder(options.llm_options, string(msg.content), "llm_output|ideas") + + idea_options = parse_msg_content(msg.content) + + N = min(size(idea_options)[1], N) + + if N == 0 + llm_recorder(options.llm_options, "None", "ideas|failed") + return nothing + end + + a = rand(1:N) + + chosen_idea1 = String(strip(idea_options[a], [' ', '\n', '"', ',', '.', '[', ']'])) + + llm_recorder(options.llm_options, chosen_idea1, "ideas") + pushfirst!(idea_database, chosen_idea1) + + if N > 1 + b = rand(1:(N - 1)) + if a == b + b += 1 + end + chosen_idea2 = String(strip(idea_options[b], [' ', '\n', '"', ',', '.', '[', ']'])) + + llm_recorder(options.llm_options, chosen_idea2, "ideas") + + pushfirst!(idea_database, chosen_idea2) + end + + num_add = 2 + for _ in 1:num_add + out = prompt_evol(idea_database, options) + if !isnothing(out) + pushfirst!(idea_database, out) + end + end +end + +function llm_mutate_op( + ex::AbstractExpression{T}, options::Options, dominating, idea_database +)::AbstractExpression{T} where {T<:DATA_TYPE} + tree = get_contents(ex) + ex = with_contents(ex, llm_mutate_op(tree, options, dominating, idea_database)) + return ex +end + +"""LLM Mutation on a tree""" +function llm_mutate_op( + tree::AbstractExpressionNode{T}, options::Options, dominating, idea_database +)::AbstractExpressionNode{T} where {T<:DATA_TYPE} + expr = tree_to_expr(tree, options) # TODO: change global expr right now, could do it by subtree (weighted near root more) + N = 5 + # LLM prompt + # TODO: we can use async map to do concurrent requests (useful for trying multiple prompts), see: https://github.com/svilupp/PromptingTools.jl?tab=readme-ov-file#asynchronous-execution + + # conversation = [ + # SystemMessage(load_prompt(options.llm_options.prompts_dir * "mutate_system.txt")), + # UserMessage(load_prompt(options.llm_options.prompts_dir * "mutate_user.txt"))] + + assumptions = sample_context( + idea_database, + options.llm_options.num_pareto_context, + options.llm_options.idea_threshold, + ) + pareto = format_pareto(dominating, options, options.llm_options.num_pareto_context) + if !options.llm_options.prompt_concepts + assumptions = [] + pareto = [] + end + conversation = [ + UserMessage( + load_prompt(options.llm_options.prompts_dir * "mutate_system.txt") * + "\n" * + construct_prompt( + load_prompt(options.llm_options.prompts_dir * "mutate_user.txt"), + assumptions, + "assump", + ), + ), + ] + llm_recorder(options.llm_options, conversation[1].content, "llm_input|mutate") + + if options.llm_options.llm_context != "" + pushfirst!(assumptions, options.llm_options.llm_context) + end + + msg = nothing + try + msg = aigenerate( + CustomOpenAISchema(), + conversation; #OllamaSchema(), conversation; + variables=get_vars(options), + operators=get_ops(options), + N=N, + expr=expr, + api_key=options.llm_options.api_key, + model=options.llm_options.model, + api_kwargs=convertDict(options.llm_options.api_kwargs), + http_kwargs=convertDict(options.llm_options.http_kwargs), + ) + catch e + llm_recorder(options.llm_options, "None", "mutate|failed") + return tree + end + + llm_recorder(options.llm_options, string(msg.content), "llm_output|mutate") + + mut_tree_options = parse_msg_content(msg.content) + + N = min(size(mut_tree_options)[1], N) + + if N == 0 + llm_recorder(options.llm_options, "None", "mutate|failed") + return tree + end + + for i in 1:N + l = rand(1:N) + t = expr_to_tree( + T, + String(strip(mut_tree_options[l], [' ', '\n', '"', ',', '.', '[', ']'])), + options, + ) + if t.val == 1 && t.constant + continue + end + + llm_recorder(options.llm_options, tree_to_expr(t, options), "mutate") + + return t + end + + out = expr_to_tree( + T, String(strip(mut_tree_options[1], [' ', '\n', '"', ',', '.', '[', ']'])), options + ) + + llm_recorder(options.llm_options, tree_to_expr(out, options), "mutate") + + return out +end + +function llm_crossover_trees( + ex1::E, ex2::E, options::Options, dominating, idea_database +)::Tuple{E,E} where {T,E<:AbstractExpression{T}} + tree1 = get_contents(ex1) + tree2 = get_contents(ex2) + tree1, tree2 = llm_crossover_trees(tree1, tree2, options, dominating, idea_database) + ex1 = with_contents(ex1, tree1) + ex2 = with_contents(ex2, tree2) + return ex1, ex2 +end + +"""LLM Crossover between two expressions""" +function llm_crossover_trees( + tree1::AbstractExpressionNode{T}, + tree2::AbstractExpressionNode{T}, + options::Options, + dominating, + idea_database, +)::Tuple{AbstractExpressionNode{T},AbstractExpressionNode{T}} where {T<:DATA_TYPE} + expr1 = tree_to_expr(tree1, options) + expr2 = tree_to_expr(tree2, options) + N = 5 + + # LLM prompt + # conversation = [ + # SystemMessage(load_prompt(options.llm_options.prompts_dir * "crossover_system.txt")), + # UserMessage(load_prompt(options.llm_options.prompts_dir * "crossover_user.txt"))] + assumptions = sample_context( + idea_database, + options.llm_options.num_pareto_context, + options.llm_options.idea_threshold, + ) + pareto = format_pareto(dominating, options, options.llm_options.num_pareto_context) + if !options.llm_options.prompt_concepts + assumptions = [] + pareto = [] + end + + conversation = [ + UserMessage( + load_prompt(options.llm_options.prompts_dir * "crossover_system.txt") * + "\n" * + construct_prompt( + load_prompt(options.llm_options.prompts_dir * "crossover_user.txt"), + assumptions, + "assump", + ), + ), + ] + + if options.llm_options.llm_context != "" + pushfirst!(assumptions, options.llm_options.llm_context) + end + + llm_recorder(options.llm_options, conversation[1].content, "llm_input|crossover") + + msg = nothing + try + msg = aigenerate( + CustomOpenAISchema(), + conversation; #OllamaSchema(), conversation; + variables=get_vars(options), + operators=get_ops(options), + N=N, + # pareto1=pareto[1], + # pareto2=pareto[2], + # pareto3=pareto[3], + expr1=expr1, + expr2=expr2, + api_key=options.llm_options.api_key, + model=options.llm_options.model, + api_kwargs=convertDict(options.llm_options.api_kwargs), + http_kwargs=convertDict(options.llm_options.http_kwargs), + ) + catch e + llm_recorder(options.llm_options, "None", "crossover|failed") + return tree1, tree2 + end + + llm_recorder(options.llm_options, string(msg.content), "llm_output|crossover") + + cross_tree_options = parse_msg_content(msg.content) + + cross_tree1 = nothing + cross_tree2 = nothing + + N = min(size(cross_tree_options)[1], N) + + if N == 0 + llm_recorder(options.llm_options, "None", "crossover|failed") + return tree1, tree2 + end + + if N == 1 + t = expr_to_tree( + T, + String(strip(cross_tree_options[1], [' ', '\n', '"', ',', '.', '[', ']'])), + options, + ) + + llm_recorder(options.llm_options, tree_to_expr(t, options), "crossover") + + return t, tree2 + end + + for i in 1:(2 * N) + l = rand(1:N) + t = expr_to_tree( + T, + String(strip(cross_tree_options[l], [' ', '\n', '"', ',', '.', '[', ']'])), + options, + ) + if t.val == 1 && t.constant + continue + end + + if isnothing(cross_tree1) + cross_tree1 = t + elseif isnothing(cross_tree2) + cross_tree2 = t + break + end + end + + if isnothing(cross_tree1) + cross_tree1 = expr_to_tree( + T, + String(strip(cross_tree_options[1], [' ', '\n', '"', ',', '.', '[', ']'])), + options, + ) + end + + if isnothing(cross_tree2) + cross_tree2 = expr_to_tree( + T, + String(strip(cross_tree_options[2], [' ', '\n', '"', ',', '.', '[', ']'])), + options, + ) + end + + recording_str = tree_to_expr(cross_tree1, options) * " && " * tree_to_expr(cross_tree2, options) + llm_recorder(options.llm_options, recording_str, "crossover") + + return cross_tree1, cross_tree2 +end + +end \ No newline at end of file diff --git a/src/LLMOptions.jl b/src/LLMOptions.jl new file mode 100644 index 000000000..5f7a01432 --- /dev/null +++ b/src/LLMOptions.jl @@ -0,0 +1,125 @@ +module LLMOptionsModule + +using StatsBase: StatsBase +using Base: isvalid + +""" + LLMWeights(;kws...) + +Defines the probability of different LLM-based mutation operations. Follows the same +pattern as MutationWeights. These weights will be normalized to sum to 1.0 after initialization. +# Arguments +- `llm_mutate::Float64`: Probability of calling LLM version of mutation. + The LLM operations are significantly slower than their symbolic counterparts, + so higher probabilities will result in slower operations. +- `llm_crossover::Float64`: Probability of calling LLM version of crossover. + Same limitation as llm_mutate. +- `llm_gen_random::Float64`: Probability of calling LLM version of gen_random. + Same limitation as llm_mutate. +""" +Base.@kwdef mutable struct LLMWeights + llm_mutate::Float64 = 0.0 + llm_crossover::Float64 = 0.0 + llm_gen_random::Float64 = 0.0 +end + +""" + LLMOptions(;kws...) + +This defines how to call the LLM inference functions. LLM inference is managed by PromptingTools.jl but +this module serves as the entry point to define new options for the LLM inference. +# Arguments +- `active::Bool`: Whether to use LLM inference or not. +- `weights::LLMWeights`: Weights for different LLM operations. +- `num_pareto_context::Int64`: Number of equations to sample from pareto frontier. +- `prompt_concepts::Bool`: Use natural language concepts in the LLM prompts. +- `prompt_evol::Bool`: Evolve natural language concepts through succesive LLM + calls. +- api_key::String: OpenAI API key. Required. +- model::String: OpenAI model to use. Required. +- api_kwargs::Dict: Additional keyword arguments to pass to the OpenAI API. + - url::String: URL to send the request to. Required. + - max_tokens::Int: Maximum number of tokens to generate. (default: 1000) +- http_kwargs::Dict: Additional keyword arguments for the HTTP request. + - retries::Int: Number of retries to attempt. (default: 3) + - readtimeout::Int: Read timeout for the HTTP request (in seconds; default is 1 hour). +- `llm_recorder_dir::String`: File to save LLM logs to. Useful for debugging. +- `llm_context::AbstractString`: Context string for LLM. +- `var_order::Union{Dict,Nothing}`: Variable order for LLM. (default: nothing) +""" +Base.@kwdef mutable struct LLMOptions + active::Bool = false + weights::LLMWeights = LLMWeights() + num_pareto_context::Int64 = 0 + prompt_concepts::Bool = false + prompt_evol::Bool = false + api_key::String = "" + model::String = "" + api_kwargs::Dict = Dict( + "max_tokens" => 1000 + ) + http_kwargs::Dict = Dict("retries" => 3, "readtimeout" => 3600) + llm_recorder_dir::String = "lasr_runs/" + prompts_dir::String = "prompts/" + llm_context::AbstractString = "" + var_order::Union{Dict,Nothing} = nothing + idea_threshold::UInt32 = 30 + is_parametric::Bool = false +end + +const llm_mutations = fieldnames(LLMWeights) +const v_llm_mutations = Symbol[llm_mutations...] + +# Validate some options are set correctly. +"""Validate some options are set correctly. +Specifically, need to check +- If `active` is true, then `api_key` and `model` must be set. +- If `active` is true, then `api_kwargs` must have a `url` key and it must be a valid URL. +- If `active` is true, then `llm_recorder_dir` must be a valid directory. +""" +function validate_llm_options(options::LLMOptions) + if options.active + if options.api_key == "" + throw(ArgumentError("api_key must be set if LLM is active.")) + end + if options.model == "" + throw(ArgumentError("model must be set if LLM is active.")) + end + if !haskey(options.api_kwargs, "url") + throw(ArgumentError("api_kwargs must have a 'url' key.")) + end + if !isdir(options.prompts_dir) + throw(ArgumentError("prompts_dir must be a valid directory.")) + end + end +end + + + +# """Sample LLM mutation, given the weightings.""" +# function sample_llm_mutation(w::LLMWeights) +# weights = convert(Vector, w) +# return StatsBase.sample(v_llm, StatsBase.Weights(weights)) +# end + +end # module + + + +# sample invocation following: +# python -m experiments.main --use_llm --use_prompt_evol --model "meta-llama/Meta-Llama-3-8B-Instruct" --api_key "vllm_api.key" --model_url "http://localhost:11440/v1" --exp_idx 0 --dataset_path FeynmanEquations.csv --start_idx 0 +# options = LLMOptions( +# active=true, +# weights=LLMWeights(llm_mutate=0.5, llm_crossover=0.3, llm_gen_random=0.2), +# num_pareto_context=5, +# prompt_evol=true, +# prompt_concepts=true, +# api_key="vllm_api.key", +# model="meta-llama/Meta-Llama-3-8B-Instruct", +# api_kwargs=Dict("url" => "http://localhost:11440/v1"), +# http_kwargs=Dict("retries" => 3, "readtimeout" => 3600), +# llm_recorder_dir="lasr_runs/", +# llm_context="", +# var_order=nothing, +# idea_threshold=30 +# ) \ No newline at end of file diff --git a/src/LibraryAugmentedSymbolicRegression.jl b/src/LibraryAugmentedSymbolicRegression.jl new file mode 100644 index 000000000..c91d6717a --- /dev/null +++ b/src/LibraryAugmentedSymbolicRegression.jl @@ -0,0 +1,1241 @@ +module LibraryAugmentedSymbolicRegression + +# Types +export Population, + PopMember, + HallOfFame, + Options, + Dataset, + MutationWeights, + LLMWeights, + LLMOptions, + Node, + GraphNode, + ParametricNode, + Expression, + ParametricExpression, + StructuredExpression, + NodeSampler, + AbstractExpression, + AbstractExpressionNode, + LaSRRegressor, + MultitargetLaSRRegressor, + LOSS_TYPE, + DATA_TYPE, + + #Functions: + equation_search, + s_r_cycle, + calculate_pareto_frontier, + count_nodes, + compute_complexity, + @parse_expression, + parse_expression, + print_tree, + string_tree, + eval_tree_array, + eval_diff_tree_array, + eval_grad_tree_array, + differentiable_eval_tree_array, + set_node!, + copy_node, + node_to_symbolic, + node_type, + symbolic_to_node, + simplify_tree!, + tree_mapreduce, + combine_operators, + gen_random_tree, + gen_random_tree_fixed_size, + @extend_operators, + get_tree, + get_contents, + get_metadata, + + #Operators + plus, + sub, + mult, + square, + cube, + pow, + safe_pow, + safe_log, + safe_log2, + safe_log10, + safe_log1p, + safe_acosh, + safe_sqrt, + neg, + greater, + cond, + relu, + logical_or, + logical_and, + + # special operators + gamma, + erf, + erfc, + atanh_clip + +using Distributed +using Printf: @printf, @sprintf +using PackageExtensionCompat: @require_extensions +using Pkg: Pkg +using TOML: parsefile +using Random: seed!, shuffle! +using Reexport +using DynamicExpressions: + Node, + GraphNode, + ParametricNode, + Expression, + ParametricExpression, + StructuredExpression, + NodeSampler, + AbstractExpression, + AbstractExpressionNode, + @parse_expression, + parse_expression, + copy_node, + set_node!, + string_tree, + print_tree, + count_nodes, + get_constants, + get_scalar_constants, + set_constants!, + set_scalar_constants!, + index_constants, + NodeIndex, + eval_tree_array, + differentiable_eval_tree_array, + eval_diff_tree_array, + eval_grad_tree_array, + node_to_symbolic, + symbolic_to_node, + combine_operators, + simplify_tree!, + tree_mapreduce, + set_default_variable_names!, + node_type, + get_tree, + get_contents, + get_metadata +using DynamicExpressions: with_type_parameters +@reexport using LossFunctions: + MarginLoss, + DistanceLoss, + SupervisedLoss, + ZeroOneLoss, + LogitMarginLoss, + PerceptronLoss, + HingeLoss, + L1HingeLoss, + L2HingeLoss, + SmoothedL1HingeLoss, + ModifiedHuberLoss, + L2MarginLoss, + ExpLoss, + SigmoidLoss, + DWDMarginLoss, + LPDistLoss, + L1DistLoss, + L2DistLoss, + PeriodicLoss, + HuberLoss, + EpsilonInsLoss, + L1EpsilonInsLoss, + L2EpsilonInsLoss, + LogitDistLoss, + QuantileLoss, + LogCoshLoss + +# https://discourse.julialang.org/t/how-to-find-out-the-version-of-a-package-from-its-module/37755/15 +const PACKAGE_VERSION = try + root = pkgdir(@__MODULE__) + if root == String + let project = parsefile(joinpath(root, "Project.toml")) + VersionNumber(project["version"]) + end + else + VersionNumber(0, 0, 0) + end +catch + VersionNumber(0, 0, 0) +end + +function deprecate_varmap(variable_names, varMap, func_name) + if varMap !== nothing + Base.depwarn("`varMap` is deprecated; use `variable_names` instead", func_name) + @assert variable_names === nothing "Cannot pass both `varMap` and `variable_names`" + variable_names = varMap + end + return variable_names +end + +using DispatchDoctor: @stable + +@stable default_mode = "disable" begin + include("Utils.jl") + include("InterfaceDynamicQuantities.jl") + include("Core.jl") + include("InterfaceDynamicExpressions.jl") + include("Recorder.jl") + include("Complexity.jl") + include("DimensionalAnalysis.jl") + include("CheckConstraints.jl") + include("AdaptiveParsimony.jl") + include("MutationFunctions.jl") + include("LLMFunctions.jl") + include("LossFunctions.jl") + include("PopMember.jl") + include("ConstantOptimization.jl") + include("Population.jl") + include("HallOfFame.jl") + include("Mutate.jl") + include("RegularizedEvolution.jl") + include("SingleIteration.jl") + include("ProgressBars.jl") + include("Migration.jl") + include("SearchUtils.jl") + include("ExpressionBuilder.jl") +end + +using .CoreModule: + MAX_DEGREE, + BATCH_DIM, + FEATURE_DIM, + DATA_TYPE, + LOSS_TYPE, + RecordType, + Dataset, + Options, + MutationWeights, + LLMOptions, + LLMWeights, + plus, + sub, + mult, + square, + cube, + pow, + safe_pow, + safe_log, + safe_log2, + safe_log10, + safe_log1p, + safe_sqrt, + safe_acosh, + neg, + greater, + cond, + relu, + logical_or, + logical_and, + gamma, + erf, + erfc, + atanh_clip, + create_expression +using .UtilsModule: is_anonymous_function, recursive_merge, json3_write, @ignore +using .ComplexityModule: compute_complexity +using .CheckConstraintsModule: check_constraints +using .AdaptiveParsimonyModule: + RunningSearchStatistics, update_frequencies!, move_window!, normalize_frequencies! +using .MutationFunctionsModule: + gen_random_tree, + gen_random_tree_fixed_size, + random_node, + random_node_and_parent, + crossover_trees +using .LLMFunctionsModule: update_idea_database, llm_recorder + +using .InterfaceDynamicExpressionsModule: @extend_operators +using .LossFunctionsModule: eval_loss, score_func, update_baseline_loss! +using .PopMemberModule: PopMember, reset_birth! +using .PopulationModule: Population, best_sub_pop, record_population, best_of_sample +using .HallOfFameModule: + HallOfFame, calculate_pareto_frontier, string_dominating_pareto_curve +using .SingleIterationModule: s_r_cycle, optimize_and_simplify_population +using .ProgressBarsModule: WrappedProgressBar +using .RecorderModule: @recorder, find_iteration_from_record +using .MigrationModule: migrate! +using .SearchUtilsModule: + SearchState, + RuntimeOptions, + WorkerAssignments, + DefaultWorkerOutputType, + assign_next_worker!, + get_worker_output_type, + extract_from_worker, + @sr_spawner, + StdinReader, + watch_stream, + close_reader!, + check_for_user_quit, + check_for_loss_threshold, + check_for_timeout, + check_max_evals, + ResourceMonitor, + record_channel_state!, + estimate_work_fraction, + update_progress_bar!, + print_search_state, + init_dummy_pops, + load_saved_hall_of_fame, + load_saved_population, + construct_datasets, + save_to_file, + get_cur_maxsize, + update_hall_of_fame! +using .ExpressionBuilderModule: embed_metadata, strip_metadata + +@stable default_mode = "disable" begin + include("deprecates.jl") + include("Configure.jl") +end + +""" + equation_search(X, y[; kws...]) + +Perform a distributed equation search for functions `f_i` which +describe the mapping `f_i(X[:, j]) โ‰ˆ y[i, j]`. Options are +configured using LibraryAugmentedSymbolicRegression.Options(...), +which should be passed as a keyword argument to options. +One can turn off parallelism with `numprocs=0`, +which is useful for debugging and profiling. + +# Arguments +- `X::AbstractMatrix{T}`: The input dataset to predict `y` from. + The first dimension is features, the second dimension is rows. +- `y::Union{AbstractMatrix{T}, AbstractVector{T}}`: The values to predict. The first dimension + is the output feature to predict with each equation, and the + second dimension is rows. +- `niterations::Int=10`: The number of iterations to perform the search. + More iterations will improve the results. +- `weights::Union{AbstractMatrix{T}, AbstractVector{T}, Nothing}=nothing`: Optionally + weight the loss for each `y` by this value (same shape as `y`). +- `options::Options=Options()`: The options for the search, such as + which operators to use, evolution hyperparameters, etc. +- `variable_names::Union{Vector{String}, Nothing}=nothing`: The names + of each feature in `X`, which will be used during printing of equations. +- `display_variable_names::Union{Vector{String}, Nothing}=variable_names`: Names + to use when printing expressions during the search, but not when saving + to an equation file. +- `y_variable_names::Union{String,AbstractVector{String},Nothing}=nothing`: The + names of each output feature in `y`, which will be used during printing + of equations. +- `parallelism=:multithreading`: What parallelism mode to use. + The options are `:multithreading`, `:multiprocessing`, and `:serial`. + By default, multithreading will be used. Multithreading uses less memory, + but multiprocessing can handle multi-node compute. If using `:multithreading` + mode, the number of threads available to julia are used. If using + `:multiprocessing`, `numprocs` processes will be created dynamically if + `procs` is unset. If you have already allocated processes, pass them + to the `procs` argument and they will be used. + You may also pass a string instead of a symbol, like `"multithreading"`. +- `numprocs::Union{Int, Nothing}=nothing`: The number of processes to use, + if you want `equation_search` to set this up automatically. By default + this will be `4`, but can be any number (you should pick a number <= + the number of cores available). +- `procs::Union{Vector{Int}, Nothing}=nothing`: If you have set up + a distributed run manually with `procs = addprocs()` and `@everywhere`, + pass the `procs` to this keyword argument. +- `addprocs_function::Union{Function, Nothing}=nothing`: If using multiprocessing + (`parallelism=:multithreading`), and are not passing `procs` manually, + then they will be allocated dynamically using `addprocs`. However, + you may also pass a custom function to use instead of `addprocs`. + This function should take a single positional argument, + which is the number of processes to use, as well as the `lazy` keyword argument. + For example, if set up on a slurm cluster, you could pass + `addprocs_function = addprocs_slurm`, which will set up slurm processes. +- `heap_size_hint_in_bytes::Union{Int,Nothing}=nothing`: On Julia 1.9+, you may set the `--heap-size-hint` + flag on Julia processes, recommending garbage collection once a process + is close to the recommended size. This is important for long-running distributed + jobs where each process has an independent memory, and can help avoid + out-of-memory errors. By default, this is set to `Sys.free_memory() / numprocs`. +- `runtests::Bool=true`: Whether to run (quick) tests before starting the + search, to see if there will be any problems during the equation search + related to the host environment. +- `saved_state=nothing`: If you have already + run `equation_search` and want to resume it, pass the state here. + To get this to work, you need to have set return_state=true, + which will cause `equation_search` to return the state. The second + element of the state is the regular return value with the hall of fame. + Note that you cannot change the operators or dataset, but most other options + should be changeable. +- `return_state::Union{Bool, Nothing}=nothing`: Whether to return the + state of the search for warm starts. By default this is false. +- `loss_type::Type=Nothing`: If you would like to use a different type + for the loss than for the data you passed, specify the type here. + Note that if you pass complex data `::Complex{L}`, then the loss + type will automatically be set to `L`. +- `verbosity`: Whether to print debugging statements or not. +- `progress`: Whether to use a progress bar output. Only available for + single target output. +- `X_units::Union{AbstractVector,Nothing}=nothing`: The units of the dataset, + to be used for dimensional constraints. For example, if `X_units=["kg", "m"]`, + then the first feature will have units of kilograms, and the second will + have units of meters. +- `y_units=nothing`: The units of the output, to be used for dimensional constraints. + If `y` is a matrix, then this can be a vector of units, in which case + each element corresponds to each output feature. + +# Returns +- `hallOfFame::HallOfFame`: The best equations seen during the search. + hallOfFame.members gives an array of `PopMember` objects, which + have their tree (equation) stored in `.tree`. Their score (loss) + is given in `.score`. The array of `PopMember` objects + is enumerated by size from `1` to `options.maxsize`. +""" +function equation_search( + X::AbstractMatrix{T}, + y::AbstractMatrix{T}; + niterations::Int=10, + weights::Union{AbstractMatrix{T},AbstractVector{T},Nothing}=nothing, + options::Options=Options(), + variable_names::Union{AbstractVector{String},Nothing}=nothing, + display_variable_names::Union{AbstractVector{String},Nothing}=variable_names, + y_variable_names::Union{String,AbstractVector{String},Nothing}=nothing, + parallelism=:multithreading, + numprocs::Union{Int,Nothing}=nothing, + procs::Union{Vector{Int},Nothing}=nothing, + addprocs_function::Union{Function,Nothing}=nothing, + heap_size_hint_in_bytes::Union{Integer,Nothing}=nothing, + runtests::Bool=true, + saved_state=nothing, + return_state::Union{Bool,Nothing,Val}=nothing, + loss_type::Type{L}=Nothing, + verbosity::Union{Integer,Nothing}=nothing, + progress::Union{Bool,Nothing}=nothing, + X_units::Union{AbstractVector,Nothing}=nothing, + y_units=nothing, + extra::NamedTuple=NamedTuple(), + v_dim_out::Val{DIM_OUT}=Val(nothing), + # Deprecated: + multithreaded=nothing, + varMap=nothing, +) where {T<:DATA_TYPE,L,DIM_OUT} + if multithreaded !== nothing + error( + "`multithreaded` is deprecated. Use the `parallelism` argument instead. " * + "Choose one of :multithreaded, :multiprocessing, or :serial.", + ) + end + variable_names = deprecate_varmap(variable_names, varMap, :equation_search) + + if weights !== nothing + @assert length(weights) == length(y) + weights = reshape(weights, size(y)) + end + + datasets = construct_datasets( + X, + y, + weights, + variable_names, + display_variable_names, + y_variable_names, + X_units, + y_units, + extra, + L, + ) + + return equation_search( + datasets; + niterations=niterations, + options=options, + parallelism=parallelism, + numprocs=numprocs, + procs=procs, + addprocs_function=addprocs_function, + heap_size_hint_in_bytes=heap_size_hint_in_bytes, + runtests=runtests, + saved_state=saved_state, + return_state=return_state, + verbosity=verbosity, + progress=progress, + v_dim_out=Val(DIM_OUT), + ) +end + +function equation_search( + X::AbstractMatrix{T1}, y::AbstractMatrix{T2}; kw... +) where {T1<:DATA_TYPE,T2<:DATA_TYPE} + U = promote_type(T1, T2) + return equation_search( + convert(AbstractMatrix{U}, X), convert(AbstractMatrix{U}, y); kw... + ) +end + +function equation_search( + X::AbstractMatrix{T1}, y::AbstractVector{T2}; kw... +) where {T1<:DATA_TYPE,T2<:DATA_TYPE} + return equation_search(X, reshape(y, (1, size(y, 1))); kw..., v_dim_out=Val(1)) +end + +function equation_search(dataset::Dataset; kws...) + return equation_search([dataset]; kws..., v_dim_out=Val(1)) +end + +function equation_search( + datasets::Vector{D}; + niterations::Int=10, + options::Options=Options(), + parallelism=:multithreading, + numprocs::Union{Int,Nothing}=nothing, + procs::Union{Vector{Int},Nothing}=nothing, + addprocs_function::Union{Function,Nothing}=nothing, + heap_size_hint_in_bytes::Union{Integer,Nothing}=nothing, + runtests::Bool=true, + saved_state=nothing, + return_state::Union{Bool,Nothing,Val}=nothing, + verbosity::Union{Int,Nothing}=nothing, + progress::Union{Bool,Nothing}=nothing, + v_dim_out::Val{DIM_OUT}=Val(nothing), +) where {DIM_OUT,T<:DATA_TYPE,L<:LOSS_TYPE,D<:Dataset{T,L}} + concurrency = if parallelism in (:multithreading, "multithreading") + :multithreading + elseif parallelism in (:multiprocessing, "multiprocessing") + :multiprocessing + elseif parallelism in (:serial, "serial") + :serial + else + error( + "Invalid parallelism mode: $parallelism. " * + "You must choose one of :multithreading, :multiprocessing, or :serial.", + ) + :serial + end + not_distributed = concurrency in (:multithreading, :serial) + not_distributed && + procs !== nothing && + error( + "`procs` should not be set when using `parallelism=$(parallelism)`. Please use `:multiprocessing`.", + ) + not_distributed && + numprocs !== nothing && + error( + "`numprocs` should not be set when using `parallelism=$(parallelism)`. Please use `:multiprocessing`.", + ) + + _return_state = if return_state isa Val + first(typeof(return_state).parameters) + else + if options.return_state === Val(nothing) + return_state === nothing ? false : return_state + else + @assert( + return_state === nothing, + "You cannot set `return_state` in both the `Options` and in the passed arguments." + ) + first(typeof(options.return_state).parameters) + end + end + + dim_out = if DIM_OUT === nothing + length(datasets) > 1 ? 2 : 1 + else + DIM_OUT + end + _numprocs::Int = if numprocs === nothing + if procs === nothing + 4 + else + length(procs) + end + else + if procs === nothing + numprocs + else + @assert length(procs) == numprocs + numprocs + end + end + + _verbosity = if verbosity === nothing && options.verbosity === nothing + 1 + elseif verbosity === nothing && options.verbosity !== nothing + options.verbosity + elseif verbosity !== nothing && options.verbosity === nothing + verbosity + else + error( + "You cannot set `verbosity` in both the search parameters `Options` and the call to `equation_search`.", + ) + 1 + end + _progress::Bool = if progress === nothing && options.progress === nothing + (_verbosity > 0) && length(datasets) == 1 + elseif progress === nothing && options.progress !== nothing + options.progress + elseif progress !== nothing && options.progress === nothing + progress + else + error( + "You cannot set `progress` in both the search parameters `Options` and the call to `equation_search`.", + ) + false + end + + _addprocs_function = addprocs_function === nothing ? addprocs : addprocs_function + + exeflags = if VERSION >= v"1.9" && concurrency == :multiprocessing + heap_size_hint_in_megabytes = floor( + Int, ( + if heap_size_hint_in_bytes === nothing + (Sys.free_memory() / _numprocs) + else + heap_size_hint_in_bytes + end + ) / 1024^2 + ) + _verbosity > 0 && + heap_size_hint_in_bytes === nothing && + @info "Automatically setting `--heap-size-hint=$(heap_size_hint_in_megabytes)M` on each Julia process. You can configure this with the `heap_size_hint_in_bytes` parameter." + + `--heap-size=$(heap_size_hint_in_megabytes)M` + else + `` + end + + # Underscores here mean that we have mutated the variable + return _equation_search( + datasets, + RuntimeOptions(; + niterations=niterations, + total_cycles=options.populations * niterations, + numprocs=_numprocs, + init_procs=procs, + addprocs_function=_addprocs_function, + exeflags=exeflags, + runtests=runtests, + verbosity=_verbosity, + progress=_progress, + parallelism=Val(concurrency), + dim_out=Val(dim_out), + return_state=Val(_return_state), + ), + options, + saved_state, + ) +end + +@noinline function _equation_search( + datasets::Vector{D}, ropt::RuntimeOptions, options::Options, saved_state +) where {D<:Dataset} + # PROMPT EVOLUTION + idea_database_all = [Vector{String}() for j in 1:length(datasets)] + + _validate_options(datasets, ropt, options) + state = _create_workers(datasets, ropt, options) + _initialize_search!(state, datasets, ropt, options, saved_state, idea_database_all) + _warmup_search!(state, datasets, ropt, options, idea_database_all) + _main_search_loop!(state, datasets, ropt, options, idea_database_all) + _tear_down!(state, ropt, options) + return _format_output(state, datasets, ropt, options) +end + +function _validate_options( + datasets::Vector{D}, ropt::RuntimeOptions, options::Options +) where {T,L,D<:Dataset{T,L}} + example_dataset = first(datasets) + nout = length(datasets) + @assert nout >= 1 + @assert (nout == 1 || ropt.dim_out == 2) + @assert options.populations >= 1 + if ropt.progress + @assert(nout == 1, "You cannot display a progress bar for multi-output searches.") + @assert(ropt.verbosity > 0, "You cannot display a progress bar with `verbosity=0`.") + end + if options.node_type <: GraphNode && ropt.verbosity > 0 + @warn "The `GraphNode` interface and mutation operators are experimental and will change in future versions." + end + if ropt.runtests + test_option_configuration(ropt.parallelism, datasets, options, ropt.verbosity) + test_dataset_configuration(example_dataset, options, ropt.verbosity) + end + for dataset in datasets + update_baseline_loss!(dataset, options) + end + if options.define_helper_functions + set_default_variable_names!(first(datasets).variable_names) + end + if options.seed !== nothing + seed!(options.seed) + end + return nothing +end +@stable default_mode = "disable" function _create_workers( + datasets::Vector{D}, ropt::RuntimeOptions, options::Options +) where {T,L,D<:Dataset{T,L}} + stdin_reader = watch_stream(stdin) + + record = RecordType() + @recorder record["options"] = "$(options)" + + nout = length(datasets) + example_dataset = first(datasets) + example_ex = create_expression(zero(T), options, example_dataset) + NT = typeof(example_ex) + PopType = Population{T,L,NT} + HallOfFameType = HallOfFame{T,L,NT} + WorkerOutputType = get_worker_output_type( + Val(ropt.parallelism), PopType, HallOfFameType + ) + ChannelType = ropt.parallelism == :multiprocessing ? RemoteChannel : Channel + + # Pointers to populations on each worker: + worker_output = Vector{WorkerOutputType}[WorkerOutputType[] for j in 1:nout] + # Initialize storage for workers + tasks = [Task[] for j in 1:nout] + # Set up a channel to send finished populations back to head node + channels = [[ChannelType(1) for i in 1:(options.populations)] for j in 1:nout] + (procs, we_created_procs) = if ropt.parallelism == :multiprocessing + configure_workers(; + procs=ropt.init_procs, + ropt.numprocs, + ropt.addprocs_function, + options, + project_path=splitdir(Pkg.project().path)[1], + file=@__FILE__, + ropt.exeflags, + ropt.verbosity, + example_dataset, + ropt.runtests, + ) + else + Int[], false + end + # Get the next worker process to give a job: + worker_assignment = WorkerAssignments() + # Randomly order which order to check populations: + # This is done so that we do work on all nout equally. + task_order = [(j, i) for j in 1:nout for i in 1:(options.populations)] + shuffle!(task_order) + + # Persistent storage of last-saved population for final return: + last_pops = init_dummy_pops(options.populations, datasets, options) + # Best 10 members from each population for migration: + best_sub_pops = init_dummy_pops(options.populations, datasets, options) + # TODO: Should really be one per population too. + all_running_search_statistics = [ + RunningSearchStatistics(; options=options) for j in 1:nout + ] + # Records the number of evaluations: + # Real numbers indicate use of batching. + num_evals = [[0.0 for i in 1:(options.populations)] for j in 1:nout] + + halls_of_fame = Vector{HallOfFameType}(undef, nout) + + cycles_remaining = [ropt.total_cycles for j in 1:nout] + cur_maxsizes = [ + get_cur_maxsize(; options, ropt.total_cycles, cycles_remaining=cycles_remaining[j]) + for j in 1:nout + ] + + return SearchState{T,L,typeof(example_ex),WorkerOutputType,ChannelType}(; + procs=procs, + we_created_procs=we_created_procs, + worker_output=worker_output, + tasks=tasks, + channels=channels, + worker_assignment=worker_assignment, + task_order=task_order, + halls_of_fame=halls_of_fame, + last_pops=last_pops, + best_sub_pops=best_sub_pops, + all_running_search_statistics=all_running_search_statistics, + num_evals=num_evals, + cycles_remaining=cycles_remaining, + cur_maxsizes=cur_maxsizes, + stdin_reader=stdin_reader, + record=Ref(record), + ) +end +function _initialize_search!( + state::SearchState{T,L,N}, + datasets, + ropt::RuntimeOptions, + options::Options, + saved_state, + idea_database_all, +) where {T,L,N} + nout = length(datasets) + + init_hall_of_fame = load_saved_hall_of_fame(saved_state) + if init_hall_of_fame === nothing + for j in 1:nout + state.halls_of_fame[j] = HallOfFame(options, datasets[j]) + end + else + # Recompute losses for the hall of fame, in + # case the dataset changed: + for j in eachindex(init_hall_of_fame, datasets, state.halls_of_fame) + hof = strip_metadata(init_hall_of_fame[j], options, datasets[j]) + for member in hof.members[hof.exists] + score, result_loss = score_func(datasets[j], member, options) + member.score = score + member.loss = result_loss + end + state.halls_of_fame[j] = hof + end + end + + for j in 1:nout, i in 1:(options.populations) + worker_idx = assign_next_worker!( + state.worker_assignment; out=j, pop=i, parallelism=ropt.parallelism, state.procs + ) + saved_pop = load_saved_population(saved_state; out=j, pop=i) + new_pop = + if saved_pop !== nothing && length(saved_pop.members) == options.population_size + _saved_pop = strip_metadata(saved_pop, options, datasets[j]) + ## Update losses: + for member in _saved_pop.members + score, result_loss = score_func(datasets[j], member, options) + member.score = score + member.loss = result_loss + end + copy_pop = copy(_saved_pop) + @sr_spawner( + begin + (copy_pop, HallOfFame(options, datasets[j]), RecordType(), 0.0) + end, + parallelism = ropt.parallelism, + worker_idx = worker_idx + ) + else + if saved_pop !== nothing && ropt.verbosity > 0 + @warn "Recreating population (output=$(j), population=$(i)), as the saved one doesn't have the correct number of members." + end + @sr_spawner( + begin + ( + Population( + datasets[j]; + population_size=options.population_size, + nlength=3, + options=options, + nfeatures=datasets[j].nfeatures, + idea_database=idea_database_all[j], + ), + HallOfFame(options, datasets[j]), + RecordType(), + Float64(options.population_size), + ) + end, + parallelism = ropt.parallelism, + worker_idx = worker_idx + ) + # This involves population_size evaluations, on the full dataset: + end + push!(state.worker_output[j], new_pop) + end + return nothing +end +function _warmup_search!( + state::SearchState{T,L,N}, + datasets, + ropt::RuntimeOptions, + options::Options, + idea_database_all, +) where {T,L,N} + nout = length(datasets) + for j in 1:nout, i in 1:(options.populations) + dataset = datasets[j] + running_search_statistics = state.all_running_search_statistics[j] + cur_maxsize = state.cur_maxsizes[j] + @recorder state.record[]["out$(j)_pop$(i)"] = RecordType() + worker_idx = assign_next_worker!( + state.worker_assignment; out=j, pop=i, parallelism=ropt.parallelism, state.procs + ) + + # TODO - why is this needed?? + # Multi-threaded doesn't like to fetch within a new task: + c_rss = deepcopy(running_search_statistics) + last_pop = state.worker_output[j][i] + updated_pop = @sr_spawner( + begin + in_pop = first( + extract_from_worker(last_pop, Population{T,L,N}, HallOfFame{T,L,N}) + ) + _dispatch_s_r_cycle( + in_pop, + dataset, + options; + pop=i, + out=j, + iteration=0, + ropt.verbosity, + cur_maxsize, + running_search_statistics=c_rss, + idea_database=idea_database_all[j], + )::DefaultWorkerOutputType{Population{T,L,N},HallOfFame{T,L,N}} + end, + parallelism = ropt.parallelism, + worker_idx = worker_idx + ) + state.worker_output[j][i] = updated_pop + end + return nothing +end +function _main_search_loop!( + state::SearchState{T,L,N}, + datasets, + ropt::RuntimeOptions, + options::Options, + idea_database_all, +) where {T,L,N} + ropt.verbosity > 0 && @info "Started!" + nout = length(datasets) + start_time = time() + if ropt.progress + #TODO: need to iterate this on the max cycles remaining! + sum_cycle_remaining = sum(state.cycles_remaining) + progress_bar = WrappedProgressBar( + 1:sum_cycle_remaining; width=options.terminal_width + ) + end + last_print_time = time() + last_speed_recording_time = time() + num_evals_last = sum(sum, state.num_evals) + num_evals_since_last = sum(sum, state.num_evals) - num_evals_last # i.e., start at 0 + print_every_n_seconds = 5 + equation_speed = Float32[] + + if ropt.parallelism in (:multiprocessing, :multithreading) + for j in 1:nout, i in 1:(options.populations) + # Start listening for each population to finish: + t = @async put!(state.channels[j][i], fetch(state.worker_output[j][i])) + push!(state.tasks[j], t) + end + end + kappa = 0 + resource_monitor = ResourceMonitor(; + # Storing n times as many monitoring intervals as populations seems like it will + # help get accurate resource estimates: + max_recordings=options.populations * 100 * nout, + start_reporting_at=options.populations * 3 * nout, + window_size=options.populations * 2 * nout, + ) + n_iterations = 0 + llm_recorder(options.llm_options, string(div(n_iterations, options.populations)), "n_iterations") + worst_members = Vector{PopMember}() + while sum(state.cycles_remaining) > 0 + kappa += 1 + if kappa > options.populations * nout + kappa = 1 + end + # nout, populations: + j, i = state.task_order[kappa] + idea_database = idea_database_all[j] + + # Check if error on population: + if ropt.parallelism in (:multiprocessing, :multithreading) + if istaskfailed(state.tasks[j][i]) + fetch(state.tasks[j][i]) + error("Task failed for population") + end + end + # Non-blocking check if a population is ready: + population_ready = if ropt.parallelism in (:multiprocessing, :multithreading) + # TODO: Implement type assertions based on parallelism. + isready(state.channels[j][i]) + else + true + end + record_channel_state!(resource_monitor, population_ready) + + # Don't start more if this output has finished its cycles: + # TODO - this might skip extra cycles? + population_ready &= (state.cycles_remaining[j] > 0) + if population_ready + if n_iterations % options.populations == 0 + worst_members = Vector{PopMember}() + end + n_iterations += 1 + # Take the fetch operation from the channel since its ready + (cur_pop, best_seen, cur_record, cur_num_evals) = if ropt.parallelism in + ( + :multiprocessing, :multithreading + ) + take!( + state.channels[j][i] + ) + else + state.worker_output[j][i] + end::DefaultWorkerOutputType{Population{T,L,N},HallOfFame{T,L,N}} + state.last_pops[j][i] = copy(cur_pop) + state.best_sub_pops[j][i] = best_sub_pop(cur_pop; topn=options.topn) + @recorder state.record[] = recursive_merge(state.record[], cur_record) + state.num_evals[j][i] += cur_num_evals + dataset = datasets[j] + cur_maxsize = state.cur_maxsizes[j] + + worst_member = nothing + for member in cur_pop.members + if worst_member == nothing || worst_member.loss < member.loss + worst_member = member + end + size = compute_complexity(member, options) + update_frequencies!(state.all_running_search_statistics[j]; size) + end + if worst_member != nothing && worst_member.loss > 100 # if the worst of population is good then thats still good to keep + push!(worst_members, worst_member) + end + #! format: off + update_hall_of_fame!(state.halls_of_fame[j], cur_pop.members, options) + update_hall_of_fame!(state.halls_of_fame[j], best_seen.members[best_seen.exists], options) + #! format: on + + # Dominating pareto curve - must be better than all simpler equations + dominating = calculate_pareto_frontier(state.halls_of_fame[j]) + if options.llm_options.active && + options.llm_options.prompt_evol && + (n_iterations % options.populations == 0) + update_idea_database(idea_database, dominating, worst_members, options) + end + + if options.save_to_file + save_to_file(dominating, nout, j, dataset, options) + end + ################################################################### + # Migration ####################################################### + if options.migration + best_of_each = Population([ + member for pop in state.best_sub_pops[j] for member in pop.members + ]) + migrate!( + best_of_each.members => cur_pop, options; frac=options.fraction_replaced + ) + end + if options.hof_migration && length(dominating) > 0 + migrate!(dominating => cur_pop, options; frac=options.fraction_replaced_hof) + end + ################################################################### + + state.cycles_remaining[j] -= 1 + if state.cycles_remaining[j] == 0 + break + end + worker_idx = assign_next_worker!( + state.worker_assignment; + out=j, + pop=i, + parallelism=ropt.parallelism, + state.procs, + ) + iteration = if options.use_recorder + key = "out$(j)_pop$(i)" + find_iteration_from_record(key, state.record[]) + 1 + else + 0 + end + + c_rss = deepcopy(state.all_running_search_statistics[j]) + in_pop = copy(cur_pop::Population{T,L,N}) + state.worker_output[j][i] = @sr_spawner( + begin + _dispatch_s_r_cycle( + in_pop, + dataset, + options; + pop=i, + out=j, + iteration, + ropt.verbosity, + cur_maxsize, + running_search_statistics=c_rss, + dominating=dominating, + idea_database=idea_database, + ) + end, + parallelism = ropt.parallelism, + worker_idx = worker_idx + ) + if ropt.parallelism in (:multiprocessing, :multithreading) + state.tasks[j][i] = @async put!( + state.channels[j][i], fetch(state.worker_output[j][i]) + ) + end + + state.cur_maxsizes[j] = get_cur_maxsize(; + options, ropt.total_cycles, cycles_remaining=state.cycles_remaining[j] + ) + move_window!(state.all_running_search_statistics[j]) + if ropt.progress + head_node_occupation = estimate_work_fraction(resource_monitor) + update_progress_bar!( + progress_bar, + only(state.halls_of_fame), + only(datasets), + options, + equation_speed, + head_node_occupation, + ropt.parallelism, + ) + end + end + yield() + + ################################################################ + ## Search statistics + elapsed_since_speed_recording = time() - last_speed_recording_time + if elapsed_since_speed_recording > 1.0 + num_evals_since_last, num_evals_last = let s = sum(sum, state.num_evals) + s - num_evals_last, s + end + current_speed = num_evals_since_last / elapsed_since_speed_recording + push!(equation_speed, current_speed) + average_over_m_measurements = 20 # 20 second running average + if length(equation_speed) > average_over_m_measurements + deleteat!(equation_speed, 1) + end + last_speed_recording_time = time() + end + ################################################################ + + ################################################################ + ## Printing code + elapsed = time() - last_print_time + # Update if time has passed + if elapsed > print_every_n_seconds + if ropt.verbosity > 0 && !ropt.progress && length(equation_speed) > 0 + + # Dominating pareto curve - must be better than all simpler equations + head_node_occupation = estimate_work_fraction(resource_monitor) + print_search_state( + state.halls_of_fame, + datasets; + options, + equation_speed, + ropt.total_cycles, + state.cycles_remaining, + head_node_occupation, + parallelism=ropt.parallelism, + width=options.terminal_width, + ) + end + last_print_time = time() + end + ################################################################ + + ################################################################ + ## Early stopping code + if any(( + check_for_loss_threshold(state.halls_of_fame, options), + check_for_user_quit(state.stdin_reader), + check_for_timeout(start_time, options), + check_max_evals(state.num_evals, options), + )) + break + end + ################################################################ + end + llm_recorder(options.llm_options, string(div(n_iterations, options.populations)), "n_iterations") + return nothing +end +function _tear_down!(state::SearchState, ropt::RuntimeOptions, options::Options) + close_reader!(state.stdin_reader) + # Safely close all processes or threads + if ropt.parallelism == :multiprocessing + state.we_created_procs && rmprocs(state.procs) + elseif ropt.parallelism == :multithreading + nout = length(state.worker_output) + for j in 1:nout, i in eachindex(state.worker_output[j]) + wait(state.worker_output[j][i]) + end + end + @recorder json3_write(state.record[], options.recorder_file) + return nothing +end +function _format_output( + state::SearchState, datasets, ropt::RuntimeOptions, options::Options +) + nout = length(datasets) + out_hof = if ropt.dim_out == 1 + embed_metadata(only(state.halls_of_fame), options, only(datasets)) + else + map(j -> embed_metadata(state.halls_of_fame[j], options, datasets[j]), 1:nout) + end + if ropt.return_state + return ( + map(j -> embed_metadata(state.last_pops[j], options, datasets[j]), 1:nout), + out_hof, + ) + else + return out_hof + end +end + +@stable default_mode = "disable" function _dispatch_s_r_cycle( + in_pop::Population{T,L,N}, + dataset::Dataset, + options::Options; + pop::Int, + out::Int, + iteration::Int, + verbosity, + cur_maxsize::Int, + running_search_statistics, + dominating=nothing, + idea_database=nothing, +) where {T,L,N} + record = RecordType() + @recorder record["out$(out)_pop$(pop)"] = RecordType( + "iteration$(iteration)" => record_population(in_pop, options) + ) + num_evals = 0.0 + normalize_frequencies!(running_search_statistics) + out_pop, best_seen, evals_from_cycle = s_r_cycle( + dataset, + in_pop, + options.ncycles_per_iteration, + cur_maxsize, + running_search_statistics; + verbosity=verbosity, + options=options, + record=record, + dominating=dominating, + idea_database=idea_database, + ) + num_evals += evals_from_cycle + out_pop, evals_from_optimize = optimize_and_simplify_population( + dataset, out_pop, options, cur_maxsize, record + ) + num_evals += evals_from_optimize + if options.batching + for i_member in 1:(options.maxsize + MAX_DEGREE) + score, result_loss = score_func(dataset, best_seen.members[i_member], options) + best_seen.members[i_member].score = score + best_seen.members[i_member].loss = result_loss + num_evals += 1 + end + end + return (out_pop, best_seen, record, num_evals) +end + +include("MLJInterface.jl") +using .MLJInterfaceModule: LaSRRegressor, MultitargetLaSRRegressor + +function __init__() + @require_extensions +end + +# Hack to get static analysis to work from within tests: +@ignore include("../test/runtests.jl") + +# TODO: Hack to force ConstructionBase version +using ConstructionBase: ConstructionBase as _ + +include("precompile.jl") +redirect_stdout(devnull) do + redirect_stderr(devnull) do + do_precompilation(Val(:precompile)) + end +end + +end #module SR diff --git a/src/LossFunctions.jl b/src/LossFunctions.jl index ac5493234..a84218879 100644 --- a/src/LossFunctions.jl +++ b/src/LossFunctions.jl @@ -1,12 +1,12 @@ module LossFunctionsModule -using Random: MersenneTwister using StatsBase: StatsBase -using DynamicExpressions: AbstractExpressionNode, Node, constructorof +using DynamicExpressions: + AbstractExpression, AbstractExpressionNode, get_tree, eval_tree_array using LossFunctions: LossFunctions using LossFunctions: SupervisedLoss -using ..InterfaceDynamicExpressionsModule: eval_tree_array -using ..CoreModule: Options, Dataset, DATA_TYPE, LOSS_TYPE +using ..InterfaceDynamicExpressionsModule: expected_array_type +using ..CoreModule: Options, Dataset, create_expression, DATA_TYPE, LOSS_TYPE using ..ComplexityModule: compute_complexity using ..DimensionalAnalysisModule: violates_dimensional_constraints @@ -25,7 +25,7 @@ function _weighted_loss( x::AbstractArray{T}, y::AbstractArray{T}, w::AbstractArray{T}, loss::LT ) where {T<:DATA_TYPE,LT<:Union{Function,SupervisedLoss}} if loss isa SupervisedLoss - return LossFunctions.sum(loss, x, y, w; normalize=true) + return sum(loss, x, y, w; normalize=true) else l(i) = loss(x[i], y[i], w[i]) return sum(l, eachindex(x)) / sum(w) @@ -41,17 +41,25 @@ end end end +function eval_tree_dispatch( + tree::Union{AbstractExpression{T},AbstractExpressionNode{T}}, + dataset::Dataset{T}, + options::Options, + idx, +) where {T<:DATA_TYPE} + A = expected_array_type(dataset.X) + return eval_tree_array(tree, maybe_getindex(dataset.X, :, idx), options)::Tuple{A,Bool} +end + # Evaluate the loss of a particular expression on the input dataset. function _eval_loss( - tree::AbstractExpressionNode{T}, + tree::Union{AbstractExpression{T},AbstractExpressionNode{T}}, dataset::Dataset{T,L}, options::Options, regularization::Bool, idx, )::L where {T<:DATA_TYPE,L<:LOSS_TYPE} - (prediction, completion) = eval_tree_array( - tree, maybe_getindex(dataset.X, :, idx), options - ) + (prediction, completion) = eval_tree_dispatch(tree, dataset, options, idx) if !completion return L(Inf) end @@ -95,7 +103,7 @@ end # Evaluate the loss of a particular expression on the input dataset. function eval_loss( - tree::AbstractExpressionNode{T}, + tree::Union{AbstractExpression{T},AbstractExpressionNode{T}}, dataset::Dataset{T,L}, options::Options; regularization::Bool=true, @@ -105,14 +113,14 @@ function eval_loss( _eval_loss(tree, dataset, options, regularization, idx) else f = options.loss_function::Function - evaluator(f, tree, dataset, options, idx) + evaluator(f, get_tree(tree), dataset, options, idx) end return loss_val end function eval_loss_batched( - tree::AbstractExpressionNode{T}, + tree::Union{AbstractExpression{T},AbstractExpressionNode{T}}, dataset::Dataset{T,L}, options::Options; regularization::Bool=true, @@ -127,8 +135,8 @@ function batch_sample(dataset, options) end # Just so we can pass either PopMember or Node here: -get_tree(t::AbstractExpressionNode) = t -get_tree(m) = m.tree +get_tree_from_member(t::Union{AbstractExpression,AbstractExpressionNode}) = t +get_tree_from_member(m) = m.tree # Beware: this is a circular dependency situation... # PopMember is using losses, but then we also want # losses to use the PopMember's cached complexity for trees. @@ -161,7 +169,7 @@ end function score_func( dataset::Dataset{T,L}, member, options::Options; complexity::Union{Int,Nothing}=nothing )::Tuple{L,L} where {T<:DATA_TYPE,L<:LOSS_TYPE} - result_loss = eval_loss(get_tree(member), dataset, options) + result_loss = eval_loss(get_tree_from_member(member), dataset, options) score = loss_to_score( result_loss, dataset.use_baseline, @@ -181,7 +189,7 @@ function score_func_batched( complexity::Union{Int,Nothing}=nothing, idx=nothing, )::Tuple{L,L} where {T<:DATA_TYPE,L<:LOSS_TYPE} - result_loss = eval_loss_batched(get_tree(member), dataset, options; idx=idx) + result_loss = eval_loss_batched(get_tree_from_member(member), dataset, options; idx=idx) score = loss_to_score( result_loss, dataset.use_baseline, @@ -201,7 +209,8 @@ Update the baseline loss of the dataset using the loss function specified in `op function update_baseline_loss!( dataset::Dataset{T,L}, options::Options ) where {T<:DATA_TYPE,L<:LOSS_TYPE} - example_tree = constructorof(options.node_type)(T; val=dataset.avg_y) + example_tree = create_expression(zero(T), options, dataset) + # constructorof(options.node_type)(T; val=dataset.avg_y) # TODO: It could be that the loss function is not defined for this example type? baseline_loss = eval_loss(example_tree, dataset, options) if isfinite(baseline_loss) @@ -215,7 +224,9 @@ function update_baseline_loss!( end function dimensional_regularization( - tree::AbstractExpressionNode{T}, dataset::Dataset{T,L}, options::Options + tree::Union{AbstractExpression{T},AbstractExpressionNode{T}}, + dataset::Dataset{T,L}, + options::Options, ) where {T<:DATA_TYPE,L<:LOSS_TYPE} if !violates_dimensional_constraints(tree, dataset, options) return zero(L) diff --git a/src/MLJInterface.jl b/src/MLJInterface.jl index f0f9943b1..04204d0d9 100644 --- a/src/MLJInterface.jl +++ b/src/MLJInterface.jl @@ -3,7 +3,16 @@ module MLJInterfaceModule using Optim: Optim using LineSearches: LineSearches using MLJModelInterface: MLJModelInterface as MMI -using DynamicExpressions: eval_tree_array, string_tree, AbstractExpressionNode, Node +using ADTypes: AbstractADType +using DynamicExpressions: + eval_tree_array, + string_tree, + AbstractExpressionNode, + AbstractExpression, + Node, + Expression, + default_node_type, + get_tree using DynamicQuantities: QuantityArray, UnionAbstractQuantity, @@ -16,20 +25,28 @@ using DynamicQuantities: using LossFunctions: SupervisedLoss using Compat: allequal, stack using ..InterfaceDynamicQuantitiesModule: get_dimensions_type -using ..CoreModule: Options, Dataset, MutationWeights, LOSS_TYPE +using ..CoreModule: Options, Dataset, MutationWeights, LLMOptions, LOSS_TYPE using ..CoreModule.OptionsModule: DEFAULT_OPTIONS, OPTION_DESCRIPTIONS using ..ComplexityModule: compute_complexity using ..HallOfFameModule: HallOfFame, format_hall_of_fame -using ..UtilsModule: subscriptify +using ..UtilsModule: subscriptify, @ignore import ..equation_search abstract type AbstractSRRegressor <: MMI.Deterministic end +# For static analysis tools: +@ignore mutable struct LaSRRegressor <: AbstractSRRegressor + selection_method::Function +end +@ignore mutable struct MultitargetLaSRRegressor <: AbstractSRRegressor + selection_method::Function +end + # TODO: To reduce code re-use, we could forward these defaults from # `equation_search`, similar to what we do for `Options`. -"""Generate an `SRRegressor` struct containing all the fields in `Options`.""" +"""Generate an `LaSRRegressor` struct containing all the fields in `Options`.""" function modelexpr(model_name::Symbol) struct_def = :(Base.@kwdef mutable struct $(model_name){D<:AbstractDimensions,L} <: AbstractSRRegressor @@ -82,8 +99,8 @@ end """Get an equivalent `Options()` object for a particular regressor.""" function get_options(::AbstractSRRegressor) end -eval(modelexpr(:SRRegressor)) -eval(modelexpr(:MultitargetSRRegressor)) +eval(modelexpr(:LaSRRegressor)) +eval(modelexpr(:MultitargetLaSRRegressor)) # Cleaning already taken care of by `Options` and `equation_search` function full_report( @@ -122,11 +139,27 @@ function MMI.update( m::AbstractSRRegressor, verbosity, old_fitresult, old_cache, X, y, w=nothing ) options = old_fitresult === nothing ? get_options(m) : old_fitresult.options - return _update(m, verbosity, old_fitresult, old_cache, X, y, w, options) + return _update(m, verbosity, old_fitresult, old_cache, X, y, w, options, nothing) end -function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options) +function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options, classes) + if isnothing(classes) && MMI.istable(X) && haskey(X, :classes) + if !(X isa NamedTuple) + error("Classes can only be specified with named tuples.") + end + new_X = Base.structdiff(X, (; X.classes)) + new_classes = X.classes + return _update( + m, verbosity, old_fitresult, old_cache, new_X, y, w, options, new_classes + ) + end + if !isnothing(old_fitresult) + @assert( + old_fitresult.has_classes == !isnothing(classes), + "If the first fit used classes, the second fit must also use classes." + ) + end # To speed up iterative fits, we cache the types: - types = if old_fitresult === nothing + types = if isnothing(old_fitresult) (; T=Any, X_t=Any, @@ -149,7 +182,7 @@ function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options) ) X_units_clean::types.X_units_clean = clean_units(X_units) y_units_clean::types.y_units_clean = clean_units(y_units) - w_t::types.w_t = if w !== nothing && isa(m, MultitargetSRRegressor) + w_t::types.w_t = if w !== nothing && isa(m, MultitargetLaSRRegressor) @assert(isa(w, AbstractVector) && ndims(w) == 1, "Unexpected input for `w`.") repeat(w', size(y_t, 1)) else @@ -174,16 +207,18 @@ function _update(m, verbosity, old_fitresult, old_cache, X, y, w, options) X_units=X_units_clean, y_units=y_units_clean, verbosity=verbosity, + extra=isnothing(classes) ? (;) : (; classes), # Help out with inference: - v_dim_out=isa(m, SRRegressor) ? Val(1) : Val(2), + v_dim_out=isa(m, LaSRRegressor) ? Val(1) : Val(2), ) fitresult = (; state=search_state, - num_targets=isa(m, SRRegressor) ? 1 : size(y_t, 1), + num_targets=isa(m, LaSRRegressor) ? 1 : size(y_t, 1), options=options, variable_names=variable_names, y_variable_names=y_variable_names, y_is_table=MMI.istable(y), + has_classes=!isnothing(classes), X_units=X_units_clean, y_units=y_units_clean, types=( @@ -228,10 +263,10 @@ function get_matrix_and_info(X, ::Type{D}) where {D} return Xm_t_strip, colnames, X_units end -function format_input_for(::SRRegressor, y, ::Type{D}) where {D} +function format_input_for(::LaSRRegressor, y, ::Type{D}) where {D} @assert( !(MMI.istable(y) || (length(size(y)) == 2 && size(y, 2) > 1)), - "For multi-output regression, please use `MultitargetSRRegressor`." + "For multi-output regression, please use `MultitargetLaSRRegressor`." ) y_t = vec(y) colnames = nothing @@ -239,10 +274,10 @@ function format_input_for(::SRRegressor, y, ::Type{D}) where {D} y_t_strip, y_units = unwrap_units_single(y_t, D_promoted) return y_t_strip, colnames, y_units end -function format_input_for(::MultitargetSRRegressor, y, ::Type{D}) where {D} +function format_input_for(::MultitargetLaSRRegressor, y, ::Type{D}) where {D} @assert( MMI.istable(y) || (length(size(y)) == 2 && size(y, 2) > 1), - "For single-output regression, please use `SRRegressor`." + "For single-output regression, please use `LaSRRegressor`." ) return get_matrix_and_info(y, D) end @@ -277,13 +312,13 @@ wrap_units(v, ::Nothing, ::Nothing) = v wrap_units(v, y_units, i::Integer) = (yi -> Quantity(yi, y_units[i])).(v) wrap_units(v, y_units, ::Nothing) = (yi -> Quantity(yi, y_units)).(v) -function prediction_fallback(::Type{T}, ::SRRegressor, Xnew_t, fitresult, _) where {T} +function prediction_fallback(::Type{T}, ::LaSRRegressor, Xnew_t, fitresult, _) where {T} prediction_warn() out = fill!(similar(Xnew_t, T, axes(Xnew_t, 2)), zero(T)) return wrap_units(out, fitresult.y_units, nothing) end function prediction_fallback( - ::Type{T}, ::MultitargetSRRegressor, Xnew_t, fitresult, prototype + ::Type{T}, ::MultitargetLaSRRegressor, Xnew_t, fitresult, prototype ) where {T} prediction_warn() out_cols = [ @@ -333,9 +368,20 @@ function MMI.fitted_params(m::AbstractSRRegressor, fitresult) end function eval_tree_mlj( - tree::Node, X_t, m::AbstractSRRegressor, ::Type{T}, fitresult, i, prototype + tree::AbstractExpression, + X_t, + classes, + m::AbstractSRRegressor, + ::Type{T}, + fitresult, + i, + prototype, ) where {T} - out, completed = eval_tree_array(tree, X_t, fitresult.options) + out, completed = if isnothing(classes) + eval_tree_array(tree, X_t, fitresult.options) + else + eval_tree_array(tree, X_t, classes, fitresult.options) + end if completed return wrap_units(out, fitresult.y_units, i) else @@ -343,13 +389,32 @@ function eval_tree_mlj( end end -function MMI.predict(m::M, fitresult, Xnew; idx=nothing) where {M<:AbstractSRRegressor} +function MMI.predict( + m::M, fitresult, Xnew; idx=nothing, classes=nothing +) where {M<:AbstractSRRegressor} + return _predict(m, fitresult, Xnew, idx, classes) +end +function _predict(m::M, fitresult, Xnew, idx, classes) where {M<:AbstractSRRegressor} if Xnew isa NamedTuple && (haskey(Xnew, :idx) || haskey(Xnew, :data)) @assert( haskey(Xnew, :idx) && haskey(Xnew, :data) && length(keys(Xnew)) == 2, "If specifying an equation index during prediction, you must use a named tuple with keys `idx` and `data`." ) - return MMI.predict(m, fitresult, Xnew.data; idx=Xnew.idx) + return _predict(m, fitresult, Xnew.data, Xnew.idx, classes) + end + if isnothing(classes) && MMI.istable(Xnew) && haskey(Xnew, :classes) + if !(Xnew isa NamedTuple) + error("Classes can only be specified with named tuples.") + end + Xnew2 = Base.structdiff(Xnew, (; Xnew.classes)) + return _predict(m, fitresult, Xnew2, idx, Xnew.classes) + end + + if fitresult.has_classes + @assert( + !isnothing(classes), + "Classes must be specified if the model was fit with classes." + ) end params = full_report(m, fitresult; v_with_strings=Val(false)) @@ -368,14 +433,14 @@ function MMI.predict(m::M, fitresult, Xnew; idx=nothing) where {M<:AbstractSRReg idx = idx === nothing ? params.best_idx : idx - if M <: SRRegressor + if M <: LaSRRegressor return eval_tree_mlj( - params.equations[idx], Xnew_t, m, T, fitresult, nothing, prototype + params.equations[idx], Xnew_t, classes, m, T, fitresult, nothing, prototype ) - elseif M <: MultitargetSRRegressor + elseif M <: MultitargetLaSRRegressor outs = [ eval_tree_mlj( - params.equations[i][idx[i]], Xnew_t, m, T, fitresult, i, prototype + params.equations[i][idx[i]], Xnew_t, classes, m, T, fitresult, i, prototype ) for i in eachindex(idx, params.equations) ] out_matrix = reduce(hcat, outs) @@ -387,10 +452,10 @@ function MMI.predict(m::M, fitresult, Xnew; idx=nothing) where {M<:AbstractSRReg end end -function get_equation_strings_for(::SRRegressor, trees, options, variable_names) +function get_equation_strings_for(::LaSRRegressor, trees, options, variable_names) return (t -> string_tree(t, options; variable_names=variable_names)).(trees) end -function get_equation_strings_for(::MultitargetSRRegressor, trees, options, variable_names) +function get_equation_strings_for(::MultitargetLaSRRegressor, trees, options, variable_names) return [ (t -> string_tree(t, options; variable_names=variable_names)).(ts) for ts in trees ] @@ -407,14 +472,14 @@ function choose_best(; trees, losses::Vector{L}, scores, complexities) where {L< ]) end -function dispatch_selection_for(m::SRRegressor, trees, losses, scores, complexities)::Int +function dispatch_selection_for(m::LaSRRegressor, trees, losses, scores, complexities)::Int length(trees) == 0 && return 0 return m.selection_method(; trees=trees, losses=losses, scores=scores, complexities=complexities ) end function dispatch_selection_for( - m::MultitargetSRRegressor, trees, losses, scores, complexities + m::MultitargetLaSRRegressor, trees, losses, scores, complexities ) any(t -> length(t) == 0, trees) && return fill(0, length(trees)) return [ @@ -426,31 +491,37 @@ end MMI.metadata_pkg( AbstractSRRegressor; - name="SymbolicRegression", + name="LibraryAugmentedSymbolicRegression", uuid="8254be44-1295-4e6a-a16d-46603ac705cb", - url="https://github.com/MilesCranmer/SymbolicRegression.jl", + url="https://github.com/MilesCranmer/LibraryAugmentedSymbolicRegression.jl", julia=true, license="Apache-2.0", is_wrapper=false, ) +const input_scitype = Union{ + MMI.Table(MMI.Continuous), + AbstractMatrix{<:MMI.Continuous}, + MMI.Table(MMI.Continuous, MMI.Count), +} + # TODO: Allow for Count data, and coerce it into Continuous as needed. MMI.metadata_model( - SRRegressor; - input_scitype=Union{MMI.Table(MMI.Continuous),AbstractMatrix{<:MMI.Continuous}}, + LaSRRegressor; + input_scitype, target_scitype=AbstractVector{<:MMI.Continuous}, supports_weights=true, reports_feature_importances=false, - load_path="SymbolicRegression.MLJInterfaceModule.SRRegressor", + load_path="LibraryAugmentedSymbolicRegression.MLJInterfaceModule.LaSRRegressor", human_name="Symbolic Regression via Evolutionary Search", ) MMI.metadata_model( - MultitargetSRRegressor; - input_scitype=Union{MMI.Table(MMI.Continuous),AbstractMatrix{<:MMI.Continuous}}, + MultitargetLaSRRegressor; + input_scitype, target_scitype=Union{MMI.Table(MMI.Continuous),AbstractMatrix{<:MMI.Continuous}}, supports_weights=true, reports_feature_importances=false, - load_path="SymbolicRegression.MLJInterfaceModule.MultitargetSRRegressor", + load_path="LibraryAugmentedSymbolicRegression.MLJInterfaceModule.MultitargetLaSRRegressor", human_name="Multi-Target Symbolic Regression via Evolutionary Search", ) @@ -503,7 +574,7 @@ function tag_with_docstring(model_name::Symbol, description::String, bottom_matt type will automatically be set to `L`. - `selection_method::Function`: Function to selection expression from the Pareto frontier for use in `predict`. - See `SymbolicRegression.MLJInterfaceModule.choose_best` for an example. + See `LibraryAugmentedSymbolicRegression.MLJInterfaceModule.choose_best` for an example. This function should return a single integer specifying the index of the expression to use. By default, this maximizes the score (a pound-for-pound rating) of expressions reaching the threshold @@ -518,11 +589,11 @@ function tag_with_docstring(model_name::Symbol, description::String, bottom_matt # Operations - `predict(mach, Xnew)`: Return predictions of the target given features `Xnew`, which - should have same scitype as `X` above. The expression used for prediction is defined - by the `selection_method` function, which can be seen by viewing `report(mach).best_idx`. + should have same scitype as `X` above. The expression used for prediction is defined + by the `selection_method` function, which can be seen by viewing `report(mach).best_idx`. - `predict(mach, (data=Xnew, idx=i))`: Return predictions of the target given features - `Xnew`, which should have same scitype as `X` above. By passing a named tuple with keys - `data` and `idx`, you are able to specify the equation you wish to evaluate in `idx`. + `Xnew`, which should have same scitype as `X` above. By passing a named tuple with keys + `data` and `idx`, you are able to specify the equation you wish to evaluate in `idx`. $(bottom_matter) """ @@ -544,10 +615,10 @@ end #https://arxiv.org/abs/2305.01582 eval( tag_with_docstring( - :SRRegressor, + :LaSRRegressor, replace( """ - Single-target Symbolic Regression regressor (`SRRegressor`) searches + Single-target Symbolic Regression regressor (`LaSRRegressor`) searches for symbolic expressions that predict a single target variable from a set of input variables. All data is assumed to be `Continuous`. The search is performed using an evolutionary algorithm. @@ -632,9 +703,9 @@ eval( ```julia using MLJ - SRRegressor = @load SRRegressor pkg=SymbolicRegression + LaSRRegressor = @load LaSRRegressor pkg=LibraryAugmentedSymbolicRegression X, y = @load_boston - model = SRRegressor(binary_operators=[+, -, *], unary_operators=[exp], niterations=100) + model = LaSRRegressor(binary_operators=[+, -, *], unary_operators=[exp], niterations=100) mach = machine(model, X, y) fit!(mach) y_hat = predict(mach, X) @@ -648,11 +719,11 @@ eval( ```julia using MLJ using DynamicQuantities - SRegressor = @load SRRegressor pkg=SymbolicRegression + SRegressor = @load LaSRRegressor pkg=LibraryAugmentedSymbolicRegression X = (; x1=rand(32) .* us"km/h", x2=rand(32) .* us"km") y = @. X.x2 / X.x1 + 0.5us"h" - model = SRRegressor(binary_operators=[+, -, *, /]) + model = LaSRRegressor(binary_operators=[+, -, *, /]) mach = machine(model, X, y) fit!(mach) y_hat = predict(mach, X) @@ -661,7 +732,7 @@ eval( println("Equation used:", r.equation_strings[r.best_idx]) ``` - See also [`MultitargetSRRegressor`](@ref). + See also [`MultitargetLaSRRegressor`](@ref). """, r"^ " => "", ), @@ -669,10 +740,10 @@ eval( ) eval( tag_with_docstring( - :MultitargetSRRegressor, + :MultitargetLaSRRegressor, replace( """ - Multi-target Symbolic Regression regressor (`MultitargetSRRegressor`) + Multi-target Symbolic Regression regressor (`MultitargetLaSRRegressor`) conducts several searches for expressions that predict each target variable from a set of input variables. All data is assumed to be `Continuous`. The search is performed using an evolutionary algorithm. @@ -759,10 +830,10 @@ eval( ```julia using MLJ - MultitargetSRRegressor = @load MultitargetSRRegressor pkg=SymbolicRegression + MultitargetLaSRRegressor = @load MultitargetLaSRRegressor pkg=LibraryAugmentedSymbolicRegression X = (a=rand(100), b=rand(100), c=rand(100)) Y = (y1=(@. cos(X.c) * 2.1 - 0.9), y2=(@. X.a * X.b + X.c)) - model = MultitargetSRRegressor(binary_operators=[+, -, *], unary_operators=[exp], niterations=100) + model = MultitargetLaSRRegressor(binary_operators=[+, -, *], unary_operators=[exp], niterations=100) mach = machine(model, X, Y) fit!(mach) y_hat = predict(mach, X) @@ -773,7 +844,7 @@ eval( end ``` - See also [`SRRegressor`](@ref). + See also [`LaSRRegressor`](@ref). """, r"^ " => "", ), diff --git a/src/Migration.jl b/src/Migration.jl index fb6e99d02..daab9255f 100644 --- a/src/Migration.jl +++ b/src/Migration.jl @@ -1,7 +1,7 @@ module MigrationModule using StatsBase: StatsBase -using ..CoreModule: Options, DATA_TYPE, LOSS_TYPE +using ..CoreModule: Options using ..PopulationModule: Population using ..PopMemberModule: PopMember, reset_birth! using ..UtilsModule: poisson_sample diff --git a/src/Mutate.jl b/src/Mutate.jl index 8820359fc..f377174f2 100644 --- a/src/Mutate.jl +++ b/src/Mutate.jl @@ -2,15 +2,16 @@ module MutateModule using DynamicExpressions: AbstractExpressionNode, - Node, + AbstractExpression, + ParametricExpression, + with_contents, + get_tree, preserve_sharing, copy_node, - count_nodes, - count_constants, + count_scalar_constants, simplify_tree!, combine_operators -using ..CoreModule: - Options, MutationWeights, Dataset, RecordType, sample_mutation, DATA_TYPE, LOSS_TYPE +using ..CoreModule: Options, MutationWeights, Dataset, RecordType, sample_mutation using ..ComplexityModule: compute_complexity using ..LossFunctionsModule: score_func, score_func_batched using ..CheckConstraintsModule: check_constraints @@ -28,38 +29,54 @@ using ..MutationFunctionsModule: crossover_trees, form_random_connection!, break_random_connection! + +using ..LLMFunctionsModule: + llm_mutate_op, llm_crossover_trees, tree_to_expr, gen_llm_random_tree, llm_recorder + using ..ConstantOptimizationModule: optimize_constants using ..RecorderModule: @recorder +function check_constant(tree::AbstractExpressionNode)::Bool + return (tree.degree == 0) && tree.constant +end + +function check_constant(tree::AbstractExpression)::Bool + return check_constant(get_tree(tree)) +end + function condition_mutation_weights!( weights::MutationWeights, member::PopMember, options::Options, curmaxsize::Int ) + tree = get_tree(member.tree) if !preserve_sharing(typeof(member.tree)) weights.form_connection = 0.0 weights.break_connection = 0.0 end - if member.tree.degree == 0 + if tree.degree == 0 # If equation is too small, don't delete operators # or simplify weights.mutate_operator = 0.0 weights.swap_operands = 0.0 weights.delete_node = 0.0 weights.simplify = 0.0 - if !member.tree.constant + if !tree.constant weights.optimize = 0.0 weights.mutate_constant = 0.0 end return nothing end - if !any(node -> node.degree == 2, member.tree) + if !any(node -> node.degree == 2, tree) # swap is implemented only for binary ops weights.swap_operands = 0.0 end - #More constants => more likely to do constant mutation - n_constants = count_constants(member.tree) - weights.mutate_constant *= min(8, n_constants) / 8.0 + if !(member.tree isa ParametricExpression) # TODO: HACK + #More constants => more likely to do constant mutation + let n_constants = count_scalar_constants(member.tree) + weights.mutate_constant *= min(8, n_constants) / 8.0 + end + end complexity = compute_complexity(member, options) if complexity >= curmaxsize @@ -85,9 +102,11 @@ function next_generation( running_search_statistics::RunningSearchStatistics, options::Options; tmp_recorder::RecordType, + dominating=nothing, + idea_database=nothing, )::Tuple{ P,Bool,Float64 -} where {T,L,D<:Dataset{T,L},N<:AbstractExpressionNode{T},P<:PopMember{T,L,N}} +} where {T,L,D<:Dataset{T,L},N<:AbstractExpression{T},P<:PopMember{T,L,N}} parent_ref = member.ref mutation_accepted = false num_evals = 0.0 @@ -118,6 +137,28 @@ function next_generation( # Mutations ############################################# local tree + if options.llm_options.active && (rand() < options.llm_options.weights.llm_mutate) + tree = copy_node(member.tree) + if check_constant(tree) + tree = with_contents( + tree, gen_random_tree_fixed_size(rand(1:curmaxsize), options, nfeatures, T) + ) + end + tree = llm_mutate_op(tree, options, dominating, idea_database) + tree = simplify_tree!(tree, options.operators) + tree = combine_operators(tree, options.operators) + @recorder tmp_recorder["type"] = "llm_mutate" + + successful_mutation = + (!check_constant(tree)) && check_constraints(tree, options, curmaxsize) + + if successful_mutation + llm_recorder(options.llm_options, tree_to_expr(tree, options), "mutate") + else + llm_recorder(options.llm_options, tree_to_expr(tree, options), "mutate|failed") + end + end + while (!successful_mutation) && attempts < max_attempts tree = copy_node(member.tree) successful_mutation = true @@ -158,11 +199,10 @@ function next_generation( elseif mutation_choice == :simplify @assert options.should_simplify simplify_tree!(tree, options.operators) - if tree isa Node - tree = combine_operators(tree, options.operators) - end + tree = combine_operators(tree, options.operators) @recorder tmp_recorder["type"] = "partial_simplify" mutation_accepted = true + is_success_always_possible = true return ( PopMember( tree, @@ -175,19 +215,50 @@ function next_generation( mutation_accepted, num_evals, ) - - is_success_always_possible = true # Simplification shouldn't hurt complexity; unless some non-symmetric constraint # to commutative operator... - elseif mutation_choice == :randomize # We select a random size, though the generated tree # may have fewer nodes than we request. tree_size_to_generate = rand(1:curmaxsize) - tree = gen_random_tree_fixed_size(tree_size_to_generate, options, nfeatures, T) - @recorder tmp_recorder["type"] = "regenerate" + if options.llm_options.active && + (rand() < options.llm_options.weights.llm_gen_random) + tree = with_contents( + tree, + combine_operators( + simplify_tree!( + gen_llm_random_tree( + tree_size_to_generate, options, nfeatures, T, idea_database + ), + options.operators, + ), + options.operators, + ), + ) + @recorder tmp_recorder["type"] = "regenerate_llm" + + is_success_always_possible = false + + if check_constant(tree) # don't allow constant outputs + tree = with_contents( + tree, + gen_random_tree_fixed_size( + tree_size_to_generate, options, nfeatures, T + ), + ) + is_success_always_possible = true + end + else + tree = with_contents( + tree, + gen_random_tree_fixed_size( + tree_size_to_generate, options, nfeatures, T + ), + ) + @recorder tmp_recorder["type"] = "regenerate" - is_success_always_possible = true + is_success_always_possible = true + end elseif mutation_choice == :optimize cur_member = PopMember( tree, @@ -202,9 +273,8 @@ function next_generation( num_evals += new_num_evals @recorder tmp_recorder["type"] = "optimize" mutation_accepted = true - return (cur_member, mutation_accepted, num_evals) - is_success_always_possible = true + return (cur_member, mutation_accepted, num_evals) elseif mutation_choice == :do_nothing @recorder begin tmp_recorder["type"] = "identity" @@ -212,6 +282,7 @@ function next_generation( tmp_recorder["reason"] = "identity" end mutation_accepted = true + is_success_always_possible = true return ( PopMember( tree, @@ -243,6 +314,7 @@ function next_generation( attempts += 1 end ############################################# + tree::AbstractExpression if !successful_mutation @recorder begin @@ -359,20 +431,79 @@ end """Generate a generation via crossover of two members.""" function crossover_generation( - member1::P, member2::P, dataset::D, curmaxsize::Int, options::Options -)::Tuple{P,P,Bool,Float64} where {T,L,D<:Dataset{T,L},P<:PopMember{T,L}} + member1::P, + member2::P, + dataset::D, + curmaxsize::Int, + options::Options; + dominating=nothing, + idea_database=nothing, +)::Tuple{P,P,Bool,Float64} where {T,L,D<:Dataset{T,L},N,P<:PopMember{T,L,N}} tree1 = member1.tree tree2 = member2.tree + + # add simplification for crossover + tree1 = simplify_tree!(tree1, options.operators) + tree1 = combine_operators(tree1, options.operators) + tree2 = simplify_tree!(tree2, options.operators) + tree2 = combine_operators(tree2, options.operators) + crossover_accepted = false + nfeatures = dataset.nfeatures + + if check_constant(tree1) + tree1 = with_contents( + tree1, gen_random_tree_fixed_size(rand(1:curmaxsize), options, nfeatures, T) + ) + end + if check_constant(tree2) + tree2 = with_contents( + tree2, gen_random_tree_fixed_size(rand(1:curmaxsize), options, nfeatures, T) + ) + end + + child_tree1 = nothing + child_tree2 = nothing + llm_skip = false + if options.llm_options.active && (rand() < options.llm_options.weights.llm_crossover) + child_tree1, child_tree2 = llm_crossover_trees( + tree1, tree2, options, dominating, idea_database + ) + + child_tree1 = simplify_tree!(child_tree1, options.operators) + child_tree1 = combine_operators(child_tree1, options.operators) + child_tree2 = simplify_tree!(child_tree2, options.operators) + child_tree2 = combine_operators(child_tree2, options.operators) + + afterSize1 = compute_complexity(child_tree1, options) + afterSize2 = compute_complexity(child_tree2, options) + + successful_crossover = + (!check_constant(child_tree1)) && + (!check_constant(child_tree2)) && + check_constraints(child_tree1, options, curmaxsize, afterSize1) && + check_constraints(child_tree2, options, curmaxsize, afterSize2) + + if successful_crossover + recorder_str = tree_to_expr(child_tree1, options) * " && " * tree_to_expr(child_tree2, options) + llm_recorder(options.llm_options, recorder_str, "crossover") + llm_skip = true + else + recorder_str = tree_to_expr(child_tree1, options) * " && " * tree_to_expr(child_tree2, options) + llm_recorder(options.llm_options, recorder_str, "crossover|failed") + child_tree1, child_tree2 = crossover_trees(tree1, tree2) + end + else + child_tree1, child_tree2 = crossover_trees(tree1, tree2) + end # We breed these until constraints are no longer violated: - child_tree1, child_tree2 = crossover_trees(tree1, tree2) num_tries = 1 max_tries = 10 num_evals = 0.0 afterSize1 = -1 afterSize2 = -1 - while true + while !llm_skip afterSize1 = compute_complexity(child_tree1, options) afterSize2 = compute_complexity(child_tree2, options) # Both trees satisfy constraints @@ -413,7 +544,7 @@ function crossover_generation( afterSize1; parent=member1.ref, deterministic=options.deterministic, - ) + )::P baby2 = PopMember( child_tree2, afterScore2, @@ -422,7 +553,7 @@ function crossover_generation( afterSize2; parent=member2.ref, deterministic=options.deterministic, - ) + )::P crossover_accepted = true return baby1, baby2, crossover_accepted, num_evals diff --git a/src/MutationFunctions.jl b/src/MutationFunctions.jl index a45a32cf7..31534054f 100644 --- a/src/MutationFunctions.jl +++ b/src/MutationFunctions.jl @@ -3,8 +3,11 @@ module MutationFunctionsModule using Random: default_rng, AbstractRNG using DynamicExpressions: AbstractExpressionNode, + AbstractExpression, AbstractNode, NodeSampler, + get_contents, + with_contents, constructorof, copy_node, set_node!, @@ -31,6 +34,11 @@ function random_node( end """Swap operands in binary operator for ops like pow and divide""" +function swap_operands(ex::AbstractExpression, rng::AbstractRNG=default_rng()) + tree = get_contents(ex) + ex = with_contents(ex, swap_operands(tree, rng)) + return ex +end function swap_operands(tree::AbstractNode, rng::AbstractRNG=default_rng()) if !any(node -> node.degree == 2, tree) return tree @@ -41,6 +49,13 @@ function swap_operands(tree::AbstractNode, rng::AbstractRNG=default_rng()) end """Randomly convert an operator into another one (binary->binary; unary->unary)""" +function mutate_operator( + ex::AbstractExpression{T}, options::Options, rng::AbstractRNG=default_rng() +) where {T<:DATA_TYPE} + tree = get_contents(ex) + ex = with_contents(ex, mutate_operator(tree, options, rng)) + return ex +end function mutate_operator( tree::AbstractExpressionNode{T}, options::Options, rng::AbstractRNG=default_rng() ) where {T} @@ -57,6 +72,13 @@ function mutate_operator( end """Randomly perturb a constant""" +function mutate_constant( + ex::AbstractExpression{T}, temperature, options::Options, rng::AbstractRNG=default_rng() +) where {T<:DATA_TYPE} + tree = get_contents(ex) + ex = with_contents(ex, mutate_constant(tree, temperature, options, rng)) + return ex +end function mutate_constant( tree::AbstractExpressionNode{T}, temperature, @@ -70,25 +92,39 @@ function mutate_constant( end node = rand(rng, NodeSampler(; tree, filter=t -> (t.degree == 0 && t.constant))) + node.val *= mutate_factor(T, temperature, options, rng) + + return tree +end + +function mutate_factor(::Type{T}, temperature, options, rng) where {T<:DATA_TYPE} bottom = 1//10 maxChange = options.perturbation_factor * temperature + 1 + bottom factor = T(maxChange^rand(rng, T)) makeConstBigger = rand(rng, Bool) - if makeConstBigger - node.val *= factor - else - node.val /= factor - end + factor = makeConstBigger ? factor : 1 / factor if rand(rng) > options.probability_negate_constant - node.val *= -1 + factor *= -1 end - - return tree + return factor end +# TODO: Shouldn't we add a mutate_feature here? + """Add a random unary/binary operation to the end of a tree""" +function append_random_op( + ex::AbstractExpression{T}, + options::Options, + nfeatures::Int, + rng::AbstractRNG=default_rng(); + makeNewBinOp::Union{Bool,Nothing}=nothing, +) where {T<:DATA_TYPE} + tree = get_contents(ex) + ex = with_contents(ex, append_random_op(tree, options, nfeatures, rng; makeNewBinOp)) + return ex +end function append_random_op( tree::AbstractExpressionNode{T}, options::Options, @@ -104,14 +140,15 @@ function append_random_op( end if makeNewBinOp - newnode = constructorof(typeof(tree))( - rand(rng, 1:(options.nbin)), - make_random_leaf(nfeatures, T, typeof(tree), rng), - make_random_leaf(nfeatures, T, typeof(tree), rng), + newnode = constructorof(typeof(tree))(; + op=rand(rng, 1:(options.nbin)), + l=make_random_leaf(nfeatures, T, typeof(tree), rng, options), + r=make_random_leaf(nfeatures, T, typeof(tree), rng, options), ) else - newnode = constructorof(typeof(tree))( - rand(rng, 1:(options.nuna)), make_random_leaf(nfeatures, T, typeof(tree), rng) + newnode = constructorof(typeof(tree))(; + op=rand(rng, 1:(options.nuna)), + l=make_random_leaf(nfeatures, T, typeof(tree), rng, options), ) end @@ -121,6 +158,16 @@ function append_random_op( end """Insert random node""" +function insert_random_op( + ex::AbstractExpression{T}, + options::Options, + nfeatures::Int, + rng::AbstractRNG=default_rng(), +) where {T<:DATA_TYPE} + tree = get_contents(ex) + ex = with_contents(ex, insert_random_op(tree, options, nfeatures, rng)) + return ex +end function insert_random_op( tree::AbstractExpressionNode{T}, options::Options, @@ -133,16 +180,28 @@ function insert_random_op( left = copy_node(node) if makeNewBinOp - right = make_random_leaf(nfeatures, T, typeof(tree), rng) - newnode = constructorof(typeof(tree))(rand(rng, 1:(options.nbin)), left, right) + right = make_random_leaf(nfeatures, T, typeof(tree), rng, options) + newnode = constructorof(typeof(tree))(; + op=rand(rng, 1:(options.nbin)), l=left, r=right + ) else - newnode = constructorof(typeof(tree))(rand(rng, 1:(options.nuna)), left) + newnode = constructorof(typeof(tree))(; op=rand(rng, 1:(options.nuna)), l=left) end set_node!(node, newnode) return tree end """Add random node to the top of a tree""" +function prepend_random_op( + ex::AbstractExpression{T}, + options::Options, + nfeatures::Int, + rng::AbstractRNG=default_rng(), +) where {T<:DATA_TYPE} + tree = get_contents(ex) + ex = with_contents(ex, prepend_random_op(tree, options, nfeatures, rng)) + return ex +end function prepend_random_op( tree::AbstractExpressionNode{T}, options::Options, @@ -155,17 +214,23 @@ function prepend_random_op( left = copy_node(tree) if makeNewBinOp - right = make_random_leaf(nfeatures, T, typeof(tree), rng) - newnode = constructorof(typeof(tree))(rand(rng, 1:(options.nbin)), left, right) + right = make_random_leaf(nfeatures, T, typeof(tree), rng, options) + newnode = constructorof(typeof(tree))(; + op=rand(rng, 1:(options.nbin)), l=left, r=right + ) else - newnode = constructorof(typeof(tree))(rand(rng, 1:(options.nuna)), left) + newnode = constructorof(typeof(tree))(; op=rand(rng, 1:(options.nuna)), l=left) end set_node!(node, newnode) return node end function make_random_leaf( - nfeatures::Int, ::Type{T}, ::Type{N}, rng::AbstractRNG=default_rng() + nfeatures::Int, + ::Type{T}, + ::Type{N}, + rng::AbstractRNG=default_rng(), + ::Union{Options,Nothing}=nothing, ) where {T<:DATA_TYPE,N<:AbstractExpressionNode} if rand(rng, Bool) return constructorof(N)(; val=randn(rng, T)) @@ -188,6 +253,16 @@ function random_node_and_parent(tree::AbstractNode, rng::AbstractRNG=default_rng end """Select a random node, and splice it out of the tree.""" +function delete_random_op!( + ex::AbstractExpression{T}, + options::Options, + nfeatures::Int, + rng::AbstractRNG=default_rng(), +) where {T<:DATA_TYPE} + tree = get_contents(ex) + ex = with_contents(ex, delete_random_op!(tree, options, nfeatures, rng)) + return ex +end function delete_random_op!( tree::AbstractExpressionNode{T}, options::Options, @@ -199,7 +274,7 @@ function delete_random_op!( if node.degree == 0 # Replace with new constant - newnode = make_random_leaf(nfeatures, T, typeof(tree), rng) + newnode = make_random_leaf(nfeatures, T, typeof(tree), rng, options) set_node!(node, newnode) elseif node.degree == 1 # Join one of the children with the parent @@ -253,7 +328,7 @@ function gen_random_tree_fixed_size( ::Type{T}, rng::AbstractRNG=default_rng(), ) where {T<:DATA_TYPE} - tree = make_random_leaf(nfeatures, T, options.node_type, rng) + tree = make_random_leaf(nfeatures, T, options.node_type, rng, options) cur_size = count_nodes(tree) while cur_size < node_count if cur_size == node_count - 1 # only unary operator allowed. @@ -267,12 +342,21 @@ function gen_random_tree_fixed_size( return tree end +function crossover_trees( + ex1::E, ex2::E, rng::AbstractRNG=default_rng() +) where {T,E<:AbstractExpression{T}} + tree1 = get_contents(ex1) + tree2 = get_contents(ex2) + out1, out2 = crossover_trees(tree1, tree2, rng) + ex1 = with_contents(ex1, out1) + ex2 = with_contents(ex2, out2) + return ex1, ex2 +end + """Crossover between two expressions""" function crossover_trees( - tree1::AbstractExpressionNode{T}, - tree2::AbstractExpressionNode{T}, - rng::AbstractRNG=default_rng(), -) where {T} + tree1::N, tree2::N, rng::AbstractRNG=default_rng() +) where {T,N<:AbstractExpressionNode{T}} tree1 = copy_node(tree1) tree2 = copy_node(tree2) @@ -315,6 +399,10 @@ function get_two_nodes_without_loop(tree::AbstractNode, rng::AbstractRNG; max_at return (tree, tree, true) end +function form_random_connection!(ex::AbstractExpression, rng::AbstractRNG=default_rng()) + tree = get_contents(ex) + return with_contents(ex, form_random_connection!(tree, rng)) +end function form_random_connection!(tree::AbstractNode, rng::AbstractRNG=default_rng()) if length(tree) < 5 return tree @@ -334,6 +422,11 @@ function form_random_connection!(tree::AbstractNode, rng::AbstractRNG=default_rn end return tree end + +function break_random_connection!(ex::AbstractExpression, rng::AbstractRNG=default_rng()) + tree = get_contents(ex) + return with_contents(ex, break_random_connection!(tree, rng)) +end function break_random_connection!(tree::AbstractNode, rng::AbstractRNG=default_rng()) tree.degree == 0 && return tree parent = rand(rng, NodeSampler(; tree, filter=t -> t.degree != 0)) diff --git a/src/MutationWeights.jl b/src/MutationWeights.jl index 8549dd810..1f3f7369f 100644 --- a/src/MutationWeights.jl +++ b/src/MutationWeights.jl @@ -18,14 +18,14 @@ will be normalized to sum to 1.0 after initialization. - `randomize::Float64`: How often to create a random tree. - `do_nothing::Float64`: How often to do nothing. - `optimize::Float64`: How often to optimize the constants in the tree, as a mutation. - Note that this is different from `optimizer_probability`, which is - performed at the end of an iteration for all individuals. + Note that this is different from `optimizer_probability`, which is + performed at the end of an iteration for all individuals. - `form_connection::Float64`: **Only used for `GraphNode`, not regular `Node`**. - Otherwise, this will automatically be set to 0.0. How often to form a - connection between two nodes. + Otherwise, this will automatically be set to 0.0. How often to form a + connection between two nodes. - `break_connection::Float64`: **Only used for `GraphNode`, not regular `Node`**. - Otherwise, this will automatically be set to 0.0. How often to break a - connection between two nodes. + Otherwise, this will automatically be set to 0.0. How often to break a + connection between two nodes. """ Base.@kwdef mutable struct MutationWeights mutate_constant::Float64 = 0.048 diff --git a/src/Operators.jl b/src/Operators.jl index cc756f0d6..e7b99ea10 100644 --- a/src/Operators.jl +++ b/src/Operators.jl @@ -1,10 +1,12 @@ module OperatorsModule +using DynamicExpressions: DynamicExpressions as DE using SpecialFunctions: SpecialFunctions using DynamicQuantities: UnionAbstractQuantity using SpecialFunctions: erf, erfc using Base: @deprecate using ..ProgramConstantsModule: DATA_TYPE +using ...UtilsModule: @ignore #TODO - actually add these operators to the module! # TODO: Should this be limited to AbstractFloat instead? @@ -95,8 +97,21 @@ function logical_and(x, y) return ((x > zero(x)) & (y > zero(y))) * one(x) end +# Strings +DE.get_op_name(::typeof(safe_pow)) = "^" +DE.get_op_name(::typeof(safe_log)) = "log" +DE.get_op_name(::typeof(safe_log2)) = "log2" +DE.get_op_name(::typeof(safe_log10)) = "log10" +DE.get_op_name(::typeof(safe_log1p)) = "log1p" +DE.get_op_name(::typeof(safe_acosh)) = "acosh" +DE.get_op_name(::typeof(safe_sqrt)) = "sqrt" + # Deprecated operations: @deprecate pow(x, y) safe_pow(x, y) @deprecate pow_abs(x, y) safe_pow(x, y) +# For static analysis tools: +@ignore pow(x, y) = safe_pow(x, y) +@ignore pow_abs(x, y) = safe_pow(x, y) + end diff --git a/src/Options.jl b/src/Options.jl index 4131a40bf..5268b5955 100644 --- a/src/Options.jl +++ b/src/Options.jl @@ -4,8 +4,8 @@ using DispatchDoctor: @unstable using Optim: Optim using Dates: Dates using StatsBase: StatsBase -using DynamicExpressions: OperatorEnum, Node -using Distributed: nworkers +using DynamicExpressions: OperatorEnum, Node, Expression, default_node_type +using ADTypes: AbstractADType, ADTypes using LossFunctions: L2DistLoss, SupervisedLoss using Optim: Optim using LineSearches: LineSearches @@ -26,67 +26,118 @@ using ..OperatorsModule: safe_acosh, atanh_clip using ..MutationWeightsModule: MutationWeights, mutations +using ..LLMOptionsModule: LLMOptions, validate_llm_options import ..OptionsStructModule: Options using ..OptionsStructModule: ComplexityMapping, operator_specialization -using ..UtilsModule: max_ops, @save_kwargs - -""" - build_constraints(una_constraints, bin_constraints, - unary_operators, binary_operators) - -Build constraints on operator-level complexity from a user-passed dict. -""" -function build_constraints( - una_constraints, bin_constraints, unary_operators, binary_operators, nuna, nbin -)::Tuple{Array{Int,1},Array{Tuple{Int,Int},1}} +using ..UtilsModule: max_ops, @save_kwargs, @ignore + +"""Build constraints on operator-level complexity from a user-passed dict.""" +@unstable function build_constraints(; + una_constraints, + bin_constraints, + @nospecialize(unary_operators), + @nospecialize(binary_operators) +)::Tuple{Vector{Int},Vector{Tuple{Int,Int}}} # Expect format ((*)=>(-1, 3)), etc. # TODO: Need to disable simplification if (*, -, +, /) are constrained? # Or, just quit simplification is constraints violated. - is_bin_constraints_already_done = typeof(bin_constraints) <: Array{Tuple{Int,Int},1} - is_una_constraints_already_done = typeof(una_constraints) <: Array{Int,1} - - if typeof(bin_constraints) <: Array && !is_bin_constraints_already_done - bin_constraints = Dict(bin_constraints) + is_una_constraints_already_done = una_constraints isa Vector{Int} + _una_constraints1 = if una_constraints isa Array && !is_una_constraints_already_done + Dict(una_constraints) + else + una_constraints end - if typeof(una_constraints) <: Array && !is_una_constraints_already_done - una_constraints = Dict(una_constraints) + _una_constraints2 = if _una_constraints1 === nothing + fill(-1, length(unary_operators)) + elseif !is_una_constraints_already_done + [ + haskey(_una_constraints1, op) ? _una_constraints1[op]::Int : -1 for + op in unary_operators + ] + else + _una_constraints1 end - if una_constraints === nothing - una_constraints = [-1 for i in 1:nuna] - elseif !is_una_constraints_already_done - una_constraints::Dict - _una_constraints = Int[] - for (i, op) in enumerate(unary_operators) - did_user_declare_constraints = haskey(una_constraints, op) - if did_user_declare_constraints - constraint::Int = una_constraints[op] - push!(_una_constraints, constraint) - else - push!(_una_constraints, -1) - end - end - una_constraints = _una_constraints + is_bin_constraints_already_done = bin_constraints isa Vector{Tuple{Int,Int}} + _bin_constraints1 = if bin_constraints isa Array && !is_bin_constraints_already_done + Dict(bin_constraints) + else + bin_constraints end - if bin_constraints === nothing - bin_constraints = [(-1, -1) for i in 1:nbin] + _bin_constraints2 = if _bin_constraints1 === nothing + fill((-1, -1), length(binary_operators)) elseif !is_bin_constraints_already_done - bin_constraints::Dict - _bin_constraints = Tuple{Int,Int}[] - for (i, op) in enumerate(binary_operators) - did_user_declare_constraints = haskey(bin_constraints, op) - if did_user_declare_constraints - constraint::Tuple{Int,Int} = bin_constraints[op] - push!(_bin_constraints, constraint) + [ + if haskey(_bin_constraints1, op) + _bin_constraints1[op]::Tuple{Int,Int} else - push!(_bin_constraints, (-1, -1)) + (-1, -1) + end for op in binary_operators + ] + else + _bin_constraints1 + end + + return _una_constraints2, _bin_constraints2 +end + +@unstable function build_nested_constraints(; + @nospecialize(binary_operators), @nospecialize(unary_operators), nested_constraints +) + nested_constraints === nothing && return nested_constraints + # Check that intersection of binary operators and unary operators is empty: + for op in binary_operators + if op โˆˆ unary_operators + error( + "Operator $(op) is both a binary and unary operator. " * + "You can't use nested constraints.", + ) + end + end + + # Convert to dict: + _nested_constraints = if nested_constraints isa Dict + nested_constraints + else + # Convert to dict: + nested_constraints = Dict( + [cons[1] => Dict(cons[2]...) for cons in nested_constraints]... + ) + end + for (op, nested_constraint) in _nested_constraints + if !(op โˆˆ binary_operators || op โˆˆ unary_operators) + error("Operator $(op) is not in the operator set.") + end + for (nested_op, max_nesting) in nested_constraint + if !(nested_op โˆˆ binary_operators || nested_op โˆˆ unary_operators) + error("Operator $(nested_op) is not in the operator set.") end + @assert nested_op โˆˆ binary_operators || nested_op โˆˆ unary_operators + @assert max_nesting >= -1 && typeof(max_nesting) <: Int end - bin_constraints = _bin_constraints end - return una_constraints, bin_constraints + # Lastly, we clean it up into a dict of (degree,op_idx) => max_nesting. + return [ + let (degree, idx) = if op โˆˆ binary_operators + 2, findfirst(isequal(op), binary_operators)::Int + else + 1, findfirst(isequal(op), unary_operators)::Int + end, + new_max_nesting_dict = [ + let (nested_degree, nested_idx) = if nested_op โˆˆ binary_operators + 2, findfirst(isequal(nested_op), binary_operators)::Int + else + 1, findfirst(isequal(nested_op), unary_operators)::Int + end + (nested_degree, nested_idx, max_nesting) + end for (nested_op, max_nesting) in nested_constraint + ] + + (degree, idx, new_max_nesting_dict) + end for (op, nested_constraint) in _nested_constraints + ] end function binopmap(op::F) where {F} @@ -152,6 +203,9 @@ end create_mutation_weights(w::MutationWeights) = w create_mutation_weights(w::NamedTuple) = MutationWeights(; w...) +create_llm_options(w::LLMOptions) = w +create_llm_options(w::NamedTuple) = LLMOptions(; w...) + const deprecated_options_mapping = Base.ImmutableDict( :mutationWeights => :mutation_weights, :hofMigration => :hof_migration, @@ -177,6 +231,9 @@ const deprecated_options_mapping = Base.ImmutableDict( :loss => :elementwise_loss, ) +# For static analysis tools: +@ignore const DEFAULT_OPTIONS = () + const OPTION_DESCRIPTIONS = """- `binary_operators`: Vector of binary operators (functions) to use. Each operator should be defined for two input scalars, and one output scalar. All operators @@ -248,8 +305,10 @@ const OPTION_DESCRIPTIONS = """- `binary_operators`: Vector of binary operators return sum((prediction .- dataset.y) .^ 2) / dataset.n end -- `node_type::Type{N}=Node`: The type of node to use for the search. - For example, `Node` or `GraphNode`. +- `expression_type::Type{E}=Expression`: The type of expression to use. + For example, `Expression`. +- `node_type::Type{N}=default_node_type(Expression)`: The type of node to use for the search. + For example, `Node` or `GraphNode`. The default is computed by `default_node_type(expression_type)`. - `populations`: How many populations of equations to use. - `population_size`: How many equations in each population. - `ncycles_per_iteration`: How many generations to consider per iteration. @@ -312,14 +371,21 @@ const OPTION_DESCRIPTIONS = """- `binary_operators`: Vector of binary operators - `optimizer_probability`: Probability of performing optimization of constants at the end of a given iteration. - `optimizer_iterations`: How many optimization iterations to perform. This gets - passed to `Optim.Options` as `iterations`. The default is 8. + passed to `Optim.Options` as `iterations`. The default is 8. - `optimizer_f_calls_limit`: How many function calls to allow during optimization. This gets passed to `Optim.Options` as `f_calls_limit`. The default is - `0` which means no limit. + `10_000`. - `optimizer_options`: General options for the constant optimization. For details we refer to the documentation on `Optim.Options` from the `Optim.jl` package. Options can be provided here as `NamedTuple`, e.g. `(iterations=16,)`, as a `Dict`, e.g. Dict(:x_tol => 1.0e-32,), or as an `Optim.Options` instance. +- `autodiff_backend`: The backend to use for differentiation, which should be + an instance of `AbstractADType` (see `DifferentiationInterface.jl`). + Default is `nothing`, which means `Optim.jl` will estimate gradients (likely + with finite differences). You can also pass a symbolic version of the backend + type, such as `:Zygote` for Zygote, `:Enzyme`, etc. Most backends will not + work, and many will never work due to incompatibilities, though support for some + is gradually being added. - `output_file`: What file to store equations to, as a backup. - `perturbation_factor`: When mutating a constant, either multiply or divide by (1+perturbation_factor)^(rand()+1). @@ -328,6 +394,8 @@ const OPTION_DESCRIPTIONS = """- `binary_operators`: Vector of binary operators - `mutation_weights`: Relative probabilities of the mutations. The struct `MutationWeights` should be passed to these options. See its documentation on `MutationWeights` for the different weights. +- `llm_options`: Options for LLM inference. Managed through struct + `LLMOptions`. See its documentation for more details. - `crossover_probability`: Probability of performing crossover. - `annealing`: Whether to use simulated annealing. - `warmup_maxsize_by`: Whether to slowly increase the max size from 5 up to @@ -401,13 +469,16 @@ $(OPTION_DESCRIPTIONS) should_simplify::Union{Nothing,Bool}=nothing, should_optimize_constants::Bool=true, output_file::Union{Nothing,AbstractString}=nothing, - node_type::Type=Node, + expression_type::Type=Expression, + node_type::Type=default_node_type(expression_type), + expression_options::NamedTuple=NamedTuple(), populations::Integer=15, perturbation_factor::Real=0.076, annealing::Bool=false, batching::Bool=false, batch_size::Integer=50, mutation_weights::Union{MutationWeights,AbstractVector,NamedTuple}=MutationWeights(), + llm_options::LLMOptions=LLMOptions(), crossover_probability::Real=0.066, warmup_maxsize_by::Real=0.0, use_frequency::Bool=true, @@ -434,6 +505,7 @@ $(OPTION_DESCRIPTIONS) optimizer_iterations::Union{Nothing,Integer}=nothing, optimizer_f_calls_limit::Union{Nothing,Integer}=nothing, optimizer_options::Union{Dict,NamedTuple,Optim.Options,Nothing}=nothing, + autodiff_backend::Union{AbstractADType,Symbol,Nothing}=nothing, use_recorder::Bool=false, recorder_file::AbstractString="pysr_recorder.json", early_stop_condition::Union{Function,Real,Nothing}=nothing, @@ -489,6 +561,10 @@ $(OPTION_DESCRIPTIONS) k == :enable_autodiff && continue k == :ns && (tournament_selection_n = kws[k]; true) && continue k == :loss && (elementwise_loss = kws[k]; true) && continue + if k == :llm_options + llm_options = kws[k] + continue + end if k == :mutationWeights if typeof(kws[k]) <: AbstractVector _mutation_weights = kws[k] @@ -561,69 +637,15 @@ $(OPTION_DESCRIPTIONS) end end - nuna = length(unary_operators) - nbin = length(binary_operators) @assert maxsize > 3 @assert warmup_maxsize_by >= 0.0f0 - @assert nuna <= max_ops && nbin <= max_ops + @assert length(unary_operators) <= max_ops + @assert length(binary_operators) <= max_ops # Make sure nested_constraints contains functions within our operator set: - if nested_constraints !== nothing - # Check that intersection of binary operators and unary operators is empty: - for op in binary_operators - if op โˆˆ unary_operators - error( - "Operator $(op) is both a binary and unary operator. " * - "You can't use nested constraints.", - ) - end - end - - # Convert to dict: - if !(typeof(nested_constraints) <: Dict) - # Convert to dict: - nested_constraints = Dict( - [cons[1] => Dict(cons[2]...) for cons in nested_constraints]... - ) - end - for (op, nested_constraint) in nested_constraints - if !(op โˆˆ binary_operators || op โˆˆ unary_operators) - error("Operator $(op) is not in the operator set.") - end - for (nested_op, max_nesting) in nested_constraint - if !(nested_op โˆˆ binary_operators || nested_op โˆˆ unary_operators) - error("Operator $(nested_op) is not in the operator set.") - end - @assert nested_op โˆˆ binary_operators || nested_op โˆˆ unary_operators - @assert max_nesting >= -1 && typeof(max_nesting) <: Int - end - end - - # Lastly, we clean it up into a dict of (degree,op_idx) => max_nesting. - new_nested_constraints = [] - # Dict() - for (op, nested_constraint) in nested_constraints - (degree, idx) = if op โˆˆ binary_operators - 2, findfirst(isequal(op), binary_operators) - else - 1, findfirst(isequal(op), unary_operators) - end - new_max_nesting_dict = [] - # Dict() - for (nested_op, max_nesting) in nested_constraint - (nested_degree, nested_idx) = if nested_op โˆˆ binary_operators - 2, findfirst(isequal(nested_op), binary_operators) - else - 1, findfirst(isequal(nested_op), unary_operators) - end - # new_max_nesting_dict[(nested_degree, nested_idx)] = max_nesting - push!(new_max_nesting_dict, (nested_degree, nested_idx, max_nesting)) - end - # new_nested_constraints[(degree, idx)] = new_max_nesting_dict - push!(new_nested_constraints, (degree, idx, new_max_nesting_dict)) - end - nested_constraints = new_nested_constraints - end + _nested_constraints = build_nested_constraints(; + binary_operators, unary_operators, nested_constraints + ) if typeof(constraints) <: Tuple constraints = collect(constraints) @@ -642,8 +664,8 @@ $(OPTION_DESCRIPTIONS) una_constraints = constraints end - una_constraints, bin_constraints = build_constraints( - una_constraints, bin_constraints, unary_operators, binary_operators, nuna, nbin + _una_constraints, _bin_constraints = build_constraints(; + una_constraints, bin_constraints, unary_operators, binary_operators ) complexity_mapping = ComplexityMapping( @@ -692,7 +714,7 @@ $(OPTION_DESCRIPTIONS) if !isa(optimizer_options, Optim.Options) optimizer_iterations = isnothing(optimizer_iterations) ? 8 : optimizer_iterations optimizer_f_calls_limit = if isnothing(optimizer_f_calls_limit) - 0 + 10_000 else optimizer_f_calls_limit end @@ -710,35 +732,35 @@ $(OPTION_DESCRIPTIONS) @warn "Optimizer warnings are turned on. This might result in a lot of warnings being printed from NaNs, as these are common during symbolic regression" end - ## Create tournament weights: - tournament_selection_weights = - let n = tournament_selection_n, p = tournament_selection_p - k = collect(0:(n - 1)) - prob_each = p * ((1 - p) .^ k) - - StatsBase.Weights(prob_each, sum(prob_each)) - end - set_mutation_weights = create_mutation_weights(mutation_weights) + set_llm_options = create_llm_options(llm_options) + validate_llm_options(set_llm_options) @assert print_precision > 0 + _autodiff_backend = if autodiff_backend isa Union{Nothing,AbstractADType} + autodiff_backend + else + ADTypes.Auto(autodiff_backend) + end + options = Options{ typeof(complexity_mapping), operator_specialization(typeof(operators)), node_type, + expression_type, + typeof(expression_options), turbo, bumper, deprecated_return_state, - typeof(tournament_selection_weights), + typeof(_autodiff_backend), }( operators, - bin_constraints, - una_constraints, + _bin_constraints, + _una_constraints, complexity_mapping, tournament_selection_n, tournament_selection_p, - tournament_selection_weights, parsimony, dimensional_constraint_penalty, dimensionless_constants_only, @@ -758,6 +780,7 @@ $(OPTION_DESCRIPTIONS) batching, batch_size, set_mutation_weights, + set_llm_options, crossover_probability, warmup_maxsize_by, use_frequency, @@ -772,18 +795,21 @@ $(OPTION_DESCRIPTIONS) print_precision, save_to_file, probability_negate_constant, - nuna, - nbin, + length(unary_operators), + length(binary_operators), seed, elementwise_loss, loss_function, node_type, + expression_type, + expression_options, progress, terminal_width, optimizer_algorithm, optimizer_probability, optimizer_nrestarts, optimizer_options, + _autodiff_backend, recorder_file, tournament_selection_p, early_stop_condition, @@ -791,7 +817,7 @@ $(OPTION_DESCRIPTIONS) timeout_in_seconds, max_evals, skip_mutation_failures, - nested_constraints, + _nested_constraints, deterministic, define_helper_functions, use_recorder, diff --git a/src/OptionsStruct.jl b/src/OptionsStruct.jl index f89a284d3..c0bb73130 100644 --- a/src/OptionsStruct.jl +++ b/src/OptionsStruct.jl @@ -1,11 +1,13 @@ module OptionsStructModule +using DispatchDoctor: @unstable using Optim: Optim using DynamicExpressions: - AbstractOperatorEnum, AbstractExpressionNode, OperatorEnum, GenericOperatorEnum + AbstractOperatorEnum, AbstractExpressionNode, AbstractExpression, OperatorEnum using LossFunctions: SupervisedLoss import ..MutationWeightsModule: MutationWeights +import ..LLMOptionsModule: LLMOptions """ This struct defines how complexity is calculated. @@ -124,10 +126,12 @@ struct Options{ CM<:ComplexityMapping, OP<:AbstractOperatorEnum, N<:AbstractExpressionNode, + E<:AbstractExpression, + EO<:NamedTuple, _turbo, _bumper, _return_state, - W, + AD, } operators::OP bin_constraints::Vector{Tuple{Int,Int}} @@ -135,7 +139,6 @@ struct Options{ complexity_mapping::CM tournament_selection_n::Int tournament_selection_p::Float32 - tournament_selection_weights::W parsimony::Float32 dimensional_constraint_penalty::Union{Float32,Nothing} dimensionless_constants_only::Bool @@ -155,6 +158,7 @@ struct Options{ batching::Bool batch_size::Int mutation_weights::MutationWeights + llm_options::LLMOptions crossover_probability::Float32 warmup_maxsize_by::Float32 use_frequency::Bool @@ -175,12 +179,15 @@ struct Options{ elementwise_loss::Union{SupervisedLoss,Function} loss_function::Union{Nothing,Function} node_type::Type{N} + expression_type::Type{E} + expression_options::EO progress::Union{Bool,Nothing} terminal_width::Union{Int,Nothing} optimizer_algorithm::Optim.AbstractOptimizer optimizer_probability::Float32 optimizer_nrestarts::Int optimizer_options::Optim.Options + autodiff_backend::AD recorder_file::String prob_pick_first::Float32 early_stop_condition::Union{Function,Nothing} @@ -204,7 +211,7 @@ function Base.print(io::IO, options::Options) * join( [ - if fieldname in (:optimizer_options, :mutation_weights) + if fieldname in (:optimizer_options, :mutation_weights, :llm_options) "$(fieldname)=..." else "$(fieldname)=$(getfield(options, fieldname))" @@ -218,4 +225,19 @@ function Base.print(io::IO, options::Options) end Base.show(io::IO, ::MIME"text/plain", options::Options) = Base.print(io, options) +@unstable function specialized_options(options::Options) + return _specialized_options(options) +end +@generated function _specialized_options(options::O) where {O<:Options} + # Return an options struct with concrete operators + type_parameters = O.parameters + fields = Any[:(getfield(options, $(QuoteNode(k)))) for k in fieldnames(O)] + quote + operators = getfield(options, :operators) + Options{$(type_parameters[1]),typeof(operators),$(type_parameters[3:end]...)}( + $(fields...) + ) + end +end + end diff --git a/src/PopMember.jl b/src/PopMember.jl index d47042c20..84f29f451 100644 --- a/src/PopMember.jl +++ b/src/PopMember.jl @@ -1,15 +1,14 @@ module PopMemberModule using DispatchDoctor: @unstable - -using DynamicExpressions: AbstractExpressionNode, copy_node, count_nodes -using ..CoreModule: Options, Dataset, DATA_TYPE, LOSS_TYPE +using DynamicExpressions: AbstractExpression, AbstractExpressionNode, string_tree +using ..CoreModule: Options, Dataset, DATA_TYPE, LOSS_TYPE, create_expression import ..ComplexityModule: compute_complexity using ..UtilsModule: get_birth_order using ..LossFunctionsModule: score_func # Define a member of population by equation, score, and age -mutable struct PopMember{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpressionNode{T}} +mutable struct PopMember{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} tree::N score::L # Inludes complexity penalty, normalization loss::L # Raw loss @@ -33,11 +32,20 @@ end ) return getfield(member, field) end +function Base.show(io::IO, p::PopMember{T,L,N}) where {T,L,N} + shower(x) = sprint(show, x) + print(io, "PopMember(") + print(io, "tree = (", string_tree(p.tree), "), ") + print(io, "loss = ", shower(p.loss), ", ") + print(io, "score = ", shower(p.score)) + print(io, ")") + return nothing +end generate_reference() = abs(rand(Int)) """ - PopMember(t::AbstractExpressionNode{T}, score::L, loss::L) + PopMember(t::AbstractExpression{T}, score::L, loss::L) Create a population member with a birth date at the current time. The type of the `Node` may be different from the type of the score @@ -45,23 +53,30 @@ and loss. # Arguments -- `t::AbstractExpressionNode{T}`: The tree for the population member. +- `t::AbstractExpression{T}`: The tree for the population member. - `score::L`: The score (normalized to a baseline, and offset by a complexity penalty) - `loss::L`: The raw loss to assign. """ function PopMember( - t::AbstractExpressionNode{T}, + t::AbstractExpression{T}, score::L, loss::L, - options::Options, + options::Union{Options,Nothing}=nothing, complexity::Union{Int,Nothing}=nothing; ref::Int=-1, parent::Int=-1, - deterministic=false, + deterministic=nothing, ) where {T<:DATA_TYPE,L<:LOSS_TYPE} if ref == -1 ref = generate_reference() end + if !(deterministic isa Bool) + throw( + ArgumentError( + "You must declare `deterministic` as `true` or `false`, it cannot be left undefined.", + ), + ) + end complexity = complexity === nothing ? -1 : complexity return PopMember{T,L,typeof(t)}( t, @@ -75,8 +90,11 @@ function PopMember( end """ - PopMember(dataset::Dataset{T,L}, - t::AbstractExpressionNode{T}, options::Options) + PopMember( + dataset::Dataset{T,L}, + t::AbstractExpression{T}, + options::Options + ) Create a population member with a birth date at the current time. Automatically compute the score for this tree. @@ -84,23 +102,24 @@ Automatically compute the score for this tree. # Arguments - `dataset::Dataset{T,L}`: The dataset to evaluate the tree on. -- `t::AbstractExpressionNode{T}`: The tree for the population member. +- `t::AbstractExpression{T}`: The tree for the population member. - `options::Options`: What options to use. """ function PopMember( dataset::Dataset{T,L}, - t::AbstractExpressionNode{T}, + tree::Union{AbstractExpressionNode{T},AbstractExpression{T}}, options::Options, complexity::Union{Int,Nothing}=nothing; ref::Int=-1, parent::Int=-1, deterministic=nothing, ) where {T<:DATA_TYPE,L<:LOSS_TYPE} - set_complexity = complexity === nothing ? compute_complexity(t, options) : complexity + ex = create_expression(tree, options, dataset) + set_complexity = complexity === nothing ? compute_complexity(ex, options) : complexity @assert set_complexity != -1 - score, loss = score_func(dataset, t, options; complexity=set_complexity) + score, loss = score_func(dataset, ex, options; complexity=set_complexity) return PopMember( - t, + ex, score, loss, options, diff --git a/src/Population.jl b/src/Population.jl index 67e90b3a9..547c8b81e 100644 --- a/src/Population.jl +++ b/src/Population.jl @@ -1,19 +1,19 @@ module PopulationModule using StatsBase: StatsBase -using Random: randperm using DispatchDoctor: @unstable -using DynamicExpressions: AbstractExpressionNode, Node, string_tree +using DynamicExpressions: AbstractExpression, string_tree using ..CoreModule: Options, Dataset, RecordType, DATA_TYPE, LOSS_TYPE using ..ComplexityModule: compute_complexity using ..LossFunctionsModule: score_func, update_baseline_loss! using ..AdaptiveParsimonyModule: RunningSearchStatistics using ..MutationFunctionsModule: gen_random_tree +using ..LLMFunctionsModule: gen_llm_random_tree using ..PopMemberModule: PopMember -using ..UtilsModule: bottomk_fast, argmin_fast +using ..UtilsModule: bottomk_fast, argmin_fast, PerThreadCache # A list of members of the population, with easy constructors, # which allow for random generation of new populations -struct Population{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpressionNode{T}} +struct Population{T<:DATA_TYPE,L<:LOSS_TYPE,N<:AbstractExpression{T}} members::Array{PopMember{T,L,N},1} n::Int end @@ -26,6 +26,14 @@ function Population(pop::Vector{<:PopMember}) return Population(pop, size(pop, 1)) end +function gen_random_tree_pop(nlength, options, nfeatures, T, idea_database) + if options.llm_options.active && (rand() < options.llm_options.weights.llm_gen_random) + gen_llm_random_tree(nlength, options, nfeatures, T, idea_database) + else + gen_random_tree(nlength, options, nfeatures, T) + end +end + """ Population(dataset::Dataset{T,L}; population_size, nlength::Int=3, options::Options, @@ -40,6 +48,7 @@ function Population( nlength::Int=3, nfeatures::Int, npop=nothing, + idea_database=nothing, ) where {T,L} @assert (population_size !== nothing) โŠป (npop !== nothing) population_size = if npop === nothing @@ -51,7 +60,7 @@ function Population( [ PopMember( dataset, - gen_random_tree(nlength, options, nfeatures, T), + gen_random_tree_pop(nlength, options, nfeatures, T, idea_database), options; parent=-1, deterministic=options.deterministic, @@ -147,7 +156,7 @@ function _best_of_sample( argmin_fast(scores) else # First, decide what place we take (usually 1st place wins): - tournament_winner = StatsBase.sample(options.tournament_selection_weights) + tournament_winner = StatsBase.sample(get_tournament_selection_weights(options)) # Then, find the member that won that place, given # their fitness: if tournament_winner == 1 @@ -159,6 +168,26 @@ function _best_of_sample( return members[chosen_idx] end +const CACHED_WEIGHTS = + let init_k = collect(0:5), + init_prob_each = 0.5f0 * (1 - 0.5f0) .^ init_k, + test_weights = StatsBase.Weights(init_prob_each, sum(init_prob_each)) + + PerThreadCache{Dict{Tuple{Int,Float32},typeof(test_weights)}}() + end + +@unstable function get_tournament_selection_weights(@nospecialize(options::Options)) + n = options.tournament_selection_n + p = options.tournament_selection_p + # Computing the weights for the tournament becomes quite expensive, + return get!(CACHED_WEIGHTS, (n, p)) do + k = collect(0:(n - 1)) + prob_each = p * ((1 - p) .^ k) + + return StatsBase.Weights(prob_each, sum(prob_each)) + end +end + function finalize_scores( dataset::Dataset{T,L}, pop::P, options::Options )::Tuple{P,Float64} where {T,L,P<:Population{T,L}} diff --git a/src/Recorder.jl b/src/Recorder.jl index d7bf1f668..a25ac0e78 100644 --- a/src/Recorder.jl +++ b/src/Recorder.jl @@ -1,6 +1,6 @@ module RecorderModule -using ..CoreModule: RecordType, Options +using ..CoreModule: RecordType "Assumes that `options` holds the user options::Options" macro recorder(ex) diff --git a/src/RegularizedEvolution.jl b/src/RegularizedEvolution.jl index f00de16b9..913df0d89 100644 --- a/src/RegularizedEvolution.jl +++ b/src/RegularizedEvolution.jl @@ -18,7 +18,9 @@ function reg_evol_cycle( curmaxsize::Int, running_search_statistics::RunningSearchStatistics, options::Options, - record::RecordType, + record::RecordType; + dominating=nothing, + idea_database=nothing, )::Tuple{P,Float64} where {T<:DATA_TYPE,L<:LOSS_TYPE,P<:Population{T,L}} # Batch over each subsample. Can give 15% improvement in speed; probably moreso for large pops. # but is ultimately a different algorithm than regularized evolution, and might not be @@ -42,6 +44,8 @@ function reg_evol_cycle( running_search_statistics, options; tmp_recorder=mutation_recorder, + dominating=dominating, + idea_database=idea_database, ) num_evals += tmp_num_evals @@ -89,7 +93,9 @@ function reg_evol_cycle( allstar2 = best_of_sample(pop, running_search_statistics, options) baby1, baby2, crossover_accepted, tmp_num_evals = crossover_generation( - allstar1, allstar2, dataset, curmaxsize, options + allstar1, allstar2, dataset, curmaxsize, options; + dominating=dominating, + idea_database=idea_database, ) num_evals += tmp_num_evals diff --git a/src/SearchUtils.jl b/src/SearchUtils.jl index fb2da4876..d8f5bd382 100644 --- a/src/SearchUtils.jl +++ b/src/SearchUtils.jl @@ -1,23 +1,22 @@ -"""Functions to help with the main loop of SymbolicRegression.jl. +"""Functions to help with the main loop of LibraryAugmentedSymbolicRegression.jl. This includes: process management, stdin reading, checking for early stops.""" module SearchUtilsModule using Printf: @printf, @sprintf -using Distributed +using Distributed: Distributed, @spawnat, Future, procs using StatsBase: mean using DispatchDoctor: @unstable -using DynamicExpressions: AbstractExpressionNode, string_tree +using DynamicExpressions: AbstractExpression, string_tree using ..UtilsModule: subscriptify using ..CoreModule: Dataset, Options, MAX_DEGREE, RecordType using ..ComplexityModule: compute_complexity using ..PopulationModule: Population using ..PopMemberModule: PopMember -using ..HallOfFameModule: - HallOfFame, calculate_pareto_frontier, string_dominating_pareto_curve +using ..HallOfFameModule: HallOfFame, string_dominating_pareto_curve using ..ProgressBarsModule: WrappedProgressBar, set_multiline_postfix!, manually_iterate! -using ..AdaptiveParsimonyModule: update_frequencies!, RunningSearchStatistics +using ..AdaptiveParsimonyModule: RunningSearchStatistics """ RuntimeOptions{N,PARALLELISM,DIM_OUT,RETURN_STATE} @@ -129,11 +128,26 @@ end function init_dummy_pops( npops::Int, datasets::Vector{D}, options::Options ) where {T,L,D<:Dataset{T,L}} + prototype = Population( + first(datasets); + population_size=1, + options=options, + nfeatures=first(datasets).nfeatures, + ) + # ^ Due to occasional inference issue, we manually specify the return type return [ - [ - Population(d; population_size=1, options=options, nfeatures=d.nfeatures) for - _ in 1:npops - ] for d in datasets + typeof(prototype)[ + if (i == 1 && j == 1) + prototype + else + Population( + datasets[j]; + population_size=1, + options=options, + nfeatures=datasets[j].nfeatures, + ) + end for i in 1:npops + ] for j in 1:length(datasets) ] end @@ -211,76 +225,52 @@ function check_max_evals(num_evals, options::Options)::Bool return options.max_evals !== nothing && options.max_evals::Int <= sum(sum, num_evals) end -const TIME_TYPE = Float64 +""" +This struct is used to monitor resources. -"""This struct is used to monitor resources.""" +Whenever we check a channel, we record if it was empty or not. +This gives us a measure for how much of a bottleneck there is +at the head worker. +""" Base.@kwdef mutable struct ResourceMonitor - """The time the search started.""" - absolute_start_time::TIME_TYPE = time() - """The time the head worker started doing work.""" - start_work::TIME_TYPE = Inf - """The time the head worker finished doing work.""" - stop_work::TIME_TYPE = Inf - - num_starts::UInt = 0 - num_stops::UInt = 0 - work_intervals::Vector{TIME_TYPE} = TIME_TYPE[] - rest_intervals::Vector{TIME_TYPE} = TIME_TYPE[] - - """Number of intervals to store.""" - num_intervals_to_store::Int -end - -function start_work_monitor!(monitor::ResourceMonitor) - monitor.start_work = time() - monitor.num_starts += 1 - if monitor.num_stops > 0 - push!(monitor.rest_intervals, monitor.start_work - monitor.stop_work) - if length(monitor.rest_intervals) > monitor.num_intervals_to_store - popfirst!(monitor.rest_intervals) - end - end - return nothing + population_ready::Vector{Bool} = Bool[] + max_recordings::Int + start_reporting_at::Int + window_size::Int end -function stop_work_monitor!(monitor::ResourceMonitor) - monitor.stop_work = time() - push!(monitor.work_intervals, monitor.stop_work - monitor.start_work) - monitor.num_stops += 1 - @assert monitor.num_stops == monitor.num_starts - if length(monitor.work_intervals) > monitor.num_intervals_to_store - popfirst!(monitor.work_intervals) +function record_channel_state!(monitor::ResourceMonitor, state) + push!(monitor.population_ready, state) + if length(monitor.population_ready) > monitor.max_recordings + popfirst!(monitor.population_ready) end return nothing end function estimate_work_fraction(monitor::ResourceMonitor)::Float64 - if monitor.num_stops <= 1 + if length(monitor.population_ready) <= monitor.start_reporting_at return 0.0 # Can't estimate from only one interval, due to JIT. end - work_intervals = monitor.work_intervals - rest_intervals = monitor.rest_intervals - # Trim 1st, in case we are still in the first interval. - if monitor.num_stops <= monitor.num_intervals_to_store + 1 - work_intervals = work_intervals[2:end] - rest_intervals = rest_intervals[2:end] - end - return mean(work_intervals) / (mean(work_intervals) + mean(rest_intervals)) + return mean(monitor.population_ready[(end - (monitor.window_size - 1)):end]) end function get_load_string(; head_node_occupation::Float64, parallelism=:serial) - parallelism == :serial && return "" - out = @sprintf("Head worker occupation: %.1f%%", head_node_occupation * 100) - - raise_usage_warning = head_node_occupation > 0.4 - if raise_usage_warning - out *= "." - out *= " This is high, and will prevent efficient resource usage." - out *= " Increase `ncycles_per_iteration` to reduce load on head worker." + if parallelism == :serial || head_node_occupation == 0.0 + return "" end + return "" + ## TODO: Debug why populations are always ready + # out = @sprintf("Head worker occupation: %.1f%%", head_node_occupation * 100) + + # raise_usage_warning = head_node_occupation > 0.4 + # if raise_usage_warning + # out *= "." + # out *= " This is high, and will prevent efficient resource usage." + # out *= " Increase `ncycles_per_iteration` to reduce load on head worker." + # end - out *= "\n" - return out + # out *= "\n" + # return out end function update_progress_bar!( @@ -386,9 +376,7 @@ The state of a search, including the populations, worker outputs, tasks, and channels. This is used to manage the search and keep track of runtime variables in a single struct. """ -Base.@kwdef struct SearchState{ - T,L,N<:AbstractExpressionNode{T},WorkerOutputType,ChannelType -} +Base.@kwdef struct SearchState{T,L,N<:AbstractExpression{T},WorkerOutputType,ChannelType} procs::Vector{Int} we_created_procs::Bool worker_output::Vector{Vector{WorkerOutputType}} @@ -478,6 +466,7 @@ function construct_datasets( y_variable_names, X_units, y_units, + extra, ::Type{L}, ) where {L} nout = size(y, 1) @@ -486,6 +475,7 @@ function construct_datasets( X, y[j, :], L; + index=j, weights=(weights === nothing ? weights : weights[j, :]), variable_names=variable_names, display_variable_names=display_variable_names, @@ -506,6 +496,7 @@ function construct_datasets( end, X_units=X_units, y_units=isa(y_units, AbstractVector) ? y_units[j] : y_units, + extra=extra, ) for j in 1:nout ] end diff --git a/src/SingleIteration.jl b/src/SingleIteration.jl index 2ae3428ee..ce420e5b0 100644 --- a/src/SingleIteration.jl +++ b/src/SingleIteration.jl @@ -1,17 +1,12 @@ module SingleIterationModule -using DynamicExpressions: - AbstractExpressionNode, - Node, - constructorof, - string_tree, - simplify_tree!, - combine_operators +using ADTypes: AutoEnzyme +using DynamicExpressions: AbstractExpression, string_tree, simplify_tree!, combine_operators using ..UtilsModule: @threads_if -using ..CoreModule: Options, Dataset, RecordType, DATA_TYPE, LOSS_TYPE +using ..CoreModule: Options, Dataset, RecordType, create_expression using ..ComplexityModule: compute_complexity -using ..PopMemberModule: PopMember, generate_reference -using ..PopulationModule: Population, finalize_scores, best_sub_pop +using ..PopMemberModule: generate_reference +using ..PopulationModule: Population, finalize_scores using ..HallOfFameModule: HallOfFame using ..AdaptiveParsimonyModule: RunningSearchStatistics using ..RegularizedEvolutionModule: reg_evol_cycle @@ -30,24 +25,24 @@ function s_r_cycle( verbosity::Int=0, options::Options, record::RecordType, + dominating=nothing, + idea_database=nothing, )::Tuple{ P,HallOfFame{T,L,N},Float64 -} where {T,L,D<:Dataset{T,L},N<:AbstractExpressionNode{T},P<:Population{T,L,N}} +} where {T,L,D<:Dataset{T,L},N<:AbstractExpression{T},P<:Population{T,L,N}} max_temp = 1.0 min_temp = 0.0 if !options.annealing min_temp = max_temp end all_temperatures = LinRange(max_temp, min_temp, ncycles) - best_examples_seen = HallOfFame(options, T, L) + best_examples_seen = HallOfFame(options, dataset) num_evals = 0.0 # For evaluating on a fixed batch (for batching) idx = options.batching ? batch_sample(dataset, options) : Int[] - loss_cache = [ - (oid=constructorof(typeof(member.tree))(T; val=zero(T)), score=zero(L)) for - member in pop.members - ] + example_tree = create_expression(zero(T), options, dataset) + loss_cache = [(oid=example_tree, score=zero(L)) for member in pop.members] first_loop = true for temperature in all_temperatures @@ -59,6 +54,8 @@ function s_r_cycle( running_search_statistics, options, record, + dominating=dominating, + idea_database=idea_database, ) num_evals += tmp_num_evals for (i, member) in enumerate(pop.members) @@ -109,13 +106,14 @@ function optimize_and_simplify_population( )::Tuple{P,Float64} where {T,L,D<:Dataset{T,L},P<:Population{T,L}} array_num_evals = zeros(Float64, pop.n) do_optimization = rand(pop.n) .< options.optimizer_probability - @threads_if !(options.deterministic) for j in 1:(pop.n) + # Note: we have to turn off this threading loop due to Enzyme, since we need + # to manually allocate a new task with a larger stack for Enzyme. + should_thread = !(options.deterministic) && !(isa(options.autodiff_backend, AutoEnzyme)) + @threads_if should_thread for j in 1:(pop.n) if options.should_simplify tree = pop.members[j].tree tree = simplify_tree!(tree, options.operators) - if tree isa Node - tree = combine_operators(tree, options.operators) - end + tree = combine_operators(tree, options.operators) pop.members[j].tree = tree end if options.should_optimize_constants && do_optimization[j] diff --git a/src/SymbolicRegression.jl b/src/SymbolicRegression.jl index f051c25da..d08d49780 100644 --- a/src/SymbolicRegression.jl +++ b/src/SymbolicRegression.jl @@ -1,4 +1,4 @@ -module SymbolicRegression +module LibraryAugmentedSymbolicRegression # Types export Population, @@ -7,12 +7,19 @@ export Population, Options, Dataset, MutationWeights, + LLMWeights, + LLMOptions, Node, GraphNode, + ParametricNode, + Expression, + ParametricExpression, + StructuredExpression, NodeSampler, + AbstractExpression, AbstractExpressionNode, - SRRegressor, - MultitargetSRRegressor, + LaSRRegressor, + MultitargetLaSRRegressor, LOSS_TYPE, DATA_TYPE, @@ -22,6 +29,8 @@ export Population, calculate_pareto_frontier, count_nodes, compute_complexity, + @parse_expression, + parse_expression, print_tree, string_tree, eval_tree_array, @@ -31,6 +40,7 @@ export Population, set_node!, copy_node, node_to_symbolic, + node_type, symbolic_to_node, simplify_tree!, tree_mapreduce, @@ -38,6 +48,9 @@ export Population, gen_random_tree, gen_random_tree_fixed_size, @extend_operators, + get_tree, + get_contents, + get_metadata, #Operators plus, @@ -76,15 +89,24 @@ using Reexport using DynamicExpressions: Node, GraphNode, + ParametricNode, + Expression, + ParametricExpression, + StructuredExpression, NodeSampler, + AbstractExpression, AbstractExpressionNode, + @parse_expression, + parse_expression, copy_node, set_node!, string_tree, print_tree, count_nodes, get_constants, - set_constants, + get_scalar_constants, + set_constants!, + set_scalar_constants!, index_constants, NodeIndex, eval_tree_array, @@ -96,8 +118,12 @@ using DynamicExpressions: combine_operators, simplify_tree!, tree_mapreduce, - set_default_variable_names! -using DynamicExpressions.EquationModule: with_type_parameters + set_default_variable_names!, + node_type, + get_tree, + get_contents, + get_metadata +using DynamicExpressions: with_type_parameters @reexport using LossFunctions: MarginLoss, DistanceLoss, @@ -162,6 +188,7 @@ using DispatchDoctor: @stable include("CheckConstraints.jl") include("AdaptiveParsimony.jl") include("MutationFunctions.jl") + include("LLMFunctions.jl") include("LossFunctions.jl") include("PopMember.jl") include("ConstantOptimization.jl") @@ -173,6 +200,7 @@ using DispatchDoctor: @stable include("ProgressBars.jl") include("Migration.jl") include("SearchUtils.jl") + include("ExpressionBuilder.jl") end using .CoreModule: @@ -185,6 +213,8 @@ using .CoreModule: Dataset, Options, MutationWeights, + LLMOptions, + LLMWeights, plus, sub, mult, @@ -207,8 +237,9 @@ using .CoreModule: gamma, erf, erfc, - atanh_clip -using .UtilsModule: is_anonymous_function, recursive_merge, json3_write + atanh_clip, + create_expression +using .UtilsModule: is_anonymous_function, recursive_merge, json3_write, @ignore using .ComplexityModule: compute_complexity using .CheckConstraintsModule: check_constraints using .AdaptiveParsimonyModule: @@ -219,6 +250,8 @@ using .MutationFunctionsModule: random_node, random_node_and_parent, crossover_trees +using .LLMFunctionsModule: update_idea_database + using .InterfaceDynamicExpressionsModule: @extend_operators using .LossFunctionsModule: eval_loss, score_func, update_baseline_loss! using .PopMemberModule: PopMember, reset_birth! @@ -246,8 +279,7 @@ using .SearchUtilsModule: check_for_timeout, check_max_evals, ResourceMonitor, - start_work_monitor!, - stop_work_monitor!, + record_channel_state!, estimate_work_fraction, update_progress_bar!, print_search_state, @@ -258,6 +290,7 @@ using .SearchUtilsModule: save_to_file, get_cur_maxsize, update_hall_of_fame! +using .ExpressionBuilderModule: embed_metadata, strip_metadata @stable default_mode = "disable" begin include("deprecates.jl") @@ -269,7 +302,7 @@ end Perform a distributed equation search for functions `f_i` which describe the mapping `f_i(X[:, j]) โ‰ˆ y[i, j]`. Options are -configured using SymbolicRegression.Options(...), +configured using LibraryAugmentedSymbolicRegression.Options(...), which should be passed as a keyword argument to options. One can turn off parallelism with `numprocs=0`, which is useful for debugging and profiling. @@ -379,6 +412,7 @@ function equation_search( progress::Union{Bool,Nothing}=nothing, X_units::Union{AbstractVector,Nothing}=nothing, y_units=nothing, + extra::NamedTuple=NamedTuple(), v_dim_out::Val{DIM_OUT}=Val(nothing), # Deprecated: multithreaded=nothing, @@ -406,6 +440,7 @@ function equation_search( y_variable_names, X_units, y_units, + extra, L, ) @@ -589,16 +624,19 @@ function equation_search( ) end -@stable default_mode = "disable" @noinline function _equation_search( +@noinline function _equation_search( datasets::Vector{D}, ropt::RuntimeOptions, options::Options, saved_state ) where {D<:Dataset} + # PROMPT EVOLUTION + idea_database_all = [Vector{String}() for j in 1:length(datasets)] + _validate_options(datasets, ropt, options) state = _create_workers(datasets, ropt, options) - _initialize_search!(state, datasets, ropt, options, saved_state) - _warmup_search!(state, datasets, ropt, options) - _main_search_loop!(state, datasets, ropt, options) + _initialize_search!(state, datasets, ropt, options, saved_state, idea_database_all) + _warmup_search!(state, datasets, ropt, options, idea_database_all) + _main_search_loop!(state, datasets, ropt, options, idea_database_all) _tear_down!(state, ropt, options) - return _format_output(state, ropt) + return _format_output(state, datasets, ropt, options) end function _validate_options( @@ -641,7 +679,8 @@ end nout = length(datasets) example_dataset = first(datasets) - NT = with_type_parameters(options.node_type, T) + example_ex = create_expression(zero(T), options, example_dataset) + NT = typeof(example_ex) PopType = Population{T,L,NT} HallOfFameType = HallOfFame{T,L,NT} WorkerOutputType = get_worker_output_type( @@ -698,9 +737,7 @@ end for j in 1:nout ] - return SearchState{ - T,L,with_type_parameters(options.node_type, T),WorkerOutputType,ChannelType - }(; + return SearchState{T,L,typeof(example_ex),WorkerOutputType,ChannelType}(; procs=procs, we_created_procs=we_created_procs, worker_output=worker_output, @@ -720,20 +757,25 @@ end ) end function _initialize_search!( - state::SearchState{T,L,N}, datasets, ropt::RuntimeOptions, options::Options, saved_state + state::SearchState{T,L,N}, + datasets, + ropt::RuntimeOptions, + options::Options, + saved_state, + idea_database_all, ) where {T,L,N} nout = length(datasets) init_hall_of_fame = load_saved_hall_of_fame(saved_state) if init_hall_of_fame === nothing for j in 1:nout - state.halls_of_fame[j] = HallOfFame(options, T, L) + state.halls_of_fame[j] = HallOfFame(options, datasets[j]) end else # Recompute losses for the hall of fame, in # case the dataset changed: for j in eachindex(init_hall_of_fame, datasets, state.halls_of_fame) - hof = init_hall_of_fame[j] + hof = strip_metadata(init_hall_of_fame[j], options, datasets[j]) for member in hof.members[hof.exists] score, result_loss = score_func(datasets[j], member, options) member.score = score @@ -750,17 +792,17 @@ function _initialize_search!( saved_pop = load_saved_population(saved_state; out=j, pop=i) new_pop = if saved_pop !== nothing && length(saved_pop.members) == options.population_size - saved_pop::Population{T,L,N} + _saved_pop = strip_metadata(saved_pop, options, datasets[j]) ## Update losses: - for member in saved_pop.members + for member in _saved_pop.members score, result_loss = score_func(datasets[j], member, options) member.score = score member.loss = result_loss end - copy_pop = copy(saved_pop) + copy_pop = copy(_saved_pop) @sr_spawner( begin - (copy_pop, HallOfFame(options, T, L), RecordType(), 0.0) + (copy_pop, HallOfFame(options, datasets[j]), RecordType(), 0.0) end, parallelism = ropt.parallelism, worker_idx = worker_idx @@ -778,8 +820,9 @@ function _initialize_search!( nlength=3, options=options, nfeatures=datasets[j].nfeatures, + idea_database=idea_database_all[j], ), - HallOfFame(options, T, L), + HallOfFame(options, datasets[j]), RecordType(), Float64(options.population_size), ) @@ -794,7 +837,11 @@ function _initialize_search!( return nothing end function _warmup_search!( - state::SearchState{T,L,N}, datasets, ropt::RuntimeOptions, options::Options + state::SearchState{T,L,N}, + datasets, + ropt::RuntimeOptions, + options::Options, + idea_database_all, ) where {T,L,N} nout = length(datasets) for j in 1:nout, i in 1:(options.populations) @@ -825,6 +872,7 @@ function _warmup_search!( ropt.verbosity, cur_maxsize, running_search_statistics=c_rss, + idea_database=idea_database_all[j], )::DefaultWorkerOutputType{Population{T,L,N},HallOfFame{T,L,N}} end, parallelism = ropt.parallelism, @@ -835,7 +883,11 @@ function _warmup_search!( return nothing end function _main_search_loop!( - state::SearchState{T,L,N}, datasets, ropt::RuntimeOptions, options::Options + state::SearchState{T,L,N}, + datasets, + ropt::RuntimeOptions, + options::Options, + idea_database_all, ) where {T,L,N} ropt.verbosity > 0 && @info "Started!" nout = length(datasets) @@ -863,11 +915,19 @@ function _main_search_loop!( end kappa = 0 resource_monitor = ResourceMonitor(; - absolute_start_time=time(), # Storing n times as many monitoring intervals as populations seems like it will # help get accurate resource estimates: - num_intervals_to_store=options.populations * 100 * nout, + max_recordings=options.populations * 100 * nout, + start_reporting_at=options.populations * 3 * nout, + window_size=options.populations * 2 * nout, ) + n_iterations = 0 + if options.llm_options.active + open(options.llm_options.llm_recorder_dir * "n_iterations.txt", "a") do file + write(file, "- " * string(div(n_iterations, options.populations)) * "\n") + end + end + worst_members = Vector{PopMember}() while sum(state.cycles_remaining) > 0 kappa += 1 if kappa > options.populations * nout @@ -875,6 +935,7 @@ function _main_search_loop!( end # nout, populations: j, i = state.task_order[kappa] + idea_database = idea_database_all[j] # Check if error on population: if ropt.parallelism in (:multiprocessing, :multithreading) @@ -890,11 +951,16 @@ function _main_search_loop!( else true end + record_channel_state!(resource_monitor, population_ready) + # Don't start more if this output has finished its cycles: # TODO - this might skip extra cycles? population_ready &= (state.cycles_remaining[j] > 0) if population_ready - start_work_monitor!(resource_monitor) + if n_iterations % options.populations == 0 + worst_members = Vector{PopMember}() + end + n_iterations += 1 # Take the fetch operation from the channel since its ready (cur_pop, best_seen, cur_record, cur_num_evals) = if ropt.parallelism in ( @@ -913,10 +979,17 @@ function _main_search_loop!( dataset = datasets[j] cur_maxsize = state.cur_maxsizes[j] + worst_member = nothing for member in cur_pop.members + if worst_member == nothing || worst_member.loss < member.loss + worst_member = member + end size = compute_complexity(member, options) update_frequencies!(state.all_running_search_statistics[j]; size) end + if worst_member != nothing && worst_member.loss > 100 # if the worst of population is good then thats still good to keep + push!(worst_members, worst_member) + end #! format: off update_hall_of_fame!(state.halls_of_fame[j], cur_pop.members, options) update_hall_of_fame!(state.halls_of_fame[j], best_seen.members[best_seen.exists], options) @@ -924,6 +997,11 @@ function _main_search_loop!( # Dominating pareto curve - must be better than all simpler equations dominating = calculate_pareto_frontier(state.halls_of_fame[j]) + if options.llm_options.active && + options.llm_options.prompt_evol && + (n_iterations % options.populations == 0) + update_idea_database(idea_database, dominating, worst_members, options) + end if options.save_to_file save_to_file(dominating, nout, j, dataset, options) @@ -975,6 +1053,8 @@ function _main_search_loop!( ropt.verbosity, cur_maxsize, running_search_statistics=c_rss, + dominating=dominating, + idea_database=idea_database, ) end, parallelism = ropt.parallelism, @@ -989,7 +1069,6 @@ function _main_search_loop!( state.cur_maxsizes[j] = get_cur_maxsize(; options, ropt.total_cycles, cycles_remaining=state.cycles_remaining[j] ) - stop_work_monitor!(resource_monitor) move_window!(state.all_running_search_statistics[j]) if ropt.progress head_node_occupation = estimate_work_fraction(resource_monitor) @@ -1060,6 +1139,11 @@ function _main_search_loop!( end ################################################################ end + if options.llm_options.active + open(options.llm_options.llm_recorder_dir * "n_iterations.txt", "a") do file + write(file, "- " * string(div(n_iterations, options.populations)) * "\n") + end + end return nothing end function _tear_down!(state::SearchState, ropt::RuntimeOptions, options::Options) @@ -1076,10 +1160,20 @@ function _tear_down!(state::SearchState, ropt::RuntimeOptions, options::Options) @recorder json3_write(state.record[], options.recorder_file) return nothing end -function _format_output(state::SearchState, ropt::RuntimeOptions) - out_hof = (ropt.dim_out == 1 ? only(state.halls_of_fame) : state.halls_of_fame) +function _format_output( + state::SearchState, datasets, ropt::RuntimeOptions, options::Options +) + nout = length(datasets) + out_hof = if ropt.dim_out == 1 + embed_metadata(only(state.halls_of_fame), options, only(datasets)) + else + map(j -> embed_metadata(state.halls_of_fame[j], options, datasets[j]), 1:nout) + end if ropt.return_state - return (state.last_pops, out_hof) + return ( + map(j -> embed_metadata(state.last_pops[j], options, datasets[j]), 1:nout), + out_hof, + ) else return out_hof end @@ -1095,6 +1189,8 @@ end verbosity, cur_maxsize::Int, running_search_statistics, + dominating=nothing, + idea_database=nothing, ) where {T,L,N} record = RecordType() @recorder record["out$(out)_pop$(pop)"] = RecordType( @@ -1111,6 +1207,8 @@ end verbosity=verbosity, options=options, record=record, + dominating=dominating, + idea_database=idea_database, ) num_evals += evals_from_cycle out_pop, evals_from_optimize = optimize_and_simplify_population( @@ -1129,16 +1227,18 @@ end end include("MLJInterface.jl") -using .MLJInterfaceModule: SRRegressor, MultitargetSRRegressor +using .MLJInterfaceModule: LaSRRegressor, MultitargetLaSRRegressor function __init__() @require_extensions end -macro ignore(args...) end # Hack to get static analysis to work from within tests: @ignore include("../test/runtests.jl") +# TODO: Hack to force ConstructionBase version +using ConstructionBase: ConstructionBase as _ + include("precompile.jl") redirect_stdout(devnull) do redirect_stderr(devnull) do diff --git a/src/Utils.jl b/src/Utils.jl index 9636252e6..a667b6987 100644 --- a/src/Utils.jl +++ b/src/Utils.jl @@ -2,7 +2,9 @@ module UtilsModule using Printf: @printf -using MacroTools: splitdef, combinedef +using MacroTools: splitdef + +macro ignore(args...) end const pseudo_time = Ref(0) @@ -190,6 +192,49 @@ macro constfield(ex) return esc(VERSION < v"1.8.0" ? ex : Expr(:const, ex)) end +json3_write(args...) = error("Please load the JSON3.jl package.") + +""" + PerThreadCache{T} + +A cache that is efficient for multithreaded code, and works +by having a separate cache for each thread. This allows +us to avoid repeated locking. We only need to lock the cache +when resizing to the number of threads. +""" +struct PerThreadCache{T} + x::Vector{T} + num_threads::Ref{Int} + lock::Threads.SpinLock + + PerThreadCache{T}() where {T} = new(Vector{T}(undef, 1), Ref(1), Threads.SpinLock()) +end + +function _get_thread_cache(cache::PerThreadCache{T}) where {T} + if cache.num_threads[] < Threads.nthreads() + Base.@lock cache.lock begin + # The reason we have this extra `.len[]` parameter is to avoid + # a race condition between a thread resizing the array concurrent + # to the check above. Basically we want to make sure the array is + # always big enough by the time we get to using it. Since `.len[]` + # is set last, we can safely use the array. + if cache.num_threads[] < Threads.nthreads() + resize!(cache.x, Threads.nthreads()) + cache.num_threads[] = Threads.nthreads() + end + end + end + threadid = Threads.threadid() + if !isassigned(cache.x, threadid) + cache.x[threadid] = eltype(cache.x)() + end + return cache.x[threadid] +end +function Base.get!(f::F, cache::PerThreadCache, key) where {F<:Function} + thread_cache = _get_thread_cache(cache) + return get!(f, thread_cache, key) +end + # https://discourse.julialang.org/t/performance-of-hasmethod-vs-try-catch-on-methoderror/99827/14 # Faster way to catch method errors: @enum IsGood::Int8 begin @@ -197,29 +242,26 @@ end Bad Undefined end -const SafeFunctions = Dict{Type,IsGood}() -const SafeFunctionsLock = Threads.SpinLock() +const SafeFunctions = PerThreadCache{Dict{Type,IsGood}}() function safe_call(f::F, x::T, default::D) where {F,T<:Tuple,D} - status = get(SafeFunctions, Tuple{F,T}, Undefined) + thread_cache = _get_thread_cache(SafeFunctions) + status = get(thread_cache, Tuple{F,T}, Undefined) status == Good && return (f(x...)::D, true) status == Bad && return (default, false) - return lock(SafeFunctionsLock) do - output = try - (f(x...)::D, true) - catch e - !isa(e, MethodError) && rethrow(e) - (default, false) - end - if output[2] - SafeFunctions[Tuple{F,T}] = Good - else - SafeFunctions[Tuple{F,T}] = Bad - end - return output + + output = try + (f(x...)::D, true) + catch e + !isa(e, MethodError) && rethrow(e) + (default, false) + end + if output[2] + thread_cache[Tuple{F,T}] = Good + else + thread_cache[Tuple{F,T}] = Bad end + return output end -json3_write(args...) = error("Please load the JSON3.jl package.") - end diff --git a/src/precompile.jl b/src/precompile.jl index 87442695d..df736dac0 100644 --- a/src/precompile.jl +++ b/src/precompile.jl @@ -41,7 +41,7 @@ function do_precompilation(::Val{mode}) where {mode} X = randn(T, 3, N) y = start ? randn(T, N) : randn(T, nout, N) @maybe_compile_workload mode begin - options = SymbolicRegression.Options(; + options = LibraryAugmentedSymbolicRegression.Options(; binary_operators=[+, *, /, -, ^], unary_operators=[sin, cos, exp, log, sqrt, abs], populations=3, diff --git a/test/LocalPreferences.toml b/test/LocalPreferences.toml index 78d43b5c7..ac5752f4a 100644 --- a/test/LocalPreferences.toml +++ b/test/LocalPreferences.toml @@ -2,6 +2,6 @@ instability_check = "error" instability_check_codegen = "min" -[SymbolicRegression] +[LaSR] instability_check = "error" instability_check_codegen = "min" diff --git a/test/Project.toml b/test/Project.toml index b6f96c2d9..fb83b7c8d 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,10 +1,12 @@ [deps] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" Bumper = "8ce10254-0962-460f-a3d8-1f77fea1446e" +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8" Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b" DynamicExpressions = "a40a106e-89c9-4ca8-8020-a735e8728b6b" DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1" @@ -19,9 +21,12 @@ Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" Suppressor = "fd094767-a336-5f1f-9728-57cf17d0bbfb" -SymbolicRegression = "8254be44-1295-4e6a-a16d-46603ac705cb" +LibraryAugmentedSymbolicRegression = "158930c3-947c-4174-974b-74b39e64a28f" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a" TestItems = "1c621080-faea-4a02-84b6-bbd5e436b8fe" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[preferences.LibraryAugmentedSymbolicRegression] +instability_check = "error" diff --git a/test/manual_distributed.jl b/test/manual_distributed.jl index 0d6490c6d..2bfe6aeeb 100644 --- a/test/manual_distributed.jl +++ b/test/manual_distributed.jl @@ -12,12 +12,12 @@ project_path = splitdir(Pkg.project().path)[1] end, ) end -@everywhere using SymbolicRegression +@everywhere using LibraryAugmentedSymbolicRegression @everywhere _inv(x::Float32)::Float32 = 1.0f0 / x X = rand(Float32, 5, 100) .+ 1 y = 1.2f0 .+ 2 ./ X[3, :] -options = SymbolicRegression.Options(; +options = LibraryAugmentedSymbolicRegression.Options(; default_params..., binary_operators=(+, *), unary_operators=(_inv,), populations=8 ) hallOfFame = equation_search( diff --git a/test/runtests.jl b/test/runtests.jl index dd2209874..52a941def 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,167 +2,180 @@ using TestItems: @testitem using TestItemRunner: @run_package_tests ENV["SYMBOLIC_REGRESSION_TEST"] = "true" -tags_to_run = let t = get(ENV, "SYMBOLIC_REGRESSION_TEST_SUITE", "unit,integration") +tags_to_run = let t = get(ENV, "SYMBOLIC_REGRESSION_TEST_SUITE", "part1,part2,part3") t = split(t, ",") t = map(Symbol, t) t end -@eval @run_package_tests filter = ti -> !isdisjoint(ti.tags, $tags_to_run) +@eval @run_package_tests filter = ti -> !isdisjoint(ti.tags, $tags_to_run) verbose = true -@testitem "JET tests" tags = [:integration, :jet] begin - test_jet_file = joinpath((@__DIR__), "test_jet.jl") - run(`$(Base.julia_cmd()) --startup-file=no $test_jet_file`) -end - -@testitem "Test custom operators and additional types" tags = [:unit] begin +# TODO: This is a very slow test +@testitem "Test custom operators and additional types" tags = [:part2] begin include("test_operators.jl") end -@testitem "Test tree construction and scoring" tags = [:unit] begin +@testitem "Test tree construction and scoring" tags = [:part3] begin include("test_tree_construction.jl") end -@testitem "Test SymbolicUtils interface" tags = [:unit] begin +include("test_graph_nodes.jl") + +@testitem "Test SymbolicUtils interface" tags = [:part1] begin include("test_symbolic_utils.jl") end -@testitem "Test constraints interface" tags = [:unit] begin +@testitem "Test constraints interface" tags = [:part2] begin include("test_constraints.jl") end -@testitem "Test custom losses" tags = [:unit] begin +@testitem "Test custom losses" tags = [:part1] begin include("test_losses.jl") end -@testitem "Test derivatives" tags = [:unit] begin +@testitem "Test derivatives" tags = [:part2] begin include("test_derivatives.jl") end +include("test_expression_derivatives.jl") -@testitem "Test simplification" tags = [:unit] begin +@testitem "Test simplification" tags = [:part3] begin include("test_simplification.jl") end -@testitem "Test printing" tags = [:unit] begin +@testitem "Test printing" tags = [:part1] begin include("test_print.jl") end -@testitem "Test validity of expression evaluation" tags = [:unit] begin +@testitem "Test validity of expression evaluation" tags = [:part2] begin include("test_evaluation.jl") end -@testitem "Test turbo mode with NaN" tags = [:unit] begin +@testitem "Test turbo mode with NaN" tags = [:part3] begin include("test_turbo_nan.jl") end -@testitem "Test validity of integer expression evaluation" tags = [:unit] begin +@testitem "Test validity of integer expression evaluation" tags = [:part1] begin include("test_integer_evaluation.jl") end -@testitem "Test tournament selection" tags = [:unit] begin +@testitem "Test tournament selection" tags = [:part2] begin include("test_prob_pick_first.jl") end -@testitem "Test crossover mutation" tags = [:unit] begin +@testitem "Test crossover mutation" tags = [:part3] begin include("test_crossover.jl") end -@testitem "Test NaN detection in evaluator" tags = [:unit] begin +# TODO: This is another very slow test +@testitem "Test NaN detection in evaluator" tags = [:part1] begin include("test_nan_detection.jl") end -@testitem "Test nested constraint checking" tags = [:unit] begin +@testitem "Test nested constraint checking" tags = [:part2] begin include("test_nested_constraints.jl") end -@testitem "Test complexity evaluation" tags = [:unit] begin +@testitem "Test complexity evaluation" tags = [:part3] begin include("test_complexity.jl") end -@testitem "Test options" tags = [:unit] begin +@testitem "Test options" tags = [:part1] begin include("test_options.jl") end -@testitem "Test hash of tree" tags = [:unit] begin +@testitem "Test hash of tree" tags = [:part2] begin include("test_hash.jl") end -@testitem "Test migration" tags = [:unit] begin +@testitem "Test migration" tags = [:part3] begin include("test_migration.jl") end -@testitem "Test deprecated options" tags = [:unit] begin +@testitem "Test deprecated options" tags = [:part1] begin include("test_deprecation.jl") end -@testitem "Test optimization mutation" tags = [:unit] begin +@testitem "Test optimization mutation" tags = [:part2] begin include("test_optimizer_mutation.jl") end -@testitem "Test RunningSearchStatistics" tags = [:unit] begin +@testitem "Test RunningSearchStatistics" tags = [:part3] begin include("test_search_statistics.jl") end -@testitem "Test utils" tags = [:unit] begin +@testitem "Test utils" tags = [:part1] begin include("test_utils.jl") end -@testitem "Test units" tags = [:integration] begin - include("test_units.jl") -end +include("test_units.jl") -@testitem "Dataset" tags = [:unit] begin +@testitem "Dataset" tags = [:part3] begin include("test_dataset.jl") end -@testitem "Test mixed settings." tags = [:integration] begin - include("test_mixed.jl") -end +include("test_mixed.jl") -@testitem "Testing fast-cycle and custom variable names" tags = [:integration] begin +@testitem "Testing fast-cycle and custom variable names" tags = [:part2] begin include("test_fast_cycle.jl") end -@testitem "Testing whether we can stop based on clock time." tags = [:integration] begin +@testitem "Testing whether we can stop based on clock time." tags = [:part3] begin include("test_stop_on_clock.jl") end -@testitem "Running README example." tags = [:integration] begin +@testitem "Running README example." tags = [:part1] begin + ENV["SYMBOLIC_REGRESSION_IS_TESTING"] = "true" include("../example.jl") end -@testitem "Testing whether the recorder works." tags = [:integration] begin +# TODO: This is the slowest test. +@testitem "Running parameterized function example." tags = [:part2] begin + ENV["SYMBOLIC_REGRESSION_IS_TESTING"] = "true" + include("../examples/parameterized_function.jl") +end + +@testitem "Testing whether the recorder works." tags = [:part3] begin include("test_recorder.jl") end -@testitem "Testing whether deterministic mode works." tags = [:integration] begin +@testitem "Testing whether deterministic mode works." tags = [:part1] begin include("test_deterministic.jl") end -@testitem "Testing whether early stop criteria works." tags = [:integration] begin +@testitem "Testing whether early stop criteria works." tags = [:part2] begin include("test_early_stop.jl") end -@testitem "Test MLJ integration" tags = [:integration] begin - include("test_mlj.jl") -end +include("test_mlj.jl") -@testitem "Testing whether we can move operators to workers." tags = [:integration] begin +@testitem "Testing whether we can move operators to workers." tags = [:part1] begin include("test_custom_operators_multiprocessing.jl") end -@testitem "Test whether the precompilation script works." tags = [:integration] begin +@testitem "Test whether the precompilation script works." tags = [:part2] begin include("test_precompilation.jl") end -@testitem "Test whether custom objectives work." tags = [:integration] begin +@testitem "Test whether custom objectives work." tags = [:part3] begin include("test_custom_objectives.jl") end -@testitem "Test abstract numbers" tags = [:integration] begin +@testitem "Test abstract numbers" tags = [:part1] begin include("test_abstract_numbers.jl") end -@testitem "Aqua tests" tags = [:integration, :aqua] begin +include("test_pretty_printing.jl") +include("test_expression_builder.jl") + +@testitem "Aqua tests" tags = [:part2, :aqua] begin include("test_aqua.jl") end + +@testitem "JET tests" tags = [:part1, :jet] begin + test_jet_file = joinpath((@__DIR__), "test_jet.jl") + run(`$(Base.julia_cmd()) --startup-file=no $test_jet_file`) +end + +@testitem "LLM Integration tests" tags = [:part3, :llm] begin + include("test_lasr_integration.jl") +end \ No newline at end of file diff --git a/test/test_abstract_numbers.jl b/test/test_abstract_numbers.jl index d1ca4412b..373cea5a7 100644 --- a/test/test_abstract_numbers.jl +++ b/test/test_abstract_numbers.jl @@ -1,4 +1,4 @@ -using SymbolicRegression +using LibraryAugmentedSymbolicRegression using Random include("test_params.jl") @@ -12,7 +12,7 @@ for T in (ComplexF16, ComplexF32, ComplexF64) early_stop(loss::L, c) where {L} = ((loss <= L(1e-2)) && (c <= 15)) - options = SymbolicRegression.Options(; + options = LibraryAugmentedSymbolicRegression.Options(; binary_operators=[+, *, -, /], unary_operators=[cos], populations=20, diff --git a/test/test_aqua.jl b/test/test_aqua.jl index 6a4153f66..6c8ef5442 100644 --- a/test/test_aqua.jl +++ b/test/test_aqua.jl @@ -1,6 +1,6 @@ -using SymbolicRegression +using LibraryAugmentedSymbolicRegression using Aqua -Aqua.test_all(SymbolicRegression; ambiguities=false) +Aqua.test_all(LibraryAugmentedSymbolicRegression; ambiguities=false) -VERSION >= v"1.9" && Aqua.test_ambiguities(SymbolicRegression) +VERSION >= v"1.9" && Aqua.test_ambiguities(LibraryAugmentedSymbolicRegression) diff --git a/test/test_complexity.jl b/test/test_complexity.jl index deaad6813..dc63aa302 100644 --- a/test/test_complexity.jl +++ b/test/test_complexity.jl @@ -1,5 +1,5 @@ println("Testing custom complexities.") -using SymbolicRegression, Test +using LibraryAugmentedSymbolicRegression, Test x1, x2, x3 = Node("x1"), Node("x2"), Node("x3") diff --git a/test/test_constraints.jl b/test/test_constraints.jl index 0a6306923..90ed581f1 100644 --- a/test/test_constraints.jl +++ b/test/test_constraints.jl @@ -1,6 +1,6 @@ using DynamicExpressions: count_depth -using SymbolicRegression -using SymbolicRegression: check_constraints +using LibraryAugmentedSymbolicRegression +using LibraryAugmentedSymbolicRegression: check_constraints include("test_params.jl") _inv(x) = 1 / x diff --git a/test/test_crossover.jl b/test/test_crossover.jl index c1da48bb0..2f72fc471 100644 --- a/test/test_crossover.jl +++ b/test/test_crossover.jl @@ -1,9 +1,9 @@ println("Testing crossover function.") -using SymbolicRegression -using SymbolicRegression: crossover_trees +using LibraryAugmentedSymbolicRegression +using LibraryAugmentedSymbolicRegression: crossover_trees include("test_params.jl") -options = SymbolicRegression.Options(; +options = LibraryAugmentedSymbolicRegression.Options(; default_params..., binary_operators=(+, *, /, -), unary_operators=(cos, exp), diff --git a/test/test_custom_objectives.jl b/test/test_custom_objectives.jl index 3696f1750..9b3a43760 100644 --- a/test/test_custom_objectives.jl +++ b/test/test_custom_objectives.jl @@ -1,4 +1,4 @@ -using SymbolicRegression +using LibraryAugmentedSymbolicRegression include("test_params.jl") def = quote diff --git a/test/test_custom_operators.jl b/test/test_custom_operators.jl index 50e72cd16..d0d6d4af0 100644 --- a/test/test_custom_operators.jl +++ b/test/test_custom_operators.jl @@ -1,4 +1,4 @@ -using SymbolicRegression +using LibraryAugmentedSymbolicRegression using Random # Test that we can work with custom operators: @@ -25,7 +25,7 @@ X = randn(MersenneTwister(0), Float32, 2, 10); # Now, test that we can work with operators defined in modules module A -using SymbolicRegression +using LibraryAugmentedSymbolicRegression using Random function my_func_a(x::T, y::T) where {T<:Real} diff --git a/test/test_custom_operators_multiprocessing.jl b/test/test_custom_operators_multiprocessing.jl index 2fca2298e..1cb481f88 100644 --- a/test/test_custom_operators_multiprocessing.jl +++ b/test/test_custom_operators_multiprocessing.jl @@ -1,4 +1,4 @@ -using SymbolicRegression +using LibraryAugmentedSymbolicRegression defs = quote _plus(x, y) = x + y @@ -22,7 +22,7 @@ end X = randn(Float32, 5, 100) y = _mult.(2, _cos.(X[4, :])) + _mult.(X[1, :], X[1, :]) -options = SymbolicRegression.Options(; +options = LibraryAugmentedSymbolicRegression.Options(; binary_operators=(_plus, _mult, _div, _min), unary_operators=(_cos, _exp), populations=20, diff --git a/test/test_dataset.jl b/test/test_dataset.jl index 9fdbcfd74..ededef6c3 100644 --- a/test/test_dataset.jl +++ b/test/test_dataset.jl @@ -1,4 +1,4 @@ -using SymbolicRegression +using LibraryAugmentedSymbolicRegression using DispatchDoctor: allow_unstable @testset "Dataset construction" begin diff --git a/test/test_deprecation.jl b/test/test_deprecation.jl index dda563741..036925c0c 100644 --- a/test/test_deprecation.jl +++ b/test/test_deprecation.jl @@ -1,4 +1,4 @@ -using SymbolicRegression +using LibraryAugmentedSymbolicRegression # Deprecated kwargs should still work: options = Options(; diff --git a/test/test_derivatives.jl b/test/test_derivatives.jl index ee2d4ddfd..8ee4d1cc9 100644 --- a/test/test_derivatives.jl +++ b/test/test_derivatives.jl @@ -1,5 +1,5 @@ -using SymbolicRegression -using SymbolicRegression: eval_diff_tree_array, eval_grad_tree_array +using LibraryAugmentedSymbolicRegression +using LibraryAugmentedSymbolicRegression: eval_diff_tree_array, eval_grad_tree_array using Random using Zygote using LinearAlgebra @@ -122,7 +122,7 @@ end println("Testing NodeIndex.") -using SymbolicRegression: get_constants, NodeIndex, index_constants +using LibraryAugmentedSymbolicRegression: get_scalar_constants, NodeIndex, index_constants options = Options(; binary_operators=(+, *, -, /, pow_abs2), unary_operators=(custom_cos, exp, sin) @@ -144,6 +144,6 @@ function check_tree( end end -@test check_tree(tree, index_constants(tree), get_constants(tree)) +@test check_tree(tree, index_constants(tree), first(get_scalar_constants(tree))) println("Done.") diff --git a/test/test_deterministic.jl b/test/test_deterministic.jl index d541eea70..deee1e8f7 100644 --- a/test/test_deterministic.jl +++ b/test/test_deterministic.jl @@ -1,53 +1,33 @@ -using SymbolicRegression +using LibraryAugmentedSymbolicRegression using Random -macro maybe_inferred(ex) - # Only get stable inference on Julia 1.10+ - return if VERSION >= v"1.10.0-DEV.0" - quote - @inferred $ex - end - else - quote - try - # Still want to test for any bugs (JuliaLang/julia#53761) - @inferred $ex - catch - $ex - end - end - end |> esc -end +X = 2 .* randn(MersenneTwister(0), Float32, 2, 1000) +y = 3 * cos.(X[2, :]) + X[1, :] .^ 2 .- 2 -begin - X = 2 .* randn(MersenneTwister(0), Float32, 2, 1000) - y = 3 * cos.(X[2, :]) + X[1, :] .^ 2 .- 2 +options = LibraryAugmentedSymbolicRegression.Options(; + binary_operators=(+, *, /, -), + unary_operators=(cos,), + crossover_probability=0.0, # required for recording, as not set up to track crossovers. + max_evals=10000, + deterministic=true, + seed=0, + verbosity=0, + progress=false, +) - options = SymbolicRegression.Options(; - binary_operators=(+, *, /, -), - unary_operators=(cos,), - crossover_probability=0.0, # required for recording, as not set up to track crossovers. - max_evals=10000, - deterministic=true, - seed=0, - verbosity=0, - progress=false, +all_outputs = [] +for i in 1:2 + hall_of_fame = equation_search( + X, + y; + niterations=5, + options=options, + parallelism=:serial, + v_dim_out=Val(1), + return_state=Val(false), ) - - all_outputs = [] - for i in 1:2 - hall_of_fame = @maybe_inferred equation_search( - X, - y; - niterations=5, - options=options, - parallelism=:serial, - v_dim_out=Val(1), - return_state=Val(false), - ) - dominating = calculate_pareto_frontier(hall_of_fame) - push!(all_outputs, dominating[end].tree) - end - - @test string(all_outputs[1]) == string(all_outputs[2]) + dominating = calculate_pareto_frontier(hall_of_fame) + push!(all_outputs, dominating[end].tree) end + +@test string(all_outputs[1]) == string(all_outputs[2]) diff --git a/test/test_early_stop.jl b/test/test_early_stop.jl index 3ba36e555..8f64689b3 100644 --- a/test/test_early_stop.jl +++ b/test/test_early_stop.jl @@ -1,11 +1,11 @@ -using SymbolicRegression +using LibraryAugmentedSymbolicRegression X = randn(Float32, 5, 100) y = 2 * cos.(X[4, :]) + X[1, :] .^ 2 early_stop(loss, c) = ((loss <= 1e-10) && (c <= 10)) -options = SymbolicRegression.Options(; +options = LibraryAugmentedSymbolicRegression.Options(; binary_operators=(+, *, /, -), unary_operators=(cos, exp), populations=20, diff --git a/test/test_evaluation.jl b/test/test_evaluation.jl index a3c05dd0c..0c38f0e58 100644 --- a/test/test_evaluation.jl +++ b/test/test_evaluation.jl @@ -1,4 +1,4 @@ -using SymbolicRegression +using LibraryAugmentedSymbolicRegression using Random include("test_params.jl") diff --git a/test/test_expression_builder.jl b/test/test_expression_builder.jl new file mode 100644 index 000000000..9de5299db --- /dev/null +++ b/test/test_expression_builder.jl @@ -0,0 +1,66 @@ +# This file tests particular functionality of ExpressionBuilderModule +@testitem "ParametricExpression" tags = [:part3] begin + using LibraryAugmentedSymbolicRegression + using LibraryAugmentedSymbolicRegression.ExpressionBuilderModule: + strip_metadata, embed_metadata, init_params + + options = Options() + ex = parse_expression( + :(x1 * p1); + expression_type=ParametricExpression, + operators=options.operators, + parameters=ones(2, 1) * 3, + parameter_names=["p1", "p2"], + variable_names=["x1"], + ) + X = ones(1, 1) * 2 + y = ones(1) + dataset = Dataset(X, y; extra=(; classes=[1])) + + @test ex isa ParametricExpression + @test ex(dataset.X, dataset.extra.classes) โ‰ˆ ones(1, 1) * 6 + + # Mistake in that we gave the wrong options! + @test_throws( + AssertionError( + "Need prototype to be of type $(options.expression_type), but got $(ex)::$(typeof(ex))", + ), + init_params(options, dataset, ex, Val(true)) + ) + + options = Options(; + expression_type=ParametricExpression, expression_options=(; max_parameters=2) + ) + + # Mistake in that we also gave the wrong number of parameter names! + pop!(ex.metadata.parameter_names) + @test_throws( + AssertionError( + "Mismatch between options.expression_options.max_parameters=$(options.expression_options.max_parameters) and prototype.metadata.parameter_names=$(ex.metadata.parameter_names)", + ), + init_params(options, dataset, ex, Val(true)) + ) + # So, we fix it: + push!(ex.metadata.parameter_names, "p2") + + @test ex.metadata.parameter_names == ["p1", "p2"] + @test keys(init_params(options, dataset, ex, Val(true))) == + (:operators, :variable_names, :parameters, :parameter_names) + + @test sprint(show, ex) == "x1 * p1" + stripped_ex = strip_metadata(ex, options, dataset) + # Stripping the metadata means that operations like `show` + # do not know what binary operator to use: + @test sprint(show, stripped_ex) == "binary_operator[4](x1, p1)" + + # However, it's important that parametric expressions are still parametric: + @test stripped_ex isa ParametricExpression + # And, that they still have the right parameters: + @test haskey(getfield(stripped_ex.metadata, :_data), :parameters) + @test stripped_ex.metadata.parameters โ‰ˆ ones(2, 1) * 3 + + # Now, test that we can embed metadata back in: + embedded_ex = embed_metadata(stripped_ex, options, dataset) + @test embedded_ex isa ParametricExpression + @test ex == embedded_ex +end diff --git a/test/test_expression_derivatives.jl b/test/test_expression_derivatives.jl new file mode 100644 index 000000000..78cfddbef --- /dev/null +++ b/test/test_expression_derivatives.jl @@ -0,0 +1,142 @@ +@testitem "Test derivatives" tags = [:part1] begin + using LibraryAugmentedSymbolicRegression + using Zygote: Zygote + using Random: MersenneTwister + + ex = @parse_expression( + x * x - cos(2.5 * y), + unary_operators = [cos], + binary_operators = [*, -, +], + variable_names = [:x, :y] + ) + + rng = MersenneTwister(0) + X = rand(rng, 2, 32) + + (ฮดy,) = Zygote.gradient(X) do X + x = @view X[1, :] + y = @view X[2, :] + + sum(i -> x[i] * x[i] - cos(2.5 * y[i]), eachindex(x)) + end + ฮดy_hat = ex'(X) + + @test ฮดy โ‰ˆ ฮดy_hat + + options2 = Options(; unary_operators=[sin], binary_operators=[+, *, -]) + (ฮดy2,) = Zygote.gradient(X) do X + x = @view X[1, :] + y = @view X[2, :] + + sum(i -> (x[i] + x[i]) * sin(2.5 + y[i]), eachindex(x)) + end + ฮดy2_hat = ex'(X, options2) + + @test ฮดy2 โ‰ˆ ฮดy2_hat +end + +@testitem "Test derivatives during optimization" tags = [:part1] begin + using LibraryAugmentedSymbolicRegression + using LibraryAugmentedSymbolicRegression.ConstantOptimizationModule: Evaluator, GradEvaluator + using DynamicExpressions + using Zygote: Zygote + using Random: MersenneTwister + using DifferentiationInterface: value_and_gradient + + rng = MersenneTwister(0) + X = rand(rng, 2, 32) + y = @. X[1, :] * X[1, :] - cos(2.6 * X[2, :]) + dataset = Dataset(X, y) + + options = Options(; + unary_operators=[cos], binary_operators=[+, *, -], autodiff_backend=:Zygote + ) + + ex = @parse_expression( + x * x - cos(2.5 * y), operators = options.operators, variable_names = [:x, :y] + ) + f = Evaluator(ex, last(get_scalar_constants(ex)), dataset, options, nothing) + fg! = GradEvaluator(f, options.autodiff_backend) + + @test f(first(get_scalar_constants(ex))) isa Float64 + + x = first(get_scalar_constants(ex)) + G = zero(x) + fg!(nothing, G, x) + @test G[] != 0 +end + +@testitem "Test derivatives of parametric expression during optimization" tags = [:part3] begin + using LibraryAugmentedSymbolicRegression + using LibraryAugmentedSymbolicRegression.ConstantOptimizationModule: + Evaluator, GradEvaluator, optimize_constants, specialized_options + using DynamicExpressions + using Zygote: Zygote + using Random: MersenneTwister + using DifferentiationInterface: value_and_gradient, AutoZygote, AutoEnzyme + enzyme_compatible = VERSION >= v"1.10.0" && VERSION < v"1.11.0-DEV.0" + @static if enzyme_compatible + using Enzyme: Enzyme + end + + rng = MersenneTwister(0) + X = rand(rng, 2, 32) + true_params = [0.5 2.0] + init_params = [0.1 0.2] + init_constants = [2.5, -0.5] + classes = rand(rng, 1:2, 32) + y = [ + X[1, i] * X[1, i] - cos(2.6 * X[2, i] - 0.2) + true_params[1, classes[i]] for + i in 1:32 + ] + + dataset = Dataset(X, y; extra=(; classes)) + + (true_val, (true_d_params, true_d_constants)) = + value_and_gradient(AutoZygote(), (init_params, init_constants)) do (params, c) + pred = [ + X[1, i] * X[1, i] - cos(c[1] * X[2, i] + c[2]) + params[1, classes[i]] for + i in 1:32 + ] + sum(abs2, pred .- y) / length(y) + end + + options = Options(; + unary_operators=[cos], binary_operators=[+, *, -], autodiff_backend=:Zygote + ) + + ex = @parse_expression( + x * x - cos(2.5 * y + -0.5) + p1, + operators = options.operators, + expression_type = ParametricExpression, + variable_names = ["x", "y"], + extra_metadata = (parameter_names=["p1"], parameters=init_params) + ) + + function test_backend(ex, @nospecialize(backend); allow_failure=false) + x0, refs = get_scalar_constants(ex) + G = zero(x0) + + f = Evaluator(ex, refs, dataset, specialized_options(options), nothing) + fg! = GradEvaluator(f, backend) + + @test f(x0) โ‰ˆ true_val + + try + val = fg!(nothing, G, x0) + @test val โ‰ˆ true_val + @test G โ‰ˆ vcat(true_d_constants[:], true_d_params[:]) + catch e + if allow_failure + @warn "Expected failure" e + else + rethrow(e) + end + end + end + + test_backend(ex, AutoZygote(); allow_failure=false) + @static if enzyme_compatible + test_backend(ex, AutoEnzyme(); allow_failure=true) + end +end diff --git a/test/test_fast_cycle.jl b/test/test_fast_cycle.jl index e7d2de845..91c7bd147 100644 --- a/test/test_fast_cycle.jl +++ b/test/test_fast_cycle.jl @@ -1,8 +1,8 @@ -using SymbolicRegression +using LibraryAugmentedSymbolicRegression using Random include("test_params.jl") -options = SymbolicRegression.Options(; +options = LibraryAugmentedSymbolicRegression.Options(; default_params..., binary_operators=(+, *), unary_operators=(cos,), @@ -46,7 +46,7 @@ print_tree(best.tree, options) println("Testing whether state saving works with changed loss function.") previous_loss = best.loss new_loss(x, y) = sum(abs2, x - y) * 0.1 -options = SymbolicRegression.Options(; +options = LibraryAugmentedSymbolicRegression.Options(; default_params..., binary_operators=(+, *), unary_operators=(cos,), diff --git a/test/test_graph_nodes.jl b/test/test_graph_nodes.jl index 639546dc0..82a612d60 100644 --- a/test/test_graph_nodes.jl +++ b/test/test_graph_nodes.jl @@ -1,26 +1,133 @@ -using SymbolicRegression +@testitem "GraphNode evaluation" tags = [:part1] begin + using LibraryAugmentedSymbolicRegression -options = Options(; binary_operators=[+, -, *, /], unary_operators=[cos, sin], maxsize=30) + options = Options(; + binary_operators=[+, -, *, /], unary_operators=[cos, sin], maxsize=30 + ) -x1, x2, x3 = [GraphNode(Float64; feature=i) for i in 1:3] + x1, x2, x3 = [GraphNode(Float64; feature=i) for i in 1:3] -base_tree = cos(x1 - 3.2) * x2 - x3 * copy(x3) -tree = sin(base_tree) + base_tree + base_tree = cos(x1 - 3.2) * x2 - x3 * copy(x3) + tree = sin(base_tree) + base_tree -X = randn(3, 50) -z = @. cos(X[1, :] - 3.2) * X[2, :] - X[3, :] * X[3, :] -y = @. sin(z) + z -dataset = Dataset(X, y) + X = randn(3, 50) + z = @. cos(X[1, :] - 3.2) * X[2, :] - X[3, :] * X[3, :] + y = @. sin(z) + z + dataset = Dataset(X, y) -tree(dataset.X, options) + tree(dataset.X, options) -eval_tree_array(tree, dataset.X, options) + eval_tree_array(tree, dataset.X, options) +end -@test compute_complexity(tree, options) == 12 -@test compute_complexity(tree, options; break_sharing=Val(true)) == 22 +@testitem "GraphNode complexity" tags = [:part1] begin + using LibraryAugmentedSymbolicRegression -pop = Population( - dataset, GraphNode{Float64}; nlength=3, options, nfeatures=3, population_size=100 -) + options = Options(; + binary_operators=[+, -, *, /], unary_operators=[cos, sin], maxsize=30 + ) + x1, x2, x3 = [GraphNode(Float64; feature=i) for i in 1:3] -equation_search([dataset], GraphNode; niterations=10000, options) + base_tree = cos(x1 - 3.2) * x2 - x3 * copy(x3) + tree = sin(base_tree) + base_tree + + @test compute_complexity(tree, options) == 12 + @test compute_complexity(tree, options; break_sharing=Val(true)) == 22 +end + +@testitem "GraphNode population" tags = [:part1] begin + using LibraryAugmentedSymbolicRegression + + options = Options(; + binary_operators=[+, -, *, /], + unary_operators=[cos, sin], + maxsize=30, + node_type=GraphNode, + ) + + X = randn(3, 50) + z = @. cos(X[1, :] - 3.2) * X[2, :] - X[3, :] * X[3, :] + y = @. sin(z) + z + dataset = Dataset(X, y) + + pop = Population(dataset; options, nlength=3, nfeatures=3, population_size=100) + @test pop isa Population{T,T,<:Expression{T,<:GraphNode{T}}} where {T} + + # Seems to not work yet: + # equation_search([dataset]; niterations=10, options) +end + +@testitem "GraphNode break connection mutation" tags = [:part1] begin + using LibraryAugmentedSymbolicRegression + using LibraryAugmentedSymbolicRegression.MutationFunctionsModule: break_random_connection! + using Random: MersenneTwister + + options = Options(; + binary_operators=[+, -, *, /], + unary_operators=[cos, sin], + maxsize=30, + node_type=GraphNode, + ) + + x1, x2, x3 = [GraphNode(Float64; feature=i) for i in 1:3] + base_tree = cos(x1 - 3.2) * x2 + tree = sin(base_tree) + base_tree + + ex = Expression(tree; operators=options.operators, variable_names=["x1", "x2", "x3"]) + + s = strip(sprint(print_tree, ex)) + @test s == "sin(cos(x1 - 3.2) * x2) + {(cos(x1 - 3.2) * x2)}" + + rng = MersenneTwister(0) + expressions = [copy(ex) for _ in 1:1000] + expressions = [break_random_connection!(ex, rng) for ex in expressions] + strings = [strip(sprint(print_tree, ex)) for ex in expressions] + strings = unique(strings) + @test Set(strings) == Set([ + "sin(cos(x1 - 3.2) * x2) + {(cos(x1 - 3.2) * x2)}", + "sin(cos(x1 - 3.2) * x2) + (cos(x1 - 3.2) * x2)", + ]) + # Either it breaks the connection or not +end + +@testitem "GraphNode form connection mutation" tags = [:part1] begin + using LibraryAugmentedSymbolicRegression + using LibraryAugmentedSymbolicRegression.MutationFunctionsModule: form_random_connection! + using Random: MersenneTwister + + options = Options(; + binary_operators=[+, -, *, /], + unary_operators=[cos, sin], + maxsize=30, + node_type=GraphNode, + ) + + x1, x2 = [GraphNode{Float64}(; feature=i) for i in 1:2] + + tree = cos(x1 * x2 + 1.5) + ex = Expression(tree; operators=options.operators, variable_names=["x1", "x2"]) + rng = MersenneTwister(0) + expressions = [copy(ex) for _ in 1:3_000] + expressions = [form_random_connection!(ex, rng) for ex in expressions] + strings = [strip(sprint(print_tree, ex)) for ex in expressions] + strings = sort(unique(strings); by=length) + + # All possible connections that can be made + @test Set(strings) == Set([ + "cos(x1)", + "cos(x2)", + "cos(1.5)", + "cos(x1 * x2)", + "cos(x2 + 1.5)", + "cos(x1 + 1.5)", + "cos(1.5 + {1.5})", + "cos((x1 * x2) + 1.5)", + "cos((x1 * x2) + {x2})", + "cos((x1 * x2) + {x1})", + "cos((x2 * {x2}) + 1.5)", + "cos((x1 * {x1}) + 1.5)", + "cos((x1 * 1.5) + {1.5})", + "cos((1.5 * x2) + {1.5})", + "cos((x1 * x2) + {(x1 * x2)})", + ]) +end diff --git a/test/test_hash.jl b/test/test_hash.jl index 5013efea9..a8c86db4b 100644 --- a/test/test_hash.jl +++ b/test/test_hash.jl @@ -1,4 +1,4 @@ -using SymbolicRegression +using LibraryAugmentedSymbolicRegression options = Options(; binary_operators=(+, *, ^, /, greater), unary_operators=(cos,)) @extend_operators options diff --git a/test/test_integer_evaluation.jl b/test/test_integer_evaluation.jl index 4b84f8be7..bd313ff0d 100644 --- a/test/test_integer_evaluation.jl +++ b/test/test_integer_evaluation.jl @@ -1,4 +1,4 @@ -using SymbolicRegression +using LibraryAugmentedSymbolicRegression using Random include("test_params.jl") diff --git a/test/test_jet.jl b/test/test_jet.jl index 28476febb..720a11aa2 100644 --- a/test/test_jet.jl +++ b/test/test_jet.jl @@ -1,6 +1,7 @@ -if VERSION < v"1.10.0" +if !(VERSION >= v"1.10.0" && VERSION < v"1.11.0-DEV.0") exit(0) end +# TODO: Check why is breaking on 1.11.0 dir = mktempdir() @@ -10,18 +11,23 @@ using Pkg @info "Creating environment..." Pkg.activate(dir; io=devnull) Pkg.develop(; path=dirname(@__DIR__), io=devnull) -Pkg.add(["JET", "Preferences"]; io=devnull) +Pkg.add(["JET", "Preferences", "DynamicExpressions"]; io=devnull) @info "Done!" using Preferences cd(dir) -Preferences.set_preferences!("SymbolicRegression", "instability_check" => "disable") +Preferences.set_preferences!( + "LibraryAugmentedSymbolicRegression", "instability_check" => "disable"; force=true +) +Preferences.set_preferences!( + "DynamicExpressions", "instability_check" => "disable"; force=true +) -using SymbolicRegression +using LibraryAugmentedSymbolicRegression using JET @info "Running tests..." -JET.test_package(SymbolicRegression; target_defined_modules=true) +JET.test_package(LibraryAugmentedSymbolicRegression; target_defined_modules=true) @info "Done!" @info "test_jet.jl finished" diff --git a/test/test_lasr_integration.jl b/test/test_lasr_integration.jl new file mode 100644 index 000000000..0274bc103 --- /dev/null +++ b/test/test_lasr_integration.jl @@ -0,0 +1,29 @@ +using LibraryAugmentedSymbolicRegression: LLMOptions, Options + +# test that we can partially specify LLMOptions +op1 = LLMOptions(active=false) +@test op1.active == false + +# test that we can fully specify LLMOptions +op2 = LLMOptions( + active=true, + weights=LLMWeights(llm_mutate=0.5, llm_crossover=0.3, llm_gen_random=0.2), + num_pareto_context=5, + prompt_evol=true, + prompt_concepts=true, + api_key="vllm_api.key", + model="modelx", + api_kwargs=Dict("url" => "http://localhost:11440/v1"), + http_kwargs=Dict("retries" => 3, "readtimeout" => 3600), + llm_recorder_dir="test/", + llm_context="test", + var_order=nothing, + idea_threshold=30 +) +@test op2.active == true + +# test that we can pass LLMOptions to Options +llm_opt = LLMOptions(active=false) +op = Options(; optimizer_options=(iterations=16, f_calls_limit=100, x_tol=1e-16), llm_options=llm_opt) +@test isa(op.llm_options, LLMOptions) +println("Passed.") diff --git a/test/test_losses.jl b/test/test_losses.jl index 9277ed23c..f6c4cbc60 100644 --- a/test/test_losses.jl +++ b/test/test_losses.jl @@ -1,10 +1,10 @@ -using SymbolicRegression -using SymbolicRegression: eval_loss +using LibraryAugmentedSymbolicRegression +using LibraryAugmentedSymbolicRegression: eval_loss using Random include("test_params.jl") -_loss = SymbolicRegression.LossFunctionsModule._loss -_weighted_loss = SymbolicRegression.LossFunctionsModule._weighted_loss +_loss = LibraryAugmentedSymbolicRegression.LossFunctionsModule._loss +_weighted_loss = LibraryAugmentedSymbolicRegression.LossFunctionsModule._weighted_loss customloss(x, y) = abs(x - y)^2.5 customloss(x, y, w) = w * (abs(x - y)^2.5) diff --git a/test/test_migration.jl b/test/test_migration.jl index 13c69753e..785115f38 100644 --- a/test/test_migration.jl +++ b/test/test_migration.jl @@ -1,4 +1,7 @@ -using SymbolicRegression +using LibraryAugmentedSymbolicRegression +using LibraryAugmentedSymbolicRegression: strip_metadata +using DynamicExpressions: get_tree +using Test using Random: seed! seed!(0) @@ -10,14 +13,20 @@ options = Options(); population1 = Population( X, y; population_size=100, options=options, nfeatures=5, nlength=10 ) +dataset = Dataset(X, y) tree = Node(1, Node(; val=1.0), Node(; feature=2) * 3.2) @test !(hash(tree) in [hash(p.tree) for p in population1.members]) -SymbolicRegression.MigrationModule.migrate!( - [PopMember(tree, 0.0, Inf, options)] => population1, options; frac=0.5 +ex = @parse_expression($tree, operators = options.operators, variable_names = [:x1, :x2],) +ex = strip_metadata(ex, options, dataset) + +LibraryAugmentedSymbolicRegression.MigrationModule.migrate!( + [PopMember(ex, 0.0, Inf, options; deterministic=false)] => population1, + options; + frac=0.5, ) # Now we see that the tree is in the population: -@test tree in [p.tree for p in population1.members] +@test tree in [get_tree(p.tree) for p in population1.members] diff --git a/test/test_mixed.jl b/test/test_mixed.jl index 1da2b1a3e..a2d047166 100644 --- a/test/test_mixed.jl +++ b/test/test_mixed.jl @@ -1,150 +1,39 @@ -using SymbolicRegression -using SymbolicRegression: string_tree -using Random, Bumper, LoopVectorization -include("test_params.jl") +@testitem "Search with batching & weighted & serial & progress bar & warmup & BFGS" tags = [ + :part1 +] begin + include("test_mixed_utils.jl") + test_mixed(0, true, true, :serial) +end -for i in 0:5 - local options, X, y, tree - batching = i in [0, 1] - weighted = i in [0, 2] +@testitem "Search with multiprocessing & batching & multi-output & use_frequency & string-specified parallelism" tags = [ + :part2 +] begin + include("test_mixed_utils.jl") + test_mixed(1, true, false, :multiprocessing) +end - numprocs = 2 - progress = false - warmup_maxsize_by = 0.0f0 - optimizer_algorithm = "NelderMead" - multi = false - tournament_selection_p = 1.0 - parallelism = :multiprocessing - crossover_probability = 0.0f0 - skip_mutation_failures = false - use_frequency = false - use_frequency_in_tournament = false - turbo = false - bumper = false - T = Float32 - print("Testing with batching=$(batching) and weighted=$(weighted), ") - if i == 0 - println("with serial & progress bar & warmup & BFGS") - numprocs = nothing #Try serial computation here. - parallelism = :serial - progress = true #Also try the progress bar. - warmup_maxsize_by = 0.5f0 #Smaller maxsize at first, build up slowly - optimizer_algorithm = "BFGS" - tournament_selection_p = 0.8 - elseif i == 1 - println("with multi-output and use_frequency and string-specified parallelism.") - multi = true - use_frequency = true - parallelism = "multiprocessing" - elseif i == 3 - println( - "with multi-threading and crossover and use_frequency_in_tournament and bumper=true", - ) - parallelism = :multithreading - numprocs = nothing - crossover_probability = 0.02f0 - use_frequency_in_tournament = true - bumper = true - elseif i == 4 - println( - "with crossover and skip mutation failures and both frequencies options, and Float16 type", - ) - crossover_probability = 0.02f0 - skip_mutation_failures = true - use_frequency = true - use_frequency_in_tournament = true - T = Float16 - elseif i == 5 - println("with default hyperparameters, Float64 type, and turbo=true") - T = Float64 - turbo = true - end - if i == 5 - options = SymbolicRegression.Options(; - unary_operators=(cos,), - batching=batching, - parsimony=0.0f0, # Required for scoring - ) - else - options = SymbolicRegression.Options(; - default_params..., - binary_operators=(+, *), - unary_operators=(cos,), - populations=4, - batching=batching, - crossover_probability=crossover_probability, - skip_mutation_failures=skip_mutation_failures, - seed=0, - progress=progress, - warmup_maxsize_by=warmup_maxsize_by, - optimizer_algorithm=optimizer_algorithm, - tournament_selection_p=tournament_selection_p, - parsimony=0.0f0, - use_frequency=use_frequency, - use_frequency_in_tournament=use_frequency_in_tournament, - turbo=turbo, - bumper=bumper, - ) - end +@testitem "Search with multi-threading & default settings" tags = [:part3] begin + include("test_mixed_utils.jl") + test_mixed(2, false, true, :multithreading) +end - X = randn(MersenneTwister(0), T, 5, 100) - if weighted - mask = rand(100) .> 0.5 - weights = map(x -> convert(T, x), mask) - # Completely different function superimposed - need - # to use correct weights to figure it out! - y = (2 .* cos.(X[4, :])) .* weights .+ (1 .- weights) .* (5 .* X[2, :]) - hallOfFame = equation_search( - X, - y; - weights=weights, - niterations=2, - options=options, - parallelism=parallelism, - numprocs=numprocs, - ) - dominating = [calculate_pareto_frontier(hallOfFame)] - else - y = 2 * cos.(X[4, :]) - niterations = 2 - if multi - # Copy the same output twice; make sure we can find it twice - y = repeat(y, 1, 2) - y = transpose(y) - niterations = 20 - end - hallOfFame = equation_search( - X, - y; - niterations=niterations, - options=options, - parallelism=parallelism, - numprocs=numprocs, - ) - dominating = if multi - [calculate_pareto_frontier(hallOfFame[j]) for j in 1:2] - else - [calculate_pareto_frontier(hallOfFame)] - end - end +@testitem "Search with multi-threading & weighted & crossover & use_frequency_in_tournament & bumper" tags = [ + :part1 +] begin + include("test_mixed_utils.jl") + test_mixed(3, false, false, :multithreading) +end - # For brevity, always assume multi-output in this test: - for dom in dominating - @test length(dom) > 0 - best = dom[end] - # Assert we created the correct type of trees: - @test typeof(best.tree) == Node{T} +@testitem "Search with multi-threading & crossover & skip mutation failures & both frequencies options & Float16 type" tags = [ + :part2 +] begin + include("test_mixed_utils.jl") + test_mixed(4, false, false, :multithreading) +end - # Test the score - @test best.loss < maximum_residual - # Test the actual equation found: - testX = randn(MersenneTwister(1), T, 5, 100) - true_y = 2 * cos.(testX[4, :]) - predicted_y, flag = eval_tree_array(best.tree, testX, options) - @test flag - @test sum(abs, true_y .- predicted_y) < maximum_residual - # eval evaluates inside global - end - - println("Passed.") -end # for i=1... +@testitem "Search with multiprocessing & default hyperparameters & Float64 type & turbo" tags = [ + :part3 +] begin + include("test_mixed_utils.jl") + test_mixed(5, false, false, :multiprocessing) +end diff --git a/test/test_mixed_utils.jl b/test/test_mixed_utils.jl new file mode 100644 index 000000000..8021e5295 --- /dev/null +++ b/test/test_mixed_utils.jl @@ -0,0 +1,142 @@ +using LibraryAugmentedSymbolicRegression +using LibraryAugmentedSymbolicRegression: string_tree +using Random, Bumper, LoopVectorization + +include("test_params.jl") + +function test_mixed(i, batching::Bool, weighted::Bool, parallelism) + progress = false + warmup_maxsize_by = 0.0f0 + optimizer_algorithm = "NelderMead" + multi = false + tournament_selection_p = 1.0 + crossover_probability = 0.0f0 + skip_mutation_failures = false + use_frequency = false + use_frequency_in_tournament = false + turbo = false + bumper = false + T = Float32 + + if i == 0 + progress = true #Also try the progress bar. + warmup_maxsize_by = 0.5f0 #Smaller maxsize at first, build up slowly + optimizer_algorithm = "BFGS" + tournament_selection_p = 0.8 + elseif i == 1 + multi = true + use_frequency = true + elseif i == 3 + crossover_probability = 0.02f0 + use_frequency_in_tournament = true + bumper = true + elseif i == 4 + crossover_probability = 0.02f0 + skip_mutation_failures = true + use_frequency = true + use_frequency_in_tournament = true + T = Float16 + elseif i == 5 + T = Float64 + turbo = true + end + + numprocs = parallelism == :multiprocessing ? 2 : nothing + + options = if i == 5 + LibraryAugmentedSymbolicRegression.Options(; + unary_operators=(cos,), + batching=batching, + parsimony=0.0f0, # Required for scoring + early_stop_condition=1e-6, + ) + else + LibraryAugmentedSymbolicRegression.Options(; + default_params..., + binary_operators=(+, *), + unary_operators=(cos,), + populations=4, + batching=batching, + crossover_probability=crossover_probability, + skip_mutation_failures=skip_mutation_failures, + seed=0, + progress=progress, + warmup_maxsize_by=warmup_maxsize_by, + optimizer_algorithm=optimizer_algorithm, + tournament_selection_p=tournament_selection_p, + parsimony=0.0f0, + use_frequency=use_frequency, + use_frequency_in_tournament=use_frequency_in_tournament, + turbo=turbo, + bumper=bumper, + early_stop_condition=1e-6, + ) + end + + X = randn(MersenneTwister(0), T, 5, 100) + + (y, hallOfFame, dominating) = if weighted + mask = rand(100) .> 0.5 + weights = map(x -> convert(T, x), mask) + # Completely different function superimposed - need + # to use correct weights to figure it out! + y = (2 .* cos.(X[4, :])) .* weights .+ (1 .- weights) .* (5 .* X[2, :]) + hallOfFame = equation_search( + X, + y; + weights=weights, + niterations=2, + options=options, + parallelism=parallelism, + numprocs=numprocs, + ) + dominating = [calculate_pareto_frontier(hallOfFame)] + + (y, hallOfFame, dominating) + else + y = 2 * cos.(X[4, :]) + niterations = 2 + if multi + # Copy the same output twice; make sure we can find it twice + y = repeat(y, 1, 2) + y = transpose(y) + niterations = 20 + end + hallOfFame = equation_search( + X, + y; + niterations=niterations, + options=options, + parallelism=parallelism, + numprocs=numprocs, + ) + dominating = if multi + [calculate_pareto_frontier(hallOfFame[j]) for j in 1:2] + else + [calculate_pareto_frontier(hallOfFame)] + end + + (y, hallOfFame, dominating) + end + + # For brevity, always assume multi-output in this test: + for dom in dominating + @test length(dom) > 0 + best = dom[end] + # Assert we created the correct type of trees: + @test node_type(typeof(best.tree)) == Node{T} + + # Test the score + @test best.loss < maximum_residual + # Test the actual equation found: + testX = randn(MersenneTwister(1), T, 5, 100) + true_y = 2 * cos.(testX[4, :]) + predicted_y, flag = eval_tree_array(best.tree, testX, options) + + @test flag + @test sum(abs, true_y .- predicted_y) < maximum_residual + # eval evaluates inside global + end + + return println("Passed.") +end diff --git a/test/test_mlj.jl b/test/test_mlj.jl index ca0416209..62f5d6ea8 100644 --- a/test/test_mlj.jl +++ b/test/test_mlj.jl @@ -1,24 +1,10 @@ -using SymbolicRegression: SymbolicRegression -using SymbolicRegression: - Node, SRRegressor, MultitargetSRRegressor, node_to_symbolic, symbolic_to_node -using MLJTestInterface: MLJTestInterface as MTI -using MLJBase: machine, fit!, report, predict -using SymbolicUtils: SymbolicUtils -using Suppressor: @capture_err - -macro quiet(ex) - return quote - redirect_stderr(devnull) do - $ex - end - end |> esc -end - -stop_kws = (; early_stop_condition=(loss, complexity) -> loss < 1e-7) +@testitem "Generic interface tests" tags = [:part1] begin + using LibraryAugmentedSymbolicRegression + using MLJTestInterface: MLJTestInterface as MTI + include("test_params.jl") -@testset "Generic interface tests" begin failures, summary = MTI.test( - [SRRegressor], MTI.make_regression()...; mod=@__MODULE__, verbosity=0, throw=true + [LaSRRegressor], MTI.make_regression()...; mod=@__MODULE__, verbosity=0, throw=true ) @test isempty(failures) @@ -27,126 +13,170 @@ stop_kws = (; early_stop_condition=(loss, complexity) -> loss < 1e-7) (X, Y) = MTI.table.((X, Y)) w = ones(100) failures, summary = MTI.test( - [MultitargetSRRegressor], X, Y, w; mod=@__MODULE__, verbosity=0, throw=true + [MultitargetLaSRRegressor], X, Y, w; mod=@__MODULE__, verbosity=0, throw=true ) @test isempty(failures) end -@testset "Variable names" begin - @testset "Single outputs" begin - X = (a=rand(32), b=rand(32)) - y = X.a .^ 2.1 - # We also make sure the deprecated npop and npopulations still work: - model = SRRegressor(; niterations=10, npop=33, npopulations=15, stop_kws...) - mach = machine(model, X, y) - fit!(mach) - rep = report(mach) - @test occursin("a", rep.equation_strings[rep.best_idx]) - ypred_good = predict(mach, X) - @test sum(abs2, predict(mach, X) .- y) / length(y) < 1e-5 - - @testset "Check that we can choose the equation" begin - ypred_same = predict(mach, (data=X, idx=rep.best_idx)) - @test ypred_good == ypred_same - - ypred_bad = predict(mach, (data=X, idx=1)) - @test ypred_good != ypred_bad - end - - @testset "Smoke test SymbolicUtils" begin - eqn = node_to_symbolic(rep.equations[rep.best_idx], model) - n = symbolic_to_node(eqn, model) - eqn2 = convert(SymbolicUtils.Symbolic, n, model) - n2 = convert(Node, eqn2, model) - end - end +@testitem "Variable names - single outputs" tags = [:part3] begin + using LibraryAugmentedSymbolicRegression + using LibraryAugmentedSymbolicRegression: Node + using MLJBase + using SymbolicUtils + using Random: MersenneTwister - @testset "Multiple outputs" begin - X = (a=rand(32), b=rand(32)) - y = X.a .^ 2.1 - model = MultitargetSRRegressor(; niterations=10, stop_kws...) - mach = machine(model, X, reduce(hcat, [reshape(y, :, 1) for i in 1:3])) - fit!(mach) - rep = report(mach) - @test all( - eq -> occursin("a", eq), [rep.equation_strings[i][rep.best_idx[i]] for i in 1:3] - ) - ypred_good = predict(mach, X) - - @testset "Test that we can choose the equation" begin - ypred_same = predict(mach, (data=X, idx=rep.best_idx)) - @test ypred_good == ypred_same - - ypred_bad = predict(mach, (data=X, idx=[1, 1, 1])) - @test ypred_good != ypred_bad - - ypred_mixed = predict(mach, (data=X, idx=[rep.best_idx[1], 1, rep.best_idx[3]])) - @test ypred_mixed == hcat(ypred_good[:, 1], ypred_bad[:, 2], ypred_good[:, 3]) - - @test_throws AssertionError predict(mach, (data=X,)) - VERSION >= v"1.8" && - @test_throws "If specifying an equation index during" predict( - mach, (data=X,) - ) - VERSION >= v"1.8" && - @test_throws "If specifying an equation index during" predict( - mach, (X=X, idx=1) - ) - end - end + include("test_params.jl") - @testset "Named outputs" begin - X = (b1=randn(32), b2=randn(32)) - Y = (c1=X.b1 .* X.b2, c2=X.b1 .+ X.b2) - w = ones(32) - model = MultitargetSRRegressor(; niterations=10, stop_kws...) - mach = machine(model, X, Y, w) - fit!(mach) - test_outs = predict(mach, X) - @test isempty(setdiff((:c1, :c2), keys(test_outs))) - @test_throws AssertionError predict(mach, (a1=randn(32), b2=randn(32))) - VERSION >= v"1.8" && @test_throws "Variable names do not match fitted" predict( - mach, (b1=randn(32), a2=randn(32)) - ) - end + stop_kws = (; early_stop_condition=(loss, complexity) -> loss < 1e-5) + + rng = MersenneTwister(0) + X = (a=rand(rng, 32), b=rand(rng, 32)) + y = X.a .^ 2.1 + # We also make sure the deprecated npop and npopulations still work: + model = LaSRRegressor(; niterations=10, npop=1000, npopulations=15, stop_kws...) + mach = machine(model, X, y) + fit!(mach) + rep = report(mach) + @test occursin("a", rep.equation_strings[rep.best_idx]) + ypred_good = predict(mach, X) + @test sum(abs2, predict(mach, X) .- y) / length(y) < 1e-5 + + # Check that we can choose the equation + ypred_same = predict(mach, (data=X, idx=rep.best_idx)) + @test ypred_good == ypred_same + + ypred_bad = predict(mach, (data=X, idx=1)) + @test ypred_good != ypred_bad + + # Smoke test SymbolicUtils + eqn = node_to_symbolic(rep.equations[rep.best_idx], model) + n = symbolic_to_node(eqn, model) + eqn2 = convert(SymbolicUtils.Symbolic, n, model) + n2 = convert(Node, eqn2, model) end -@testset "Good predictions" begin - X = randn(100, 3) +@testitem "Variable names - multiple outputs" tags = [:part1] begin + using LibraryAugmentedSymbolicRegression + using MLJBase + using Random: MersenneTwister + + include("test_params.jl") + + stop_kws = (; early_stop_condition=(loss, complexity) -> loss < 1e-5) + + rng = MersenneTwister(0) + X = (a=rand(rng, 32), b=rand(rng, 32)) + y = X.a .^ 2.1 + model = MultitargetLaSRRegressor(; niterations=10, stop_kws...) + mach = machine(model, X, reduce(hcat, [reshape(y, :, 1) for i in 1:3])) + fit!(mach) + rep = report(mach) + @test all( + eq -> occursin("a", eq), [rep.equation_strings[i][rep.best_idx[i]] for i in 1:3] + ) + ypred_good = predict(mach, X) + + # Test that we can choose the equation + ypred_same = predict(mach, (data=X, idx=rep.best_idx)) + @test ypred_good == ypred_same + + ypred_bad = predict(mach, (data=X, idx=[1, 1, 1])) + @test ypred_good != ypred_bad + + ypred_mixed = predict(mach, (data=X, idx=[rep.best_idx[1], 1, rep.best_idx[3]])) + @test ypred_mixed == hcat(ypred_good[:, 1], ypred_bad[:, 2], ypred_good[:, 3]) + + @test_throws AssertionError predict(mach, (data=X,)) + VERSION >= v"1.8" && + @test_throws "If specifying an equation index during" predict(mach, (data=X,)) + VERSION >= v"1.8" && + @test_throws "If specifying an equation index during" predict(mach, (X=X, idx=1)) +end + +@testitem "Variable names - named outputs" tags = [:part1] begin + using LibraryAugmentedSymbolicRegression + using MLJBase + using Random: MersenneTwister + + include("test_params.jl") + + stop_kws = (; early_stop_condition=(loss, complexity) -> loss < 1e-5) + + rng = MersenneTwister(0) + X = (b1=randn(rng, 32), b2=randn(rng, 32)) + Y = (c1=X.b1 .* X.b2, c2=X.b1 .+ X.b2) + w = ones(32) + model = MultitargetLaSRRegressor(; niterations=10, stop_kws...) + mach = machine(model, X, Y, w) + fit!(mach) + test_outs = predict(mach, X) + @test isempty(setdiff((:c1, :c2), keys(test_outs))) + @test_throws AssertionError predict(mach, (a1=randn(32), b2=randn(32))) + VERSION >= v"1.8" && @test_throws "Variable names do not match fitted" predict( + mach, (b1=randn(32), a2=randn(32)) + ) +end + +@testitem "Good predictions" tags = [:part1] begin + using LibraryAugmentedSymbolicRegression + using MLJBase + using Random: MersenneTwister + + include("test_params.jl") + + stop_kws = (; early_stop_condition=(loss, complexity) -> loss < 1e-5) + + rng = MersenneTwister(0) + X = randn(rng, 100, 3) Y = X - model = MultitargetSRRegressor(; niterations=10, stop_kws...) + model = MultitargetLaSRRegressor(; niterations=10, stop_kws...) mach = machine(model, X, Y) fit!(mach) @test sum(abs2, predict(mach, X) .- Y) / length(X) < 1e-6 end -@testset "Helpful errors" begin - model = MultitargetSRRegressor() - mach = machine(model, randn(32, 3), randn(32); scitype_check_level=0) +@testitem "Helpful errors" tags = [:part3] begin + using LibraryAugmentedSymbolicRegression + using MLJBase + using Random: MersenneTwister + + include("test_params.jl") + + model = MultitargetLaSRRegressor() + rng = MersenneTwister(0) + mach = machine(model, randn(rng, 32, 3), randn(rng, 32); scitype_check_level=0) @test_throws AssertionError @quiet(fit!(mach)) VERSION >= v"1.8" && @test_throws "For single-output regression, please" @quiet(fit!(mach)) - model = SRRegressor() - mach = machine(model, randn(32, 3), randn(32, 2); scitype_check_level=0) + model = LaSRRegressor() + rng = MersenneTwister(0) + mach = machine(model, randn(rng, 32, 3), randn(rng, 32, 2); scitype_check_level=0) @test_throws AssertionError @quiet(fit!(mach)) VERSION >= v"1.8" && @test_throws "For multi-output regression, please" @quiet(fit!(mach)) - model = SRRegressor(; verbosity=0) - mach = machine(model, randn(32, 3), randn(32)) + model = LaSRRegressor(; verbosity=0) + rng = MersenneTwister(0) + mach = machine(model, randn(rng, 32, 3), randn(rng, 32)) @test_throws ErrorException @quiet(fit!(mach; verbosity=0)) end -@testset "Unfinished search" begin - model = SRRegressor(; timeout_in_seconds=1e-10) - mach = machine(model, randn(32, 3), randn(32)) +@testitem "Unfinished search" tags = [:part3] begin + using LibraryAugmentedSymbolicRegression + using MLJBase + using Suppressor + using Random: MersenneTwister + + model = LaSRRegressor(; timeout_in_seconds=1e-10) + rng = MersenneTwister(0) + mach = machine(model, randn(rng, 32, 3), randn(rng, 32)) fit!(mach) # Ensure that the hall of fame is empty: _, hof = mach.fitresult.state hof.exists .= false # Recompute the report: - mach.report[:fit] = SymbolicRegression.MLJInterfaceModule.full_report( + mach.report[:fit] = LibraryAugmentedSymbolicRegression.MLJInterfaceModule.full_report( model, mach.fitresult ) @test report(mach).best_idx == 0 @@ -156,15 +186,16 @@ end end @test occursin("Evaluation failed either due to", msg) - model = MultitargetSRRegressor(; timeout_in_seconds=1e-10) - mach = machine(model, randn(32, 3), randn(32, 3)) + model = MultitargetLaSRRegressor(; timeout_in_seconds=1e-10) + rng = MersenneTwister(0) + mach = machine(model, randn(rng, 32, 3), randn(rng, 32, 3)) fit!(mach) # Ensure that the hall of fame is empty: _, hofs = mach.fitresult.state foreach(hofs) do hof hof.exists .= false end - mach.report[:fit] = SymbolicRegression.MLJInterfaceModule.full_report( + mach.report[:fit] = LibraryAugmentedSymbolicRegression.MLJInterfaceModule.full_report( model, mach.fitresult ) @test report(mach).best_idx == [0, 0, 0] diff --git a/test/test_nan_detection.jl b/test/test_nan_detection.jl index fb2a21d1f..c1a104d1a 100644 --- a/test/test_nan_detection.jl +++ b/test/test_nan_detection.jl @@ -1,5 +1,6 @@ println("Testing NaN detection.") -using SymbolicRegression +using LibraryAugmentedSymbolicRegression +using LoopVectorization for T in [Float16, Float32, Float64], turbo in [true, false] T == Float16 && turbo && continue diff --git a/test/test_nested_constraints.jl b/test/test_nested_constraints.jl index 8d3932c01..59e89863a 100644 --- a/test/test_nested_constraints.jl +++ b/test/test_nested_constraints.jl @@ -1,5 +1,5 @@ println("Test operator nesting and flagging.") -using SymbolicRegression +using LibraryAugmentedSymbolicRegression function create_options(nested_constraints) return Options(; @@ -14,19 +14,19 @@ options = create_options(nothing) tree = cos(exp(exp(exp(exp(Node("x1")))))) degree_of_exp = 1 index_of_exp = findfirst(isequal(exp), options.operators.unaops) -@test 4 == SymbolicRegression.CheckConstraintsModule.count_max_nestedness( +@test 4 == LibraryAugmentedSymbolicRegression.CheckConstraintsModule.count_max_nestedness( tree, degree_of_exp, index_of_exp ) tree = cos(exp(Node("x1")) + exp(exp(exp(exp(Node("x1")))))) -@test 4 == SymbolicRegression.CheckConstraintsModule.count_max_nestedness( +@test 4 == LibraryAugmentedSymbolicRegression.CheckConstraintsModule.count_max_nestedness( tree, degree_of_exp, index_of_exp ) degree_of_plus = 2 index_of_plus = findfirst(isequal(+), options.operators.binops) tree = cos(exp(Node("x1")) + exp(exp(Node("x1") + exp(exp(exp(Node("x1"))))))) -@test 2 == SymbolicRegression.CheckConstraintsModule.count_max_nestedness( +@test 2 == LibraryAugmentedSymbolicRegression.CheckConstraintsModule.count_max_nestedness( tree, degree_of_plus, index_of_plus ) @@ -34,21 +34,21 @@ tree = cos(exp(Node("x1")) + exp(exp(Node("x1") + exp(exp(exp(Node("x1"))))))) x1 = Node("x1") options = create_options(nothing) tree = cos(cos(x1)) + cos(x1) + exp(cos(x1)) -@test !SymbolicRegression.CheckConstraintsModule.flag_illegal_nests(tree, options) +@test !LibraryAugmentedSymbolicRegression.CheckConstraintsModule.flag_illegal_nests(tree, options) options = create_options([cos => [cos => 0]]) -@test SymbolicRegression.CheckConstraintsModule.flag_illegal_nests(tree, options) +@test LibraryAugmentedSymbolicRegression.CheckConstraintsModule.flag_illegal_nests(tree, options) options = create_options([cos => [cos => 1]]) -@test !SymbolicRegression.CheckConstraintsModule.flag_illegal_nests(tree, options) +@test !LibraryAugmentedSymbolicRegression.CheckConstraintsModule.flag_illegal_nests(tree, options) options = create_options([cos => [exp => 0]]) -@test !SymbolicRegression.CheckConstraintsModule.flag_illegal_nests(tree, options) +@test !LibraryAugmentedSymbolicRegression.CheckConstraintsModule.flag_illegal_nests(tree, options) options = create_options([exp => [cos => 0]]) -@test SymbolicRegression.CheckConstraintsModule.flag_illegal_nests(tree, options) +@test LibraryAugmentedSymbolicRegression.CheckConstraintsModule.flag_illegal_nests(tree, options) options = create_options([(+) => [(+) => 0]]) -@test SymbolicRegression.CheckConstraintsModule.flag_illegal_nests(tree, options) +@test LibraryAugmentedSymbolicRegression.CheckConstraintsModule.flag_illegal_nests(tree, options) println("Passed.") diff --git a/test/test_operators.jl b/test/test_operators.jl index 47b83418a..1d3bf614e 100644 --- a/test/test_operators.jl +++ b/test/test_operators.jl @@ -1,5 +1,5 @@ -using SymbolicRegression -using SymbolicRegression: +using LibraryAugmentedSymbolicRegression +using LibraryAugmentedSymbolicRegression: plus, sub, mult, @@ -79,7 +79,7 @@ end ], ) for T in types_to_test - @test_nowarn SymbolicRegression.assert_operators_well_defined(T, options) + @test_nowarn LibraryAugmentedSymbolicRegression.assert_operators_well_defined(T, options) end end @@ -90,17 +90,17 @@ end unary_operators=[square, cube, log, log2, log10, log1p, sqrt, acosh, neg], ) for T in types_to_test - @test_nowarn SymbolicRegression.assert_operators_well_defined(T, options) + @test_nowarn LibraryAugmentedSymbolicRegression.assert_operators_well_defined(T, options) end end @testset "Test incompatibilities are caught" begin options = Options(; binary_operators=[greater]) - @test_throws ErrorException SymbolicRegression.assert_operators_well_defined( + @test_throws ErrorException LibraryAugmentedSymbolicRegression.assert_operators_well_defined( ComplexF64, options ) VERSION >= v"1.8" && - @test_throws "complex plane" SymbolicRegression.assert_operators_well_defined( + @test_throws "complex plane" LibraryAugmentedSymbolicRegression.assert_operators_well_defined( ComplexF64, options ) end @@ -108,14 +108,14 @@ end @testset "Operators which return the wrong type should fail" begin my_bad_op(x) = 1.0f0 options = Options(; binary_operators=[], unary_operators=[my_bad_op]) - @test_throws ErrorException SymbolicRegression.assert_operators_well_defined( + @test_throws ErrorException LibraryAugmentedSymbolicRegression.assert_operators_well_defined( Float64, options ) VERSION >= v"1.8" && - @test_throws "returned an output of type" SymbolicRegression.assert_operators_well_defined( + @test_throws "returned an output of type" LibraryAugmentedSymbolicRegression.assert_operators_well_defined( Float64, options ) - @test_nowarn SymbolicRegression.assert_operators_well_defined(Float32, options) + @test_nowarn LibraryAugmentedSymbolicRegression.assert_operators_well_defined(Float32, options) end @testset "Turbo mode should be the same" begin diff --git a/test/test_optimizer_mutation.jl b/test/test_optimizer_mutation.jl index 41f3bd67d..a5bad0e3c 100644 --- a/test/test_optimizer_mutation.jl +++ b/test/test_optimizer_mutation.jl @@ -1,9 +1,9 @@ -using SymbolicRegression -using SymbolicRegression: SymbolicRegression -using SymbolicRegression: Dataset, RunningSearchStatistics, RecordType +using LibraryAugmentedSymbolicRegression +using LibraryAugmentedSymbolicRegression: LibraryAugmentedSymbolicRegression +using LibraryAugmentedSymbolicRegression: Dataset, RunningSearchStatistics, RecordType using Optim: Optim -using SymbolicRegression.MutateModule: next_generation -using DynamicExpressions: get_constants +using LibraryAugmentedSymbolicRegression.MutateModule: next_generation +using DynamicExpressions: get_scalar_constants mutation_weights = (; optimize=1e30) # We also test whether a named tuple works. options = Options(; @@ -35,7 +35,7 @@ new_member, _, _ = next_generation( tmp_recorder=RecordType(), ) -resultant_constants = get_constants(new_member.tree) +resultant_constants, refs = get_scalar_constants(new_member.tree) for k in [0.0, 0.2, 0.5, 1.0] @test sin(resultant_constants[1] * k + resultant_constants[2]) โ‰ˆ sin(2.1 * k + 0.8) atol = 1e-3 diff --git a/test/test_options.jl b/test/test_options.jl index 9c7bc0d99..c60f94c15 100644 --- a/test/test_options.jl +++ b/test/test_options.jl @@ -1,4 +1,4 @@ -using SymbolicRegression +using LibraryAugmentedSymbolicRegression using Optim: Optim # testing types diff --git a/test/test_params.jl b/test/test_params.jl index 7aafd7f48..a9c4871bd 100644 --- a/test/test_params.jl +++ b/test/test_params.jl @@ -1,11 +1,15 @@ -using SymbolicRegression: L2DistLoss, MutationWeights +using LibraryAugmentedSymbolicRegression: L2DistLoss, MutationWeights +using DynamicExpressions.OperatorEnumConstructionModule: empty_all_globals! using Optim: Optim using LineSearches: LineSearches using Test: Test ENV["SYMBOLIC_REGRESSION_IS_TESTING"] = "true" +empty_all_globals!() + const maximum_residual = 1e-2 + if !@isdefined(custom_cos) || !hasmethod(custom_cos, (String,)) @eval custom_cos(x) = cos(x) end @@ -69,3 +73,11 @@ const default_params = ( test_info(_, x) = error("Test failed: $x") test_info(_, ::Test.Pass) = nothing test_info(f::F, ::Test.Fail) where {F} = f() + +macro quiet(ex) + return quote + redirect_stderr(devnull) do + $ex + end + end |> esc +end diff --git a/test/test_precompilation.jl b/test/test_precompilation.jl index 6fb4cf572..89ed9f1b4 100644 --- a/test/test_precompilation.jl +++ b/test/test_precompilation.jl @@ -1,3 +1,3 @@ -using SymbolicRegression +using LibraryAugmentedSymbolicRegression -SymbolicRegression.do_precompilation(Val(:compile)) +LibraryAugmentedSymbolicRegression.do_precompilation(Val(:compile)) diff --git a/test/test_pretty_printing.jl b/test/test_pretty_printing.jl new file mode 100644 index 000000000..3d7f334a9 --- /dev/null +++ b/test/test_pretty_printing.jl @@ -0,0 +1,111 @@ +@testitem "pretty print member" tags = [:part3] begin + using LibraryAugmentedSymbolicRegression + + options = Options(; binary_operators=[+, ^]) + + ex = @parse_expression(x^2.0 + 1.5, binary_operators = [+, ^], variable_names = [:x]) + shower(x) = sprint((io, e) -> show(io, MIME"text/plain"(), e), x) + s = shower(ex) + @test s == "(x ^ 2.0) + 1.5" + + X = [1.0 2.0 3.0] + y = [2.0, 3.0, 4.0] + dataset = Dataset(X, y) + member = PopMember(dataset, ex, options; deterministic=false) + member.score = 1.0 + @test member isa PopMember{Float64,Float64,<:Expression{Float64,Node{Float64}}} + s_member = shower(member) + @test s_member == "PopMember(tree = ((x ^ 2.0) + 1.5), loss = 16.25, score = 1.0)" + + # New options shouldn't change this + options = Options(; binary_operators=[-, /]) + s_member = shower(member) + @test s_member == "PopMember(tree = ((x ^ 2.0) + 1.5), loss = 16.25, score = 1.0)" +end + +@testitem "pretty print hall of fame" tags = [:part1] begin + using LibraryAugmentedSymbolicRegression + using LibraryAugmentedSymbolicRegression: embed_metadata + using LibraryAugmentedSymbolicRegression.CoreModule: safe_pow + + options = Options(; binary_operators=[+, safe_pow], maxsize=7) + + ex = @parse_expression( + $safe_pow(x, 2.0) + 1.5, binary_operators = [+, safe_pow], variable_names = [:x] + ) + shower(x) = sprint((io, e) -> show(io, MIME"text/plain"(), e), x) + s = shower(ex) + @test s == "(x ^ 2.0) + 1.5" + + X = [1.0 2.0 3.0] + y = [2.0, 3.0, 4.0] + dataset = Dataset(X, y) + member = PopMember(dataset, ex, options; deterministic=false) + member.score = 1.0 + @test member isa PopMember{Float64,Float64,<:Expression{Float64,Node{Float64}}} + + hof = HallOfFame(options, dataset) + hof = embed_metadata(hof, options, dataset) + hof.members[5] = member + hof.exists[5] = true + s_hof = strip(shower(hof)) + true_s = "HallOfFame{...}: + .exists[1] = false + .members[1] = undef + .exists[2] = false + .members[2] = undef + .exists[3] = false + .members[3] = undef + .exists[4] = false + .members[4] = undef + .exists[5] = true + .members[5] = PopMember(tree = ((x ^ 2.0) + 1.5), loss = 16.25, score = 1.0) + .exists[6] = false + .members[6] = undef + .exists[7] = false + .members[7] = undef + .exists[8] = false + .members[8] = undef + .exists[9] = false + .members[9] = undef" + + @test s_hof == true_s +end + +@testitem "pretty print expression" tags = [:part2] begin + using LibraryAugmentedSymbolicRegression + using Suppressor: @capture_out + + options = Options(; binary_operators=[+, -, *, /], unary_operators=[cos]) + ex = @parse_expression( + cos(x) + y * y, operators = options.operators, variable_names = [:x, :y] + ) + + s = sprint((io, ex) -> print_tree(io, ex, options), ex) + @test strip(s) == "cos(x) + (y * y)" + + s = @capture_out begin + print_tree(ex, options) + end + @test strip(s) == "cos(x) + (y * y)" + + # Works with the tree itself too + s = @capture_out begin + print_tree(get_tree(ex), options) + end + @test strip(s) == "cos(x1) + (x2 * x2)" + s = sprint((io, ex) -> print_tree(io, ex, options), get_tree(ex)) + @test strip(s) == "cos(x1) + (x2 * x2)" + + # Updating options won't change printout, UNLESS + # we pass the options. + options = Options(; binary_operators=[/, *, -, +], unary_operators=[sin]) + + s = @capture_out begin + print_tree(ex) + end + @test strip(s) == "cos(x) + (y * y)" + + s = sprint((io, ex) -> print_tree(io, ex, options), ex) + @test strip(s) == "sin(x) / (y - y)" +end diff --git a/test/test_print.jl b/test/test_print.jl index fff027504..2c3b34b29 100644 --- a/test/test_print.jl +++ b/test/test_print.jl @@ -1,5 +1,5 @@ -using SymbolicRegression -using SymbolicRegression.UtilsModule: split_string +using LibraryAugmentedSymbolicRegression +using LibraryAugmentedSymbolicRegression.UtilsModule: split_string include("test_params.jl") diff --git a/test/test_prob_pick_first.jl b/test/test_prob_pick_first.jl index 8967f0e04..dbd5cc1cc 100644 --- a/test/test_prob_pick_first.jl +++ b/test/test_prob_pick_first.jl @@ -1,6 +1,7 @@ println("Testing whether tournament_selection_p works.") -using SymbolicRegression -using DynamicExpressions.EquationModule: with_type_parameters +using LibraryAugmentedSymbolicRegression +using DynamicExpressions: with_type_parameters, @parse_expression +using Test include("test_params.jl") n = 10 @@ -15,27 +16,29 @@ options = Options(; for reverse in [false, true] T = Float32 - NT = with_type_parameters(options.node_type, T) - members = PopMember{T,T,NT}[] # Generate members with scores from 0 to 1: - for i in 1:n - tree = Node("x1") * 3.2f0 - score = Float32(i - 1) / (n - 1) - if reverse - score = 1 - score - end - test_loss = 1.0f0 # (arbitrary for this test) - push!(members, PopMember(tree, score, test_loss, options)) - end + members = [ + let + ex = @parse_expression( + x1 * 3.2, operators = options.operators, variable_names = [:x1], + ) + score = Float32(i - 1) / (n - 1) + if reverse + score = 1 - score + end + test_loss = 1.0f0 # (arbitrary for this test) + PopMember(ex, score, test_loss, options; deterministic=false) + end for i in 1:n + ] pop = Population(members) - dummy_running_stats = SymbolicRegression.AdaptiveParsimonyModule.RunningSearchStatistics(; + dummy_running_stats = LibraryAugmentedSymbolicRegression.AdaptiveParsimonyModule.RunningSearchStatistics(; options=options ) best_pop_member = [ - SymbolicRegression.best_of_sample(pop, dummy_running_stats, options).score for + LibraryAugmentedSymbolicRegression.best_of_sample(pop, dummy_running_stats, options).score for j in 1:100 ] diff --git a/test/test_recorder.jl b/test/test_recorder.jl index 95d7b7348..054454fe1 100644 --- a/test/test_recorder.jl +++ b/test/test_recorder.jl @@ -1,5 +1,5 @@ -using SymbolicRegression -using SymbolicRegression.UtilsModule: recursive_merge +using LibraryAugmentedSymbolicRegression +using LibraryAugmentedSymbolicRegression.UtilsModule: recursive_merge using JSON3 include("test_params.jl") @@ -8,7 +8,7 @@ recorder_file = joinpath(base_dir, "pysr_recorder.json") X = 2 .* randn(Float32, 2, 1000) y = 3 * cos.(X[2, :]) + X[1, :] .^ 2 .- 2 -options = SymbolicRegression.Options(; +options = LibraryAugmentedSymbolicRegression.Options(; binary_operators=(+, *, /, -), unary_operators=(cos,), use_recorder=true, diff --git a/test/test_search_statistics.jl b/test/test_search_statistics.jl index 35a9b0175..770803427 100644 --- a/test/test_search_statistics.jl +++ b/test/test_search_statistics.jl @@ -1,5 +1,5 @@ -using SymbolicRegression -using SymbolicRegression.AdaptiveParsimonyModule: +using LibraryAugmentedSymbolicRegression +using LibraryAugmentedSymbolicRegression.AdaptiveParsimonyModule: RunningSearchStatistics, update_frequencies!, move_window!, normalize_frequencies! using Random diff --git a/test/test_simplification.jl b/test/test_simplification.jl index ad6c0a562..fa6764c4a 100644 --- a/test/test_simplification.jl +++ b/test/test_simplification.jl @@ -1,8 +1,11 @@ include("test_params.jl") -using SymbolicRegression, Test +using LibraryAugmentedSymbolicRegression, Test using SymbolicUtils: simplify, Symbolic -using Random: MersenneTwister -using Base: โ‰ˆ +using DynamicExpressions.OperatorEnumConstructionModule: empty_all_globals! +#! format: off +using Base: โ‰ˆ; using Random: MersenneTwister +#! format: on +# ^ Can't end line with โ‰ˆ due to JuliaSyntax.jl bug function Base.:โ‰ˆ(a::String, b::String) a = replace(a, r"\s+" => "") @@ -10,6 +13,8 @@ function Base.:โ‰ˆ(a::String, b::String) return a == b end +empty_all_globals!() + binary_operators = (+, -, /, *) index_of_mult = [i for (i, op) in enumerate(binary_operators) if op == *][1] diff --git a/test/test_stop_on_clock.jl b/test/test_stop_on_clock.jl index 295586e57..67beb7e0a 100644 --- a/test/test_stop_on_clock.jl +++ b/test/test_stop_on_clock.jl @@ -1,13 +1,25 @@ -using SymbolicRegression +using LibraryAugmentedSymbolicRegression using Random +using Distributed: rmprocs include("test_params.jl") X = randn(MersenneTwister(0), Float32, 5, 100) y = 2 * cos.(X[4, :]) -options = Options(; default_params..., timeout_in_seconds=1) +# Ensure is precompiled: +options = Options(; + default_params..., + population_size=10, + ncycles_per_iteration=100, + maxsize=15, + timeout_in_seconds=1, +) +equation_search(X, y; niterations=1, options=options, parallelism=:serial) + +# Ensure nothing might prevent slow checking of the clock: +rmprocs() +GC.gc(true) # full=true start_time = time() -# With multithreading: -equation_search(X, y; niterations=10000000, options=options, parallelism=:multithreading) +equation_search(X, y; niterations=10000000, options=options, parallelism=:serial) end_time = time() @test end_time - start_time < 100 diff --git a/test/test_symbolic_utils.jl b/test/test_symbolic_utils.jl index 3e074abaa..408c4e65e 100644 --- a/test/test_symbolic_utils.jl +++ b/test/test_symbolic_utils.jl @@ -1,5 +1,5 @@ using SymbolicUtils -using SymbolicRegression +using LibraryAugmentedSymbolicRegression include("test_params.jl") _inv(x) = 1 / x diff --git a/test/test_tree_construction.jl b/test/test_tree_construction.jl index f233bf118..49af73d46 100644 --- a/test/test_tree_construction.jl +++ b/test/test_tree_construction.jl @@ -1,6 +1,6 @@ -using SymbolicRegression +using LibraryAugmentedSymbolicRegression using Random -using SymbolicRegression: eval_loss, score_func, Dataset +using LibraryAugmentedSymbolicRegression: eval_loss, score_func, Dataset using ForwardDiff include("test_params.jl") diff --git a/test/test_turbo_nan.jl b/test/test_turbo_nan.jl index 2447b8253..e57e73a0a 100644 --- a/test/test_turbo_nan.jl +++ b/test/test_turbo_nan.jl @@ -1,4 +1,5 @@ -using SymbolicRegression +using LibraryAugmentedSymbolicRegression +using LoopVectorization bad_op(x::T) where {T} = (x >= 0) ? x : T(0) diff --git a/test/test_units.jl b/test/test_units.jl index 0e58173e1..c7f85130e 100644 --- a/test/test_units.jl +++ b/test/test_units.jl @@ -1,52 +1,22 @@ -using SymbolicRegression -using SymbolicRegression: - square, - cube, - plus, - sub, - mult, - greater, - cond, - relu, - logical_or, - logical_and, - safe_pow, - atanh_clip -using SymbolicRegression.InterfaceDynamicQuantitiesModule: get_units, get_dimensions_type -using SymbolicRegression.MLJInterfaceModule: unwrap_units_single -using SymbolicRegression.DimensionalAnalysisModule: - violates_dimensional_constraints, @maybe_return_call, WildcardQuantity -using DynamicQuantities: - DEFAULT_DIM_BASE_TYPE, - RealQuantity, - Quantity, - QuantityArray, - SymbolicDimensions, - Dimensions, - DimensionError, - @u_str, - @us_str, - uparse, - sym_uparse, - ustrip, - dimension -using MLJBase: MLJBase as MLJ -using MLJModelInterface: MLJModelInterface as MMI -include("utils.jl") - -custom_op(x, y) = x + y - -options = Options(; - binary_operators=[-, *, /, custom_op, ^], unary_operators=[cos, cbrt, sqrt, abs, inv] -) -@extend_operators options - -(x1, x2, x3) = (i -> Node(Float64; feature=i)).(1:3) - -@testset "Dimensional analysis" begin +@testitem "Dimensional analysis" tags = [:part3] begin + using LibraryAugmentedSymbolicRegression + using LibraryAugmentedSymbolicRegression.InterfaceDynamicQuantitiesModule: get_units + using LibraryAugmentedSymbolicRegression.DimensionalAnalysisModule: violates_dimensional_constraints + using DynamicQuantities + using DynamicQuantities: DEFAULT_DIM_BASE_TYPE + X = randn(3, 100) y = @. cos(X[3, :] * 2.1 - 0.2) + 0.5 + custom_op(x, y) = x + y + options = Options(; + binary_operators=[-, *, /, custom_op, ^], + unary_operators=[cos, cbrt, sqrt, abs, inv], + ) + @extend_operators options + + (x1, x2, x3) = (i -> Node(Float64; feature=i)).(1:3) + D = Dimensions{DEFAULT_DIM_BASE_TYPE} SD = SymbolicDimensions{DEFAULT_DIM_BASE_TYPE} @@ -130,15 +100,24 @@ options = Options(; end end -options = Options(; binary_operators=[-, *, /, custom_op], unary_operators=[cos]) -@extend_operators options +@testitem "Search with dimensional constraints" tags = [:part3] begin + using LibraryAugmentedSymbolicRegression + using LibraryAugmentedSymbolicRegression.DimensionalAnalysisModule: violates_dimensional_constraints + using Random: MersenneTwister -@testset "Search with dimensional constraints" begin - X = rand(1, 128) .* 10 + rng = MersenneTwister(0) + X = rand(rng, 1, 128) .* 20 y = @. cos(X[1, :]) + X[1, :] dataset = Dataset(X, y; X_units=["kg"], y_units="1") + custom_op(x, y) = x + y + options = Options(; + binary_operators=[-, *, /, custom_op], + unary_operators=[cos], + early_stop_condition=(loss, complexity) -> (loss < 1e-7 && complexity <= 8), + ) + @extend_operators options - hof = EquationSearch(dataset; options) + hof = equation_search(dataset; niterations=1000, options) # Solutions should be like cos([cons] * X[1]) + [cons]*X[1] dominating = calculate_pareto_frontier(hof) @@ -151,26 +130,27 @@ options = Options(; binary_operators=[-, *, /, custom_op], unary_operators=[cos] # Check that every cos(...) which contains x1 also has complexity has_cos(tree) = - any(tree) do t + any(get_tree(tree)) do t t.degree == 1 && options.operators.unaops[t.op] == cos end valid_trees = [ - !has_cos(member.tree) || any(member.tree) do t - if ( + !has_cos(member.tree) || any( + t -> t.degree == 1 && - options.operators.unaops[t.op] == cos && - Node(Float64; feature=1) in t - ) - return compute_complexity(t, options) > 1 - end - return false - end for member in dominating + options.operators.unaops[t.op] == cos && + Node(Float64; feature=1) in t && + compute_complexity(t, options) > 1, + get_tree(member.tree), + ) for member in dominating ] @test all(valid_trees) @test length(valid_trees) > 0 end -@testset "Operator compatibility" begin +@testitem "Operator compatibility" tags = [:part3] begin + using LibraryAugmentedSymbolicRegression + using DynamicQuantities + ## square cube plus sub mult greater cond relu logical_or logical_and safe_pow atanh_clip # Want to ensure these operators perform correctly in the context of units @test square(1.0u"m") == 1.0u"m^2" @@ -210,11 +190,24 @@ end @test_throws DimensionError atanh_clip(1.0u"m") end -options = Options(; binary_operators=[-, *, /, custom_op], unary_operators=[cos]) -@extend_operators options +@testitem "Search with dimensional constraints on output" tags = [:part3] begin + using LibraryAugmentedSymbolicRegression + using MLJBase: MLJBase as MLJ + using DynamicQuantities + using Random: MersenneTwister + + include("utils.jl") -@testset "Search with dimensional constraints on output" begin - X = randn(2, 128) + custom_op(x, y) = x + y + options = Options(; + binary_operators=[-, *, /, custom_op], + unary_operators=[cos], + early_stop_condition=(loss, complexity) -> (loss < 1e-7 && complexity == 3), + ) + @extend_operators options + + rng = MersenneTwister(0) + X = randn(rng, 2, 128) X[2, :] .= X[1, :] y = X[1, :] .^ 2 @@ -224,7 +217,7 @@ options = Options(; binary_operators=[-, *, /, custom_op], unary_operators=[cos] # Solution should be x2 * x2 dominating = calculate_pareto_frontier(hof) - best = first(filter(m::PopMember -> m.loss < 1e-7, dominating)).tree + best = get_tree(first(filter(m::PopMember -> m.loss < 1e-7, dominating)).tree) x2 = Node(Float64; feature=2) @@ -236,24 +229,29 @@ options = Options(; binary_operators=[-, *, /, custom_op], unary_operators=[cos] @warn "Complexity of best solution is not 3; search with units might have failed" end - X = randn(2, 128) + rng = MersenneTwister(0) + X = randn(rng, 2, 128) y = @. cbrt(X[1, :]) .+ sqrt(abs(X[2, :])) - options2 = Options(; binary_operators=[+, *], unary_operators=[sqrt, cbrt, abs]) + options2 = Options(; + binary_operators=[+, *], + unary_operators=[sqrt, cbrt, abs], + early_stop_condition=(loss, complexity) -> (loss < 1e-7 && complexity == 6), + ) hof = EquationSearch(X, y; options=options2, X_units=["kg^3", "kg^2"], y_units="kg") dominating = calculate_pareto_frontier(hof) best = first(filter(m::PopMember -> m.loss < 1e-7, dominating)).tree @test compute_complexity(best, options2) == 6 - @test any(best) do t + @test any(get_tree(best)) do t t.degree == 1 && options2.operators.unaops[t.op] == cbrt end - @test any(best) do t + @test any(get_tree(best)) do t t.degree == 1 && options2.operators.unaops[t.op] == safe_sqrt end @testset "With MLJ" begin for as_quantity_array in (false, true) - model = SRRegressor(; + model = LaSRRegressor(; binary_operators=[+, *], unary_operators=[sqrt, cbrt, abs], early_stop_condition=(loss, complexity) -> (loss < 1e-7 && complexity <= 6), @@ -269,10 +267,10 @@ options = Options(; binary_operators=[-, *, /, custom_op], unary_operators=[cos] report = MLJ.report(mach) best_idx = findfirst(report.losses .< 1e-7)::Int @test report.complexities[best_idx] <= 6 - @test any(report.equations[best_idx]) do t + @test any(get_tree(report.equations[best_idx])) do t t.degree == 1 && t.op == 2 # cbrt end - @test any(report.equations[best_idx]) do t + @test any(get_tree(report.equations[best_idx])) do t t.degree == 1 && t.op == 1 # safe_sqrt end @@ -282,7 +280,7 @@ options = Options(; binary_operators=[-, *, /, custom_op], unary_operators=[cos] end # Multiple outputs, and with RealQuantity - model = MultitargetSRRegressor(; + model = MultitargetLaSRRegressor(; binary_operators=[+, *], unary_operators=[sqrt, cbrt, abs], early_stop_condition=(loss, complexity) -> (loss < 1e-7 && complexity <= 8), @@ -318,14 +316,20 @@ options = Options(; binary_operators=[-, *, /, custom_op], unary_operators=[cos] end end -@testset "Should error on mismatched units" begin +@testitem "Should error on mismatched units" tags = [:part3] begin + using LibraryAugmentedSymbolicRegression + using DynamicQuantities + X = randn(11, 50) y = randn(50) VERSION >= v"1.8.0" && @test_throws("Number of features", Dataset(X, y; X_units=["m", "1"], y_units="kg")) end -@testset "Should print units" begin +@testitem "Should print units" tags = [:part3] begin + using LibraryAugmentedSymbolicRegression + using DynamicQuantities + X = randn(5, 64) y = randn(64) dataset = Dataset(X, y; X_units=["m^3", "km/s", "kg", "1", "1"], y_units="kg") @@ -386,7 +390,13 @@ end ) == "xโ‚…[5.0 m] * 3.2" end -@testset "Dimensionless constants" begin +@testitem "Dimensionless constants" tags = [:part3] begin + using LibraryAugmentedSymbolicRegression + using LibraryAugmentedSymbolicRegression.DimensionalAnalysisModule: violates_dimensional_constraints + using DynamicQuantities + + include("utils.jl") + options = Options(; binary_operators=[+, -, *, /, square, cube], unary_operators=[cos, sin], @@ -422,7 +432,14 @@ end end end -@testset "Miscellaneous" begin +@testitem "Miscellaneous tests of unit interface" tags = [:part3] begin + using LibraryAugmentedSymbolicRegression + using DynamicQuantities + using LibraryAugmentedSymbolicRegression.DimensionalAnalysisModule: @maybe_return_call, WildcardQuantity + using LibraryAugmentedSymbolicRegression.MLJInterfaceModule: unwrap_units_single + using LibraryAugmentedSymbolicRegression.InterfaceDynamicQuantitiesModule: get_dimensions_type + using MLJModelInterface: MLJModelInterface as MMI + function test_return_call(op::Function, w...) @maybe_return_call(typeof(first(w)), op, w) return nothing diff --git a/test/test_utils.jl b/test/test_utils.jl index 67ceb0dcc..d6155c5f7 100644 --- a/test/test_utils.jl +++ b/test/test_utils.jl @@ -1,5 +1,5 @@ -using SymbolicRegression -using SymbolicRegression.UtilsModule: +using LibraryAugmentedSymbolicRegression +using LibraryAugmentedSymbolicRegression.UtilsModule: findmin_fast, argmin_fast, bottomk_fast, is_anonymous_function using Random diff --git a/test/user_defined_operator.jl b/test/user_defined_operator.jl index a236a5089..f96f4d994 100644 --- a/test/user_defined_operator.jl +++ b/test/user_defined_operator.jl @@ -1,11 +1,11 @@ -using SymbolicRegression, Test +using LibraryAugmentedSymbolicRegression, Test include("test_params.jl") _inv(x::Float32)::Float32 = 1.0f0 / x X = rand(Float32, 5, 100) .+ 1 y = 1.2f0 .+ 2 ./ X[3, :] -options = SymbolicRegression.Options(; +options = LibraryAugmentedSymbolicRegression.Options(; default_params..., binary_operators=(+, *), unary_operators=(_inv,), populations=8 ) hallOfFame = equation_search(