/*******************************************************************************
* Copyright 2023 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

//@HEADER
// ***************************************************
//
// HPCG: High Performance Conjugate Gradient Benchmark
//
// Contact:
// Michael A. Heroux ( maherou@sandia.gov)
// Jack Dongarra     (dongarra@eecs.utk.edu)
// Piotr Luszczek    (luszczek@eecs.utk.edu)
//
// ***************************************************
//@HEADER


#include "CustomKernels.hpp"

#include "kernels/trsv_kernels.hpp"

namespace custom {

    // Select between kernel modes for SpTRSV
    // solve (L+D) * y = x,   -- uplo::lower_diagonal
    // solve (D+U) * y = x    -- uplo::upper_diagonal
    sycl::event SpTRSV(sycl::queue &queue, sparseMatrix *matA, uplo mode, double *x, double *y, const std::vector<sycl::event>& dependencies)
    {
        constexpr bool isFused = false;
        constexpr local_int_t block_size = 16;
        assert(matA->block_size == block_size);

        switch (mode) {
        case uplo::lower_diagonal:
            {
                return sparse_esb4_trsv_fwd_esimd<block_size, isFused>(
                    queue, matA->nrows, matA->nBlocks,
                    matA->esbblockptr, matA->esblastLower, matA->esbcolind, matA->esbvalues, matA->diags, matA->nColors,
                    matA->xcolors_host, x, y, dependencies);
            } break;
        case uplo::upper_diagonal:
            {
                return sparse_esb4_trsv_bwd_esimd<block_size, isFused>(
                    queue, matA->nrows, matA->nBlocks,
                    matA->esbfirstUpper, matA->esblastUpper, matA->esbcolind, matA->esbvalues, matA->diags, matA->nColors,
                    matA->xcolors_host, x, y, dependencies);
            } break;
        default:
            {
                throw std::runtime_error("unsupported uplo parameter called in SpTRSV");
            }
        }
        return sycl::event();
    }

    // Select between kernel modes for SpTRSV_FUSED
    // t = y, solve (L+D) * y = x, x = t + diag * y -- uplo::lower_diagonal
    // x = diag * x, solve (D+U) * y = x,           -- uplo::upper_diagonal
    sycl::event SpTRSV_FUSED(sycl::queue &queue, sparseMatrix *matA, uplo mode, double *x, double *y, const std::vector<sycl::event>& dependencies)
    {

        constexpr bool isFused = true;
        constexpr local_int_t block_size = 16;
        assert(matA->block_size == block_size);

        switch (mode) {
        case uplo::lower_diagonal:
            {
                return sparse_esb4_trsv_fwd_esimd<block_size, isFused>(
                    queue, matA->nrows, matA->nBlocks,
                    matA->esbblockptr, matA->esblastLower, matA->esbcolind, matA->esbvalues, matA->diags, matA->nColors,
                    matA->xcolors_host, x, y, dependencies);
            } break;
        case uplo::upper_diagonal:
            {
                return sparse_esb4_trsv_bwd_esimd<block_size, isFused>(
                    queue, matA->nrows, matA->nBlocks,
                    matA->esbfirstUpper, matA->esblastUpper, matA->esbcolind, matA->esbvalues, matA->diags, matA->nColors,
                    matA->xcolors_host, x, y, dependencies);
            } break;
        default:
            {
                throw std::runtime_error("unsupported uplo parameter called in SpTRSV");
            }
        }
        return sycl::event();
    }


} // namespace custom
