Kokkos Core Kernels Package Version of the Day
Loading...
Searching...
No Matches
Kokkos_Half.hpp
1//@HEADER
2// ************************************************************************
3//
4// Kokkos v. 4.0
5// Copyright (2022) National Technology & Engineering
6// Solutions of Sandia, LLC (NTESS).
7//
8// Under the terms of Contract DE-NA0003525 with NTESS,
9// the U.S. Government retains certain rights in this software.
10//
11// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions.
12// See https://kokkos.org/LICENSE for license information.
13// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
14//
15//@HEADER
16
17#ifndef KOKKOS_HALF_HPP_
18#define KOKKOS_HALF_HPP_
19#ifndef KOKKOS_IMPL_PUBLIC_INCLUDE
20#define KOKKOS_IMPL_PUBLIC_INCLUDE
21#define KOKKOS_IMPL_PUBLIC_INCLUDE_NOTDEFINED_HALF
22#endif
23
24#include <type_traits>
25#include <Kokkos_Macros.hpp>
26#include <iosfwd> // istream & ostream for extraction and insertion ops
27#include <string>
28
29#ifdef KOKKOS_IMPL_HALF_TYPE_DEFINED
30
31// KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH: A macro to select which
32// floating_pointer_wrapper operator paths should be used. For CUDA, let the
33// compiler conditionally select when device ops are used For SYCL, we have a
34// full half type on both host and device
35#if defined(__CUDA_ARCH__) || defined(KOKKOS_ENABLE_SYCL)
36#define KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
37#endif
38
39/************************* BEGIN forward declarations *************************/
40namespace Kokkos {
41namespace Experimental {
42namespace Impl {
43template <class FloatType>
44class floating_point_wrapper;
45}
46
47// Declare half_t (binary16)
48using half_t = Kokkos::Experimental::Impl::floating_point_wrapper<
49 Kokkos::Impl::half_impl_t ::type>;
50KOKKOS_INLINE_FUNCTION
51half_t cast_to_half(float val);
52KOKKOS_INLINE_FUNCTION
53half_t cast_to_half(bool val);
54KOKKOS_INLINE_FUNCTION
55half_t cast_to_half(double val);
56KOKKOS_INLINE_FUNCTION
57half_t cast_to_half(short val);
58KOKKOS_INLINE_FUNCTION
59half_t cast_to_half(int val);
60KOKKOS_INLINE_FUNCTION
61half_t cast_to_half(long val);
62KOKKOS_INLINE_FUNCTION
63half_t cast_to_half(long long val);
64KOKKOS_INLINE_FUNCTION
65half_t cast_to_half(unsigned short val);
66KOKKOS_INLINE_FUNCTION
67half_t cast_to_half(unsigned int val);
68KOKKOS_INLINE_FUNCTION
69half_t cast_to_half(unsigned long val);
70KOKKOS_INLINE_FUNCTION
71half_t cast_to_half(unsigned long long val);
72KOKKOS_INLINE_FUNCTION
73half_t cast_to_half(half_t);
74
75template <class T>
76KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_same<T, float>::value, T>
77 cast_from_half(half_t);
78template <class T>
79KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_same<T, bool>::value, T>
80 cast_from_half(half_t);
81template <class T>
82KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_same<T, double>::value, T>
83 cast_from_half(half_t);
84template <class T>
85KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_same<T, short>::value, T>
86 cast_from_half(half_t);
87template <class T>
88KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_same<T, int>::value, T>
89 cast_from_half(half_t);
90template <class T>
91KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_same<T, long>::value, T>
92 cast_from_half(half_t);
93template <class T>
94KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_same<T, long long>::value, T>
95 cast_from_half(half_t);
96template <class T>
97KOKKOS_INLINE_FUNCTION
98 std::enable_if_t<std::is_same<T, unsigned short>::value, T>
99 cast_from_half(half_t);
100template <class T>
101KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_same<T, unsigned int>::value, T>
102 cast_from_half(half_t);
103template <class T>
104KOKKOS_INLINE_FUNCTION
105 std::enable_if_t<std::is_same<T, unsigned long>::value, T>
106 cast_from_half(half_t);
107template <class T>
108KOKKOS_INLINE_FUNCTION
109 std::enable_if_t<std::is_same<T, unsigned long long>::value, T>
110 cast_from_half(half_t);
111
112// declare bhalf_t
113#ifdef KOKKOS_IMPL_BHALF_TYPE_DEFINED
114using bhalf_t = Kokkos::Experimental::Impl::floating_point_wrapper<
115 Kokkos::Impl ::bhalf_impl_t ::type>;
116
117KOKKOS_INLINE_FUNCTION
118bhalf_t cast_to_bhalf(float val);
119KOKKOS_INLINE_FUNCTION
120bhalf_t cast_to_bhalf(bool val);
121KOKKOS_INLINE_FUNCTION
122bhalf_t cast_to_bhalf(double val);
123KOKKOS_INLINE_FUNCTION
124bhalf_t cast_to_bhalf(short val);
125KOKKOS_INLINE_FUNCTION
126bhalf_t cast_to_bhalf(int val);
127KOKKOS_INLINE_FUNCTION
128bhalf_t cast_to_bhalf(long val);
129KOKKOS_INLINE_FUNCTION
130bhalf_t cast_to_bhalf(long long val);
131KOKKOS_INLINE_FUNCTION
132bhalf_t cast_to_bhalf(unsigned short val);
133KOKKOS_INLINE_FUNCTION
134bhalf_t cast_to_bhalf(unsigned int val);
135KOKKOS_INLINE_FUNCTION
136bhalf_t cast_to_bhalf(unsigned long val);
137KOKKOS_INLINE_FUNCTION
138bhalf_t cast_to_bhalf(unsigned long long val);
139KOKKOS_INLINE_FUNCTION
140bhalf_t cast_to_bhalf(bhalf_t val);
141
142template <class T>
143KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_same<T, float>::value, T>
144 cast_from_bhalf(bhalf_t);
145template <class T>
146KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_same<T, bool>::value, T>
147 cast_from_bhalf(bhalf_t);
148template <class T>
149KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_same<T, double>::value, T>
150 cast_from_bhalf(bhalf_t);
151template <class T>
152KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_same<T, short>::value, T>
153 cast_from_bhalf(bhalf_t);
154template <class T>
155KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_same<T, int>::value, T>
156 cast_from_bhalf(bhalf_t);
157template <class T>
158KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_same<T, long>::value, T>
159 cast_from_bhalf(bhalf_t);
160template <class T>
161KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_same<T, long long>::value, T>
162 cast_from_bhalf(bhalf_t);
163template <class T>
164KOKKOS_INLINE_FUNCTION
165 std::enable_if_t<std::is_same<T, unsigned short>::value, T>
166 cast_from_bhalf(bhalf_t);
167template <class T>
168KOKKOS_INLINE_FUNCTION std::enable_if_t<std::is_same<T, unsigned int>::value, T>
169 cast_from_bhalf(bhalf_t);
170template <class T>
171KOKKOS_INLINE_FUNCTION
172 std::enable_if_t<std::is_same<T, unsigned long>::value, T>
173 cast_from_bhalf(bhalf_t);
174template <class T>
175KOKKOS_INLINE_FUNCTION
176 std::enable_if_t<std::is_same<T, unsigned long long>::value, T>
177 cast_from_bhalf(bhalf_t);
178#endif // KOKKOS_IMPL_BHALF_TYPE_DEFINED
179
180template <class T>
181static KOKKOS_INLINE_FUNCTION Kokkos::Experimental::half_t cast_to_wrapper(
182 T x, const volatile Kokkos::Impl::half_impl_t::type&);
183
184#ifdef KOKKOS_IMPL_BHALF_TYPE_DEFINED
185template <class T>
186static KOKKOS_INLINE_FUNCTION Kokkos::Experimental::bhalf_t cast_to_wrapper(
187 T x, const volatile Kokkos::Impl::bhalf_impl_t::type&);
188#endif // KOKKOS_IMPL_BHALF_TYPE_DEFINED
189
190template <class T>
191static KOKKOS_INLINE_FUNCTION T
192cast_from_wrapper(const Kokkos::Experimental::half_t& x);
193
194#ifdef KOKKOS_IMPL_BHALF_TYPE_DEFINED
195template <class T>
196static KOKKOS_INLINE_FUNCTION T
197cast_from_wrapper(const Kokkos::Experimental::bhalf_t& x);
198#endif // KOKKOS_IMPL_BHALF_TYPE_DEFINED
199/************************** END forward declarations **************************/
200
201namespace Impl {
202template <class FloatType>
203class alignas(FloatType) floating_point_wrapper {
204 public:
205 using impl_type = FloatType;
206
207 private:
208 impl_type val;
209 using fixed_width_integer_type = std::conditional_t<
210 sizeof(impl_type) == 2, uint16_t,
211 std::conditional_t<
212 sizeof(impl_type) == 4, uint32_t,
213 std::conditional_t<sizeof(impl_type) == 8, uint64_t, void>>>;
214 static_assert(!std::is_void<fixed_width_integer_type>::value,
215 "Invalid impl_type");
216
217 public:
218 // In-class initialization and defaulted default constructors not used
219 // since Cuda supports half precision initialization via the below constructor
220 KOKKOS_FUNCTION
221 floating_point_wrapper() : val(0.0F) {}
222
223// Copy constructors
224// Getting "C2580: multiple versions of a defaulted special
225// member function are not allowed" with VS 16.11.3 and CUDA 11.4.2
226#if defined(_WIN32) && defined(KOKKOS_ENABLE_CUDA)
227 KOKKOS_FUNCTION
228 floating_point_wrapper(const floating_point_wrapper& rhs) : val(rhs.val) {}
229#else
230 KOKKOS_DEFAULTED_FUNCTION
231 floating_point_wrapper(const floating_point_wrapper&) noexcept = default;
232#endif
233
234 KOKKOS_INLINE_FUNCTION
235 floating_point_wrapper(const volatile floating_point_wrapper& rhs) {
236#if defined(KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH) && !defined(KOKKOS_ENABLE_SYCL)
237 val = rhs.val;
238#else
239 const volatile fixed_width_integer_type* rv_ptr =
240 reinterpret_cast<const volatile fixed_width_integer_type*>(&rhs.val);
241 const fixed_width_integer_type rv_val = *rv_ptr;
242 val = reinterpret_cast<const impl_type&>(rv_val);
243#endif // KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
244 }
245
246 // Don't support implicit conversion back to impl_type.
247 // impl_type is a storage only type on host.
248 KOKKOS_FUNCTION
249 explicit operator impl_type() const { return val; }
250 KOKKOS_FUNCTION
251 explicit operator float() const { return cast_from_wrapper<float>(*this); }
252 KOKKOS_FUNCTION
253 explicit operator bool() const { return cast_from_wrapper<bool>(*this); }
254 KOKKOS_FUNCTION
255 explicit operator double() const { return cast_from_wrapper<double>(*this); }
256 KOKKOS_FUNCTION
257 explicit operator short() const { return cast_from_wrapper<short>(*this); }
258 KOKKOS_FUNCTION
259 explicit operator int() const { return cast_from_wrapper<int>(*this); }
260 KOKKOS_FUNCTION
261 explicit operator long() const { return cast_from_wrapper<long>(*this); }
262 KOKKOS_FUNCTION
263 explicit operator long long() const {
264 return cast_from_wrapper<long long>(*this);
265 }
266 KOKKOS_FUNCTION
267 explicit operator unsigned short() const {
268 return cast_from_wrapper<unsigned short>(*this);
269 }
270 KOKKOS_FUNCTION
271 explicit operator unsigned int() const {
272 return cast_from_wrapper<unsigned int>(*this);
273 }
274 KOKKOS_FUNCTION
275 explicit operator unsigned long() const {
276 return cast_from_wrapper<unsigned long>(*this);
277 }
278 KOKKOS_FUNCTION
279 explicit operator unsigned long long() const {
280 return cast_from_wrapper<unsigned long long>(*this);
281 }
282
297 KOKKOS_FUNCTION
298 constexpr floating_point_wrapper(impl_type rhs) : val(rhs) {}
299 KOKKOS_FUNCTION
300 floating_point_wrapper(float rhs) : val(cast_to_wrapper(rhs, val).val) {}
301 KOKKOS_FUNCTION
302 floating_point_wrapper(double rhs) : val(cast_to_wrapper(rhs, val).val) {}
303 KOKKOS_FUNCTION
304 explicit floating_point_wrapper(bool rhs)
305 : val(cast_to_wrapper(rhs, val).val) {}
306 KOKKOS_FUNCTION
307 floating_point_wrapper(short rhs) : val(cast_to_wrapper(rhs, val).val) {}
308 KOKKOS_FUNCTION
309 floating_point_wrapper(int rhs) : val(cast_to_wrapper(rhs, val).val) {}
310 KOKKOS_FUNCTION
311 floating_point_wrapper(long rhs) : val(cast_to_wrapper(rhs, val).val) {}
312 KOKKOS_FUNCTION
313 floating_point_wrapper(long long rhs) : val(cast_to_wrapper(rhs, val).val) {}
314 KOKKOS_FUNCTION
315 floating_point_wrapper(unsigned short rhs)
316 : val(cast_to_wrapper(rhs, val).val) {}
317 KOKKOS_FUNCTION
318 floating_point_wrapper(unsigned int rhs)
319 : val(cast_to_wrapper(rhs, val).val) {}
320 KOKKOS_FUNCTION
321 floating_point_wrapper(unsigned long rhs)
322 : val(cast_to_wrapper(rhs, val).val) {}
323 KOKKOS_FUNCTION
324 floating_point_wrapper(unsigned long long rhs)
325 : val(cast_to_wrapper(rhs, val).val) {}
326
327 // Unary operators
328 KOKKOS_FUNCTION
329 floating_point_wrapper operator+() const {
330 floating_point_wrapper tmp = *this;
331#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
332 tmp.val = +tmp.val;
333#else
334 tmp.val = cast_to_wrapper(+cast_from_wrapper<float>(tmp), val).val;
335#endif
336 return tmp;
337 }
338
339 KOKKOS_FUNCTION
340 floating_point_wrapper operator-() const {
341 floating_point_wrapper tmp = *this;
342#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
343 tmp.val = -tmp.val;
344#else
345 tmp.val = cast_to_wrapper(-cast_from_wrapper<float>(tmp), val).val;
346#endif
347 return tmp;
348 }
349
350 // Prefix operators
351 KOKKOS_FUNCTION
352 floating_point_wrapper& operator++() {
353#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
354 val = val + impl_type(1.0F); // cuda has no operator++ for __nv_bfloat
355#else
356 float tmp = cast_from_wrapper<float>(*this);
357 ++tmp;
358 val = cast_to_wrapper(tmp, val).val;
359#endif
360 return *this;
361 }
362
363 KOKKOS_FUNCTION
364 floating_point_wrapper& operator--() {
365#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
366 val = val - impl_type(1.0F); // cuda has no operator-- for __nv_bfloat
367#else
368 float tmp = cast_from_wrapper<float>(*this);
369 --tmp;
370 val = cast_to_wrapper(tmp, val).val;
371#endif
372 return *this;
373 }
374
375 // Postfix operators
376 KOKKOS_FUNCTION
377 floating_point_wrapper operator++(int) {
378 floating_point_wrapper tmp = *this;
379 operator++();
380 return tmp;
381 }
382
383 KOKKOS_FUNCTION
384 floating_point_wrapper operator--(int) {
385 floating_point_wrapper tmp = *this;
386 operator--();
387 return tmp;
388 }
389
390 // Binary operators
391 KOKKOS_FUNCTION
392 floating_point_wrapper& operator=(impl_type rhs) {
393 val = rhs;
394 return *this;
395 }
396
397 template <class T>
398 KOKKOS_FUNCTION floating_point_wrapper& operator=(T rhs) {
399 val = cast_to_wrapper(rhs, val).val;
400 return *this;
401 }
402
403 template <class T>
404 KOKKOS_FUNCTION void operator=(T rhs) volatile {
405 impl_type new_val = cast_to_wrapper(rhs, val).val;
406 volatile fixed_width_integer_type* val_ptr =
407 reinterpret_cast<volatile fixed_width_integer_type*>(
408 const_cast<impl_type*>(&val));
409 *val_ptr = reinterpret_cast<fixed_width_integer_type&>(new_val);
410 }
411
412 // Compound operators
413 KOKKOS_FUNCTION
414 floating_point_wrapper& operator+=(floating_point_wrapper rhs) {
415#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
416 val = val + rhs.val; // cuda has no operator+= for __nv_bfloat
417#else
418 val = cast_to_wrapper(
419 cast_from_wrapper<float>(*this) + cast_from_wrapper<float>(rhs),
420 val)
421 .val;
422#endif
423 return *this;
424 }
425
426 KOKKOS_FUNCTION
427 void operator+=(const volatile floating_point_wrapper& rhs) volatile {
428 floating_point_wrapper tmp_rhs = rhs;
429 floating_point_wrapper tmp_lhs = *this;
430
431 tmp_lhs += tmp_rhs;
432 *this = tmp_lhs;
433 }
434
435 // Compound operators: upcast overloads for +=
436 template <class T>
437 KOKKOS_FUNCTION friend std::enable_if_t<
438 std::is_same<T, float>::value || std::is_same<T, double>::value, T>
439 operator+=(T& lhs, floating_point_wrapper rhs) {
440 lhs += static_cast<T>(rhs);
441 return lhs;
442 }
443
444 KOKKOS_FUNCTION
445 floating_point_wrapper& operator+=(float rhs) {
446 float result = static_cast<float>(val) + rhs;
447 val = static_cast<impl_type>(result);
448 return *this;
449 }
450
451 KOKKOS_FUNCTION
452 floating_point_wrapper& operator+=(double rhs) {
453 double result = static_cast<double>(val) + rhs;
454 val = static_cast<impl_type>(result);
455 return *this;
456 }
457
458 KOKKOS_FUNCTION
459 floating_point_wrapper& operator-=(floating_point_wrapper rhs) {
460#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
461 val = val - rhs.val; // cuda has no operator-= for __nv_bfloat
462#else
463 val = cast_to_wrapper(
464 cast_from_wrapper<float>(*this) - cast_from_wrapper<float>(rhs),
465 val)
466 .val;
467#endif
468 return *this;
469 }
470
471 KOKKOS_FUNCTION
472 void operator-=(const volatile floating_point_wrapper& rhs) volatile {
473 floating_point_wrapper tmp_rhs = rhs;
474 floating_point_wrapper tmp_lhs = *this;
475
476 tmp_lhs -= tmp_rhs;
477 *this = tmp_lhs;
478 }
479
480 // Compund operators: upcast overloads for -=
481 template <class T>
482 KOKKOS_FUNCTION friend std::enable_if_t<
483 std::is_same<T, float>::value || std::is_same<T, double>::value, T>
484 operator-=(T& lhs, floating_point_wrapper rhs) {
485 lhs -= static_cast<T>(rhs);
486 return lhs;
487 }
488
489 KOKKOS_FUNCTION
490 floating_point_wrapper& operator-=(float rhs) {
491 float result = static_cast<float>(val) - rhs;
492 val = static_cast<impl_type>(result);
493 return *this;
494 }
495
496 KOKKOS_FUNCTION
497 floating_point_wrapper& operator-=(double rhs) {
498 double result = static_cast<double>(val) - rhs;
499 val = static_cast<impl_type>(result);
500 return *this;
501 }
502
503 KOKKOS_FUNCTION
504 floating_point_wrapper& operator*=(floating_point_wrapper rhs) {
505#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
506 val = val * rhs.val; // cuda has no operator*= for __nv_bfloat
507#else
508 val = cast_to_wrapper(
509 cast_from_wrapper<float>(*this) * cast_from_wrapper<float>(rhs),
510 val)
511 .val;
512#endif
513 return *this;
514 }
515
516 KOKKOS_FUNCTION
517 void operator*=(const volatile floating_point_wrapper& rhs) volatile {
518 floating_point_wrapper tmp_rhs = rhs;
519 floating_point_wrapper tmp_lhs = *this;
520
521 tmp_lhs *= tmp_rhs;
522 *this = tmp_lhs;
523 }
524
525 // Compund operators: upcast overloads for *=
526 template <class T>
527 KOKKOS_FUNCTION friend std::enable_if_t<
528 std::is_same<T, float>::value || std::is_same<T, double>::value, T>
529 operator*=(T& lhs, floating_point_wrapper rhs) {
530 lhs *= static_cast<T>(rhs);
531 return lhs;
532 }
533
534 KOKKOS_FUNCTION
535 floating_point_wrapper& operator*=(float rhs) {
536 float result = static_cast<float>(val) * rhs;
537 val = static_cast<impl_type>(result);
538 return *this;
539 }
540
541 KOKKOS_FUNCTION
542 floating_point_wrapper& operator*=(double rhs) {
543 double result = static_cast<double>(val) * rhs;
544 val = static_cast<impl_type>(result);
545 return *this;
546 }
547
548 KOKKOS_FUNCTION
549 floating_point_wrapper& operator/=(floating_point_wrapper rhs) {
550#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
551 val = val / rhs.val; // cuda has no operator/= for __nv_bfloat
552#else
553 val = cast_to_wrapper(
554 cast_from_wrapper<float>(*this) / cast_from_wrapper<float>(rhs),
555 val)
556 .val;
557#endif
558 return *this;
559 }
560
561 KOKKOS_FUNCTION
562 void operator/=(const volatile floating_point_wrapper& rhs) volatile {
563 floating_point_wrapper tmp_rhs = rhs;
564 floating_point_wrapper tmp_lhs = *this;
565
566 tmp_lhs /= tmp_rhs;
567 *this = tmp_lhs;
568 }
569
570 // Compund operators: upcast overloads for /=
571 template <class T>
572 KOKKOS_FUNCTION friend std::enable_if_t<
573 std::is_same<T, float>::value || std::is_same<T, double>::value, T>
574 operator/=(T& lhs, floating_point_wrapper rhs) {
575 lhs /= static_cast<T>(rhs);
576 return lhs;
577 }
578
579 KOKKOS_FUNCTION
580 floating_point_wrapper& operator/=(float rhs) {
581 float result = static_cast<float>(val) / rhs;
582 val = static_cast<impl_type>(result);
583 return *this;
584 }
585
586 KOKKOS_FUNCTION
587 floating_point_wrapper& operator/=(double rhs) {
588 double result = static_cast<double>(val) / rhs;
589 val = static_cast<impl_type>(result);
590 return *this;
591 }
592
593 // Binary Arithmetic
594 KOKKOS_FUNCTION
595 friend floating_point_wrapper operator+(floating_point_wrapper lhs,
596 floating_point_wrapper rhs) {
597#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
598 lhs += rhs;
599#else
600 lhs.val = cast_to_wrapper(
601 cast_from_wrapper<float>(lhs) + cast_from_wrapper<float>(rhs),
602 lhs.val)
603 .val;
604#endif
605 return lhs;
606 }
607
608 // Binary Arithmetic upcast operators for +
609 template <class T>
610 KOKKOS_FUNCTION friend std::enable_if_t<
611 std::is_same<T, float>::value || std::is_same<T, double>::value, T>
612 operator+(floating_point_wrapper lhs, T rhs) {
613 return T(lhs) + rhs;
614 }
615
616 template <class T>
617 KOKKOS_FUNCTION friend std::enable_if_t<
618 std::is_same<T, float>::value || std::is_same<T, double>::value, T>
619 operator+(T lhs, floating_point_wrapper rhs) {
620 return lhs + T(rhs);
621 }
622
623 KOKKOS_FUNCTION
624 friend floating_point_wrapper operator-(floating_point_wrapper lhs,
625 floating_point_wrapper rhs) {
626#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
627 lhs -= rhs;
628#else
629 lhs.val = cast_to_wrapper(
630 cast_from_wrapper<float>(lhs) - cast_from_wrapper<float>(rhs),
631 lhs.val)
632 .val;
633#endif
634 return lhs;
635 }
636
637 // Binary Arithmetic upcast operators for -
638 template <class T>
639 KOKKOS_FUNCTION friend std::enable_if_t<
640 std::is_same<T, float>::value || std::is_same<T, double>::value, T>
641 operator-(floating_point_wrapper lhs, T rhs) {
642 return T(lhs) - rhs;
643 }
644
645 template <class T>
646 KOKKOS_FUNCTION friend std::enable_if_t<
647 std::is_same<T, float>::value || std::is_same<T, double>::value, T>
648 operator-(T lhs, floating_point_wrapper rhs) {
649 return lhs - T(rhs);
650 }
651
652 KOKKOS_FUNCTION
653 friend floating_point_wrapper operator*(floating_point_wrapper lhs,
654 floating_point_wrapper rhs) {
655#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
656 lhs *= rhs;
657#else
658 lhs.val = cast_to_wrapper(
659 cast_from_wrapper<float>(lhs) * cast_from_wrapper<float>(rhs),
660 lhs.val)
661 .val;
662#endif
663 return lhs;
664 }
665
666 // Binary Arithmetic upcast operators for *
667 template <class T>
668 KOKKOS_FUNCTION friend std::enable_if_t<
669 std::is_same<T, float>::value || std::is_same<T, double>::value, T>
670 operator*(floating_point_wrapper lhs, T rhs) {
671 return T(lhs) * rhs;
672 }
673
674 template <class T>
675 KOKKOS_FUNCTION friend std::enable_if_t<
676 std::is_same<T, float>::value || std::is_same<T, double>::value, T>
677 operator*(T lhs, floating_point_wrapper rhs) {
678 return lhs * T(rhs);
679 }
680
681 KOKKOS_FUNCTION
682 friend floating_point_wrapper operator/(floating_point_wrapper lhs,
683 floating_point_wrapper rhs) {
684#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
685 lhs /= rhs;
686#else
687 lhs.val = cast_to_wrapper(
688 cast_from_wrapper<float>(lhs) / cast_from_wrapper<float>(rhs),
689 lhs.val)
690 .val;
691#endif
692 return lhs;
693 }
694
695 // Binary Arithmetic upcast operators for /
696 template <class T>
697 KOKKOS_FUNCTION friend std::enable_if_t<
698 std::is_same<T, float>::value || std::is_same<T, double>::value, T>
699 operator/(floating_point_wrapper lhs, T rhs) {
700 return T(lhs) / rhs;
701 }
702
703 template <class T>
704 KOKKOS_FUNCTION friend std::enable_if_t<
705 std::is_same<T, float>::value || std::is_same<T, double>::value, T>
706 operator/(T lhs, floating_point_wrapper rhs) {
707 return lhs / T(rhs);
708 }
709
710 // Logical operators
711 KOKKOS_FUNCTION
712 bool operator!() const {
713#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
714 return static_cast<bool>(!val);
715#else
716 return !cast_from_wrapper<float>(*this);
717#endif
718 }
719
720 // NOTE: Loses short-circuit evaluation
721 KOKKOS_FUNCTION
722 bool operator&&(floating_point_wrapper rhs) const {
723#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
724 return static_cast<bool>(val && rhs.val);
725#else
726 return cast_from_wrapper<float>(*this) && cast_from_wrapper<float>(rhs);
727#endif
728 }
729
730 // NOTE: Loses short-circuit evaluation
731 KOKKOS_FUNCTION
732 bool operator||(floating_point_wrapper rhs) const {
733#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
734 return static_cast<bool>(val || rhs.val);
735#else
736 return cast_from_wrapper<float>(*this) || cast_from_wrapper<float>(rhs);
737#endif
738 }
739
740 // Comparison operators
741 KOKKOS_FUNCTION
742 bool operator==(floating_point_wrapper rhs) const {
743#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
744 return static_cast<bool>(val == rhs.val);
745#else
746 return cast_from_wrapper<float>(*this) == cast_from_wrapper<float>(rhs);
747#endif
748 }
749
750 KOKKOS_FUNCTION
751 bool operator!=(floating_point_wrapper rhs) const {
752#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
753 return static_cast<bool>(val != rhs.val);
754#else
755 return cast_from_wrapper<float>(*this) != cast_from_wrapper<float>(rhs);
756#endif
757 }
758
759 KOKKOS_FUNCTION
760 bool operator<(floating_point_wrapper rhs) const {
761#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
762 return static_cast<bool>(val < rhs.val);
763#else
764 return cast_from_wrapper<float>(*this) < cast_from_wrapper<float>(rhs);
765#endif
766 }
767
768 KOKKOS_FUNCTION
769 bool operator>(floating_point_wrapper rhs) const {
770#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
771 return static_cast<bool>(val > rhs.val);
772#else
773 return cast_from_wrapper<float>(*this) > cast_from_wrapper<float>(rhs);
774#endif
775 }
776
777 KOKKOS_FUNCTION
778 bool operator<=(floating_point_wrapper rhs) const {
779#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
780 return static_cast<bool>(val <= rhs.val);
781#else
782 return cast_from_wrapper<float>(*this) <= cast_from_wrapper<float>(rhs);
783#endif
784 }
785
786 KOKKOS_FUNCTION
787 bool operator>=(floating_point_wrapper rhs) const {
788#ifdef KOKKOS_HALF_IS_FULL_TYPE_ON_ARCH
789 return static_cast<bool>(val >= rhs.val);
790#else
791 return cast_from_wrapper<float>(*this) >= cast_from_wrapper<float>(rhs);
792#endif
793 }
794
795 KOKKOS_FUNCTION
796 friend bool operator==(const volatile floating_point_wrapper& lhs,
797 const volatile floating_point_wrapper& rhs) {
798 floating_point_wrapper tmp_lhs = lhs, tmp_rhs = rhs;
799 return tmp_lhs == tmp_rhs;
800 }
801
802 KOKKOS_FUNCTION
803 friend bool operator!=(const volatile floating_point_wrapper& lhs,
804 const volatile floating_point_wrapper& rhs) {
805 floating_point_wrapper tmp_lhs = lhs, tmp_rhs = rhs;
806 return tmp_lhs != tmp_rhs;
807 }
808
809 KOKKOS_FUNCTION
810 friend bool operator<(const volatile floating_point_wrapper& lhs,
811 const volatile floating_point_wrapper& rhs) {
812 floating_point_wrapper tmp_lhs = lhs, tmp_rhs = rhs;
813 return tmp_lhs < tmp_rhs;
814 }
815
816 KOKKOS_FUNCTION
817 friend bool operator>(const volatile floating_point_wrapper& lhs,
818 const volatile floating_point_wrapper& rhs) {
819 floating_point_wrapper tmp_lhs = lhs, tmp_rhs = rhs;
820 return tmp_lhs > tmp_rhs;
821 }
822
823 KOKKOS_FUNCTION
824 friend bool operator<=(const volatile floating_point_wrapper& lhs,
825 const volatile floating_point_wrapper& rhs) {
826 floating_point_wrapper tmp_lhs = lhs, tmp_rhs = rhs;
827 return tmp_lhs <= tmp_rhs;
828 }
829
830 KOKKOS_FUNCTION
831 friend bool operator>=(const volatile floating_point_wrapper& lhs,
832 const volatile floating_point_wrapper& rhs) {
833 floating_point_wrapper tmp_lhs = lhs, tmp_rhs = rhs;
834 return tmp_lhs >= tmp_rhs;
835 }
836
837 // Insertion and extraction operators
838 friend std::ostream& operator<<(std::ostream& os,
839 const floating_point_wrapper& x) {
840 const std::string out = std::to_string(static_cast<double>(x));
841 os << out;
842 return os;
843 }
844
845 friend std::istream& operator>>(std::istream& is, floating_point_wrapper& x) {
846 std::string in;
847 is >> in;
848 x = std::stod(in);
849 return is;
850 }
851};
852} // namespace Impl
853
854// Declare wrapper overloads now that floating_point_wrapper is declared
855template <class T>
856static KOKKOS_INLINE_FUNCTION Kokkos::Experimental::half_t cast_to_wrapper(
857 T x, const volatile Kokkos::Impl::half_impl_t::type&) {
858 return Kokkos::Experimental::cast_to_half(x);
859}
860
861#ifdef KOKKOS_IMPL_BHALF_TYPE_DEFINED
862template <class T>
863static KOKKOS_INLINE_FUNCTION Kokkos::Experimental::bhalf_t cast_to_wrapper(
864 T x, const volatile Kokkos::Impl::bhalf_impl_t::type&) {
865 return Kokkos::Experimental::cast_to_bhalf(x);
866}
867#endif // KOKKOS_IMPL_BHALF_TYPE_DEFINED
868
869template <class T>
870static KOKKOS_INLINE_FUNCTION T
871cast_from_wrapper(const Kokkos::Experimental::half_t& x) {
872 return Kokkos::Experimental::cast_from_half<T>(x);
873}
874
875#ifdef KOKKOS_IMPL_BHALF_TYPE_DEFINED
876template <class T>
877static KOKKOS_INLINE_FUNCTION T
878cast_from_wrapper(const Kokkos::Experimental::bhalf_t& x) {
879 return Kokkos::Experimental::cast_from_bhalf<T>(x);
880}
881#endif // KOKKOS_IMPL_BHALF_TYPE_DEFINED
882
883} // namespace Experimental
884} // namespace Kokkos
885
886#endif // KOKKOS_IMPL_HALF_TYPE_DEFINED
887
888// If none of the above actually did anything and defined a half precision type
889// define a fallback implementation here using float
890#ifndef KOKKOS_IMPL_HALF_TYPE_DEFINED
891#define KOKKOS_IMPL_HALF_TYPE_DEFINED
892#define KOKKOS_HALF_T_IS_FLOAT true
893namespace Kokkos {
894namespace Impl {
895struct half_impl_t {
896 using type = float;
897};
898} // namespace Impl
899namespace Experimental {
900
901using half_t = Kokkos::Impl::half_impl_t::type;
902
903// cast_to_half
904KOKKOS_INLINE_FUNCTION
905half_t cast_to_half(float val) { return half_t(val); }
906KOKKOS_INLINE_FUNCTION
907half_t cast_to_half(bool val) { return half_t(val); }
908KOKKOS_INLINE_FUNCTION
909half_t cast_to_half(double val) { return half_t(val); }
910KOKKOS_INLINE_FUNCTION
911half_t cast_to_half(short val) { return half_t(val); }
912KOKKOS_INLINE_FUNCTION
913half_t cast_to_half(unsigned short val) { return half_t(val); }
914KOKKOS_INLINE_FUNCTION
915half_t cast_to_half(int val) { return half_t(val); }
916KOKKOS_INLINE_FUNCTION
917half_t cast_to_half(unsigned int val) { return half_t(val); }
918KOKKOS_INLINE_FUNCTION
919half_t cast_to_half(long val) { return half_t(val); }
920KOKKOS_INLINE_FUNCTION
921half_t cast_to_half(unsigned long val) { return half_t(val); }
922KOKKOS_INLINE_FUNCTION
923half_t cast_to_half(long long val) { return half_t(val); }
924KOKKOS_INLINE_FUNCTION
925half_t cast_to_half(unsigned long long val) { return half_t(val); }
926
927// cast_from_half
928// Using an explicit list here too, since the other ones are explicit and for
929// example don't include char
930template <class T>
931KOKKOS_INLINE_FUNCTION std::enable_if_t<
932 std::is_same<T, float>::value || std::is_same<T, bool>::value ||
933 std::is_same<T, double>::value || std::is_same<T, short>::value ||
934 std::is_same<T, unsigned short>::value || std::is_same<T, int>::value ||
935 std::is_same<T, unsigned int>::value || std::is_same<T, long>::value ||
936 std::is_same<T, unsigned long>::value ||
937 std::is_same<T, long long>::value ||
938 std::is_same<T, unsigned long long>::value,
939 T>
940cast_from_half(half_t val) {
941 return T(val);
942}
943
944} // namespace Experimental
945} // namespace Kokkos
946
947#else
948#define KOKKOS_HALF_T_IS_FLOAT false
949#endif // KOKKOS_IMPL_HALF_TYPE_DEFINED
950
951#ifndef KOKKOS_IMPL_BHALF_TYPE_DEFINED
952#define KOKKOS_IMPL_BHALF_TYPE_DEFINED
953#define KOKKOS_BHALF_T_IS_FLOAT true
954namespace Kokkos {
955namespace Impl {
956struct bhalf_impl_t {
957 using type = float;
958};
959} // namespace Impl
960
961namespace Experimental {
962
963using bhalf_t = Kokkos::Impl::bhalf_impl_t::type;
964
965// cast_to_bhalf
966KOKKOS_INLINE_FUNCTION
967bhalf_t cast_to_bhalf(float val) { return bhalf_t(val); }
968KOKKOS_INLINE_FUNCTION
969bhalf_t cast_to_bhalf(bool val) { return bhalf_t(val); }
970KOKKOS_INLINE_FUNCTION
971bhalf_t cast_to_bhalf(double val) { return bhalf_t(val); }
972KOKKOS_INLINE_FUNCTION
973bhalf_t cast_to_bhalf(short val) { return bhalf_t(val); }
974KOKKOS_INLINE_FUNCTION
975bhalf_t cast_to_bhalf(unsigned short val) { return bhalf_t(val); }
976KOKKOS_INLINE_FUNCTION
977bhalf_t cast_to_bhalf(int val) { return bhalf_t(val); }
978KOKKOS_INLINE_FUNCTION
979bhalf_t cast_to_bhalf(unsigned int val) { return bhalf_t(val); }
980KOKKOS_INLINE_FUNCTION
981bhalf_t cast_to_bhalf(long val) { return bhalf_t(val); }
982KOKKOS_INLINE_FUNCTION
983bhalf_t cast_to_bhalf(unsigned long val) { return bhalf_t(val); }
984KOKKOS_INLINE_FUNCTION
985bhalf_t cast_to_bhalf(long long val) { return bhalf_t(val); }
986KOKKOS_INLINE_FUNCTION
987bhalf_t cast_to_bhalf(unsigned long long val) { return bhalf_t(val); }
988
989// cast_from_bhalf
990template <class T>
991KOKKOS_INLINE_FUNCTION std::enable_if_t<
992 std::is_same<T, float>::value || std::is_same<T, bool>::value ||
993 std::is_same<T, double>::value || std::is_same<T, short>::value ||
994 std::is_same<T, unsigned short>::value || std::is_same<T, int>::value ||
995 std::is_same<T, unsigned int>::value || std::is_same<T, long>::value ||
996 std::is_same<T, unsigned long>::value ||
997 std::is_same<T, long long>::value ||
998 std::is_same<T, unsigned long long>::value,
999 T>
1000cast_from_bhalf(bhalf_t val) {
1001 return T(val);
1002}
1003} // namespace Experimental
1004} // namespace Kokkos
1005#else
1006#define KOKKOS_BHALF_T_IS_FLOAT false
1007#endif // KOKKOS_IMPL_BHALF_TYPE_DEFINED
1008#ifdef KOKKOS_IMPL_PUBLIC_INCLUDE_NOTDEFINED_HALF
1009#undef KOKKOS_IMPL_PUBLIC_INCLUDE
1010#undef KOKKOS_IMPL_PUBLIC_INCLUDE_NOTDEFINED_HALF
1011#endif
1012#endif // KOKKOS_HALF_HPP_
KOKKOS_FORCEINLINE_FUNCTION constexpr bool operator<(const pair< T1, T2 > &lhs, const pair< T1, T2 > &rhs)
Less-than operator for Kokkos::pair.
KOKKOS_FORCEINLINE_FUNCTION constexpr bool operator<=(const pair< T1, T2 > &lhs, const pair< T1, T2 > &rhs)
Less-than-or-equal-to operator for Kokkos::pair.
KOKKOS_FORCEINLINE_FUNCTION constexpr bool operator>=(const pair< T1, T2 > &lhs, const pair< T1, T2 > &rhs)
Greater-than-or-equal-to operator for Kokkos::pair.
KOKKOS_FORCEINLINE_FUNCTION constexpr bool operator>(const pair< T1, T2 > &lhs, const pair< T1, T2 > &rhs)
Greater-than operator for Kokkos::pair.