Skip to content

Commit

Permalink
Merge pull request #33 from LuxDL/ap/common_layers
Browse files Browse the repository at this point in the history
Add Common layers
  • Loading branch information
avik-pal committed Jun 9, 2024
2 parents e8ab38f + fb6766c commit cc56c00
Show file tree
Hide file tree
Showing 18 changed files with 621 additions and 40 deletions.
20 changes: 18 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
name = "Boltz"
uuid = "4544d5e4-abc5-4dea-817f-29e4c205d9c8"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.3.5"
version = "0.3.6"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -13,6 +14,8 @@ JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -21,27 +24,36 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"

[weakdeps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
BoltzForwardDiffExt = "ForwardDiff"
BoltzMetalheadExt = "Metalhead"
BoltzZygoteExt = "Zygote"

[compat]
ADTypes = "1.3"
Aqua = "0.8.7"
ArgCheck = "2.3"
Artifacts = "1.10"
ChainRulesCore = "1.24"
ComponentArrays = "0.15.13"
ConcreteStructs = "0.2.3"
ExplicitImports = "1.5"
ForwardDiff = "0.10.36"
GPUArraysCore = "0.1.6"
JLD2 = "0.4.48"
LazyArtifacts = "1.10"
Lux = "0.5.50"
LuxAMDGPU = "0.2.3"
LuxCUDA = "0.3.2"
LuxCore = "0.1.15"
LuxDeviceUtils = "0.1.21"
LuxLib = "0.3.26"
LuxTestUtils = "0.1.15"
Markdown = "1.10"
Metalhead = "0.9"
NNlib = "0.9.17"
Pkg = "1.10"
Expand All @@ -52,11 +64,14 @@ Reexport = "1.2.2"
Statistics = "1.10"
Test = "1.10"
WeightInitializers = "0.1.7"
Zygote = "0.6.70"
julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
Expand All @@ -65,6 +80,7 @@ Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "ExplicitImports", "LuxAMDGPU", "LuxCUDA", "LuxLib", "LuxTestUtils", "Metalhead", "Pkg", "ReTestItems", "Test"]
test = ["Aqua", "ComponentArrays", "ExplicitImports", "ForwardDiff", "LuxAMDGPU", "LuxCUDA", "LuxLib", "LuxTestUtils", "Metalhead", "Pkg", "ReTestItems", "Test", "Zygote"]
17 changes: 17 additions & 0 deletions ext/BoltzForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module BoltzForwardDiffExt

using ADTypes: AutoForwardDiff
using Boltz: Boltz, Layers
using ForwardDiff: ForwardDiff

@inline Boltz._is_extension_loaded(::Val{:ForwardDiff}) = true

@inline Boltz._should_type_assert(::AbstractArray{<:ForwardDiff.Dual}) = false
@inline Boltz._should_type_assert(::ForwardDiff.Dual) = false

# Hamiltonian NN
function Layers.hamiltonian_forward(::AutoForwardDiff, model, x)
return ForwardDiff.gradient(sum model, x)
end

end
2 changes: 1 addition & 1 deletion ext/BoltzMetalheadExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using Boltz: Boltz, __maybe_initialize_model, Vision
using Lux: Lux, FromFluxAdaptor
using Metalhead: Metalhead

Boltz._is_extension_loaded(::Val{:Metalhead}) = true
@inline Boltz._is_extension_loaded(::Val{:Metalhead}) = true

function Vision.__AlexNet(; pretrained=false, kwargs...)
model = FromFluxAdaptor()(Metalhead.AlexNet().layers)
Expand Down
14 changes: 14 additions & 0 deletions ext/BoltzZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
module BoltzZygoteExt

using ADTypes: AutoZygote
using Boltz: Boltz, Layers
using Zygote: Zygote

@inline Boltz._is_extension_loaded(::Val{:Zygote}) = true

# Hamiltonian NN
function Layers.hamiltonian_forward(::AutoZygote, model, x)
return only(Zygote.gradient(sum model, x))
end

end
5 changes: 4 additions & 1 deletion src/Boltz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ include("utils.jl")
include("initialize.jl")
include("patch.jl")

# Basis Functions
include("basis.jl")

# Layers
include("layers/Layers.jl")

Expand All @@ -32,6 +35,6 @@ include("vision/Vision.jl")
# deprecated
include("deprecated.jl")

export Layers, Vision
export Basis, Layers, Vision

end
151 changes: 151 additions & 0 deletions src/basis.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
module Basis

using ..Boltz: _unsqueeze1
using ChainRulesCore: ChainRulesCore, NoTangent
using ConcreteStructs: @concrete
using Markdown: @doc_str

const CRC = ChainRulesCore

# The rrules in this file are hardcoded to be used exclusively with GeneralBasisFunction
@concrete struct GeneralBasisFunction{name}
f
n::Int
end

function Base.show(io::IO, basis::GeneralBasisFunction{name}) where {name}
print(io, "Basis.$(name)(order=$(basis.n))")
end

@inline function (basis::GeneralBasisFunction{name, F})(x::AbstractArray) where {name, F}
return basis.f.(1:(basis.n), _unsqueeze1(x))
end

@doc doc"""
Chebyshev(n)
Constructs a Chebyshev basis of the form $[T_{0}(x), T_{1}(x), \dots, T_{n-1}(x)]$ where
$T_j(.)$ is the $j^{th}$ Chebyshev polynomial of the first kind.
## Arguments
- `n`: number of terms in the polynomial expansion.
"""
Chebyshev(n) = GeneralBasisFunction{:Chebyshev}(__chebyshev, n)

@inline __chebyshev(i, x) = @fastmath cos(i * acos(x))

