Eigen  5.0.1-dev
Loading...
Searching...
No Matches
GeneralMatrixVector.h
1// This file is part of Eigen, a lightweight C++ template library
2// for linear algebra.
3//
4// Copyright (C) 2008-2016 Gael Guennebaud <gael.guennebaud@inria.fr>
5//
6// This Source Code Form is subject to the terms of the Mozilla
7// Public License v. 2.0. If a copy of the MPL was not distributed
8// with this file, You can obtain one at http://mozilla.org/MPL/2.0/.
9
10#ifndef EIGEN_GENERAL_MATRIX_VECTOR_H
11#define EIGEN_GENERAL_MATRIX_VECTOR_H
12
13// IWYU pragma: private
14#include "../InternalHeaderCheck.h"
15
16namespace Eigen {
17
18namespace internal {
19
20enum GEMVPacketSizeType { GEMVPacketFull = 0, GEMVPacketHalf, GEMVPacketQuarter };
21
22template <int N, typename T1, typename T2, typename T3>
23struct gemv_packet_cond {
24 typedef T3 type;
25};
26
27template <typename T1, typename T2, typename T3>
28struct gemv_packet_cond<GEMVPacketFull, T1, T2, T3> {
29 typedef T1 type;
30};
31
32template <typename T1, typename T2, typename T3>
33struct gemv_packet_cond<GEMVPacketHalf, T1, T2, T3> {
34 typedef T2 type;
35};
36
37template <typename LhsScalar, typename RhsScalar, int PacketSize_ = GEMVPacketFull>
38class gemv_traits {
39 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
40
41#define PACKET_DECL_COND_POSTFIX(postfix, name, packet_size) \
42 typedef typename gemv_packet_cond< \
43 packet_size, typename packet_traits<name##Scalar>::type, typename packet_traits<name##Scalar>::half, \
44 typename unpacket_traits<typename packet_traits<name##Scalar>::half>::half>::type name##Packet##postfix
45
46 PACKET_DECL_COND_POSTFIX(_, Lhs, PacketSize_);
47 PACKET_DECL_COND_POSTFIX(_, Rhs, PacketSize_);
48 PACKET_DECL_COND_POSTFIX(_, Res, PacketSize_);
49#undef PACKET_DECL_COND_POSTFIX
50
51 public:
52 enum {
53 Vectorizable = unpacket_traits<LhsPacket_>::vectorizable && unpacket_traits<RhsPacket_>::vectorizable &&
54 int(unpacket_traits<LhsPacket_>::size) == int(unpacket_traits<RhsPacket_>::size),
55 LhsPacketSize = Vectorizable ? unpacket_traits<LhsPacket_>::size : 1,
56 RhsPacketSize = Vectorizable ? unpacket_traits<RhsPacket_>::size : 1,
57 ResPacketSize = Vectorizable ? unpacket_traits<ResPacket_>::size : 1
58 };
59
60 typedef std::conditional_t<Vectorizable, LhsPacket_, LhsScalar> LhsPacket;
61 typedef std::conditional_t<Vectorizable, RhsPacket_, RhsScalar> RhsPacket;
62 typedef std::conditional_t<Vectorizable, ResPacket_, ResScalar> ResPacket;
63};
64
65/* Optimized col-major matrix * vector product:
66 * This algorithm processes the matrix per vertical panels,
67 * which are then processed horizontally per chunk of 8*PacketSize x 1 vertical segments.
68 *
69 * Mixing type logic: C += alpha * A * B
70 * | A | B |alpha| comments
71 * |real |cplx |cplx | no vectorization
72 * |real |cplx |real | alpha is converted to a cplx when calling the run function, no vectorization
73 * |cplx |real |cplx | invalid, the caller has to do tmp: = A * B; C += alpha*tmp
74 * |cplx |real |real | optimal case, vectorization possible via real-cplx mul
75 *
76 * The same reasoning apply for the transposed case.
77 */
78template <typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar,
79 typename RhsMapper, bool ConjugateRhs, int Version>
80struct general_matrix_vector_product<Index, LhsScalar, LhsMapper, ColMajor, ConjugateLhs, RhsScalar, RhsMapper,
81 ConjugateRhs, Version> {
82 typedef gemv_traits<LhsScalar, RhsScalar> Traits;
83 typedef gemv_traits<LhsScalar, RhsScalar, GEMVPacketHalf> HalfTraits;
84 typedef gemv_traits<LhsScalar, RhsScalar, GEMVPacketQuarter> QuarterTraits;
85
86 typedef typename ScalarBinaryOpTraits<LhsScalar, RhsScalar>::ReturnType ResScalar;
87
88 typedef typename Traits::LhsPacket LhsPacket;
89 typedef typename Traits::RhsPacket RhsPacket;
90 typedef typename Traits::ResPacket ResPacket;
91
92 typedef typename HalfTraits::LhsPacket LhsPacketHalf;
93 typedef typename HalfTraits::RhsPacket RhsPacketHalf;
94 typedef typename HalfTraits::ResPacket ResPacketHalf;
95
96 typedef typename QuarterTraits::LhsPacket LhsPacketQuarter;
97 typedef typename QuarterTraits::RhsPacket RhsPacketQuarter;
98 typedef typename QuarterTraits::ResPacket ResPacketQuarter;
99
100 EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE static void run(Index rows, Index cols, const LhsMapper& lhs,
101 const RhsMapper& rhs, ResScalar* res, Index resIncr,
102 RhsScalar alpha);
103};
104
105template <typename Index, typename LhsScalar, typename LhsMapper, bool ConjugateLhs, typename RhsScalar,
106 typename RhsMapper, bool ConjugateRhs, int Version>
107EIGEN_DEVICE_FUNC EIGEN_DONT_INLINE void
108general_matrix_vector_product<Index, LhsScalar, LhsMapper, ColMajor, ConjugateLhs, RhsScalar, RhsMapper, ConjugateRhs,
109 Version>::run(Index rows, Index cols, const LhsMapper& alhs, const RhsMapper& rhs,
110 ResScalar* res, Index resIncr, RhsScalar alpha) {
111 EIGEN_UNUSED_VARIABLE(resIncr);
112 eigen_internal_assert(resIncr == 1);
113
114 // The following copy tells the compiler that lhs's attributes are not modified outside this function
115 // This helps GCC to generate proper code.
116 LhsMapper lhs(alhs);
117
118 conj_helper<LhsScalar, RhsScalar, ConjugateLhs, ConjugateRhs> cj;
119 conj_helper<LhsPacket, RhsPacket, ConjugateLhs, ConjugateRhs> pcj;
120 conj_helper<LhsPacketHalf, RhsPacketHalf, ConjugateLhs, ConjugateRhs> pcj_half;
121 conj_helper<LhsPacketQuarter, RhsPacketQuarter, ConjugateLhs, ConjugateRhs> pcj_quarter;
122
123 const Index lhsStride = lhs.stride();
124 // TODO: for padded aligned inputs, we could enable aligned reads
125 enum {
126 LhsAlignment = Unaligned,