-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
5f6f6d3
commit f8f9248
Showing
12 changed files
with
242 additions
and
57 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Linear model | ||
This example shows how to train a mixed model with linear and tree components. The dataset is a linear function with some added noise, then a step in the middle of the function. This is challenging for a linear model due to the step, and challenging for a tree model due to the sloped function (see the characteristic axis aligned step function of the tree model). We create a combined model by first boosting 5 iterations of a linear model and then 15 iterations of the tree model. The result is a model that is better able to fit the linear function and the step function. | ||
|
||
<img src="linear_model.png" alt="drawing" width="800"/> |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from pathlib import Path | ||
|
||
import seaborn as sns | ||
from matplotlib import pyplot as plt | ||
from matplotlib.ticker import FuncFormatter | ||
|
||
import cunumeric as cn | ||
import legateboost as lb | ||
|
||
sns.set() | ||
plt.rcParams["font.family"] = "serif" | ||
|
||
rs = cn.random.RandomState(42) | ||
X = cn.linspace(0, 10, 200)[:, cn.newaxis] | ||
y_true = X[:, 0].copy() | ||
y_true[X.shape[0] // 2 :] += 3.0 | ||
y = y_true + rs.normal(0, 0.25, X.shape[0]) | ||
params = {"n_estimators": 20, "learning_rate": 0.5, "verbose": True, "random_state": 20} | ||
eval_result = {} | ||
linear_model = lb.LBRegressor(base_models=(lb.models.Linear(),), **params).fit( | ||
X, y, eval_set=[(X, y_true)], eval_result=eval_result | ||
) | ||
linear_test_error = cn.sqrt(eval_result["eval-0"]["mse"]) | ||
tree_model = lb.LBRegressor(base_models=(lb.models.Tree(max_depth=1),), **params).fit( | ||
X, y, eval_set=[(X, y_true)], eval_result=eval_result | ||
) | ||
tree_test_error = cn.sqrt(eval_result["eval-0"]["mse"]) | ||
model = lb.LBRegressor( | ||
base_models=(lb.models.Linear(),) * 5 + (lb.models.Tree(max_depth=1),) * 15, | ||
**params | ||
).fit(X, y, eval_set=[(X, y_true)], eval_result=eval_result) | ||
mixed_test_error = cn.sqrt(eval_result["eval-0"]["mse"]) | ||
|
||
# plot | ||
fig, ax = plt.subplots(1, 2, figsize=(12, 6)) | ||
plt.gca().xaxis.set_major_formatter(FuncFormatter(lambda x, _: int(x))) | ||
sns.scatterplot(x=X[:, 0], y=y, color=".2", alpha=0.5, label="f(x)+noise", ax=ax[0]) | ||
sns.lineplot(x=X[:, 0], y=linear_model.predict(X), label="linear model", ax=ax[0]) | ||
sns.lineplot(x=X[:, 0], y=tree_model.predict(X), label="tree model", ax=ax[0]) | ||
sns.lineplot(x=X[:, 0], y=model.predict(X), label="linear + tree model", ax=ax[0]) | ||
ax[0].set_xlabel("X") | ||
|
||
sns.lineplot( | ||
x=range(params["n_estimators"]), y=linear_test_error, label="linear model", ax=ax[1] | ||
) | ||
sns.lineplot( | ||
x=range(params["n_estimators"]), y=tree_test_error, label="tree model", ax=ax[1] | ||
) | ||
sns.lineplot( | ||
x=range(params["n_estimators"]), | ||
y=mixed_test_error, | ||
label="linear + tree model", | ||
ax=ax[1], | ||
) | ||
ax[1].set_xlabel("n_estimators") | ||
ax[1].set_ylabel("test error") | ||
plt.suptitle("Linear Models + Tree Models") | ||
plt.tight_layout() | ||
image_dir = Path(__file__).parent | ||
plt.savefig(image_dir / "linear_model.png") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
from .tree import Tree | ||
from .linear import Linear | ||
from .base_model import BaseModel |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from abc import ABC, abstractmethod | ||
|
||
import numpy as np | ||
|
||
import cunumeric as cn | ||
|
||
from ..utils import PickleCunumericMixin | ||
|
||
|
||
class BaseModel(PickleCunumericMixin, ABC): | ||
def set_random_state(self, random_state: np.random.RandomState) -> "BaseModel": | ||
self.random_state = random_state | ||
return self | ||
|
||
@abstractmethod | ||
def fit( | ||
self, | ||
X: cn.ndarray, | ||
g: cn.ndarray, | ||
h: cn.ndarray, | ||
) -> "BaseModel": | ||
pass | ||
|
||
@abstractmethod | ||
def update( | ||
self, | ||
X: cn.ndarray, | ||
g: cn.ndarray, | ||
h: cn.ndarray, | ||
) -> "BaseModel": | ||
pass | ||
|
||
@abstractmethod | ||
def predict(self, X: cn.ndarray) -> cn.ndarray: | ||
pass | ||
|
||
@abstractmethod | ||
def __str__(self) -> str: | ||
pass | ||
|
||
@abstractmethod | ||
def __eq__(self, other: object) -> bool: | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
import cunumeric as cn | ||
|
||
from .base_model import BaseModel | ||
|
||
|
||
class Linear(BaseModel): | ||
def fit( | ||
self, | ||
X: cn.ndarray, | ||
g: cn.ndarray, | ||
h: cn.ndarray, | ||
) -> "Linear": | ||
|
||
num_outputs = g.shape[1] | ||
self.bias = -g.sum(axis=0) / h.sum(axis=0) | ||
g = g + self.bias[cn.newaxis, :] * h | ||
self.betas = cn.zeros((X.shape[1], num_outputs)) | ||
for k in range(num_outputs): | ||
W = cn.sqrt(h[:, k]) | ||
Xw = X * W[:, cn.newaxis] | ||
yw = W * (-g[:, k] / h[:, k]) | ||
self.betas[:, k] = cn.linalg.lstsq(Xw, yw)[0] | ||
return self | ||
|
||
def clear(self) -> None: | ||
self.bias.fill(0) | ||
self.betas.fill(0) | ||
|
||
def update( | ||
self, | ||
X: cn.ndarray, | ||
g: cn.ndarray, | ||
h: cn.ndarray, | ||
) -> "Linear": | ||
return self.fit(X, g, h) | ||
|
||
def predict(self, X: cn.ndarray) -> cn.ndarray: | ||
return self.bias + X.dot(self.betas) | ||
|
||
def __str__(self) -> str: | ||
return "Bias: " + str(self.bias) + "\nCoefficients: " + str(self.betas) + "\n" | ||
|
||
def __eq__(self, other: object) -> bool: | ||
return (other.betas == self.betas).all() |
Oops, something went wrong.