Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ggml: fix gradient allocation logic #966

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/mnist/mnist-common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -536,7 +536,7 @@ void mnist_model_train(mnist_model & model, const float * images, const float *

// gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients.
struct ggml_cgraph * gb_grad = ggml_graph_dup(model.ctx_compute, gf);
ggml_build_backward_expand(model.ctx_compute, gf, gb_grad, /*accumulate =*/ true, false);
ggml_build_backward_expand(model.ctx_compute, gf, gb_grad, /*accumulate =*/ true);

// gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step.
struct ggml_cgraph * gb_opt = ggml_graph_dup(model.ctx_compute, gb_grad);
Expand Down
40 changes: 20 additions & 20 deletions include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -574,10 +574,10 @@ extern "C" {

// this tensor...
enum ggml_tensor_flag {
GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph
GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph
GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph
GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph
GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters
GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up)
};

// n-dimensional tensor
Expand Down Expand Up @@ -1407,14 +1407,14 @@ extern "C" {
// supports 3D: a->ne[2] == b->ne[1]
GGML_API struct ggml_tensor * ggml_get_rows(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
struct ggml_tensor * a, // data
struct ggml_tensor * b); // row indices

GGML_API struct ggml_tensor * ggml_get_rows_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c);
struct ggml_tensor * a, // gradients of ggml_get_rows result
struct ggml_tensor * b, // row indices
struct ggml_tensor * c); // data for ggml_get_rows, only used for its shape

GGML_API struct ggml_tensor * ggml_diag(
struct ggml_context * ctx,
Expand Down Expand Up @@ -1565,9 +1565,9 @@ extern "C" {
// a - dy
GGML_API struct ggml_tensor * ggml_rope_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c,
struct ggml_tensor * a, // gradients of ggml_rope result
struct ggml_tensor * b, // positions
struct ggml_tensor * c, // freq factors
int n_dims,
int mode,
int n_ctx_orig,
Expand Down Expand Up @@ -2033,15 +2033,15 @@ extern "C" {
// loss function

GGML_API struct ggml_tensor * ggml_cross_entropy_loss(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b);
struct ggml_context * ctx,
struct ggml_tensor * a, // logits
struct ggml_tensor * b); // labels

GGML_API struct ggml_tensor * ggml_cross_entropy_loss_back(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
struct ggml_tensor * c);
struct ggml_context * ctx,
struct ggml_tensor * a, // logits
struct ggml_tensor * b, // labels
struct ggml_tensor * c); // gradients of cross_entropy_loss result

// AdamW optimizer step
// Paper: https://arxiv.org/pdf/1711.05101v3.pdf
Expand All @@ -2063,7 +2063,7 @@ extern "C" {
GGML_API void ggml_set_loss(struct ggml_tensor * tensor);

GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate, bool keep);
GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate);

GGML_API void ggml_build_opt_adamw(
struct ggml_context * ctx,
Expand Down
Loading
Loading