Skip to content

Commit

Permalink
Datafit - Genetic algorithm (#35)
Browse files Browse the repository at this point in the history
* Update variable names

* Unfortunately, I think this method is just too good

* Update function docs

* Add types for single or multi variable problems

* Add dataset type

* Use dataset type

* Update changelog;

* Release v0.1.0
  • Loading branch information
nicfv authored Mar 9, 2024
1 parent de12d9f commit 5c6e6b6
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 79 deletions.
6 changes: 6 additions & 0 deletions datafit/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Changelog

## 0.1.0

- Update code base to use a genetic-style algorithm that randomly mutates the set of function parameters
- Support both single and multi-variable curve fitting
- Add several helpful type definitions

## 0.0.2

- Set up basic single-variable curve-fitting algorithm using a "search range" method. This method is not ideal because it is extremely prone to getting stuck in local minima and has a difficult time zeroing in on the true best fit.
Expand Down
2 changes: 1 addition & 1 deletion datafit/README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Simple curve-fitting algorithm
Simple curve-fitting algorithm: Curve fitting with single-variable or multivariate problems, using a genetic algorithm.

![NPM Version](https://img.shields.io/npm/v/datafit)
![NPM Downloads](https://img.shields.io/npm/dt/datafit)
2 changes: 1 addition & 1 deletion datafit/package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "datafit",
"version": "0.0.2",
"version": "0.1.0",
"description": "Simple curve-fitting algorithm",
"main": "dist/index.js",
"types": "types/index.d.ts",
Expand Down
107 changes: 39 additions & 68 deletions datafit/src/CurveFit.ts
Original file line number Diff line number Diff line change
@@ -1,28 +1,18 @@
import { SMath } from 'smath';
import { Fit, Point, fx, Range } from './types';
import { Fit, fx, Config, X, Dataset } from './types';

export abstract class CurveFit {
/**
* Anything beyond this takes several minutes of computation time.
*/
private static readonly MAX_PARAMS: number = 18;
/**
*
* Minimize the sum of squared errors to fit a set of data
* points to a curve with a set of unknown parameters.
* @param f The model function for curve fitting.
* @param data The entire dataset, as an array of points.
* @param a_initial The initial guess for function parameters,
* which defaults to an array filled of zeroes.
* @param distance The distance to vary when checking parameter sets.
* Each iteration will vary parameters by an offset of `distance` and
* will save the best fit for that search range, to use for the next
* iteration. `distance.max` is used for the first iteration, and the
* distance is halved in succession until `distance.min` is reached. For
* a wider searching range, increase the starting distance, howevever,
* this will increase the number of iterations to solution, which will
* increase compute time and resources.
* @returns The set of parameters for the best fit and sum of squared errors.
* which defaults to an array filled with zeroes.
* @param config Configuration options for curve fitting.
* @returns The set of parameters and error for the best fit.
*/
public static fit(f: fx, data: Array<Point>, a_initial: Array<number> = [], distance: Range = { max: 10, min: 0.1 }): Fit {
public static fit<T = X>(f: fx<T>, data: Dataset<T>, a_initial: Array<number> = [], config: Config = { generations: 100, population: 100, survivors: 10, initialDeviation: 10, finalDeviation: 1 }): Fit {
const N_params: number = f.length - 1;
if (a_initial.length === 0) {
a_initial.length = N_params;
Expand All @@ -31,58 +21,21 @@ export abstract class CurveFit {
if (a_initial.length !== N_params) {
throw new Error('The initial guess should contain ' + N_params + ' parameters.');
}
if (N_params > this.MAX_PARAMS) {
throw new Error('Your function includes too many unknown parameters.');
}
if (distance.max < distance.min) {
throw new Error('Starting distance should be greater than ending distance.');
}
let bestFit: Fit = this.fitStep(f, a_initial, data, distance.max);
for (let dist_curr = distance.max / 2; dist_curr > distance.min; dist_curr /= 2) {
bestFit = this.fitStep(f, bestFit.a, data, dist_curr);
}
return bestFit;
}
/**
* Determine the set of parameters that make the best fit
* for the model function for all combinations of "adjacent"
* parameter sets. (One distance step in all directions.)
* @param f The model function for curve fitting.
* @param a The set of parameters to originate from.
* @param data The entire dataset, as an array of points.
* @param distance The distance from `a` to deviate.
* @returns The best fit for the model up to the distance specified.
*/
private static fitStep(f: fx, a: Array<number>, data: Array<Point>, distance: number): Fit {
const d: number = 5,
n: number = a.length;
let err_min: number = this.sumSquares(f, a, data),
a_min: Array<number> = a.slice();
for (let i: number = 0; i < d ** n; i++) {
// `str` contains a string of characters in [012]
// and is guaranteed to include all combinations.
// 0 = subtract distance from corresponding parameter
// 1 = no change to corresponding parameter
// 2 = add distance to corresponding parameter
const str: string = i.toString(d).padStart(n, '0'),
a_new: Array<number> = a.slice();
for (let i: number = 0; i < str.length; i++) {
const val: number = Number.parseInt(str[i]);
a_new[i] += SMath.translate(val, 0, d - 1, -distance, distance);
}
// Check the sum of squared errors. If the new
// set of parameters yields a lower error, save
// those as the new minimum parameters and error.
const err_new: number = this.sumSquares(f, a_new, data);
if (err_new < err_min) {
err_min = err_new;
a_min = a_new;
const census: Array<Fit> = [];
for (let generation = 0; generation < config.generations; generation++) {
for (let i = 0; i < config.population; i++) {
// Mutate a random parent from the prior generation of survivors
const a: Array<number> = this.mutate(
census[this.randInt(0, config.survivors)]?.a ?? a_initial,
SMath.translate(generation, 0, config.generations, config.initialDeviation, config.finalDeviation)
);
census.push({ a: a, err: this.err(f, a, data) });
}
// Sort by increasing error and only keep the survivors
census.sort((x, y) => x.err - y.err);
census.splice(config.survivors);
}
return {
a: a_min,
err: err_min,
};
return census[0];
}
/**
* Calculate the sum of squared errors for a set of function parameters.
Expand All @@ -91,9 +44,27 @@ export abstract class CurveFit {
* @param data The entire dataset, as an array of points.
* @returns The sum of squared errors.
*/
private static sumSquares(f: fx, a: Array<number>, data: Array<Point>): number {
private static err<T = X>(f: fx<T>, a: Array<number>, data: Dataset<T>): number {
let sum: number = 0;
data.forEach(point => sum += (point.y - f(point.x, ...a)) ** 2);
return sum;
}
/**
* Randomly mutate the set of function parameters by some maximum deviation.
* @param a The set of function parameters to mutate.
* @param deviation The maximum amount to deviate in any direction.
* @returns A mutated set of parameters.
*/
private static mutate(a: Array<number>, deviation: number): Array<number> {
return a.map(c => c += (Math.random() - 0.5) * deviation);
}
/**
* Generate a random integer between `min, max`
* @param min Minimum value
* @param max Maximum value
* @returns A random integer
*/
private static randInt(min: number, max: number): number {
return Math.floor(Math.random() * (max - min) + min);
}
}
46 changes: 37 additions & 9 deletions datafit/src/types.ts
Original file line number Diff line number Diff line change
@@ -1,20 +1,36 @@
/**
* Use this type to define a single-variable curve and dataset
*/
export type SingleVariable = number;
/**
* Use this type to define a multi-variable curve and dataset
*/
export type MultiVariable = Array<number>;
/**
* Defines whether this is a single- or multi-variable problem.
*/
export type X = SingleVariable | MultiVariable;
/**
* Represents a mathematical function y = f(x) with unknown constants `a`
*/
export type fx = (x: number, ...a: Array<number>) => number;
export type fx<T = X> = (x: T, ...a: Array<number>) => number;
/**
* Stores a cartesian (x,y) coordinate pair.
*/
export interface Point {
export interface Point<T = X> {
/**
* X-coordinate
*/
readonly x: number;
readonly x: T;
/**
* Y-coordinate
*/
readonly y: number;
}
/**
* Contains a set of data points.
*/
export type Dataset<T = X> = Array<Point<T>>;
/**
* Includes information about a best-fit for a curve.
*/
Expand All @@ -29,15 +45,27 @@ export interface Fit {
readonly err: number;
}
/**
* Represents a number range.
* Configuration options for `CurveFit`
*/
export interface Range {
export interface Config {
/**
* Determines the number of generations, or iterations.
*/
readonly generations: number;
/**
* Determines the number of parameters sets to generate.
*/
readonly population: number;
/**
* Determines how many survivors remain after every generation.
*/
readonly survivors: number;
/**
* The minimum value of the range.
* Determines how much a set of parameters can mutate on the first generation.
*/
readonly min: number;
readonly initialDeviation: number;
/**
* The maximum value of the range.
* Determines how much a set of parameters can mutate on the final generation.
*/
readonly max: number;
readonly finalDeviation: number;
}

0 comments on commit 5c6e6b6

Please sign in to comment.