From 9883b39075d90d40dac5915820d885579e7c9f86 Mon Sep 17 00:00:00 2001 From: dboyliao <6830390+dboyliao@users.noreply.github.com> Date: Fri, 8 May 2026 23:28:46 +0800 Subject: [PATCH 1/7] =?UTF-8?q?=E2=9C=A8=20feat(xfundingv2):=20add=20isFut?= =?UTF-8?q?ures=20flag=20to=20TWAPExecutor=20and=20implement=20JSON=20mars?= =?UTF-8?q?haling/unmarshaling?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../xfundingv2/twap_order_executor.go | 83 +++++----- .../xfundingv2/twap_order_executor_sync.go | 78 +++++++++ .../twap_order_executor_sync_test.go | 148 ++++++++++++++++++ .../xfundingv2/twap_order_executor_test.go | 9 +- 4 files changed, 274 insertions(+), 44 deletions(-) create mode 100644 pkg/strategy/xfundingv2/twap_order_executor_sync.go create mode 100644 pkg/strategy/xfundingv2/twap_order_executor_sync_test.go diff --git a/pkg/strategy/xfundingv2/twap_order_executor.go b/pkg/strategy/xfundingv2/twap_order_executor.go index 6986a50ca1..fb6a8c11bc 100644 --- a/pkg/strategy/xfundingv2/twap_order_executor.go +++ b/pkg/strategy/xfundingv2/twap_order_executor.go @@ -13,38 +13,37 @@ import ( ) type TWAPExecutor struct { - TWAPWorkerConfig - - mu sync.Mutex - ctx context.Context + mu sync.Mutex + ctx context.Context exchange types.ExchangeOrderQueryService - market types.Market executor *bbgo.GeneralOrderExecutor - orders map[uint64]types.OrderQuery - trades map[uint64]types.Trade + syncState TWAPExecutorSyncState logger logrus.FieldLogger } -func NewTWAPOrderExecutor( +func NewTWAPExecutor( ctx context.Context, exchange types.ExchangeOrderQueryService, + isFutures bool, market types.Market, executor *bbgo.GeneralOrderExecutor, config TWAPWorkerConfig, ) *TWAPExecutor { return &TWAPExecutor{ - TWAPWorkerConfig: config, - ctx: ctx, exchange: exchange, executor: executor, - market: market, - orders: make(map[uint64]types.OrderQuery), - trades: make(map[uint64]types.Trade), + syncState: TWAPExecutorSyncState{ + Config: config, + Market: market, + IsFutures: isFutures, + Orders: make(map[uint64]types.OrderQuery), + Trades: make(map[uint64]types.Trade), + }, } } @@ -52,12 +51,16 @@ func (o *TWAPExecutor) SetLogger(logger logrus.FieldLogger) { o.logger = logger } +func (o *TWAPExecutor) Market() types.Market { + return o.syncState.Market +} + func (o *TWAPExecutor) Start() { if o.logger == nil { o.logger = logrus.WithFields( logrus.Fields{ "component": "TWAPOrderExecutor", - "symbol": o.market.Symbol, + "symbol": o.syncState.Market.Symbol, }, ) } @@ -69,8 +72,8 @@ func (o *TWAPExecutor) AddTrade(trade types.Trade) { o.mu.Lock() defer o.mu.Unlock() - if _, exists := o.orders[trade.OrderID]; exists { - o.trades[trade.ID] = trade + if _, exists := o.syncState.Orders[trade.OrderID]; exists { + o.syncState.Trades[trade.ID] = trade } } @@ -121,7 +124,7 @@ func (o *TWAPExecutor) SyncOrder(order types.Order) error { } func (o *TWAPExecutor) GetOrder(orderID uint64) (types.Order, bool) { - if _, exists := o.orders[orderID]; !exists { + if _, exists := o.syncState.Orders[orderID]; !exists { return types.Order{}, false } return o.executor.OrderStore().Get(orderID) @@ -130,7 +133,7 @@ func (o *TWAPExecutor) GetOrder(orderID uint64) (types.Order, bool) { func (o *TWAPExecutor) AllOrders() []types.Order { var orders []types.Order - for orderID := range o.orders { + for orderID := range o.syncState.Orders { if order, exists := o.executor.OrderStore().Get(orderID); exists { orders = append(orders, order) } @@ -140,7 +143,7 @@ func (o *TWAPExecutor) AllOrders() []types.Order { func (o *TWAPExecutor) AllTrades() []types.Trade { var trades []types.Trade - for _, trade := range o.trades { + for _, trade := range o.syncState.Trades { trades = append(trades, trade) } return trades @@ -149,15 +152,15 @@ func (o *TWAPExecutor) AllTrades() []types.Trade { // place order func (o *TWAPExecutor) PlaceOrder(quantity fixedpoint.Value, side types.SideType, orderBook types.OrderBook, deadlineExceeded bool) (*types.Order, error) { // find the better price and submit new order - quantity = o.market.TruncateQuantity(quantity) + quantity = o.syncState.Market.TruncateQuantity(quantity) price, err := o.GetPrice(side, orderBook) if err != nil { o.logger.WithError(err).Warn("[TWAP tick] failed to get price for active order update") return nil, err } - price = o.market.TruncatePrice(price) + price = o.syncState.Market.TruncatePrice(price) order := o.buildSubmitOrder(quantity, price, side, deadlineExceeded) - if o.market.IsDustQuantity(order.Quantity, order.Price) { + if o.syncState.Market.IsDustQuantity(order.Quantity, order.Price) { return nil, fmt.Errorf("order is of dust quantity: %s", quantity) } @@ -168,15 +171,15 @@ func (o *TWAPExecutor) PlaceOrder(quantity fixedpoint.Value, side types.SideType if err != nil || len(createdOrders) == 0 { return nil, fmt.Errorf("failed to submit order: %+v, %v", order, err) } - o.orders[createdOrders[0].OrderID] = createdOrders[0].AsQuery() + o.syncState.Orders[createdOrders[0].OrderID] = createdOrders[0].AsQuery() return &createdOrders[0], nil } func (o *TWAPExecutor) buildSubmitOrder(quantity, price fixedpoint.Value, side types.SideType, deadlineExceeded bool) types.SubmitOrder { if deadlineExceeded { return types.SubmitOrder{ - Symbol: o.market.Symbol, - Market: o.market, + Symbol: o.syncState.Market.Symbol, + Market: o.syncState.Market, Side: side, Type: types.OrderTypeMarket, Quantity: quantity, @@ -185,14 +188,14 @@ func (o *TWAPExecutor) buildSubmitOrder(quantity, price fixedpoint.Value, side t orderType := types.OrderTypeLimitMaker timeInForce := types.TimeInForceGTC - if o.OrderType == TWAPOrderTypeTaker { + if o.syncState.Config.OrderType == TWAPOrderTypeTaker { orderType = types.OrderTypeLimit timeInForce = types.TimeInForceIOC } return types.SubmitOrder{ - Symbol: o.market.Symbol, - Market: o.market, + Symbol: o.syncState.Market.Symbol, + Market: o.syncState.Market, Side: side, Type: orderType, Quantity: quantity, @@ -204,11 +207,11 @@ func (o *TWAPExecutor) buildSubmitOrder(quantity, price fixedpoint.Value, side t func (o *TWAPExecutor) GetPrice(side types.SideType, orderBook types.OrderBook) (price fixedpoint.Value, err error) { defer func() { if err == nil { - price = o.market.TruncatePrice(price) + price = o.syncState.Market.TruncatePrice(price) } }() - switch o.OrderType { + switch o.syncState.Config.OrderType { case TWAPOrderTypeTaker: return o.getTakerPrice(side, orderBook) case TWAPOrderTypeMaker: @@ -223,11 +226,11 @@ func (o *TWAPExecutor) getTakerPrice(side types.SideType, orderBook types.OrderB case types.SideTypeBuy: ask, ok := orderBook.BestAsk() if !ok { - return fixedpoint.Zero, fmt.Errorf("no ask price available for %s", o.market.Symbol) + return fixedpoint.Zero, fmt.Errorf("no ask price available for %s", o.syncState.Market.Symbol) } price := ask.Price - if o.MaxSlippage.Sign() > 0 { - maxPrice := price.Mul(fixedpoint.One.Add(o.MaxSlippage)) + if o.syncState.Config.MaxSlippage.Sign() > 0 { + maxPrice := price.Mul(fixedpoint.One.Add(o.syncState.Config.MaxSlippage)) price = fixedpoint.Min(price, maxPrice) } return price, nil @@ -235,11 +238,11 @@ func (o *TWAPExecutor) getTakerPrice(side types.SideType, orderBook types.OrderB case types.SideTypeSell: bid, ok := orderBook.BestBid() if !ok { - return fixedpoint.Zero, fmt.Errorf("no bid price available for %s", o.market.Symbol) + return fixedpoint.Zero, fmt.Errorf("no bid price available for %s", o.syncState.Market.Symbol) } price := bid.Price - if o.MaxSlippage.Sign() > 0 { - minPrice := price.Mul(fixedpoint.One.Sub(o.MaxSlippage)) + if o.syncState.Config.MaxSlippage.Sign() > 0 { + minPrice := price.Mul(fixedpoint.One.Sub(o.syncState.Config.MaxSlippage)) price = fixedpoint.Max(price, minPrice) } return price, nil @@ -250,15 +253,15 @@ func (o *TWAPExecutor) getTakerPrice(side types.SideType, orderBook types.OrderB } func (o *TWAPExecutor) getMakerPrice(side types.SideType, orderBook types.OrderBook) (fixedpoint.Value, error) { - tickSize := o.market.TickSize - numOfTicks := fixedpoint.NewFromInt(int64(o.NumOfTicks)) + tickSize := o.syncState.Market.TickSize + numOfTicks := fixedpoint.NewFromInt(int64(o.syncState.Config.NumOfTicks)) tickImprovement := tickSize.Mul(numOfTicks) switch side { case types.SideTypeBuy: bid, ok := orderBook.BestBid() if !ok { - return fixedpoint.Zero, fmt.Errorf("no bid price available for %s", o.market.Symbol) + return fixedpoint.Zero, fmt.Errorf("no bid price available for %s", o.syncState.Market.Symbol) } // improve price by moving closer to spread price := bid.Price.Add(tickImprovement) @@ -272,7 +275,7 @@ func (o *TWAPExecutor) getMakerPrice(side types.SideType, orderBook types.OrderB case types.SideTypeSell: ask, ok := orderBook.BestAsk() if !ok { - return fixedpoint.Zero, fmt.Errorf("no ask price available for %s", o.market.Symbol) + return fixedpoint.Zero, fmt.Errorf("no ask price available for %s", o.syncState.Market.Symbol) } price := ask.Price.Sub(tickImprovement) bid, hasBid := orderBook.BestBid() diff --git a/pkg/strategy/xfundingv2/twap_order_executor_sync.go b/pkg/strategy/xfundingv2/twap_order_executor_sync.go new file mode 100644 index 0000000000..d611975767 --- /dev/null +++ b/pkg/strategy/xfundingv2/twap_order_executor_sync.go @@ -0,0 +1,78 @@ +package xfundingv2 + +import ( + "encoding/json" + "errors" + "fmt" + + "github.com/c9s/bbgo/pkg/bbgo" + "github.com/c9s/bbgo/pkg/types" +) + +type TWAPExecutorSyncState struct { + Config TWAPWorkerConfig `json:"config"` + Market types.Market `json:"market"` + IsFutures bool `json:"isFutures"` + Orders map[uint64]types.OrderQuery `json:"orders,omitempty"` + Trades map[uint64]types.Trade `json:"trades,omitempty"` +} + +func (o *TWAPExecutor) LoadStrategy(s *Strategy) error { + o.SetLogger(s.logger) + var session *bbgo.ExchangeSession + var executor *bbgo.GeneralOrderExecutor + if o.syncState.IsFutures { + executor = s.futuresGeneralOrderExecutors[o.syncState.Market.Symbol] + session = s.futuresSession + } else { + executor = s.spotGeneralOrderExecutors[o.syncState.Market.Symbol] + session = s.spotSession + } + if executor == nil { + return errors.New("[TWAPExecutor] futures general order executor not found for market: " + o.syncState.Market.Symbol) + } + o.executor = executor + + // sync market + if market, ok := session.Market(o.syncState.Market.Symbol); !ok { + return errors.New("[TWAPExecutor] market not found in session: " + o.syncState.Market.Symbol) + } else { + o.syncState.Market = market + } + // set order query service + if service, ok := session.Exchange.(types.ExchangeOrderQueryService); ok { + o.exchange = service + } else { + return errors.New("[TWAPExecutor] session exchange does not implement ExchangeOrderQueryService") + } + // sync orders + for _, query := range o.syncState.Orders { + order, err := o.exchange.QueryOrder(o.ctx, query) + if err != nil || order == nil { + return fmt.Errorf("[TWAPExecutor] failed to query order %v: %w", query, err) + } + o.executor.OrderStore().Add(*order) + } + return nil +} + +func (o *TWAPExecutor) MarshalJSON() ([]byte, error) { + return json.Marshal(o.syncState) +} + +func (o *TWAPExecutor) UnmarshalJSON(b []byte) error { + stateData := TWAPExecutorSyncState{} + if err := json.Unmarshal(b, &stateData); err != nil { + return err + } + + o.syncState = stateData + if o.syncState.Orders == nil { + o.syncState.Orders = make(map[uint64]types.OrderQuery) + } + if o.syncState.Trades == nil { + o.syncState.Trades = make(map[uint64]types.Trade) + } + + return nil +} diff --git a/pkg/strategy/xfundingv2/twap_order_executor_sync_test.go b/pkg/strategy/xfundingv2/twap_order_executor_sync_test.go new file mode 100644 index 0000000000..fd8c40c670 --- /dev/null +++ b/pkg/strategy/xfundingv2/twap_order_executor_sync_test.go @@ -0,0 +1,148 @@ +package xfundingv2 + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "github.com/c9s/bbgo/pkg/types" + + . "github.com/c9s/bbgo/pkg/testing/testhelper" +) + +func TestTWAPExecutor_MarshalJSON(t *testing.T) { + market := Market("BTCUSDT") + market.Exchange = types.ExchangeBinance + + config := TWAPWorkerConfig{ + Duration: 5 * time.Minute, + NumSlices: 10, + OrderType: TWAPOrderTypeMaker, + NumOfTicks: 3, + MaxSlippage: Number(0.001), + } + + executor := &TWAPExecutor{ + syncState: TWAPExecutorSyncState{ + Config: config, + Market: market, + IsFutures: true, + Orders: map[uint64]types.OrderQuery{ + 123: {Symbol: "BTCUSDT", OrderID: "123"}, + }, + Trades: map[uint64]types.Trade{}, + }, + } + + data, err := json.Marshal(executor) + assert.NoError(t, err) + + // Verify JSON structure has expected top-level keys + var raw map[string]json.RawMessage + err = json.Unmarshal(data, &raw) + assert.NoError(t, err) + + assert.Contains(t, raw, "config") + assert.Contains(t, raw, "isFutures") + assert.Contains(t, raw, "market") + + // Verify isFutures=true and orders at top level + var stateData TWAPExecutorSyncState + err = json.Unmarshal(data, &stateData) + assert.NoError(t, err) + assert.True(t, stateData.IsFutures) + assert.Len(t, stateData.Orders, 1) + + // Round-trip: unmarshal back and verify fields + var restored TWAPExecutor + err = json.Unmarshal(data, &restored) + assert.NoError(t, err) + + assert.Equal(t, config, restored.syncState.Config) + assert.Equal(t, market.Symbol, restored.syncState.Market.Symbol) + assert.Equal(t, types.ExchangeBinance, restored.syncState.Market.Exchange) + assert.True(t, restored.syncState.IsFutures) + assert.Len(t, restored.syncState.Orders, 1) + assert.Equal(t, "123", restored.syncState.Orders[123].OrderID) +} + +func TestTWAPExecutor_MarshalJSON_NotFutures(t *testing.T) { + market := Market("BTCUSDT") + market.Exchange = types.ExchangeBinance + + executor := &TWAPExecutor{ + syncState: TWAPExecutorSyncState{ + Config: TWAPWorkerConfig{ + OrderType: TWAPOrderTypeTaker, + }, + Market: market, + IsFutures: false, + Orders: map[uint64]types.OrderQuery{}, + Trades: map[uint64]types.Trade{}, + }, + } + + data, err := json.Marshal(executor) + assert.NoError(t, err) + + var restored TWAPExecutor + err = json.Unmarshal(data, &restored) + assert.NoError(t, err) + + assert.False(t, restored.syncState.IsFutures) + assert.Equal(t, TWAPOrderTypeTaker, restored.syncState.Config.OrderType) +} + +func TestTWAPExecutor_UnmarshalJSON(t *testing.T) { + t.Run("valid JSON", func(t *testing.T) { + // time.Duration marshals as nanoseconds (int64) + jsonData := `{ + "config": { + "duration": 300000000000, + "numSlices": 10, + "orderType": "maker", + "numOfTicks": 3 + }, + "market": {"symbol": "BTCUSDT", "exchange": "binance"}, + "isFutures": false, + "orders": { + "100": {"symbol": "BTCUSDT", "orderID": "100"} + }, + "trades": {} + }` + + var executor TWAPExecutor + err := json.Unmarshal([]byte(jsonData), &executor) + assert.NoError(t, err) + assert.Equal(t, 5*time.Minute, executor.syncState.Config.Duration) + assert.Equal(t, 10, executor.syncState.Config.NumSlices) + assert.Equal(t, TWAPOrderTypeMaker, executor.syncState.Config.OrderType) + assert.Equal(t, "BTCUSDT", executor.syncState.Market.Symbol) + assert.False(t, executor.syncState.IsFutures) + assert.Len(t, executor.syncState.Orders, 1) + }) + + t.Run("futures flag", func(t *testing.T) { + jsonData := `{ + "config": {}, + "market": {"symbol": "ETHUSDT", "exchange": "binance"}, + "isFutures": true, + "orders": {}, + "trades": {} + }` + + var executor TWAPExecutor + err := json.Unmarshal([]byte(jsonData), &executor) + assert.NoError(t, err) + assert.True(t, executor.syncState.IsFutures) + assert.Equal(t, "ETHUSDT", executor.syncState.Market.Symbol) + }) + + t.Run("invalid JSON returns error", func(t *testing.T) { + var executor TWAPExecutor + err := json.Unmarshal([]byte(`{invalid`), &executor) + assert.Error(t, err) + }) +} diff --git a/pkg/strategy/xfundingv2/twap_order_executor_test.go b/pkg/strategy/xfundingv2/twap_order_executor_test.go index ce2a3565bd..3ecc3a20e2 100644 --- a/pkg/strategy/xfundingv2/twap_order_executor_test.go +++ b/pkg/strategy/xfundingv2/twap_order_executor_test.go @@ -40,9 +40,10 @@ func testExecutorSetup( generalExecutor := bbgo.NewGeneralOrderExecutor(session, "BTCUSDT", "test", "test-instance", position) ctx := context.Background() - executor := NewTWAPOrderExecutor( + executor := NewTWAPExecutor( ctx, mockOrderQuery, + false, market, generalExecutor, config, @@ -63,8 +64,8 @@ func TestNewTWAPOrderExecutor(t *testing.T) { executor, _, _ := testExecutorSetup(t, ctrl, config) assert.NotNil(t, executor) - assert.Equal(t, TWAPOrderTypeMaker, executor.OrderType) - assert.Equal(t, 2, executor.NumOfTicks) + assert.Equal(t, TWAPOrderTypeMaker, executor.syncState.Config.OrderType) + assert.Equal(t, 2, executor.syncState.Config.NumOfTicks) } func TestTWAPOrderExecutor_Start(t *testing.T) { @@ -683,7 +684,7 @@ func TestTWAPOrderExecutor_GetOrder(t *testing.T) { Symbol: "BTCUSDT", }, } - executor.orders[order.OrderID] = order.AsQuery() // Add to ordersMap to simulate tracking + executor.syncState.Orders[order.OrderID] = order.AsQuery() // Add to ordersMap to simulate tracking executor.executor.OrderStore().Add(order) // Test GetOrder From 3b8531f68f84098f4a77a0508302c07f05e1415b Mon Sep 17 00:00:00 2001 From: dboyliao <6830390+dboyliao@users.noreply.github.com> Date: Fri, 8 May 2026 23:29:33 +0800 Subject: [PATCH 2/7] =?UTF-8?q?=F0=9F=94=A7=20fix(types):=20add=20JSON=20t?= =?UTF-8?q?ags=20to=20OrderQuery=20for=20proper=20marshaling?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/types/order.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg/types/order.go b/pkg/types/order.go index 5bfcb9347a..7387ffb352 100644 --- a/pkg/types/order.go +++ b/pkg/types/order.go @@ -315,10 +315,10 @@ func (o *SubmitOrder) amountField() *slack.AttachmentField { } type OrderQuery struct { - Symbol string - OrderID string - ClientOrderID string - OrderUUID string + Symbol string `json:"symbol"` + OrderID string `json:"orderID"` + ClientOrderID string `json:"clientOrderID,omitempty"` + OrderUUID string `json:"orderUUID,omitempty"` } type Order struct { From 6309bd7fee6a9eb620a184470ded90f19de642e2 Mon Sep 17 00:00:00 2001 From: dboyliao <6830390+dboyliao@users.noreply.github.com> Date: Fri, 8 May 2026 23:29:56 +0800 Subject: [PATCH 3/7] =?UTF-8?q?=E2=9C=A8=20feat(xfundingv2):=20implement?= =?UTF-8?q?=20LoadStrategy=20and=20JSON=20marshaling/unmarshaling=20for=20?= =?UTF-8?q?TWAPWorker?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/strategy/xfundingv2/twap.go | 219 ++++++++++------------ pkg/strategy/xfundingv2/twap_sync.go | 61 ++++++ pkg/strategy/xfundingv2/twap_sync_test.go | 189 +++++++++++++++++++ pkg/strategy/xfundingv2/twap_test.go | 8 +- 4 files changed, 356 insertions(+), 121 deletions(-) create mode 100644 pkg/strategy/xfundingv2/twap_sync.go create mode 100644 pkg/strategy/xfundingv2/twap_sync_test.go diff --git a/pkg/strategy/xfundingv2/twap.go b/pkg/strategy/xfundingv2/twap.go index cee4b32051..7fab8f8a50 100644 --- a/pkg/strategy/xfundingv2/twap.go +++ b/pkg/strategy/xfundingv2/twap.go @@ -59,27 +59,9 @@ type TWAPWorker struct { // filledQuantity, activeOrder, trades, state mu sync.Mutex - config TWAPWorkerConfig - - targetPosition fixedpoint.Value // positive = buy/long, negative = sell/short - - // state - state TWAPWorkerState - startTime time.Time - endTime time.Time - - placeOrderInterval time.Duration - currentIntervalStart time.Time - currentIntervalEnd time.Time - lastCheckTime time.Time - - symbol string - - ctx context.Context - - activeOrder *types.Order - twapExecutor *TWAPExecutor + syncState TWAPWorkerSyncState + ctx context.Context logger logrus.FieldLogger } @@ -99,15 +81,18 @@ func NewTWAPWorker( return nil, fmt.Errorf("exchange does not support OrderQueryService: %s", session.Exchange.Name()) } w := &TWAPWorker{ - config: config, - symbol: symbol, - state: TWAPWorkerStatePending, - targetPosition: fixedpoint.Zero, + syncState: TWAPWorkerSyncState{ + Symbol: symbol, + Config: config, + State: TWAPWorkerStatePending, + TargetPosition: fixedpoint.Zero, + }, } w.ctx = ctx - w.twapExecutor = NewTWAPOrderExecutor( + w.syncState.TWAPExecutor = NewTWAPExecutor( w.ctx, service, + session.Futures, market, generalExecutor, config, @@ -117,7 +102,7 @@ func NewTWAPWorker( // SetTargetPosition sets the target position for the TWAP worker. func (w *TWAPWorker) SetTargetPosition(targetPosition fixedpoint.Value) { - w.targetPosition = targetPosition + w.syncState.TargetPosition = targetPosition } func (w *TWAPWorker) SetLogger(logger logrus.FieldLogger) { @@ -125,36 +110,36 @@ func (w *TWAPWorker) SetLogger(logger logrus.FieldLogger) { } func (w *TWAPWorker) Symbol() string { - return w.symbol + return w.syncState.Symbol } func (w *TWAPWorker) Market() types.Market { - return w.twapExecutor.market + return w.syncState.TWAPExecutor.Market() } func (w *TWAPWorker) Executor() *TWAPExecutor { - return w.twapExecutor + return w.syncState.TWAPExecutor } func (w *TWAPWorker) State() TWAPWorkerState { w.mu.Lock() defer w.mu.Unlock() - return w.state + return w.syncState.State } func (w *TWAPWorker) IsDone() bool { w.mu.Lock() defer w.mu.Unlock() - return w.state == TWAPWorkerStateDone + return w.syncState.State == TWAPWorkerStateDone } func (w *TWAPWorker) AveragePrice() fixedpoint.Value { w.mu.Lock() defer w.mu.Unlock() - trades := w.twapExecutor.AllTrades() + trades := w.syncState.TWAPExecutor.AllTrades() return tradingutil.AveragePriceFromTrades(trades) } @@ -170,11 +155,11 @@ func (w *TWAPWorker) AddTrade(trade types.Trade) { w.mu.Lock() defer w.mu.Unlock() - w.twapExecutor.AddTrade(trade) + w.syncState.TWAPExecutor.AddTrade(trade) } func (w *TWAPWorker) filledPosition() fixedpoint.Value { - trades := w.twapExecutor.AllTrades() + trades := w.syncState.TWAPExecutor.AllTrades() position := fixedpoint.Zero for _, t := range trades { if t.Side == types.SideTypeBuy { @@ -190,7 +175,7 @@ func (w *TWAPWorker) TotalFee() map[string]fixedpoint.Value { w.mu.Lock() defer w.mu.Unlock() - trades := w.twapExecutor.AllTrades() + trades := w.syncState.TWAPExecutor.AllTrades() feeMap := make(map[string]fixedpoint.Value) for _, t := range trades { if t.FeeCurrency == "" || t.Fee.IsZero() { @@ -205,7 +190,7 @@ func (w *TWAPWorker) ActiveOrder() *types.Order { w.mu.Lock() defer w.mu.Unlock() - return w.activeOrder + return w.syncState.ActiveOrder } func (w *TWAPWorker) RemainingQuantity() fixedpoint.Value { @@ -218,14 +203,14 @@ func (w *TWAPWorker) RemainingQuantity() fixedpoint.Value { func (w *TWAPWorker) remainingQuantity() fixedpoint.Value { // remaining = target - filled // NOTE: the remaining quantity can be positive or negative. - return w.targetPosition.Sub(w.filledPosition()) + return w.syncState.TargetPosition.Sub(w.filledPosition()) } func (w *TWAPWorker) TargetPosition() fixedpoint.Value { w.mu.Lock() defer w.mu.Unlock() - return w.targetPosition + return w.syncState.TargetPosition } func (w *TWAPWorker) Start(ctx context.Context, currentTime time.Time) error { @@ -233,28 +218,28 @@ func (w *TWAPWorker) Start(ctx context.Context, currentTime time.Time) error { w.mu.Lock() defer w.mu.Unlock() - if w.state != TWAPWorkerStatePending { - return fmt.Errorf("cannot start TWAPWorker: expected state Pending, got %s", w.state) + if w.syncState.State != TWAPWorkerStatePending { + return fmt.Errorf("cannot start TWAPWorker: expected state Pending, got %s", w.syncState.State) } if w.logger == nil { w.logger = logrus.WithFields(logrus.Fields{ "component": "twap", - "symbol": w.symbol, + "symbol": w.syncState.Symbol, }) } // start the executor - w.twapExecutor.SetLogger(w.logger) - w.twapExecutor.Start() + w.syncState.TWAPExecutor.SetLogger(w.logger) + w.syncState.TWAPExecutor.Start() - w.resetTime(currentTime, w.config.Duration) + w.resetTime(currentTime, w.syncState.Config.Duration) w.logger.Infof( "[TWAP Start] started: targetPosition=%s, duration=%s, interval=%s", - w.targetPosition, - w.config.Duration, - w.placeOrderInterval, + w.syncState.TargetPosition, + w.syncState.Config.Duration, + w.syncState.PlaceOrderInterval, ) return nil } @@ -263,10 +248,10 @@ func (w *TWAPWorker) RemainingDuration(currentTime time.Time) time.Duration { w.mu.Lock() defer w.mu.Unlock() - if currentTime.After(w.endTime) { + if currentTime.After(w.syncState.EndTime) { return 0 } - return w.endTime.Sub(currentTime) + return w.syncState.EndTime.Sub(currentTime) } // ResetTime resets the start and end time of the TWAP execution. @@ -280,21 +265,21 @@ func (w *TWAPWorker) ResetTime(currentTime time.Time, duration time.Duration) { } func (w *TWAPWorker) resetTime(currentTime time.Time, duration time.Duration) { - w.state = TWAPWorkerStateRunning - w.startTime = currentTime - w.config.Duration = duration - w.endTime = currentTime.Add(w.config.Duration) + w.syncState.State = TWAPWorkerStateRunning + w.syncState.StartTime = currentTime + w.syncState.Config.Duration = duration + w.syncState.EndTime = currentTime.Add(w.syncState.Config.Duration) - numSlices := w.config.NumSlices + numSlices := w.syncState.Config.NumSlices if numSlices <= 0 { numSlices = 1 } - w.placeOrderInterval = w.config.Duration / time.Duration(numSlices) - w.currentIntervalStart = currentTime - w.currentIntervalEnd = w.currentIntervalStart.Add(w.placeOrderInterval) - if w.currentIntervalEnd.After(w.endTime) { - w.currentIntervalEnd = w.endTime + w.syncState.PlaceOrderInterval = w.syncState.Config.Duration / time.Duration(numSlices) + w.syncState.CurrentIntervalStart = currentTime + w.syncState.CurrentIntervalEnd = w.syncState.CurrentIntervalStart.Add(w.syncState.PlaceOrderInterval) + if w.syncState.CurrentIntervalEnd.After(w.syncState.EndTime) { + w.syncState.CurrentIntervalEnd = w.syncState.EndTime } w.syncAndResetActiveOrder() @@ -305,24 +290,24 @@ func (w *TWAPWorker) Stop() { w.mu.Lock() defer w.mu.Unlock() - if w.state == TWAPWorkerStateRunning || w.state == TWAPWorkerStatePending { - if w.activeOrder != nil { - err := w.twapExecutor.CancelOrder(w.ctx, *w.activeOrder) + if w.syncState.State == TWAPWorkerStateRunning || w.syncState.State == TWAPWorkerStatePending { + if w.syncState.ActiveOrder != nil { + err := w.syncState.TWAPExecutor.CancelOrder(w.ctx, *w.syncState.ActiveOrder) if err != nil { - w.logger.WithError(err).Warnf("[TWAP Stop] failed to cancel active order: %s", w.activeOrder) + w.logger.WithError(err).Warnf("[TWAP Stop] failed to cancel active order: %s", w.syncState.ActiveOrder) } } // stop executor - if err := w.twapExecutor.Stop(); err != nil { + if err := w.syncState.TWAPExecutor.Stop(); err != nil { w.logger.WithError(err).Warn("[TWAP Stop] failed to stop TWAP executor") } - w.state = TWAPWorkerStateDone - w.activeOrder = nil + w.syncState.State = TWAPWorkerStateDone + w.syncState.ActiveOrder = nil w.logger.Infof( "[TWAP Stop] stopped: filled=%s / target=%s", - w.filledPosition(), w.targetPosition, + w.filledPosition(), w.syncState.TargetPosition, ) } } @@ -331,18 +316,18 @@ func (w *TWAPWorker) Stop() { // its trades via REST API, updates ordersMap and tradesMap accordingly, then // resets activeOrder to nil. Must be called under lock. func (w *TWAPWorker) syncAndResetActiveOrder() *types.Order { - if w.activeOrder == nil { + if w.syncState.ActiveOrder == nil { return nil } - if err := w.twapExecutor.SyncOrder(*w.activeOrder); err != nil { - w.logger.WithError(err).Warnf("[TWAP syncAndResetActiveOrder] fail to sync active order, resetting: %s", w.activeOrder.String()) - w.activeOrder = nil + if err := w.syncState.TWAPExecutor.SyncOrder(*w.syncState.ActiveOrder); err != nil { + w.logger.WithError(err).Warnf("[TWAP syncAndResetActiveOrder] fail to sync active order, resetting: %s", w.syncState.ActiveOrder.String()) + w.syncState.ActiveOrder = nil return nil } - oriActiveOrder, _ := w.twapExecutor.GetOrder(w.activeOrder.OrderID) - w.activeOrder = nil + oriActiveOrder, _ := w.syncState.TWAPExecutor.GetOrder(w.syncState.ActiveOrder.OrderID) + w.syncState.ActiveOrder = nil return &oriActiveOrder } @@ -358,27 +343,27 @@ func (w *TWAPWorker) Tick(currentTime time.Time, orderBook types.OrderBook) erro func (w *TWAPWorker) tick(currentTime time.Time, orderBook types.OrderBook) error { defer func() { - if currentTime.After(w.currentIntervalEnd) { - w.currentIntervalStart = w.currentIntervalEnd - w.currentIntervalEnd = w.currentIntervalStart.Add(w.placeOrderInterval) - if w.currentIntervalStart.After(w.endTime) { - w.currentIntervalStart = w.endTime + if currentTime.After(w.syncState.CurrentIntervalEnd) { + w.syncState.CurrentIntervalStart = w.syncState.CurrentIntervalEnd + w.syncState.CurrentIntervalEnd = w.syncState.CurrentIntervalStart.Add(w.syncState.PlaceOrderInterval) + if w.syncState.CurrentIntervalStart.After(w.syncState.EndTime) { + w.syncState.CurrentIntervalStart = w.syncState.EndTime } - if w.currentIntervalEnd.After(w.endTime) { - w.currentIntervalEnd = w.endTime + if w.syncState.CurrentIntervalEnd.After(w.syncState.EndTime) { + w.syncState.CurrentIntervalEnd = w.syncState.EndTime } } - if currentTime.After(w.endTime) { - w.state = TWAPWorkerStateDone + if currentTime.After(w.syncState.EndTime) { + w.syncState.State = TWAPWorkerStateDone } }() - if w.state != TWAPWorkerStateRunning { + if w.syncState.State != TWAPWorkerStateRunning { // the worker is not running return nil } - if currentTime.Before(w.currentIntervalStart) { + if currentTime.Before(w.syncState.CurrentIntervalStart) { // not time for the next order yet return nil } @@ -394,17 +379,17 @@ func (w *TWAPWorker) tick(currentTime time.Time, orderBook types.OrderBook) erro } // check if deadline exceeded - deadlineExceeded := !currentTime.Before(w.endTime) + deadlineExceeded := !currentTime.Before(w.syncState.EndTime) // if deadline exceeded, we want to place a final order for the remaining quantity if deadlineExceeded { - if w.activeOrder != nil { - if err := w.twapExecutor.CancelOrder(w.ctx, *w.activeOrder); err != nil { + if w.syncState.ActiveOrder != nil { + if err := w.syncState.TWAPExecutor.CancelOrder(w.ctx, *w.syncState.ActiveOrder); err != nil { w.logger.WithError(err).Warn("[TWAP tick] failed to cancel active order when deadline exceeded") return nil } w.syncAndResetActiveOrder() } - createdOrder, err := w.twapExecutor.PlaceOrder( + createdOrder, err := w.syncState.TWAPExecutor.PlaceOrder( remaining.Abs(), orderSide(remaining), orderBook, @@ -413,39 +398,39 @@ func (w *TWAPWorker) tick(currentTime time.Time, orderBook types.OrderBook) erro if err != nil || createdOrder == nil { return fmt.Errorf("failed to place final order when deadline exceeded: %w", err) } - w.activeOrder = createdOrder + w.syncState.ActiveOrder = createdOrder return nil } // from here, deadline not exceeded // we don't have an active order, place a new one - if w.activeOrder == nil { + if w.syncState.ActiveOrder == nil { sliceQty := w.calculateSliceQuantity(currentTime, remaining, false) - createdOrder, err := w.twapExecutor.PlaceOrder(sliceQty, orderSide(remaining), orderBook, false) + createdOrder, err := w.syncState.TWAPExecutor.PlaceOrder(sliceQty, orderSide(remaining), orderBook, false) if err != nil || createdOrder == nil { return fmt.Errorf("failed to place order: %w", err) } - w.activeOrder = createdOrder + w.syncState.ActiveOrder = createdOrder return nil } // from here, active order is not nil // we are within current interval and we have a better price - if w.shouldUpdateActiveOrder(orderBook) && currentTime.Before(w.currentIntervalEnd) { + if w.shouldUpdateActiveOrder(orderBook) && currentTime.Before(w.syncState.CurrentIntervalEnd) { // throttle order updates to avoid excessive cancel-and-replace - if !w.lastCheckTime.IsZero() && currentTime.Sub(w.lastCheckTime) < w.config.CheckInterval { + if !w.syncState.LastCheckTime.IsZero() && currentTime.Sub(w.syncState.LastCheckTime) < w.syncState.Config.CheckInterval { return nil } - w.lastCheckTime = currentTime + w.syncState.LastCheckTime = currentTime - if err := w.twapExecutor.CancelOrder(w.ctx, *w.activeOrder); err != nil { + if err := w.syncState.TWAPExecutor.CancelOrder(w.ctx, *w.syncState.ActiveOrder); err != nil { w.logger.WithError(err).Warn("[TWAP tick] failed to cancel active order") return nil } // find the better price and submit new order - createdOrder, err := w.twapExecutor.PlaceOrder( - w.activeOrder.GetRemainingQuantity(), - w.activeOrder.Side, + createdOrder, err := w.syncState.TWAPExecutor.PlaceOrder( + w.syncState.ActiveOrder.GetRemainingQuantity(), + w.syncState.ActiveOrder.Side, orderBook, deadlineExceeded, ) @@ -453,7 +438,7 @@ func (w *TWAPWorker) tick(currentTime time.Time, orderBook types.OrderBook) erro return fmt.Errorf("failed to place replacement order: %w", err) } oriActiveOrder := w.syncAndResetActiveOrder() - w.activeOrder = createdOrder + w.syncState.ActiveOrder = createdOrder w.logger.Infof("[TWAP tick] active order updated: %s %s qty=%s(executed: %s)->%s price=%s->%s", createdOrder.Side, createdOrder.Type, @@ -467,18 +452,18 @@ func (w *TWAPWorker) tick(currentTime time.Time, orderBook types.OrderBook) erro } // we are within the current interval, just wait for the next tick - if currentTime.Before(w.currentIntervalEnd) { + if currentTime.Before(w.syncState.CurrentIntervalEnd) { return nil } // currentTime is after current interval end, time to place the next slice order // calculate slice quantity sliceQty := w.calculateSliceQuantity(currentTime, remaining, deadlineExceeded) - createdOrder, err := w.twapExecutor.PlaceOrder(sliceQty, orderSide(remaining), orderBook, deadlineExceeded) + createdOrder, err := w.syncState.TWAPExecutor.PlaceOrder(sliceQty, orderSide(remaining), orderBook, deadlineExceeded) if err != nil || createdOrder == nil { return fmt.Errorf("failed to place order for next slice: %w", err) } - w.activeOrder = createdOrder + w.syncState.ActiveOrder = createdOrder return nil } @@ -491,12 +476,12 @@ func (w *TWAPWorker) calculateSliceQuantity(currentTime time.Time, remaining fix } // dynamic slice: remaining / remaining_slices - timeLeft := w.endTime.Sub(currentTime) + timeLeft := w.syncState.EndTime.Sub(currentTime) if timeLeft <= 0 { return remaining } - remainingSlices := int(timeLeft / w.placeOrderInterval) + remainingSlices := int(timeLeft / w.syncState.PlaceOrderInterval) if remainingSlices <= 0 { remainingSlices = 1 } @@ -504,15 +489,15 @@ func (w *TWAPWorker) calculateSliceQuantity(currentTime time.Time, remaining fix sliceQty := remaining.Div(fixedpoint.NewFromInt(int64(remainingSlices))) // apply min/max slice size constraints - if w.config.MaxSliceSize.Sign() > 0 && sliceQty.Compare(w.config.MaxSliceSize) > 0 { - sliceQty = w.config.MaxSliceSize + if w.syncState.Config.MaxSliceSize.Sign() > 0 && sliceQty.Compare(w.syncState.Config.MaxSliceSize) > 0 { + sliceQty = w.syncState.Config.MaxSliceSize } - if w.config.MinSliceSize.Sign() > 0 && sliceQty.Compare(w.config.MinSliceSize) < 0 { + if w.syncState.Config.MinSliceSize.Sign() > 0 && sliceQty.Compare(w.syncState.Config.MinSliceSize) < 0 { // if remaining is less than min, just use remaining - if remaining.Compare(w.config.MinSliceSize) <= 0 { + if remaining.Compare(w.syncState.Config.MinSliceSize) <= 0 { sliceQty = remaining } else { - sliceQty = w.config.MinSliceSize + sliceQty = w.syncState.Config.MinSliceSize } } @@ -528,30 +513,30 @@ func (w *TWAPWorker) calculateSliceQuantity(currentTime time.Time, remaining fix // with a better price. For taker orders (IOC), always update. For maker orders, // compare the current order price against the best computed maker price. func (w *TWAPWorker) shouldUpdateActiveOrder(orderBook types.OrderBook) bool { - if w.activeOrder == nil { + if w.syncState.ActiveOrder == nil { return false } // taker orders are IOC — always refresh - if w.config.OrderType == TWAPOrderTypeTaker { + if w.syncState.Config.OrderType == TWAPOrderTypeTaker { return true } - newPrice, err := w.twapExecutor.GetPrice(w.activeOrder.Side, orderBook) + newPrice, err := w.syncState.TWAPExecutor.GetPrice(w.syncState.ActiveOrder.Side, orderBook) if err != nil { w.logger.WithError(err).Warn("[TWAP shouldUpdateOrder] failed to get price for order update check") return false } newPriceBtter := false - switch w.activeOrder.Side { + switch w.syncState.ActiveOrder.Side { case types.SideTypeBuy: - newPriceBtter = newPrice.Compare(w.activeOrder.Price) > 0 + newPriceBtter = newPrice.Compare(w.syncState.ActiveOrder.Price) > 0 case types.SideTypeSell: - newPriceBtter = newPrice.Compare(w.activeOrder.Price) < 0 + newPriceBtter = newPrice.Compare(w.syncState.ActiveOrder.Price) < 0 } w.logger.Infof("[TWAP shouldUpdateOrder] order update check: current price=%s, new price=%s, better=%t", - w.activeOrder.Price.String(), newPrice.String(), newPriceBtter) + w.syncState.ActiveOrder.Price.String(), newPrice.String(), newPriceBtter) return newPriceBtter } diff --git a/pkg/strategy/xfundingv2/twap_sync.go b/pkg/strategy/xfundingv2/twap_sync.go new file mode 100644 index 0000000000..c6cb6b8365 --- /dev/null +++ b/pkg/strategy/xfundingv2/twap_sync.go @@ -0,0 +1,61 @@ +package xfundingv2 + +import ( + "context" + "encoding/json" + "fmt" + "time" + + "github.com/c9s/bbgo/pkg/fixedpoint" + "github.com/c9s/bbgo/pkg/types" +) + +func (w *TWAPWorker) LoadStrategy(ctx context.Context, s *Strategy) error { + if w.syncState.TWAPExecutor == nil { + // should not happen + return fmt.Errorf("[TWAPWorker] TWAPExecutor is nil") + } + + w.ctx = ctx + w.SetLogger(s.logger) + if err := w.syncState.TWAPExecutor.LoadStrategy(s); err != nil { + return fmt.Errorf("[TWAPWorker] failed to load TWAPExecutor: %w", err) + } + if w.syncState.ActiveOrder != nil { + return w.syncState.TWAPExecutor.SyncOrder(*w.syncState.ActiveOrder) + } + return nil +} + +type TWAPWorkerSyncState struct { + Config TWAPWorkerConfig `json:"config"` + + // TargetPosition: positive = buy/long, negative = sell/short + TargetPosition fixedpoint.Value `json:"targetPosition"` + State TWAPWorkerState `json:"state"` + StartTime time.Time `json:"startTime"` + EndTime time.Time `json:"endTime"` + CurrentIntervalStart time.Time `json:"currentIntervalStart"` + CurrentIntervalEnd time.Time `json:"currentIntervalEnd"` + LastCheckTime time.Time `json:"lastCheckTime"` + PlaceOrderInterval time.Duration `json:"placeOrderInterval"` + + Symbol string `json:"symbol"` + ActiveOrder *types.Order `json:"activeOrder,omitempty"` + TWAPExecutor *TWAPExecutor `json:"executor,omitempty"` +} + +func (w *TWAPWorker) MarshalJSON() ([]byte, error) { + return json.Marshal(w.syncState) +} + +func (w *TWAPWorker) UnmarshalJSON(b []byte) error { + stateData := TWAPWorkerSyncState{} + if err := json.Unmarshal(b, &stateData); err != nil { + return fmt.Errorf("failed to unmarshal TWAPWorker: %w", err) + } + + w.syncState = stateData + + return nil +} diff --git a/pkg/strategy/xfundingv2/twap_sync_test.go b/pkg/strategy/xfundingv2/twap_sync_test.go new file mode 100644 index 0000000000..181189062d --- /dev/null +++ b/pkg/strategy/xfundingv2/twap_sync_test.go @@ -0,0 +1,189 @@ +package xfundingv2 + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/c9s/bbgo/pkg/fixedpoint" + . "github.com/c9s/bbgo/pkg/testing/testhelper" + "github.com/c9s/bbgo/pkg/types" +) + +func TestTWAPWorker_MarshalJSON(t *testing.T) { + t.Run("all fields present", func(t *testing.T) { + now := time.Now().Truncate(time.Second) + order := &types.Order{ + SubmitOrder: types.SubmitOrder{ + Symbol: "BTCUSDT", + Side: types.SideTypeBuy, + Quantity: fixedpoint.NewFromFloat(0.1), + }, + OrderID: 12345, + } + + w := &TWAPWorker{ + syncState: TWAPWorkerSyncState{ + Config: TWAPWorkerConfig{ + Duration: 10 * time.Minute, + NumSlices: 5, + OrderType: TWAPOrderTypeTaker, + MaxSlippage: fixedpoint.NewFromFloat(0.001), + }, + TargetPosition: fixedpoint.NewFromFloat(1.5), + State: TWAPWorkerStateRunning, + StartTime: now, + EndTime: now.Add(10 * time.Minute), + PlaceOrderInterval: 2 * time.Minute, + CurrentIntervalStart: now, + CurrentIntervalEnd: now.Add(2 * time.Minute), + LastCheckTime: now.Add(1 * time.Minute), + Symbol: "BTCUSDT", + ActiveOrder: order, + }, + } + + data, err := json.Marshal(w) + require.NoError(t, err) + + var raw map[string]json.RawMessage + err = json.Unmarshal(data, &raw) + require.NoError(t, err) + + assert.Contains(t, raw, "config") + assert.Contains(t, raw, "symbol") + assert.Contains(t, raw, "activeOrder") + }) + + t.Run("nil active order omitted", func(t *testing.T) { + w := &TWAPWorker{ + syncState: TWAPWorkerSyncState{ + Config: TWAPWorkerConfig{ + Duration: 5 * time.Minute, + NumSlices: 3, + }, + TargetPosition: fixedpoint.NewFromFloat(2.0), + State: TWAPWorkerStatePending, + Symbol: "ETHUSDT", + }, + } + + data, err := json.Marshal(w) + require.NoError(t, err) + + var raw map[string]json.RawMessage + err = json.Unmarshal(data, &raw) + require.NoError(t, err) + + assert.NotContains(t, raw, "activeOrder") + assert.Contains(t, raw, "symbol") + }) +} + +func TestTWAPWorker_UnmarshalJSON(t *testing.T) { + t.Run("round trip", func(t *testing.T) { + now := time.Now().Truncate(time.Second) + order := &types.Order{ + SubmitOrder: types.SubmitOrder{ + Symbol: "BTCUSDT", + Side: types.SideTypeBuy, + Quantity: fixedpoint.NewFromFloat(0.1), + }, + OrderID: 12345, + Exchange: types.ExchangeBinance, + } + + market := Market("BTCUSDT") + market.Exchange = types.ExchangeBinance + + executor := &TWAPExecutor{ + syncState: TWAPExecutorSyncState{ + Config: TWAPWorkerConfig{ + Duration: 10 * time.Minute, + NumSlices: 5, + OrderType: TWAPOrderTypeTaker, + }, + Market: market, + IsFutures: true, + Orders: map[uint64]types.OrderQuery{}, + Trades: map[uint64]types.Trade{}, + }, + } + + original := &TWAPWorker{ + syncState: TWAPWorkerSyncState{ + Config: TWAPWorkerConfig{ + Duration: 10 * time.Minute, + NumSlices: 5, + OrderType: TWAPOrderTypeTaker, + MaxSlippage: fixedpoint.NewFromFloat(0.001), + CheckInterval: 30 * time.Second, + }, + TargetPosition: fixedpoint.NewFromFloat(1.5), + State: TWAPWorkerStateRunning, + StartTime: now, + EndTime: now.Add(10 * time.Minute), + PlaceOrderInterval: 2 * time.Minute, + CurrentIntervalStart: now, + CurrentIntervalEnd: now.Add(2 * time.Minute), + LastCheckTime: now.Add(1 * time.Minute), + Symbol: "BTCUSDT", + ActiveOrder: order, + TWAPExecutor: executor, + }, + } + + data, err := json.Marshal(original) + require.NoError(t, err) + + restored := &TWAPWorker{} + err = json.Unmarshal(data, restored) + require.NoError(t, err) + + assert.Equal(t, original.syncState.Config, restored.syncState.Config) + assert.Equal(t, original.syncState.TargetPosition, restored.syncState.TargetPosition) + assert.Equal(t, original.syncState.State, restored.syncState.State) + assert.Equal(t, original.syncState.Symbol, restored.syncState.Symbol) + require.NotNil(t, restored.syncState.ActiveOrder) + assert.Equal(t, original.syncState.ActiveOrder.OrderID, restored.syncState.ActiveOrder.OrderID) + require.NotNil(t, restored.syncState.TWAPExecutor) + assert.True(t, restored.syncState.TWAPExecutor.syncState.IsFutures) + }) + + t.Run("nil active order", func(t *testing.T) { + original := &TWAPWorker{ + syncState: TWAPWorkerSyncState{ + Config: TWAPWorkerConfig{ + Duration: 5 * time.Minute, + NumSlices: 3, + }, + TargetPosition: fixedpoint.NewFromFloat(2.0), + State: TWAPWorkerStateDone, + Symbol: "ETHUSDT", + }, + } + + data, err := json.Marshal(original) + require.NoError(t, err) + + restored := &TWAPWorker{} + err = json.Unmarshal(data, restored) + require.NoError(t, err) + + assert.Equal(t, original.syncState.Config, restored.syncState.Config) + assert.Equal(t, original.syncState.TargetPosition, restored.syncState.TargetPosition) + assert.Equal(t, original.syncState.State, restored.syncState.State) + assert.Equal(t, original.syncState.Symbol, restored.syncState.Symbol) + assert.Nil(t, restored.syncState.ActiveOrder) + }) + + t.Run("invalid JSON", func(t *testing.T) { + w := &TWAPWorker{} + err := w.UnmarshalJSON([]byte(`{invalid json`)) + require.Error(t, err) + assert.Contains(t, err.Error(), "failed to unmarshal TWAPWorker") + }) +} diff --git a/pkg/strategy/xfundingv2/twap_test.go b/pkg/strategy/xfundingv2/twap_test.go index ba3e1b17ce..f6a72a45c2 100644 --- a/pkg/strategy/xfundingv2/twap_test.go +++ b/pkg/strategy/xfundingv2/twap_test.go @@ -814,7 +814,7 @@ func TestTWAPWorker_Misc(t *testing.T) { }) t.Run("better buy price triggers update", func(t *testing.T) { - worker.activeOrder = &types.Order{ + worker.syncState.ActiveOrder = &types.Order{ OrderID: 1, SubmitOrder: types.SubmitOrder{ Side: types.SideTypeBuy, @@ -828,7 +828,7 @@ func TestTWAPWorker_Misc(t *testing.T) { }) t.Run("worse buy price does not trigger update", func(t *testing.T) { - worker.activeOrder = &types.Order{ + worker.syncState.ActiveOrder = &types.Order{ OrderID: 1, SubmitOrder: types.SubmitOrder{ Side: types.SideTypeBuy, @@ -843,7 +843,7 @@ func TestTWAPWorker_Misc(t *testing.T) { t.Run("better sell price triggers update", func(t *testing.T) { worker.SetTargetPosition(Number(-1.0)) - worker.activeOrder = &types.Order{ + worker.syncState.ActiveOrder = &types.Order{ OrderID: 1, SubmitOrder: types.SubmitOrder{ Side: types.SideTypeSell, @@ -878,7 +878,7 @@ func TestTWAPWorker_Misc(t *testing.T) { err := worker.Start(ctx, startTime) assert.NoError(t, err) - worker.activeOrder = &types.Order{ + worker.syncState.ActiveOrder = &types.Order{ OrderID: 1, SubmitOrder: types.SubmitOrder{ Side: types.SideTypeBuy, From a577ca2335a721ce5dce810b511cbb3fe37cc6b1 Mon Sep 17 00:00:00 2001 From: dboyliao <6830390+dboyliao@users.noreply.github.com> Date: Fri, 8 May 2026 23:30:20 +0800 Subject: [PATCH 4/7] =?UTF-8?q?=E2=9C=A8=20feat(xfundingv2):=20add=20LoadS?= =?UTF-8?q?trategy=20and=20JSON=20marshaling/unmarshaling=20for=20Arbitrag?= =?UTF-8?q?eRound=20and=20PendingRound?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/strategy/xfundingv2/arb_round.go | 286 +++++++++--------- pkg/strategy/xfundingv2/arb_round_fee.go | 10 +- pkg/strategy/xfundingv2/arb_round_pnl.go | 14 +- pkg/strategy/xfundingv2/arb_round_pnl_test.go | 16 +- pkg/strategy/xfundingv2/arb_round_sync.go | 111 +++++++ .../xfundingv2/arb_round_sync_test.go | 103 +++++++ pkg/strategy/xfundingv2/arb_round_test.go | 2 +- 7 files changed, 373 insertions(+), 169 deletions(-) create mode 100644 pkg/strategy/xfundingv2/arb_round_sync.go create mode 100644 pkg/strategy/xfundingv2/arb_round_sync_test.go diff --git a/pkg/strategy/xfundingv2/arb_round.go b/pkg/strategy/xfundingv2/arb_round.go index 2d99cef14d..b68757ab0b 100644 --- a/pkg/strategy/xfundingv2/arb_round.go +++ b/pkg/strategy/xfundingv2/arb_round.go @@ -34,46 +34,21 @@ type FuturesService interface { } type transferRetry struct { - Trade types.Trade - LastTried time.Time + Trade types.Trade `json:"trade"` + LastTried time.Time `json:"lastTried"` } type ArbitrageRound struct { mu sync.Mutex - triggeredFundingRate fixedpoint.Value - triggeredSpotTargetPosition fixedpoint.Value - minHoldingIntervals int - fundingIntervalHours int - fundingIntervalStart, fundingIntervalEnd time.Time - fundingFeeRecords map[int64]FundingFee + syncState ArbitrageRoundSyncState - // TWAP workers - spotWorker *TWAPWorker - futuresWorker *TWAPWorker - futuresService FuturesService - asset string // base asset, e.g. "BTC" - - spotFeeAssetAmount, futuresFeeAssetAmount fixedpoint.Value + futuresService FuturesService spotExchangeFeeRates, futuresExchangeFeeRates map[types.ExchangeName]types.ExchangeFee - feeSymbol string - avgFeeCost fixedpoint.Value - retryDuration time.Duration retryTransferTickC chan time.Time - retryTransfers map[uint64]transferRetry - - state RoundState logger logrus.FieldLogger - - // startTime is the time when the round is started - startTime time.Time - // closingTime is the time when the round is entered closing state - closingTime time.Time - closingDuration time.Duration - // lastUpdateTime is the last time when the round is updated - lastUpdateTime time.Time } func NewArbitrageRound( @@ -90,22 +65,24 @@ func NewArbitrageRound( fundingIntervalStart := fundingRate.NextFundingTime.Add(-time.Duration(fundingIntervalHours) * time.Hour) fundingIntervalEnd := fundingRate.NextFundingTime.Add(-time.Second) return &ArbitrageRound{ - triggeredFundingRate: fundingRate.LastFundingRate, - triggeredSpotTargetPosition: spotTwap.TargetPosition(), - minHoldingIntervals: minHoldingIntervals, - fundingIntervalHours: fundingIntervalHours, - fundingIntervalStart: fundingIntervalStart, - fundingIntervalEnd: fundingIntervalEnd, - fundingFeeRecords: make(map[int64]FundingFee), - - spotWorker: spotTwap, - futuresWorker: futuresTwap, - - futuresService: futuresService, - asset: asset, - - state: RoundPending, - retryTransfers: make(map[uint64]transferRetry), + syncState: ArbitrageRoundSyncState{ + Symbol: spotTwap.Symbol(), + TriggeredFundingRate: fundingRate.LastFundingRate, + TriggeredSpotTargetPosition: spotTwap.TargetPosition(), + MinHoldingIntervals: minHoldingIntervals, + FundingIntervalHours: fundingIntervalHours, + FundingIntervalStart: fundingIntervalStart, + FundingIntervalEnd: fundingIntervalEnd, + FundingFeeRecords: make(map[int64]FundingFee), + + SpotWorker: spotTwap, + FuturesWorker: futuresTwap, + Asset: asset, + State: RoundPending, + RetryTransfers: make(map[uint64]transferRetry), + }, + + futuresService: futuresService, retryTransferTickC: make(chan time.Time, 1), } } @@ -119,36 +96,36 @@ func (r *ArbitrageRound) SetFuturesExchangeFeeRates(rates map[types.ExchangeName } func (r *ArbitrageRound) SetAvgFeeCost(feeSymbol string, cost fixedpoint.Value) { - r.feeSymbol = feeSymbol - r.avgFeeCost = cost + r.syncState.FeeSymbol = feeSymbol + r.syncState.AvgFeeCost = cost } func (r *ArbitrageRound) SetSpotFeeAssetAmount(amount fixedpoint.Value) { r.mu.Lock() defer r.mu.Unlock() - r.spotFeeAssetAmount = amount + r.syncState.SpotFeeAssetAmount = amount } func (r *ArbitrageRound) SpotFeeAssetAmount() fixedpoint.Value { r.mu.Lock() defer r.mu.Unlock() - return r.spotFeeAssetAmount + return r.syncState.SpotFeeAssetAmount } func (r *ArbitrageRound) SetFuturesFeeAssetAmount(amount fixedpoint.Value) { r.mu.Lock() defer r.mu.Unlock() - r.futuresFeeAssetAmount = amount + r.syncState.FuturesFeeAssetAmount = amount } func (r *ArbitrageRound) FuturesFeeAssetAmount() fixedpoint.Value { r.mu.Lock() defer r.mu.Unlock() - return r.futuresFeeAssetAmount + return r.syncState.FuturesFeeAssetAmount } // RequiredFeeAssetAmount returns the required fee asset amount for the round based on its current state and position. @@ -157,16 +134,18 @@ func (r *ArbitrageRound) RequiredFeeAssetAmounts() (fixedpoint.Value, fixedpoint r.mu.Lock() defer r.mu.Unlock() - halfSpotFee := r.spotFeeAssetAmount.Div(fixedpoint.Two) - halfFuturesFee := r.futuresFeeAssetAmount.Div(fixedpoint.Two) - switch r.state { + halfSpotFee := r.syncState.SpotFeeAssetAmount.Div(fixedpoint.Two) + halfFuturesFee := r.syncState.FuturesFeeAssetAmount.Div(fixedpoint.Two) + switch r.syncState.State { case RoundPending: - return r.spotFeeAssetAmount, r.futuresFeeAssetAmount + return r.syncState.SpotFeeAssetAmount, r.syncState.FuturesFeeAssetAmount case RoundOpening: // calculate the executed ratio executedRatio := fixedpoint.Zero - if !r.spotWorker.TargetPosition().IsZero() { - executedRatio = r.spotWorker.FilledPosition().Abs().Div(r.spotWorker.TargetPosition().Abs()) + if !r.syncState.SpotWorker.TargetPosition().IsZero() { + executedRatio = r.syncState.SpotWorker.FilledPosition(). + Abs(). + Div(r.syncState.SpotWorker.TargetPosition().Abs()) } remainRatio := fixedpoint.Max( fixedpoint.One.Sub(executedRatio), @@ -177,8 +156,8 @@ func (r *ArbitrageRound) RequiredFeeAssetAmounts() (fixedpoint.Value, fixedpoint return halfSpotFee.Mul(remainRatio), halfFuturesFee.Mul(remainRatio) case RoundReady, RoundClosing: executedRatio := fixedpoint.Zero - if !r.triggeredSpotTargetPosition.IsZero() { - executedRatio = r.spotWorker.FilledPosition().Abs().Div(r.triggeredSpotTargetPosition.Abs()) + if !r.syncState.TriggeredSpotTargetPosition.IsZero() { + executedRatio = r.syncState.SpotWorker.FilledPosition().Abs().Div(r.syncState.TriggeredSpotTargetPosition.Abs()) } remainRatio := fixedpoint.Max( fixedpoint.One.Sub(executedRatio), @@ -191,58 +170,62 @@ func (r *ArbitrageRound) RequiredFeeAssetAmounts() (fixedpoint.Value, fixedpoint } func (r *ArbitrageRound) StartTime() time.Time { - return r.startTime + return r.syncState.StartTime +} + +func (r *ArbitrageRound) HasStarted() bool { + return !r.syncState.StartTime.IsZero() } func (r *ArbitrageRound) TriggeredFundingRate() fixedpoint.Value { - return r.triggeredFundingRate + return r.syncState.TriggeredFundingRate } func (r *ArbitrageRound) NumHoldingIntervals(currentTime time.Time) int { - if r.startTime.IsZero() { + if r.syncState.StartTime.IsZero() { return 0 } // the funding rate has not flipped, check if the minimum holding time has passed - intervalDuration := time.Duration(r.fundingIntervalHours) * time.Hour + intervalDuration := time.Duration(r.syncState.FundingIntervalHours) * time.Hour lastIntervalEnd := currentTime.Truncate(intervalDuration) - return int(lastIntervalEnd.Sub(r.fundingIntervalStart) / intervalDuration) + return int(lastIntervalEnd.Sub(r.syncState.FundingIntervalStart) / intervalDuration) } func (r *ArbitrageRound) MinHoldingIntervals() int { - return r.minHoldingIntervals + return r.syncState.MinHoldingIntervals } func (r *ArbitrageRound) TargetPosition() fixedpoint.Value { - return r.spotWorker.TargetPosition() + return r.syncState.SpotWorker.TargetPosition() } func (r *ArbitrageRound) LastUpdateTime() time.Time { - return r.lastUpdateTime + return r.syncState.LastUpdateTime } func (r *ArbitrageRound) SetUpdateTime(t time.Time) { - r.lastUpdateTime = t + r.syncState.LastUpdateTime = t } func (r *ArbitrageRound) String() string { - if r.state != RoundClosing { + if r.syncState.State != RoundClosing { return fmt.Sprintf( "ArbitrageRound(symbol=%s, state=%s, spot=%s, futures=%s, startTime=%s)", - r.spotWorker.Symbol(), - r.state, - r.spotWorker.FilledPosition(), - r.futuresWorker.FilledPosition(), - r.startTime.Format(time.RFC3339), + r.syncState.SpotWorker.Symbol(), + r.syncState.State, + r.syncState.SpotWorker.FilledPosition(), + r.syncState.FuturesWorker.FilledPosition(), + r.syncState.StartTime.Format(time.RFC3339), ) } return fmt.Sprintf( "ArbitrageRound(symbol=%s, state=%s, spot=%s, futures=%s, closingTime=%s, expectedCloseTime=%s)", - r.spotWorker.Symbol(), - r.state, - r.spotWorker.FilledPosition(), - r.futuresWorker.FilledPosition(), - r.closingTime.Format(time.RFC3339), - r.closingTime.Add(r.closingDuration).Format(time.RFC3339), + r.syncState.SpotWorker.Symbol(), + r.syncState.State, + r.syncState.SpotWorker.FilledPosition(), + r.syncState.FuturesWorker.FilledPosition(), + r.syncState.ClosingTime.Format(time.RFC3339), + r.syncState.ClosingTime.Add(r.syncState.ClosingDuration).Format(time.RFC3339), ) } @@ -250,7 +233,7 @@ func (r *ArbitrageRound) CollectedFunding(ctx context.Context, currentTime time. r.mu.Lock() defer r.mu.Unlock() - if r.startTime.IsZero() { + if r.syncState.StartTime.IsZero() { return fixedpoint.Zero } return r.collectedFunding(ctx, currentTime) @@ -260,7 +243,7 @@ func (r *ArbitrageRound) collectedFunding(ctx context.Context, currentTime time. r.syncFundingFeeRecords(ctx, currentTime) var totalFunding fixedpoint.Value - for _, fee := range r.fundingFeeRecords { + for _, fee := range r.syncState.FundingFeeRecords { totalFunding = totalFunding.Add(fee.Amount) } return totalFunding @@ -278,8 +261,8 @@ func (r *ArbitrageRound) Orders() map[string][]types.Order { defer r.mu.Unlock() orders := map[string][]types.Order{ - "spot": r.spotWorker.Executor().AllOrders(), - "futures": r.futuresWorker.Executor().AllOrders(), + "spot": r.syncState.SpotWorker.Executor().AllOrders(), + "futures": r.syncState.FuturesWorker.Executor().AllOrders(), } return orders @@ -290,8 +273,8 @@ func (r *ArbitrageRound) Trades() map[string][]types.Trade { defer r.mu.Unlock() trades := map[string][]types.Trade{ - "spot": r.spotWorker.Executor().AllTrades(), - "futures": r.futuresWorker.Executor().AllTrades(), + "spot": r.syncState.SpotWorker.Executor().AllTrades(), + "futures": r.syncState.FuturesWorker.Executor().AllTrades(), } return trades @@ -301,22 +284,22 @@ func (r *ArbitrageRound) HasOrder(orderID uint64) bool { r.mu.Lock() defer r.mu.Unlock() - _, spotExists := r.spotWorker.Executor().GetOrder(orderID) - _, futuresExists := r.futuresWorker.Executor().GetOrder(orderID) + _, spotExists := r.syncState.SpotWorker.Executor().GetOrder(orderID) + _, futuresExists := r.syncState.FuturesWorker.Executor().GetOrder(orderID) return spotExists || futuresExists } func (r *ArbitrageRound) syncFundingFeeRecords(ctx context.Context, currentTime time.Time) { - if r.startTime.IsZero() || r.startTime.After(currentTime) { + if r.syncState.StartTime.IsZero() || r.syncState.StartTime.After(currentTime) { return } q := batch.BinanceFuturesIncomeBatchQuery{ BinanceFuturesIncomeHistoryService: r.futuresService, } - symbol := r.futuresWorker.Symbol() - dataC, errC := q.Query(ctx, symbol, binanceapi.FuturesIncomeFundingFee, r.startTime, currentTime) + symbol := r.syncState.FuturesWorker.Symbol() + dataC, errC := q.Query(ctx, symbol, binanceapi.FuturesIncomeFundingFee, r.syncState.StartTime, currentTime) for { select { case <-ctx.Done(): @@ -334,7 +317,7 @@ func (r *ArbitrageRound) syncFundingFeeRecords(ctx context.Context, currentTime Txn: income.TranId, Time: income.Time.Time(), } - r.fundingFeeRecords[income.TranId] = record + r.syncState.FundingFeeRecords[income.TranId] = record } case err, ok := <-errC: if !ok { @@ -349,58 +332,57 @@ func (r *ArbitrageRound) syncFundingFeeRecords(ctx context.Context, currentTime } func (r *ArbitrageRound) Start(ctx context.Context, currentTime time.Time) error { - if r.startTime.IsZero() { - if currentTime.After(r.fundingIntervalEnd) { + if r.syncState.StartTime.IsZero() { + if currentTime.After(r.syncState.FundingIntervalEnd) { // the round is triggered after the funding interval -> error return fmt.Errorf( "the round is triggered after the funding interval end (%s): %s", - r.fundingIntervalEnd.Format(time.RFC3339), + r.syncState.FundingIntervalEnd.Format(time.RFC3339), currentTime.Format(time.RFC3339), ) } - if err := r.spotWorker.Start(ctx, currentTime); err != nil { + if err := r.syncState.SpotWorker.Start(ctx, currentTime); err != nil { return fmt.Errorf("failed to start spot worker: %w", err) } - if err := r.futuresWorker.Start(ctx, currentTime); err != nil { + if err := r.syncState.FuturesWorker.Start(ctx, currentTime); err != nil { return fmt.Errorf("failed to start futures worker: %w", err) } - go r.retryTransferWorker(ctx) + go r.retryTransferWorker(ctx, r.retryTransferTickC) - r.startTime = currentTime - r.state = RoundOpening + r.syncState.StartTime = currentTime + r.syncState.State = RoundOpening } return nil } func (r *ArbitrageRound) Stop() { - r.spotWorker.Stop() - r.futuresWorker.Stop() + r.syncState.SpotWorker.Stop() + r.syncState.FuturesWorker.Stop() close(r.retryTransferTickC) } -func (r *ArbitrageRound) retryTransferWorker(ctx context.Context) { +func (r *ArbitrageRound) retryTransferWorker(ctx context.Context, tickC <-chan time.Time) { for { select { case <-ctx.Done(): return - case currentTime, ok := <-r.retryTransferTickC: + case currentTime, ok := <-tickC: if !ok { return } // retry failed transfers if any r.mu.Lock() - for tradeID, transfer := range r.retryTransfers { - retryDuration := r.retryDuration - if retryDuration == 0 { + for tradeID, transfer := range r.syncState.RetryTransfers { + if r.syncState.RetryDuration == 0 { // default retry duration is 10 minutes - retryDuration = 10 * time.Minute + r.syncState.RetryDuration = 10 * time.Minute } - if currentTime.Sub(transfer.LastTried) < retryDuration { + if currentTime.Sub(transfer.LastTried) < r.syncState.RetryDuration { continue } + r.logger.Infof("retry transfer (trade: %d): %s %s", tradeID, transfer.Trade.Quantity.String(), r.syncState.Asset) r.HandleSpotTrade(transfer.Trade, currentTime) - r.logger.Infof("retry transfer succeeded (trade: %d): %s %s", tradeID, transfer.Trade.Quantity.String(), r.asset) } r.mu.Unlock() } @@ -408,7 +390,7 @@ func (r *ArbitrageRound) retryTransferWorker(ctx context.Context) { } func (r *ArbitrageRound) SetRetryDuration(d time.Duration) { - r.retryDuration = d + r.syncState.RetryDuration = d } func (r *ArbitrageRound) HandleSpotTrade(trade types.Trade, currentTime time.Time) { @@ -416,45 +398,45 @@ func (r *ArbitrageRound) HandleSpotTrade(trade types.Trade, currentTime time.Tim r.mu.Lock() defer r.mu.Unlock() - if trade.Symbol != r.spotWorker.Symbol() || trade.IsFutures { + if trade.Symbol != r.syncState.SpotWorker.Symbol() || trade.IsFutures { return } - r.spotWorker.AddTrade(trade) + r.syncState.SpotWorker.AddTrade(trade) // try to transfer asset from spot to futures. // if transfer fails, retry in the next tick until it succeeds if err := r.futuresService.TransferFuturesAccountAsset( - r.spotWorker.ctx, r.asset, trade.Quantity, types.TransferIn, + r.syncState.SpotWorker.ctx, r.syncState.Asset, trade.Quantity, types.TransferIn, ); err != nil { r.logger.WithError(err).Errorf("failed to transfer %s %s from futures to spot", - trade.Quantity, r.asset) - if _, found := r.retryTransfers[trade.ID]; !found { + trade.Quantity, r.syncState.Asset) + if _, found := r.syncState.RetryTransfers[trade.ID]; !found { bbgo.Notify( fmt.Errorf("transfer failed (%s %s), retrying: %w", trade.Quantity.String(), - r.asset, + r.syncState.Asset, err, ), ) } - r.retryTransfers[trade.ID] = transferRetry{ + r.syncState.RetryTransfers[trade.ID] = transferRetry{ Trade: trade, LastTried: currentTime, } return } // transfer succeeded, remove from retry list if exists - delete(r.retryTransfers, trade.ID) + delete(r.syncState.RetryTransfers, trade.ID) r.syncFuturesPosition(trade) } func (r *ArbitrageRound) HandleFuturesTrade(trade types.Trade, currentTime time.Time) { - if trade.Symbol != r.futuresWorker.Symbol() || !trade.IsFutures { + if trade.Symbol != r.syncState.FuturesWorker.Symbol() || !trade.IsFutures { return } r.logger.Infof("handling future trade: %s", trade) - r.futuresWorker.AddTrade(trade) + r.syncState.FuturesWorker.AddTrade(trade) } func (r *ArbitrageRound) SetLogger(logger logrus.FieldLogger) { @@ -462,33 +444,37 @@ func (r *ArbitrageRound) SetLogger(logger logrus.FieldLogger) { } func (r *ArbitrageRound) SpotSymbol() string { - return r.spotWorker.Symbol() + return r.syncState.SpotWorker.Symbol() } func (r *ArbitrageRound) FuturesSymbol() string { - return r.futuresWorker.Symbol() + return r.syncState.FuturesWorker.Symbol() +} + +func (r *ArbitrageRound) FuturesMarket() types.Market { + return r.syncState.FuturesWorker.Market() } func (r *ArbitrageRound) State() RoundState { - return r.state + return r.syncState.State } func (r *ArbitrageRound) SetClosing(currentTime time.Time, duration time.Duration) { r.mu.Lock() defer r.mu.Unlock() - r.spotWorker.SetTargetPosition(fixedpoint.Zero) - r.spotWorker.ResetTime(currentTime, duration) - r.futuresWorker.SetTargetPosition(fixedpoint.Zero) - r.futuresWorker.ResetTime(currentTime, duration) + r.syncState.SpotWorker.SetTargetPosition(fixedpoint.Zero) + r.syncState.SpotWorker.ResetTime(currentTime, duration) + r.syncState.FuturesWorker.SetTargetPosition(fixedpoint.Zero) + r.syncState.FuturesWorker.ResetTime(currentTime, duration) - r.state = RoundClosing - r.closingTime = currentTime - r.closingDuration = duration + r.syncState.State = RoundClosing + r.syncState.ClosingTime = currentTime + r.syncState.ClosingDuration = duration } func (r *ArbitrageRound) AnnualizedRate() fixedpoint.Value { - return AnnualizedRate(r.triggeredFundingRate, r.fundingIntervalHours) + return AnnualizedRate(r.syncState.TriggeredFundingRate, r.syncState.FundingIntervalHours) } func (r *ArbitrageRound) Tick(currentTime time.Time, spotOrderBook types.OrderBook, futuresOrderBook types.OrderBook) { @@ -498,7 +484,7 @@ func (r *ArbitrageRound) Tick(currentTime time.Time, spotOrderBook types.OrderBo defer func() { // the state is PositionOpening // check if the spot and futures positions are fully filled -> PositionReady - if r.state == RoundOpening { + if r.syncState.State == RoundOpening { // get mid price spotBid, _ := spotOrderBook.BestBid() spotAsk, _ := spotOrderBook.BestAsk() @@ -507,29 +493,29 @@ func (r *ArbitrageRound) Tick(currentTime time.Time, spotOrderBook types.OrderBo spotMidPrice := spotBid.Price.Add(spotAsk.Price).Div(fixedpoint.Two) futuresMidPrice := futuresBid.Price.Add(futuresAsk.Price).Div(fixedpoint.Two) - spotRemaining := r.spotWorker.RemainingQuantity() - futuresRemaining := r.futuresWorker.RemainingQuantity() - spotIsDust := r.spotWorker.Market().IsDustQuantity(spotRemaining.Abs(), spotMidPrice) - futuresIsDust := r.futuresWorker.Market().IsDustQuantity(futuresRemaining.Abs(), futuresMidPrice) + spotRemaining := r.syncState.SpotWorker.RemainingQuantity() + futuresRemaining := r.syncState.FuturesWorker.RemainingQuantity() + spotIsDust := r.syncState.SpotWorker.Market().IsDustQuantity(spotRemaining.Abs(), spotMidPrice) + futuresIsDust := r.syncState.FuturesWorker.Market().IsDustQuantity(futuresRemaining.Abs(), futuresMidPrice) if spotIsDust && futuresIsDust { - r.state = RoundReady + r.syncState.State = RoundReady return } } // the state is PositionClosing // check if the spot and futures positions are fully closed -> PositionClosed - if r.state == RoundClosing { - if r.spotWorker.FilledPosition().IsZero() && r.futuresWorker.FilledPosition().IsZero() { - r.state = RoundClosed - r.logger.Infof("positions closed, arbitrage round completed: %s", r.spotWorker.Symbol()) + if r.syncState.State == RoundClosing { + if r.syncState.SpotWorker.FilledPosition().IsZero() && r.syncState.FuturesWorker.FilledPosition().IsZero() { + r.syncState.State = RoundClosed + r.logger.Infof("positions closed, arbitrage round completed: %s", r.syncState.SpotWorker.Symbol()) } return } }() - if r.state == RoundPending { + if r.syncState.State == RoundPending { // not started yet, do nothing return } @@ -537,28 +523,28 @@ func (r *ArbitrageRound) Tick(currentTime time.Time, spotOrderBook types.OrderBo if r.logger == nil { r.logger = logrus.WithFields(logrus.Fields{ "component": "ArbitrageRound", - "symbol": r.spotWorker.Symbol(), + "symbol": r.syncState.SpotWorker.Symbol(), }) } - if r.state == RoundClosed || r.state == RoundReady { + if r.syncState.State == RoundClosed || r.syncState.State == RoundReady { return } r.retryTransferTickC <- currentTime // it's opening or closing, tick the workers - r.spotWorker.Tick(currentTime, spotOrderBook) - r.futuresWorker.Tick(currentTime, futuresOrderBook) + r.syncState.SpotWorker.Tick(currentTime, spotOrderBook) + r.syncState.FuturesWorker.Tick(currentTime, futuresOrderBook) } func (r *ArbitrageRound) syncFuturesPosition(trade types.Trade) { - futureTargetPosition := r.futuresWorker.TargetPosition() - if r.spotWorker.TargetPosition().Sign() > 0 { + futureTargetPosition := r.syncState.FuturesWorker.TargetPosition() + if r.syncState.SpotWorker.TargetPosition().Sign() > 0 { futureTargetPosition = futureTargetPosition.Sub(trade.Quantity) } else { futureTargetPosition = futureTargetPosition.Add(trade.Quantity) } r.logger.Infof("syncing futures position to %s", futureTargetPosition) - r.futuresWorker.SetTargetPosition(futureTargetPosition) + r.syncState.FuturesWorker.SetTargetPosition(futureTargetPosition) } diff --git a/pkg/strategy/xfundingv2/arb_round_fee.go b/pkg/strategy/xfundingv2/arb_round_fee.go index 68ba013b0a..9f5c8c7836 100644 --- a/pkg/strategy/xfundingv2/arb_round_fee.go +++ b/pkg/strategy/xfundingv2/arb_round_fee.go @@ -11,9 +11,13 @@ import ( ) type PendingRound struct { - Round *ArbitrageRound - RetryCount int - LastRetryTime time.Time + Round *ArbitrageRound `json:"round"` + RetryCount int `json:"retryCount"` + LastRetryTime time.Time `json:"lastRetryTime"` +} + +func (r *PendingRound) LoadStrategy(ctx context.Context, s *Strategy) error { + return r.Round.LoadStrategy(ctx, s) } func (s *Strategy) processPendingRounds(ctx context.Context, currentTime time.Time) { diff --git a/pkg/strategy/xfundingv2/arb_round_pnl.go b/pkg/strategy/xfundingv2/arb_round_pnl.go index b0eaf58f95..63db2f2d1b 100644 --- a/pkg/strategy/xfundingv2/arb_round_pnl.go +++ b/pkg/strategy/xfundingv2/arb_round_pnl.go @@ -43,8 +43,8 @@ func (r *ArbitrageRound) PnL(ctx context.Context, currentTime time.Time) RoundPn fundingIncome := r.collectedFunding(ctx, currentTime) - spotMarket := r.spotWorker.Market() - futuresMarket := r.futuresWorker.Market() + spotMarket := r.syncState.SpotWorker.Market() + futuresMarket := r.syncState.FuturesWorker.Market() spotPosition := types.NewPositionFromMarket(spotMarket) futuresPosition := types.NewPositionFromMarket(futuresMarket) @@ -54,19 +54,19 @@ func (r *ArbitrageRound) PnL(ctx context.Context, currentTime time.Time) RoundPn if r.futuresExchangeFeeRates != nil { futuresPosition.ExchangeFeeRates = r.futuresExchangeFeeRates } - if !r.avgFeeCost.IsZero() { + if !r.syncState.AvgFeeCost.IsZero() { spotPosition.FeeAverageCosts = map[string]fixedpoint.Value{ - r.feeSymbol: r.avgFeeCost, + r.syncState.FeeSymbol: r.syncState.AvgFeeCost, } futuresPosition.FeeAverageCosts = map[string]fixedpoint.Value{ - r.feeSymbol: r.avgFeeCost, + r.syncState.FeeSymbol: r.syncState.AvgFeeCost, } } spotProfitStats := types.NewProfitStats(spotMarket) futuresProfitStats := types.NewProfitStats(futuresMarket) - spotTrades := r.spotWorker.Executor().AllTrades() + spotTrades := r.syncState.SpotWorker.Executor().AllTrades() for _, trade := range spotTrades { profit, netProfit, madeProfit := spotPosition.AddTrade(trade) if madeProfit { @@ -75,7 +75,7 @@ func (r *ArbitrageRound) PnL(ctx context.Context, currentTime time.Time) RoundPn } } - futuresTrades := r.futuresWorker.Executor().AllTrades() + futuresTrades := r.syncState.FuturesWorker.Executor().AllTrades() for _, trade := range futuresTrades { profit, netProfit, madeProfit := futuresPosition.AddTrade(trade) if madeProfit { diff --git a/pkg/strategy/xfundingv2/arb_round_pnl_test.go b/pkg/strategy/xfundingv2/arb_round_pnl_test.go index 4c72bbebbd..b22a18f954 100644 --- a/pkg/strategy/xfundingv2/arb_round_pnl_test.go +++ b/pkg/strategy/xfundingv2/arb_round_pnl_test.go @@ -32,11 +32,11 @@ func TestArbitrageRound_TradePnL(t *testing.T) { t.Run("returns zero profit when position is only opened", func(t *testing.T) { // Add orders first so AddTrade accepts them - spotExecutor := round.spotWorker.Executor() - spotExecutor.orders[1] = types.OrderQuery{OrderID: "1"} + spotExecutor := round.syncState.SpotWorker.Executor() + spotExecutor.syncState.Orders[1] = types.OrderQuery{OrderID: "1"} - futuresExecutor := round.futuresWorker.Executor() - futuresExecutor.orders[2] = types.OrderQuery{OrderID: "2"} + futuresExecutor := round.syncState.FuturesWorker.Executor() + futuresExecutor.syncState.Orders[2] = types.OrderQuery{OrderID: "2"} // Opening trades: buy spot at 40000, sell futures at 40100 spotExecutor.AddTrade(types.Trade{ @@ -73,11 +73,11 @@ func TestArbitrageRound_TradePnL(t *testing.T) { }) t.Run("calculates realized profit after closing trades", func(t *testing.T) { - spotExecutor := round.spotWorker.Executor() - spotExecutor.orders[3] = types.OrderQuery{OrderID: "3"} + spotExecutor := round.syncState.SpotWorker.Executor() + spotExecutor.syncState.Orders[3] = types.OrderQuery{OrderID: "3"} - futuresExecutor := round.futuresWorker.Executor() - futuresExecutor.orders[4] = types.OrderQuery{OrderID: "4"} + futuresExecutor := round.syncState.FuturesWorker.Executor() + futuresExecutor.syncState.Orders[4] = types.OrderQuery{OrderID: "4"} // Closing trades: sell spot at 41000 (profit), buy futures at 39900 (profit) spotExecutor.AddTrade(types.Trade{ diff --git a/pkg/strategy/xfundingv2/arb_round_sync.go b/pkg/strategy/xfundingv2/arb_round_sync.go new file mode 100644 index 0000000000..6e12f43794 --- /dev/null +++ b/pkg/strategy/xfundingv2/arb_round_sync.go @@ -0,0 +1,111 @@ +package xfundingv2 + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "time" + + "github.com/c9s/bbgo/pkg/fixedpoint" + "github.com/c9s/bbgo/pkg/types" +) + +func (r *ArbitrageRound) LoadStrategy(ctx context.Context, s *Strategy) error { + r.SetLogger(s.logger) + r.SetFuturesExchangeFeeRates( + map[types.ExchangeName]types.ExchangeFee{ + s.futuresSession.Exchange.Name(): { + MakerFeeRate: s.futuresSession.MakerFeeRate, + TakerFeeRate: s.futuresSession.TakerFeeRate, + }, + }, + ) + r.SetSpotExchangeFeeRates( + map[types.ExchangeName]types.ExchangeFee{ + s.spotSession.Exchange.Name(): { + MakerFeeRate: s.spotSession.MakerFeeRate, + TakerFeeRate: s.spotSession.TakerFeeRate, + }, + }, + ) + r.retryTransferTickC = make(chan time.Time) + if !r.HasStarted() { + // the round has been started before, we need to start the retry worker + go r.retryTransferWorker(ctx, r.retryTransferTickC) + } + if service, ok := s.futuresSession.Exchange.(FuturesService); ok { + r.futuresService = service + } else { + return errors.New("[ArbitrageRound] futures exchange does not implement FuturesService") + } + if r.syncState.SpotWorker != nil { + if err := r.syncState.SpotWorker.LoadStrategy(ctx, s); err != nil { + return fmt.Errorf("[ArbitrageRound] spot load strategy error: %w", err) + } + } else { + // should not happend + // by the time we create the round, the spot worker is never nil + // the restored round should always have the spot worker restored as well. + return errors.New("[ArbitrageRound] spot worker is nil") + } + if r.syncState.FuturesWorker != nil { + if err := r.syncState.FuturesWorker.LoadStrategy(ctx, s); err != nil { + return fmt.Errorf("[ArbitrageRound] futures load strategy error: %w", err) + } + } else { + // should not happend + // by the time we create the round, the futures worker is never nil + // the restored round should always have the futures worker restored as well. + return errors.New("[ArbitrageRound] futures worker is nil") + } + + return nil +} + +type ArbitrageRoundSyncState struct { + TriggeredFundingRate fixedpoint.Value `json:"triggeredFundingRate"` + TriggeredSpotTargetPosition fixedpoint.Value `json:"triggeredSpotTargetPosition"` + MinHoldingIntervals int `json:"minHoldingIntervals"` + FundingIntervalHours int `json:"fundingIntervalHours"` + FundingIntervalStart time.Time `json:"fundingIntervalStart"` + FundingIntervalEnd time.Time `json:"fundingIntervalEnd"` + FundingFeeRecords map[int64]FundingFee `json:"fundingFeeRecords"` + + Symbol string `json:"symbol"` + SpotWorker *TWAPWorker `json:"spotWorker,omitempty"` + FuturesWorker *TWAPWorker `json:"futuresWorker,omitempty"` + Asset string `json:"asset"` // base asset, e.g. "BTC" + + SpotFeeAssetAmount fixedpoint.Value `json:"spotFeeAssetAmount"` + FuturesFeeAssetAmount fixedpoint.Value `json:"futuresFeeAssetAmount"` + FeeSymbol string `json:"feeSymbol"` + AvgFeeCost fixedpoint.Value `json:"avgFeeCost"` + + RetryDuration time.Duration `json:"retryDuration"` + RetryTransfers map[uint64]transferRetry `json:"retryTransfers"` + + State RoundState `json:"state"` + + // StartTime is the time when the round is started + StartTime time.Time `json:"startTime"` + // ClosingTime is the time when the round is entered closing state + ClosingTime time.Time `json:"closingTime"` + ClosingDuration time.Duration `json:"closingDuration"` + // LastUpdateTime is the last time when the round is updated + LastUpdateTime time.Time `json:"lastUpdateTime"` +} + +func (r *ArbitrageRound) MarshalJSON() ([]byte, error) { + return json.Marshal(r.syncState) +} + +func (r *ArbitrageRound) UnmarshalJSON(b []byte) error { + syncState := ArbitrageRoundSyncState{} + if err := json.Unmarshal(b, &syncState); err != nil { + return err + } + + r.syncState = syncState + return nil +} diff --git a/pkg/strategy/xfundingv2/arb_round_sync_test.go b/pkg/strategy/xfundingv2/arb_round_sync_test.go new file mode 100644 index 0000000000..46ed7ec776 --- /dev/null +++ b/pkg/strategy/xfundingv2/arb_round_sync_test.go @@ -0,0 +1,103 @@ +package xfundingv2 + +import ( + "encoding/json" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/c9s/bbgo/pkg/fixedpoint" + "github.com/c9s/bbgo/pkg/types" +) + +func TestArbitrageRound_MarshalUnmarshalJSON(t *testing.T) { + t.Run("round_trip_preserves_all_fields", func(t *testing.T) { + startTime := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + closingTime := time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC) + lastUpdateTime := time.Date(2025, 1, 2, 1, 0, 0, 0, time.UTC) + fundingStart := time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC) + fundingEnd := time.Date(2025, 1, 1, 7, 59, 59, 0, time.UTC) + + round := &ArbitrageRound{ + syncState: ArbitrageRoundSyncState{ + TriggeredFundingRate: fixedpoint.NewFromFloat(0.001), + TriggeredSpotTargetPosition: fixedpoint.NewFromFloat(0.5), + MinHoldingIntervals: 3, + FundingIntervalHours: 8, + FundingIntervalStart: fundingStart, + FundingIntervalEnd: fundingEnd, + FundingFeeRecords: map[int64]FundingFee{ + 100: { + Asset: "BTC", + Amount: fixedpoint.NewFromFloat(0.0001), + Txn: 100, + Time: startTime, + }, + }, + Asset: "BTC", + + SpotFeeAssetAmount: fixedpoint.NewFromFloat(0.01), + FuturesFeeAssetAmount: fixedpoint.NewFromFloat(0.02), + FeeSymbol: "BNB", + AvgFeeCost: fixedpoint.NewFromFloat(600.0), + + RetryDuration: 5 * time.Minute, + RetryTransfers: map[uint64]transferRetry{ + 1: { + Trade: types.Trade{ + ID: 1, + Exchange: types.ExchangeBinance, + Side: types.SideTypeBuy, + }, + LastTried: startTime, + }, + }, + + State: RoundReady, + StartTime: startTime, + ClosingTime: closingTime, + ClosingDuration: 30 * time.Minute, + LastUpdateTime: lastUpdateTime, + }, + } + + data, err := json.Marshal(round) + require.NoError(t, err) + + var restored ArbitrageRound + err = json.Unmarshal(data, &restored) + require.NoError(t, err) + + assert.Equal(t, round.syncState, restored.syncState) + }) + + t.Run("nil_workers_are_preserved", func(t *testing.T) { + round := &ArbitrageRound{ + syncState: ArbitrageRoundSyncState{ + TriggeredFundingRate: fixedpoint.NewFromFloat(0.0005), + State: RoundPending, + Asset: "ETH", + FundingFeeRecords: make(map[int64]FundingFee), + }, + } + + data, err := json.Marshal(round) + require.NoError(t, err) + + var restored ArbitrageRound + err = json.Unmarshal(data, &restored) + require.NoError(t, err) + + assert.Nil(t, restored.syncState.SpotWorker) + assert.Nil(t, restored.syncState.FuturesWorker) + assert.Equal(t, round.syncState, restored.syncState) + }) + + t.Run("unmarshal_invalid_json_returns_error", func(t *testing.T) { + var round ArbitrageRound + err := json.Unmarshal([]byte(`{invalid`), &round) + assert.Error(t, err) + }) +} diff --git a/pkg/strategy/xfundingv2/arb_round_test.go b/pkg/strategy/xfundingv2/arb_round_test.go index be3c291e4d..0154f41d80 100644 --- a/pkg/strategy/xfundingv2/arb_round_test.go +++ b/pkg/strategy/xfundingv2/arb_round_test.go @@ -76,7 +76,7 @@ func TestArbitrageRound_CollectedFunding(t *testing.T) { }) t.Run("sums funding fee records", func(t *testing.T) { - round.startTime = time.Date(2024, 1, 1, 1, 0, 0, 0, time.UTC) + round.syncState.StartTime = time.Date(2024, 1, 1, 1, 0, 0, 0, time.UTC) // Simulate funding fee income returned by the service mockService.incomeHistory = []binanceapi.FuturesIncome{ From 9953ce4486553d72336c1e5cab581aeed6983575 Mon Sep 17 00:00:00 2001 From: dboyliao <6830390+dboyliao@users.noreply.github.com> Date: Fri, 8 May 2026 23:30:40 +0800 Subject: [PATCH 5/7] =?UTF-8?q?=E2=9C=A8=20feat(xfundingv2):=20implement?= =?UTF-8?q?=20persistence=20for=20strategy=20state?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/strategy/xfundingv2/strategy.go | 84 ++++++++++++++++++----------- 1 file changed, 52 insertions(+), 32 deletions(-) diff --git a/pkg/strategy/xfundingv2/strategy.go b/pkg/strategy/xfundingv2/strategy.go index 0f6f3e79d7..751ac18b44 100644 --- a/pkg/strategy/xfundingv2/strategy.go +++ b/pkg/strategy/xfundingv2/strategy.go @@ -68,15 +68,15 @@ type Strategy struct { costEstimator *CostEstimator preliminaryMarketSelector *MarketSelector - pendingRounds map[string]*PendingRound - activeRounds map[string]*ArbitrageRound - coinmarketcapClient *coinmarketcap.DataSource - // persist the positions + // persistence states + // pending rounds and active rounds + pendingRounds map[string]*PendingRound `persistence:"pendingRounds"` + activeRounds map[string]*ArbitrageRound `persistence:"activeRounds"` // the positions are shared across rounds and the executors of the same symbol. - spotPositions map[string]*types.Position `persistence:"spot_positions"` - futuresPositions map[string]*types.Position `persistence:"futures_positions"` + spotPositions map[string]*types.Position `persistence:"spotPositions,omitempty"` + futuresPositions map[string]*types.Position `persistence:"futuresPositions,omitempty"` // order executors for each symbol // we need to cache the executors as map at startup since the executors are bound to the user data stream (via `.Bind()`). @@ -146,29 +146,15 @@ func (s *Strategy) Initialize() error { s.futuresOrderBooks = make(map[string]*types.StreamOrderBook) s.spotOrderBooks = make(map[string]*types.StreamOrderBook) - // Initialize position maps (may be populated by LoadState if persisted state exists) - if s.spotPositions == nil { - s.spotPositions = make(map[string]*types.Position) - } - if s.futuresPositions == nil { - s.futuresPositions = make(map[string]*types.Position) - } - // Initialize executor maps - if s.spotGeneralOrderExecutors == nil { - s.spotGeneralOrderExecutors = make(map[string]*bbgo.GeneralOrderExecutor) - } - if s.futuresGeneralOrderExecutors == nil { - s.futuresGeneralOrderExecutors = make(map[string]*bbgo.GeneralOrderExecutor) - } + s.spotGeneralOrderExecutors = make(map[string]*bbgo.GeneralOrderExecutor) + s.futuresGeneralOrderExecutors = make(map[string]*bbgo.GeneralOrderExecutor) if !bbgo.IsBackTesting { s.logLimiter = rate.NewLimiter(rate.Every(time.Minute*10), 1) } if s.MaxPositionExposure == nil { s.MaxPositionExposure = make(map[string]fixedpoint.Value) } - s.activeRounds = make(map[string]*ArbitrageRound) - s.pendingRounds = make(map[string]*PendingRound) return nil } @@ -202,8 +188,24 @@ func (s *Strategy) CrossSubscribe(sessions map[string]*bbgo.ExchangeSession) { } func (s *Strategy) CrossRun( - ctx context.Context, orderExecutionRouter bbgo.OrderExecutionRouter, sessions map[string]*bbgo.ExchangeSession, + ctx context.Context, _ bbgo.OrderExecutionRouter, sessions map[string]*bbgo.ExchangeSession, ) error { + // Initialize position maps (may be populated by LoadState if persisted state exists) + if s.spotPositions == nil { + s.spotPositions = make(map[string]*types.Position) + } + if s.futuresPositions == nil { + s.futuresPositions = make(map[string]*types.Position) + } + + // Initialize round maps (may be populated by LoadState if persisted state exists) + if s.activeRounds == nil { + s.activeRounds = make(map[string]*ArbitrageRound) + } + if s.pendingRounds == nil { + s.pendingRounds = make(map[string]*PendingRound) + } + s.spotSession = sessions[s.SpotSession] s.futuresSession = sessions[s.FuturesSession] @@ -377,6 +379,19 @@ func (s *Strategy) CrossRun( binanceEx, _ := s.futuresSession.Exchange.(*binance.Exchange) s.preliminaryMarketSelector = NewMarketSelector(*s.MarketSelectionConfig, binanceEx, s.logger) + // runtime init done, load pending and active rounds + for symbol, pendingRound := range s.pendingRounds { + if err := pendingRound.LoadStrategy(ctx, s); err != nil { + return fmt.Errorf("failed to restore pending round (%s): %w", symbol, err) + } + } + for symbol, activeRound := range s.activeRounds { + if err := activeRound.LoadStrategy(ctx, s); err != nil { + return fmt.Errorf("failed to restore active round (%s): %w", symbol, err) + } + } + + // setup callbacks for _, sess := range []*bbgo.ExchangeSession{s.spotSession, s.futuresSession} { sess.MarketDataStream.OnKLineClosed(types.KLineWith(s.TickSymbol, types.Interval1m, func(kline types.KLine) { s.tick(ctx, kline.EndTime.Time()) @@ -449,8 +464,8 @@ func (s *Strategy) tick(ctx context.Context, tickTime time.Time) { // 4. tick existing active rounds for _, round := range s.activeRounds { - spotOrderBook := s.spotOrderBooks[round.spotWorker.Symbol()].Copy() - futuresOrderBook := s.futuresOrderBooks[round.futuresWorker.Symbol()].Copy() + spotOrderBook := s.spotOrderBooks[round.SpotSymbol()].Copy() + futuresOrderBook := s.futuresOrderBooks[round.FuturesSymbol()].Copy() round.Tick(tickTime, spotOrderBook, futuresOrderBook) } } @@ -544,8 +559,13 @@ func (s *Strategy) transitClosingRound(ctx context.Context, round *ArbitrageRoun func (s *Strategy) checkOpenNewRound(ctx context.Context, currentTime time.Time) { var lastOpenTime time.Time for _, round := range s.activeRounds { + startTime := round.StartTime() if lastOpenTime.IsZero() { - lastOpenTime = round.StartTime() + lastOpenTime = startTime + continue + } + if startTime.After(lastOpenTime) { + lastOpenTime = startTime } } if !lastOpenTime.IsZero() && currentTime.Sub(lastOpenTime) < s.OpenPositionInterval.Duration() { @@ -553,9 +573,9 @@ func (s *Strategy) checkOpenNewRound(ctx context.Context, currentTime time.Time) return } + // Only open new round when there is no active round + // TODO: support multiple active rounds for different symbols concurrently (e.g BTCUSDT and ETHUSDT) if len(s.activeRounds) == 0 { - // Only open new round when there is no active round - // TODO: support multiple active rounds for different symbols concurrently (e.g BTCUSDT and ETHUSDT) candidates, err := s.preliminaryMarketSelector.SelectMarkets(ctx, s.candidateSymbols) if err != nil { s.logger.WithError(err).Error("failed to select market candidates") @@ -581,14 +601,14 @@ func (s *Strategy) checkOpenNewRound(ctx context.Context, currentTime time.Time) if selectedCandidate.MinHoldingDuration <= s.MarketSelectionConfig.MaxHoldingHours.Duration() { spotExecutor := s.spotGeneralOrderExecutors[selectedCandidate.Symbol] spotTwap, err := NewTWAPWorker(ctx, selectedCandidate.Symbol, s.spotSession, spotExecutor, s.TWAPWorkerConfig) - if err != nil { + if err != nil || spotTwap == nil { s.logger.WithError(err).Errorf("failed to create TWAP worker for spot %s", selectedCandidate.Symbol) return } spotTwap.SetTargetPosition(selectedCandidate.TargetFuturesPosition.Neg()) futuresExecutor := s.futuresGeneralOrderExecutors[selectedCandidate.Symbol] futuresTwap, err := NewTWAPWorker(ctx, selectedCandidate.Symbol, s.futuresSession, futuresExecutor, s.TWAPWorkerConfig) - if err != nil { + if err != nil || futuresTwap == nil { s.logger.WithError(err).Errorf("failed to create TWAP worker for futures %s", selectedCandidate.Symbol) return } @@ -827,10 +847,10 @@ func (s *Strategy) handleRoundExit(ctx context.Context, round *ArbitrageRound, t switch s.MarketSelectionConfig.FuturesDirection { case types.PositionShort: // short futures -> transfer base currency - asset = round.futuresWorker.Market().BaseCurrency + asset = round.FuturesMarket().BaseCurrency case types.PositionLong: // long futures -> transfer quote currency - asset = round.futuresWorker.Market().QuoteCurrency + asset = round.FuturesMarket().QuoteCurrency } account := s.futuresSession.GetAccount() balance, ok := account.Balance(asset) From 52eed00527948b408025ac6cb49b174d42c00202 Mon Sep 17 00:00:00 2001 From: dboyliao <6830390+dboyliao@users.noreply.github.com> Date: Mon, 11 May 2026 20:31:47 +0800 Subject: [PATCH 6/7] =?UTF-8?q?=F0=9F=94=84=20refactor(xfundingv2):=20extr?= =?UTF-8?q?act=20spot/futures=20workers=20out=20from=20the=20sync=20state?= =?UTF-8?q?=20struct=20for=20arbitrage=20round?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/strategy/xfundingv2/arb_round.go | 100 +++++++++--------- pkg/strategy/xfundingv2/arb_round_pnl.go | 8 +- pkg/strategy/xfundingv2/arb_round_pnl_test.go | 8 +- pkg/strategy/xfundingv2/arb_round_sync.go | 40 ++++--- .../xfundingv2/arb_round_sync_test.go | 17 +-- 5 files changed, 97 insertions(+), 76 deletions(-) diff --git a/pkg/strategy/xfundingv2/arb_round.go b/pkg/strategy/xfundingv2/arb_round.go index b68757ab0b..21b3b91da0 100644 --- a/pkg/strategy/xfundingv2/arb_round.go +++ b/pkg/strategy/xfundingv2/arb_round.go @@ -41,7 +41,9 @@ type transferRetry struct { type ArbitrageRound struct { mu sync.Mutex - syncState ArbitrageRoundSyncState + syncState ArbitrageRoundSyncState + spotWorker *TWAPWorker + futuresWorker *TWAPWorker futuresService FuturesService spotExchangeFeeRates, futuresExchangeFeeRates map[types.ExchangeName]types.ExchangeFee @@ -75,13 +77,13 @@ func NewArbitrageRound( FundingIntervalEnd: fundingIntervalEnd, FundingFeeRecords: make(map[int64]FundingFee), - SpotWorker: spotTwap, - FuturesWorker: futuresTwap, Asset: asset, State: RoundPending, RetryTransfers: make(map[uint64]transferRetry), }, + spotWorker: spotTwap, + futuresWorker: futuresTwap, futuresService: futuresService, retryTransferTickC: make(chan time.Time, 1), } @@ -142,10 +144,10 @@ func (r *ArbitrageRound) RequiredFeeAssetAmounts() (fixedpoint.Value, fixedpoint case RoundOpening: // calculate the executed ratio executedRatio := fixedpoint.Zero - if !r.syncState.SpotWorker.TargetPosition().IsZero() { - executedRatio = r.syncState.SpotWorker.FilledPosition(). + if !r.spotWorker.TargetPosition().IsZero() { + executedRatio = r.spotWorker.FilledPosition(). Abs(). - Div(r.syncState.SpotWorker.TargetPosition().Abs()) + Div(r.spotWorker.TargetPosition().Abs()) } remainRatio := fixedpoint.Max( fixedpoint.One.Sub(executedRatio), @@ -157,7 +159,7 @@ func (r *ArbitrageRound) RequiredFeeAssetAmounts() (fixedpoint.Value, fixedpoint case RoundReady, RoundClosing: executedRatio := fixedpoint.Zero if !r.syncState.TriggeredSpotTargetPosition.IsZero() { - executedRatio = r.syncState.SpotWorker.FilledPosition().Abs().Div(r.syncState.TriggeredSpotTargetPosition.Abs()) + executedRatio = r.spotWorker.FilledPosition().Abs().Div(r.syncState.TriggeredSpotTargetPosition.Abs()) } remainRatio := fixedpoint.Max( fixedpoint.One.Sub(executedRatio), @@ -196,7 +198,7 @@ func (r *ArbitrageRound) MinHoldingIntervals() int { } func (r *ArbitrageRound) TargetPosition() fixedpoint.Value { - return r.syncState.SpotWorker.TargetPosition() + return r.spotWorker.TargetPosition() } func (r *ArbitrageRound) LastUpdateTime() time.Time { @@ -211,19 +213,19 @@ func (r *ArbitrageRound) String() string { if r.syncState.State != RoundClosing { return fmt.Sprintf( "ArbitrageRound(symbol=%s, state=%s, spot=%s, futures=%s, startTime=%s)", - r.syncState.SpotWorker.Symbol(), + r.spotWorker.Symbol(), r.syncState.State, - r.syncState.SpotWorker.FilledPosition(), - r.syncState.FuturesWorker.FilledPosition(), + r.spotWorker.FilledPosition(), + r.futuresWorker.FilledPosition(), r.syncState.StartTime.Format(time.RFC3339), ) } return fmt.Sprintf( "ArbitrageRound(symbol=%s, state=%s, spot=%s, futures=%s, closingTime=%s, expectedCloseTime=%s)", - r.syncState.SpotWorker.Symbol(), + r.spotWorker.Symbol(), r.syncState.State, - r.syncState.SpotWorker.FilledPosition(), - r.syncState.FuturesWorker.FilledPosition(), + r.spotWorker.FilledPosition(), + r.futuresWorker.FilledPosition(), r.syncState.ClosingTime.Format(time.RFC3339), r.syncState.ClosingTime.Add(r.syncState.ClosingDuration).Format(time.RFC3339), ) @@ -261,8 +263,8 @@ func (r *ArbitrageRound) Orders() map[string][]types.Order { defer r.mu.Unlock() orders := map[string][]types.Order{ - "spot": r.syncState.SpotWorker.Executor().AllOrders(), - "futures": r.syncState.FuturesWorker.Executor().AllOrders(), + "spot": r.spotWorker.Executor().AllOrders(), + "futures": r.futuresWorker.Executor().AllOrders(), } return orders @@ -273,8 +275,8 @@ func (r *ArbitrageRound) Trades() map[string][]types.Trade { defer r.mu.Unlock() trades := map[string][]types.Trade{ - "spot": r.syncState.SpotWorker.Executor().AllTrades(), - "futures": r.syncState.FuturesWorker.Executor().AllTrades(), + "spot": r.spotWorker.Executor().AllTrades(), + "futures": r.futuresWorker.Executor().AllTrades(), } return trades @@ -284,8 +286,8 @@ func (r *ArbitrageRound) HasOrder(orderID uint64) bool { r.mu.Lock() defer r.mu.Unlock() - _, spotExists := r.syncState.SpotWorker.Executor().GetOrder(orderID) - _, futuresExists := r.syncState.FuturesWorker.Executor().GetOrder(orderID) + _, spotExists := r.spotWorker.Executor().GetOrder(orderID) + _, futuresExists := r.futuresWorker.Executor().GetOrder(orderID) return spotExists || futuresExists } @@ -298,7 +300,7 @@ func (r *ArbitrageRound) syncFundingFeeRecords(ctx context.Context, currentTime q := batch.BinanceFuturesIncomeBatchQuery{ BinanceFuturesIncomeHistoryService: r.futuresService, } - symbol := r.syncState.FuturesWorker.Symbol() + symbol := r.futuresWorker.Symbol() dataC, errC := q.Query(ctx, symbol, binanceapi.FuturesIncomeFundingFee, r.syncState.StartTime, currentTime) for { select { @@ -341,10 +343,10 @@ func (r *ArbitrageRound) Start(ctx context.Context, currentTime time.Time) error currentTime.Format(time.RFC3339), ) } - if err := r.syncState.SpotWorker.Start(ctx, currentTime); err != nil { + if err := r.spotWorker.Start(ctx, currentTime); err != nil { return fmt.Errorf("failed to start spot worker: %w", err) } - if err := r.syncState.FuturesWorker.Start(ctx, currentTime); err != nil { + if err := r.futuresWorker.Start(ctx, currentTime); err != nil { return fmt.Errorf("failed to start futures worker: %w", err) } @@ -357,8 +359,8 @@ func (r *ArbitrageRound) Start(ctx context.Context, currentTime time.Time) error } func (r *ArbitrageRound) Stop() { - r.syncState.SpotWorker.Stop() - r.syncState.FuturesWorker.Stop() + r.spotWorker.Stop() + r.futuresWorker.Stop() close(r.retryTransferTickC) } @@ -398,16 +400,16 @@ func (r *ArbitrageRound) HandleSpotTrade(trade types.Trade, currentTime time.Tim r.mu.Lock() defer r.mu.Unlock() - if trade.Symbol != r.syncState.SpotWorker.Symbol() || trade.IsFutures { + if trade.Symbol != r.spotWorker.Symbol() || trade.IsFutures { return } - r.syncState.SpotWorker.AddTrade(trade) + r.spotWorker.AddTrade(trade) // try to transfer asset from spot to futures. // if transfer fails, retry in the next tick until it succeeds if err := r.futuresService.TransferFuturesAccountAsset( - r.syncState.SpotWorker.ctx, r.syncState.Asset, trade.Quantity, types.TransferIn, + r.spotWorker.ctx, r.syncState.Asset, trade.Quantity, types.TransferIn, ); err != nil { r.logger.WithError(err).Errorf("failed to transfer %s %s from futures to spot", trade.Quantity, r.syncState.Asset) @@ -432,11 +434,11 @@ func (r *ArbitrageRound) HandleSpotTrade(trade types.Trade, currentTime time.Tim } func (r *ArbitrageRound) HandleFuturesTrade(trade types.Trade, currentTime time.Time) { - if trade.Symbol != r.syncState.FuturesWorker.Symbol() || !trade.IsFutures { + if trade.Symbol != r.futuresWorker.Symbol() || !trade.IsFutures { return } r.logger.Infof("handling future trade: %s", trade) - r.syncState.FuturesWorker.AddTrade(trade) + r.futuresWorker.AddTrade(trade) } func (r *ArbitrageRound) SetLogger(logger logrus.FieldLogger) { @@ -444,15 +446,15 @@ func (r *ArbitrageRound) SetLogger(logger logrus.FieldLogger) { } func (r *ArbitrageRound) SpotSymbol() string { - return r.syncState.SpotWorker.Symbol() + return r.spotWorker.Symbol() } func (r *ArbitrageRound) FuturesSymbol() string { - return r.syncState.FuturesWorker.Symbol() + return r.futuresWorker.Symbol() } func (r *ArbitrageRound) FuturesMarket() types.Market { - return r.syncState.FuturesWorker.Market() + return r.futuresWorker.Market() } func (r *ArbitrageRound) State() RoundState { @@ -463,10 +465,10 @@ func (r *ArbitrageRound) SetClosing(currentTime time.Time, duration time.Duratio r.mu.Lock() defer r.mu.Unlock() - r.syncState.SpotWorker.SetTargetPosition(fixedpoint.Zero) - r.syncState.SpotWorker.ResetTime(currentTime, duration) - r.syncState.FuturesWorker.SetTargetPosition(fixedpoint.Zero) - r.syncState.FuturesWorker.ResetTime(currentTime, duration) + r.spotWorker.SetTargetPosition(fixedpoint.Zero) + r.spotWorker.ResetTime(currentTime, duration) + r.futuresWorker.SetTargetPosition(fixedpoint.Zero) + r.futuresWorker.ResetTime(currentTime, duration) r.syncState.State = RoundClosing r.syncState.ClosingTime = currentTime @@ -493,10 +495,10 @@ func (r *ArbitrageRound) Tick(currentTime time.Time, spotOrderBook types.OrderBo spotMidPrice := spotBid.Price.Add(spotAsk.Price).Div(fixedpoint.Two) futuresMidPrice := futuresBid.Price.Add(futuresAsk.Price).Div(fixedpoint.Two) - spotRemaining := r.syncState.SpotWorker.RemainingQuantity() - futuresRemaining := r.syncState.FuturesWorker.RemainingQuantity() - spotIsDust := r.syncState.SpotWorker.Market().IsDustQuantity(spotRemaining.Abs(), spotMidPrice) - futuresIsDust := r.syncState.FuturesWorker.Market().IsDustQuantity(futuresRemaining.Abs(), futuresMidPrice) + spotRemaining := r.spotWorker.RemainingQuantity() + futuresRemaining := r.futuresWorker.RemainingQuantity() + spotIsDust := r.spotWorker.Market().IsDustQuantity(spotRemaining.Abs(), spotMidPrice) + futuresIsDust := r.futuresWorker.Market().IsDustQuantity(futuresRemaining.Abs(), futuresMidPrice) if spotIsDust && futuresIsDust { r.syncState.State = RoundReady @@ -507,9 +509,9 @@ func (r *ArbitrageRound) Tick(currentTime time.Time, spotOrderBook types.OrderBo // the state is PositionClosing // check if the spot and futures positions are fully closed -> PositionClosed if r.syncState.State == RoundClosing { - if r.syncState.SpotWorker.FilledPosition().IsZero() && r.syncState.FuturesWorker.FilledPosition().IsZero() { + if r.spotWorker.FilledPosition().IsZero() && r.futuresWorker.FilledPosition().IsZero() { r.syncState.State = RoundClosed - r.logger.Infof("positions closed, arbitrage round completed: %s", r.syncState.SpotWorker.Symbol()) + r.logger.Infof("positions closed, arbitrage round completed: %s", r.spotWorker.Symbol()) } return } @@ -523,7 +525,7 @@ func (r *ArbitrageRound) Tick(currentTime time.Time, spotOrderBook types.OrderBo if r.logger == nil { r.logger = logrus.WithFields(logrus.Fields{ "component": "ArbitrageRound", - "symbol": r.syncState.SpotWorker.Symbol(), + "symbol": r.spotWorker.Symbol(), }) } @@ -534,17 +536,17 @@ func (r *ArbitrageRound) Tick(currentTime time.Time, spotOrderBook types.OrderBo r.retryTransferTickC <- currentTime // it's opening or closing, tick the workers - r.syncState.SpotWorker.Tick(currentTime, spotOrderBook) - r.syncState.FuturesWorker.Tick(currentTime, futuresOrderBook) + r.spotWorker.Tick(currentTime, spotOrderBook) + r.futuresWorker.Tick(currentTime, futuresOrderBook) } func (r *ArbitrageRound) syncFuturesPosition(trade types.Trade) { - futureTargetPosition := r.syncState.FuturesWorker.TargetPosition() - if r.syncState.SpotWorker.TargetPosition().Sign() > 0 { + futureTargetPosition := r.futuresWorker.TargetPosition() + if r.spotWorker.TargetPosition().Sign() > 0 { futureTargetPosition = futureTargetPosition.Sub(trade.Quantity) } else { futureTargetPosition = futureTargetPosition.Add(trade.Quantity) } r.logger.Infof("syncing futures position to %s", futureTargetPosition) - r.syncState.FuturesWorker.SetTargetPosition(futureTargetPosition) + r.futuresWorker.SetTargetPosition(futureTargetPosition) } diff --git a/pkg/strategy/xfundingv2/arb_round_pnl.go b/pkg/strategy/xfundingv2/arb_round_pnl.go index 63db2f2d1b..c7b784bd1e 100644 --- a/pkg/strategy/xfundingv2/arb_round_pnl.go +++ b/pkg/strategy/xfundingv2/arb_round_pnl.go @@ -43,8 +43,8 @@ func (r *ArbitrageRound) PnL(ctx context.Context, currentTime time.Time) RoundPn fundingIncome := r.collectedFunding(ctx, currentTime) - spotMarket := r.syncState.SpotWorker.Market() - futuresMarket := r.syncState.FuturesWorker.Market() + spotMarket := r.spotWorker.Market() + futuresMarket := r.futuresWorker.Market() spotPosition := types.NewPositionFromMarket(spotMarket) futuresPosition := types.NewPositionFromMarket(futuresMarket) @@ -66,7 +66,7 @@ func (r *ArbitrageRound) PnL(ctx context.Context, currentTime time.Time) RoundPn spotProfitStats := types.NewProfitStats(spotMarket) futuresProfitStats := types.NewProfitStats(futuresMarket) - spotTrades := r.syncState.SpotWorker.Executor().AllTrades() + spotTrades := r.spotWorker.Executor().AllTrades() for _, trade := range spotTrades { profit, netProfit, madeProfit := spotPosition.AddTrade(trade) if madeProfit { @@ -75,7 +75,7 @@ func (r *ArbitrageRound) PnL(ctx context.Context, currentTime time.Time) RoundPn } } - futuresTrades := r.syncState.FuturesWorker.Executor().AllTrades() + futuresTrades := r.futuresWorker.Executor().AllTrades() for _, trade := range futuresTrades { profit, netProfit, madeProfit := futuresPosition.AddTrade(trade) if madeProfit { diff --git a/pkg/strategy/xfundingv2/arb_round_pnl_test.go b/pkg/strategy/xfundingv2/arb_round_pnl_test.go index b22a18f954..fe7efd82ad 100644 --- a/pkg/strategy/xfundingv2/arb_round_pnl_test.go +++ b/pkg/strategy/xfundingv2/arb_round_pnl_test.go @@ -32,10 +32,10 @@ func TestArbitrageRound_TradePnL(t *testing.T) { t.Run("returns zero profit when position is only opened", func(t *testing.T) { // Add orders first so AddTrade accepts them - spotExecutor := round.syncState.SpotWorker.Executor() + spotExecutor := round.spotWorker.Executor() spotExecutor.syncState.Orders[1] = types.OrderQuery{OrderID: "1"} - futuresExecutor := round.syncState.FuturesWorker.Executor() + futuresExecutor := round.futuresWorker.Executor() futuresExecutor.syncState.Orders[2] = types.OrderQuery{OrderID: "2"} // Opening trades: buy spot at 40000, sell futures at 40100 @@ -73,10 +73,10 @@ func TestArbitrageRound_TradePnL(t *testing.T) { }) t.Run("calculates realized profit after closing trades", func(t *testing.T) { - spotExecutor := round.syncState.SpotWorker.Executor() + spotExecutor := round.spotWorker.Executor() spotExecutor.syncState.Orders[3] = types.OrderQuery{OrderID: "3"} - futuresExecutor := round.syncState.FuturesWorker.Executor() + futuresExecutor := round.futuresWorker.Executor() futuresExecutor.syncState.Orders[4] = types.OrderQuery{OrderID: "4"} // Closing trades: sell spot at 41000 (profit), buy futures at 39900 (profit) diff --git a/pkg/strategy/xfundingv2/arb_round_sync.go b/pkg/strategy/xfundingv2/arb_round_sync.go index 6e12f43794..1d312bbd12 100644 --- a/pkg/strategy/xfundingv2/arb_round_sync.go +++ b/pkg/strategy/xfundingv2/arb_round_sync.go @@ -39,8 +39,8 @@ func (r *ArbitrageRound) LoadStrategy(ctx context.Context, s *Strategy) error { } else { return errors.New("[ArbitrageRound] futures exchange does not implement FuturesService") } - if r.syncState.SpotWorker != nil { - if err := r.syncState.SpotWorker.LoadStrategy(ctx, s); err != nil { + if r.spotWorker != nil { + if err := r.spotWorker.LoadStrategy(ctx, s); err != nil { return fmt.Errorf("[ArbitrageRound] spot load strategy error: %w", err) } } else { @@ -49,8 +49,8 @@ func (r *ArbitrageRound) LoadStrategy(ctx context.Context, s *Strategy) error { // the restored round should always have the spot worker restored as well. return errors.New("[ArbitrageRound] spot worker is nil") } - if r.syncState.FuturesWorker != nil { - if err := r.syncState.FuturesWorker.LoadStrategy(ctx, s); err != nil { + if r.futuresWorker != nil { + if err := r.futuresWorker.LoadStrategy(ctx, s); err != nil { return fmt.Errorf("[ArbitrageRound] futures load strategy error: %w", err) } } else { @@ -72,10 +72,10 @@ type ArbitrageRoundSyncState struct { FundingIntervalEnd time.Time `json:"fundingIntervalEnd"` FundingFeeRecords map[int64]FundingFee `json:"fundingFeeRecords"` - Symbol string `json:"symbol"` - SpotWorker *TWAPWorker `json:"spotWorker,omitempty"` - FuturesWorker *TWAPWorker `json:"futuresWorker,omitempty"` - Asset string `json:"asset"` // base asset, e.g. "BTC" + Symbol string `json:"symbol"` + SpotExchangeName types.ExchangeName `json:"spotExchangeName"` + FuturesExchangeName types.ExchangeName `json:"futuresExchangeName"` + Asset string `json:"asset"` // base asset, e.g. "BTC" SpotFeeAssetAmount fixedpoint.Value `json:"spotFeeAssetAmount"` FuturesFeeAssetAmount fixedpoint.Value `json:"futuresFeeAssetAmount"` @@ -97,15 +97,29 @@ type ArbitrageRoundSyncState struct { } func (r *ArbitrageRound) MarshalJSON() ([]byte, error) { - return json.Marshal(r.syncState) + v := struct { + SyncState ArbitrageRoundSyncState `json:"syncState"` + SpotWorker *TWAPWorker `json:"spotWorker,omitempty"` + FuturesWorker *TWAPWorker `json:"futuresWorker,omitempty"` + }{ + SyncState: r.syncState, + SpotWorker: r.spotWorker, + FuturesWorker: r.futuresWorker, + } + return json.Marshal(v) } func (r *ArbitrageRound) UnmarshalJSON(b []byte) error { - syncState := ArbitrageRoundSyncState{} - if err := json.Unmarshal(b, &syncState); err != nil { + v := struct { + SyncState ArbitrageRoundSyncState `json:"syncState"` + SpotWorker *TWAPWorker `json:"spotWorker,omitempty"` + FuturesWorker *TWAPWorker `json:"futuresWorker,omitempty"` + }{} + if err := json.Unmarshal(b, &v); err != nil { return err } - - r.syncState = syncState + r.syncState = v.SyncState + r.spotWorker = v.SpotWorker + r.futuresWorker = v.FuturesWorker return nil } diff --git a/pkg/strategy/xfundingv2/arb_round_sync_test.go b/pkg/strategy/xfundingv2/arb_round_sync_test.go index 46ed7ec776..1110a75cb8 100644 --- a/pkg/strategy/xfundingv2/arb_round_sync_test.go +++ b/pkg/strategy/xfundingv2/arb_round_sync_test.go @@ -36,7 +36,10 @@ func TestArbitrageRound_MarshalUnmarshalJSON(t *testing.T) { Time: startTime, }, }, - Asset: "BTC", + Symbol: "BTCUSDT", + SpotExchangeName: types.ExchangeBinance, + FuturesExchangeName: types.ExchangeBinance, + Asset: "BTC", SpotFeeAssetAmount: fixedpoint.NewFromFloat(0.01), FuturesFeeAssetAmount: fixedpoint.NewFromFloat(0.02), @@ -77,9 +80,11 @@ func TestArbitrageRound_MarshalUnmarshalJSON(t *testing.T) { round := &ArbitrageRound{ syncState: ArbitrageRoundSyncState{ TriggeredFundingRate: fixedpoint.NewFromFloat(0.0005), - State: RoundPending, - Asset: "ETH", - FundingFeeRecords: make(map[int64]FundingFee), + SpotExchangeName: types.ExchangeBinance, + FuturesExchangeName: types.ExchangeBinance, + State: RoundPending, + Asset: "ETH", + FundingFeeRecords: make(map[int64]FundingFee), }, } @@ -90,8 +95,8 @@ func TestArbitrageRound_MarshalUnmarshalJSON(t *testing.T) { err = json.Unmarshal(data, &restored) require.NoError(t, err) - assert.Nil(t, restored.syncState.SpotWorker) - assert.Nil(t, restored.syncState.FuturesWorker) + assert.Nil(t, restored.spotWorker) + assert.Nil(t, restored.futuresWorker) assert.Equal(t, round.syncState, restored.syncState) }) From 4d594792496c7797ee24c2114d50c2b891f9c9c6 Mon Sep 17 00:00:00 2001 From: dboyliao <6830390+dboyliao@users.noreply.github.com> Date: Mon, 11 May 2026 20:36:38 +0800 Subject: [PATCH 7/7] =?UTF-8?q?=F0=9F=94=84=20refactor(xfundingv2):=20rena?= =?UTF-8?q?me=20LoadStrategy=20to=20Initialize?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pkg/strategy/xfundingv2/arb_round_fee.go | 4 ++-- pkg/strategy/xfundingv2/arb_round_sync.go | 6 +++--- pkg/strategy/xfundingv2/strategy.go | 4 ++-- pkg/strategy/xfundingv2/twap_order_executor_sync.go | 2 +- pkg/strategy/xfundingv2/twap_sync.go | 4 ++-- 5 files changed, 10 insertions(+), 10 deletions(-) diff --git a/pkg/strategy/xfundingv2/arb_round_fee.go b/pkg/strategy/xfundingv2/arb_round_fee.go index 9f5c8c7836..0fea8fc93a 100644 --- a/pkg/strategy/xfundingv2/arb_round_fee.go +++ b/pkg/strategy/xfundingv2/arb_round_fee.go @@ -16,8 +16,8 @@ type PendingRound struct { LastRetryTime time.Time `json:"lastRetryTime"` } -func (r *PendingRound) LoadStrategy(ctx context.Context, s *Strategy) error { - return r.Round.LoadStrategy(ctx, s) +func (r *PendingRound) Initialize(ctx context.Context, s *Strategy) error { + return r.Round.Initialize(ctx, s) } func (s *Strategy) processPendingRounds(ctx context.Context, currentTime time.Time) { diff --git a/pkg/strategy/xfundingv2/arb_round_sync.go b/pkg/strategy/xfundingv2/arb_round_sync.go index 1d312bbd12..d28c1180d3 100644 --- a/pkg/strategy/xfundingv2/arb_round_sync.go +++ b/pkg/strategy/xfundingv2/arb_round_sync.go @@ -11,7 +11,7 @@ import ( "github.com/c9s/bbgo/pkg/types" ) -func (r *ArbitrageRound) LoadStrategy(ctx context.Context, s *Strategy) error { +func (r *ArbitrageRound) Initialize(ctx context.Context, s *Strategy) error { r.SetLogger(s.logger) r.SetFuturesExchangeFeeRates( map[types.ExchangeName]types.ExchangeFee{ @@ -40,7 +40,7 @@ func (r *ArbitrageRound) LoadStrategy(ctx context.Context, s *Strategy) error { return errors.New("[ArbitrageRound] futures exchange does not implement FuturesService") } if r.spotWorker != nil { - if err := r.spotWorker.LoadStrategy(ctx, s); err != nil { + if err := r.spotWorker.Initialize(ctx, s); err != nil { return fmt.Errorf("[ArbitrageRound] spot load strategy error: %w", err) } } else { @@ -50,7 +50,7 @@ func (r *ArbitrageRound) LoadStrategy(ctx context.Context, s *Strategy) error { return errors.New("[ArbitrageRound] spot worker is nil") } if r.futuresWorker != nil { - if err := r.futuresWorker.LoadStrategy(ctx, s); err != nil { + if err := r.futuresWorker.Initialize(ctx, s); err != nil { return fmt.Errorf("[ArbitrageRound] futures load strategy error: %w", err) } } else { diff --git a/pkg/strategy/xfundingv2/strategy.go b/pkg/strategy/xfundingv2/strategy.go index 751ac18b44..995a30b893 100644 --- a/pkg/strategy/xfundingv2/strategy.go +++ b/pkg/strategy/xfundingv2/strategy.go @@ -381,12 +381,12 @@ func (s *Strategy) CrossRun( // runtime init done, load pending and active rounds for symbol, pendingRound := range s.pendingRounds { - if err := pendingRound.LoadStrategy(ctx, s); err != nil { + if err := pendingRound.Initialize(ctx, s); err != nil { return fmt.Errorf("failed to restore pending round (%s): %w", symbol, err) } } for symbol, activeRound := range s.activeRounds { - if err := activeRound.LoadStrategy(ctx, s); err != nil { + if err := activeRound.Initialize(ctx, s); err != nil { return fmt.Errorf("failed to restore active round (%s): %w", symbol, err) } } diff --git a/pkg/strategy/xfundingv2/twap_order_executor_sync.go b/pkg/strategy/xfundingv2/twap_order_executor_sync.go index d611975767..4f8ab4de73 100644 --- a/pkg/strategy/xfundingv2/twap_order_executor_sync.go +++ b/pkg/strategy/xfundingv2/twap_order_executor_sync.go @@ -17,7 +17,7 @@ type TWAPExecutorSyncState struct { Trades map[uint64]types.Trade `json:"trades,omitempty"` } -func (o *TWAPExecutor) LoadStrategy(s *Strategy) error { +func (o *TWAPExecutor) Initialize(s *Strategy) error { o.SetLogger(s.logger) var session *bbgo.ExchangeSession var executor *bbgo.GeneralOrderExecutor diff --git a/pkg/strategy/xfundingv2/twap_sync.go b/pkg/strategy/xfundingv2/twap_sync.go index c6cb6b8365..9d56eebc38 100644 --- a/pkg/strategy/xfundingv2/twap_sync.go +++ b/pkg/strategy/xfundingv2/twap_sync.go @@ -10,7 +10,7 @@ import ( "github.com/c9s/bbgo/pkg/types" ) -func (w *TWAPWorker) LoadStrategy(ctx context.Context, s *Strategy) error { +func (w *TWAPWorker) Initialize(ctx context.Context, s *Strategy) error { if w.syncState.TWAPExecutor == nil { // should not happen return fmt.Errorf("[TWAPWorker] TWAPExecutor is nil") @@ -18,7 +18,7 @@ func (w *TWAPWorker) LoadStrategy(ctx context.Context, s *Strategy) error { w.ctx = ctx w.SetLogger(s.logger) - if err := w.syncState.TWAPExecutor.LoadStrategy(s); err != nil { + if err := w.syncState.TWAPExecutor.Initialize(s); err != nil { return fmt.Errorf("[TWAPWorker] failed to load TWAPExecutor: %w", err) } if w.syncState.ActiveOrder != nil {