Skip to content

Commit 53a1b51

Browse files
committed
Made writes to output stream synchronous
This ensures that the discovery outputs the `quit` message before quitting.
1 parent 410dd27 commit 53a1b51

File tree

4 files changed

+765
-53
lines changed

4 files changed

+765
-53
lines changed

discovery.go

+49-47
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ import (
3434
"regexp"
3535
"strconv"
3636
"strings"
37+
"sync"
3738

3839
"github.com/arduino/go-properties-orderedmap"
3940
)
@@ -85,23 +86,23 @@ type ErrorCallback func(err string)
8586
// it must be created using the NewServer function.
8687
type Server struct {
8788
impl Discovery
88-
outputChan chan *message
8989
userAgent string
9090
reqProtocolVersion int
9191
initialized bool
9292
started bool
9393
syncStarted bool
9494
cachedPorts map[string]*Port
9595
cachedErr string
96+
output io.Writer
97+
outputMutex sync.Mutex
9698
}
9799

98100
// NewServer creates a new discovery server backed by the
99101
// provided pluggable discovery implementation. To start the server
100102
// use the Run method.
101103
func NewServer(impl Discovery) *Server {
102104
return &Server{
103-
impl: impl,
104-
outputChan: make(chan *message),
105+
impl: impl,
105106
}
106107
}
107108

