Skip to content

Commit

Permalink
fix(storage): Fix SocketTimeoutException when executing a long multi-…
Browse files Browse the repository at this point in the history
…part upload (#2973)
  • Loading branch information
vincetran authored Jan 15, 2025
1 parent 3e568e7 commit 7581848
Show file tree
Hide file tree
Showing 11 changed files with 288 additions and 156 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ internal class AbortMultiPartUploadWorker(
private val transferStatusUpdater: TransferStatusUpdater,
context: Context,
workerParameters: WorkerParameters
) : BaseTransferWorker(transferStatusUpdater, transferDB, context, workerParameters) {
) : SuspendingTransferWorker(transferStatusUpdater, transferDB, context, workerParameters) {

override suspend fun performWork(): Result {
val s3: S3Client = clientProvider.getStorageTransferClient(transferRecord.region, transferRecord.bucketName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,10 @@

package com.amplifyframework.storage.s3.transfer.worker

import android.app.NotificationChannel
import android.app.NotificationManager
import android.content.Context
import android.net.ConnectivityManager
import android.net.NetworkCapabilities
import android.os.Build
import android.util.Log
import androidx.annotation.RequiresApi
import androidx.core.app.NotificationCompat
import androidx.work.CoroutineWorker
import androidx.work.Data
import androidx.work.ForegroundInfo
import androidx.work.WorkerParameters
import androidx.work.workDataOf
import aws.sdk.kotlin.services.s3.model.ObjectCannedAcl
import aws.sdk.kotlin.services.s3.model.PutObjectRequest
import aws.sdk.kotlin.services.s3.model.RequestPayer
Expand All @@ -37,40 +27,15 @@ import aws.sdk.kotlin.services.s3.model.StorageClass
import aws.smithy.kotlin.runtime.content.ByteStream
import aws.smithy.kotlin.runtime.content.fromFile
import aws.smithy.kotlin.runtime.time.Instant
import com.amplifyframework.core.Amplify
import com.amplifyframework.core.category.CategoryType
import com.amplifyframework.storage.ObjectMetadata
import com.amplifyframework.storage.TransferState
import com.amplifyframework.storage.s3.AWSS3StoragePlugin
import com.amplifyframework.storage.s3.R
import com.amplifyframework.storage.s3.transfer.ProgressListener
import com.amplifyframework.storage.s3.transfer.TransferDB
import com.amplifyframework.storage.s3.transfer.TransferRecord
import com.amplifyframework.storage.s3.transfer.TransferStatusUpdater
import java.io.File
import java.lang.Exception
import java.net.SocketException
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.currentCoroutineContext
import kotlinx.coroutines.isActive

/**
* Base worker to perform transfer file task.
*/
internal abstract class BaseTransferWorker(
private val transferStatusUpdater: TransferStatusUpdater,
private val transferDB: TransferDB,
context: Context,
workerParameters: WorkerParameters
) : CoroutineWorker(context, workerParameters) {

internal lateinit var transferRecord: TransferRecord
internal lateinit var outputData: Data
private val logger =
Amplify.Logging.logger(
CategoryType.STORAGE,
AWSS3StoragePlugin.AWS_S3_STORAGE_LOG_NAMESPACE.format(this::class.java.simpleName)
)
internal interface BaseTransferWorker {

companion object {
internal const val PART_RECORD_ID = "PART_RECORD_ID"
Expand All @@ -86,91 +51,7 @@ internal abstract class BaseTransferWorker(
internal const val MULTIPART_UPLOAD: String = "MULTIPART_UPLOAD"
}

override suspend fun doWork(): Result {
// Foreground task is disabled until the foreground notification behavior and the recent customer feedback,
// it will be enabled in future based on the customer request.
val isForegroundTask: Boolean = (inputData.keyValueMap[RUN_AS_FOREGROUND_TASK] ?: false) as Boolean
if (isForegroundTask) {
setForegroundAsync(getForegroundInfo())
}
val result = runCatching {
val transferRecordId =
inputData.keyValueMap[PART_RECORD_ID] as? Int ?: inputData.keyValueMap[TRANSFER_RECORD_ID] as Int
outputData = workDataOf(OUTPUT_TRANSFER_RECORD_ID to inputData.keyValueMap[TRANSFER_RECORD_ID] as Int)
transferDB.getTransferRecordById(transferRecordId)?.let { tr ->
transferRecord = tr
performWork()
} ?: return run {
Result.failure(outputData)
}
}

return when {
result.isSuccess -> {
result.getOrThrow()
}
else -> {
val ex = result.exceptionOrNull()
if (currentCoroutineContext().isActive) {
logger.error("${this.javaClass.simpleName} failed with exception: ${Log.getStackTraceString(ex)}")
}
if (!currentCoroutineContext().isActive && isRetryableError(ex)) {
Result.retry()
} else {
transferStatusUpdater.updateOnError(transferRecord.id, Exception(ex))
transferStatusUpdater.updateTransferState(
transferRecord.id,
TransferState.FAILED
)
Result.failure(outputData)
}
}
}
}

abstract suspend fun performWork(): Result

internal open var maxRetryCount = 0

override suspend fun getForegroundInfo(): ForegroundInfo {
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.O) {
createChannel()
}
val appIcon = R.drawable.amplify_storage_transfer_notification_icon
return ForegroundInfo(
1,
NotificationCompat.Builder(
applicationContext,
applicationContext.getString(R.string.amplify_storage_notification_channel_id)
)
.setSmallIcon(appIcon)
.setContentTitle(applicationContext.getString(R.string.amplify_storage_notification_title))
.build()
)
}

private fun isRetryableError(e: Throwable?): Boolean {
return !isNetworkAvailable(applicationContext) ||
runAttemptCount < maxRetryCount ||
e is CancellationException ||
// SocketException is thrown when download is terminated due to network disconnection.
e is SocketException
}

@RequiresApi(Build.VERSION_CODES.O)
private fun createChannel() {
val notificationManager =
applicationContext.getSystemService(Context.NOTIFICATION_SERVICE) as NotificationManager
notificationManager.createNotificationChannel(
NotificationChannel(
applicationContext.getString(R.string.amplify_storage_notification_channel_id),
applicationContext.getString(R.string.amplify_storage_notification_channel_name),
NotificationManager.IMPORTANCE_DEFAULT
)
)
}

private fun isNetworkAvailable(context: Context): Boolean {
fun isNetworkAvailable(context: Context): Boolean {
val connectivityManager =
context.getSystemService(Context.CONNECTIVITY_SERVICE) as ConnectivityManager
if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.Q) {
Expand Down Expand Up @@ -198,7 +79,7 @@ internal abstract class BaseTransferWorker(
return false
}

internal fun createPutObjectRequest(
fun createPutObjectRequest(
transferRecord: TransferRecord,
progressListener: ProgressListener?
): PutObjectRequest {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Copyright 2025 Amazon.com, Inc. or its affiliates. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License").
* You may not use this file except in compliance with the License.
* A copy of the License is located at
*
* http://aws.amazon.com/apache2.0
*
* or in the "license" file accompanying this file. This file is distributed
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
* express or implied. See the License for the specific language governing
* permissions and limitations under the License.
*/

package com.amplifyframework.storage.s3.transfer.worker

import android.content.Context
import android.util.Log
import androidx.work.Data
import androidx.work.Worker
import androidx.work.WorkerParameters
import androidx.work.workDataOf
import com.amplifyframework.core.Amplify
import com.amplifyframework.core.category.CategoryType
import com.amplifyframework.storage.TransferState
import com.amplifyframework.storage.s3.AWSS3StoragePlugin
import com.amplifyframework.storage.s3.transfer.TransferDB
import com.amplifyframework.storage.s3.transfer.TransferRecord
import com.amplifyframework.storage.s3.transfer.TransferStatusUpdater
import com.amplifyframework.storage.s3.transfer.worker.BaseTransferWorker.Companion.OUTPUT_TRANSFER_RECORD_ID
import com.amplifyframework.storage.s3.transfer.worker.BaseTransferWorker.Companion.PART_RECORD_ID
import com.amplifyframework.storage.s3.transfer.worker.BaseTransferWorker.Companion.TRANSFER_RECORD_ID
import java.lang.Exception
import java.net.SocketException

/**
* Base worker to perform transfer file task.
*/
internal abstract class BlockingTransferWorker(
private val transferStatusUpdater: TransferStatusUpdater,
private val transferDB: TransferDB,
context: Context,
workerParameters: WorkerParameters
) : Worker(context, workerParameters), BaseTransferWorker {

internal lateinit var transferRecord: TransferRecord
internal lateinit var outputData: Data

private val logger =
Amplify.Logging.logger(
CategoryType.STORAGE,
AWSS3StoragePlugin.AWS_S3_STORAGE_LOG_NAMESPACE.format(this::class.java.simpleName)
)

override fun doWork(): Result {
val result = runCatching {
val transferRecordId =
inputData.keyValueMap[PART_RECORD_ID] as? Int ?: inputData.keyValueMap[TRANSFER_RECORD_ID] as Int
outputData = workDataOf(OUTPUT_TRANSFER_RECORD_ID to inputData.keyValueMap[TRANSFER_RECORD_ID] as Int)
transferDB.getTransferRecordById(transferRecordId)?.let { tr ->
transferRecord = tr
performWork()
} ?: return run {
Result.failure(outputData)
}
}

return when {
result.isSuccess -> {
result.getOrThrow()
}
else -> {
val ex = result.exceptionOrNull()
logger.error("${this.javaClass.simpleName} failed with exception: ${Log.getStackTraceString(ex)}")
if (isRetryableError(ex)) {
Result.retry()
} else {
transferStatusUpdater.updateOnError(transferRecord.id, Exception(ex))
transferStatusUpdater.updateTransferState(
transferRecord.id,
TransferState.FAILED
)
Result.failure(outputData)
}
}
}
}

abstract fun performWork(): Result

internal open var maxRetryCount = 0

private fun isRetryableError(e: Throwable?): Boolean {
return !isNetworkAvailable(applicationContext) ||
runAttemptCount < maxRetryCount ||
// SocketException is thrown when download is terminated due to network disconnection.
e is SocketException
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ internal class CompleteMultiPartUploadWorker(
private val transferStatusUpdater: TransferStatusUpdater,
context: Context,
workerParameters: WorkerParameters
) : BaseTransferWorker(transferStatusUpdater, transferDB, context, workerParameters) {
) : SuspendingTransferWorker(transferStatusUpdater, transferDB, context, workerParameters) {

override suspend fun performWork(): Result {
val completedParts = transferDB.queryPartETagsOfUpload(transferRecord.id)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ internal class DownloadWorker(
private val transferStatusUpdater: TransferStatusUpdater,
context: Context,
workerParameters: WorkerParameters
) : BaseTransferWorker(transferStatusUpdater, transferDB, context, workerParameters) {
) : SuspendingTransferWorker(transferStatusUpdater, transferDB, context, workerParameters) {

private lateinit var downloadProgressListener: DownloadProgressListener
private val defaultBufferSize = 8192L
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import com.amplifyframework.storage.TransferState
import com.amplifyframework.storage.s3.transfer.StorageTransferClientProvider
import com.amplifyframework.storage.s3.transfer.TransferDB
import com.amplifyframework.storage.s3.transfer.TransferStatusUpdater
import com.amplifyframework.storage.s3.transfer.worker.BaseTransferWorker.Companion.MULTI_PART_UPLOAD_ID
import com.amplifyframework.storage.s3.transfer.worker.BaseTransferWorker.Companion.TRANSFER_RECORD_ID

/**
* Worker to initiate multipart upload
Expand All @@ -34,7 +36,7 @@ internal class InitiateMultiPartUploadTransferWorker(
private val transferStatusUpdater: TransferStatusUpdater,
context: Context,
workerParameters: WorkerParameters
) : BaseTransferWorker(transferStatusUpdater, transferDB, context, workerParameters) {
) : SuspendingTransferWorker(transferStatusUpdater, transferDB, context, workerParameters) {

override suspend fun performWork(): Result {
val s3: S3Client = clientProvider.getStorageTransferClient(transferRecord.region, transferRecord.bucketName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ import com.amplifyframework.storage.s3.transfer.StorageTransferClientProvider
import com.amplifyframework.storage.s3.transfer.TransferDB
import com.amplifyframework.storage.s3.transfer.TransferStatusUpdater
import com.amplifyframework.storage.s3.transfer.UploadProgressListenerInterceptor
import com.amplifyframework.storage.s3.transfer.worker.BaseTransferWorker.Companion.MULTI_PART_UPLOAD_ID
import java.io.File
import kotlinx.coroutines.currentCoroutineContext
import kotlinx.coroutines.isActive
import kotlinx.coroutines.runBlocking

/**
* Worker to upload a part for multipart upload
Expand All @@ -39,41 +39,39 @@ internal class PartUploadTransferWorker(
private val transferStatusUpdater: TransferStatusUpdater,
context: Context,
workerParameters: WorkerParameters
) : BaseTransferWorker(transferStatusUpdater, transferDB, context, workerParameters) {
) : BlockingTransferWorker(transferStatusUpdater, transferDB, context, workerParameters) {

private lateinit var multiPartUploadId: String
private lateinit var partUploadProgressListener: PartUploadProgressListener
override var maxRetryCount = 3

override suspend fun performWork(): Result {
if (!currentCoroutineContext().isActive) {
return Result.retry()
}
override fun performWork(): Result {
transferStatusUpdater.updateTransferState(transferRecord.mainUploadId, TransferState.IN_PROGRESS)
multiPartUploadId = inputData.keyValueMap[MULTI_PART_UPLOAD_ID] as String
partUploadProgressListener = PartUploadProgressListener(transferRecord, transferStatusUpdater)
val s3: S3Client = clientProvider.getStorageTransferClient(transferRecord.region, transferRecord.bucketName)
return s3.withConfig {
interceptors += UploadProgressListenerInterceptor(partUploadProgressListener)
enableAccelerate = transferRecord.useAccelerateEndpoint == 1
}.uploadPart {
bucket = transferRecord.bucketName
key = transferRecord.key
uploadId = multiPartUploadId
body = File(transferRecord.file).asByteStream(
start = transferRecord.fileOffset,
transferRecord.fileOffset + transferRecord.bytesTotal - 1
)
partNumber = transferRecord.partNumber
}.let { response ->
response.eTag?.let { tag ->
transferDB.updateETag(transferRecord.id, tag)
transferDB.updateState(transferRecord.id, TransferState.PART_COMPLETED)
updateProgress()
Result.success(outputData)
} ?: run {
throw IllegalStateException("Etag is empty")

return runBlocking {
s3.withConfig {
interceptors += UploadProgressListenerInterceptor(partUploadProgressListener)
enableAccelerate = transferRecord.useAccelerateEndpoint == 1
}.uploadPart {
bucket = transferRecord.bucketName
key = transferRecord.key
uploadId = multiPartUploadId
body = File(transferRecord.file).asByteStream(
start = transferRecord.fileOffset,
transferRecord.fileOffset + transferRecord.bytesTotal - 1
)
partNumber = transferRecord.partNumber
}
}.eTag?.let { tag ->
transferDB.updateETag(transferRecord.id, tag)
transferDB.updateState(transferRecord.id, TransferState.PART_COMPLETED)
updateProgress()
return Result.success(outputData)
} ?: run {
throw IllegalStateException("Etag is empty")
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ internal class RouterWorker(
?: throw IllegalArgumentException("Worker class name is missing")
private val workerId = parameter.inputData.getString(BaseTransferWorker.WORKER_ID)

private var delegateWorker: BaseTransferWorker? = null
private var delegateWorker: ListenableWorker? = null

companion object {
internal const val WORKER_CLASS_NAME = "WORKER_CLASS_NAME"
Expand Down
Loading

0 comments on commit 7581848

Please sign in to comment.