/*
 *
 * Copyright 2018 gRPC authors.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 *
 */

// Package handshaker provides ALTS handshaking functionality for GCP.
package handshaker

import (
	"errors"
	"fmt"
	"io"
	"net"
	"sync"

	"golang.org/x/net/context"
	grpc "google.golang.org/grpc"
	"google.golang.org/grpc/codes"
	"google.golang.org/grpc/credentials"
	"google.golang.org/grpc/credentials/alts/core"
	"google.golang.org/grpc/credentials/alts/core/authinfo"
	"google.golang.org/grpc/credentials/alts/core/conn"
	altsgrpc "google.golang.org/grpc/credentials/alts/core/proto/grpc_gcp"
	altspb "google.golang.org/grpc/credentials/alts/core/proto/grpc_gcp"
)

const (
	// The maximum byte size of receive frames.
	frameLimit              = 64 * 1024 // 64 KB
	rekeyRecordProtocolName = "ALTSRP_GCM_AES128_REKEY"
	// maxPendingHandshakes represents the maximum number of concurrent
	// handshakes.
	maxPendingHandshakes = 100
)

var (
	hsProtocol      = altspb.HandshakeProtocol_ALTS
	appProtocols    = []string{"grpc"}
	recordProtocols = []string{rekeyRecordProtocolName}
	keyLength       = map[string]int{
		rekeyRecordProtocolName: 44,
	}
	altsRecordFuncs = map[string]conn.ALTSRecordFunc{
		// ALTS handshaker protocols.
		rekeyRecordProtocolName: func(s core.Side, keyData []byte) (conn.ALTSRecordCrypto, error) {
			return conn.NewAES128GCMRekey(s, keyData)
		},
	}
	// control number of concurrent created (but not closed) handshakers.
	mu                   sync.Mutex
	concurrentHandshakes = int64(0)
	// errDropped occurs when maxPendingHandshakes is reached.
	errDropped = errors.New("maximum number of concurrent ALTS handshakes is reached")
)

func init() {
	for protocol, f := range altsRecordFuncs {
		if err := conn.RegisterProtocol(protocol, f); err != nil {
			panic(err)
		}
	}
}

func acquire(n int64) bool {
	mu.Lock()
	success := maxPendingHandshakes-concurrentHandshakes >= n
	if success {
		concurrentHandshakes += n
	}
	mu.Unlock()
	return success
}

func release(n int64) {
	mu.Lock()
	concurrentHandshakes -= n
	if concurrentHandshakes < 0 {
		mu.Unlock()
		panic("bad release")
	}
	mu.Unlock()
}

// ClientHandshakerOptions contains the client handshaker options that can
// provided by the caller.
type ClientHandshakerOptions struct {
	// ClientIdentity is the handshaker client local identity.
	ClientIdentity *altspb.Identity
	// TargetName is the server service account name for secure name
	// checking.
	TargetName string
	// TargetServiceAccounts contains a list of expected target service
	// accounts. One of these accounts should match one of the accounts in
	// the handshaker results. Otherwise, the handshake fails.
	TargetServiceAccounts []string
	// RPCVersions specifies the gRPC versions accepted by the client.
	RPCVersions *altspb.RpcProtocolVersions
}

// ServerHandshakerOptions contains the server handshaker options that can
// provided by the caller.
type ServerHandshakerOptions struct {
	// RPCVersions specifies the gRPC versions accepted by the server.
	RPCVersions *altspb.RpcProtocolVersions
}

// DefaultClientHandshakerOptions returns the default client handshaker options.
func DefaultClientHandshakerOptions() *ClientHandshakerOptions {
	return &ClientHandshakerOptions{}
}

// DefaultServerHandshakerOptions returns the default client handshaker options.
func DefaultServerHandshakerOptions() *ServerHandshakerOptions {
	return &ServerHandshakerOptions{}
}

// TODO: add support for future local and remote endpoint in both client options
//       and server options (server options struct does not exist now. When
//       caller can provide endpoints, it should be created.

