Processing a database activity stream using the AWS SDK - Amazon Aurora

Processing a database activity stream using the AWS SDK

You can programmatically process an activity stream by using the AWS SDK. The following are fully functioning Java and Python examples of how you might process the Kinesis data stream.

Java
import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.net.InetAddress; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.security.NoSuchAlgorithmException; import java.security.NoSuchProviderException; import java.security.Security; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.UUID; import java.util.zip.GZIPInputStream; import javax.crypto.Cipher; import javax.crypto.NoSuchPaddingException; import javax.crypto.spec.SecretKeySpec; import com.amazonaws.auth.AWSStaticCredentialsProvider; import com.amazonaws.auth.BasicAWSCredentials; import com.amazonaws.encryptionsdk.AwsCrypto; import com.amazonaws.encryptionsdk.CryptoInputStream; import com.amazonaws.encryptionsdk.jce.JceMasterKey; import com.amazonaws.services.kinesis.clientlibrary.exceptions.InvalidStateException; import com.amazonaws.services.kinesis.clientlibrary.exceptions.ShutdownException; import com.amazonaws.services.kinesis.clientlibrary.exceptions.ThrottlingException; import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessor; import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorCheckpointer; import com.amazonaws.services.kinesis.clientlibrary.interfaces.IRecordProcessorFactory; import com.amazonaws.services.kinesis.clientlibrary.lib.worker.InitialPositionInStream; import com.amazonaws.services.kinesis.clientlibrary.lib.worker.KinesisClientLibConfiguration; import com.amazonaws.services.kinesis.clientlibrary.lib.worker.ShutdownReason; import com.amazonaws.services.kinesis.clientlibrary.lib.worker.Worker; import com.amazonaws.services.kinesis.clientlibrary.lib.worker.Worker.Builder; import com.amazonaws.services.kinesis.model.Record; import com.amazonaws.services.kms.AWSKMS; import com.amazonaws.services.kms.AWSKMSClientBuilder; import com.amazonaws.services.kms.model.DecryptRequest; import com.amazonaws.services.kms.model.DecryptResult; import com.amazonaws.util.Base64; import com.amazonaws.util.IOUtils; import com.google.gson.Gson; import com.google.gson.GsonBuilder; import com.google.gson.annotations.SerializedName; import org.bouncycastle.jce.provider.BouncyCastleProvider; public class DemoConsumer { private static final String STREAM_NAME = "aws-rds-das-[cluster-external-resource-id]"; private static final String APPLICATION_NAME = "AnyApplication"; //unique application name for dynamo table generation that holds kinesis shard tracking private static final String AWS_ACCESS_KEY = "[AWS_ACCESS_KEY_TO_ACCESS_KINESIS]"; private static final String AWS_SECRET_KEY = "[AWS_SECRET_KEY_TO_ACCESS_KINESIS]"; private static final String DBC_RESOURCE_ID = "[cluster-external-resource-id]"; private static final String REGION_NAME = "[region-name]"; //us-east-1, us-east-2... private static final BasicAWSCredentials CREDENTIALS = new BasicAWSCredentials(AWS_ACCESS_KEY, AWS_SECRET_KEY); private static final AWSStaticCredentialsProvider CREDENTIALS_PROVIDER = new AWSStaticCredentialsProvider(CREDENTIALS); private static final AwsCrypto CRYPTO = new AwsCrypto(); private static final AWSKMS KMS = AWSKMSClientBuilder.standard() .withRegion(REGION_NAME) .withCredentials(CREDENTIALS_PROVIDER).build(); class Activity { String type; String version; String databaseActivityEvents; String key; } class ActivityEvent { @SerializedName("class") String _class; String clientApplication; String command; String commandText; String databaseName; String dbProtocol; String dbUserName; String endTime; String errorMessage; String exitCode; String logTime; String netProtocol; String objectName; String objectType; List<String> paramList; String pid; String remoteHost; String remotePort; String rowCount; String serverHost; String serverType; String serverVersion; String serviceName; String sessionId; String startTime; String statementId; String substatementId; String transactionId; String type; } class ActivityRecords { String type; String clusterId; String instanceId; List<ActivityEvent> databaseActivityEventList; } static class RecordProcessorFactory implements IRecordProcessorFactory { @Override public IRecordProcessor createProcessor() { return new RecordProcessor(); } } static class RecordProcessor implements IRecordProcessor { private static final long BACKOFF_TIME_IN_MILLIS = 3000L; private static final int PROCESSING_RETRIES_MAX = 10; private static final long CHECKPOINT_INTERVAL_MILLIS = 60000L; private static final Gson GSON = new GsonBuilder().serializeNulls().create(); private static final Cipher CIPHER; static { Security.insertProviderAt(new BouncyCastleProvider(), 1); try { CIPHER = Cipher.getInstance("AES/GCM/NoPadding", "BC"); } catch (NoSuchAlgorithmException | NoSuchPaddingException | NoSuchProviderException e) { throw new ExceptionInInitializerError(e); } } private long nextCheckpointTimeInMillis; @Override public void initialize(String shardId) { } @Override public void processRecords(final List<Record> records, final IRecordProcessorCheckpointer checkpointer) { for (final Record record : records) { processSingleBlob(record.getData()); } if (System.currentTimeMillis() > nextCheckpointTimeInMillis) { checkpoint(checkpointer); nextCheckpointTimeInMillis = System.currentTimeMillis() + CHECKPOINT_INTERVAL_MILLIS; } } @Override public void shutdown(IRecordProcessorCheckpointer checkpointer, ShutdownReason reason) { if (reason == ShutdownReason.TERMINATE) { checkpoint(checkpointer); } } private void processSingleBlob(final ByteBuffer bytes) { try { // JSON $Activity final Activity activity = GSON.fromJson(new String(bytes.array(), StandardCharsets.UTF_8), Activity.class); // Base64.Decode final byte[] decoded = Base64.decode(activity.databaseActivityEvents); final byte[] decodedDataKey = Base64.decode(activity.key); Map<String, String> context = new HashMap<>(); context.put("aws:rds:dbc-id", DBC_RESOURCE_ID); // Decrypt final DecryptRequest decryptRequest = new DecryptRequest() .withCiphertextBlob(ByteBuffer.wrap(decodedDataKey)).withEncryptionContext(context); final DecryptResult decryptResult = KMS.decrypt(decryptRequest); final byte[] decrypted = decrypt(decoded, getByteArray(decryptResult.getPlaintext())); // GZip Decompress final byte[] decompressed = decompress(decrypted); // JSON $ActivityRecords final ActivityRecords activityRecords = GSON.fromJson(new String(decompressed, StandardCharsets.UTF_8), ActivityRecords.class); // Iterate throught $ActivityEvents for (final ActivityEvent event : activityRecords.databaseActivityEventList) { System.out.println(GSON.toJson(event)); } } catch (Exception e) { // Handle error. e.printStackTrace(); } } private static byte[] decompress(final byte[] src) throws IOException { ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(src); GZIPInputStream gzipInputStream = new GZIPInputStream(byteArrayInputStream); return IOUtils.toByteArray(gzipInputStream); } private void checkpoint(IRecordProcessorCheckpointer checkpointer) { for (int i = 0; i < PROCESSING_RETRIES_MAX; i++) { try { checkpointer.checkpoint(); break; } catch (ShutdownException se) { // Ignore checkpoint if the processor instance has been shutdown (fail over). System.out.println("Caught shutdown exception, skipping checkpoint." + se); break; } catch (ThrottlingException e) { // Backoff and re-attempt checkpoint upon transient failures if (i >= (PROCESSING_RETRIES_MAX - 1)) { System.out.println("Checkpoint failed after " + (i + 1) + "attempts." + e); break; } else { System.out.println("Transient issue when checkpointing - attempt " + (i + 1) + " of " + PROCESSING_RETRIES_MAX + e); } } catch (InvalidStateException e) { // This indicates an issue with the DynamoDB table (check for table, provisioned IOPS). System.out.println("Cannot save checkpoint to the DynamoDB table used by the Amazon Kinesis Client Library." + e); break; } try { Thread.sleep(BACKOFF_TIME_IN_MILLIS); } catch (InterruptedException e) { System.out.println("Interrupted sleep" + e); } } } } private static byte[] decrypt(final byte[] decoded, final byte[] decodedDataKey) throws IOException { // Create a JCE master key provider using the random key and an AES-GCM encryption algorithm final JceMasterKey masterKey = JceMasterKey.getInstance(new SecretKeySpec(decodedDataKey, "AES"), "BC", "DataKey", "AES/GCM/NoPadding"); try (final CryptoInputStream<JceMasterKey> decryptingStream = CRYPTO.createDecryptingStream(masterKey, new ByteArrayInputStream(decoded)); final ByteArrayOutputStream out = new ByteArrayOutputStream()) { IOUtils.copy(decryptingStream, out); return out.toByteArray(); } } public static void main(String[] args) throws Exception { final String workerId = InetAddress.getLocalHost().getCanonicalHostName() + ":" + UUID.randomUUID(); final KinesisClientLibConfiguration kinesisClientLibConfiguration = new KinesisClientLibConfiguration(APPLICATION_NAME, STREAM_NAME, CREDENTIALS_PROVIDER, workerId); kinesisClientLibConfiguration.withInitialPositionInStream(InitialPositionInStream.LATEST); kinesisClientLibConfiguration.withRegionName(REGION_NAME); final Worker worker = new Builder() .recordProcessorFactory(new RecordProcessorFactory()) .config(kinesisClientLibConfiguration) .build(); System.out.printf("Running %s to process stream %s as worker %s...\n", APPLICATION_NAME, STREAM_NAME, workerId); try { worker.run(); } catch (Throwable t) { System.err.println("Caught throwable while processing data."); t.printStackTrace(); System.exit(1); } System.exit(0); } private static byte[] getByteArray(final ByteBuffer b) { byte[] byteArray = new byte[b.remaining()]; b.get(byteArray); return byteArray; } }
Python
import base64 import json import zlib import aws_encryption_sdk from aws_encryption_sdk import CommitmentPolicy from aws_encryption_sdk.internal.crypto import WrappingKey from aws_encryption_sdk.key_providers.raw import RawMasterKeyProvider from aws_encryption_sdk.identifiers import WrappingAlgorithm, EncryptionKeyType import boto3 REGION_NAME = '<region>' # us-east-1 RESOURCE_ID = '<external-resource-id>' # cluster-ABCD123456 STREAM_NAME = 'aws-rds-das-' + RESOURCE_ID # aws-rds-das-cluster-ABCD123456 enc_client = aws_encryption_sdk.EncryptionSDKClient(commitment_policy=CommitmentPolicy.FORBID_ENCRYPT_ALLOW_DECRYPT) class MyRawMasterKeyProvider(RawMasterKeyProvider): provider_id = "BC" def __new__(cls, *args, **kwargs): obj = super(RawMasterKeyProvider, cls).__new__(cls) return obj def __init__(self, plain_key): RawMasterKeyProvider.__init__(self) self.wrapping_key = WrappingKey(wrapping_algorithm=WrappingAlgorithm.AES_256_GCM_IV12_TAG16_NO_PADDING, wrapping_key=plain_key, wrapping_key_type=EncryptionKeyType.SYMMETRIC) def _get_raw_key(self, key_id): return self.wrapping_key def decrypt_payload(payload, data_key): my_key_provider = MyRawMasterKeyProvider(data_key) my_key_provider.add_master_key("DataKey") decrypted_plaintext, header = enc_client.decrypt( source=payload, materials_manager=aws_encryption_sdk.materials_managers.default.DefaultCryptoMaterialsManager(master_key_provider=my_key_provider)) return decrypted_plaintext def decrypt_decompress(payload, key): decrypted = decrypt_payload(payload, key) return zlib.decompress(decrypted, zlib.MAX_WBITS + 16) def main(): session = boto3.session.Session() kms = session.client('kms', region_name=REGION_NAME) kinesis = session.client('kinesis', region_name=REGION_NAME) response = kinesis.describe_stream(StreamName=STREAM_NAME) shard_iters = [] for shard in response['StreamDescription']['Shards']: shard_iter_response = kinesis.get_shard_iterator(StreamName=STREAM_NAME, ShardId=shard['ShardId'], ShardIteratorType='LATEST') shard_iters.append(shard_iter_response['ShardIterator']) while len(shard_iters) > 0: next_shard_iters = [] for shard_iter in shard_iters: response = kinesis.get_records(ShardIterator=shard_iter, Limit=10000) for record in response['Records']: record_data = record['Data'] record_data = json.loads(record_data) payload_decoded = base64.b64decode(record_data['databaseActivityEvents']) data_key_decoded = base64.b64decode(record_data['key']) data_key_decrypt_result = kms.decrypt(CiphertextBlob=data_key_decoded, EncryptionContext={'aws:rds:dbc-id': RESOURCE_ID}) print (decrypt_decompress(payload_decoded, data_key_decrypt_result['Plaintext'])) if 'NextShardIterator' in response: next_shard_iters.append(response['NextShardIterator']) shard_iters = next_shard_iters if __name__ == '__main__': main()