Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tests] Unit tests non-tunable conv asm solvers #3494

Merged
merged 15 commits into from
Feb 12, 2025
2 changes: 1 addition & 1 deletion src/hip/handlehip.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -782,7 +782,7 @@ bool Handle::CooperativeLaunchSupported() const

std::string Handle::GetDeviceNameImpl() const { return this->impl->get_device_name(); }

std::string Handle::GetDeviceName() const { return this->impl->target_properties.Name(); }
std::string Handle::GetDeviceName() const { return this->GetTargetProperties().Name(); }

const TargetProperties& Handle::GetTargetProperties() const
{
Expand Down
30 changes: 15 additions & 15 deletions src/include/miopen/conv/solvers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -331,37 +331,37 @@ struct ConvAsm1x1UV2 final : ConvTunableSolver<PerformanceConfigConvAsm1x1UV2>
const PerformanceConfigConvAsm1x1UV2&) const override;
};

struct ConvAsm5x10u2v2f1 final : ConvSolver
struct MIOPEN_INTERNALS_EXPORT ConvAsm5x10u2v2f1 final : ConvSolver
{
const std::string& SolverDbId() const override { return GetSolverDbId<ConvAsm5x10u2v2f1>(); }

MIOPEN_INTERNALS_EXPORT bool
IsApplicable(const ExecutionContext&, const miopen::conv::ProblemDescription&) const override;
MIOPEN_INTERNALS_EXPORT ConvSolution
GetSolution(const ExecutionContext&, const miopen::conv::ProblemDescription&) const override;
bool IsApplicable(const ExecutionContext&,
const miopen::conv::ProblemDescription&) const override;
ConvSolution GetSolution(const ExecutionContext&,
const miopen::conv::ProblemDescription&) const override;
};

struct ConvAsm5x10u2v2b1 final : ConvSolver
struct MIOPEN_INTERNALS_EXPORT ConvAsm5x10u2v2b1 final : ConvSolver
{
const std::string& SolverDbId() const override { return GetSolverDbId<ConvAsm5x10u2v2b1>(); }

MIOPEN_INTERNALS_EXPORT bool
IsApplicable(const ExecutionContext&, const miopen::conv::ProblemDescription&) const override;
MIOPEN_INTERNALS_EXPORT ConvSolution
GetSolution(const ExecutionContext&, const miopen::conv::ProblemDescription&) const override;
bool IsApplicable(const ExecutionContext&,
const miopen::conv::ProblemDescription&) const override;
ConvSolution GetSolution(const ExecutionContext&,
const miopen::conv::ProblemDescription&) const override;
};

struct ConvAsm7x7c3h224w224k64u2v2p3q3f1 final : ConvSolver
struct MIOPEN_INTERNALS_EXPORT ConvAsm7x7c3h224w224k64u2v2p3q3f1 final : ConvSolver
{
const std::string& SolverDbId() const override
{
return GetSolverDbId<ConvAsm7x7c3h224w224k64u2v2p3q3f1>();
}

MIOPEN_INTERNALS_EXPORT bool
IsApplicable(const ExecutionContext&, const miopen::conv::ProblemDescription&) const override;
MIOPEN_INTERNALS_EXPORT ConvSolution
GetSolution(const ExecutionContext&, const miopen::conv::ProblemDescription&) const override;
bool IsApplicable(const ExecutionContext&,
const miopen::conv::ProblemDescription&) const override;
ConvSolution GetSolution(const ExecutionContext&,
const miopen::conv::ProblemDescription&) const override;
};

