Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/Flowfusion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,14 @@ export
NoisyInterpolatingDiscreteFlow,
DoobMatchingFlow,
OUFlow,
VPFlow,
CosineVPSchedule,
MaskedState,
Guide,
tangent_guide,
bridge,
vp_alpha_bar,
vp_bridge_coefficients,
scalefloss,
gen,
Tracker,
Expand Down
1 change: 1 addition & 0 deletions src/loss.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ floss(P::fbu(Deterministic), X̂₁, X₁::msu(ContinuousState),
floss(P::fbu(BrownianMotion), X̂₁, X₁::msu(ContinuousState), c) = scaledmaskedmean(mse(X̂₁, X₁), c, getlmask(X₁))
floss(P::OUFlow, X̂₁, X₁::msu(ContinuousState), c) = scaledmaskedmean(mse(X̂₁, X₁), c, getlmask(X₁))
floss(P::OUBridgeExpVar, X̂₁, X₁::msu(ContinuousState), c) = scaledmaskedmean(mse(X̂₁, X₁), c, getlmask(X₁))
floss(P::fbu(VPFlow), X̂₁, X₁::msu(ContinuousState), c) = scaledmaskedmean(mse(X̂₁, X₁), c, getlmask(X₁))
floss(P::fbu(ManifoldProcess{<:Euclidean}), X̂₁, X₁::msu(ContinuousState), c) = scaledmaskedmean(mse(X̂₁, X₁), c, getlmask(X₁))
#floss(P::fbu(OrnsteinUhlenbeck), X̂₁, X₁::msu(ContinuousState), c) = scaledmaskedmean(mse(X̂₁, X₁), c, getlmask(X₁)) #<- I'm not sure MSE on X1 works for this process. We need to pull X1 back to Xt and get the generator.
#For a discrete process, X̂₁ will be a distribution, and X₁ will have to be a onehot before going onto the gpu.
Expand Down
108 changes: 107 additions & 1 deletion src/processes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,113 @@
#For processes that aren't used elsewhere
##########################################

function _vp_check_flow_time(t::Real)
0 <= t <= 1 || throw(ArgumentError("flow time must be in [0, 1], got $t"))
return nothing
end

function _vp_check_flow_time(t::AbstractArray)
all((0 .<= t) .& (t .<= 1)) ||
throw(ArgumentError("all flow times must be in [0, 1]"))
return nothing
end

function _vp_check_clean_endpoint(t)
vals = t isa Real ? (t,) : t
all(x -> isapprox(float(x), 1.0; atol=sqrt(eps(float(x)))), vals) ||
throw(ArgumentError("VPFlow expects endpoint time 1"))
return nothing
end

_vp_eps(x::Real) = eps(typeof(float(one(typeof(x)))))
_vp_eps(x::AbstractArray) = eps(typeof(float(one(eltype(x)))))
_vp_eps(x, y) = max(_vp_eps(x), _vp_eps(y))

"""
vp_alpha_bar(P::VPFlow, t)

Return the cumulative signal power of the VP schedule at flow time `t`.
Flow time is oriented so that `t=0` is noisiest and `t=1` is the clean endpoint.
"""
function (S::CosineVPSchedule)(t)
diffusion_index = (1 .- t) .* S.n_timestep
angle = diffusion_index ./ (S.n_timestep + 1) .* (pi / 2)
return cos.(angle) .^ 2
end

function vp_alpha_bar(P::VPFlow, t)
_vp_check_flow_time(t)
alpha_bar = P.alpha_bar.(t)
all((0 .<= alpha_bar) .& (alpha_bar .<= 1)) ||
throw(ArgumentError("VP alpha_bar schedule must return values in [0, 1]"))
return alpha_bar
end

"""
vp_bridge_coefficients(P::VPFlow, s, t)

Coefficients for the exact endpoint-conditioned transition `x_t | x_s, x_1`
under the VP schedule, for flow times `0 <= s <= t <= 1`.

Returns `(coef_x1, coef_xs, variance)` such that
`x_t = coef_x1 * x_1 + coef_xs * x_s + sqrt(variance) * z`.
The inputs may be scalars or broadcast-compatible arrays.
"""
function vp_bridge_coefficients(P::VPFlow, s, t)
_vp_check_flow_time(s)
_vp_check_flow_time(t)
eps_t = _vp_eps(s, t)
all(s .<= t .+ sqrt(eps_t)) ||
throw(ArgumentError("expected s <= t for x_t | x_s, x_1"))

A_s = vp_alpha_bar(P, s)
A_t = vp_alpha_bar(P, t)
eps_a = _vp_eps(A_s, A_t)
same_time = abs.(t .- s) .<= sqrt(eps_t)
all(same_time .| (A_s .<= A_t)) ||
throw(ArgumentError("VP alpha_bar schedule must be nondecreasing"))
denom = max.(1 .- A_s, eps_a)
ratio_raw = A_s ./ max.(A_t, eps_a)
ratio = ifelse.(same_time, 1, clamp.(ratio_raw, 0, 1))

coef_x1 = ifelse.(same_time, 0, sqrt.(A_t) .* (1 .- ratio) ./ denom)
coef_xs = ifelse.(same_time, 1, sqrt.(ratio) .* (1 .- A_t) ./ denom)
variance = ifelse.(same_time, 0, (1 .- A_t) .* (1 .- ratio) ./ denom)
return coef_x1, coef_xs, max.(variance, 0)
end

function ForwardBackward.endpoint_conditioned_sample(
Xa::ContinuousState,
Xc::ContinuousState,
P::VPFlow,
t_a,
t_b,
t_c,
)::ContinuousState
size(Xa.state) == size(Xc.state) ||
throw(DimensionMismatch("Xa and Xc must have the same state shape"))
_vp_check_clean_endpoint(t_c)

xa = Xa.state
xc = Xc.state
nd = ndims(xa)
ta = expand(t_a, nd)
tb = expand(t_b, nd)
_vp_check_flow_time(ta)
_vp_check_flow_time(tb)
all(ta .<= tb) || throw(ArgumentError("expected t_a <= t_b"))

coef_x1, coef_xa, variance = vp_bridge_coefficients(P, ta, tb)
T = eltype(xa)
coef_x1 = T.(coef_x1)
coef_xa = T.(coef_xa)
variance = max.(T.(variance), zero(T))
mu = coef_x1 .* xc .+ coef_xa .* xa
noise = randn(T, size(xa)...)
xb = mu .+ sqrt.(variance) .* noise
return ContinuousState(xb)
end

##########################################
#https://arxiv.org/pdf/2407.15595
##########################################
Expand Down Expand Up @@ -134,4 +241,3 @@ function bridge(P::OUFlow, X0, X1, t0, t)
OU = OrnsteinUhlenbeckExpVar(tensor(X1), P.θ, P.v_at_0, P.v_at_1, dec = P.dec) #<-Note X1 as mean
endpoint_conditioned_sample(X0, X1, OU, t0, t, eltype(t)(1))
end

35 changes: 34 additions & 1 deletion src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,4 +73,37 @@ struct OUFlow{T} <: Process
dec::T
end

OUFlow(θ::T, v_at_0::T) where T = OUFlow(θ, v_at_0, T(1e-2), T(-0.1))
OUFlow(θ::T, v_at_0::T) where T = OUFlow(θ, v_at_0, T(1e-2), T(-0.1))

"""
CosineVPSchedule(n_timestep)
CosineVPSchedule()

Cosine cumulative signal-power schedule for `VPFlow`. Flow time runs from `0`
(maximally noised) to `1` (clean endpoint), with
`alpha_bar(t) = cos(((1 - t) * n_timestep / (n_timestep + 1)) * pi / 2)^2`.
"""
struct CosineVPSchedule
n_timestep::Int
function CosineVPSchedule(n_timestep::Integer=1000)
n_timestep > 0 || throw(ArgumentError("n_timestep must be positive"))
return new(Int(n_timestep))
end
end

"""
VPFlow(alpha_bar)
VPFlow()

Endpoint-conditioned flow induced by a variance-preserving diffusion schedule.
`alpha_bar` is any callable cumulative signal-power schedule with values in
`[0, 1]`, increasing from noisy flow time `0` to clean flow time `1`.

`VPFlow()` uses `CosineVPSchedule()` and recovers the cosine VP bridge currently
used by Branching Genie.
"""
struct VPFlow{S} <: ForwardBackward.ContinuousProcess
alpha_bar::S
end

VPFlow() = VPFlow(CosineVPSchedule())
53 changes: 53 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ using ForwardBackward
XR() = ManifoldState(MR, [rand(MR) for _ in zeros(siz...)])

for (f,p) in [(XC, BrownianMotion()),
(XC, VPFlow()),
(XT, ManifoldProcess(1)),
(XR, ManifoldProcess(1)),
(XD, InterpolatingDiscreteFlow())]
Expand Down Expand Up @@ -84,4 +85,56 @@ using ForwardBackward
end

end

@testset "VP flow" begin
P = VPFlow(CosineVPSchedule(1000))

@test isapprox(vp_alpha_bar(P, 1.0), 1.0; atol=1e-12)
@test vp_alpha_bar(P, 0.0) < 3e-6
@test 0.49 < vp_alpha_bar(P, 0.5) < 0.51
@test_throws ArgumentError vp_alpha_bar(P, -0.01)
@test_throws ArgumentError vp_alpha_bar(P, 1.01)
@test vp_alpha_bar(VPFlow(t -> t), 0.25) == 0.25

for (s, t, u) in ((0.0, 0.2, 0.7), (0.05, 0.4, 0.95), (0.33, 0.66, 1.0))
a_st, b_st, v_st = vp_bridge_coefficients(P, s, t)
a_tu, b_tu, v_tu = vp_bridge_coefficients(P, t, u)
a_su, b_su, v_su = vp_bridge_coefficients(P, s, u)

@test isapprox(b_tu * b_st, b_su; rtol=1e-10, atol=1e-10)
@test isapprox(a_tu + b_tu * a_st, a_su; rtol=1e-10, atol=1e-10)
@test isapprox((b_tu^2) * v_st + v_tu, v_su; rtol=1e-10, atol=1e-10)
end

Xa = ContinuousState(randn(Float32, 2, 3, 4))
X1 = ContinuousState(randn(Float32, 2, 3, 4))

Xt = bridge(P, Xa, X1, 0.2f0)
@test size(tensor(Xt)) == size(tensor(Xa))

Xsame = bridge(P, Xt, X1, 0.4f0, 0.4f0)
@test tensor(Xsame) ≈ tensor(Xt)

t0 = Float32[0.0, 0.1, 0.2, 0.3]
t1 = Float32[0.5, 0.6, 0.7, 0.8]
Xvec = bridge(P, Xa, X1, t0, t1)
@test size(tensor(Xvec)) == size(tensor(Xa))

Xclean = bridge(P, Xa, X1, 0.3, 1.0)
@test tensor(Xclean) ≈ tensor(X1)

@test_throws ArgumentError bridge(P, Xa, X1, 0.5, 0.4)
@test_throws ArgumentError ForwardBackward.endpoint_conditioned_sample(Xa, X1, P, 0.0, 0.5, 0.9)

P_linear = VPFlow(t -> t)
Xt0 = bridge(P_linear, Xa, X1, 0.0, 0.0)
@test tensor(Xt0) ≈ tensor(Xa)
Xlin = bridge(P_linear, Xa, X1, 0.0, 0.6)
@test size(tensor(Xlin)) == size(tensor(Xa))
Xlin_clean = bridge(P_linear, Xa, X1, 0.25, 1.0)
@test tensor(Xlin_clean) ≈ tensor(X1)

P_bad = VPFlow(t -> 1 - t)
@test_throws ArgumentError bridge(P_bad, Xa, X1, 0.2, 0.4)
end
end
Loading