diff --git a/examples/thrust/thrust.cu b/examples/thrust/thrust.cu index 1e52206e99..f5fb22277a 100644 --- a/examples/thrust/thrust.cu +++ b/examples/thrust/thrust.cu @@ -321,6 +321,27 @@ auto thrustDeviceAlloc = [](auto alignment, std::size_t size) return p; }; +template +struct IndexToViewIterator +{ + View view; + LLAMA_FN_HOST_ACC_INLINE auto operator()(std::size_t i) + { + return *(view.begin() + i); + } +}; + +template +auto make_view_it(View view, std::size_t i) +{ + auto ci = thrust::counting_iterator{0}; + return thrust::transform_iterator< + IndexToViewIterator, + decltype(ci), + typename View::iterator::reference, + typename View::iterator::value_type>{ci, IndexToViewIterator{std::move(view)}}; +} + template void run(std::ostream& plotFile) { @@ -375,8 +396,16 @@ void run(std::ostream& plotFile) auto view = llama::allocView(mapping, thrustDeviceAlloc); + auto b = make_view_it(view, 0); + auto e = make_view_it(view, N); + // auto b = view.begin(); + // auto e = view.end(); + + auto r = (*b); + r(tag::eventId{}) = 0; + // touch memory once before running benchmarks - thrust::fill(thrust::device, view.begin(), view.end(), 0); + thrust::fill(thrust::device, b, e, 0); syncWithCuda(); //#if THRUST_DEVICE_SYSTEM == THRUST_DEVICE_SYSTEM_CUDA @@ -427,7 +456,7 @@ void run(std::ostream& plotFile) } else { - thrust::tabulate(thrust::device, view.begin(), view.end(), InitOne{}); + thrust::tabulate(thrust::device, b, e, InitOne{}); syncWithCuda(); } tabulateTotal += stopwatch.printAndReset("tabulate", '\t'); @@ -453,7 +482,7 @@ void run(std::ostream& plotFile) { Stopwatch stopwatch; if constexpr(usePSTL) - std::for_each(exec, view.begin(), view.end(), NormalizeVel{}); + std::for_each(exec, b, e, NormalizeVel{}); else { thrust::for_each( @@ -471,10 +500,10 @@ void run(std::ostream& plotFile) thrust::device_vector dst(N); Stopwatch stopwatch; if constexpr(usePSTL) - std::transform(exec, view.begin(), view.end(), dst.begin(), GetMass{}); + std::transform(exec, b, e, dst.begin(), GetMass{}); else { - thrust::transform(thrust::device, view.begin(), view.end(), dst.begin(), GetMass{}); + thrust::transform(thrust::device, b, e, dst.begin(), GetMass{}); syncWithCuda(); } transformTotal += stopwatch.printAndReset("transform", '\t'); @@ -489,8 +518,8 @@ void run(std::ostream& plotFile) if constexpr(usePSTL) std::transform_exclusive_scan( exec, - view.begin(), - view.end(), + b, + e, scan_result.begin(), std::uint32_t{0}, std::plus<>{}, @@ -499,8 +528,8 @@ void run(std::ostream& plotFile) { thrust::transform_exclusive_scan( thrust::device, - view.begin(), - view.end(), + b, + e, scan_result.begin(), Predicate{}, std::uint32_t{0}, @@ -516,16 +545,10 @@ void run(std::ostream& plotFile) { Stopwatch stopwatch; if constexpr(usePSTL) - sink = std::transform_reduce(exec, view.begin(), view.end(), MassType{0}, std::plus<>{}, GetMass{}); + sink = std::transform_reduce(exec, b, e, MassType{0}, std::plus<>{}, GetMass{}); else { - sink = thrust::transform_reduce( - thrust::device, - view.begin(), - view.end(), - GetMass{}, - MassType{0}, - thrust::plus<>{}); + sink = thrust::transform_reduce(thrust::device, b, e, GetMass{}, MassType{0}, thrust::plus<>{}); syncWithCuda(); } transformReduceTotal += stopwatch.printAndReset("transform_reduce", '\t'); @@ -533,12 +556,13 @@ void run(std::ostream& plotFile) { auto dstView = llama::allocView(mapping, thrustDeviceAlloc); + auto db = make_view_it(dstView, 0); Stopwatch stopwatch; if constexpr(usePSTL) - std::copy(exec, view.begin(), view.end(), dstView.begin()); + std::copy(exec, b, e, db); else { - thrust::copy(thrust::device, view.begin(), view.end(), dstView.begin()); + thrust::copy(thrust::device, b, e, db); syncWithCuda(); } copyTotal += stopwatch.printAndReset("copy", '\t'); @@ -548,12 +572,13 @@ void run(std::ostream& plotFile) { auto dstView = llama::allocView(mapping, thrustDeviceAlloc); + auto db = make_view_it(dstView, 0); Stopwatch stopwatch; if constexpr(usePSTL) - std::copy_if(exec, view.begin(), view.end(), dstView.begin(), Predicate{}); + std::copy_if(exec, b, e, db, Predicate{}); else { - thrust::copy_if(thrust::device, view.begin(), view.end(), dstView.begin(), Predicate{}); + thrust::copy_if(thrust::device, b, e, db, Predicate{}); syncWithCuda(); } copyIfTotal += stopwatch.printAndReset("copy_if", '\t'); @@ -564,10 +589,10 @@ void run(std::ostream& plotFile) { Stopwatch stopwatch; if constexpr(usePSTL) - std::remove_if(exec, view.begin(), view.end(), Predicate{}); + std::remove_if(exec, b, e, Predicate{}); else { - thrust::remove_if(thrust::device, view.begin(), view.end(), Predicate{}); + thrust::remove_if(thrust::device, b, e, Predicate{}); syncWithCuda(); } removeIfTotal += stopwatch.printAndReset("remove_if", '\t'); @@ -576,14 +601,14 @@ void run(std::ostream& plotFile) //{ // Stopwatch stopwatch; // if constexpr(usePSTL) - // std::sort(std::execution::par, view.begin(), view.end(), Less{}); + // std::sort(std::execution::par, b, e, Less{}); // else // { - // thrust::sort(thrust::device, view.begin(), view.end(), Less{}); + // thrust::sort(thrust::device, b, e, Less{}); // syncWithCuda(); // } // sortTotal += stopwatch.printAndReset("sort", '\t'); - // if(!thrust::is_sorted(thrust::device, view.begin(), view.end(), Less{})) + // if(!thrust::is_sorted(thrust::device, b, e, Less{})) // std::cerr << "VALIDATION FAILED\n"; //} diff --git a/include/llama/ArrayIndexRange.hpp b/include/llama/ArrayIndexRange.hpp index 4d8c2c7d99..60cbd8d76c 100644 --- a/include/llama/ArrayIndexRange.hpp +++ b/include/llama/ArrayIndexRange.hpp @@ -120,11 +120,11 @@ namespace llama current[0] = static_cast(current[0]) + n; // current is either within bounds or at the end ([last + 1, 0, 0, ..., 0]) - //assert( - // (current[0] < extents[0] - // || (current[0] == extents[0] - // && std::all_of(std::begin(current) + 1, std::end(current), [](auto c) { return c == 0; }))) - // && "Iterator was moved past the end"); + assert( + (current[0] < extents[0] + || (current[0] == extents[0] + && std::all_of(std::begin(current) + 1, std::end(current), [](auto c) { return c == 0; }))) + && "Iterator was moved past the end"); return *this; } diff --git a/include/llama/RecordRef.hpp b/include/llama/RecordRef.hpp index a64c297ec2..207bd8de54 100644 --- a/include/llama/RecordRef.hpp +++ b/include/llama/RecordRef.hpp @@ -381,8 +381,7 @@ namespace llama using ArrayIndex = typename View::Mapping::ArrayIndex; using RecordDim = typename View::Mapping::RecordDim; - // std::conditional_t view; - View view; + std::conditional_t view; public: /// Subtree of the record dimension of View starting at BoundRecordCoord. If BoundRecordCoord is diff --git a/include/llama/View.hpp b/include/llama/View.hpp index dcefec8cb1..69a07c11d3 100644 --- a/include/llama/View.hpp +++ b/include/llama/View.hpp @@ -229,9 +229,9 @@ namespace llama constexpr Iterator() = default; - LLAMA_FN_HOST_ACC_INLINE constexpr Iterator(ArrayIndexIterator arrayIndex, View view) + LLAMA_FN_HOST_ACC_INLINE constexpr Iterator(ArrayIndexIterator arrayIndex, View* view) : arrayIndex(arrayIndex) - , view(std::move(view)) + , view(view) { } @@ -268,7 +268,7 @@ namespace llama LLAMA_FN_HOST_ACC_INLINE constexpr auto operator*() const -> reference { - return const_cast(view)(*arrayIndex); + return (*view)(*arrayIndex); } LLAMA_FN_HOST_ACC_INLINE @@ -363,7 +363,7 @@ namespace llama } ArrayIndexIterator arrayIndex; - View view; + View* view; }; /// Using a mapping, maps the given array index and record coordinate to a memory reference onto the given blobs. @@ -559,25 +559,25 @@ namespace llama LLAMA_FN_HOST_ACC_INLINE auto begin() -> iterator { - return {ArrayIndexRange{mapping().extents()}.begin(), *this}; + return {ArrayIndexRange{mapping().extents()}.begin(), this}; } LLAMA_FN_HOST_ACC_INLINE auto begin() const -> const_iterator { - return {ArrayIndexRange{mapping().extents()}.begin(), *this}; + return {ArrayIndexRange{mapping().extents()}.begin(), this}; } LLAMA_FN_HOST_ACC_INLINE auto end() -> iterator { - return {ArrayIndexRange{mapping().extents()}.end(), *this}; + return {ArrayIndexRange{mapping().extents()}.end(), this}; } LLAMA_FN_HOST_ACC_INLINE auto end() const -> const_iterator { - return {ArrayIndexRange{mapping().extents()}.end(), *this}; + return {ArrayIndexRange{mapping().extents()}.end(), this}; } Array storageBlobs;