package endpoint

import (
  "encoding/json"
  "fmt"
  "net/http"
  "runtime/debug"

  "github.com/go-chi/chi/middleware"
  "github.com/rs/zerolog/log"
)

type (
  Endpoint func(http.ResponseWriter, *http.Request) error

  ErrResponse struct {
    Status    int    `json:"status,omitempty"`
    Err       string `json:"error,omitempty"`
    Message   string `json:"message,omitempty"`
    Details   string `json:"details,omitempty"`
    Code      string `json:"code,omitempty"`
    RequestID string `json:"request,omitempty"`
  }
)

func Handle(handler Endpoint) http.HandlerFunc {
  fn := func(w http.ResponseWriter, r *http.Request) {
    if err := handler(w, r); err != nil {
      WriteError(w, r, err)
    }
  }

  return http.HandlerFunc(fn)
}

var nonErrorsCodes = map[int]bool{
  404: true,
}

func errResponse(input interface{}) *ErrResponse {
  var res *ErrResponse
  var err interface{}

  switch input.(type) {
  case *HandlerError:
    e := input.(*HandlerError)
    res = &ErrResponse{
      Status:  e.Status,
      Err:     http.StatusText(e.Status),
      Message: e.Message,
    }
    err = e.Err
  default:
    res = &ErrResponse{
      Status: http.StatusInternalServerError,
      Err:    http.StatusText(http.StatusInternalServerError),
    }
    err = input
  }

  if err != nil {
    switch err.(type) {
    case *error:
      e := err.(error)
      res.Details = e.Error()
      break
    default:
      res.Details = fmt.Sprintf("%+v", err)
      break
    }
  }

  return res
}

func WriteError(w http.ResponseWriter, r *http.Request, err interface{}) {
  hlog := log.With().
    Str("module", "http").
    Logger()

  res := errResponse(err)

  if reqID := middleware.GetReqID(r.Context()); reqID != "" {
    res.RequestID = reqID
  }

  w.Header().Set("Content-Type", "application/json")
  w.WriteHeader(res.Status)

  if err := json.NewEncoder(w).Encode(res); err != nil {
    hlog.Warn().Err(err).Msg("Failed writing json error response")
  }

  if !nonErrorsCodes[res.Status] {
    logEntry := middleware.GetLogEntry(r)
    if logEntry != nil {
      logEntry.Panic(err, debug.Stack())
    } else {
      hlog.Error().Str("stack", string(debug.Stack())).Msgf("%+v", err)
    }
  }
}