/*
 * Copyright (c) 2013 IBM Corp.
 *
 * All rights reserved. This program and the accompanying materials
 * are made available under the terms of the Eclipse Public License v1.0
 * which accompanies this distribution, and is available at
 * http://www.eclipse.org/legal/epl-v10.html
 *
 * Contributors:
 *    Seth Hoenig
 *    Allan Stockdill-Mander
 *    Mike Robertson
 */

package mqtt

import (
	"crypto/tls"
	"errors"
	"fmt"
	"net"
	"net/url"
	"reflect"
	"time"

	"github.com/eclipse/paho.mqtt.golang/packets"
	"golang.org/x/net/websocket"
)

func signalError(c chan<- error, err error) {
	select {
	case c <- err:
	default:
	}
}

func openConnection(uri *url.URL, tlsc *tls.Config, timeout time.Duration) (net.Conn, error) {
	switch uri.Scheme {
	case "ws":
		conn, err := websocket.Dial(uri.String(), "mqtt", fmt.Sprintf("http://%s", uri.Host))
		if err != nil {
			return nil, err
		}
		conn.PayloadType = websocket.BinaryFrame
		return conn, err
	case "wss":
		config, _ := websocket.NewConfig(uri.String(), fmt.Sprintf("https://%s", uri.Host))
		config.Protocol = []string{"mqtt"}
		config.TlsConfig = tlsc
		conn, err := websocket.DialConfig(config)
		if err != nil {
			return nil, err
		}
		conn.PayloadType = websocket.BinaryFrame
		return conn, err
	case "tcp":
		conn, err := net.DialTimeout("tcp", uri.Host, timeout)
		if err != nil {
			return nil, err
		}
		return conn, nil
	case "ssl":
		fallthrough
	case "tls":
		fallthrough
	case "tcps":
		conn, err := tls.DialWithDialer(&net.Dialer{Timeout: timeout}, "tcp", uri.Host, tlsc)
		if err != nil {
			return nil, err
		}
		return conn, nil
	}
	return nil, errors.New("Unknown protocol")
}

// actually read incoming messages off the wire
// send Message object into ibound channel
func incoming(c *client) {
	var err error
	var cp packets.ControlPacket

	defer c.workers.Done()

	DEBUG.Println(NET, "incoming started")

	for {
		if cp, err = packets.ReadPacket(c.conn); err != nil {
			break
		}
		DEBUG.Println(NET, "Received Message")
		select {
		case c.ibound <- cp:
			// Notify keepalive logic that we recently received a packet
			if c.options.KeepAlive != 0 {
				c.packetResp <- struct{}{}
			}
		case <-c.stop:
			// This avoids a deadlock should a message arrive while shutting down.
			// In that case the "reader" of c.ibound might already be gone
			WARN.Println(NET, "incoming dropped a received message during shutdown")
			break
		}
	}
	// We received an error on read.
	// If disconnect is in progress, swallow error and return
	select {
	case <-c.stop:
		DEBUG.Println(NET, "incoming stopped")
		return
	// Not trying to disconnect, send the error to the errors channel
	default:
		ERROR.Println(NET, "incoming stopped with error", err)
		signalError(c.errors, err)
		return
	}
}

// receive a Message object on obound, and then
// actually send outgoing message to the wire
func outgoing(c *client) {
	defer c.workers.Done()
	DEBUG.Println(NET, "outgoing started")

	for {
		DEBUG.Println(NET, "outgoing waiting for an outbound message")
		select {
		case <-c.stop:
			DEBUG.Println(NET, "outgoing stopped")
			return
		case pub := <-c.obound:
			msg := pub.p.(*packets.PublishPacket)

			if c.options.WriteTimeout > 0 {
				c.conn.SetWriteDeadline(time.Now().Add(c.options.WriteTimeout))
			}

			if err := msg.Write(c.conn); err != nil {
				ERROR.Println(NET, "outgoing stopped with error", err)
				signalError(c.errors, err)
				return
			}

			if c.options.WriteTimeout > 0 {
				// If we successfully wrote, we don't want the timeout to happen during an idle period
				// so we reset it to infinite.
				c.conn.SetWriteDeadline(time.Time{})
			}

			if msg.Qos == 0 {
				pub.t.flowComplete()
			}
			DEBUG.Println(NET, "obound wrote msg, id:", msg.MessageID)
		case msg := <-c.oboundP:
			switch msg.p.(type) {
			case *packets.SubscribePacket:
				msg.p.(*packets.SubscribePacket).MessageID = c.getID(msg.t)
			case *packets.UnsubscribePacket:
				msg.p.(*packets.UnsubscribePacket).MessageID = c.getID(msg.t)
			}
			DEBUG.Println(NET, "obound priority msg to write, type", reflect.TypeOf(msg.p))
			if err := msg.p.Write(c.conn); err != nil {
				ERROR.Println(NET, "outgoing stopped with error", err)
				signalError(c.errors, err)
				return
			}
			switch msg.p.(type) {
			case *packets.DisconnectPacket:
				msg.t.(*DisconnectToken).flowComplete()
				DEBUG.Println(NET, "outbound wrote disconnect, stopping")
				return
			}
		}
		// Reset ping timer after sending control packet.
		if c.options.KeepAlive != 0 {
			select {
			case c.keepaliveReset <- struct{}{}:
			default:
				DEBUG.Println(NET, "couldn't send keepalive signal in outbound as channel full")
			}
		}
	}
}

