Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

app.py 3.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
  1. import os
  2. import sys
  3. # Set the root directory as the working directory
  4. root = os.path.dirname(os.path.abspath(__file__))
  5. sys.path.append(root)
  6. from fastapi import FastAPI, UploadFile, File, HTTPException
  7. from fastapi.responses import StreamingResponse
  8. from fastapi.middleware.cors import CORSMiddleware
  9. import shutil
  10. import tempfile
  11. import torch
  12. import mlflow.pytorch
  13. from models import StyleTransferNetwork
  14. from utils.image_utils import imload, imsave
  15. from configs import config
  16. import boto3
  17. from botocore.exceptions import ClientError
  18. # Initialize FastAPI app
  19. app = FastAPI()
  20. mlflow.set_tracking_uri("https://dagshub.com/shatter-star/musical-octo-dollop.mlflow")
  21. # Initialize CORS middleware
  22. app.add_middleware(
  23. CORSMiddleware,
  24. allow_origins=["*"], # Allow requests from all origins
  25. allow_credentials=True,
  26. allow_methods=["GET", "POST"],
  27. allow_headers=["*"],
  28. )
  29. # Initialize StyleTransferNetwork model
  30. device = torch.device('cpu')
  31. model_uri = "mlflow-artifacts:/366666ce4dc8413383fd5d9a1ce802f9/8c9c0df67b1d4151886eec4a77c36417/artifacts/model"
  32. model = mlflow.pytorch.load_model(model_uri, map_location=device)
  33. model.eval()
  34. # Configure S3 client using the IAM role assigned to the Lambda function
  35. s3_client = boto3.client(
  36. 's3',
  37. aws_access_key_id=os.environ.get('AWS_ACCESS_KEY_ID'),
  38. aws_secret_access_key=os.environ.get('AWS_SECRET_ACCESS_KEY'),
  39. region_name=os.environ.get('AWS_DEFAULT_REGION')
  40. )
  41. S3_BUCKET_NAME = 'neural-images'
  42. S3_STYLIZED_IMAGE_PREFIX = 'images/'
  43. # Define the endpoint for stylizing images
  44. @app.post("/stylize")
  45. async def stylize(content_image: UploadFile = File(...), style_index: int = 0):
  46. try:
  47. # Validate the style index
  48. if style_index != -1 and style_index not in range(config.NUM_STYLE):
  49. raise HTTPException(status_code=400, detail="Invalid style index")
  50. # Save uploaded content image to a temporary file
  51. with tempfile.NamedTemporaryFile(delete=False) as temp_content:
  52. shutil.copyfileobj(content_image.file, temp_content)
  53. content_path = temp_content.name
  54. # Generate filename for the stylized image
  55. content_filename, content_extension = os.path.splitext(content_image.filename)
  56. output_filename = f"stylized_{content_filename}{content_extension}"
  57. output_path = os.path.join('/tmp', output_filename)
  58. # Load content image and apply style transfer
  59. content_image = imload(content_path, imsize=config.IMSIZE)
  60. if style_index == -1:
  61. style_code = torch.eye(config.NUM_STYLE).unsqueeze(-1)
  62. content_image = content_image.repeat(config.NUM_STYLE, 1, 1, 1)
  63. else:
  64. style_code = torch.zeros(1, config.NUM_STYLE, 1)
  65. style_code[:, style_index, :] = 1
  66. stylized_image = model(content_image, style_code)
  67. imsave(stylized_image, output_path)
  68. # Upload the stylized image to S3
  69. s3_key = f"{S3_STYLIZED_IMAGE_PREFIX}{output_filename}"
  70. try:
  71. s3_client.upload_file(output_path, S3_BUCKET_NAME, s3_key)
  72. except ClientError as e:
  73. raise HTTPException(status_code=500, detail=f"Error uploading image to S3: {e}")
  74. # Return the stylized image as a streaming response
  75. file_like = open(output_path, mode="rb")
  76. return StreamingResponse(file_like, media_type='image/jpeg')
  77. finally:
  78. # Remove temporary content file and stylized image file
  79. os.unlink(content_path)
  80. os.unlink(output_path)
  81. @app.get("/")
  82. def root():
  83. return {"message": "Style Transfer API is running!"}
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...