@doc doc"""
Sin(n)
Constructs a sine basis of the form $[\sin(x), \sin(2x), \dots, \sin(nx)]$.
## Arguments
- `n`: number of terms in the sine expansion.
"""
Sin(n) = GeneralBasisFunction{:Sin}(@fastmath(sin∘*), n)

@doc doc"""
Cos(n)
Constructs a cosine basis of the form $[\cos(x), \cos(2x), \dots, \cos(nx)]$.
## Arguments
- `n`: number of terms in the cosine expansion.
"""
Cos(n) = GeneralBasisFunction{:Cos}(@fastmath(cos∘*), n)

@doc doc"""
Fourier(n)
Constructs a Fourier basis of the form
$F_j(x) = j is even ? cos((j÷2)x) : sin((j÷2)x)$ => $[F_0(x), F_1(x), \dots, F_n(x)]$.
## Arguments
- `n`: number of terms in the Fourier expansion.
"""
Fourier(n) = GeneralBasisFunction{:Fourier}(__fourier, n)

@inline @fastmath function __fourier(i, x::AbstractFloat)
s, c = sincos(i * x / 2)
return ifelse(iseven(i), c, s)
end

@inline function __fourier(i, x) # No FastMath for non abstract floats
s, c = sincos(i * x / 2)
return ifelse(iseven(i), c, s)
end

@fastmath function CRC.rrule(::typeof(Broadcast.broadcasted), ::typeof(__fourier), i, x)
ix_by_2 = @. i * x / 2
s = @. sin(ix_by_2)
c = @. cos(ix_by_2)
y = @. ifelse(iseven(i), c, s)

∇fourier = let s = s, c = c, i = i
Δ -> begin
return (NoTangent(), NoTangent(), NoTangent(),
dropdims(sum((i / 2) .* ifelse.(iseven.(i), -s, c) .* Δ; dims=1); dims=1))
end
end

return y, ∇fourier
end

@doc doc"""
Legendre(n)
Constructs a Legendre basis of the form $[P_{0}(x), P_{1}(x), \dots, P_{n-1}(x)]$ where
$P_j(.)$ is the $j^{th}$ Legendre polynomial.
## Arguments
- `n`: number of terms in the polynomial expansion.
"""
Legendre(n) = GeneralBasisFunction{:Legendre}(__legendre_poly, n)

## Source: https://github.com/ranocha/PolynomialBases.jl/blob/master/src/legendre.jl
@inline function __legendre_poly(i, x)
p = i - 1
a = one(x)
b = x

p 0 && return a
p == 1 && return b

for j in 2:p
a, b = b, @fastmath(((2j - 1) * x * b - (j - 1) * a)/j)
end

return b
end

@doc doc"""
Polynomial(n)
Constructs a Polynomial basis of the form $[1, x, \dots, x^(n-1)]$.
## Arguments
- `n`: number of terms in the polynomial expansion.
"""
Polynomial(n) = GeneralBasisFunction{:Polynomial}(__polynomial, n)

@inline __polynomial(i, x) = x^(i - 1)

function CRC.rrule(::typeof(Broadcast.broadcasted), ::typeof(__polynomial), i, x)
y_m1 = x .^ (i .- 2)
y = y_m1 .* x
∇polynomial = let y_m1 = y_m1, i = i
Δ -> begin
return (NoTangent(), NoTangent(), NoTangent(),
dropdims(sum((i .- 1) .* y_m1 .* Δ; dims=1); dims=1))
end
end
return y, ∇polynomial
end

end
22 changes: 18 additions & 4 deletions src/layers/Layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,35 @@ using PrecompileTools: @recompile_invalidations

@recompile_invalidations begin
using ArgCheck: @argcheck
using ..Boltz: _fast_chunk
using ADTypes: AutoForwardDiff, AutoZygote
using ..Boltz: Boltz, _fast_chunk, _should_type_assert, _stack
using ConcreteStructs: @concrete
using ChainRulesCore: ChainRulesCore
using Lux: Lux
using LuxCore: LuxCore, AbstractExplicitLayer
using Lux: Lux, StatefulLuxLayer
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer
using LuxDeviceUtils: get_device, LuxCPUDevice, LuxCUDADevice
using Markdown: @doc_str
using NNlib: NNlib
using Random: AbstractRNG
using WeightInitializers: zeros32, randn32
end

const CRC = ChainRulesCore

include("conv_norm_act.jl")
const NORM_LAYER_DOC = "Function with signature `f(i::Integer, dims::Integer, act::F; \
kwargs...)`. `i` is the location of the layer in the model, \
`dims` is the channel dimension of the input, and `act` is the \
activation function. `kwargs` are forwarded from the `norm_kwargs` \
input, The function should return a normalization layer. Defaults \
to `nothing`, which means no normalization layer is used"

include("attention.jl")
include("conv_norm_act.jl")
include("encoder.jl")
include("embeddings.jl")
include("hamiltonian.jl")
include("mlp.jl")
include("spline.jl")
include("tensor_product.jl")

end
Loading

2 comments on commit cc56c00

@avik-pal
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/108604

Tip: Release Notes

Did you know you can add release notes too? Just add markdown formatted text underneath the comment after the text
"Release notes:" and it will be added to the registry PR, and if TagBot is installed it will also be added to the
release that TagBot creates. i.e.

@JuliaRegistrator register

Release notes:

## Breaking changes

- blah

To add them here just re-invoke and the PR will be updated.

Tagging

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.3.6 -m "<description of version>" cc56c005af4071aeb84d8575a9c86d47fcf2113f
git push origin v0.3.6

Please sign in to comment.