// receive Message objects on ibound
// store messages if necessary
// send replies on obound
// delete messages from store if necessary
func alllogic(c *client) {
	defer c.workers.Done()
	DEBUG.Println(NET, "logic started")

	for {
		DEBUG.Println(NET, "logic waiting for msg on ibound")

		select {
		case msg := <-c.ibound:
			DEBUG.Println(NET, "logic got msg on ibound")
			persistInbound(c.persist, msg)
			switch m := msg.(type) {
			case *packets.PingrespPacket:
				DEBUG.Println(NET, "received pingresp")
				c.pingResp <- struct{}{}
			case *packets.SubackPacket:
				DEBUG.Println(NET, "received suback, id:", m.MessageID)
				token := c.getToken(m.MessageID).(*SubscribeToken)
				DEBUG.Println(NET, "granted qoss", m.ReturnCodes)
				for i, qos := range m.ReturnCodes {
					token.subResult[token.subs[i]] = qos
				}
				token.flowComplete()
				c.freeID(m.MessageID)
			case *packets.UnsubackPacket:
				DEBUG.Println(NET, "received unsuback, id:", m.MessageID)
				token := c.getToken(m.MessageID).(*UnsubscribeToken)
				token.flowComplete()
				c.freeID(m.MessageID)
			case *packets.PublishPacket:
				DEBUG.Println(NET, "received publish, msgId:", m.MessageID)
				DEBUG.Println(NET, "putting msg on onPubChan")
				switch m.Qos {
				case 2:
					c.incomingPubChan <- m
					DEBUG.Println(NET, "done putting msg on incomingPubChan")
					pr := packets.NewControlPacket(packets.Pubrec).(*packets.PubrecPacket)
					pr.MessageID = m.MessageID
					DEBUG.Println(NET, "putting pubrec msg on obound")
					c.oboundP <- &PacketAndToken{p: pr, t: nil}
					DEBUG.Println(NET, "done putting pubrec msg on obound")
				case 1:
					c.incomingPubChan <- m
					DEBUG.Println(NET, "done putting msg on incomingPubChan")
					pa := packets.NewControlPacket(packets.Puback).(*packets.PubackPacket)
					pa.MessageID = m.MessageID
					DEBUG.Println(NET, "putting puback msg on obound")
					c.oboundP <- &PacketAndToken{p: pa, t: nil}
					DEBUG.Println(NET, "done putting puback msg on obound")
				case 0:
					c.incomingPubChan <- m
					DEBUG.Println(NET, "done putting msg on incomingPubChan")
				}
			case *packets.PubackPacket:
				DEBUG.Println(NET, "received puback, id:", m.MessageID)
				// c.receipts.get(msg.MsgId()) <- Receipt{}
				// c.receipts.end(msg.MsgId())
				c.getToken(m.MessageID).flowComplete()
				c.freeID(m.MessageID)
			case *packets.PubrecPacket:
				DEBUG.Println(NET, "received pubrec, id:", m.MessageID)
				prel := packets.NewControlPacket(packets.Pubrel).(*packets.PubrelPacket)
				prel.MessageID = m.MessageID
				select {
				case c.oboundP <- &PacketAndToken{p: prel, t: nil}:
				case <-time.After(time.Second):
				}
			case *packets.PubrelPacket:
				DEBUG.Println(NET, "received pubrel, id:", m.MessageID)
				pc := packets.NewControlPacket(packets.Pubcomp).(*packets.PubcompPacket)
				pc.MessageID = m.MessageID
				select {
				case c.oboundP <- &PacketAndToken{p: pc, t: nil}:
				case <-time.After(time.Second):
				}
			case *packets.PubcompPacket:
				DEBUG.Println(NET, "received pubcomp, id:", m.MessageID)
				c.getToken(m.MessageID).flowComplete()
				c.freeID(m.MessageID)
			}
		case <-c.stop:
			WARN.Println(NET, "logic stopped")
			return
		}
	}
}

func errorWatch(c *client) {
	select {
	case <-c.stop:
		WARN.Println(NET, "errorWatch stopped")
		return
	case err := <-c.errors:
		ERROR.Println(NET, "error triggered, stopping")
		c.internalConnLost(err)
		return
	}
}