struct ConvOclDirectFwd11x11 final : ConvSolver
Expand Down
2 changes: 1 addition & 1 deletion src/include/miopen/fusion/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ inline bool WinoCommonIsApplicable(const FusionContext& context, const FusionDes
return false;
if(!conv_problem.IsDirectionForward())
return false;
const auto target = conv_ctx.GetStream().GetTargetProperties();
const auto& target = conv_ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;

Expand Down
4 changes: 2 additions & 2 deletions src/include/miopen/handle.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ struct MIOPEN_EXPORT Handle : miopenHandle
virtual std::size_t GetMaxMemoryAllocSize() const;
virtual bool CooperativeLaunchSupported() const;

virtual std::string GetDeviceName() const;
const TargetProperties& GetTargetProperties() const;
std::string GetDeviceName() const;
virtual const TargetProperties& GetTargetProperties() const;

private:
std::string GetDeviceNameImpl() const;
Expand Down
6 changes: 4 additions & 2 deletions src/include/miopen/target_properties.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,11 @@ struct Handle;

struct TargetProperties
{
const std::string& Name() const { return name; }
virtual ~TargetProperties() = default;

virtual const std::string& Name() const { return name; }
const std::string& DbId() const { return dbId; }
boost::optional<bool> Xnack() const { return xnack; }
virtual boost::optional<bool> Xnack() const { return xnack; }
boost::optional<bool> Sramecc() const { return sramecc; }
boost::optional<bool> SrameccReported() const { return sramecc_reported; }
static std::size_t GetMaxWaveScratchSize() { return MaxWaveScratchSize; }
Expand Down
2 changes: 1 addition & 1 deletion src/nogpu/handle.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ const TargetProperties& Handle::GetTargetProperties() const
}

std::string Handle::GetDeviceNameImpl() const { return this->impl->device_name; }
std::string Handle::GetDeviceName() const { return this->impl->target_properties.Name(); }
std::string Handle::GetDeviceName() const { return this->GetTargetProperties().Name(); }

std::ostream& Handle::Print(std::ostream& os) const
{
Expand Down
2 changes: 1 addition & 1 deletion src/ocl/handleocl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ std::size_t Handle::GetGlobalMemorySize() const

std::string Handle::GetDeviceNameImpl() const { return this->impl->get_device_name(); }

std::string Handle::GetDeviceName() const { return this->impl->target_properties.Name(); }
std::string Handle::GetDeviceName() const { return this->GetTargetProperties().Name(); }

const TargetProperties& Handle::GetTargetProperties() const
{
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv/conv_MP_bidirectional_winograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ static bool IsApplicableTransform(const ExecutionContext& ctx, const ProblemDesc
if(!(problem.IsFp32() || problem.IsFp16()))
return false;

const auto target = ctx.GetStream().GetTargetProperties();
const auto& target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;

Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv/conv_asm_1x1u.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ bool ConvAsm1x1U::IsApplicable(const ExecutionContext& ctx, const ProblemDescrip
if(problem.IsTensorsCasted())
return false;

const auto target = ctx.GetStream().GetTargetProperties();
const auto& target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;

Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv/conv_asm_1x1u_stride2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,7 @@ bool ConvAsm1x1UV2::IsApplicable(const ExecutionContext& ctx,
if(problem.IsTensorsCasted())
return false;

const auto target = ctx.GetStream().GetTargetProperties();
const auto& target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;

Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv/conv_asm_3x3u.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ bool ConvAsm3x3U::IsApplicable(const ExecutionContext& ctx, const ProblemDescrip
if(!ctx.rmv.IsV2orV3())
return false;

const auto target = ctx.GetStream().GetTargetProperties();
const auto& target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;

Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv/conv_asm_5x10u2v2b1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ bool ConvAsm5x10u2v2b1::IsApplicable(const ExecutionContext& ctx,
if(!ctx.rmv.IsV2orV3())
return false;

const auto target = ctx.GetStream().GetTargetProperties();
const auto& target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;

Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv/conv_asm_5x10u2v2f1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ bool ConvAsm5x10u2v2f1::IsApplicable(const ExecutionContext& ctx,
if(!ctx.rmv.IsV2orV3())
return false;

const auto target = ctx.GetStream().GetTargetProperties();
const auto& target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;

Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv/conv_asm_7x7c3h224w224k64u2v2p3q3f1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ bool ConvAsm7x7c3h224w224k64u2v2p3q3f1::IsApplicable(const ExecutionContext& ctx
if(problem.IsTensorsCasted())
return false;

const auto target = ctx.GetStream().GetTargetProperties();
const auto& target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;

Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv/conv_asm_dir_BwdWrW1x1.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -496,7 +496,7 @@ bool ConvAsmBwdWrW1x1::IsApplicable(const ExecutionContext& ctx,
if(problem.IsTensorsCasted())
return false;

const auto target = ctx.GetStream().GetTargetProperties();
const auto& target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;

Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv/conv_asm_dir_BwdWrW3x3.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ bool ConvAsmBwdWrW3x3::IsApplicable(const ExecutionContext& ctx,
if(!ctx.rmv.IsV2orV3())
return false;

const auto target = ctx.GetStream().GetTargetProperties();
const auto& target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ bool ConvAsmImplicitGemmV4R1DynamicBwd::IsApplicable(const ExecutionContext& ctx
if(!problem.IsLayoutDefault())
return false;

const auto target = ctx.GetStream().GetTargetProperties();
const auto& target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;

Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv/conv_asm_implicit_gemm_gtc_bwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1017,7 +1017,7 @@ bool ConvAsmImplicitGemmGTCDynamicBwdXdlops::IsApplicable(const ExecutionContext
if(!problem.IsLayoutDefault())
return false;

const auto target = ctx.GetStream().GetTargetProperties();
const auto& target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;

Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv/conv_asm_implicit_gemm_gtc_bwd_nhwc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -969,7 +969,7 @@ bool ConvAsmImplicitGemmGTCDynamicBwdXdlopsNHWC::IsApplicable(
if(!ctx.rmv.IsV3())
return false;

const auto target = ctx.GetStream().GetTargetProperties();
const auto& target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false; // NOLINT (readability-simplify-boolean-expr)

Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv/conv_asm_implicit_gemm_gtc_fwd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1550,7 +1550,7 @@ bool ConvAsmImplicitGemmGTCDynamicFwdXdlops::IsApplicable(const ExecutionContext
}
#endif

const auto target = ctx.GetStream().GetTargetProperties();
const auto& target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;

Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv/conv_asm_implicit_gemm_gtc_fwd_nchwc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,7 @@ bool ConvAsmImplicitGemmGTCDynamicFwdDlopsNCHWC::IsApplicable(
if(!ctx.rmv.IsV3())
return false;

const auto target = ctx.GetStream().GetTargetProperties();
const auto& target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false; // NOLINT (readability-simplify-boolean-expr)

Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv/conv_asm_implicit_gemm_gtc_fwd_nhwc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -912,7 +912,7 @@ bool ConvAsmImplicitGemmGTCDynamicFwdXdlopsNHWC::IsApplicable(
if(!ctx.rmv.IsV3())
return false;

const auto target = ctx.GetStream().GetTargetProperties();
const auto& target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false; // NOLINT (readability-simplify-boolean-expr)

Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv/conv_asm_implicit_gemm_gtc_wrw_nhwc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -918,7 +918,7 @@ bool ConvAsmImplicitGemmGTCDynamicWrwXdlopsNHWC::IsApplicable(
if(!ctx.rmv.IsV3())
return false;

const auto target = ctx.GetStream().GetTargetProperties();
const auto& target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false; // NOLINT (readability-simplify-boolean-expr)

Expand Down
4 changes: 2 additions & 2 deletions src/solver/conv/conv_asm_implicit_gemm_v4r1_dynamic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ bool ConvAsmImplicitGemmV4R1DynamicFwd::IsApplicable(const ExecutionContext& ctx
if(!problem.IsLayoutDefault())
return false;

const auto target = ctx.GetStream().GetTargetProperties();
const auto& target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;
auto tunables = GetImplicitGemmV4R1DynamicTunables();
Expand Down Expand Up @@ -366,7 +366,7 @@ bool ConvAsmImplicitGemmV4R1DynamicFwd_1x1::IsApplicable(const ExecutionContext&
if(!problem.IsLayoutDefault())
return false;

const auto target = ctx.GetStream().GetTargetProperties();
const auto& target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;
auto tunables = GetImplicitGemmV4R1DynamicTunables();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -861,7 +861,7 @@ bool ConvAsmImplicitGemmGTCDynamicWrwXdlops::IsApplicable(const ExecutionContext
if(!problem.IsLayoutDefault())
return false;

const auto target = ctx.GetStream().GetTargetProperties();
const auto& target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;
bool is_valid;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,7 @@ bool ConvAsmImplicitGemmV4R1DynamicWrw::IsApplicable(const ExecutionContext& ctx
if(!problem.IsLayoutDefault())
return false;

const auto target = ctx.GetStream().GetTargetProperties();
const auto& target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;
std::string kernel_name;
Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv/conv_bin_wino3x3U.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ bool ConvBinWinograd3x3U::IsApplicable(const ExecutionContext& ctx,
if(!(ctx.rmv.IsV2orV3() && ctx.use_asm_kernels))
return false;

const auto target = ctx.GetStream().GetTargetProperties();
const auto& target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;

Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv/conv_bin_winoRxS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ bool ConvBinWinogradRxS::IsApplicable(const ExecutionContext& ctx,
if(!ctx.rmv.IsV2orV3())
return false;

const auto target = ctx.GetStream().GetTargetProperties();
const auto& target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;

Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv/conv_multipass_wino3x3WrW.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ bool ConvWinograd3x3MultipassWrW<WinoDataH, WinoFilterH, WinoDataW, WinoFilterW>
if(!problem.IsLayoutDefault())
return false;

const auto target = ctx.GetStream().GetTargetProperties();
const auto& target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;

Expand Down
2 changes: 1 addition & 1 deletion src/solver/conv/conv_winoRxS.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -671,7 +671,7 @@ static bool IsApplicableBase(const ExecutionContext& ctx, const ProblemDescripti
if(!ctx.rmv.IsV3())
return false;

const auto target = ctx.GetStream().GetTargetProperties();
const auto& target = ctx.GetStream().GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;

Expand Down
2 changes: 1 addition & 1 deletion test/gtest/conv_igemm_dynamic_xdlops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ using TestCase = decltype(GetTestCases())::value_type;

bool IsTestSupportedForDevice(const miopen::Handle& handle)
{
const auto target = handle.GetTargetProperties();
const auto& target = handle.GetTargetProperties();
if(target.Xnack() && *target.Xnack())
return false;
using e_mask = enabled<Gpu::Default>;
Expand Down
2 changes: 1 addition & 1 deletion test/gtest/conv_igemm_dynamic_xdlops_half.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ void Run2dDriver(miopenDataType_t prec)

bool IsTestSupportedForDevice(const miopen::Handle& handle)
{
const auto target = handle.GetTargetProperties();
const auto& target = handle.GetTargetProperties();
std::string devName = handle.GetDeviceName();
if(target.Xnack() && *target.Xnack())
return false;
Expand Down
2 changes: 1 addition & 1 deletion test/gtest/conv_igemm_dynamic_xdlops_nhwc_bf16.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ void Run2dDriver(miopenDataType_t prec)

bool IsTestSupportedForDevice(const miopen::Handle& handle)
{
const auto target = handle.GetTargetProperties();
const auto& target = handle.GetTargetProperties();
std::string devName = handle.GetDeviceName();
if(target.Xnack() && *target.Xnack())
return false;
Expand Down
2 changes: 1 addition & 1 deletion test/gtest/conv_igemm_dynamic_xdlops_nhwc_nchw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ void Run2dDriver(miopenDataType_t prec)

bool IsTestSupportedForDevice(const miopen::Handle& handle)
{
const auto target = handle.GetTargetProperties();
const auto& target = handle.GetTargetProperties();
std::string devName = handle.GetDeviceName();
if(target.Xnack() && *target.Xnack())
return false;
Expand Down
25 changes: 23 additions & 2 deletions test/gtest/gtest_common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,30 @@ std::ostream& operator<<(std::ostream& os, const DevDescription& dd)
return os << dd.name << "(" << dd.cu_cnt << ")";
}

MockHandle::MockHandle(const DevDescription& dev_description) : dev_descr{dev_description} {}
MockTargetProperties::MockTargetProperties(const TargetProperties& target_properties,
const DevDescription& dev_description,
bool disable_xnack)
: TargetProperties{target_properties}, name{dev_description.name}, xnack_disabled{disable_xnack}
{
}

const std::string& MockTargetProperties::Name() const { return name; }

boost::optional<bool> MockTargetProperties::Xnack() const
{
return xnack_disabled ? boost::none : TargetProperties::Xnack();
}

std::string MockHandle::GetDeviceName() const { return std::string{dev_descr.name}; }
MockHandle::MockHandle(const DevDescription& dev_description, bool disable_xnack)
: dev_descr{dev_description},
target_properties{Handle::GetTargetProperties(), dev_description, disable_xnack}
{
}

const miopen::TargetProperties& MockHandle::GetTargetProperties() const
{
return target_properties;
}

std::size_t MockHandle::GetMaxComputeUnits() const { return dev_descr.cu_cnt; }

Expand Down
Loading
Loading