2026. június 15., hétfő

Heurisztikát is tartalmazó veszteségfüggvény

import torch
import torch.nn as nn
import torch.optim as optim

# 1. Hálózat definiálása
class XAINeuralNetwork(nn.Module):
    def __init__(self, input_size):
        super(XAINeuralNetwork, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_size, 16),
            nn.ReLU(),
            nn.Linear(16, 8),
            nn.ReLU(),
            nn.Linear(8, 1),
            nn.Sigmoid()
        )

    def forward(self, x):
        return self.fc(x)

# 2. Heurisztikát is tartalmazó veszteségfüggvény
def combined_loss(outputs, targets, inputs, lambda_heuristic=0.5):
    # Alapvető veszteség (pl. bináris keresztentrópia)
    criterion = nn.BCELoss()
    base_loss = criterion(outputs, targets)
    
    # MATEMATIKAI HEURISZTIKA INTEGRÁLÁSA:
    # Feltételezzük, hogy az 1. bemeneti változó (inputs[:, 0]) egy kritikus szabályozó.
    # Szabály: Ha az inputs[:, 0] > 0.8, akkor az outputnak is nagynak kell lennie (> 0.5).
    # Büntetjük azt az esetet, ha ez a szabály sérül.
    critical_feature_idx = 0
    rule_violation = torch.relu(0.5 - outputs) * (inputs[:, critical_feature_idx] > 0.8).float()
    
    # A heurisztikus büntetés átlagos mértéke
    heuristic_penalty = torch.mean(rule_violation)
    
    # Végső veszteség a heurisztikus súly (lambda) figyelembevételével
    total_loss = base_loss + lambda_heuristic * heuristic_penalty
    return total_loss

# 3. Tanítási folyamat szimulációja
input_size = 5
model = XAINeuralNetwork(input_size)
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Dummy adatok: (batch_size=10, features=5)
X_train = torch.rand(10, input_size)
y_train = torch.randint(0, 2, (10, 1)).float()

# Tanítási ciklus
epochs = 50
for epoch in range(epochs):
    optimizer.zero_grad()
    predictions = model(X_train)
    
    # Veszteség kalkuláció a beépített heurisztikával
    loss = combined_loss(predictions, y_train, X_train, lambda_heuristic=0.8)
    
    loss.backward()
    optimizer.step()
    
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

Nincsenek megjegyzések:

Megjegyzés küldése