diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index e497b6d02388a..71c0bef8ee7ee 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -66,6 +66,7 @@ extern "C" { // "offset" refers to the offset of the tensor data for setting/getting data GGML_API GGML_CALL void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); GGML_API GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + GGML_API GGML_CALL void ggml_backend_tensor_memset( struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); GGML_API void ggml_backend_synchronize(ggml_backend_t backend); @@ -122,7 +123,7 @@ extern "C" { // The backend registry is a registry of all the available backends, and allows initializing backends in a generic way GGML_API size_t ggml_backend_reg_get_count(void); - GGML_API size_t ggml_backend_reg_find_by_name(const char * name); + GGML_API size_t ggml_backend_reg_find_by_name(const char * name); // returns index of backend with name, or SIZE_MAX if not found GGML_API ggml_backend_t ggml_backend_reg_init_backend_from_str(const char * backend_str); // str is backend_name:params (params is optional) GGML_API const char * ggml_backend_reg_get_name(size_t i); GGML_API ggml_backend_t ggml_backend_reg_init_backend(size_t i, const char * params); // params is backend-specific diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index a413df35750b1..2035001e97d7e 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -534,6 +534,7 @@ extern "C" { GGML_OP_CROSS_ENTROPY_LOSS, GGML_OP_CROSS_ENTROPY_LOSS_BACK, + GGML_OP_OPT_STEP_ADAMW, GGML_OP_COUNT, }; @@ -571,10 +572,12 @@ extern "C" { GGML_LOG_LEVEL_DEBUG = 4, }; + // this tensor... enum ggml_tensor_flag { - GGML_TENSOR_FLAG_INPUT = 1, - GGML_TENSOR_FLAG_OUTPUT = 2, - GGML_TENSOR_FLAG_PARAM = 4, + GGML_TENSOR_FLAG_INPUT = 1, // ...is an input for the GGML compute graph + GGML_TENSOR_FLAG_OUTPUT = 2, // ...is an output for the GGML compute graph + GGML_TENSOR_FLAG_PARAM = 4, // ...contains trainable parameters + GGML_TENSOR_FLAG_LOSS = 8, // ...defines loss for numerical optimization (multiple loss tensors add up) }; // n-dimensional tensor @@ -2037,23 +2040,44 @@ extern "C" { struct ggml_tensor * b, struct ggml_tensor * c); + // AdamW optimizer step + // Paper: https://arxiv.org/pdf/1711.05101v3.pdf + // PyTorch: https://pytorch.org/docs/stable/generated/torch.optim.AdamW.html + GGML_API struct ggml_tensor * ggml_opt_step_adamw( + struct ggml_context * ctx, + struct ggml_tensor * a, + float alpha, + float beta1, + float beta2, + float eps, + float wd); // weight decay + // // automatic differentiation // - GGML_API void ggml_set_param( - struct ggml_context * ctx, - struct ggml_tensor * tensor); + GGML_API void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor); + GGML_API void ggml_set_loss(struct ggml_tensor * tensor); GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); - GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep); + GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate, bool keep); + + GGML_API void ggml_build_opt_adamw( + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct ggml_cgraph * gb, + float alpha, + float beta1, + float beta2, + float eps, + float wd); // weight decay // graph allocation in a context GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false GGML_API struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t size, bool grads); GGML_API struct ggml_cgraph * ggml_graph_dup (struct ggml_context * ctx, struct ggml_cgraph * cgraph); GGML_API void ggml_graph_cpy (struct ggml_cgraph * src, struct ggml_cgraph * dst); - GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // zero grads + GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); // set regular grads + optimizer momenta to 0, set loss grad to 1 GGML_API void ggml_graph_clear (struct ggml_cgraph * cgraph); GGML_API int ggml_graph_size (struct ggml_cgraph * cgraph); diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h index 36ca370867c9e..b0d4141cc4363 100644 --- a/ggml/src/ggml-backend-impl.h +++ b/ggml/src/ggml-backend-impl.h @@ -38,15 +38,16 @@ extern "C" { typedef void * ggml_backend_buffer_context_t; struct ggml_backend_buffer_i { - const char * (*GGML_CALL get_name) (ggml_backend_buffer_t buffer); - void (*GGML_CALL free_buffer)(ggml_backend_buffer_t buffer); - void * (*GGML_CALL get_base) (ggml_backend_buffer_t buffer); - void (*GGML_CALL init_tensor)(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); - void (*GGML_CALL set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); - void (*GGML_CALL get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); - bool (*GGML_CALL cpy_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst); // dst is in the buffer, src may be in any buffer - void (*GGML_CALL clear) (ggml_backend_buffer_t buffer, uint8_t value); - void (*GGML_CALL reset) (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras + const char * (*GGML_CALL get_name) (ggml_backend_buffer_t buffer); + void (*GGML_CALL free_buffer) (ggml_backend_buffer_t buffer); + void * (*GGML_CALL get_base) (ggml_backend_buffer_t buffer); + void (*GGML_CALL init_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor); + void (*GGML_CALL memset_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); + void (*GGML_CALL set_tensor) (ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); + void (*GGML_CALL get_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); + bool (*GGML_CALL cpy_tensor) (ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst); // dst is in the buffer, src may be in any buffer + void (*GGML_CALL clear) (ggml_backend_buffer_t buffer, uint8_t value); + void (*GGML_CALL reset) (ggml_backend_buffer_t buffer); // reset any internal state due to tensor initialization, such as tensor extras }; struct ggml_backend_buffer { diff --git a/ggml/src/ggml-backend.c b/ggml/src/ggml-backend.c index b5d9301a78762..ba280e064141f 100644 --- a/ggml/src/ggml-backend.c +++ b/ggml/src/ggml-backend.c @@ -246,6 +246,22 @@ GGML_CALL void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * buf->iface.get_tensor(buf, tensor, data, offset, size); } +GGML_API GGML_CALL void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; + + GGML_ASSERT(buf != NULL && "tensor buffer not set"); + GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); + GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); + + if (!size) { + return; + } + + GGML_ASSERT(buf->iface.memset_tensor != NULL && "memset not supported by backend buffer"); + + buf->iface.memset_tensor(buf, tensor, value, offset, size); +} + void ggml_backend_synchronize(ggml_backend_t backend) { if (backend->iface.synchronize == NULL) { return; @@ -569,6 +585,12 @@ GGML_CALL static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t free(buffer->context); } +GGML_CALL static void ggml_backend_cpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + memset((char *)tensor->data + offset, value, size); + + GGML_UNUSED(buffer); +} + GGML_CALL static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { memcpy((char *)tensor->data + offset, data, size); @@ -600,6 +622,7 @@ static struct ggml_backend_buffer_i cpu_backend_buffer_i = { /* .free_buffer = */ ggml_backend_cpu_buffer_free_buffer, /* .get_base = */ ggml_backend_cpu_buffer_get_base, /* .init_tensor = */ NULL, // no initialization required + /* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor, /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor, @@ -613,6 +636,7 @@ static struct ggml_backend_buffer_i cpu_backend_buffer_i_from_ptr = { /* .free_buffer = */ NULL, // ptr is not owned by the buffer, so it does not need to be freed /* .get_base = */ ggml_backend_cpu_buffer_get_base, /* .init_tensor = */ NULL, // no initialization required + /* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_cpu_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cpu_buffer_get_tensor, /* .cpy_tensor = */ ggml_backend_cpu_buffer_cpy_tensor, @@ -980,6 +1004,7 @@ static struct ggml_backend_buffer_i ggml_backend_multi_buffer_context_interface( /* .free_buffer = */ ggml_backend_multi_buffer_free_buffer, /* .get_base = */ NULL, /* .init_tensor = */ NULL, + /* .memset_tensor = */ NULL, /* .set_tensor = */ NULL, /* .get_tensor = */ NULL, /* .cpy_tensor = */ NULL, diff --git a/ggml/src/ggml-cann.cpp b/ggml/src/ggml-cann.cpp index aa315b83f77aa..d3ab78006ee23 100644 --- a/ggml/src/ggml-cann.cpp +++ b/ggml/src/ggml-cann.cpp @@ -1037,6 +1037,7 @@ static ggml_backend_buffer_i ggml_backend_cann_buffer_interface = { /* .free_buffer = */ ggml_backend_cann_buffer_free_buffer, /* .get_base = */ ggml_backend_cann_buffer_get_base, /* .init_tensor = */ ggml_backend_cann_buffer_init_tensor, + /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_cann_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cann_buffer_get_tensor, /* .cpy_tensor = */ ggml_backend_cann_buffer_cpy_tensor, diff --git a/ggml/src/ggml-cuda.cu b/ggml/src/ggml-cuda.cu index 54f1a7c2d3075..b0843dc621cf5 100644 --- a/ggml/src/ggml-cuda.cu +++ b/ggml/src/ggml-cuda.cu @@ -21,6 +21,8 @@ #include "ggml-cuda/mmq.cuh" #include "ggml-cuda/mmvq.cuh" #include "ggml-cuda/norm.cuh" +#include "ggml-cuda/opt-step-adamw.cuh" +#include "ggml-cuda/out-prod.cuh" #include "ggml-cuda/pad.cuh" #include "ggml-cuda/pool2d.cuh" #include "ggml-cuda/quantize.cuh" @@ -493,6 +495,14 @@ GGML_CALL static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t } } +GGML_CALL static void ggml_backend_cuda_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; + + ggml_cuda_set_device(ctx->device); + CUDA_CHECK(cudaMemsetAsync((char *)tensor->data + offset, value, size, cudaStreamPerThread)); + CUDA_CHECK(cudaStreamSynchronize(cudaStreamPerThread)); +} + GGML_CALL static void ggml_backend_cuda_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context; @@ -544,6 +554,7 @@ static ggml_backend_buffer_i ggml_backend_cuda_buffer_interface = { /* .free_buffer = */ ggml_backend_cuda_buffer_free_buffer, /* .get_base = */ ggml_backend_cuda_buffer_get_base, /* .init_tensor = */ ggml_backend_cuda_buffer_init_tensor, + /* .memset_tensor = */ ggml_backend_cuda_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_cuda_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cuda_buffer_get_tensor, /* .cpy_tensor = */ ggml_backend_cuda_buffer_cpy_tensor, @@ -860,6 +871,7 @@ static struct ggml_backend_buffer_i ggml_backend_cuda_split_buffer_interface = { /* .free_buffer = */ ggml_backend_cuda_split_buffer_free_buffer, /* .get_base = */ ggml_backend_cuda_split_buffer_get_base, /* .init_tensor = */ ggml_backend_cuda_split_buffer_init_tensor, + /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_cuda_split_buffer_set_tensor, /* .get_tensor = */ ggml_backend_cuda_split_buffer_get_tensor, /* .cpy_tensor = */ NULL, @@ -2168,6 +2180,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_REPEAT: ggml_cuda_op_repeat(ctx, dst); break; + case GGML_OP_REPEAT_BACK: + ggml_cuda_op_repeat_back(ctx, dst); + break; case GGML_OP_GET_ROWS: ggml_cuda_op_get_rows(ctx, dst); break; @@ -2201,6 +2216,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_UNARY_OP_NEG: ggml_cuda_op_neg(ctx, dst); break; + case GGML_UNARY_OP_STEP: + ggml_cuda_op_step(ctx, dst); + break; case GGML_UNARY_OP_GELU: ggml_cuda_op_gelu(ctx, dst); break; @@ -2267,6 +2285,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_MUL_MAT_ID: ggml_cuda_mul_mat_id(ctx, dst); break; + case GGML_OP_OUT_PROD: + ggml_cuda_out_prod(ctx, dst); + break; case GGML_OP_SCALE: ggml_cuda_op_scale(ctx, dst); break; @@ -2324,6 +2345,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_CROSS_ENTROPY_LOSS: ggml_cuda_cross_entropy_loss(ctx, dst); break; + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + ggml_cuda_cross_entropy_loss_back(ctx, dst); + break; + case GGML_OP_OPT_STEP_ADAMW: + ggml_cuda_opt_step_adamw(ctx, dst); + break; default: return false; } @@ -2761,6 +2788,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { case GGML_UNARY_OP_NEG: + case GGML_UNARY_OP_STEP: case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_RELU: @@ -2813,6 +2841,8 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons return false; } } break; + case GGML_OP_OUT_PROD: + return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->ne[2] == 1 && op->ne[3] == 1; case GGML_OP_GET_ROWS: { switch (op->src[0]->type) { @@ -2869,6 +2899,12 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons } break; case GGML_OP_DUP: case GGML_OP_REPEAT: + { + ggml_type src0_type = op->src[0]->type; + return src0_type != GGML_TYPE_I32 && src0_type != GGML_TYPE_I16; + } break; + case GGML_OP_REPEAT_BACK: + return op->type == GGML_TYPE_F32 && op->src[0]->ne[3] == 1; case GGML_OP_CONCAT: { ggml_type src0_type = op->src[0]->type; @@ -2935,9 +2971,11 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons } return ggml_cuda_info().devices[cuda_ctx->device].cc >= CC_VOLTA && op->src[1]->type == GGML_TYPE_F16 && op->src[2]->type == GGML_TYPE_F16; +#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) case GGML_OP_CROSS_ENTROPY_LOSS: + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + case GGML_OP_OPT_STEP_ADAMW: return true; -#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) default: return false; } diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index e1390a0414559..c7b6be4e2905c 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -1,4 +1,5 @@ #include "binbcast.cuh" +#include static __device__ __forceinline__ float op_repeat(const float a, const float b) { return b; @@ -90,6 +91,30 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * s dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); } +template +static __global__ void k_repeat_back( + const T * __restrict__ src, T * __restrict__ dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t ne0, const int64_t ne1, const int64_t ne2) { + + const int64_t tid0 = (int64_t) blockIdx.x*blockDim.x + threadIdx.x; + const int64_t tid1 = (int64_t) blockIdx.y*blockDim.y + threadIdx.y; + const int64_t tid2 = (int64_t) blockIdx.z*blockDim.z + threadIdx.z; + + if (tid0 >= ne0) { + return; + } + + T sum = 0; + for (int64_t i2 = tid2; i2 < ne02; i2 += ne2) { + for (int64_t i1 = tid1; i1 < ne01; i1 += ne1) { + for (int64_t i0 = tid0; i0 < ne00; i0 += ne0) { + sum += src[i2*ne01*ne00 + i1*ne00 + i0]; + } + } + } + dst[tid2*ne1*ne0 + tid1*ne0 + tid0] = sum; +} + template struct bin_bcast_cuda { template @@ -247,6 +272,16 @@ struct bin_bcast_cuda { } }; +template +static void repeat_back_cuda( + const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t ne0, const int64_t ne1, const int64_t ne2, cudaStream_t stream) { + + const dim3 block_dims(WARP_SIZE, 1, 1); + const dim3 block_nums((ne0 + WARP_SIZE - 1) / WARP_SIZE, ne1, ne2); + k_repeat_back<<>>(src, dst, ne00, ne01, ne02, ne0, ne1, ne2); +} + template static void ggml_cuda_op_bin_bcast( const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, @@ -286,3 +321,35 @@ void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ggml_cuda_op_bin_bcast>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream()); } + +void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(src0->type == dst->type); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_can_repeat(dst, src0)); + + cudaStream_t stream = ctx.stream(); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + GGML_ASSERT(src0->ne[3] == 1); + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + const int64_t ne2 = dst->ne[2]; + GGML_ASSERT(dst->ne[3] == 1); + + switch (dst->type) { + case GGML_TYPE_F32: { + const float * src0_d = (const float *) src0->data; + float * dst_d = (float *) dst->data; + repeat_back_cuda(src0_d, dst_d, ne00, ne01, ne02, ne0, ne1, ne2, stream); + } break; + default: { + GGML_ASSERT(false); + } break; + } +} diff --git a/ggml/src/ggml-cuda/binbcast.cuh b/ggml/src/ggml-cuda/binbcast.cuh index 198c9ef6fd8ea..3ac1c9b03fcea 100644 --- a/ggml/src/ggml-cuda/binbcast.cuh +++ b/ggml/src/ggml-cuda/binbcast.cuh @@ -5,3 +5,5 @@ void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_sub(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_repeat_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/cross-entropy-loss.cu b/ggml/src/ggml-cuda/cross-entropy-loss.cu index 5575a90f64326..ed09406a88bac 100644 --- a/ggml/src/ggml-cuda/cross-entropy-loss.cu +++ b/ggml/src/ggml-cuda/cross-entropy-loss.cu @@ -71,6 +71,32 @@ static __global__ void cross_entropy_loss_f32(const float * logits, const float dst[blockIdx.x] = loss; } +static __global__ void cross_entropy_loss_back_f32(const float * logits, const float * labels, const float * loss, float * dst, const int nclasses) { + extern __shared__ float tmp[]; + + float maxval = -INFINITY; + for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { + const float val = logits[blockIdx.x*nclasses + i]; + maxval = fmaxf(maxval, val); + tmp[i] = val; + } + maxval = warp_reduce_max(maxval); + + float sum = 0.0f; + for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { + const float val = expf(tmp[i] - maxval); + sum += val; + tmp[i] = val; + } + sum = warp_reduce_sum(sum); + const float sm_scale = 1.0f/sum; + + const float d_by_nrows = *loss/gridDim.x; + for (int i = threadIdx.x; i < nclasses; i += WARP_SIZE) { + dst[blockIdx.x*nclasses + i] = (tmp[i]*sm_scale - labels[blockIdx.x*nclasses + i])*d_by_nrows; + } +} + void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; @@ -104,3 +130,37 @@ void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * // Combine results from individual blocks: sum_f32_cuda(pool, dst_tmp.ptr, dst_d, blocks_num.x, stream); } + +void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + const ggml_tensor * opt0 = dst->src[2]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(opt0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src1)); + GGML_ASSERT(ggml_is_contiguous(opt0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + GGML_ASSERT(ggml_are_same_shape(src0, src1)); + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + + const int64_t ne00 = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + const float * opt0_d = (const float *) opt0->data; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); + + const dim3 blocks_dim(WARP_SIZE, 1, 1); + const dim3 blocks_num(nrows, 1, 1); + const int shmem = ne00*sizeof(float); + + cross_entropy_loss_back_f32<<>>(src0_d, src1_d, opt0_d, dst_d, ne00); +} diff --git a/ggml/src/ggml-cuda/cross-entropy-loss.cuh b/ggml/src/ggml-cuda/cross-entropy-loss.cuh index 9d7b8b0f0082b..9ec7152ff4518 100644 --- a/ggml/src/ggml-cuda/cross-entropy-loss.cuh +++ b/ggml/src/ggml-cuda/cross-entropy-loss.cuh @@ -3,3 +3,5 @@ #define CUDA_CROSS_ENTROPY_LOSS_BLOCK_SIZE 256 void ggml_cuda_cross_entropy_loss(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_cross_entropy_loss_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/opt-step-adamw.cu b/ggml/src/ggml-cuda/opt-step-adamw.cu new file mode 100644 index 0000000000000..d6f13a9c62df2 --- /dev/null +++ b/ggml/src/ggml-cuda/opt-step-adamw.cu @@ -0,0 +1,80 @@ +#include "opt-step-adamw.cuh" + +#include + +static __global__ void opt_step_adamw_f32( + float * __restrict__ x, const float * __restrict__ g, float * __restrict__ g_m, float * __restrict__ g_v, const int64_t k, + const float alpha, const float beta1, const float beta2, const float eps, const float wd, + const float beta1h, const float beta2h) { + + const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x; + + if (i >= k) { + return; + } + + const float gi = g[i]; + const float gmi = g_m[i]*beta1 + gi*(1.0f - beta1); + const float gvi = g_v[i]*beta2 + gi*gi*(1.0f - beta2); + + g_m[i] = gmi; + g_v[i] = gvi; + + const float mh = gmi*beta1h; + const float vh = sqrtf(gvi*beta2h) + eps; + + x[i] = x[i]*(1.0f - alpha*wd) - mh/vh; +} + +static void opt_step_adamw_f32_cuda( + float * x, const float * g, float * g_m, float * g_v, const int64_t k, + const float alpha, const float beta1, const float beta2, const float eps, const float wd, + const float beta1h, const float beta2h, cudaStream_t stream) { + + const dim3 block_dims(CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1); + const dim3 block_nums((k + CUDA_OPT_STEP_ADAMW_BLOCK_SIZE - 1) / CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1); + opt_step_adamw_f32<<>>(x, g, g_m, g_v, k, alpha, beta1, beta2, eps, wd, beta1h, beta2h); +} + +void ggml_cuda_opt_step_adamw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src0_grad = dst->src[1]; + const ggml_tensor * src0_grad_m = dst->src[2]; + const ggml_tensor * src0_grad_v = dst->src[3]; + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src0_grad->type == GGML_TYPE_F32); + GGML_ASSERT(src0_grad_m->type == GGML_TYPE_F32); + GGML_ASSERT(src0_grad_v->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(src0_grad)); + GGML_ASSERT(ggml_is_contiguous(src0_grad_m)); + GGML_ASSERT(ggml_is_contiguous(src0_grad_v)); + GGML_ASSERT(ggml_are_same_shape(src0, src0_grad)); + GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m)); + GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v)); + + float * src0_d = (float *) src0->data; + const float * src0_grad_d = (const float *) src0_grad->data; + float * src0_grad_m_d = (float *) src0_grad_m->data; + float * src0_grad_v_d = (float *) src0_grad_v->data; + + cudaStream_t stream = ctx.stream(); + + const int64_t ne = ggml_nelements(src0); + + int64_t iter; memcpy(&iter, &dst->op_params[0], sizeof(int64_t)); + float alpha; memcpy(&alpha, &dst->op_params[2], sizeof(float)); + float beta1; memcpy(&beta1, &dst->op_params[3], sizeof(float)); + float beta2; memcpy(&beta2, &dst->op_params[4], sizeof(float)); + float eps; memcpy(&eps, &dst->op_params[5], sizeof(float)); + float wd; memcpy(&wd, &dst->op_params[6], sizeof(float)); + + const float beta1h = alpha/(1.0f - powf(beta1, iter)); + const float beta2h = 1.0f/(1.0f - powf(beta2, iter)); + + opt_step_adamw_f32_cuda(src0_d, src0_grad_d, src0_grad_m_d, src0_grad_v_d, ne, alpha, beta1, beta2, eps, wd, beta1h, beta2h, stream); + + iter++; + memcpy(&dst->op_params[0], &iter, sizeof(int64_t)); +} diff --git a/ggml/src/ggml-cuda/opt-step-adamw.cuh b/ggml/src/ggml-cuda/opt-step-adamw.cuh new file mode 100644 index 0000000000000..58d6f6e5dfc55 --- /dev/null +++ b/ggml/src/ggml-cuda/opt-step-adamw.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_OPT_STEP_ADAMW_BLOCK_SIZE 256 + +void ggml_cuda_opt_step_adamw(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/out-prod.cu b/ggml/src/ggml-cuda/out-prod.cu new file mode 100644 index 0000000000000..619cfdcb5894a --- /dev/null +++ b/ggml/src/ggml-cuda/out-prod.cu @@ -0,0 +1,51 @@ +#include "out-prod.cuh" + +#include + +void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src1 = dst->src[1]; + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_contiguous(src0)); + GGML_ASSERT(ggml_is_contiguous(dst)); + + GGML_ASSERT(ne01 == ne11); + GGML_ASSERT(ne0 == ne00); + GGML_ASSERT(ne1 == ne10); + + GGML_ASSERT(ne2 == src0->ne[2]); + GGML_ASSERT(ne2 == src1->ne[2]); + GGML_ASSERT(ne3 == src0->ne[3]); + GGML_ASSERT(ne3 == src1->ne[3]); + + const float * src0_d = (const float *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + cudaStream_t stream = ctx.stream(); + cublasHandle_t handle = ctx.cublas_handle(); + + const float alpha = 1.0f; + const float beta = 0.0f; + + GGML_ASSERT(ne2 == 1); + GGML_ASSERT(ne3 == 1); + CUBLAS_CHECK(cublasSetStream(handle, stream)); + + const bool src1_T = ggml_is_transposed(src1); + const cublasOperation_t src1_cublas_op = src1_T ? CUBLAS_OP_N : CUBLAS_OP_T; + const int64_t ldb = (src1_T ? nb10 : nb11) / sizeof(float); + GGML_ASSERT( (src1_T ? nb11 : nb10) == sizeof(float)); + + CUBLAS_CHECK( + cublasSgemm(handle, CUBLAS_OP_N, src1_cublas_op, + ne0, ne1, ne01, + &alpha, src0_d, ne00, + src1_d, ldb, + &beta, dst_d, ne0)); +} diff --git a/ggml/src/ggml-cuda/out-prod.cuh b/ggml/src/ggml-cuda/out-prod.cuh new file mode 100644 index 0000000000000..a0046f5f8f484 --- /dev/null +++ b/ggml/src/ggml-cuda/out-prod.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_out_prod(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/unary.cu b/ggml/src/ggml-cuda/unary.cu index 8ac669f94e2de..163b5a8ffec6b 100644 --- a/ggml/src/ggml-cuda/unary.cu +++ b/ggml/src/ggml-cuda/unary.cu @@ -10,6 +10,16 @@ static __global__ void neg_f32(const float * x, float * dst, const int k) { dst[i] = -x[i]; } +static __global__ void step_f32(const float * x, float * dst, const int k) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + dst[i] = x[i] > 0.0f; +} + static __global__ void gelu_f32(const float * x, float * dst, const int k) { const float GELU_COEF_A = 0.044715f; const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; @@ -134,6 +144,11 @@ static void neg_f32_cuda(const float * x, float * dst, const int k, cudaStream_t neg_f32<<>>(x, dst, k); } +static void step_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_STEP_BLOCK_SIZE - 1) / CUDA_STEP_BLOCK_SIZE; + step_f32<<>>(x, dst, k); +} + static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE; gelu_f32<<>>(x, dst, k); @@ -213,6 +228,20 @@ void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { neg_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); } +void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *)src0->data; + float * dst_d = (float *)dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + step_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream); +} + void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; diff --git a/ggml/src/ggml-cuda/unary.cuh b/ggml/src/ggml-cuda/unary.cuh index ed2ffc461e810..fe519f6a232df 100644 --- a/ggml/src/ggml-cuda/unary.cuh +++ b/ggml/src/ggml-cuda/unary.cuh @@ -1,6 +1,7 @@ #include "common.cuh" #define CUDA_NEG_BLOCK_SIZE 256 +#define CUDA_STEP_BLOCK_SIZE 256 #define CUDA_GELU_BLOCK_SIZE 256 #define CUDA_SILU_BLOCK_SIZE 256 #define CUDA_TANH_BLOCK_SIZE 256 @@ -15,6 +16,8 @@ void ggml_cuda_op_neg(ggml_backend_cuda_context & ctx, ggml_tensor * dst); +void ggml_cuda_op_step(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index d0c377255968c..1f3c70c2e6934 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -30,6 +30,7 @@ #define cublasSetStream hipblasSetStream #define cublasSgemm hipblasSgemm #define cublasStatus_t hipblasStatus_t +#define cublasOperation_t hipblasOperation_t #define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6 #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess diff --git a/ggml/src/ggml-kompute.cpp b/ggml/src/ggml-kompute.cpp index 7f0bd82d5de92..9cbc57a647de5 100644 --- a/ggml/src/ggml-kompute.cpp +++ b/ggml/src/ggml-kompute.cpp @@ -1872,6 +1872,7 @@ static ggml_backend_buffer_i ggml_backend_kompute_buffer_i = { /* .free_buffer = */ ggml_backend_kompute_buffer_free_buffer, /* .get_base = */ ggml_backend_kompute_buffer_get_base, /* .init_tensor = */ NULL, + /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_kompute_buffer_set_tensor, /* .get_tensor = */ ggml_backend_kompute_buffer_get_tensor, /* .cpy_tensor = */ NULL, diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index f87181d19332e..ef3b7f0e824a9 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -3167,6 +3167,7 @@ GGML_CALL static void ggml_backend_metal_buffer_clear(ggml_backend_buffer_t buff /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer, /* .get_base = */ ggml_backend_metal_buffer_get_base, /* .init_tensor = */ NULL, + /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_metal_buffer_set_tensor, /* .get_tensor = */ ggml_backend_metal_buffer_get_tensor, /* .cpy_tensor = */ ggml_backend_metal_buffer_cpy_tensor, diff --git a/ggml/src/ggml-rpc.cpp b/ggml/src/ggml-rpc.cpp index a8a2eb85adc23..49b3fa91174e2 100644 --- a/ggml/src/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc.cpp @@ -469,6 +469,7 @@ static ggml_backend_buffer_i ggml_backend_rpc_buffer_interface = { /* .free_buffer = */ ggml_backend_rpc_buffer_free_buffer, /* .get_base = */ ggml_backend_rpc_buffer_get_base, /* .init_tensor = */ ggml_backend_rpc_buffer_init_tensor, + /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_rpc_buffer_set_tensor, /* .get_tensor = */ ggml_backend_rpc_buffer_get_tensor, /* .cpy_tensor = */ ggml_backend_rpc_buffer_cpy_tensor, diff --git a/ggml/src/ggml-sycl.cpp b/ggml/src/ggml-sycl.cpp index acef7c6d4e1ea..16e6be4a0a12b 100644 --- a/ggml/src/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl.cpp @@ -4323,6 +4323,7 @@ static struct ggml_backend_buffer_i ggml_backend_sycl_buffer_interface = { /* .free_buffer = */ ggml_backend_sycl_buffer_free_buffer, /* .get_base = */ ggml_backend_sycl_buffer_get_base, /* .init_tensor = */ ggml_backend_sycl_buffer_init_tensor, + /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_sycl_buffer_set_tensor, /* .get_tensor = */ ggml_backend_sycl_buffer_get_tensor, /* .cpy_tensor = */ ggml_backend_sycl_buffer_cpy_tensor, @@ -4734,6 +4735,7 @@ static struct ggml_backend_buffer_i ggml_backend_sycl_split_buffer_interface = { /* .free_buffer = */ ggml_backend_sycl_split_buffer_free_buffer, /* .get_base = */ ggml_backend_sycl_split_buffer_get_base, /* .init_tensor = */ ggml_backend_sycl_split_buffer_init_tensor, + /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_sycl_split_buffer_set_tensor, /* .get_tensor = */ ggml_backend_sycl_split_buffer_get_tensor, /* .cpy_tensor = */ NULL, diff --git a/ggml/src/ggml-vulkan.cpp b/ggml/src/ggml-vulkan.cpp index bad960510850e..f9da45881e9df 100644 --- a/ggml/src/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan.cpp @@ -6246,6 +6246,7 @@ static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = { /* .free_buffer = */ ggml_backend_vk_buffer_free_buffer, /* .get_base = */ ggml_backend_vk_buffer_get_base, /* .init_tensor = */ ggml_backend_vk_buffer_init_tensor, + /* .memset_tensor = */ NULL, /* .set_tensor = */ ggml_backend_vk_buffer_set_tensor, /* .get_tensor = */ ggml_backend_vk_buffer_get_tensor, /* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor, diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index c5c98dbe4bf68..201d5466a0e4b 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1,6 +1,7 @@ #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows #define _USE_MATH_DEFINES // For M_PI on MSVC +#include "ggml-backend.h" #include "ggml-impl.h" #include "ggml-cpu-impl.h" #include "ggml-quants.h" @@ -2997,9 +2998,10 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS", "CROSS_ENTROPY_LOSS_BACK", + "OPT_STEP_ADAMW", }; -static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79"); +static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -3090,9 +3092,10 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss(x,y)", "cross_entropy_loss_back(x,y)", + "adamw(x)", }; -static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79"); +static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -4094,7 +4097,11 @@ static void ggml_set_op_params_f32(struct ggml_tensor * tensor, uint32_t i, floa } struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) { - memset(tensor->data, 0, ggml_nbytes(tensor)); + if (tensor->buffer) { + ggml_backend_tensor_memset(tensor, 0, 0, ggml_nbytes(tensor)); + } else { + memset(tensor->data, 0, ggml_nbytes(tensor)); + } return tensor; } @@ -8320,11 +8327,46 @@ struct ggml_tensor * ggml_cross_entropy_loss_back( return result; } -//////////////////////////////////////////////////////////////////////////////// +// opt_step_adamw -void ggml_set_param( +struct ggml_tensor * ggml_opt_step_adamw( struct ggml_context * ctx, - struct ggml_tensor * tensor) { + struct ggml_tensor * a, + float alpha, + float beta1, + float beta2, + float eps, + float wd) { + GGML_ASSERT(a->grad); + GGML_ASSERT(alpha > 0.0f); + GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f); + GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f); + GGML_ASSERT(eps >= 0.0f); + GGML_ASSERT(wd >= 0.0f && wd <= 1.0f); + + struct ggml_tensor * result = ggml_view_tensor(ctx, a); + + result->op = GGML_OP_OPT_STEP_ADAMW; + result->grad = NULL; + result->src[0] = a; + result->src[1] = a->grad; + result->src[2] = ggml_dup_tensor(ctx, a->grad); + result->src[3] = ggml_dup_tensor(ctx, a->grad); + + const int64_t iter = 1; + memcpy(&result->op_params[0], &iter, sizeof(int64_t)); + ggml_set_op_params_f32(result, 2, alpha); + ggml_set_op_params_f32(result, 3, beta1); + ggml_set_op_params_f32(result, 4, beta2); + ggml_set_op_params_f32(result, 5, eps); + ggml_set_op_params_f32(result, 6, wd); + + return result; +} + +//////////////////////////////////////////////////////////////////////////////// + +void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor) { tensor->flags |= GGML_TENSOR_FLAG_PARAM; GGML_ASSERT(tensor->grad == NULL); @@ -8332,6 +8374,13 @@ void ggml_set_param( ggml_format_name(tensor->grad, "%s (grad)", tensor->name); } +void ggml_set_loss(struct ggml_tensor * tensor) { + GGML_ASSERT(ggml_is_scalar(tensor)); + GGML_ASSERT(tensor->type == GGML_TYPE_F32); + GGML_ASSERT(tensor->grad); + tensor->flags |= GGML_TENSOR_FLAG_LOSS; +} + // ggml_compute_forward_dup static void ggml_compute_forward_dup_same_cont( @@ -17406,7 +17455,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32( const int64_t ir0 = dr*ith; const int64_t ir1 = MIN(ir0 + dr, nr); - float * d = (float *) opt0->data; + const float d_by_nr = ((const float *) opt0->data)[0] / (float) nr; for (int64_t i1 = ir0; i1 < ir1; i1++) { float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]); @@ -17430,7 +17479,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32( // grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr ggml_vec_sub_f32(nc, ds0, ds0, s1); - ggml_vec_scale_f32(nc, ds0, d[0] / (float) nr); + ggml_vec_scale_f32(nc, ds0, d_by_nr); #ifndef NDEBUG for (int i = 0; i < nc; ++i) { @@ -17459,6 +17508,94 @@ static void ggml_compute_forward_cross_entropy_loss_back( } } +static void ggml_compute_forward_opt_step_adamw_f32( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src0_grad = dst->src[1]; + const struct ggml_tensor * src0_grad_m = dst->src[2]; + const struct ggml_tensor * src0_grad_v = dst->src[3]; + GGML_ASSERT(ggml_are_same_shape(src0, src0_grad)); + + const int ith = params->ith; + const int nth = params->nth; + + const int nr = ggml_nrows(src0); + + GGML_TENSOR_UNARY_OP_LOCALS + GGML_ASSERT(nb00 == sizeof(float)); + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + /* const float gnorm = 1.0f; */ + int64_t iter; memcpy(&iter, &dst->op_params[0], sizeof(int64_t)); + const float alpha = ggml_get_op_params_f32(dst, 2); + const float beta1 = ggml_get_op_params_f32(dst, 3); + const float beta2 = ggml_get_op_params_f32(dst, 4); + const float eps = ggml_get_op_params_f32(dst, 5); + const float wd = ggml_get_op_params_f32(dst, 6); + + const float beta1h = alpha/(1.0f - powf(beta1, iter)); + const float beta2h = 1.0f/(1.0f - powf(beta2, iter)); + + for (int ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + const size_t offset = i03*nb03 + i02*nb02 + i01*nb01; + + float * w = (float *) ((char *) src0->data + offset); // weight + const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad + float * m = (float *) ((char *) src0_grad_m->data + offset); + float * v = (float *) ((char *) src0_grad_v->data + offset); + + for (int i00 = 0; i00 < ne00; ++i00) { + m[i00] = m[i00]*beta1 + g[i00]*(1.0f - beta1); + v[i00] = v[i00]*beta2 + g[i00]*g[i00]*(1.0f - beta2); + + const float mh = m[i00]*beta1h; + const float vh = sqrtf(v[i00]*beta2h) + eps; + + // The weight decay is applied independently of the Adam momenta m and v. + // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss. + // See: https://arxiv.org/pdf/1711.05101v3.pdf + w[i00] = w[i00]*(1.0f - alpha*wd) - mh/vh; + } + } + + ggml_barrier(params->threadpool); + if (ith != 0) { + return; + } + + iter++; + memcpy(&dst->op_params[0], &iter, sizeof(int64_t)); +} + +static void ggml_compute_forward_opt_step_adamw( + const struct ggml_compute_params * params, + struct ggml_tensor * dst) { + + const struct ggml_tensor * src0 = dst->src[0]; + + switch (src0->type) { + case GGML_TYPE_F32: + { + ggml_compute_forward_opt_step_adamw_f32(params, dst); + } break; + default: + { + GGML_ABORT("fatal error"); + } + } +} ///////////////////////////////// static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) { @@ -17804,6 +17941,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm ggml_compute_forward_cross_entropy_loss_back(params, tensor); } break; + case GGML_OP_OPT_STEP_ADAMW: + { + ggml_compute_forward_opt_step_adamw(params, tensor); + } + break; case GGML_OP_NONE: { // nop @@ -17958,7 +18100,7 @@ void ggml_build_backward_gradient_checkpointing( struct ggml_tensor * * checkpoints, int n_checkpoints) { ggml_graph_cpy(gf, gb_tmp); - ggml_build_backward_expand(ctx, gf, gb_tmp, true); + ggml_build_backward_expand(ctx, gf, gb_tmp, false, true); if (n_checkpoints <= 0) { ggml_graph_cpy(gb_tmp, gb); @@ -17996,42 +18138,93 @@ void ggml_build_backward_gradient_checkpointing( ggml_hash_map_free(replacements); } -// functions to change gradients considering the case that input a might be initial gradient with zero value - -static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) { +// utility functions to change gradients +// if a is in acc_table, modify gradients in-place and mark result as gradient accumulator +// else if a is in zero_table, replace a +// else, just add/subtract/etc. the gradients + +static struct ggml_tensor * ggml_add_or_set( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_hash_set * zero_table, + struct ggml_hash_set * acc_table) { + if (ggml_hash_contains(acc_table, a)) { + struct ggml_tensor * ret = ggml_add_impl(ctx, a, b, true); + const size_t insert_result = ggml_hash_insert(acc_table, ret); + GGML_ASSERT(insert_result != GGML_HASHSET_FULL); + GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS); + return ret; + } if (ggml_hash_contains(zero_table, a)) { return b; - } else { - return ggml_add_impl(ctx, a, b, false); } + return ggml_add_impl(ctx, a, b, false); } -static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, struct ggml_hash_set * zero_table) { +static struct ggml_tensor * ggml_acc_or_set( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + const size_t nb1, + const size_t nb2, + const size_t nb3, + const size_t offset, + struct ggml_hash_set * zero_table, + struct ggml_hash_set * acc_table) { + if (ggml_hash_contains(acc_table, a)) { + struct ggml_tensor * ret = ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true); + const size_t insert_result = ggml_hash_insert(acc_table, ret); + GGML_ASSERT(insert_result != GGML_HASHSET_FULL); + GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS); + return ret; + } if (ggml_hash_contains(zero_table, a)) { - struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f); + struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false); - } else { - return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false); } + return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false); } -static struct ggml_tensor * ggml_add1_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) { +static struct ggml_tensor * ggml_add1_or_set( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_hash_set * zero_table, + struct ggml_hash_set * acc_table) { + if (ggml_hash_contains(acc_table, a)) { + struct ggml_tensor * ret = ggml_add1_impl(ctx, a, b, true); + const size_t insert_result = ggml_hash_insert(acc_table, ret); + GGML_ASSERT(insert_result != GGML_HASHSET_FULL); + GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS); + return ret; + } if (ggml_hash_contains(zero_table, a)) { return ggml_repeat(ctx, b, a); - } else { - return ggml_add1_impl(ctx, a, b, false); } + return ggml_add1_impl(ctx, a, b, false); } -static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) { +static struct ggml_tensor * ggml_sub_or_set( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * b, + struct ggml_hash_set * zero_table, + struct ggml_hash_set * acc_table) { + if (ggml_hash_contains(acc_table, a)) { + struct ggml_tensor * ret = ggml_sub_impl(ctx, a, b, true); + const size_t insert_result = ggml_hash_insert(acc_table, ret); + GGML_ASSERT(insert_result != GGML_HASHSET_FULL); + GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS); + return ret; + } if (ggml_hash_contains(zero_table, a)) { return ggml_neg(ctx, b); - } else { - return ggml_sub_impl(ctx, a, b, false); } + return ggml_sub_impl(ctx, a, b, false); } -static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set * zero_table) { +static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set * zero_table, struct ggml_hash_set * acc_table) { struct ggml_tensor * src0 = tensor->src[0]; struct ggml_tensor * src1 = tensor->src[1]; struct ggml_tensor * src2 = tensor->src[2]; @@ -18040,38 +18233,38 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor case GGML_OP_DUP: { if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); + src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); } } break; case GGML_OP_ADD: { if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); + src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); } if (src1->grad) { if (ggml_are_same_shape(src0, src1)) { - src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table); + src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table); } else { - src1->grad = ggml_add_or_set(ctx, src1->grad, ggml_repeat_back(ctx, tensor->grad, src1), zero_table); + src1->grad = ggml_add_or_set(ctx, src1->grad, ggml_repeat_back(ctx, tensor->grad, src1), zero_table, acc_table); } } } break; case GGML_OP_ADD1: { if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); + src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); } if (src1->grad) { src1->grad = ggml_add_or_set(ctx, src1->grad, ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean - zero_table); + zero_table, acc_table); } } break; case GGML_OP_ACC: { if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); + src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); } if (src1->grad) { const size_t nb1 = ((int32_t *) tensor->op_params)[0]; @@ -18093,16 +18286,16 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor ggml_reshape(ctx, ggml_cont(ctx, tensor_grad_view), src1->grad), - zero_table); + zero_table, acc_table); } } break; case GGML_OP_SUB: { if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); + src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); } if (src1->grad) { - src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table); + src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table); } } break; case GGML_OP_MUL: @@ -18112,14 +18305,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor ggml_add_or_set(ctx, src0->grad, ggml_mul(ctx, src1, tensor->grad), - zero_table); + zero_table, acc_table); } if (src1->grad) { src1->grad = ggml_add_or_set(ctx, src1->grad, ggml_mul(ctx, src0, tensor->grad), - zero_table); + zero_table, acc_table); } } break; case GGML_OP_DIV: @@ -18129,7 +18322,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor ggml_add_or_set(ctx, src0->grad, ggml_div(ctx, tensor->grad, src1), - zero_table); + zero_table, acc_table); } if (src1->grad) { src1->grad = @@ -18138,7 +18331,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor ggml_mul(ctx, tensor->grad, ggml_div(ctx, tensor, src1)), - zero_table); + zero_table, acc_table); } } break; case GGML_OP_SQR: @@ -18150,7 +18343,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor ggml_scale(ctx, ggml_mul(ctx, src0, tensor->grad), 2.0f), - zero_table); + zero_table, acc_table); } } break; case GGML_OP_SQRT: @@ -18164,7 +18357,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor tensor->grad, tensor), 0.5f), - zero_table); + zero_table, acc_table); } } break; case GGML_OP_LOG: @@ -18176,7 +18369,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor ggml_div(ctx, tensor->grad, src0), - zero_table); + zero_table, acc_table); } } break; case GGML_OP_SIN: @@ -18188,7 +18381,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor ggml_mul(ctx, tensor->grad, ggml_cos(ctx, src0)), - zero_table); + zero_table, acc_table); } } break; case GGML_OP_COS: @@ -18200,7 +18393,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor ggml_mul(ctx, tensor->grad, ggml_sin(ctx, src0)), - zero_table); + zero_table, acc_table); } } break; case GGML_OP_SUM: @@ -18210,7 +18403,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor ggml_add1_or_set(ctx, src0->grad, tensor->grad, - zero_table); + zero_table, acc_table); } } break; case GGML_OP_SUM_ROWS: @@ -18222,7 +18415,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor ggml_repeat(ctx, tensor->grad, src0->grad), - zero_table); + zero_table, acc_table); } } break; case GGML_OP_MEAN: @@ -18237,7 +18430,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src0->grad = ggml_add_or_set(ctx, src0->grad, ggml_repeat_back(ctx, tensor->grad, src0->grad), - zero_table); + zero_table, acc_table); } } break; case GGML_OP_REPEAT_BACK: @@ -18247,7 +18440,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src0->grad = ggml_add_or_set(ctx, src0->grad, ggml_repeat(ctx, tensor->grad, src0->grad), - zero_table); + zero_table, acc_table); } } break; case GGML_OP_CONCAT: @@ -18272,7 +18465,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src0->grad = ggml_add_or_set(ctx, src0->grad, ggml_rms_norm_back(ctx, src0, tensor->grad, eps), - zero_table); + zero_table, acc_table); } } break; case GGML_OP_RMS_NORM_BACK: @@ -18320,7 +18513,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor ggml_add_or_set(ctx, src0->grad, // [n,m,q1,r1] s1_tg, // [n,m,q1,r1] - zero_table); + zero_table, acc_table); } if (src1->grad) { src1->grad = @@ -18338,7 +18531,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src0, // [n,m,q1,r1] ggml_transpose(ctx, // [p,m,qq,rr] tensor->grad)), // [m,p,qq,rr] - zero_table); + zero_table, acc_table); } } break; case GGML_OP_MUL_MAT_ID: @@ -18360,7 +18553,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor ggml_add_or_set(ctx, src0->grad, ggml_scale_impl(ctx, tensor->grad, s, false), - zero_table); + zero_table, acc_table); } } break; case GGML_OP_SET: @@ -18389,7 +18582,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor tensor->grad, ggml_neg(ctx, tensor_grad_view), nb1, nb2, nb3, offset, false), - zero_table); + zero_table, acc_table); } if (src1->grad) { @@ -18399,7 +18592,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor ggml_reshape(ctx, ggml_cont(ctx, tensor_grad_view), src1->grad), - zero_table); + zero_table, acc_table); } } break; case GGML_OP_CPY: @@ -18410,7 +18603,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor // tensor = src0 * 1 + src1 * 0 if (src0->grad) { // dsrc0 = dtensor * 1 - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); + src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); } if (src1->grad) { // dsrc1 = dtensor * 0 -> noop @@ -18422,7 +18615,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor if (src0->grad) { GGML_ASSERT(ggml_is_contiguous(src0->grad)); GGML_ASSERT(ggml_is_contiguous(tensor->grad)); - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table); + src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); } } break; case GGML_OP_RESHAPE: @@ -18436,7 +18629,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor ? tensor->grad : ggml_cont(ctx, tensor->grad), src0->grad), - zero_table); + zero_table, acc_table); } } break; case GGML_OP_VIEW: @@ -18465,7 +18658,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor nb3 = (nb3 / n0) * ng; } - src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table); + src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table, acc_table); } } break; case GGML_OP_PERMUTE: @@ -18490,7 +18683,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor axes_backward[1], axes_backward[2], axes_backward[3]), - zero_table); + zero_table, acc_table); } } break; case GGML_OP_TRANSPOSE: @@ -18500,7 +18693,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src0->grad = ggml_add_or_set(ctx, src0->grad, ggml_transpose(ctx, tensor->grad), - zero_table); + zero_table, acc_table); } } break; case GGML_OP_GET_ROWS: @@ -18512,7 +18705,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor // last ggml_get_rows_back argument src0->grad is only // necessary to setup correct output shape ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad), - zero_table); + zero_table, acc_table); } if (src1->grad) { // noop @@ -18536,7 +18729,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor /* ggml_diag_mask_inf_impl() shouldn't be here */ /* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */ ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), - zero_table); + zero_table, acc_table); } } break; case GGML_OP_DIAG_MASK_ZERO: @@ -18547,7 +18740,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src0->grad = ggml_add_or_set(ctx, src0->grad, ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), - zero_table); + zero_table, acc_table); } } break; case GGML_OP_SOFT_MAX: @@ -18557,7 +18750,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src0->grad = ggml_add_or_set(ctx, src0->grad, ggml_soft_max_back(ctx, tensor->grad, tensor), - zero_table); + zero_table, acc_table); } } break; @@ -18598,7 +18791,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor attn_factor, beta_fast, beta_slow), - zero_table); + zero_table, acc_table); } } break; case GGML_OP_ROPE_BACK: @@ -18634,7 +18827,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor beta_fast, beta_slow, false), - zero_table); + zero_table, acc_table); } } break; case GGML_OP_CLAMP: @@ -18659,7 +18852,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src1->grad = ggml_add_or_set(ctx, src1->grad, ggml_im2col_back(ctx, src0, tensor->grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D), - zero_table); + zero_table, acc_table); } } break; case GGML_OP_IM2COL_BACK: @@ -18688,7 +18881,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src0->grad = ggml_add_or_set(ctx, src0->grad, ggml_pool_2d_back(ctx, tensor->grad, src0, op, k0, k1, s0, s1, p0, p1), - zero_table); + zero_table, acc_table); } } break; case GGML_OP_POOL_2D_BACK: @@ -18753,7 +18946,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src0->grad = ggml_add_or_set(ctx, src0->grad, grad_q, - zero_table); + zero_table, acc_table); } if (src1->grad) { struct ggml_tensor * view_k = ggml_view_1d(ctx, flash_grad, elem_k, offs_k); @@ -18761,7 +18954,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src1->grad = ggml_add_or_set(ctx, src1->grad, grad_k, - zero_table); + zero_table, acc_table); } if (src2->grad) { struct ggml_tensor * view_v = ggml_view_1d(ctx, flash_grad, elem_v, offs_v); @@ -18769,7 +18962,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src2->grad = ggml_add_or_set(ctx, src2->grad, grad_v, - zero_table); + zero_table, acc_table); } } break; case GGML_OP_FLASH_ATTN_BACK: @@ -18795,7 +18988,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor ggml_mul(ctx, ggml_sgn(ctx, src0), tensor->grad), - zero_table); + zero_table, acc_table); } } break; case GGML_UNARY_OP_SGN: @@ -18807,7 +19000,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor case GGML_UNARY_OP_NEG: { if (src0->grad) { - src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table); + src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); } } break; case GGML_UNARY_OP_STEP: @@ -18832,7 +19025,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor ggml_mul(ctx, ggml_step(ctx, src0), tensor->grad), - zero_table); + zero_table, acc_table); } } break; case GGML_UNARY_OP_SIGMOID: @@ -18854,7 +19047,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src0->grad = ggml_add_or_set(ctx, src0->grad, ggml_silu_back(ctx, src0, tensor->grad), - zero_table); + zero_table, acc_table); } } break; case GGML_UNARY_OP_EXP: @@ -18863,7 +19056,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src0->grad = ggml_add_or_set(ctx, src0->grad, ggml_mul(ctx, tensor, tensor->grad), - zero_table); + zero_table, acc_table); } } break; default: @@ -18893,13 +19086,17 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor src0, src1, tensor->grad), - zero_table); + zero_table, acc_table); } } break; case GGML_OP_CROSS_ENTROPY_LOSS_BACK: { GGML_ABORT("fatal error"); // not supported } + case GGML_OP_OPT_STEP_ADAMW: + { + GGML_ABORT("fatal error"); // not supported + } case GGML_OP_NONE: { // nop @@ -18989,7 +19186,7 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * ggml_build_forward_impl(cgraph, tensor, true); } -void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep) { +void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate, bool keep) { GGML_ASSERT(gf->n_nodes > 0); GGML_ASSERT(gf->grads); @@ -19005,21 +19202,35 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * } } - // remember original gradients which start with zero values + // keep tables of original gradients for replacement/accumulation logic struct ggml_hash_set zero_table = ggml_hash_set_new(gf->size); + struct ggml_hash_set acc_table = ggml_hash_set_new(gf->size); for (int i = 0; i < gf->n_nodes; i++) { - if (gf->grads[i]) { - ggml_hash_insert(&zero_table, gf->grads[i]); + struct ggml_tensor * node = gf->nodes[i]; + + if (node->grad) { + { + const size_t insert_result = ggml_hash_insert(&zero_table, node->grad); + GGML_ASSERT(insert_result != GGML_HASHSET_FULL); + GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS); + } + + // only gradients of trainable parameters should be accumulated + if (accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) { + const size_t insert_result = ggml_hash_insert(&acc_table, node->grad); + GGML_ASSERT(insert_result != GGML_HASHSET_FULL); + GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS); + } } } for (int i = gf->n_nodes - 1; i >= 0; i--) { struct ggml_tensor * node = gf->nodes[i]; - // inplace operations to add gradients are not created by ggml_compute_backward + // inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation // use allocator to automatically make inplace operations if (node->grad) { - ggml_compute_backward(ctx, node, &zero_table); + ggml_compute_backward(ctx, node, &zero_table, &acc_table); } } @@ -19033,8 +19244,30 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * } ggml_hash_set_free(&zero_table); + ggml_hash_set_free(&acc_table); +} + +void ggml_build_opt_adamw( + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct ggml_cgraph * gb, + float alpha, + float beta1, + float beta2, + float eps, + float wd) { + for (int i = 0; i < gf->n_nodes; i++) { + struct ggml_tensor * node = gf->nodes[i]; + + if (node->flags & GGML_TENSOR_FLAG_PARAM) { + GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node); + struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, alpha, beta1, beta2, eps, wd); + ggml_build_forward_expand(gb, opt_step); + } + } } + static void * incr_ptr_aligned(void ** p, size_t size, size_t align) { void * ptr = *p; ptr = (void *) GGML_PAD((uintptr_t) ptr, align); @@ -19162,10 +19395,28 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) { GGML_ASSERT(cgraph->grads != NULL); for (int i = 0; i < cgraph->n_nodes; i++) { - struct ggml_tensor * grad = cgraph->grads[i]; + struct ggml_tensor * node = cgraph->nodes[i]; + + // initial gradients of loss should be 1, 0 otherwise + if (node->grad) { + if (node->flags & GGML_TENSOR_FLAG_LOSS) { + GGML_ASSERT(node->grad->buffer); + GGML_ASSERT(node->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_scalar(node)); + + const float onef = 1.0f; + ggml_backend_tensor_set(node->grad, &onef, 0, ggml_nbytes(node->grad)); + } else { + ggml_set_zero(node->grad); + } + } - if (grad) { - ggml_set_zero(grad); + GGML_ASSERT(node); + if (node->op == GGML_OP_OPT_STEP_ADAMW) { + // set iteration to 1 and clear momenta + ggml_set_op_params_i32(node, 0, 1); + ggml_set_zero(node->src[2]); + ggml_set_zero(node->src[3]); } } } @@ -19458,6 +19709,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { } break; case GGML_OP_CROSS_ENTROPY_LOSS: case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + case GGML_OP_OPT_STEP_ADAMW: { n_tasks = n_threads; } break; @@ -19753,8 +20005,8 @@ void ggml_threadpool_resume(struct ggml_threadpool * threadpool) { struct ggml_cplan ggml_graph_plan( const struct ggml_cgraph * cgraph, - int n_threads, - struct ggml_threadpool * threadpool) { + int n_threads, + struct ggml_threadpool * threadpool) { if (threadpool == NULL) { GGML_PRINT_DEBUG("Threadpool is not specified. Will create a disposable threadpool : n_threads %d\n", n_threads); @@ -21851,7 +22103,7 @@ enum ggml_opt_result ggml_opt_resume( ggml_build_forward_expand(gf, f); struct ggml_cgraph * gb = ggml_graph_dup(ctx, gf); - ggml_build_backward_expand(ctx, gf, gb, true); + ggml_build_backward_expand(ctx, gf, gb, false, true); return ggml_opt_resume_g(ctx, opt, f, gf, gb, NULL, NULL); } diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 3d2dfb41329fc..cf7b97d453966 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -10e83a412717c20d57ba19f025248e18e43addf3 +e7b23907cb2816e9951fe9b524d7127ab777297a diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index aa7896defdad0..889a199448a6d 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -799,10 +799,11 @@ struct test_case { out = ggml_sum(ctx, out); ggml_set_name(out, "sum_of_out"); } + ggml_set_loss(out); ggml_build_forward_expand(gf, out); ggml_graph_cpy(gf, gb); - ggml_build_backward_expand(ctx, gf, gb, false); + ggml_build_backward_expand(ctx, gf, gb, false, false); if (expect.size() != 1 || expect[0] != 0.0f) { GGML_ASSERT(ggml_graph_n_nodes(gb) > ggml_graph_n_nodes(gf)); for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { @@ -837,22 +838,11 @@ struct test_case { return false; } - // randomize tensors - initialize_tensors(ctx); - - for (struct ggml_tensor * t = ggml_get_first_tensor(ctx); t != nullptr; t = ggml_get_next_tensor(ctx, t)) { - if (!t->grad) { - continue; - } - std::vector tmp(ggml_nelements(t->grad)); - ggml_backend_tensor_set(t->grad, tmp.data(), 0, ggml_nbytes(t->grad)); - } + initialize_tensors(ctx); // Randomizes all tensors (including gradients). + ggml_graph_reset(gb); // Sets gradients to 1 if loss, 0 otherwise. - // build graphs - const float onef = 1.0f; ggml_backend_graph_compute(backend, gf); - ggml_backend_tensor_set(out->grad, &onef, 0, ggml_nbytes(out->grad)); ggml_backend_graph_compute(backend, gb); bool ok = true; @@ -1681,6 +1671,50 @@ struct test_mul_mat_id : public test_case { } }; +// GGML_OP_OUT_PROD +struct test_out_prod : public test_case { + const ggml_type type_a; + const ggml_type type_b; + const int64_t m; + const int64_t n; + const int64_t k; + const std::array bs; // dims 3 and 4 + const bool trans_b; + + std::string vars() override { + return VARS_TO_STR7(type_a, type_b, m, n, k, bs, trans_b); + } + + double max_nmse_err() override { + return 5e-4; + } + + test_out_prod(ggml_type type_a = GGML_TYPE_F32, ggml_type type_b = GGML_TYPE_F32, + int64_t m = 32, int64_t n = 32, int64_t k = 32, + std::array bs = {10, 10}, + bool trans_b = false) + : type_a(type_a), type_b(type_b), m(m), n(n), k(k), bs(bs), trans_b(trans_b) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor_4d(ctx, type_a, m, k, bs[0], bs[1]); + ggml_set_name(a, "a"); + + ggml_tensor * b; + if (trans_b) { + b = ggml_new_tensor_4d(ctx, type_b, k, n, bs[0], bs[1]); + b = ggml_transpose(ctx, b); + } else { + b = ggml_new_tensor_4d(ctx, type_b, n, k, bs[0], bs[1]); + } + ggml_set_name(b, "b"); + + ggml_tensor * out = ggml_out_prod(ctx, a, b); + ggml_set_name(out, "out"); + + return out; + } +}; + // GGML_OP_SQR struct test_sqr : public test_case { const ggml_type type; @@ -2666,6 +2700,51 @@ struct test_cross_entropy_loss : public test_case { } }; +// GGML_OP_OPT_STEP_ADAMW +struct test_opt_step_adamw : public test_case { + const ggml_type type; + const std::array ne; + const float alpha; + const float beta1; + const float beta2; + const float eps; + const float wd; + + std::string vars() override { + return VARS_TO_STR7(type, ne, alpha, beta1, beta2, eps, wd); + } + + test_opt_step_adamw(ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 5, 4, 3}, + float alpha = 1e-3f, + float beta1 = 0.9f, + float beta2 = 0.999f, + float eps = 1e-8f, + float wd = 0.0f) + : type(type), ne(ne), alpha(alpha), beta1(beta1), beta2(beta2), eps(eps), wd(wd) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]); + ggml_set_param(ctx, a); // Despite tensor a having gradients the output tensor will not. + ggml_set_name(a, "a"); + + ggml_tensor * out = ggml_opt_step_adamw(ctx, a, alpha, beta1, beta2, eps, wd); + ggml_set_name(out, "out"); + + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + init_tensor_uniform(t, 0.0f, 1.0f); // grad_v needs non-negative values. + } + } + + bool grad_precise() override { + return true; + } +}; + enum llm_norm_type { LLM_NORM, LLM_NORM_RMS, @@ -3159,14 +3238,15 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_conv_transpose_1d({3,2,1,1}, {3,1,2,1}, 1, 0, 1)); test_cases.emplace_back(new test_conv_transpose_1d({2,1,1,1}, {3,1,1,1}, 1, 0, 1)); - - test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 1, 1})); - test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {2, 1, 1, 1})); - test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 2, 1, 1})); - test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 2, 1})); - test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, 3}, {1, 1, 1, 2})); - test_cases.emplace_back(new test_repeat(GGML_TYPE_I32, {10, 5, 4, 3}, {2, 1, 1, 1})); - test_cases.emplace_back(new test_repeat(GGML_TYPE_I16, {10, 5, 4, 3}, {1, 1, 1, 2})); + for (int ne3 : {1, 3}) { // CUDA backwards pass only supports ne3 == 1 + test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 1, 1})); + test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {2, 1, 1, 1})); + test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 2, 1, 1})); + test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 2, 1})); + test_cases.emplace_back(new test_repeat(GGML_TYPE_F32, {10, 5, 4, ne3}, {1, 1, 1, 2})); + test_cases.emplace_back(new test_repeat(GGML_TYPE_I32, {10, 5, 4, ne3}, {2, 1, 1, 1})); + test_cases.emplace_back(new test_repeat(GGML_TYPE_I16, {10, 5, 4, ne3}, {1, 1, 1, 2})); + } test_cases.emplace_back(new test_dup(GGML_TYPE_F32)); test_cases.emplace_back(new test_dup(GGML_TYPE_F16)); @@ -3350,6 +3430,27 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } } + for (ggml_type type_a : base_types) { + for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) { + test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, { 1, 1})); + test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 1})); + test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 1})); + test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 10})); + test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 10})); + test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 10})); + test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 1, 16, {10, 10})); + + test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, { 1, 1})); + test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, { 1, 1}, true)); + test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 1})); + test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 1})); + test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 10})); + test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 10})); + test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 10})); + test_cases.emplace_back(new test_out_prod(type_a, type_b, 256, 16, 16, {10, 10})); + } + } + test_cases.emplace_back(new test_sqr()); test_cases.emplace_back(new test_sqrt()); test_cases.emplace_back(new test_log()); @@ -3476,6 +3577,9 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op } test_cases.emplace_back(new test_cross_entropy_loss()); + for (float wd : {0.0f, 1e-2f}) { + test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}, 1.0f, 1e-3f, 0.9f, 0.999f, wd)); + } // these tests are disabled to save execution time, but they can be handy for debugging #if 0 diff --git a/tests/test-grad0.cpp b/tests/test-grad0.cpp index 1834c11d894b4..2ef606d2c3591 100644 --- a/tests/test-grad0.cpp +++ b/tests/test-grad0.cpp @@ -240,7 +240,7 @@ static bool check_gradient( struct ggml_cgraph * gb = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true); ggml_build_forward_expand(gf, f); ggml_graph_cpy(gf, gb); - ggml_build_backward_expand(ctx0, gf, gb, false); + ggml_build_backward_expand(ctx0, gf, gb, false, false); ggml_graph_compute_with_ctx(ctx0, gf, n_threads);