10#ifndef EIGEN_GENERAL_MATRIX_VECTOR_H
11#define EIGEN_GENERAL_MATRIX_VECTOR_H
14#include "../InternalHeaderCheck.h"
20enum GEMVPacketSizeType { GEMVPacketFull = 0, GEMVPacketHalf, GEMVPacketQuarter };
22template <
int N,
typename T1,
typename T2,
typename T3>
23struct gemv_packet_cond {
27template <
typename T1,
typename T2,
typename T3>
28struct gemv_packet_cond<GEMVPacketFull, T1, T2, T3> {
32template <
typename T1,
typename T2,
typename T3>
33struct gemv_packet_cond<GEMVPacketHalf, T1, T2, T3> {
37template <
typename LhsScalar,
typename RhsScalar,
int PacketSize_ = GEMVPacketFull>
39 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
41#define PACKET_DECL_COND_POSTFIX(postfix, name, packet_size) \
42 typedef typename gemv_packet_cond< \
43 packet_size, typename packet_traits<name##Scalar>::type, typename packet_traits<name##Scalar>::half, \
44 typename unpacket_traits<typename packet_traits<name##Scalar>::half>::half>::type name##Packet##postfix
46 PACKET_DECL_COND_POSTFIX(_, Lhs, PacketSize_);
47 PACKET_DECL_COND_POSTFIX(_, Rhs, PacketSize_);
48 PACKET_DECL_COND_POSTFIX(_, Res, PacketSize_);
49#undef PACKET_DECL_COND_POSTFIX
53 Vectorizable = unpacket_traits<LhsPacket_>::vectorizable && unpacket_traits<RhsPacket_>::vectorizable &&
54 int(unpacket_traits<LhsPacket_>::size) == int(unpacket_traits<RhsPacket_>::size),
55 LhsPacketSize = Vectorizable ? unpacket_traits<LhsPacket_>::size : 1,
56 RhsPacketSize = Vectorizable ? unpacket_traits<RhsPacket_>::size : 1,
57 ResPacketSize = Vectorizable ? unpacket_traits<ResPacket_>::size : 1
60 typedef std::conditional_t<Vectorizable, LhsPacket_, LhsScalar> LhsPacket;
61 typedef std::conditional_t<Vectorizable, RhsPacket_, RhsScalar> RhsPacket;
62 typedef std::conditional_t<Vectorizable, ResPacket_, ResScalar> ResPacket;
78template <
typename Index,
typename LhsScalar,
typename LhsMapper,
bool ConjugateLhs,
typename RhsScalar,
79 typename RhsMapper,
bool ConjugateRhs,
int Version>
80struct general_matrix_vector_product<
Index, LhsScalar, LhsMapper,
ColMajor, ConjugateLhs, RhsScalar, RhsMapper,
81 ConjugateRhs, Version> {
82 typedef gemv_traits<LhsScalar, RhsScalar> Traits;
83 typedef gemv_traits<LhsScalar, RhsScalar, GEMVPacketHalf> HalfTraits;
84 typedef gemv_traits<LhsScalar, RhsScalar, GEMVPacketQuarter> QuarterTraits;
86 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
88 typedef typename Traits::LhsPacket LhsPacket;
89 typedef typename Traits::RhsPacket RhsPacket;
90 typedef typename Traits::ResPacket ResPacket;
92 typedef typename HalfTraits::LhsPacket LhsPacketHalf;
93 typedef typename HalfTraits::RhsPacket RhsPacketHalf;
94 typedef typename HalfTraits::ResPacket ResPacketHalf;
96 typedef typename QuarterTraits::LhsPacket LhsPacketQuarter;
97 typedef typename QuarterTraits::RhsPacket RhsPacketQuarter;
98 typedef typename QuarterTraits::ResPacket ResPacketQuarter;
100 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE
static void run(
Index rows,
Index cols,
const LhsMapper& lhs,
101 const RhsMapper& rhs, ResScalar* res,
Index resIncr,
105template <
typename Index,
typename LhsScalar,
typename LhsMapper,
bool ConjugateLhs,
typename RhsScalar,
106 typename RhsMapper,
bool ConjugateRhs,
int Version>
107EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE
void
108general_matrix_vector_product<
Index, LhsScalar, LhsMapper,
ColMajor, ConjugateLhs, RhsScalar, RhsMapper, ConjugateRhs,
109 Version>::run(
Index rows,
Index cols,
const LhsMapper& alhs,
const RhsMapper& rhs,
110 ResScalar* res,
Index resIncr, RhsScalar alpha) {
111 EIGEN_UNUSED_VARIABLE(resIncr);
112 eigen_internal_assert(resIncr == 1);
118 conj_helper<LhsScalar, RhsScalar, ConjugateLhs, ConjugateRhs> cj;
119 conj_helper<LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs> pcj;
120 conj_helper<LhsPacketHalf, RhsPacketHalf, ConjugateLhs, ConjugateRhs> pcj_half;
121 conj_helper<LhsPacketQuarter, RhsPacketQuarter, ConjugateLhs, ConjugateRhs> pcj_quarter;
123 const Index lhsStride = lhs.stride();