Skip to content

Commit

Permalink
Merge pull request #33 from LuxDL/ap/common_layers
Browse files Browse the repository at this point in the history
Add Common layers
  • Loading branch information
avik-pal authored Jun 9, 2024
2 parents e8ab38f + fb6766c commit de0f105
Show file tree
Hide file tree
Showing 18 changed files with 621 additions and 40 deletions.
20 changes: 18 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
name = "Boltz"
uuid = "4544d5e4-abc5-4dea-817f-29e4c205d9c8"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "0.3.5"
version = "0.3.6"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Artifacts = "56f22d72-fd6d-98f1-02f0-08ddc0907c33"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -13,6 +14,8 @@ JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
LazyArtifacts = "4af54fe1-eca0-43a8-85a7-787d91b784e3"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand All @@ -21,27 +24,36 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d"

[weakdeps]
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
BoltzForwardDiffExt = "ForwardDiff"
BoltzMetalheadExt = "Metalhead"
BoltzZygoteExt = "Zygote"

[compat]
ADTypes = "1.3"
Aqua = "0.8.7"
ArgCheck = "2.3"
Artifacts = "1.10"
ChainRulesCore = "1.24"
ComponentArrays = "0.15.13"
ConcreteStructs = "0.2.3"
ExplicitImports = "1.5"
ForwardDiff = "0.10.36"
GPUArraysCore = "0.1.6"
JLD2 = "0.4.48"
LazyArtifacts = "1.10"
Lux = "0.5.50"
LuxAMDGPU = "0.2.3"
LuxCUDA = "0.3.2"
LuxCore = "0.1.15"
LuxDeviceUtils = "0.1.21"
LuxLib = "0.3.26"
LuxTestUtils = "0.1.15"
Markdown = "1.10"
Metalhead = "0.9"
NNlib = "0.9.17"
Pkg = "1.10"
Expand All @@ -52,11 +64,14 @@ Reexport = "1.2.2"
Statistics = "1.10"
Test = "1.10"
WeightInitializers = "0.1.7"
Zygote = "0.6.70"
julia = "1.10"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
LuxLib = "82251201-b29d-42c6-8e01-566dec8acb11"
Expand All @@ -65,6 +80,7 @@ Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "ExplicitImports", "LuxAMDGPU", "LuxCUDA", "LuxLib", "LuxTestUtils", "Metalhead", "Pkg", "ReTestItems", "Test"]
test = ["Aqua", "ComponentArrays", "ExplicitImports", "ForwardDiff", "LuxAMDGPU", "LuxCUDA", "LuxLib", "LuxTestUtils", "Metalhead", "Pkg", "ReTestItems", "Test", "Zygote"]
17 changes: 17 additions & 0 deletions ext/BoltzForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
module BoltzForwardDiffExt

using ADTypes: AutoForwardDiff
using Boltz: Boltz, Layers
using ForwardDiff: ForwardDiff

@inline Boltz._is_extension_loaded(::Val{:ForwardDiff}) = true

@inline Boltz._should_type_assert(::AbstractArray{<:ForwardDiff.Dual}) = false
@inline Boltz._should_type_assert(::ForwardDiff.Dual) = false

# Hamiltonian NN
function Layers.hamiltonian_forward(::AutoForwardDiff, model, x)
return ForwardDiff.gradient(sum model, x)
end

end
2 changes: 1 addition & 1 deletion ext/BoltzMetalheadExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using Boltz: Boltz, __maybe_initialize_model, Vision
using Lux: Lux, FromFluxAdaptor
using Metalhead: Metalhead

Boltz._is_extension_loaded(::Val{:Metalhead}) = true
@inline Boltz._is_extension_loaded(::Val{:Metalhead}) = true

function Vision.__AlexNet(; pretrained=false, kwargs...)
model = FromFluxAdaptor()(Metalhead.AlexNet().layers)
Expand Down
14 changes: 14 additions & 0 deletions ext/BoltzZygoteExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
module BoltzZygoteExt

using ADTypes: AutoZygote
using Boltz: Boltz, Layers
using Zygote: Zygote

@inline Boltz._is_extension_loaded(::Val{:Zygote}) = true

# Hamiltonian NN
function Layers.hamiltonian_forward(::AutoZygote, model, x)
return only(Zygote.gradient(sum model, x))
end

end
5 changes: 4 additions & 1 deletion src/Boltz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@ include("utils.jl")
include("initialize.jl")
include("patch.jl")

# Basis Functions
include("basis.jl")

# Layers
include("layers/Layers.jl")

Expand All @@ -32,6 +35,6 @@ include("vision/Vision.jl")
# deprecated
include("deprecated.jl")

export Layers, Vision
export Basis, Layers, Vision

end
151 changes: 151 additions & 0 deletions src/basis.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
module Basis

using ..Boltz: _unsqueeze1
using ChainRulesCore: ChainRulesCore, NoTangent
using ConcreteStructs: @concrete
using Markdown: @doc_str

const CRC = ChainRulesCore

# The rrules in this file are hardcoded to be used exclusively with GeneralBasisFunction
@concrete struct GeneralBasisFunction{name}
f
n::Int
end

function Base.show(io::IO, basis::GeneralBasisFunction{name}) where {name}
print(io, "Basis.$(name)(order=$(basis.n))")
end

@inline function (basis::GeneralBasisFunction{name, F})(x::AbstractArray) where {name, F}
return basis.f.(1:(basis.n), _unsqueeze1(x))
end