// altsHandshaker is used to complete a ALTS handshaking between client and
// server. This handshaker talks to the ALTS handshaker service in the metadata
// server.
type altsHandshaker struct {
	// RPC stream used to access the ALTS Handshaker service.
	stream altsgrpc.HandshakerService_DoHandshakeClient
	// the connection to the peer.
	conn net.Conn
	// client handshake options.
	clientOpts *ClientHandshakerOptions
	// server handshake options.
	serverOpts *ServerHandshakerOptions
	// defines the side doing the handshake, client or server.
	side core.Side
}

// NewClientHandshaker creates a ALTS handshaker for GCP which contains an RPC
// stub created using the passed conn and used to talk to the ALTS Handshaker
// service in the metadata server.
func NewClientHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, opts *ClientHandshakerOptions) (core.Handshaker, error) {
	stream, err := altsgrpc.NewHandshakerServiceClient(conn).DoHandshake(ctx, grpc.FailFast(false))
	if err != nil {
		return nil, err
	}
	return &altsHandshaker{
		stream:     stream,
		conn:       c,
		clientOpts: opts,
		side:       core.ClientSide,
	}, nil
}

// NewServerHandshaker creates a ALTS handshaker for GCP which contains an RPC
// stub created using the passed conn and used to talk to the ALTS Handshaker
// service in the metadata server.
func NewServerHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, opts *ServerHandshakerOptions) (core.Handshaker, error) {
	stream, err := altsgrpc.NewHandshakerServiceClient(conn).DoHandshake(ctx, grpc.FailFast(false))
	if err != nil {
		return nil, err
	}
	return &altsHandshaker{
		stream:     stream,
		conn:       c,
		serverOpts: opts,
		side:       core.ServerSide,
	}, nil
}

// ClientHandshake starts and completes a client ALTS handshaking for GCP. Once
// done, ClientHandshake returns a secure connection.
func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
	if !acquire(1) {
		return nil, nil, errDropped
	}
	defer release(1)

	if h.side != core.ClientSide {
		return nil, nil, errors.New("only handshakers created using NewClientHandshaker can perform a client handshaker")
	}

	// Create target identities from service account list.
	targetIdentities := make([]*altspb.Identity, 0, len(h.clientOpts.TargetServiceAccounts))
	for _, account := range h.clientOpts.TargetServiceAccounts {
		targetIdentities = append(targetIdentities, &altspb.Identity{
			IdentityOneof: &altspb.Identity_ServiceAccount{
				ServiceAccount: account,
			},
		})
	}
	req := &altspb.HandshakerReq{
		ReqOneof: &altspb.HandshakerReq_ClientStart{
			ClientStart: &altspb.StartClientHandshakeReq{
				HandshakeSecurityProtocol: hsProtocol,
				ApplicationProtocols:      appProtocols,
				RecordProtocols:           recordProtocols,
				TargetIdentities:          targetIdentities,
				LocalIdentity:             h.clientOpts.ClientIdentity,
				TargetName:                h.clientOpts.TargetName,
				RpcVersions:               h.clientOpts.RPCVersions,
			},
		},
	}

	conn, result, err := h.doHandshake(req)
	if err != nil {
		return nil, nil, err
	}
	authInfo := authinfo.New(result)
	return conn, authInfo, nil
}

// ServerHandshake starts and completes a server ALTS handshaking for GCP. Once
// done, ServerHandshake returns a secure connection.
func (h *altsHandshaker) ServerHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
	if !acquire(1) {
		return nil, nil, errDropped
	}
	defer release(1)

	if h.side != core.ServerSide {
		return nil, nil, errors.New("only handshakers created using NewServerHandshaker can perform a server handshaker")
	}

	p := make([]byte, frameLimit)
	n, err := h.conn.Read(p)
	if err != nil {
		return nil, nil, err
	}

	// Prepare server parameters.
	// TODO: currently only ALTS parameters are provided. Might need to use
	//       more options in the future.
	params := make(map[int32]*altspb.ServerHandshakeParameters)
	params[int32(altspb.HandshakeProtocol_ALTS)] = &altspb.ServerHandshakeParameters{
		RecordProtocols: recordProtocols,
	}
	req := &altspb.HandshakerReq{
		ReqOneof: &altspb.HandshakerReq_ServerStart{
			ServerStart: &altspb.StartServerHandshakeReq{
				ApplicationProtocols: appProtocols,
				HandshakeParameters:  params,
				InBytes:              p[:n],
				RpcVersions:          h.serverOpts.RPCVersions,
			},
		},
	}

	conn, result, err := h.doHandshake(req)
	if err != nil {
		return nil, nil, err
	}
	authInfo := authinfo.New(result)
	return conn, authInfo, nil
}

