Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 128 additions & 0 deletions pkg/backend/cloudflare/cloudflare.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
package cloudflareai

import (
"bytes"
"encoding/json"
"fmt"
"log/slog"
"net/http"
"strings"

"github.com/labstack/echo/v4"
"github.com/moeru-ai/unspeech/pkg/apierrors"
"github.com/moeru-ai/unspeech/pkg/backend/types"
"github.com/moeru-ai/unspeech/pkg/utils"
"github.com/samber/lo"
"github.com/samber/mo"
)

// CloudflareSpeechRequest defines the payload for the Workers AI TTS API.
// Based on official documentation for models like @cf/myshell-ai/melotts.
type CloudflareSpeechRequest struct {
Text string `json:"text"`
// 'lang' is another potential field for models that support it.
// Lang string `json:"lang,omitempty"`
}

// HandleSpeechCloudflare processes a TTS request using the Cloudflare Workers AI API.
// It requires the Cloudflare Account ID to be passed in.
func HandleSpeechCloudflare(c echo.Context, accountID string, options mo.Option[types.SpeechRequestOptions]) mo.Result[any] {
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The function signature for HandleSpeechCloudflare takes an accountID parameter, which is inconsistent with other backend handlers and makes it difficult to integrate. Configuration like account IDs should typically be managed via environment variables for better portability and consistency. I'd recommend changing the function signature to remove the accountID parameter.

You can then add the following code at the beginning of the function (you'll need to import os):

accountID := os.Getenv("CLOUDFLARE_ACCOUNT_ID")
if accountID == "" {
    return mo.Err[any](apierrors.NewErrInternal().WithDetail("CLOUDFLARE_ACCOUNT_ID environment variable not set").WithCaller())
}

Also, remember to register this new backend in pkg/backend/backend.go so it can be used.

Suggested change
func HandleSpeechCloudflare(c echo.Context, accountID string, options mo.Option[types.SpeechRequestOptions]) mo.Result[any] {
func HandleSpeechCloudflare(c echo.Context, options mo.Option[types.SpeechRequestOptions]) mo.Result[any] {

// Extract options safely once
opt := options.MustGet()
Comment thread
nekomeowww marked this conversation as resolved.

// --- 1. Select Model ---
// Choose a Cloudflare TTS model.
// You could make this dynamic based on opt.Model if you map your internal
// model names (e.g., "tts-1") to Cloudflare's model names.
//
// Available models include:
// - @cf/myshell-ai/melotts (supports 'text' and 'lang' params)
// - @cf/deepgram/aura-1 (supports 'text' param)
//
// We'll use @cf/myshell-ai/melotts as an example.
const modelName = "@cf/myshell-ai/melotts"
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The modelName is hardcoded, which limits the flexibility of this backend. As the comment on the preceding lines suggests, this should be dynamic. You can use the opt.Model field from the request options and fall back to a default model if it's not provided.

Suggested change
const modelName = "@cf/myshell-ai/melotts"
modelName := opt.Model
if modelName == "" {
modelName = "@cf/myshell-ai/melotts"
}

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's follow this.


// --- 2. Build Cloudflare Payload ---
// Note: Cloudflare's TTS models (like melotts) do not support the
// 'voice', 'speed', or 'response_format' parameters from the OpenAI API.
// The input text field is 'text' (or 'prompt' for some models), not 'input'.
values := CloudflareSpeechRequest{
Text: opt.Input,
}
payload := lo.Must(json.Marshal(values))

// --- 3. Build HTTP Request ---
// The endpoint format is:
// https://api.cloudflare.com/client/v4/accounts/{ACCOUNT_ID}/ai/run/{MODEL_NAME}
endpoint := fmt.Sprintf(
"https://api.cloudflare.com/client/v4/accounts/%s/ai/run/%s",
accountID,
modelName,
)

req, err := http.NewRequestWithContext(
c.Request().Context(),
http.MethodPost,
endpoint,
bytes.NewBuffer(payload),
)
if err != nil {
return mo.Err[any](apierrors.NewErrInternal().WithCaller())
}

// Set Headers
// The Authorization header must contain a Cloudflare API Token (Bearer)
req.Header.Set("Authorization", c.Request().Header.Get("Authorization"))
req.Header.Set("Content-Type", "application/json")
// Requesting a specific audio format is good practice, though models
// often default to mp3.
req.Header.Set("Accept", "audio/mpeg")

// --- 4. Execute Request ---
res, err := http.DefaultClient.Do(req)
Comment thread
nekomeowww marked this conversation as resolved.
if err != nil {
return mo.Err[any](
apierrors.NewErrBadGateway().
WithDetail(err.Error()).
WithError(err).
WithCaller(),
)
}

defer func() { _ = res.Body.Close() }()

// --- 5. Handle Errors (Same as before, this logic is solid) ---
if res.StatusCode >= http.StatusBadRequest {
ct := res.Header.Get("Content-Type")

switch {
case strings.HasPrefix(ct, "application/json"):
// Cloudflare errors are returned as JSON
return mo.Err[any](
apierrors.NewUpstreamError(res.StatusCode).
WithDetail(utils.NewJSONResponseError(res.StatusCode, res.Body).OrEmpty().Error()),
)
case strings.HasPrefix(ct, "text/"):
return mo.Err[any](
apierrors.NewUpstreamError(res.StatusCode).
WithDetail(utils.NewTextResponseError(res.StatusCode, res.Body).OrEmpty().Error()),
)
default:
slog.Warn("unknown upstream error",
slog.Int("status", res.StatusCode),
slog.String("content_type", ct),
slog.String("content_length", res.Header.Get("Content-Length")),
)

return mo.Err[any](
apierrors.NewUpstreamError(res.StatusCode).
WithDetail("unknown Content-Type: " + ct),
)
}
}

// --- 6. Stream Successful Audio Response ---
// On success, Cloudflare returns the raw audio stream directly in the body.
// The Content-Type (e.g., "audio/mpeg") is correctly proxied.
return mo.Ok[any](c.Stream(http.StatusOK, res.Header.Get("Content-Type"), res.Body))
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The return value of c.Stream is an error. By wrapping it in mo.Ok[any](...), you are treating a potential error as a success case, which will hide streaming failures. The error should be checked and handled. Since headers will have already been sent at this point, you cannot send a JSON error response, but logging the error is important for debugging.

	if err := c.Stream(http.StatusOK, res.Header.Get("Content-Type"), res.Body); err != nil {
		slog.ErrorContext(c.Request().Context(), "failed to stream response", "err", err)
	}
	return mo.Ok[any](nil)

}
Loading