codermert commited on
Commit
ff8c21e
·
verified ·
1 Parent(s): f4f5632

Create infer.py

Browse files
Files changed (1) hide show
  1. infer.py +86 -0
infer.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import cv2 as cv
3
+ import torch
4
+ from RealESRGAN import RealESRGAN
5
+ import tempfile
6
+ import numpy as np
7
+ import tqdm
8
+ import ffmpeg
9
+ import spaces
10
+
11
+
12
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
13
+
14
+ @spaces.GPU(duration=60)
15
+ def infer_image(img: Image.Image, size_modifier: int ) -> Image.Image:
16
+ if img is None:
17
+ raise Exception("Image not uploaded")
18
+
19
+ width, height = img.size
20
+
21
+ if width >= 100000 or height >= 100000:
22
+ raise Exception("The image is too large.")
23
+
24
+ model = RealESRGAN(device, scale=size_modifier)
25
+ model.load_weights(f'weights/RealESRGAN_x{size_modifier}.pth', download=False)
26
+
27
+ result = model.predict(img.convert('RGB'))
28
+ print(f"Image size ({device}): {size_modifier} ... OK")
29
+ return result
30
+
31
+ @spaces.GPU(duration=300)
32
+ def infer_video(video_filepath: str, size_modifier: int) -> str:
33
+ model = RealESRGAN(device, scale=size_modifier)
34
+ model.load_weights(f'weights/RealESRGAN_x{size_modifier}.pth', download=False)
35
+
36
+ cap = cv.VideoCapture(video_filepath)
37
+
38
+ tmpfile = tempfile.NamedTemporaryFile(suffix='.mp4', delete=False)
39
+ vid_output = tmpfile.name
40
+ tmpfile.close()
41
+
42
+ # Check if the input video has an audio stream
43
+ probe = ffmpeg.probe(video_filepath)
44
+ has_audio = any(stream['codec_type'] == 'audio' for stream in probe['streams'])
45
+
46
+ if has_audio:
47
+ # Extract audio from the input video
48
+ audio_file = video_filepath.replace(".mp4", ".wav")
49
+ ffmpeg.input(video_filepath).output(audio_file, format='wav', ac=1).run(overwrite_output=True)
50
+
51
+ vid_writer = cv.VideoWriter(
52
+ vid_output,
53
+ fourcc=cv.VideoWriter.fourcc(*'mp4v'),
54
+ fps=cap.get(cv.CAP_PROP_FPS),
55
+ frameSize=(int(cap.get(cv.CAP_PROP_FRAME_WIDTH)) * size_modifier, int(cap.get(cv.CAP_PROP_FRAME_HEIGHT)) * size_modifier)
56
+ )
57
+
58
+ n_frames = int(cap.get(cv.CAP_PROP_FRAME_COUNT))
59
+
60
+ for _ in tqdm.tqdm(range(n_frames)):
61
+ ret, frame = cap.read()
62
+ if not ret:
63
+ break
64
+
65
+ frame = cv.cvtColor(frame, cv.COLOR_BGR2RGB)
66
+ frame = Image.fromarray(frame)
67
+
68
+ upscaled_frame = model.predict(frame.convert('RGB'))
69
+
70
+ upscaled_frame = np.array(upscaled_frame)
71
+ upscaled_frame = cv.cvtColor(upscaled_frame, cv.COLOR_RGB2BGR)
72
+
73
+ vid_writer.write(upscaled_frame)
74
+
75
+ vid_writer.release()
76
+
77
+ if has_audio:
78
+ # Re-encode the video with the modified audio
79
+ ffmpeg.input(vid_output).output(video_filepath.replace(".mp4", "_upscaled.mp4"), vcodec='libx264', acodec='aac', audio_bitrate='320k').run(overwrite_output=True)
80
+
81
+ # Replace the original audio with the upscaled audio
82
+ ffmpeg.input(audio_file).output(video_filepath.replace(".mp4", "_upscaled.mp4"), acodec='aac', audio_bitrate='320k').run(overwrite_output=True)
83
+
84
+ print(f"Video file : {video_filepath}")
85
+
86
+ return vid_output.replace(".mp4", "_upscaled.mp4") if has_audio else vid_output