func (h *altsHandshaker) doHandshake(req *altspb.HandshakerReq) (net.Conn, *altspb.HandshakerResult, error) {
	resp, err := h.accessHandshakerService(req)
	if err != nil {
		return nil, nil, err
	}
	// Check of the returned status is an error.
	if resp.GetStatus() != nil {
		if got, want := resp.GetStatus().Code, uint32(codes.OK); got != want {
			return nil, nil, fmt.Errorf("%v", resp.GetStatus().Details)
		}
	}

	var extra []byte
	if req.GetServerStart() != nil {
		extra = req.GetServerStart().GetInBytes()[resp.GetBytesConsumed():]
	}
	result, extra, err := h.processUntilDone(resp, extra)
	if err != nil {
		return nil, nil, err
	}
	// The handshaker returns a 128 bytes key. It should be truncated based
	// on the returned record protocol.
	keyLen, ok := keyLength[result.RecordProtocol]
	if !ok {
		return nil, nil, fmt.Errorf("unknown resulted record protocol %v", result.RecordProtocol)
	}
	sc, err := conn.NewConn(h.conn, h.side, result.GetRecordProtocol(), result.KeyData[:keyLen], extra)
	if err != nil {
		return nil, nil, err
	}
	return sc, result, nil
}

func (h *altsHandshaker) accessHandshakerService(req *altspb.HandshakerReq) (*altspb.HandshakerResp, error) {
	if err := h.stream.Send(req); err != nil {
		return nil, err
	}
	resp, err := h.stream.Recv()
	if err != nil {
		return nil, err
	}
	return resp, nil
}

// processUntilDone processes the handshake until the handshaker service returns
// the results. Handshaker service takes care of frame parsing, so we read
// whatever received from the network and send it to the handshaker service.
func (h *altsHandshaker) processUntilDone(resp *altspb.HandshakerResp, extra []byte) (*altspb.HandshakerResult, []byte, error) {
	for {
		if len(resp.OutFrames) > 0 {
			if _, err := h.conn.Write(resp.OutFrames); err != nil {
				return nil, nil, err
			}
		}
		if resp.Result != nil {
			return resp.Result, extra, nil
		}
		buf := make([]byte, frameLimit)
		n, err := h.conn.Read(buf)
		if err != nil && err != io.EOF {
			return nil, nil, err
		}
		// If there is nothing to send to the handshaker service, and
		// nothing is received from the peer, then we are stuck.
		// This covers the case when the peer is not responding. Note
		// that handshaker service connection issues are caught in
		// accessHandshakerService before we even get here.
		if len(resp.OutFrames) == 0 && n == 0 {
			return nil, nil, core.PeerNotRespondingError
		}
		// Append extra bytes from the previous interaction with the
		// handshaker service with the current buffer read from conn.
		p := append(extra, buf[:n]...)
		resp, err = h.accessHandshakerService(&altspb.HandshakerReq{
			ReqOneof: &altspb.HandshakerReq_Next{
				Next: &altspb.NextHandshakeMessageReq{
					InBytes: p,
				},
			},
		})
		if err != nil {
			return nil, nil, err
		}
		// Set extra based on handshaker service response.
		if n == 0 {
			extra = nil
		} else {
			extra = buf[resp.GetBytesConsumed():n]
		}
	}
}

// Close terminates the Handshaker. It should be called when the caller obtains
// the secure connection.
func (h *altsHandshaker) Close() {
	h.stream.CloseSend()
}
