Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TPU Provisioner: Add support for v6e & cross project reservations #851

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tpu-provisioner/go.mod
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module github.com/GoogleCloudPlatform/ai-on-gke/tpu-provisioner

go 1.22.0
go 1.23.0

require (
cloud.google.com/go/compute/metadata v0.3.0
Expand Down
5 changes: 5 additions & 0 deletions tpu-provisioner/internal/cloud/common.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ const (

// AnnotationCopyLabels is a comma-separated list of labels to copy from the Pod to the node pool config (Nodes).
AnnotationCopyLabels = "tpu-provisioner.cloud.google.com/copy-labels"
// AnnotationAdditionalNodeNetworks is a comma-separated list of additional networks and subnets to attach to the node pool.
// Format: "<network-name>:<subnet-name>, ..."
AnnotationAdditionalNodeNetworks = "tpu-provisioner.cloud.google.com/additional-node-networks"
// AnnotatationServiceAccount is the GCP service account to use for the node pool.
AnnotationNodeServiceAccount = "tpu-provisioner.cloud.google.com/node-service-account"

EventNodePoolCreationStarted = "NodePoolCreationStarted"
EventNodePoolCreationSucceeded = "NodePoolCreationSucceeded"
Expand Down
38 changes: 36 additions & 2 deletions tpu-provisioner/internal/cloud/gke.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ const (
V4PodSliceAccelerator = "tpu-v4-podslice"
V5ePodSliceAccelerator = "tpu-v5-lite-podslice"
V5pPodSliceAccelerator = "tpu-v5p-slice"
V6eSliceAccelerator = "tpu-v6e-slice"

// Resource type labels
GoogleTPUResource = "google.com/tpu"
Expand Down Expand Up @@ -355,10 +356,40 @@ func (g *GKE) nodePoolForPod(name string, p *corev1.Pod) (*containerv1beta1.Node
}
}

var networkConfig *containerv1beta1.NodeNetworkConfig
var additionalNodeNetworks []*containerv1beta1.AdditionalNodeNetworkConfig
// additional-node-networks: "vpc1:subnet1, vpc2:subnet2"
for _, pair := range strings.Split(getAnnotation(p, AnnotationAdditionalNodeNetworks), ",") {
pair = strings.TrimSpace(pair)
if pair == "" {
continue
}

netAndSubnet := strings.SplitN(pair, ":", 2)
if len(netAndSubnet) != 2 {
return nil, fmt.Errorf("invalid additional network annotation: %v", pair)
}

additionalNodeNetworks = append(additionalNodeNetworks, &containerv1beta1.AdditionalNodeNetworkConfig{
Network: strings.TrimSpace(netAndSubnet[0]),
Subnetwork: strings.TrimSpace(netAndSubnet[1]),
})
}
if len(additionalNodeNetworks) > 0 {
networkConfig = &containerv1beta1.NodeNetworkConfig{
AdditionalNodeNetworkConfigs: additionalNodeNetworks,
}
}

nodeServiceAccount := g.ClusterContext.NodeServiceAccount
if sa, ok := p.Annotations[AnnotationNodeServiceAccount]; ok {
nodeServiceAccount = sa
}

return &containerv1beta1.NodePool{
Name: name,
Config: &containerv1beta1.NodeConfig{
ServiceAccount: g.ClusterContext.NodeServiceAccount,
ServiceAccount: nodeServiceAccount,
ShieldedInstanceConfig: &containerv1beta1.ShieldedInstanceConfig{
EnableIntegrityMonitoring: true,
EnableSecureBoot: g.ClusterContext.NodeSecureBoot,
Expand Down Expand Up @@ -387,6 +418,7 @@ func (g *GKE) nodePoolForPod(name string, p *corev1.Pod) (*containerv1beta1.Node
MaxSurge: 1,
},
MaxPodsConstraint: &containerv1beta1.MaxPodsConstraint{MaxPodsPerNode: maxPodsPerNode},
NetworkConfig: networkConfig,
}, nil
}

Expand Down Expand Up @@ -438,7 +470,7 @@ func tpuTopologyToNodeCount(accelerator, topo string) (int, error) {
switch accelerator {
case V4PodSliceAccelerator, V5pPodSliceAccelerator:
expectedDims = 3
case V5ePodSliceAccelerator:
case V5ePodSliceAccelerator, V6eSliceAccelerator:
expectedDims = 2
default:
return 0, fmt.Errorf("invalid accelerator: %v", accelerator)
Expand Down Expand Up @@ -475,6 +507,8 @@ func tpuMachineType(accel string, tpuRequest int) (string, error) {
return fmt.Sprintf("ct5lp-hightpu-%vt", tpuRequest), nil
case V5pPodSliceAccelerator: // v5p
return fmt.Sprintf("ct5p-hightpu-%vt", tpuRequest), nil
case V6eSliceAccelerator: // v6e
return fmt.Sprintf("ct6e-standard-%vt", tpuRequest), nil
}

return "", fmt.Errorf("invalid accelerator: %v", accel)
Expand Down
10 changes: 10 additions & 0 deletions tpu-provisioner/internal/cloud/gke_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ func Test_tpuTopologyToNodeCount(t *testing.T) {
topo: "not-a-topo",
err: true,
},
{
accel: "tpu-v6e-slice",
topo: "16x16",
count: 64,
},
{
accel: "tpu-v6e-slice",
topo: "1x1x1",
err: true,
},
}

for _, c := range cases {
Expand Down
Loading