/*
 * Minio Cloud Storage, (C) 2019 Minio, Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package sql

import (
	"fmt"
	"strings"

	"github.com/xwb1989/sqlparser"
)

func getColumnName(colName *sqlparser.ColName) string {
	columnName := colName.Qualifier.Name.String()
	if qualifier := colName.Qualifier.Qualifier.String(); qualifier != "" {
		columnName = qualifier + "." + columnName
	}

	if columnName == "" {
		columnName = colName.Name.String()
	} else {
		columnName = columnName + "." + colName.Name.String()
	}

	return columnName
}

func newLiteralExpr(parserExpr sqlparser.Expr, tableAlias string) (Expr, error) {
	switch parserExpr.(type) {
	case *sqlparser.NullVal:
		return newValueExpr(NewNull()), nil
	case sqlparser.BoolVal:
		return newValueExpr(NewBool((bool(parserExpr.(sqlparser.BoolVal))))), nil
	case *sqlparser.SQLVal:
		sqlValue := parserExpr.(*sqlparser.SQLVal)
		value, err := NewValue(sqlValue)
		if err != nil {
			return nil, err
		}
		return newValueExpr(value), nil
	case *sqlparser.ColName:
		columnName := getColumnName(parserExpr.(*sqlparser.ColName))
		if tableAlias != "" {
			if !strings.HasPrefix(columnName, tableAlias+".") {
				err := fmt.Errorf("column name %v does not start with table alias %v", columnName, tableAlias)
				return nil, errInvalidKeyPath(err)
			}
			columnName = strings.TrimPrefix(columnName, tableAlias+".")
		}

		return newColumnExpr(columnName), nil
	case sqlparser.ValTuple:
		var valueType Type
		var values []*Value
		for i, valExpr := range parserExpr.(sqlparser.ValTuple) {
			sqlVal, ok := valExpr.(*sqlparser.SQLVal)
			if !ok {
				return nil, errParseInvalidTypeParam(fmt.Errorf("value %v in Tuple should be primitive value", i+1))
			}

			val, err := NewValue(sqlVal)
			if err != nil {
				return nil, err
			}

			if i == 0 {
				valueType = val.Type()
			} else if valueType != val.Type() {
				return nil, errParseInvalidTypeParam(fmt.Errorf("mixed value type is not allowed in Tuple"))
			}

			values = append(values, val)
		}

		return newValueExpr(NewArray(values)), nil
	}

	return nil, nil
}

func isExprToComparisonExpr(parserExpr *sqlparser.IsExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
	leftExpr, err := newExpr(parserExpr.Expr, tableAlias, isSelectExpr)
	if err != nil {
		return nil, err
	}

	f, err := newComparisonExpr(ComparisonOperator(parserExpr.Operator), leftExpr)
	if err != nil {
		return nil, err
	}

	if !leftExpr.Type().isBase() {
		return f, nil
	}

	value, err := f.Eval(nil)
	if err != nil {
		return nil, err
	}

	return newValueExpr(value), nil
}

func rangeCondToComparisonFunc(parserExpr *sqlparser.RangeCond, tableAlias string, isSelectExpr bool) (Expr, error) {
	leftExpr, err := newExpr(parserExpr.Left, tableAlias, isSelectExpr)
	if err != nil {
		return nil, err
	}

	fromExpr, err := newExpr(parserExpr.From, tableAlias, isSelectExpr)
	if err != nil {
		return nil, err
	}

	toExpr, err := newExpr(parserExpr.To, tableAlias, isSelectExpr)
	if err != nil {
		return nil, err
	}

	f, err := newComparisonExpr(ComparisonOperator(parserExpr.Operator), leftExpr, fromExpr, toExpr)
	if err != nil {
		return nil, err
	}

	if !leftExpr.Type().isBase() || !fromExpr.Type().isBase() || !toExpr.Type().isBase() {
		return f, nil
	}

	value, err := f.Eval(nil)
	if err != nil {
		return nil, err
	}

	return newValueExpr(value), nil
}

func toComparisonExpr(parserExpr *sqlparser.ComparisonExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
	leftExpr, err := newExpr(parserExpr.Left, tableAlias, isSelectExpr)
	if err != nil {
		return nil, err
	}

	rightExpr, err := newExpr(parserExpr.Right, tableAlias, isSelectExpr)
	if err != nil {
		return nil, err
	}

	f, err := newComparisonExpr(ComparisonOperator(parserExpr.Operator), leftExpr, rightExpr)
	if err != nil {
		return nil, err
	}

	if !leftExpr.Type().isBase() || !rightExpr.Type().isBase() {
		return f, nil
	}

	value, err := f.Eval(nil)
	if err != nil {
		return nil, err
	}

	return newValueExpr(value), nil
}

func toArithExpr(parserExpr *sqlparser.BinaryExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
	leftExpr, err := newExpr(parserExpr.Left, tableAlias, isSelectExpr)
	if err != nil {
		return nil, err
	}

	rightExpr, err := newExpr(parserExpr.Right, tableAlias, isSelectExpr)
	if err != nil {
		return nil, err
	}

	f, err := newArithExpr(ArithOperator(parserExpr.Operator), leftExpr, rightExpr)
	if err != nil {
		return nil, err
	}

	if !leftExpr.Type().isBase() || !rightExpr.Type().isBase() {
		return f, nil
	}

	value, err := f.Eval(nil)
	if err != nil {
		return nil, err
	}

	return newValueExpr(value), nil
}

func toFuncExpr(parserExpr *sqlparser.FuncExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
	funcName := strings.ToUpper(parserExpr.Name.String())
	if !isSelectExpr && isAggregateFuncName(funcName) {
		return nil, errUnsupportedSQLOperation(fmt.Errorf("%v() must be used in select expression", funcName))
	}
	funcs, aggregatedExprFound, err := newSelectExprs(parserExpr.Exprs, tableAlias)
	if err != nil {
		return nil, err
	}

	if aggregatedExprFound {
		return nil, errIncorrectSQLFunctionArgumentType(fmt.Errorf("%v(): aggregated expression must not be used as argument", funcName))
	}

	return newFuncExpr(FuncName(funcName), funcs...)
}

func toAndExpr(parserExpr *sqlparser.AndExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
	leftExpr, err := newExpr(parserExpr.Left, tableAlias, isSelectExpr)
	if err != nil {
		return nil, err
	}

	rightExpr, err := newExpr(parserExpr.Right, tableAlias, isSelectExpr)
	if err != nil {
		return nil, err
	}

	f, err := newAndExpr(leftExpr, rightExpr)
	if err != nil {
		return nil, err
	}

	if leftExpr.Type() != Bool || rightExpr.Type() != Bool {
		return f, nil
	}

	value, err := f.Eval(nil)
	if err != nil {
		return nil, err
	}

	return newValueExpr(value), nil
}

func toOrExpr(parserExpr *sqlparser.OrExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
	leftExpr, err := newExpr(parserExpr.Left, tableAlias, isSelectExpr)
	if err != nil {
		return nil, err
	}

	rightExpr, err := newExpr(parserExpr.Right, tableAlias, isSelectExpr)
	if err != nil {
		return nil, err
	}

	f, err := newOrExpr(leftExpr, rightExpr)
	if err != nil {
		return nil, err
	}

	if leftExpr.Type() != Bool || rightExpr.Type() != Bool {
		return f, nil
	}

	value, err := f.Eval(nil)
	if err != nil {
		return nil, err
	}

	return newValueExpr(value), nil
}

func toNotExpr(parserExpr *sqlparser.NotExpr, tableAlias string, isSelectExpr bool) (Expr, error) {
	rightExpr, err := newExpr(parserExpr.Expr, tableAlias, isSelectExpr)
	if err != nil {
		return nil, err
	}

	f, err := newNotExpr(rightExpr)
	if err != nil {
		return nil, err
	}

	if rightExpr.Type() != Bool {
		return f, nil
	}

	value, err := f.Eval(nil)
	if err != nil {
		return nil, err
	}

	return newValueExpr(value), nil
}

func newExpr(parserExpr sqlparser.Expr, tableAlias string, isSelectExpr bool) (Expr, error) {
	f, err := newLiteralExpr(parserExpr, tableAlias)
	if err != nil {
		return nil, err
	}

	if f != nil {
		return f, nil
	}

	switch parserExpr.(type) {
	case *sqlparser.ParenExpr:
		return newExpr(parserExpr.(*sqlparser.ParenExpr).Expr, tableAlias, isSelectExpr)
	case *sqlparser.IsExpr:
		return isExprToComparisonExpr(parserExpr.(*sqlparser.IsExpr), tableAlias, isSelectExpr)
	case *sqlparser.RangeCond:
		return rangeCondToComparisonFunc(parserExpr.(*sqlparser.RangeCond), tableAlias, isSelectExpr)
	case *sqlparser.ComparisonExpr:
		return toComparisonExpr(parserExpr.(*sqlparser.ComparisonExpr), tableAlias, isSelectExpr)
	case *sqlparser.BinaryExpr:
		return toArithExpr(parserExpr.(*sqlparser.BinaryExpr), tableAlias, isSelectExpr)
	case *sqlparser.FuncExpr:
		return toFuncExpr(parserExpr.(*sqlparser.FuncExpr), tableAlias, isSelectExpr)
	case *sqlparser.AndExpr:
		return toAndExpr(parserExpr.(*sqlparser.AndExpr), tableAlias, isSelectExpr)
	case *sqlparser.OrExpr:
		return toOrExpr(parserExpr.(*sqlparser.OrExpr), tableAlias, isSelectExpr)
	case *sqlparser.NotExpr:
		return toNotExpr(parserExpr.(*sqlparser.NotExpr), tableAlias, isSelectExpr)
	}

	return nil, errParseUnsupportedSyntax(fmt.Errorf("unknown expression type %T; %v", parserExpr, parserExpr))
}

func newSelectExprs(parserSelectExprs []sqlparser.SelectExpr, tableAlias string) ([]Expr, bool, error) {
	var funcs []Expr
	starExprFound := false
	aggregatedExprFound := false

	for _, selectExpr := range parserSelectExprs {
		switch selectExpr.(type) {
		case *sqlparser.AliasedExpr:
			if starExprFound {
				return nil, false, errParseAsteriskIsNotAloneInSelectList(nil)
			}

			aliasedExpr := selectExpr.(*sqlparser.AliasedExpr)
			f, err := newExpr(aliasedExpr.Expr, tableAlias, true)
			if err != nil {
				return nil, false, err
			}

			if f.Type() == aggregateFunction {
				if !aggregatedExprFound {
					aggregatedExprFound = true
					if len(funcs) > 0 {
						return nil, false, errParseUnsupportedSyntax(fmt.Errorf("expression must not mixed with aggregated expression"))
					}
				}
			} else if aggregatedExprFound {
				return nil, false, errParseUnsupportedSyntax(fmt.Errorf("expression must not mixed with aggregated expression"))
			}

			alias := aliasedExpr.As.String()
			if alias != "" {
				f = newAliasExpr(alias, f)
			}

			funcs = append(funcs, f)
		case *sqlparser.StarExpr:
			if starExprFound {
				err := fmt.Errorf("only single star expression allowed")
				return nil, false, errParseInvalidContextForWildcardInSelectList(err)
			}
			starExprFound = true
			funcs = append(funcs, newStarExpr())
		default:
			return nil, false, errParseUnsupportedSyntax(fmt.Errorf("unknown select expression %v", selectExpr))
		}
	}

	return funcs, aggregatedExprFound, nil
}

// Select - SQL Select statement.
type Select struct {
	tableName           string
	tableAlias          string
	selectExprs         []Expr
	aggregatedExprFound bool
	whereExpr           Expr
}

// TableAlias - returns table alias name.
func (statement *Select) TableAlias() string {
	return statement.tableAlias
}

// IsSelectAll - returns whether '*' is used in select expression or not.
func (statement *Select) IsSelectAll() bool {
	if len(statement.selectExprs) == 1 {
		_, ok := statement.selectExprs[0].(*starExpr)
		return ok
	}

	return false
}

// IsAggregated - returns whether aggregated functions are used in select expression or not.
func (statement *Select) IsAggregated() bool {
	return statement.aggregatedExprFound
}

// AggregateResult - returns aggregate result as record.
func (statement *Select) AggregateResult(output Record) error {
	if !statement.aggregatedExprFound {
		return nil
	}

	for i, expr := range statement.selectExprs {
		value, err := expr.AggregateValue()
		if err != nil {
			return err
		}
		if value == nil {
			return errInternalError(fmt.Errorf("%v returns <nil> for AggregateValue()", expr))
		}

		name := fmt.Sprintf("_%v", i+1)
		if _, ok := expr.(*aliasExpr); ok {
			name = expr.(*aliasExpr).alias
		}

		if err = output.Set(name, value); err != nil {
			return errInternalError(fmt.Errorf("error occurred to store value %v for %v; %v", value, name, err))
		}
	}

	return nil
}

// Eval - evaluates this Select expressions for given record.
func (statement *Select) Eval(input, output Record) (Record, error) {
	if statement.whereExpr != nil {
		value, err := statement.whereExpr.Eval(input)
		if err != nil {
			return nil, err
		}

		if value == nil || value.valueType != Bool {
			err = fmt.Errorf("WHERE expression %v returns invalid bool value %v", statement.whereExpr, value)
			return nil, errInternalError(err)
		}

		if !value.BoolValue() {
			return nil, nil
		}
	}

	// Call selectExprs
	for i, expr := range statement.selectExprs {
		value, err := expr.Eval(input)
		if err != nil {
			return nil, err
		}

		if statement.aggregatedExprFound {
			continue
		}

		name := fmt.Sprintf("_%v", i+1)
		switch expr.(type) {
		case *starExpr:
			return value.recordValue(), nil
		case *aliasExpr:
			name = expr.(*aliasExpr).alias
		case *columnExpr:
			name = expr.(*columnExpr).name
		}

		if err = output.Set(name, value); err != nil {
			return nil, errInternalError(fmt.Errorf("error occurred to store value %v for %v; %v", value, name, err))
		}
	}

	return output, nil
}

// NewSelect - creates new Select by parsing sql.
func NewSelect(sql string) (*Select, error) {
	stmt, err := sqlparser.Parse(sql)
	if err != nil {
		return nil, errUnsupportedSQLStructure(err)
	}

	selectStmt, ok := stmt.(*sqlparser.Select)
	if !ok {
		return nil, errParseUnsupportedSelect(fmt.Errorf("unsupported SQL statement %v", sql))
	}

	var tableName, tableAlias string
	for _, fromExpr := range selectStmt.From {
		tableExpr := fromExpr.(*sqlparser.AliasedTableExpr)
		tableName = tableExpr.Expr.(sqlparser.TableName).Name.String()
		tableAlias = tableExpr.As.String()
	}

	selectExprs, aggregatedExprFound, err := newSelectExprs(selectStmt.SelectExprs, tableAlias)
	if err != nil {
		return nil, err
	}

	var whereExpr Expr
	if selectStmt.Where != nil {
		whereExpr, err = newExpr(selectStmt.Where.Expr, tableAlias, false)
		if err != nil {
			return nil, err
		}
	}

	return &Select{
		tableName:           tableName,
		tableAlias:          tableAlias,
		selectExprs:         selectExprs,
		aggregatedExprFound: aggregatedExprFound,
		whereExpr:           whereExpr,
	}, nil
}