Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test.py 2.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
  1. import argparse
  2. import json
  3. import logging
  4. import os
  5. import boto3
  6. from botocore.exceptions import ClientError
  7. logger = logging.getLogger(__name__)
  8. sm_client = boto3.client("sagemaker")
  9. def invoke_endpoint(endpoint_name):
  10. """
  11. Add custom logic here to invoke the endpoint and validate reponse
  12. """
  13. return {"endpoint_name": endpoint_name, "success": True}
  14. def test_endpoint(endpoint_name):
  15. """
  16. Describe the endpoint and ensure InSerivce, then invoke endpoint. Raises exception on error.
  17. """
  18. error_message = None
  19. try:
  20. # Ensure endpoint is in service
  21. response = sm_client.describe_endpoint(EndpointName=endpoint_name)
  22. status = response["EndpointStatus"]
  23. if status != "InService":
  24. error_message = f"SageMaker endpoint: {endpoint_name} status: {status} not InService"
  25. logger.error(error_message)
  26. raise Exception(error_message)
  27. # Output if endpoint has data capture enbaled
  28. endpoint_config_name = response["EndpointConfigName"]
  29. response = sm_client.describe_endpoint_config(EndpointConfigName=endpoint_config_name)
  30. if "DataCaptureConfig" in response and response["DataCaptureConfig"]["EnableCapture"]:
  31. logger.info(f"data capture enabled for endpoint config {endpoint_config_name}")
  32. # Call endpoint to handle
  33. return invoke_endpoint(endpoint_name)
  34. except ClientError as e:
  35. error_message = e.response["Error"]["Message"]
  36. logger.error(error_message)
  37. raise Exception(error_message)
  38. if __name__ == "__main__":
  39. parser = argparse.ArgumentParser()
  40. parser.add_argument("--log-level", type=str, default=os.environ.get("LOGLEVEL", "INFO").upper())
  41. parser.add_argument("--import-build-config", type=str, required=True)
  42. parser.add_argument("--export-test-results", type=str, required=True)
  43. args, _ = parser.parse_known_args()
  44. # Configure logging to output the line number and message
  45. log_format = "%(levelname)s: [%(filename)s:%(lineno)s] %(message)s"
  46. logging.basicConfig(format=log_format, level=args.log_level)
  47. # Load the build config
  48. with open(args.import_build_config, "r") as f:
  49. config = json.load(f)
  50. # Get the endpoint name from sagemaker project name
  51. endpoint_name = "{}-{}".format(
  52. config["Parameters"]["SageMakerProjectName"], config["Parameters"]["StageName"]
  53. )
  54. results = test_endpoint(endpoint_name)
  55. # Print results and write to file
  56. logger.debug(json.dumps(results, indent=4))
  57. with open(args.export_test_results, "w") as f:
  58. json.dump(results, f, indent=4)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...