@@ -206,6 +206,62 @@ func TestMarshalMultiTag(t *testing.T) {
206
206
}
207
207
}
208
208
209
+ func TestDecode (t * testing.T ) {
210
+ rnd := rand .New (rand .NewSource (0 ))
211
+ kexInit := new (kexInitMsg ).Generate (rnd , 10 ).Interface ()
212
+ kexDHInit := new (kexDHInitMsg ).Generate (rnd , 10 ).Interface ()
213
+ kexDHReply := new (kexDHReplyMsg )
214
+ kexDHReply .Y = randomInt (rnd )
215
+ // Note: userAuthSuccessMsg can't be tested directly since it
216
+ // doesn't have a field for sshtype. So it's tested separately
217
+ // at the end.
218
+ decodeMessageTypes := []interface {}{
219
+ new (disconnectMsg ),
220
+ new (serviceRequestMsg ),
221
+ new (serviceAcceptMsg ),
222
+ new (extInfoMsg ),
223
+ kexInit ,
224
+ kexDHInit ,
225
+ kexDHReply ,
226
+ new (userAuthRequestMsg ),
227
+ new (userAuthFailureMsg ),
228
+ new (userAuthBannerMsg ),
229
+ new (userAuthPubKeyOkMsg ),
230
+ new (globalRequestMsg ),
231
+ new (globalRequestSuccessMsg ),
232
+ new (globalRequestFailureMsg ),
233
+ new (channelOpenMsg ),
234
+ new (channelDataMsg ),
235
+ new (channelOpenConfirmMsg ),
236
+ new (channelOpenFailureMsg ),
237
+ new (windowAdjustMsg ),
238
+ new (channelEOFMsg ),
239
+ new (channelCloseMsg ),
240
+ new (channelRequestMsg ),
241
+ new (channelRequestSuccessMsg ),
242
+ new (channelRequestFailureMsg ),
243
+ new (userAuthGSSAPIToken ),
244
+ new (userAuthGSSAPIMIC ),
245
+ new (userAuthGSSAPIErrTok ),
246
+ new (userAuthGSSAPIError ),
247
+ }
248
+ for _ , msg := range decodeMessageTypes {
249
+ decoded , err := decode (Marshal (msg ))
250
+ if err != nil {
251
+ t .Errorf ("error decoding %T" , msg )
252
+ } else if reflect .TypeOf (msg ) != reflect .TypeOf (decoded ) {
253
+ t .Errorf ("error decoding %T, unexpected %T" , msg , decoded )
254
+ }
255
+ }
256
+
257
+ userAuthSuccess , err := decode ([]byte {msgUserAuthSuccess })
258
+ if err != nil {
259
+ t .Errorf ("error decoding userAuthSuccessMsg" )
260
+ } else if reflect .TypeOf (userAuthSuccess ) != reflect .TypeOf ((* userAuthSuccessMsg )(nil )) {
261
+ t .Errorf ("error decoding userAuthSuccessMsg, unexpected %T" , userAuthSuccess )
262
+ }
263
+ }
264
+
209
265
func randomBytes (out []byte , rand * rand.Rand ) {
210
266
for i := 0 ; i < len (out ); i ++ {
211
267
out [i ] = byte (rand .Int31 ())
0 commit comments