End-to-end workflow
After you've set up Amazon Bedrock Marketplace, you can use the following example code in your end-to-end workflow. If you need more context, you can read the sections that follow the code.
from botocore.exceptions import ClientError import pprint from datetime import datetime import json import time import sys import boto3 import argparse SM_HUB_NAME = 'SageMakerPublicHub' DELIMITER = "\n\n\n\n================================================================================================" class Bedrock: def __init__(self, region_name) -> None: self.region_name = region_name self.boto3_session = boto3.session.Session() self.sagemaker_client = self.boto3_session.client( service_name='sagemaker', region_name=self.region_name, ) self.bedrock_client = self.boto3_session.client( service_name='bedrock', region_name=self.region_name ) self.endpoint_paginator = self.bedrock_client.get_paginator('list_marketplace_model_endpoints') self.bedrock_runtime_client = self.boto3_session.client( service_name='bedrock-runtime', region_name=self.region_name) def list_models(self): SM_RESPONSE_FIELD_NAME = 'HubContentSummaries' SM_HUB_CONTENT_TYPE = 'Model' response = self.sagemaker_client.list_hub_contents( MaxResults=100, HubName=SM_HUB_NAME, HubContentType=SM_HUB_CONTENT_TYPE ) all_models = Bedrock.extract_bedrock_models(response[SM_RESPONSE_FIELD_NAME]) while ("NextToken" in response) and response["NextToken"]: response = self.sagemaker_client.list_hub_contents( MaxResults=100, HubName=SM_HUB_NAME, HubContentType=SM_HUB_CONTENT_TYPE, NextToken=response['NextToken'] ) extracted_models = Bedrock.extract_bedrock_models(response[SM_RESPONSE_FIELD_NAME]) if not extracted_models: # Bedrock enabled models always appear first, therefore can return when results are empty. return all_models all_models.extend(extracted_models) time.sleep(1) return all_models def describe_model(self, hub_name: str, hub_content_name: str): return self.sagemaker_client.describe_hub_content( HubName=hub_name, HubContentType='Model', HubContentName=hub_content_name ) def list_endpoints(self): for response in self.endpoint_paginator.paginate(): for endpoint in response['marketplaceModelEndpoints']: yield endpoint def list_endpoints_for_model(self, hub_content_arn: str): for response in self.endpoint_paginator.paginate( modelSourceEquals=hub_content_arn): for endpoint in response['marketplaceModelEndpoints']: yield endpoint # acceptEula needed only for gated models def create_endpoint(self, model, endpoint_config, endpoint_name: str, tags = []): model_arn = model['HubContentArn'] if self._requires_eula(model=model): return self.bedrock_client.create_marketplace_model_endpoint( modelSourceIdentifier=model_arn, endpointConfig=endpoint_config, endpointName=endpoint_name, acceptEula=True, tags=tags ) else: return self.bedrock_client.create_marketplace_model_endpoint( modelSourceIdentifier=model_arn, endpointConfig=endpoint_config, endpointName=endpoint_name, tags=tags ) def delete_endpoint(self, endpoint_arn: str): return self.bedrock_client.delete_marketplace_model_endpoint(endpointArn=endpoint_arn) def describe_endpoint(self, endpoint_arn: str): return self.bedrock_client.get_marketplace_model_endpoint(endpointArn=endpoint_arn)['marketplaceModelEndpoint'] def update_endpoint(self, endpoint_arn: str, endpoint_config): return self.bedrock_client.update_marketplace_model_endpoint(endpointArn=endpoint_arn, endpointConfig=endpoint_config) def register_endpoint(self, endpoint_arn: str, model_arn: str): return self.bedrock_client.register_marketplace_model_endpoint(endpointIdentifier=endpoint_arn, modelSourceIdentifier=model_arn)['marketplaceModelEndpoint']['endpointArn'] def deregister_endpoint(self, endpoint_arn: str): return self.bedrock_client.deregister_marketplace_model_endpoint(endpointArn=endpoint_arn) def invoke(self, endpoint_arn: str, body): response = self.bedrock_runtime_client.invoke_model(modelId=endpoint_arn, body=body, contentType='application/json') return json.loads(response["body"].read()) def invoke_with_stream(self, endpoint_arn: str, body): return self.bedrock_runtime_client.invoke_model_with_response_stream(modelId=endpoint_arn, body=body) def converse(self, endpoint_arn: str, conversation): return self.bedrock_runtime_client.converse(modelId=endpoint_arn, messages=conversation) def converse_with_stream(self, endpoint_arn: str, conversation): return self.bedrock_runtime_client.converse_stream(modelId=endpoint_arn, messages=conversation, inferenceConfig={"maxTokens": 4096, "temperature": 0.5, "topP": 0.9}) def wait_for_endpoint(self, endpoint_arn: str): endpoint = self.describe_endpoint(endpoint_arn=endpoint_arn) while endpoint['endpointStatus'] in ['Creating', 'Updating']: print( f"Endpoint {endpoint_arn} status is still {endpoint['endpointStatus']}. Waiting 10 seconds before continuing...") time.sleep(10) endpoint = self.describe_endpoint(endpoint_arn=endpoint_arn) print(f"Endpoint status: {endpoint['status']}") def _requires_eula(self, model): if 'HubContentDocument' in model: hcd = json.loads(model['HubContentDocument']) if ('HostingEulaUri' in hcd) and hcd['HostingEulaUri']: return True return False @staticmethod def extract_bedrock_models(hub_content_summaries): models = [] for content in hub_content_summaries: if ('HubContentSearchKeywords' in content) and ( '@capability:bedrock_console' in content['HubContentSearchKeywords']): print(f"ModelName: {content['HubContentDisplayName']}, modelSourceIdentifier: {content['HubContentArn']}") models.append(content) return models def run_script(sagemaker_execution_role: str, region: str): # Script params model_arn = 'arn:aws:sagemaker:
AWS Region
:aws:hub-content/SageMakerPublicHub/Model/example-model-name
/hub-content-arn
' model_name = 'example-model-name
' sample_endpoint_name = f'test-ep-{datetime.now().strftime("%Y-%m-%d%H%M%S")}' sagemaker_execution_role = sagemaker_execution_role conversation = [ { "role": "user", "content": [ { "text": "whats the best park in the US?" } ] } ] bedrock = Bedrock(region_name=region) ### ### Model discovery ### # List all models - no new Bedrock Marketplace API here. Uses existing SageMaker APIs print(DELIMITER) print("All models:") all_models = bedrock.list_models() # Describe a model - no new Bedrock Marketplace API here. Uses existing SageMaker APIs # Examples: # bedrock.describe_model("SageMakerPublicHub", "huggingface-llm-amazon-mistrallite") # bedrock.describe_model("SageMakerPublicHub", "huggingface-llm-gemma-2b-instruct") print(DELIMITER) print(f'Describing model: {model_name}') model = bedrock.describe_model(SM_HUB_NAME, model_name) pprint.pprint(model) ## If customer wants to use a proprietary model, they need to subscribe to it first ## If customer wants to use a gated model, they need to accept EULA. Note: EULA Acceptance is on-creation, and needs ## to be provided on every call. Cannot un-accept a EULA ## If customer wants to use an open weight model, they can proceed to deploy ### ### Model deployment to create endpoints ### # # Create endpoint - uses Bedrock Marketplace API endpoint_arn = bedrock.create_endpoint( endpoint_name=sample_endpoint_name, endpoint_config={ "sageMaker": { "initialInstanceCount": 1, "instanceType": "ml.g5.2xlarge", "executionRole": sagemaker_execution_role # Other fields: # kmsEncryptionKey: KmsKeyId # vpc: VpcConfig } }, # Optional: # tags: TagList model=model )['marketplaceModelEndpoint']['endpointArn'] # # Describe endpoint - uses Bedrock Marketplace API endpoint = bedrock.describe_endpoint(endpoint_arn=endpoint_arn) print(DELIMITER) print('Created endpoint:') pprint.pprint(endpoint) # Wait while endpoint is being created print(DELIMITER) bedrock.wait_for_endpoint(endpoint_arn=endpoint_arn) ### ### Currently, customers cannot use self-hosted endpoints with Bedrock Runtime APIs and tools. They can only pass a model ID to the APIs. ### Bedrock Marketplace will enable customers to use self-hosted endpoints through existing Bedrock Runtime APIs and tools ### See below examples of calling invoke_model, invoke_model_with_response_stream, converse and converse_stream ### Customers will be able to use the endpoints with Bedrock dev tools also (Guardrails, Model eval, Agents, Knowledge bases, Prompt flows, Prompt management) - examples not shown below ### # Prepare sample data for invoke calls by getting default payload in model metadata model_data = json.loads(bedrock.describe_model('SageMakerPublicHub', model_name)['HubContentDocument']) payload = list(model_data["DefaultPayloads"].keys())[0] invoke_body = model_data["DefaultPayloads"][payload]["Body"] invoke_content_field_name = 'generated_text' # Invoke model (text) - without stream - uses existing Bedrock Runtime API print(DELIMITER) print(f'Invoking model with body: {invoke_body}') invoke_generated_response = bedrock.invoke(endpoint_arn=endpoint_arn, body=json.dumps(invoke_body)) print(f'Generated text:') print(invoke_generated_response[invoke_content_field_name]) sys.stdout.flush() # Converse with model (chat) - without stream - uses existing Bedrock Runtime API print(DELIMITER) print(f'Converse model with conversation: {conversation}') print(bedrock.converse(endpoint_arn=endpoint_arn, conversation=conversation)['output']) ### ## Other endpoint management operations ### # List all endpoints - uses Bedrock Marketplace API print(DELIMITER) print('Listing all endpoints') for endpoint in bedrock.list_endpoints(): pprint.pprint(endpoint) # List endpoints for a model # Example: bedrock.list_endpoints_for_model(hub_content_arn='arn:aws:sagemaker:us-west-2:aws:hub-content/SageMakerPublicHub/Model/huggingface-textgeneration1-mpt-7b-storywriter-bf16/3.2.0') print(DELIMITER) print(f"Listing all endpoints for model: {model_arn}") for endpoint in bedrock.list_endpoints_for_model(hub_content_arn=model_arn): pprint.pprint(endpoint) # # Update endpoint - uses new API provided by Bedrock Marketplace updated_endpoint_arn = bedrock.update_endpoint( endpoint_arn=endpoint_arn, endpoint_config={ "sageMaker": { "initialInstanceCount": 2, # update to increase instance count "instanceType": "ml.g5.2xlarge", "executionRole": sagemaker_execution_role # Other fields: # kmsEncryptionKey: KmsKeyId # vpc: VpcConfig } # Optional: # tags: TagList } )['marketplaceModelEndpoint']['endpointArn'] # Wait while endpoint is being updated print(DELIMITER) bedrock.wait_for_endpoint(endpoint_arn=updated_endpoint_arn) # Confirm endpoint update updated_endpoint = bedrock.describe_endpoint(endpoint_arn=updated_endpoint_arn) print(f'Updated endpoint: {updated_endpoint}') assert updated_endpoint['endpointConfig']['sageMaker']['initialInstanceCount'] == 2 print(DELIMITER) print(f'Confirmed that updated endpoint\'s initialInstanceCount config changed from 1 to 2') # Wait while endpoint is being updated print(DELIMITER) bedrock.wait_for_endpoint(endpoint_arn=updated_endpoint_arn) # Deregister endpoint - uses Bedrock Marketplace API print(DELIMITER) print(f'De-registering endpoint: {updated_endpoint_arn}') bedrock.deregister_endpoint(endpoint_arn=updated_endpoint_arn) try: pprint.pprint(bedrock.describe_endpoint(endpoint_arn=updated_endpoint_arn)) except ClientError as err: assert err.response['Error']['Code'] == 'ResourceNotFoundException' print(f"Confirmed that endpoint {updated_endpoint_arn} was de-registered") # Re-register endpoint - uses Bedrock Marketplace API print(DELIMITER) print(f'Registered endpoint: {bedrock.register_endpoint(endpoint_arn=updated_endpoint_arn, model_arn=model_arn)}') pprint.pprint(bedrock.describe_endpoint(endpoint_arn=updated_endpoint_arn)) # Delete endpoint - uses Bedrock Marketplace API print(DELIMITER) print(f'Deleting endpoint: {updated_endpoint_arn}') bedrock.delete_endpoint(endpoint_arn=updated_endpoint_arn) try: pprint.pprint(bedrock.describe_endpoint(endpoint_arn=updated_endpoint_arn)) except ClientError as err: assert err.response['Error']['Code'] == 'ResourceNotFoundException' print(f"Confirmed that endpoint {updated_endpoint_arn} was deleted") if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--sagemaker-execution-role', required=True) parser.add_argument('--region', required=True) args = parser.parse_args() run_script(args.sagemaker_execution_role, args.region)