@@ -111,21 +112,20 @@ func NewServer(impl Discovery) *Server {
111112
// the input stream is closed. In case of IO error the error is
112113
// returned.
113114
func (d *Server) Run(in io.Reader, out io.Writer) error {
114-
go d.outputProcessor(out)
115-
defer close(d.outputChan)
115+
d.output = out
116116
reader := bufio.NewReader(in)
117117
for {
118118
fullCmd, err := reader.ReadString('\n')
119119
if err != nil {
120-
d.outputChan <- messageError("command_error", err.Error())
120+
d.send(messageError("command_error", err.Error()))
121121
return err
122122
}
123123
fullCmd = strings.TrimSpace(fullCmd)
124124
split := strings.Split(fullCmd, " ")
125125
cmd := strings.ToUpper(split[0])
126126

127127
if !d.initialized && cmd != "HELLO" && cmd != "QUIT" {
128-
d.outputChan <- messageError("command_error", fmt.Sprintf("First command must be HELLO, but got '%s'", cmd))
128+
d.send(messageError("command_error", fmt.Sprintf("First command must be HELLO, but got '%s'", cmd)))
129129
continue
130130
}
131131

@@ -142,61 +142,61 @@ func (d *Server) Run(in io.Reader, out io.Writer) error {
142142
d.stop()
143143
case "QUIT":
144144
d.impl.Quit()
145-
d.outputChan <- messageOk("quit")
145+
d.send(messageOk("quit"))
146146
return nil
147147
default:
148-
d.outputChan <- messageError("command_error", fmt.Sprintf("Command %s not supported", cmd))
148+
d.send(messageError("command_error", fmt.Sprintf("Command %s not supported", cmd)))
149149
}
150150
}
151151
}
152152

153153
func (d *Server) hello(cmd string) {
154154
if d.initialized {
155-
d.outputChan <- messageError("hello", "HELLO already called")
155+
d.send(messageError("hello", "HELLO already called"))
156156
return
157157
}
158158
re := regexp.MustCompile(`^(\d+) "([^"]+)"$`)
159159
matches := re.FindStringSubmatch(cmd)
160160
if len(matches) != 3 {
161-
d.outputChan <- messageError("hello", "Invalid HELLO command")
161+
d.send(messageError("hello", "Invalid HELLO command"))
162162
return
163163
}
164164
d.userAgent = matches[2]
165165
v, err := strconv.ParseInt(matches[1], 10, 64)
166166
if err != nil {
167-
d.outputChan <- messageError("hello", "Invalid protocol version: "+matches[2])
167+
d.send(messageError("hello", "Invalid protocol version: "+matches[2]))
168168
return
169169
}
170170
d.reqProtocolVersion = int(v)
171171
if err := d.impl.Hello(d.userAgent, 1); err != nil {
172-
d.outputChan <- messageError("hello", err.Error())
172+
d.send(messageError("hello", err.Error()))
173173
return
174174
}
175-
d.outputChan <- &message{
175+
d.send(&message{
176176
EventType: "hello",
177177
ProtocolVersion: 1, // Protocol version 1 is the only supported for now...
178178
Message: "OK",
179-
}
179+
})
180180
d.initialized = true
181181
}
182182

183183
func (d *Server) start() {
184184
if d.started {
185-
d.outputChan <- messageError("start", "Discovery already STARTed")
185+
d.send(messageError("start", "Discovery already STARTed"))
186186
return
187187
}
188188
if d.syncStarted {
189-
d.outputChan <- messageError("start", "Discovery already START_SYNCed, cannot START")
189+
d.send(messageError("start", "Discovery already START_SYNCed, cannot START"))
190190
return
191191
}
192192
d.cachedPorts = map[string]*Port{}
193193
d.cachedErr = ""
194194
if err := d.impl.StartSync(d.eventCallback, d.errorCallback); err != nil {
195-
d.outputChan <- messageError("start", "Cannot START: "+err.Error())
195+
d.send(messageError("start", "Cannot START: "+err.Error()))
196196
return
197197
}
198198
d.started = true
199-
d.outputChan <- messageOk("start")
199+
d.send(messageOk("start"))
200200
}
201201

202202
func (d *Server) eventCallback(event string, port *Port) {
@@ -215,82 +215,84 @@ func (d *Server) errorCallback(msg string) {
215215

216216
func (d *Server) list() {
217217
if !d.started {
218-
d.outputChan <- messageError("list", "Discovery not STARTed")
218+
d.send(messageError("list", "Discovery not STARTed"))
219219
return
220220
}
221221
if d.syncStarted {
222-
d.outputChan <- messageError("list", "discovery already START_SYNCed, LIST not allowed")
222+
d.send(messageError("list", "discovery already START_SYNCed, LIST not allowed"))
223223
return
224224
}
225225
if d.cachedErr != "" {
226-
d.outputChan <- messageError("list", d.cachedErr)
226+
d.send(messageError("list", d.cachedErr))
227227
return
228228
}
229229
ports := []*Port{}
230230
for _, port := range d.cachedPorts {
231231
ports = append(ports, port)
232232
}
233-
d.outputChan <- &message{
233+
d.send(&message{
234234
EventType: "list",
235235
Ports: &ports,
236-
}
236+
})
237237
}
238238

239239
func (d *Server) startSync() {
240240
if d.syncStarted {
241-
d.outputChan <- messageError("start_sync", "Discovery already START_SYNCed")
241+
d.send(messageError("start_sync", "Discovery already START_SYNCed"))
242242
return
243243
}
244244
if d.started {
245-
d.outputChan <- messageError("start_sync", "Discovery already STARTed, cannot START_SYNC")
245+
d.send(messageError("start_sync", "Discovery already STARTed, cannot START_SYNC"))
246246
return
247247
}
248248
if err := d.impl.StartSync(d.syncEvent, d.errorEvent); err != nil {
249-
d.outputChan <- messageError("start_sync", "Cannot START_SYNC: "+err.Error())
249+
d.send(messageError("start_sync", "Cannot START_SYNC: "+err.Error()))
250250
return
251251
}
252252
d.syncStarted = true
253-
d.outputChan <- messageOk("start_sync")
253+
d.send(messageOk("start_sync"))
254254
}
255255

256256
func (d *Server) stop() {
257257
if !d.syncStarted && !d.started {
258-
d.outputChan <- messageError("stop", "Discovery already STOPped")
258+
d.send(messageError("stop", "Discovery already STOPped"))
259259
return
260260
}
261261
if err := d.impl.Stop(); err != nil {
262-
d.outputChan <- messageError("stop", "Cannot STOP: "+err.Error())
262+
d.send(messageError("stop", "Cannot STOP: "+err.Error()))
263263
return
264264
}
265265
d.started = false
266266
if d.syncStarted {
267267
d.syncStarted = false
268268
}
269-
d.outputChan <- messageOk("stop")
269+
d.send(messageOk("stop"))
270270
}
271271

272272
func (d *Server) syncEvent(event string, port *Port) {
273-
d.outputChan <- &message{
273+
d.send(&message{
274274
EventType: event,
275275
Port: port,
276-
}
276+
})
277277
}
278278

279279
func (d *Server) errorEvent(msg string) {
280-
d.outputChan <- messageError("start_sync", msg)
280+
d.send(messageError("start_sync", msg))
281281
}
282282

283-
func (d *Server) outputProcessor(outWriter io.Writer) {
284-
// Start go routine to serialize messages printing
285-
go func() {
286-
for msg := range d.outputChan {
287-
data, err := json.MarshalIndent(msg, "", " ")
288-
if err != nil {
289-
// We are certain that this will be marshalled correctly
290-
// so we don't handle the error
291-
data, _ = json.MarshalIndent(messageError("command_error", err.Error()), "", " ")
292-
}
293-
fmt.Fprintln(outWriter, string(data))
294-
}
295-
}()
283+
func (d *Server) send(msg *message) {
284+
data, err := json.MarshalIndent(msg, "", " ")
285+
if err != nil {
286+
// We are certain that this will be marshalled correctly
287+
// so we don't handle the error
288+
data, _ = json.MarshalIndent(messageError("command_error", err.Error()), "", " ")
289+
}
290+
data = append(data, '\n')
291+
292+
d.outputMutex.Lock()
293+
defer d.outputMutex.Unlock()
294+
n, err := d.output.Write(data)
295+
if n != len(data) || err != nil {
296+
panic("ERROR")
297+
}
296298
}

discovery_test.go

+36
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
package discovery
2+
3+
import (
4+
"testing"
5+
6+
"github.com/arduino/arduino-cli/executils"
7+
"github.com/stretchr/testify/require"
8+
)
9+
10+
func TestDisc(t *testing.T) {
11+
builder, err := executils.NewProcess(nil, "go", "build")
12+
require.NoError(t, err)
13+
builder.SetDir("dummy-discovery")
14+
require.NoError(t, builder.Run())
15+
16+
discovery, err := executils.NewProcess(nil, "./dummy-discovery")
17+
require.NoError(t, err)
18+
discovery.SetDir("dummy-discovery")
19+
20+
stdout, err := discovery.StdoutPipe()
21+
require.NoError(t, err)
22+
stdin, err := discovery.StdinPipe()
23+
require.NoError(t, err)
24+
25+
require.NoError(t, discovery.Start())
26+
27+
n, err := stdin.Write([]byte("quit\n"))
28+
require.NoError(t, err)
29+
require.Greater(t, n, 0)
30+
output := [1024]byte{}
31+
n, err = stdout.Read(output[:])
32+
require.Greater(t, n, 0)
33+
require.NoError(t, err)
34+
35+
require.Equal(t, "{\n \"eventType\": \"quit\",\n \"message\": \"OK\"\n}\n", string(output[:n]))
36+
}

go.mod

+5-1
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,8 @@ module github.com/arduino/pluggable-discovery-protocol-handler/v2
22

33
go 1.16
44

5-
require github.com/arduino/go-properties-orderedmap v1.4.0
5+
require (
6+
github.com/arduino/arduino-cli v0.0.0-20230206132931-4c0aaa876d1b
7+
github.com/arduino/go-properties-orderedmap v1.7.1
8+
github.com/stretchr/testify v1.8.0
9+
)

0 commit comments

Comments
 (0)