논문 링크: [2412.18185] TextMatch: Enhancing Image-Text Consistency Through Multimodal Optimization
TextMatch: Enhancing Image-Text Consistency Through Multimodal Optimization
Text-to-image generative models excel in creating images from text but struggle with ensuring alignment and consistency between outputs and prompts. This paper introduces TextMatch, a novel framework that leverages multimodal optimization to address image-
arxiv.org
Abstract
Semi-supervised learning은 소량의 labeled data와 다량의 unlabeled data를 함께 사용하는 학습 방식으로, 의료 영상 segmentation 분야에서 큰 진전을 이루어왔다. 그러나 라벨이 없는 상태에서 학습을 진행하면 노이즈가 유입되기 쉬워지며, 이는 segmentation에 필수적인 명확하게 클러스터링된 feature space를 구성하는 데 방해가 된다.
한편, 최근 Vision-Language (VL) models는 자연 영상 처리 분야에서 큰 주목을 받으며, 특히 text prompts를 활용한 object localization 성능 향상에서 그 가능성을 보여주었다. 이러한 VL 모델의 강점은 라벨 부족 문제를 해결할 수 있는 새로운 방안이 될 수 있다.
이에 착안하여, 본 논문에서는 Textmatch라는 새로운 프레임워크를 제안한다. 이 프레임워크는 text prompts를 활용하여 semi-supervised medical image segmentation 성능을 향상시키는 데 초점을 둔다.
Proposed Methods
1. Bilateral Prompt Decoder (BPD)
- 시각적 feature와 언어적 feature 사이의 modality 차이를 조정하기 위해 사용된다.
- 두 모달리티 간의 보완적인 정보를 효과적으로 추출할 수 있도록 돕는다.
2. Multi-views Consistency Regularization (MCR)
- 이미지와 텍스트 두 도메인에서 각각 perturbation (변형)을 주어 여러 view를 생성한다.
- 이렇게 생성된 여러 view 간의 일관성을 유지하도록 학습시켜, 노이즈의 영향을 줄이고 더 신뢰성 있는 pseudo-label을 생성할 수 있도록 한다.
3. Pseudo-Label Guided Contrastive Learning (PGCL)
- 위에서 얻은 pseudo-label을 바탕으로 feature space에서 contrastive learning을 수행한다.
- 같은 클래스 내 feature는 서로 가깝게, 다른 클래스는 멀게 학습시킴으로써 segmentation에 유리한 구별 가능한 representation을 형성한다.
Experiments
- 공개된 두 개의 데이터셋에서 광범위한 실험을 수행한 결과, image-only 또는 기존의 multi-modal 방식보다 뛰어난 성능을 보였다.
- 따라서, 본 프레임워크는 현재 SOTA 성능을 달성하였음을 입증한다.
Keywords
- Medical image segmentation
- Semi-supervised learning
- Bilateral Prompt
- Multi-views Consistency
- Contrastive learning
1. Introduction
1.1 배경 및 필요성
의료 영상 segmentation은 임상 진단, 질병 경과 추적, 치료 계획 수립 등에서 중요한 정보를 제공한다. 최근에는 딥러닝 기술의 발달로 다양한 deep neural network가 등장하였고, 대규모 주석 데이터를 기반으로 탁월한 성능을 달성하였다.
그러나 pixel-level annotation은 시간이 많이 들고, 전문 지식이 필요하며, 높은 비용이 소모된다. 따라서, 이 제약을 완화할 수 있는 방법이 필수적이다.
1.2 Semi-supervised Learning의 등장과 방식
Semi-supervised learning은 소량의 labeled data와 대량의 unlabeled data를 함께 활용하여 성능을 높이는 방식으로 주목받고 있다.
현재 semi-supervised segmentation 방법은 다음 세 가지로 나눌 수 있다:
- Pseudo-labeling 방법:
unlabeled 데이터에 대해 예측한 결과를 pseudo-label로 삼아 self-training 방식으로 학습을 반복하여 데이터셋을 확장한다. - Consistency Regularization 방법:
입력에 변형(perturbation)을 가한 후에도 예측 결과가 일관되도록 유도하여, 모델의 일반화 성능을 높인다. - Hybrid 방법:
위의 두 가지를 결합하여 성능 향상을 도모한다.
1.3 기존 방식의 한계점
- Pseudo-label은 supervised data가 적기 때문에 노이즈가 포함될 가능성이 높다.
- 대부분의 방법은 logit(pixel) space에 한정된 supervision만 제공하며, feature space에서는 명확한 guidance가 부족하다.
- 이로 인해 잘 클러스터링된 feature space를 형성하지 못하고, segmentation 성능을 떨어뜨리는 결과를 초래한다.
1.4 Vision-Language (VL) 모델의 가능성
최근 Vision-Language 모델은 자연 영상 처리뿐만 아니라 의료 영상 분석에서도 활용되며, 주로 fully-supervised 환경에서 좋은 성능을 보이고 있다.
특히 text prompt를 이용하면 object localization 능력이 향상되고, 노이즈의 영향을 줄일 수 있음이 확인되었다. 그러나, semi-supervised 학습에서는 아직 텍스트 정보를 제대로 활용하지 못하고 있으며, 대부분이 이미지 정보에만 의존하고 있다.
1.5 Textmatch Framework
이러한 문제를 해결하기 위해 본 논문에서는 Textmatch라는 새로운 semi-supervised medical image segmentation 프레임워크를 제안한다.
- Bilateral Prompt Decoder (BPD)
- 시각 feature와 언어 feature의 차이를 조정
- 서로 보완적인 multi-modal feature를 추출
- Multi-views Consistency Regularization (MCR)
- 이미지와 텍스트 모두에 perturbation을 적용
- 다양한 view를 생성하고, 이들 간의 일관성 제약을 부여
- 노이즈를 줄이고 더 견고한 pseudo-label을 생성
- Pseudo-label Guided Contrastive Learning (PGCL)
- 생성된 pseudo-label을 활용해 feature space에서 contrastive learning 수행
- 동일 클래스는 가까이, 이질 클래스는 멀어지도록 학습
- 클래스-구별 가능한 feature 표현 학습을 강화
1.6 Contributions
- BPD 설계를 통해 visual-linguistic feature의 차이를 조율하고, 서로 보완적인 multi-modal 표현을 추출.
- MCR 전략을 통해 이미지/텍스트 perturbation을 모두 활용해 노이즈에 강한 pseudo-label 생성.
- PGCL 전략을 통해 feature space의 구조적 지도(supervision) 제공.
- 2개의 공개 데이터셋 실험을 통해 기존 방법보다 뛰어난 성능을 입증.
2. Method
2.1 Architecture
먼저 두 종류의 입력 데이터를 받는다. 첫째는 labeled 데이터로, 이미지, 텍스트 프롬프트, 그리고 GT segmentation 마스크가 포함되어 있으며, 둘째는 unlabeled 데이터로, 이미지와 텍스트 프롬프트만 포함된다. 입력된 이미지는 Visual Encoder를 통해 CNN 기반으로 세 단계에 걸쳐 다양한 레벨의 시각적 feature $f_I$가 추출되고, 텍스트 프롬프트는 BERT와 같은 Text Encoder를 통해 언어적 feature $f_T$가 생성된다. 이 두 feature는 Projection Head를 통해 동일한 차원으로 정렬되어 통합 가능하게 되며, 이후 BPD를 통과한다. 이 BPD는 이미지와 텍스트가 서로를 프롬프트로 삼아 쌍방향으로 attention을 주고받으며 feature를 강화하는 구조로, multi-head attention을 통해 이미지와 텍스트 feature를 동시에 업데이트하며, 업샘플링과 skip connection을 통해 풍부한 정보를 유지한 채 최종 출력 feature $f_O$를 생성한다. 이 feature는 Segmentation Head를 통해 segmentation mask로 출력되며, labeled 데이터에 대해서는 GT와 비교하여 supervised loss인 $\mathcal{L}_{sup}$ 계산한다.
한편, unlabeled 데이터에 대해서는 MCR 전략이 적용된다. 이 단계에서는 이미지와 텍스트에 각각 다양한 형태의 augmentation을 적용해 multi-view 쌍 $V_t = \{ (x_i^j, t_i^j) \}$을 생성하고, 원본 view는 Student 모델이, augmented views는 Teacher 모델이 사용한다. Teacher 모델은 Student 모델의 EMA로 업데이트되어 더 안정적인 예측을 수행하며, 여러 view에서의 예측 평균을 pseudo-label로 사용하고, Student 모델은 원본 view에 대한 예측과 얼마나 일치하는지를 기준으로 consistency loss $\mathcal{L}_{reg}$를 계산한다. 이어서 적용되는 PGCL은 pseudo-label을 기반으로 각 픽셀 feature가 어떤 클래스에 속하는지 결정하고, 이를 바탕으로 foreground와 background에 대한 prototype을 생성한다. 이 prototype은 해당 클래스의 여러 픽셀 feature를 평균 내어 만든 대표 벡터이며, 학습 중에는 EMA 방식으로 업데이트된다. 그런 다음 InfoNCE 기반 contrastive loss $\mathcal{L}_c$를 적용하여 같은 클래스는 feature 공간에서 가까워지고, 다른 클래스는 멀어지도록 학습한다. 최종적으로 모델은 supervised loss$\mathcal{L}_{sup}$, consistency loss $\mathcal{L}_{reg}$, pseudo-label loss $\mathcal{L}_{pse}$, contrastive loss $\mathcal{L}_c$의 네 가지 손실을 가중치 $\lambda_1, \lambda_2, \lambda_3$와 함께 조합한 total objective $\mathcal{L} = \mathcal{L}_{sup} + \lambda_1 \mathcal{L}_{reg} + \lambda_2 \mathcal{L}_{pse} + \lambda_3 \mathcal{L}_c$로 학습되며, 라벨이 있는 경우에는 $\mathcal{L}_{sup}$만, 라벨이 없는 경우에는 나머지 세 loss를 조합하여 학습이 진행된다.
이 프레임워크는 Mean Teacher 구조를 기반으로 하며, student 모델과 teacher 모델 모두 다음으로 구성된다:
Mean Teacher: semi-supervised에서 자주 쓰이는 방법으로, Student 모델은 현재 학습을 수행하는 역할이고, Teacher 모델은 그동안의 학습 결과를 부드럽게 평균(EMA)한 모델로, 더 안정적인 pseudo-label을 생성하는 데 사용된다. 둘 다 같은 구조를 가지며, Student는 원본 입력을, Teacher는 변형된 입력을 사용해 서로 예측이 일치하도록 훈련함으로써, 모델의 일반화 성능을 높인다.
- Visual Encoder
- Text Encoder
- 여러 개의 Bilateral Prompt Decoder (BPD)
- Segmentation Head
Process
- Labeled data는 ground truth를 이용한 supervised learning에 사용된다.
- Unlabeled data는:
- Augmentation된 multi-view 생성
- Pseudo-label 생성 및 학습 supervision
- Feature space 상의 contrastive learning에 활용된다.
2.2 문제 정의 및 Feature 추출
학습 데이터셋은 다음과 같이 구성된다:
- 작은 labeled subset: $N$개의 샘플 보유
- 큰 unlabeled subset: $M$개의 샘플 보유 ($M \gg N$)
모든 데이터는 이미지와 이에 대응하는 text prompt를 포함한다.
- Labeled set:$\mathcal{D}_l = \{(x_i^l, t_i^l, y_i^l)\}_{i=1}^N$
- Unlabeled set:$\mathcal{D}_u = \{(x_i^u, t_i^u)\}_{i=1}^M$
여기서,
- $x_i \in \mathbb{R}^{C \times H \times W}$: 이미지
- $y_i \in \mathbb{R}^{H \times W}$: segmentation label
- $t_i \in \mathbb{R}^L$: 텍스트 프롬프트 (L개의 단어)
Visual Feature 추출
입력 이미지 $x_i \in \mathbb{R}^{H \times W \times D}$에 대해, Visual Encoder의 각 stage에서 다음과 같은 feature들을 추출한다:
$f_I = \{f_i^I \in \mathbb{R}^{\frac{H}{d_i} \times \frac{W}{d_i} \times C_i} \}_{i=1}^{4}$
- $d_i$: downsampling 비율
- $C_i$: feature dimension
Text Feature 추출
텍스트 프롬프트 $t_i \in \mathbb{R}^{L}$에 대해, text encoder는 다음과 같은 언어 feature를 생성한다:
$$f_T \in \mathbb{R}^{L \times C}$$
2.3 Bilateral Prompt Decoder (BPD)
기존의 multi-modal fusion 기법들과 달리, 제안된 BPD는 시각적 feature와 언어적 feature를 동시에 상호 보완적으로 강화한다.
입력:
- 시각적 feature: $f_I \in \mathbb{R}^{H \times W \times C_I}$
- 언어적 feature:
Step 1: Token 정렬 (Dimension alignment)
시각/언어 feature의 차원을 맞추기 위해 1×1 convolution과 activation 함수 $\sigma$를 적용한다.
$$\begin{aligned} f_I^d &\in \mathbb{R}^{(H \times W) \times C_d} = \sigma(\text{Conv}_{1 \times 1}(f_I)) \\ f_T^d &\in \mathbb{R}^{L \times C_d} = \sigma(\text{Conv}_{1 \times 1}(f_T)) \end{aligned} \tag{1}$$
- $\sigma$: activation function
- $C_d$: aligned dimension
Step 2: Bilateral prompting
$$f_I' \in \mathbb{R}^{(H \times W) \times C_d} = f_I^d + \alpha \cdot \text{MHSA}(f_I^d, f_T^d, f_T^d) \tag{2}$$
$$f_T' \in \mathbb{R}^{L \times C_d} = f_T^d + \alpha \cdot \text{MHSA}(f_T^d, f_I^d, f_I^d) \tag{3}$$
- $\text{MHSA}$: Multi-Head Self-Attention
- $alpha$: learnable residual 가중치
이 구조는 text-to-image, image-to-text attention을 동시에 수행함으로써 cross-modal 이해를 높인다.
Step 3: Feature 복원 및 결합
$$f_I'' = \text{Upsample}(\text{Reshape}(f_I')) \tag{4}$$
$f_I'$는 토큰 시퀀스이므로 spatial shape로 reshape → upsample하여 decoder 단계 크기로 맞춘다.
$$f_O = \sigma(\text{Conv}([f_I''; f_C])) \tag{5}$$
- $f_C$: Decoder의 skip connection에서 가져온 low-level feature
- $[ \cdot ; \cdot ]$: channel 축 결합
2.4 Multi-views Consistency Regularization (MCR)
- 이미지와 텍스트에 각각 perturbation을 적용해, 여러 개의 augmented views를 생성한다:
$$\mathcal{V}_i = \{(\mathcal{A}(x_i^u), \mathcal{T}(t_i^u))_j\}_{j=0}^{n} \tag{6}$$
- $\mathcal{A}(\cdot)$: 이미지 augmentation
- $\mathcal{T}(\cdot)$: 텍스트 variation 생성 모델
- $\mathcal{V}_i^0$: student model에 입력
- $\mathcal{V}_i^{j \geq 1}$: teacher model로 예측
각 view의 예측이 서로 일관성 있게 되도록 모델을 regularization한다:
Regularization Loss
$$\mathcal{L}_{reg} = \frac{1}{B \times n} \sum_{i=1}^{B} \sum_{j=1}^{n} \text{MSE}(M(\mathcal{V}_i^0), T(\mathcal{V}_i^j)) \tag{7}$$
- $M$: student model
- $T$: teacher model(EMA로 update됨)
- $B$: batch size
- $n$: view 수
Pseudo-label Supervision Loss (Dice Loss)
$$\mathcal{L}_{pse} = \frac{1}{B} \sum_{i=1}^{B} \text{DiceLoss}\left( M(\mathcal{V}_i^0), \frac{1}{n} \sum_{j=1}^{n} T(\mathcal{V}_i^j) \right) \tag{8}$$
- 여러 view에서 teacher의 예측 결과 평균 → pseudo-label로 사용
- student는 이 평균 label을 정답처럼 학습
2.5 Pseudo-label Guided Contrastive Learning (PGCL)
Feature space에서 intra-class는 가깝게, inter-class는 멀게 학습하도록 contrastive loss를 적용한다.
Ground Truth 기반 Prototype 생성 (labeled data 기준)
prototype:
- 하나의 class(예: foreground)의 여러 픽셀 feature들을 평균내어 만든 대표 벡터
- 이 prototype과 비교해서 feature가 잘 군집되도록 유도함
Labeled image의 시각 feature: $f_l \in \mathbb{R}^{H \times W \times C}$
- Ground truth label: $y \in \mathbb{R}^{H \times W}$
$$P^f = \frac{\sum_{i,j} y_{ij} \cdot f^l_{ij}}{\sum_{i,j} y_{ij}}, \quad P^b = \frac{\sum_{i,j} (1 - y_{ij}) \cdot f^l_{ij}}{\sum_{i,j} (1 - y_{ij})} \tag{9}$$
- $y \in \mathbb{R}^{H \times W}$: foreground (1) / background (0) label
- $f^l \in \mathbb{R}^{H \times W \times C}$: student 모델의 feature
- $P_f$, $P_b$: foreground, background prototype
InfoNCE 기반 Contrastive Loss
$$\mathcal{H}(f_{ij}^u, P_x, P_y) = \frac{\exp(\text{sim}(f_{ij}^u, P_x)/\tau)}{\exp(\text{sim}(f_{ij}^u, P_x)/\tau) + \exp(\text{sim}(f_{ij}^u, P_y)/\tau)} \tag{10}$$
- $f_{ij}^u$: unlabeled 데이터의 픽셀 feature
- $P_x$: 해당 클래스 prototype
- $P_y$: 다른 클래스 prototype
- $\tau$: temperature (보통 0.9) -> sharpness
- sim: cosine similarity
픽셀 feature가 정답 클래스 prototype과는 가까워지고, 다른 클래스 prototype과는 멀어지게 유도하는 loss이다.
이 식은 이 feature가 정답 class에 속할 확률처럼 계산된다.
최종 Contrastive Loss:
$$\mathcal{L}_c = -\frac{1}{H \times W} \sum_{i=1}^{H} \sum_{j=1}^{W} \left( \hat{y}_{ij} \cdot \log \mathcal{H}(f_{ij}^u, P_f, P_b) + (1 - \hat{y}_{ij}) \cdot \log \mathcal{H}(f_{ij}^u, P_b, P_f) \right) \tag{11}$$
- $\hat{y}_{ij}$: pseudo-label에 따른 pixel-level label (0 or 1)
- $H$: soft similarity 확률
픽셀이 foreground(1)라고 예측되면: → $\mathcal{H}(f, P_f, P_b)$ 값을 크게 (→ $\log$값을 덜 마이너스) 만들어야 하고,
픽셀이 background(0)라고 예측되면: → $\mathcal{H}(f, P_b, P_f)$ 값을 크게 만들어야 한다.
즉, 정답 클래스 prototype과의 유사도는 높이고, 다른 클래스 prototype과는 멀게 만드는 방향으로 loss를 설계한 것이다.
2.6 Total Objective (Loss Function)
- Supervised loss for labeled data(라벨이 있는 데이터에 대한 지도학습 손실): $\mathcal{L}_{sup}$ = Cross Entropy + Dice Loss
- Final loss:
$$\mathcal{L} = \mathcal{L}_{sup} + \lambda_1 \mathcal{L}_{reg} + \lambda_2 \mathcal{L}_{pse} + \lambda_3 \mathcal{L}_c \tag{12}$$
- $\lambda_1, \lambda_2, \lambda_3$: 각 항목의 가중치
라벨이 있는 경우에는 cross-entropy + Dice loss로 학습하며, 라벨이 없는 경우 아래 loss들을 조합하여 학습한다.
3. Experiments and Results
3.1 Datasets & Metrics
제안된 Textmatch 모델은 두 개의 공개 의료 영상 데이터셋에서 성능을 평가하였다.
- QaTa-COV19
- 총 9,258개 X-ray 영상
- COVID-19 병변에 대한 수작업 마스크 포함
- MosMedData+
- 총 2,729개 CT 영상 슬라이스
- 폐 감염 정보 포함
두 데이터셋에 맞는 text prompt를 생성하였으며, 프롬프트 내용은 다음과 같다:
- 양쪽 폐 감염 여부
- 병변의 개수
- 감염 부위의 위치
Metrics
- Dice Similarity Coefficient (DSC, Dice Loss 기반)
- IoU (Intersection over Union)
두 지표 모두 예측 마스크와 정답 마스크 사이의 겹치는 영역의 비율을 기반으로 계산된다.
3.2 Implementation Details
- 프레임워크: PyTorch
- GPU 환경: NVIDIA GeForce RTX 2080Ti 4장 (각 10GB)
모델 구성
- Image Encoder: ConvNeXt-Tiny
- Text Encoder: BERT
- Projection Head: 얕은 Fully Connected laye
Data Augmentation
- 이미지: Random scaling, morphological 변화, 밝기 변화 기반
- 텍스트: Generative Pre-trained Transformer 사용하여 의미 유사한 텍스트 생성
Training Conditions
- Labeled 비율: 5%, 15%, 25% (fair comparison 위해 기준 유지)
- Optimizer: Adam
- Batch Size: 48
- Learning Rate: $3 \times 10^{-4}$
Hyperparameter
- View 개수 $C = 3$
- Temperature $\tau = 0.9$
- 손실 함수 가중치 (Equation 12의 λ 계수):
- $\lambda_1 = \lambda_3 = 0.1$ (모든 데이터셋 공통)
- $\lambda_2 = 0.1$ (QaTa-COV19), $\lambda_2 = 0.5$ (MosMedData+, 높은 난이도 대응)
3.3 기존 기법과의 성능 비교
Results
- Textmatch는 기존 image-only 방식 대비 평균 성능이:
- QaTa-COV19에서 10.19%
- MosMedData+에서 5.03% 개선
- 기존 multi-modal 방법 대비 성능도:
- QaTa-COV19에서 5.94%
- MosMedData+에서 2.92% 더 우수
특히, 단 5%의 labeled data만 사용한 경우에도, 기존 multi-modal 기법이 25%의 라벨을 사용할 때와 동등하거나 초과하는 성능을 보였다.
→ 이는 Textmatch가 unlabeled 데이터로부터 더 discriminative feature를 잘 학습하고, 노이즈 영향도 줄인다는 것을 보여준다.
3.4 Visual Comparison
각 방법의 segmentation 결과를 시각적으로 보여준다.
- X-ray와 CT 모두에서 Textmatch는 다른 방법보다 더 정확한 병변 영역과 명확한 경계선을 보여준다.
- 이는 text prompt의 보조 정보가 시각적 특징을 강화시키는 데 기여했음을 의미한다.
3.5 Ablation Study
각 모듈의 기여도를 분석하기 위해 수행한 ablation 실험 결과를 보여준다 (QaTa-COV19 기준).
3.6 Feature Embedding 분석 (T-SNE)
T-SNE를 통해 feature space에서의 embedding clustering을 시각화한 것이다.
- PGCL 적용 시: 클래스 간 경계가 선명해지고, intra-class compactness가 증가
- PGCL 미적용 시: 클래스 간 경계가 뒤엉켜 구별이 어렵다
→ PGCL은 feature space에 대한 명시적 guidance 부족 문제를 효과적으로 해결하며, segmentation에 유리한 구별 가능한 표현 학습을 유도한다.
4. Conclusion
본 논문에서는 Textmatch라는 새로운 semi-supervised medical image segmentation 프레임워크를 제안하였다. 이 프레임워크는 text prompt의 잠재력을 활용하여, 소량의 라벨 정보로도 정확한 segmentation 결과를 도출하고자 한다.
제안한 핵심 구성 요소:
- Bilateral Prompt Decoder (BPD)
- 시각 정보와 언어 정보를 동시에 활용하여, multi-modal feature로부터 상호 보완적 정보를 효과적으로 추출함.
- Multi-views Consistency Regularization (MCR)
- 이미지와 텍스트 perturbation을 동시에 적용하여 다양한 시점(view)을 생성
- 이를 통해 노이즈에 강건한 학습이 가능해짐.
- Pseudo-label Guided Contrastive Learning (PGCL)
- 클래스 간 구분이 가능한 feature 표현을 학습하도록 유도
- pixel-level feature들이 각 클래스의 prototype과 더 잘 정렬되도록 함
실험 결과
- QaTa-COV19와 MosMedData+ 두 가지 공개 데이터셋에서 수행된 광범위한 실험을 통해, Textmatch는 기존의 image-only 및 multi-modal 방법들보다 우수한 성능을 달성하였다.
- 특히 text prompt를 활용한 semi-supervised 학습 전략이 의료 영상 segmentation에서 효과적임을 실증적으로 확인하였다.