Skip to content

Commit

Permalink
Try fix lgbm
Browse files Browse the repository at this point in the history
  • Loading branch information
wjsi committed Oct 3, 2023
1 parent 2750b48 commit 624fb5a
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion mars/learn/contrib/lightgbm/_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def __call__(self):
elif hasattr(self.model, "classes_"):
dtype = np.array(self.model.classes_).dtype
else:
dtype = getattr(self.model, "out_dtype_", np.dtype("float"))
dtype = getattr(self.model, "out_dtype_", [np.dtype("float")])[0]

if self.output_types[0] == OutputType.tensor:
# tensor
Expand Down
6 changes: 3 additions & 3 deletions mars/learn/contrib/lightgbm/_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,11 +406,11 @@ def execute(cls, ctx, op: "LGBMTrain"):
op.model_type == LGBMModelType.RANKER
or op.model_type == LGBMModelType.REGRESSOR
):
model.set_params(out_dtype_=np.dtype("float"))
model.set_params(out_dtype_=[np.dtype("float")])
elif hasattr(label_val, "dtype"):
model.set_params(out_dtype_=label_val.dtype)
model.set_params(out_dtype_=[label_val.dtype])
else:
model.set_params(out_dtype_=label_val.dtypes[0])
model.set_params(out_dtype_=[label_val.dtypes[0]])

Check warning on line 413 in mars/learn/contrib/lightgbm/_train.py

View check run for this annotation

Codecov / codecov/patch

mars/learn/contrib/lightgbm/_train.py#L413

Added line #L413 was not covered by tests

ctx[op.outputs[0].key] = pickle.dumps(model)
finally:
Expand Down

0 comments on commit 624fb5a

Please sign in to comment.