fix(report): prevent state change loss during in-flight ReportState

- Consume stateChanged atomically with the state snapshot under a single Lock
- Restore stateChanged on UpdateTask error so the change is not silently lost
- Collapse the early-return check into the same Lock to avoid triple locking
- Add tests covering the in-flight Fire race and the error-restore path

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
Bo-Yi Wu
2026-04-12 11:24:22 +08:00
parent 7031b3507d
commit 2931fe9e48
2 changed files with 116 additions and 13 deletions

View File

@@ -464,18 +464,16 @@ func (r *Reporter) ReportState(reportResult bool) error {
return true return true
}) })
r.stateMu.RLock() // Consume stateChanged atomically with the snapshot; restored on error
changed := r.stateChanged // below so a concurrent Fire() during UpdateTask isn't silently lost.
r.stateMu.RUnlock() r.stateMu.Lock()
if !reportResult && !r.stateChanged && len(outputs) == 0 {
// Early return avoids the expensive proto.Clone on the common no-op path. r.stateMu.Unlock()
if !reportResult && !changed && len(outputs) == 0 {
return nil return nil
} }
r.stateMu.RLock()
state := proto.Clone(r.state).(*runnerv1.TaskState) state := proto.Clone(r.state).(*runnerv1.TaskState)
r.stateMu.RUnlock() r.stateChanged = false
r.stateMu.Unlock()
if !reportResult { if !reportResult {
state.Result = runnerv1.Result_RESULT_UNSPECIFIED state.Result = runnerv1.Result_RESULT_UNSPECIFIED
@@ -486,13 +484,12 @@ func (r *Reporter) ReportState(reportResult bool) error {
Outputs: outputs, Outputs: outputs,
})) }))
if err != nil { if err != nil {
r.stateMu.Lock()
r.stateChanged = true
r.stateMu.Unlock()
return err return err
} }
r.stateMu.Lock()
r.stateChanged = false
r.stateMu.Unlock()
for _, k := range resp.Msg.SentOutputs { for _, k := range resp.Msg.SentOutputs {
r.outputs.Store(k, struct{}{}) r.outputs.Store(k, struct{}{})
} }

View File

@@ -442,6 +442,112 @@ func TestReporter_BatchSizeFlush(t *testing.T) {
"batch size threshold should have triggered immediate flush") "batch size threshold should have triggered immediate flush")
} }
// TestReporter_StateChangedNotLostDuringReport asserts that a Fire() arriving
// mid-UpdateTask re-dirties the flag so the change is picked up by the next report.
func TestReporter_StateChangedNotLostDuringReport(t *testing.T) {
var updateTaskCalls atomic.Int64
inFlight := make(chan struct{})
release := make(chan struct{})
client := mocks.NewClient(t)
client.On("UpdateTask", mock.Anything, mock.Anything).Return(
func(_ context.Context, _ *connect_go.Request[runnerv1.UpdateTaskRequest]) (*connect_go.Response[runnerv1.UpdateTaskResponse], error) {
n := updateTaskCalls.Add(1)
if n == 1 {
// Signal that the first UpdateTask is in flight, then block until released.
close(inFlight)
<-release
}
return connect_go.NewResponse(&runnerv1.UpdateTaskResponse{}), nil
},
)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
taskCtx, err := structpb.NewStruct(map[string]any{})
require.NoError(t, err)
cfg, _ := config.LoadDefault("")
reporter := NewReporter(ctx, cancel, client, &runnerv1.Task{Context: taskCtx}, cfg)
reporter.ResetSteps(2)
// Mark stateChanged=true so the first ReportState proceeds to UpdateTask.
reporter.stateMu.Lock()
reporter.stateChanged = true
reporter.stateMu.Unlock()
// Kick off the first ReportState in a goroutine — it will block in UpdateTask.
done := make(chan error, 1)
go func() {
done <- reporter.ReportState(false)
}()
// Wait until UpdateTask is in flight (snapshot taken, flag consumed).
<-inFlight
// Concurrent Fire() modifies state — must re-flip stateChanged so the
// change is not lost when the in-flight ReportState finishes.
require.NoError(t, reporter.Fire(&log.Entry{
Message: "step starts",
Data: log.Fields{"stage": "Main", "stepNumber": 1, "raw_output": true},
}))
// Release the in-flight UpdateTask and wait for it to return.
close(release)
require.NoError(t, <-done)
// stateChanged must still be true so the next ReportState picks up the
// concurrent Fire()'s change instead of skipping via the early-return path.
reporter.stateMu.RLock()
changed := reporter.stateChanged
reporter.stateMu.RUnlock()
assert.True(t, changed, "stateChanged must remain true after a concurrent Fire() during in-flight ReportState")
// And the next ReportState must actually send a second UpdateTask.
require.NoError(t, reporter.ReportState(false))
assert.Equal(t, int64(2), updateTaskCalls.Load(), "concurrent Fire() change must trigger a second UpdateTask, not be silently lost")
}
// TestReporter_StateChangedRestoredOnError verifies that when UpdateTask fails,
// the dirty flag is restored so the snapshotted change isn't silently lost.
func TestReporter_StateChangedRestoredOnError(t *testing.T) {
var updateTaskCalls atomic.Int64
client := mocks.NewClient(t)
client.On("UpdateTask", mock.Anything, mock.Anything).Return(
func(_ context.Context, _ *connect_go.Request[runnerv1.UpdateTaskRequest]) (*connect_go.Response[runnerv1.UpdateTaskResponse], error) {
n := updateTaskCalls.Add(1)
if n == 1 {
return nil, errors.New("transient network error")
}
return connect_go.NewResponse(&runnerv1.UpdateTaskResponse{}), nil
},
)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
taskCtx, err := structpb.NewStruct(map[string]any{})
require.NoError(t, err)
cfg, _ := config.LoadDefault("")
reporter := NewReporter(ctx, cancel, client, &runnerv1.Task{Context: taskCtx}, cfg)
reporter.ResetSteps(1)
reporter.stateMu.Lock()
reporter.stateChanged = true
reporter.stateMu.Unlock()
// First ReportState fails — flag must be restored to true.
require.Error(t, reporter.ReportState(false))
reporter.stateMu.RLock()
changed := reporter.stateChanged
reporter.stateMu.RUnlock()
assert.True(t, changed, "stateChanged must be restored to true after UpdateTask error so the change is retried")
// The next ReportState should still issue a request because the flag was restored.
require.NoError(t, reporter.ReportState(false))
assert.Equal(t, int64(2), updateTaskCalls.Load())
}
// TestReporter_StateNotifyFlush verifies that step transitions trigger // TestReporter_StateNotifyFlush verifies that step transitions trigger
// an immediate state flush via the stateNotify channel. // an immediate state flush via the stateNotify channel.
func TestReporter_StateNotifyFlush(t *testing.T) { func TestReporter_StateNotifyFlush(t *testing.T) {