@doc doc"""
Chebyshev(n)
Constructs a Chebyshev basis of the form $[T_{0}(x), T_{1}(x), \dots, T_{n-1}(x)]$ where
$T_j(.)$ is the $j^{th}$ Chebyshev polynomial of the first kind.
## Arguments
- `n`: number of terms in the polynomial expansion.
"""
Chebyshev(n) = GeneralBasisFunction{:Chebyshev}(__chebyshev, n)

@inline __chebyshev(i, x) = @fastmath cos(i * acos(x))

@doc doc"""
Sin(n)
Constructs a sine basis of the form $[\sin(x), \sin(2x), \dots, \sin(nx)]$.
## Arguments
- `n`: number of terms in the sine expansion.
"""
Sin(n) = GeneralBasisFunction{:Sin}(@fastmath(sin∘*), n)

@doc doc"""
Cos(n)
Constructs a cosine basis of the form $[\cos(x), \cos(2x), \dots, \cos(nx)]$.
## Arguments
- `n`: number of terms in the cosine expansion.
"""
Cos(n) = GeneralBasisFunction{:Cos}(@fastmath(cos∘*), n)

@doc doc"""
Fourier(n)
Constructs a Fourier basis of the form
$F_j(x) = j is even ? cos((j÷2)x) : sin((j÷2)x)$ => $[F_0(x), F_1(x), \dots, F_n(x)]$.
## Arguments
- `n`: number of terms in the Fourier expansion.
"""
Fourier(n) = GeneralBasisFunction{:Fourier}(__fourier, n)

@inline @fastmath function __fourier(i, x::AbstractFloat)
s, c = sincos(i * x / 2)
return ifelse(iseven(i), c, s)
end

@inline function __fourier(i, x) # No FastMath for non abstract floats
s, c = sincos(i * x / 2)
return ifelse(iseven(i), c, s)
end

@fastmath function CRC.rrule(::typeof(Broadcast.broadcasted), ::typeof(__fourier), i, x)
ix_by_2 = @. i * x / 2
s = @. sin(ix_by_2)
c = @. cos(ix_by_2)
y = @. ifelse(iseven(i), c, s)

∇fourier = let s = s, c = c, i = i
Δ -> begin
return (NoTangent(), NoTangent(), NoTangent(),
dropdims(sum((i / 2) .* ifelse.(iseven.(i), -s, c) .* Δ; dims=1); dims=1))
end
end

return y, ∇fourier
end

@doc doc"""
Legendre(n)
Constructs a Legendre basis of the form $[P_{0}(x), P_{1}(x), \dots, P_{n-1}(x)]$ where
$P_j(.)$ is the $j^{th}$ Legendre polynomial.
## Arguments
- `n`: number of terms in the polynomial expansion.
"""
Legendre(n) = GeneralBasisFunction{:Legendre}(__legendre_poly, n)

## Source: https://github.com/ranocha/PolynomialBases.jl/blob/master/src/legendre.jl
@inline function __legendre_poly(i, x)
p = i - 1
a = one(x)
b = x

p 0 && return a
p == 1 && return b

for j in 2:p
a, b = b, @fastmath(((2j - 1) * x * b - (j - 1) * a)/j)
end

return b
end

@doc doc"""
Polynomial(n)
Constructs a Polynomial basis of the form $[1, x, \dots, x^(n-1)]$.
## Arguments
- `n`: number of terms in the polynomial expansion.
"""
Polynomial(n) = GeneralBasisFunction{:Polynomial}(__polynomial, n)

@inline __polynomial(i, x) = x^(i - 1)

function CRC.rrule(::typeof(Broadcast.broadcasted), ::typeof(__polynomial), i, x)
y_m1 = x .^ (i .- 2)
y = y_m1 .* x
∇polynomial = let y_m1 = y_m1, i = i
Δ -> begin
return (NoTangent(), NoTangent(), NoTangent(),
dropdims(sum((i .- 1) .* y_m1 .* Δ; dims=1); dims=1))
end
end
return y, ∇polynomial
end

end
22 changes: 18 additions & 4 deletions src/layers/Layers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,35 @@ using PrecompileTools: @recompile_invalidations

@recompile_invalidations begin
using ArgCheck: @argcheck
using ..Boltz: _fast_chunk
using ADTypes: AutoForwardDiff, AutoZygote
using ..Boltz: Boltz, _fast_chunk, _should_type_assert, _stack
using ConcreteStructs: @concrete
using ChainRulesCore: ChainRulesCore
using Lux: Lux
using LuxCore: LuxCore, AbstractExplicitLayer
using Lux: Lux, StatefulLuxLayer
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer
using LuxDeviceUtils: get_device, LuxCPUDevice, LuxCUDADevice
using Markdown: @doc_str
using NNlib: NNlib
using Random: AbstractRNG
using WeightInitializers: zeros32, randn32
end

const CRC = ChainRulesCore

include("conv_norm_act.jl")
const NORM_LAYER_DOC = "Function with signature `f(i::Integer, dims::Integer, act::F; \
kwargs...)`. `i` is the location of the layer in the model, \
`dims` is the channel dimension of the input, and `act` is the \
activation function. `kwargs` are forwarded from the `norm_kwargs` \
input, The function should return a normalization layer. Defaults \
to `nothing`, which means no normalization layer is used"

include("attention.jl")
include("conv_norm_act.jl")
include("encoder.jl")
include("embeddings.jl")
include("hamiltonian.jl")
include("mlp.jl")
include("spline.jl")
include("tensor_product.jl")

end
Loading

0 comments on commit de0f105

Please sign in to comment.