diff --git a/src/Flowfusion.jl b/src/Flowfusion.jl index 5dfc5b5..a40e118 100644 --- a/src/Flowfusion.jl +++ b/src/Flowfusion.jl @@ -39,10 +39,14 @@ export NoisyInterpolatingDiscreteFlow, DoobMatchingFlow, OUFlow, + VPFlow, + CosineVPSchedule, MaskedState, Guide, tangent_guide, bridge, + vp_alpha_bar, + vp_bridge_coefficients, scalefloss, gen, Tracker, diff --git a/src/loss.jl b/src/loss.jl index 2dcc8c2..0b01e7a 100644 --- a/src/loss.jl +++ b/src/loss.jl @@ -31,6 +31,7 @@ floss(P::fbu(Deterministic), X̂₁, X₁::msu(ContinuousState), floss(P::fbu(BrownianMotion), X̂₁, X₁::msu(ContinuousState), c) = scaledmaskedmean(mse(X̂₁, X₁), c, getlmask(X₁)) floss(P::OUFlow, X̂₁, X₁::msu(ContinuousState), c) = scaledmaskedmean(mse(X̂₁, X₁), c, getlmask(X₁)) floss(P::OUBridgeExpVar, X̂₁, X₁::msu(ContinuousState), c) = scaledmaskedmean(mse(X̂₁, X₁), c, getlmask(X₁)) +floss(P::fbu(VPFlow), X̂₁, X₁::msu(ContinuousState), c) = scaledmaskedmean(mse(X̂₁, X₁), c, getlmask(X₁)) floss(P::fbu(ManifoldProcess{<:Euclidean}), X̂₁, X₁::msu(ContinuousState), c) = scaledmaskedmean(mse(X̂₁, X₁), c, getlmask(X₁)) #floss(P::fbu(OrnsteinUhlenbeck), X̂₁, X₁::msu(ContinuousState), c) = scaledmaskedmean(mse(X̂₁, X₁), c, getlmask(X₁)) #<- I'm not sure MSE on X1 works for this process. We need to pull X1 back to Xt and get the generator. #For a discrete process, X̂₁ will be a distribution, and X₁ will have to be a onehot before going onto the gpu. diff --git a/src/processes.jl b/src/processes.jl index ef82f8b..c5bbf60 100644 --- a/src/processes.jl +++ b/src/processes.jl @@ -2,6 +2,113 @@ #For processes that aren't used elsewhere ########################################## +function _vp_check_flow_time(t::Real) + 0 <= t <= 1 || throw(ArgumentError("flow time must be in [0, 1], got $t")) + return nothing +end + +function _vp_check_flow_time(t::AbstractArray) + all((0 .<= t) .& (t .<= 1)) || + throw(ArgumentError("all flow times must be in [0, 1]")) + return nothing +end + +function _vp_check_clean_endpoint(t) + vals = t isa Real ? (t,) : t + all(x -> isapprox(float(x), 1.0; atol=sqrt(eps(float(x)))), vals) || + throw(ArgumentError("VPFlow expects endpoint time 1")) + return nothing +end + +_vp_eps(x::Real) = eps(typeof(float(one(typeof(x))))) +_vp_eps(x::AbstractArray) = eps(typeof(float(one(eltype(x))))) +_vp_eps(x, y) = max(_vp_eps(x), _vp_eps(y)) + +""" + vp_alpha_bar(P::VPFlow, t) + +Return the cumulative signal power of the VP schedule at flow time `t`. +Flow time is oriented so that `t=0` is noisiest and `t=1` is the clean endpoint. +""" +function (S::CosineVPSchedule)(t) + diffusion_index = (1 .- t) .* S.n_timestep + angle = diffusion_index ./ (S.n_timestep + 1) .* (pi / 2) + return cos.(angle) .^ 2 +end + +function vp_alpha_bar(P::VPFlow, t) + _vp_check_flow_time(t) + alpha_bar = P.alpha_bar.(t) + all((0 .<= alpha_bar) .& (alpha_bar .<= 1)) || + throw(ArgumentError("VP alpha_bar schedule must return values in [0, 1]")) + return alpha_bar +end + +""" + vp_bridge_coefficients(P::VPFlow, s, t) + +Coefficients for the exact endpoint-conditioned transition `x_t | x_s, x_1` +under the VP schedule, for flow times `0 <= s <= t <= 1`. + +Returns `(coef_x1, coef_xs, variance)` such that +`x_t = coef_x1 * x_1 + coef_xs * x_s + sqrt(variance) * z`. +The inputs may be scalars or broadcast-compatible arrays. +""" +function vp_bridge_coefficients(P::VPFlow, s, t) + _vp_check_flow_time(s) + _vp_check_flow_time(t) + eps_t = _vp_eps(s, t) + all(s .<= t .+ sqrt(eps_t)) || + throw(ArgumentError("expected s <= t for x_t | x_s, x_1")) + + A_s = vp_alpha_bar(P, s) + A_t = vp_alpha_bar(P, t) + eps_a = _vp_eps(A_s, A_t) + same_time = abs.(t .- s) .<= sqrt(eps_t) + all(same_time .| (A_s .<= A_t)) || + throw(ArgumentError("VP alpha_bar schedule must be nondecreasing")) + denom = max.(1 .- A_s, eps_a) + ratio_raw = A_s ./ max.(A_t, eps_a) + ratio = ifelse.(same_time, 1, clamp.(ratio_raw, 0, 1)) + + coef_x1 = ifelse.(same_time, 0, sqrt.(A_t) .* (1 .- ratio) ./ denom) + coef_xs = ifelse.(same_time, 1, sqrt.(ratio) .* (1 .- A_t) ./ denom) + variance = ifelse.(same_time, 0, (1 .- A_t) .* (1 .- ratio) ./ denom) + return coef_x1, coef_xs, max.(variance, 0) +end + +function ForwardBackward.endpoint_conditioned_sample( + Xa::ContinuousState, + Xc::ContinuousState, + P::VPFlow, + t_a, + t_b, + t_c, +)::ContinuousState + size(Xa.state) == size(Xc.state) || + throw(DimensionMismatch("Xa and Xc must have the same state shape")) + _vp_check_clean_endpoint(t_c) + + xa = Xa.state + xc = Xc.state + nd = ndims(xa) + ta = expand(t_a, nd) + tb = expand(t_b, nd) + _vp_check_flow_time(ta) + _vp_check_flow_time(tb) + all(ta .<= tb) || throw(ArgumentError("expected t_a <= t_b")) + + coef_x1, coef_xa, variance = vp_bridge_coefficients(P, ta, tb) + T = eltype(xa) + coef_x1 = T.(coef_x1) + coef_xa = T.(coef_xa) + variance = max.(T.(variance), zero(T)) + mu = coef_x1 .* xc .+ coef_xa .* xa + noise = randn(T, size(xa)...) + xb = mu .+ sqrt.(variance) .* noise + return ContinuousState(xb) +end + ########################################## #https://arxiv.org/pdf/2407.15595 ########################################## @@ -134,4 +241,3 @@ function bridge(P::OUFlow, X0, X1, t0, t) OU = OrnsteinUhlenbeckExpVar(tensor(X1), P.θ, P.v_at_0, P.v_at_1, dec = P.dec) #<-Note X1 as mean endpoint_conditioned_sample(X0, X1, OU, t0, t, eltype(t)(1)) end - diff --git a/src/types.jl b/src/types.jl index e3c30b1..704e75d 100644 --- a/src/types.jl +++ b/src/types.jl @@ -73,4 +73,37 @@ struct OUFlow{T} <: Process dec::T end -OUFlow(θ::T, v_at_0::T) where T = OUFlow(θ, v_at_0, T(1e-2), T(-0.1)) \ No newline at end of file +OUFlow(θ::T, v_at_0::T) where T = OUFlow(θ, v_at_0, T(1e-2), T(-0.1)) + +""" + CosineVPSchedule(n_timestep) + CosineVPSchedule() + +Cosine cumulative signal-power schedule for `VPFlow`. Flow time runs from `0` +(maximally noised) to `1` (clean endpoint), with +`alpha_bar(t) = cos(((1 - t) * n_timestep / (n_timestep + 1)) * pi / 2)^2`. +""" +struct CosineVPSchedule + n_timestep::Int + function CosineVPSchedule(n_timestep::Integer=1000) + n_timestep > 0 || throw(ArgumentError("n_timestep must be positive")) + return new(Int(n_timestep)) + end +end + +""" + VPFlow(alpha_bar) + VPFlow() + +Endpoint-conditioned flow induced by a variance-preserving diffusion schedule. +`alpha_bar` is any callable cumulative signal-power schedule with values in +`[0, 1]`, increasing from noisy flow time `0` to clean flow time `1`. + +`VPFlow()` uses `CosineVPSchedule()` and recovers the cosine VP bridge currently +used by Branching Genie. +""" +struct VPFlow{S} <: ForwardBackward.ContinuousProcess + alpha_bar::S +end + +VPFlow() = VPFlow(CosineVPSchedule()) diff --git a/test/runtests.jl b/test/runtests.jl index 5a4cb67..a75fac4 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -53,6 +53,7 @@ using ForwardBackward XR() = ManifoldState(MR, [rand(MR) for _ in zeros(siz...)]) for (f,p) in [(XC, BrownianMotion()), + (XC, VPFlow()), (XT, ManifoldProcess(1)), (XR, ManifoldProcess(1)), (XD, InterpolatingDiscreteFlow())] @@ -84,4 +85,56 @@ using ForwardBackward end end + + @testset "VP flow" begin + P = VPFlow(CosineVPSchedule(1000)) + + @test isapprox(vp_alpha_bar(P, 1.0), 1.0; atol=1e-12) + @test vp_alpha_bar(P, 0.0) < 3e-6 + @test 0.49 < vp_alpha_bar(P, 0.5) < 0.51 + @test_throws ArgumentError vp_alpha_bar(P, -0.01) + @test_throws ArgumentError vp_alpha_bar(P, 1.01) + @test vp_alpha_bar(VPFlow(t -> t), 0.25) == 0.25 + + for (s, t, u) in ((0.0, 0.2, 0.7), (0.05, 0.4, 0.95), (0.33, 0.66, 1.0)) + a_st, b_st, v_st = vp_bridge_coefficients(P, s, t) + a_tu, b_tu, v_tu = vp_bridge_coefficients(P, t, u) + a_su, b_su, v_su = vp_bridge_coefficients(P, s, u) + + @test isapprox(b_tu * b_st, b_su; rtol=1e-10, atol=1e-10) + @test isapprox(a_tu + b_tu * a_st, a_su; rtol=1e-10, atol=1e-10) + @test isapprox((b_tu^2) * v_st + v_tu, v_su; rtol=1e-10, atol=1e-10) + end + + Xa = ContinuousState(randn(Float32, 2, 3, 4)) + X1 = ContinuousState(randn(Float32, 2, 3, 4)) + + Xt = bridge(P, Xa, X1, 0.2f0) + @test size(tensor(Xt)) == size(tensor(Xa)) + + Xsame = bridge(P, Xt, X1, 0.4f0, 0.4f0) + @test tensor(Xsame) ≈ tensor(Xt) + + t0 = Float32[0.0, 0.1, 0.2, 0.3] + t1 = Float32[0.5, 0.6, 0.7, 0.8] + Xvec = bridge(P, Xa, X1, t0, t1) + @test size(tensor(Xvec)) == size(tensor(Xa)) + + Xclean = bridge(P, Xa, X1, 0.3, 1.0) + @test tensor(Xclean) ≈ tensor(X1) + + @test_throws ArgumentError bridge(P, Xa, X1, 0.5, 0.4) + @test_throws ArgumentError ForwardBackward.endpoint_conditioned_sample(Xa, X1, P, 0.0, 0.5, 0.9) + + P_linear = VPFlow(t -> t) + Xt0 = bridge(P_linear, Xa, X1, 0.0, 0.0) + @test tensor(Xt0) ≈ tensor(Xa) + Xlin = bridge(P_linear, Xa, X1, 0.0, 0.6) + @test size(tensor(Xlin)) == size(tensor(Xa)) + Xlin_clean = bridge(P_linear, Xa, X1, 0.25, 1.0) + @test tensor(Xlin_clean) ≈ tensor(X1) + + P_bad = VPFlow(t -> 1 - t) + @test_throws ArgumentError bridge(P_bad, Xa, X1, 0.2, 0.4) + end end