trying to do rice partitioning on gpu

This commit is contained in:
chudov
2009-10-04 17:28:52 +00:00
parent 609f160457
commit d96ba46ad3
3 changed files with 222 additions and 166 deletions

View File

@@ -1192,11 +1192,12 @@ namespace CUETools.Codecs.FlaCuda
cuda.SetParameter(cudaCalcPartition, 0, (uint)task.cudaPartitions.Pointer);
cuda.SetParameter(cudaCalcPartition, 1 * sizeof(uint), (uint)task.cudaResidual.Pointer);
cuda.SetParameter(cudaCalcPartition, 2 * sizeof(uint), (uint)task.cudaBestResidualTasks.Pointer);
cuda.SetParameter(cudaCalcPartition, 3 * sizeof(uint), (uint)max_porder);
cuda.SetParameter(cudaCalcPartition, 4 * sizeof(uint), (uint)calcPartitionPartSize);
cuda.SetParameter(cudaCalcPartition, 5 * sizeof(uint), (uint)calcPartitionPartCount);
cuda.SetParameterSize(cudaCalcPartition, 6U * sizeof(uint));
cuda.SetParameter(cudaCalcPartition, 2 * sizeof(uint), (uint)task.cudaSamples.Pointer);
cuda.SetParameter(cudaCalcPartition, 3 * sizeof(uint), (uint)task.cudaBestResidualTasks.Pointer);
cuda.SetParameter(cudaCalcPartition, 4 * sizeof(uint), (uint)max_porder);
cuda.SetParameter(cudaCalcPartition, 5 * sizeof(uint), (uint)calcPartitionPartSize);
cuda.SetParameter(cudaCalcPartition, 6 * sizeof(uint), (uint)calcPartitionPartCount);
cuda.SetParameterSize(cudaCalcPartition, 7U * sizeof(uint));
cuda.SetFunctionBlockShape(cudaCalcPartition, 16, 16, 1);
cuda.SetParameter(task.cudaSumPartition, 0, (uint)task.cudaPartitions.Pointer);
@@ -1237,7 +1238,8 @@ namespace CUETools.Codecs.FlaCuda
if (!encode_on_cpu)
{
int bsz = calcPartitionPartCount * calcPartitionPartSize;
cuda.LaunchAsync(task.cudaEncodeResidual, residualPartCount, channels * task.frameCount, task.stream);
if (cudaCalcPartition.Pointer != task.cudaCalcPartition.Pointer)
cuda.LaunchAsync(task.cudaEncodeResidual, residualPartCount, channels * task.frameCount, task.stream);
cuda.LaunchAsync(cudaCalcPartition, (task.frameSize + bsz - 1) / bsz, channels * task.frameCount, task.stream);
if (max_porder > 0)
cuda.LaunchAsync(task.cudaSumPartition, Flake.MAX_RICE_PARAM + 1, channels * task.frameCount, task.stream);